Skip to content

Commit

Permalink
test: remove models and model from input output identifiers (#367)
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy authored Jun 30, 2023
1 parent cedbe2a commit a86fd7c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 28 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- Remove `model` and `models` for input and output identifiers in tests. Replace by `shared` instead. ([#367](https://github.com/Substra/substra/pull/367))

## [0.45.0](https://github.com/Substra/substra/releases/tag/0.45.0) - 2023-06-12

### Added
Expand Down
10 changes: 5 additions & 5 deletions tests/data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_predictions(path):
def train(inputs, outputs, task_properties):
X = inputs['{InputIdentifiers.datasamples}'][0]
y = inputs['{InputIdentifiers.datasamples}'][1]
models_path = inputs.get('{InputIdentifiers.models}', [])
models_path = inputs.get('{InputIdentifiers.shared}', [])
models = [_load_model(model_path) for model_path in models_path]
print(f'Train, get X: {{X}}, y: {{y}}, models: {{models}}')
Expand All @@ -95,12 +95,12 @@ def train(inputs, outputs, task_properties):
res = dict(value=avg + err)
print(f'Train, return {{res}}')
_save_model(res, outputs['{OutputIdentifiers.model}'])
_save_model(res, outputs['{OutputIdentifiers.shared}'])
@tools.register
def predict(inputs, outputs, task_properties):
X = inputs['{InputIdentifiers.datasamples}'][0]
model = _load_model(inputs['{InputIdentifiers.model}'])
model = _load_model(inputs['{InputIdentifiers.shared}'])
res = [x * model['value'] for x in X]
print(f'Predict, get X: {{X}}, model: {{model}}, return {{res}}')
Expand Down Expand Up @@ -129,14 +129,14 @@ def _save_predictions(y_pred, path):
@tools.register
def aggregate(inputs, outputs, task_properties):
models_path = inputs.get('{InputIdentifiers.models}', [])
models_path = inputs.get('{InputIdentifiers.shared}', [])
models = [_load_model(model_path) for model_path in models_path]
print(f'Aggregate models: {{models}}')
values = [m['value'] for m in models]
avg = sum(values) / len(values)
res = dict(value=avg)
print(f'Aggregate result: {{res}}')
_save_model(res, outputs['{OutputIdentifiers.model}'])
_save_model(res, outputs['{OutputIdentifiers.shared}'])
@tools.register
def predict(inputs, outputs, task_properties):
Expand Down
41 changes: 19 additions & 22 deletions tests/fl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class FunctionCategory(str, Enum):
class InputIdentifiers(str, Enum):
local = "local"
shared = "shared"
model = "model"
models = "models"
predictions = "predictions"
performance = "performance"
opener = "opener"
Expand All @@ -36,7 +34,6 @@ class InputIdentifiers(str, Enum):
class OutputIdentifiers(str, Enum):
local = "local"
shared = "shared"
model = "model"
predictions = "predictions"
performance = "performance"

Expand All @@ -45,7 +42,7 @@ class FLFunctionInputs(list, Enum):
"""Substra function inputs by function category based on the InputIdentifiers"""

FUNCTION_AGGREGATE = [
FunctionInputSpec(identifier=InputIdentifiers.models, kind=AssetKind.model.value, optional=False, multiple=True)
FunctionInputSpec(identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=False, multiple=True)
]
FUNCTION_SIMPLE = [
FunctionInputSpec(
Expand All @@ -57,7 +54,7 @@ class FLFunctionInputs(list, Enum):
FunctionInputSpec(
identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False
),
FunctionInputSpec(identifier=InputIdentifiers.models, kind=AssetKind.model.value, optional=True, multiple=True),
FunctionInputSpec(identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=True, multiple=True),
]
FUNCTION_COMPOSITE = [
FunctionInputSpec(
Expand Down Expand Up @@ -85,7 +82,7 @@ class FLFunctionInputs(list, Enum):
identifier=InputIdentifiers.opener, kind=AssetKind.data_manager.value, optional=False, multiple=False
),
FunctionInputSpec(
identifier=InputIdentifiers.model, kind=AssetKind.model.value, optional=False, multiple=False
identifier=InputIdentifiers.shared, kind=AssetKind.model.value, optional=False, multiple=False
),
]
FUNCTION_PREDICT_COMPOSITE = [
Expand Down Expand Up @@ -125,10 +122,10 @@ class FLFunctionOutputs(list, Enum):
"""Substra function outputs by function category based on the OutputIdentifiers"""

FUNCTION_AGGREGATE = [
FunctionOutputSpec(identifier=OutputIdentifiers.model, kind=AssetKind.model.value, multiple=False)
FunctionOutputSpec(identifier=OutputIdentifiers.shared, kind=AssetKind.model.value, multiple=False)
]
FUNCTION_SIMPLE = [
FunctionOutputSpec(identifier=OutputIdentifiers.model, kind=AssetKind.model.value, multiple=False)
FunctionOutputSpec(identifier=OutputIdentifiers.shared, kind=AssetKind.model.value, multiple=False)
]
FUNCTION_COMPOSITE = [
FunctionOutputSpec(identifier=OutputIdentifiers.local, kind=AssetKind.model.value, multiple=False),
Expand Down Expand Up @@ -168,9 +165,9 @@ def task(opener_key, data_sample_keys):
def trains_to_train(model_keys):
return [
InputRef(
identifier=InputIdentifiers.models,
identifier=InputIdentifiers.shared,
parent_task_key=model_key,
parent_task_output_identifier=OutputIdentifiers.model,
parent_task_output_identifier=OutputIdentifiers.shared,
)
for model_key in model_keys
]
Expand All @@ -179,9 +176,9 @@ def trains_to_train(model_keys):
def trains_to_aggregate(model_keys):
return [
InputRef(
identifier=InputIdentifiers.models,
identifier=InputIdentifiers.shared,
parent_task_key=model_key,
parent_task_output_identifier=OutputIdentifiers.model,
parent_task_output_identifier=OutputIdentifiers.shared,
)
for model_key in model_keys
]
Expand All @@ -190,9 +187,9 @@ def trains_to_aggregate(model_keys):
def train_to_predict(model_key):
return [
InputRef(
identifier=InputIdentifiers.model,
identifier=InputIdentifiers.shared,
parent_task_key=model_key,
parent_task_output_identifier=OutputIdentifiers.model,
parent_task_output_identifier=OutputIdentifiers.shared,
)
]

Expand Down Expand Up @@ -252,15 +249,15 @@ def aggregate_to_shared(model_key):
InputRef(
identifier=InputIdentifiers.shared,
parent_task_key=model_key,
parent_task_output_identifier=OutputIdentifiers.model,
parent_task_output_identifier=OutputIdentifiers.shared,
)
]

@staticmethod
def composites_to_aggregate(model_keys):
return [
InputRef(
identifier=InputIdentifiers.models,
identifier=InputIdentifiers.shared,
parent_task_key=model_key,
parent_task_output_identifier=OutputIdentifiers.shared,
)
Expand All @@ -271,17 +268,17 @@ def composites_to_aggregate(model_keys):
def aggregate_to_predict(model_key):
return [
InputRef(
identifier=InputIdentifiers.models,
identifier=InputIdentifiers.shared,
parent_task_key=model_key,
parent_task_output_identifier=OutputIdentifiers.model,
parent_task_output_identifier=OutputIdentifiers.shared,
)
]

@staticmethod
def local_to_aggregate(model_key):
return [
InputRef(
identifier=InputIdentifiers.models,
identifier=InputIdentifiers.shared,
parent_task_key=model_key,
parent_task_output_identifier=OutputIdentifiers.local,
)
Expand All @@ -291,7 +288,7 @@ def local_to_aggregate(model_key):
def shared_to_aggregate(model_key):
return [
InputRef(
identifier=InputIdentifiers.models,
identifier=InputIdentifiers.shared,
parent_task_key=model_key,
parent_task_output_identifier=OutputIdentifiers.shared,
)
Expand All @@ -310,11 +307,11 @@ class FLTaskOutputGenerator:

@staticmethod
def traintask(authorized_ids=None):
return {OutputIdentifiers.model: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))}
return {OutputIdentifiers.shared: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))}

@staticmethod
def aggregatetask(authorized_ids=None):
return {OutputIdentifiers.model: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))}
return {OutputIdentifiers.shared: ComputeTaskOutputSpec(permissions=_permission_from_ids(authorized_ids))}

@staticmethod
def predicttask(authorized_ids=None):
Expand Down
2 changes: 1 addition & 1 deletion tests/sdk/local/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_tasks_extra_fields(self, asset_factory, clients):

assert len(predicttask.inputs) == 3 # data sample + opener + input task

model_ref = [x for x in predicttask.inputs if x.identifier == InputIdentifiers.model]
model_ref = [x for x in predicttask.inputs if x.identifier == InputIdentifiers.shared]
assert len(model_ref) == 1
assert model_ref[0].parent_task_key == traintask_key

Expand Down

0 comments on commit a86fd7c

Please sign in to comment.