Skip to content

Commit

Permalink
Add test for consume name to name mapping (#867)
Browse files Browse the repository at this point in the history
Fix #863
  • Loading branch information
mrchtr authored Feb 27, 2024
1 parent 76e19bf commit 88c6ea8
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 167 deletions.
4 changes: 3 additions & 1 deletion examples/sample_pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,6 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
arguments={"arg_x": "value_x"},
)

dataset.write(ref="write_to_file", arguments={"path": "/data/export"})
dataset.write(
ref="write_to_file", arguments={"path": "/data/export"}, consumes={"text": "text"}
)
10 changes: 6 additions & 4 deletions src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def load_dataframe(self) -> dd.DataFrame:
DEFAULT_INDEX_NAME,
)

for field_name in self.operation_spec.outer_consumes:
for field_name in self.operation_spec.consumes_from_dataset:
location = self.manifest.get_field_location(field_name)
field_mapping[location].append(field_name)

Expand Down Expand Up @@ -130,7 +130,9 @@ def load_dataframe(self) -> dd.DataFrame:
msg = "No data could be loaded"
raise RuntimeError(msg)

if consumes_mapping := self.operation_spec._mappings["consumes"]:
if (
consumes_mapping := self.operation_spec.operation_consumes_to_dataset_consumes
):
dataframe = dataframe.rename(
columns={
v: k for k, v in consumes_mapping.items() if isinstance(v, str)
Expand Down Expand Up @@ -160,7 +162,7 @@ def write_dataframe(
dataframe.index = dataframe.index.rename(DEFAULT_INDEX_NAME)

# validation that all columns are in the dataframe
expected_columns = list(self.operation_spec.inner_produces)
expected_columns = list(self.operation_spec.operation_produces)
self.validate_dataframe_columns(dataframe, expected_columns)

dataframe = dataframe[expected_columns]
Expand Down Expand Up @@ -202,7 +204,7 @@ def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar:

schema = {
field.name: field.type.value
for field in self.operation_spec.outer_produces.values()
for field in self.operation_spec.produces_to_dataset.values()
}
return self._create_write_task(dataframe, location=location, schema=schema)

Expand Down
10 changes: 6 additions & 4 deletions src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def _run_execution(

component: Component = component_cls(**self.user_arguments)

component.consumes = self.operation_spec.inner_consumes
component.produces = self.operation_spec.inner_produces
component.consumes = self.operation_spec.operation_consumes
component.produces = self.operation_spec.operation_produces

state = component.setup()

Expand Down Expand Up @@ -495,7 +495,9 @@ def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame:
dataframe = transform(dataframe)

# Drop columns not in specification
columns = [name for name, field in operation_spec.inner_produces.items()]
columns = [
name for name, field in operation_spec.operation_produces.items()
]

return dataframe[columns]

Expand Down Expand Up @@ -523,7 +525,7 @@ def _execute_component(

# Create meta dataframe with expected format
meta_dict = {"id": pd.Series(dtype="object")}
for field_name, field in self.operation_spec.inner_produces.items():
for field_name, field in self.operation_spec.operation_produces.items():
meta_dict[field_name] = pd.Series(dtype=pd.ArrowDtype(field.type.value))
meta_df = pd.DataFrame(meta_dict).set_index("id")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_pdf_reader():

for path in pdf_path:
component = PDFReader(
produces=dict(spec.inner_produces),
produces=dict(spec.operation_produces),
pdf_path=path,
n_rows_to_load=None,
index_column=None,
Expand Down
115 changes: 74 additions & 41 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,10 @@ def __init__(
}
self._validate_mappings()

self._inner_consumes: t.Optional[t.Mapping[str, Field]] = None
self._outer_consumes: t.Optional[t.Mapping[str, Field]] = None
self._inner_produces: t.Optional[t.Mapping[str, Field]] = None
self._outer_produces: t.Optional[t.Mapping[str, Field]] = None
self._operation_consumes: t.Optional[t.Mapping[str, Field]] = None
self._consumes_from_dataset: t.Optional[t.Mapping[str, Field]] = None
self._operation_produces: t.Optional[t.Mapping[str, Field]] = None
self._produces_to_dataset: t.Optional[t.Mapping[str, Field]] = None

def to_dict(self) -> dict:
def _dump_mapping(
Expand Down Expand Up @@ -414,8 +414,10 @@ def _validate_mappings(self) -> None:
msg = f"Unexpected type {type(value)} received for key {key} in {name} mapping"
raise InvalidPipelineDefinition(msg)

def _inner_mapping(self, name: str) -> t.Mapping[str, Field]:
"""Calculate the "inner mapping" of the operation. This is the mapping that the component
def _dataset_schema_to_operation_schema(self, name: str) -> t.Mapping[str, Field]:
"""Calculate the operations schema based on dataset schema.
Maps dataset fields to the fields that the operation will receive.
This is the mapping that the component
`transform` (or equivalent) method will receive. This is calculated by starting from the
component spec section, and updating it with any string to type mappings from the
argument mapping.
Expand Down Expand Up @@ -453,73 +455,104 @@ def _inner_mapping(self, name: str) -> t.Mapping[str, Field]:

return types.MappingProxyType(mapping)

def _outer_mapping(self, name: str) -> t.Mapping[str, Field]:
"""Calculate the "outer mapping" of the operation. This is the mapping that the dataIO
needs to read / write. This is calculated by starting from the "inner mapping" updating it
def _operation_schema_to_dataset_schema(self, name: str) -> t.Mapping[str, Field]:
"""
Maps the operations fields to the dataset fields which will be written in DataIO.
This is calculated by starting from the "inner mapping" updating it
with any string to string mappings from the argument mapping.
Args:
name: "consumes" or "produces"
"""
spec_mapping = getattr(self, f"inner_{name}")
args_mapping = self._mappings[name]
operations_schema = getattr(self, f"operation_{name}")

if not args_mapping:
return spec_mapping
# Dataset arguments mapping
operation_to_dataset_mapping = self._mappings[name]

mapping = dict(spec_mapping)
if not operation_to_dataset_mapping:
return operations_schema

for key, value in args_mapping.items():
if not isinstance(value, str):
mapping = dict(operations_schema)

for (
operations_column_name,
dataset_column_name_or_type,
) in operation_to_dataset_mapping.items():
# If the value is not a string, it means it's a type, so we skip it
if not isinstance(dataset_column_name_or_type, str):
continue

if key in spec_mapping:
mapping[value] = Field(name=value, type=mapping.pop(key).type)
if operations_column_name in operations_schema:
mapping[dataset_column_name_or_type] = Field(
name=dataset_column_name_or_type,
type=mapping.pop(operations_column_name).type,
)
else:
msg = (
f"Received a string value for key `{key}` in the `{name}` "
f"argument passed to the operation, but `{key}` is not defined in "
f"the `{name}` section of the component spec."
f"Received a string value for key `{operations_column_name}` in the `{name}` "
f"argument passed to the operation, but `{operations_column_name}` is not "
f"defined in the `{name}` section of the component spec."
)
raise InvalidPipelineDefinition(msg)

return types.MappingProxyType(mapping)

@property
def inner_consumes(self) -> t.Mapping[str, Field]:
"""The "inner" `consumes` mapping which the component `transform` (or equivalent) method
def operation_consumes(self) -> t.Mapping[str, Field]:
"""The operations `consumes` mapping which the component `transform` (or equivalent) method
will receive.
"""
if self._inner_consumes is None:
self._inner_consumes = self._inner_mapping("consumes")
if self._operation_consumes is None:
self._operation_consumes = self._dataset_schema_to_operation_schema(
"consumes",
)

return self._inner_consumes
return self._operation_consumes

@property
def outer_consumes(self) -> t.Mapping[str, Field]:
"""The "outer" `consumes` mapping which the dataIO needs to read / write."""
if self._outer_consumes is None:
self._outer_consumes = self._outer_mapping("consumes")
def consumes_from_dataset(self) -> t.Mapping[str, Field]:
"""Defines which fields of the dataset are consumed by the operation."""
if self._consumes_from_dataset is None:
self._consumes_from_dataset = self._operation_schema_to_dataset_schema(
"consumes",
)

return self._outer_consumes
return self._consumes_from_dataset

@property
def inner_produces(self) -> t.Mapping[str, Field]:
"""The "inner" `produces` mapping which the component `transform` (or equivalent) method
def operation_produces(self) -> t.Mapping[str, Field]:
"""The operations `produces` mapping which the component `transform` (or equivalent) method
will receive.
"""
if self._inner_produces is None:
self._inner_produces = self._inner_mapping("produces")
if self._operation_produces is None:
self._operation_produces = self._dataset_schema_to_operation_schema(
"produces",
)

return self._inner_produces
return self._operation_produces

@property
def outer_produces(self) -> t.Mapping[str, Field]:
"""The "outer" `produces` mapping which the dataIO needs to read / write."""
if self._outer_produces is None:
self._outer_produces = self._outer_mapping("produces")
def produces_to_dataset(self) -> t.Mapping[str, Field]:
"""The produces schema used by data_io to write the dataset."""
if self._produces_to_dataset is None:
self._produces_to_dataset = self._operation_schema_to_dataset_schema(
"produces",
)

return self._produces_to_dataset

return self._outer_produces
@property
def operation_consumes_to_dataset_consumes(self):
"""
The consumes name mapping. The key is the name of the field in the operation, value is the
name of the field in the dataset.
E.g.:
{
"OperationField": "DatasetField"
}
"""
return self._mappings["consumes"]

@property
def component_name(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion src/fondant/core/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def evolve( # : PLR0912 (too many branches)
evolved_manifest.remove_field(field_name)

# Add or update all produced fields defined in the component spec
for name, field in operation_spec.outer_produces.items():
for name, field in operation_spec.produces_to_dataset.items():
# If field was not part of the input manifest, add field to output manifest.
# If field was part of the input manifest and got produced by the component, update
# the manifest field.
Expand Down
75 changes: 65 additions & 10 deletions src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module defines classes to represent a Fondant Pipeline."""
import copy
import datetime
import hashlib
import inspect
Expand Down Expand Up @@ -145,13 +146,18 @@ def __init__(
cache: t.Optional[bool] = True,
resources: t.Optional[Resources] = None,
component_dir: t.Optional[Path] = None,
dataset_fields: t.Optional[t.Mapping[str, Field]] = None,
) -> None:
self.image = image
self.component_spec = component_spec
self.input_partition_rows = input_partition_rows
self.cache = self._configure_caching_from_image_tag(cache)
self.component_dir = component_dir

if consumes is None:
consumes = self._infer_consumes(component_spec, dataset_fields)
consumes = self._validate_consumes(consumes, component_spec, dataset_fields)

self.operation_spec = OperationSpec(
self.component_spec,
consumes=consumes,
Expand Down Expand Up @@ -186,10 +192,6 @@ def from_component_yaml(cls, path, fields=None, **kwargs) -> "ComponentOp":
component_dir / cls.COMPONENT_SPEC_NAME,
)

# If consumes is not defined in the pipeline, we will try to infer it
if kwargs.get("consumes") is None:
kwargs["consumes"] = cls._infer_consumes(component_spec, fields)

image = Image(
base_image=component_spec.image,
)
Expand All @@ -198,11 +200,16 @@ def from_component_yaml(cls, path, fields=None, **kwargs) -> "ComponentOp":
image=image,
component_spec=component_spec,
component_dir=component_dir,
dataset_fields=fields,
**kwargs,
)

@classmethod
def _infer_consumes(cls, component_spec, dataset_fields):
def _infer_consumes(
cls,
component_spec,
dataset_fields,
) -> t.Union[t.Optional[t.Dict[str, str]], t.Optional[t.Dict[str, pa.DataType]]]:
"""Infer the consumes section of the component spec."""
if component_spec.consumes_is_defined is False:
msg = (
Expand Down Expand Up @@ -230,6 +237,57 @@ def _infer_consumes(cls, component_spec, dataset_fields):
# in the component spec
return {k: v.type.value for k, v in component_spec.consumes.items()}

@classmethod
def _validate_consumes(
cls,
consumes: t.Optional[t.Dict[str, str]],
component_spec: ComponentSpec,
dataset_fields: t.Optional[t.Mapping[str, Field]],
) -> t.Union[t.Optional[t.Dict[str, str]], t.Optional[t.Dict[str, pa.DataType]]]:
"""
Validate the consumes of the component spec.
Every column in the consumes should be present in the dataset fields and in the
ComponentSpec. Except if additionalProperties is set to True in the ComponentSpec.
In that case, we will infer the type from the dataset fields.
"""
if consumes is None or dataset_fields is None:
return consumes

validated_consumes = copy.deepcopy(consumes)

for operations_column_name, dataset_column_name_or_type in consumes.items():
# Dataset column name is part of the dataset fields
if (
isinstance(dataset_column_name_or_type, str)
and dataset_column_name_or_type not in dataset_fields.keys()
):
msg = (
f"The dataset does not contain the column {dataset_column_name_or_type} "
f"required by the component {component_spec.name}."
)
raise InvalidPipelineDefinition(msg)

# If operations column name is not in the component spec, but additional properties
# are true we will infer the correct type from the dataset fields
if (
isinstance(dataset_column_name_or_type, str)
and operations_column_name not in component_spec.consumes.keys()
):
if component_spec.consumes_additional_properties:
validated_consumes[operations_column_name] = dataset_fields[
operations_column_name
].type.value
else:
msg = (
f"Received a string value for key `{operations_column_name}` in the "
f"`consumes` argument passed to the operation, "
f"but `{operations_column_name}` is not defined in the `consumes` "
f"section of the component spec."
)
raise InvalidPipelineDefinition(msg)

return validated_consumes

@classmethod
def from_ref(
cls,
Expand All @@ -251,13 +309,10 @@ def from_ref(
if issubclass(ref, LightweightComponent):
component_spec = ref.get_component_spec()

# If consumes is not defined in the pipeline, we will try to infer it
if kwargs.get("consumes") is None:
kwargs["consumes"] = cls._infer_consumes(component_spec, fields)

operation = cls(
ref.image(),
component_spec,
dataset_fields=fields,
**kwargs,
)
else:
Expand Down Expand Up @@ -565,7 +620,7 @@ def _validate_pipeline_definition(self, run_id: str):
for (
component_field_name,
component_field,
) in operation_spec.outer_consumes.items():
) in operation_spec.consumes_from_dataset.items():
if component_field_name not in manifest.fields:
msg = (
f"Component '{component_op.component_name}' is trying to invoke the"
Expand Down
Loading

0 comments on commit 88c6ea8

Please sign in to comment.