diff --git a/src/fondant/component/data_io.py b/src/fondant/component/data_io.py index 7023c1ee2..79a181f8d 100644 --- a/src/fondant/component/data_io.py +++ b/src/fondant/component/data_io.py @@ -1,16 +1,19 @@ import logging import os import typing as t +from collections import defaultdict import dask.dataframe as dd from dask.diagnostics import ProgressBar from dask.distributed import Client -from fondant.core.component_spec import ComponentSpec, ComponentSubset +from fondant.core.component_spec import ComponentSpec from fondant.core.manifest import Manifest logger = logging.getLogger(__name__) +DEFAULT_INDEX_NAME = "id" + class DataIO: def __init__(self, *, manifest: Manifest, component_spec: ComponentSpec) -> None: @@ -82,73 +85,48 @@ def partition_loaded_dataframe(self, dataframe: dd.DataFrame) -> dd.DataFrame: return dataframe - def _load_subset(self, subset_name: str, fields: t.List[str]) -> dd.DataFrame: + def load_dataframe(self) -> dd.DataFrame: """ - Function that loads a subset from the manifest as a Dask dataframe. - - Args: - subset_name: the name of the subset to load - fields: the fields to load from the subset + Function that loads the subsets defined in the component spec as a single Dask dataframe for + the user. Returns: - The subset as a dask dataframe + The Dask dataframe with all columns defined in the manifest field mapping """ - subset = self.manifest.subsets[subset_name] - remote_path = subset.location - - logger.info(f"Loading subset {subset_name} with fields {fields}...") + dataframe = None + field_mapping = defaultdict(list) - subset_df = dd.read_parquet( - remote_path, - columns=fields, - calculate_divisions=True, + # Add index field to field mapping to guarantee start reading with the index dataframe + field_mapping[self.manifest.get_field_location(DEFAULT_INDEX_NAME)].append( + DEFAULT_INDEX_NAME, ) - # add subset prefix to columns - subset_df = subset_df.rename( - columns={col: subset_name + "_" + col for col in subset_df.columns}, - ) + for field_name in self.component_spec.consumes: + location = self.manifest.get_field_location(field_name) + field_mapping[location].append(field_name) - return subset_df - - def _load_index(self) -> dd.DataFrame: - """ - Function that loads the index from the manifest as a Dask dataframe. - - Returns: - The index as a dask dataframe - """ - # get index subset from the manifest - index = self.manifest.index - # get remote path - remote_path = index.location - - # load index from parquet, expecting id and source columns - return dd.read_parquet(remote_path, calculate_divisions=True) - - def load_dataframe(self) -> dd.DataFrame: - """ - Function that loads the subsets defined in the component spec as a single Dask dataframe for - the user. + for location, fields in field_mapping.items(): + if DEFAULT_INDEX_NAME in fields: + fields.remove(DEFAULT_INDEX_NAME) - Returns: - The Dask dataframe with the field columns in the format (_) - as well as the index columns. - """ - # load index into dataframe - dataframe = self._load_index() - for name, subset in self.component_spec.consumes.items(): - fields = list(subset.fields.keys()) - subset_df = self._load_subset(name, fields) - # left joins -> filter on index - dataframe = dd.merge( - dataframe, - subset_df, - left_index=True, - right_index=True, - how="left", + partial_df = dd.read_parquet( + location, + columns=fields, + index=DEFAULT_INDEX_NAME, + calculate_divisions=True, ) + if dataframe is None: + # ensure that the index is set correctly and divisions are known. + dataframe = partial_df + else: + dataframe = dataframe.merge( + partial_df, + how="left", + left_index=True, + right_index=True, + ) + dataframe = self.partition_loaded_dataframe(dataframe) logging.info(f"Columns of dataframe: {list(dataframe.columns)}") @@ -170,79 +148,48 @@ def write_dataframe( dataframe: dd.DataFrame, dask_client: t.Optional[Client] = None, ) -> None: - write_tasks = [] + columns_to_produce = [ + column_name for column_name, field in self.component_spec.produces.items() + ] - dataframe.index = dataframe.index.rename("id") + dataframe.index = dataframe.index.rename(DEFAULT_INDEX_NAME) - # Turn index into an empty dataframe so we can write it - index_df = dataframe.index.to_frame().drop(columns=["id"]) - write_index_task = self._write_subset( - index_df, - subset_name="index", - subset_spec=self.component_spec.index, - ) - write_tasks.append(write_index_task) + # validation that all columns are in the dataframe + self.validate_dataframe_columns(dataframe, columns_to_produce) - for subset_name, subset_spec in self.component_spec.produces.items(): - subset_df = self._extract_subset_dataframe( - dataframe, - subset_name=subset_name, - subset_spec=subset_spec, - ) - write_subset_task = self._write_subset( - subset_df, - subset_name=subset_name, - subset_spec=subset_spec, - ) - write_tasks.append(write_subset_task) + dataframe = dataframe[columns_to_produce] + write_task = self._write_dataframe(dataframe) with ProgressBar(): logging.info("Writing data...") - # alternative implementation possible: futures = client.compute(...) - dd.compute(*write_tasks, scheduler=dask_client) + dd.compute(write_task, scheduler=dask_client) @staticmethod - def _extract_subset_dataframe( - dataframe: dd.DataFrame, - *, - subset_name: str, - subset_spec: ComponentSubset, - ) -> dd.DataFrame: - """Create subset dataframe to save with the original field name as the column name.""" - # Create a new dataframe with only the columns needed for the output subset - subset_columns = [f"{subset_name}_{field}" for field in subset_spec.fields] - try: - subset_df = dataframe[subset_columns] - except KeyError as e: + def validate_dataframe_columns(dataframe: dd.DataFrame, columns: t.List[str]): + """Validates that all columns are available in the dataset.""" + missing_fields = [] + for col in columns: + if col not in dataframe.columns: + missing_fields.append(col) + + if missing_fields: msg = ( - f"Field {e.args[0]} defined in output subset {subset_name} " + f"Fields {missing_fields} defined in output dataset " f"but not found in dataframe" ) raise ValueError( msg, ) - # Remove the subset prefix from the column names - subset_df = subset_df.rename( - columns={col: col[(len(f"{subset_name}_")) :] for col in subset_columns}, + def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar: + """Create dataframe writing task.""" + location = ( + self.manifest.base_path + "/" + self.component_spec.component_folder_name ) - - return subset_df - - def _write_subset( - self, - dataframe: dd.DataFrame, - *, - subset_name: str, - subset_spec: ComponentSubset, - ) -> dd.core.Scalar: - if subset_name == "index": - location = self.manifest.index.location - else: - location = self.manifest.subsets[subset_name].location - - schema = {field.name: field.type.value for field in subset_spec.fields.values()} - + schema = { + field.name: field.type.value + for field in self.component_spec.produces.values() + } return self._create_write_task(dataframe, location=location, schema=schema) @staticmethod diff --git a/src/fondant/component/executor.py b/src/fondant/component/executor.py index 3d4d6097f..d77200da8 100644 --- a/src/fondant/component/executor.py +++ b/src/fondant/component/executor.py @@ -491,14 +491,11 @@ def optional_fondant_arguments() -> t.List[str]: @staticmethod def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable: """Factory that creates a function to wrap the component transform function. The wrapper: - - Converts the columns to hierarchical format before passing the dataframe to the - transform function - Removes extra columns from the returned dataframe which are not defined in the component spec `produces` section - Sorts the columns from the returned dataframe according to the order in the component spec `produces` section to match the order in the `meta` argument passed to Dask's `map_partitions`. - - Flattens the returned dataframe columns. Args: transform: Transform method to wrap @@ -506,27 +503,13 @@ def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable: """ def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame: - # Switch to hierarchical columns - dataframe.columns = pd.MultiIndex.from_tuples( - tuple(column.split("_")) for column in dataframe.columns - ) - # Call transform method dataframe = transform(dataframe) # Drop columns not in specification - columns = [ - (subset_name, field) - for subset_name, subset in spec.produces.items() - for field in subset.fields - ] - dataframe = dataframe[columns] - - # Switch to flattened columns - dataframe.columns = [ - "_".join(column) for column in dataframe.columns.to_flat_index() - ] - return dataframe + columns = [name for name, field in spec.produces.items()] + + return dataframe[columns] return wrapped_transform @@ -552,11 +535,8 @@ def _execute_component( # Create meta dataframe with expected format meta_dict = {"id": pd.Series(dtype="object")} - for subset_name, subset in self.spec.produces.items(): - for field_name, field in subset.fields.items(): - meta_dict[f"{subset_name}_{field_name}"] = pd.Series( - dtype=pd.ArrowDtype(field.type.value), - ) + for field_name, field in self.spec.produces.items(): + meta_dict[field_name] = pd.Series(dtype=pd.ArrowDtype(field.type.value)) meta_df = pd.DataFrame(meta_dict).set_index("id") wrapped_transform = self.wrap_transform(component.transform, spec=self.spec) @@ -573,8 +553,10 @@ def _execute_component( return dataframe + # TODO: fix in #244 def _infer_index_change(self) -> bool: """Infer if this component changes the index based on its component spec.""" + """ if not self.spec.accepts_additional_subsets: return True if not self.spec.outputs_additional_subsets: @@ -585,6 +567,8 @@ def _infer_index_change(self) -> bool: return any( not subset.additional_fields for subset in self.spec.produces.values() ) + """ + return False class DaskWriteExecutor(Executor[DaskWriteComponent]): diff --git a/src/fondant/core/manifest.py b/src/fondant/core/manifest.py index fc750620d..013ce2b71 100644 --- a/src/fondant/core/manifest.py +++ b/src/fondant/core/manifest.py @@ -4,7 +4,6 @@ import pkgutil import types import typing as t -from collections import OrderedDict from dataclasses import asdict, dataclass from pathlib import Path @@ -146,7 +145,7 @@ def metadata(self) -> t.Dict[str, t.Any]: @property def index(self) -> Field: - return Field(name="Index", location=self._specification["index"]["location"]) + return Field(name="id", location=self._specification["index"]["location"]) def update_metadata(self, key: str, value: t.Any) -> None: self.metadata[key] = value @@ -155,43 +154,16 @@ def update_metadata(self, key: str, value: t.Any) -> None: def base_path(self) -> str: return self.metadata["base_path"] - @property - def field_mapping(self) -> t.Mapping[str, t.List[str]]: - """ - Retrieve a mapping of field locations to corresponding field names. - A dictionary where keys are field locations and values are lists - of column names. - - The method returns an immutable OrderedDict where the first dict element contains the - location of the dataframe with the index. This allows an efficient left join operation. - - Example: - { - "/base_path/component_1": ["Name", "HP"], - "/base_path/component_2": ["Type 1", "Type 2"], - } - """ - field_mapping = {} - for field_name, field in {"id": self.index, **self.fields}.items(): - location = ( - f"{self.base_path}/{self.pipeline_name}/{self.run_id}{field.location}" - ) - if location in field_mapping: - field_mapping[location].append(field_name) - else: - field_mapping[location] = [field_name] - - # Sort field mapping that the first dataset contains the index - sorted_keys = sorted( - field_mapping.keys(), - key=lambda key: "id" in field_mapping[key], - reverse=True, - ) - sorted_field_mapping = OrderedDict( - (key, field_mapping[key]) for key in sorted_keys - ) + def get_field_location(self, field_name: str): + """Return absolute path to the field location.""" + if field_name == "id": + return f"{self.base_path}/{self.pipeline_name}/{self.run_id}{self.index.location}" + if field_name not in self.fields: + msg = f"Field {field_name} is not available in the manifest." + raise ValueError(msg) - return types.MappingProxyType(sorted_field_mapping) + field = self.fields[field_name] + return f"{self.base_path}/{self.pipeline_name}/{self.run_id}{field.location}" @property def run_id(self) -> str: diff --git a/tests/component/examples/component_specs/arguments/component.yaml b/tests/component/examples/component_specs/arguments/component.yaml new file mode 100644 index 000000000..659ed0026 --- /dev/null +++ b/tests/component/examples/component_specs/arguments/component.yaml @@ -0,0 +1,68 @@ +name: Example component +description: This is an example component +image: example_component:latest + +args: + string_default_arg: + description: default string argument + type: str + default: foo + integer_default_arg: + description: default integer argument + type: int + default: 0 + float_default_arg: + description: default float argument + type: float + default: 3.14 + bool_false_default_arg: + description: default bool argument + type: bool + default: False + bool_true_default_arg: + description: default bool argument + type: bool + default: True + list_default_arg: + description: default list argument + type: list + default: ["foo", "bar"] + dict_default_arg: + description: default dict argument + type: dict + default: {"foo":1, "bar":2} + string_default_arg_none: + description: default string argument + type: str + default: None + integer_default_arg_none: + description: default integer argument + type: int + default: 0 + float_default_arg_none: + description: default float argument + type: float + default: 0.0 + bool_default_arg_none: + description: default bool argument + type: bool + default: False + list_default_arg_none: + description: default list argument + type: list + default: [] + dict_default_arg_none: + description: default dict argument + type: dict + default: {} + override_default_arg: + description: argument with default python value type that can be overriden + type: str + default: foo + override_default_arg_with_none: + description: argument with default python type that can be overriden with None + type: str + optional_arg: + description: optional argument + type: str + default: None diff --git a/tests/component/examples/component_specs/arguments/component_default_args.yaml b/tests/component/examples/component_specs/arguments/component_default_args.yaml new file mode 100644 index 000000000..816211c04 --- /dev/null +++ b/tests/component/examples/component_specs/arguments/component_default_args.yaml @@ -0,0 +1,69 @@ +name: Example component +description: This is an example component +image: example_component:latest + +args: + string_default_arg: + description: default string argument + type: str + default: foo + integer_default_arg: + description: default integer argument + type: int + default: 1 + float_default_arg: + description: default float argument + type: float + default: 3.14 + bool_false_default_arg: + description: default bool argument + type: bool + default: False + bool_true_default_arg: + description: default bool argument + type: bool + default: True + list_default_arg: + description: default list argument + type: list + default: ["foo", "bar"] + dict_default_arg: + description: default dict argument + type: dict + default: {"foo":1, "bar":2} + string_default_arg_none: + description: default string argument + type: str + default: None + integer_default_arg_none: + description: default integer argument + type: int + default: None + float_default_arg_none: + description: default float argument + type: float + default: None + bool_default_arg_none: + description: default bool argument + type: bool + default: None + list_default_arg_none: + description: default list argument + type: list + default: None + dict_default_arg_none: + description: default dict argument + type: dict + default: None + override_default_arg: + description: argument with default python value type that can be overriden + type: str + default: foo + override_default_none_arg: + description: argument with default None value type that can be overriden with a valid python type + type: float + default: None + override_default_arg_with_none: + description: argument with default python type that can be overriden with None + type: str + diff --git a/tests/component/examples/component_specs/arguments/input_manifest.json b/tests/component/examples/component_specs/arguments/input_manifest.json new file mode 100644 index 000000000..9ee2494f9 --- /dev/null +++ b/tests/component/examples/component_specs/arguments/input_manifest.json @@ -0,0 +1,18 @@ +{ + "metadata": { + "pipeline_name": "example_pipeline", + "base_path": "tests/example_data/subsets_input/mock_base_path", + "run_id": "example_pipeline_123", + "component_id": "component_1", + "cache_key": "00" + }, + "index": { + "location": "/component_1" + }, + "fields": { + "data": { + "type": "binary", + "location": "/component_1" + } + } +} \ No newline at end of file diff --git a/tests/component/examples/component_specs/component.yaml b/tests/component/examples/component_specs/component.yaml new file mode 100644 index 000000000..973cc3e6b --- /dev/null +++ b/tests/component/examples/component_specs/component.yaml @@ -0,0 +1,23 @@ +name: Example component +description: This is an example component +image: example_component:latest + +consumes: + images_data: + type: binary + +produces: + images_data: + type: array + items: + type: float32 +additionalFields: false + + +args: + flag: + description: user argument + type: str + value: + description: integer value + type: int diff --git a/tests/component/examples/component_specs/input_manifest.json b/tests/component/examples/component_specs/input_manifest.json new file mode 100644 index 000000000..80fa0b91d --- /dev/null +++ b/tests/component/examples/component_specs/input_manifest.json @@ -0,0 +1,17 @@ +{ + "metadata": { + "pipeline_name": "test_pipeline", + "base_path": "/bucket", + "run_id": "test_pipeline_12345", + "component_id": "67890" + }, + "index": { + "location": "/example_component" + }, + "fields": { + "data": { + "location": "/example_component", + "type": "binary" + } + } +} \ No newline at end of file diff --git a/tests/component/examples/data/components/1.yaml b/tests/component/examples/data/components/1.yaml new file mode 100644 index 000000000..95e5e578f --- /dev/null +++ b/tests/component/examples/data/components/1.yaml @@ -0,0 +1,29 @@ +name: Test component 1 +description: This is an example component +image: example_component:latest + +consumes: + Name: + type: "string" + HP: + type: "int32" + + Type 1: + type: "string" + Type 2: + type: "string" + +produces: + Name: + type: "string" + HP: + type: "int32" + Type 1: + type: "string" + Type 2: + type: "string" + +args: + storage_args: + description: Storage arguments + type: str \ No newline at end of file diff --git a/tests/component/examples/data/manifest.json b/tests/component/examples/data/manifest.json new file mode 100644 index 000000000..cc579fef1 --- /dev/null +++ b/tests/component/examples/data/manifest.json @@ -0,0 +1,29 @@ +{ + "metadata": { + "pipeline_name": "test_pipeline", + "base_path": "tests/component/examples/data", + "run_id": "test_pipeline_12345", + "component_id": "67890" + }, + "index": { + "location": "/component_1" + }, + "fields": { + "Name": { + "type": "string", + "location": "/component_1" + }, + "HP": { + "type": "int32", + "location": "/component_1" + }, + "Type 1": { + "type": "string", + "location": "/component_2" + }, + "Type 2": { + "type": "string", + "location": "/component_2" + } + } +} \ No newline at end of file diff --git a/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.0.parquet b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.0.parquet new file mode 100644 index 000000000..fa5d96dad Binary files /dev/null and b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.0.parquet differ diff --git a/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.1.parquet b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.1.parquet new file mode 100644 index 000000000..0c86db04d Binary files /dev/null and b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.1.parquet differ diff --git a/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.2.parquet b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.2.parquet new file mode 100644 index 000000000..d226a4249 Binary files /dev/null and b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_1/part.2.parquet differ diff --git a/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.0.parquet b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.0.parquet new file mode 100644 index 000000000..80c4500be Binary files /dev/null and b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.0.parquet differ diff --git a/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.1.parquet b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.1.parquet new file mode 100644 index 000000000..2dd74184f Binary files /dev/null and b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.1.parquet differ diff --git a/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.2.parquet b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.2.parquet new file mode 100644 index 000000000..1ae8001c0 Binary files /dev/null and b/tests/component/examples/data/test_pipeline/test_pipeline_12345/component_2/part.2.parquet differ diff --git a/tests/component/examples/mock_base_path/example_pipeline/cache/42.txt b/tests/component/examples/mock_base_path/example_pipeline/cache/42.txt new file mode 100644 index 000000000..4a9ff8afc --- /dev/null +++ b/tests/component/examples/mock_base_path/example_pipeline/cache/42.txt @@ -0,0 +1 @@ +tests/component/examples/mock_base_path/example_pipeline/example_pipeline_2023/component_1/manifest.json \ No newline at end of file diff --git a/tests/component/examples/mock_base_path/example_pipeline/example_pipeline_2023/component_1/manifest.json b/tests/component/examples/mock_base_path/example_pipeline/example_pipeline_2023/component_1/manifest.json new file mode 100644 index 000000000..47c2fe949 --- /dev/null +++ b/tests/component/examples/mock_base_path/example_pipeline/example_pipeline_2023/component_1/manifest.json @@ -0,0 +1,31 @@ +{ + "metadata": { + "pipeline_name": "example_pipeline", + "base_path": "tests/example_data/subsets_input/mock_base_path", + "run_id": "example_pipeline_2023", + "component_id": "component_1", + "cache_key": "42" + }, + "index": { + "location": "/component_1" + }, + "fields": + { + "data": { + "type": "binary", + "location": "/component_1" + }, + "height": { + "type": "int32", + "location": "/component_1" + }, + "width": { + "type": "int32", + "location": "/component_1" + }, + "captions": { + "type": "string", + "location": "/component_1" + } + } +} \ No newline at end of file diff --git a/tests/test_component.py b/tests/component/test_component.py similarity index 99% rename from tests/test_component.py rename to tests/component/test_component.py index e5dcb3bc3..830ce2963 100644 --- a/tests/test_component.py +++ b/tests/component/test_component.py @@ -23,8 +23,8 @@ from fondant.core.component_spec import ComponentSpec from fondant.core.manifest import Manifest, Metadata -components_path = Path(__file__).parent / "example_specs/components" -base_path = Path(__file__).parent / "example_specs/mock_base_path" +components_path = Path(__file__).parent / "examples/component_specs" +base_path = Path(__file__).parent / "examples/mock_base_path" N_PARTITIONS = 2 diff --git a/tests/test_data_io.py b/tests/component/test_data_io.py similarity index 61% rename from tests/test_data_io.py rename to tests/component/test_data_io.py index 9ade4a329..30a4b7c10 100644 --- a/tests/test_data_io.py +++ b/tests/component/test_data_io.py @@ -8,8 +8,10 @@ from fondant.core.component_spec import ComponentSpec from fondant.core.manifest import Manifest -manifest_path = Path(__file__).parent / "example_data/manifest.json" -component_spec_path = Path(__file__).parent / "example_data/components/1.yaml" +manifest_path = Path(__file__).parent / "examples/data/manifest.json" +component_spec_path = ( + Path(__file__).parent / "examples/data/components/1.yaml" +) NUMBER_OF_TEST_ROWS = 151 @@ -37,33 +39,16 @@ def dataframe(manifest, component_spec): return data_loader.load_dataframe() -def test_load_index(manifest, component_spec): - """Test the loading of just the index.""" - data_loader = DaskDataLoader(manifest=manifest, component_spec=component_spec) - index_df = data_loader._load_index() - assert len(index_df) == NUMBER_OF_TEST_ROWS - assert index_df.index.name == "id" - - -def test_load_subset(manifest, component_spec): - """Test the loading of one field of a subset.""" - data_loader = DaskDataLoader(manifest=manifest, component_spec=component_spec) - subset_df = data_loader._load_subset(subset_name="types", fields=["Type 1"]) - assert len(subset_df) == NUMBER_OF_TEST_ROWS - assert list(subset_df.columns) == ["types_Type 1"] - assert subset_df.index.name == "id" - - def test_load_dataframe(manifest, component_spec): - """Test merging of subsets in a dataframe based on a component_spec.""" + """Test merging of fields in a dataframe based on a component_spec.""" dl = DaskDataLoader(manifest=manifest, component_spec=component_spec) dataframe = dl.load_dataframe() assert len(dataframe) == NUMBER_OF_TEST_ROWS assert list(dataframe.columns) == [ - "properties_Name", - "properties_HP", - "types_Type 1", - "types_Type 2", + "Name", + "HP", + "Type 1", + "Type 2", ] assert dataframe.index.name == "id" @@ -78,7 +63,7 @@ def test_load_dataframe_default(manifest, component_spec): def test_load_dataframe_rows(manifest, component_spec): - """Test merging of subsets in a dataframe based on a component_spec.""" + """Test merging of fields in a dataframe based on a component_spec.""" dl = DaskDataLoader( manifest=manifest, component_spec=component_spec, @@ -89,34 +74,7 @@ def test_load_dataframe_rows(manifest, component_spec): assert dataframe.npartitions == expected_partitions -def test_write_index( - tmp_path_factory, - dataframe, - manifest, - component_spec, - dask_client, -): - """Test writing out the index.""" - with tmp_path_factory.mktemp("temp") as fn: - # override the base path of the manifest with the temp dir - manifest.update_metadata("base_path", str(fn)) - data_writer = DaskDataWriter( - manifest=manifest, - component_spec=component_spec, - ) - # write out index to temp dir - data_writer.write_dataframe(dataframe, dask_client) - number_workers = os.cpu_count() - # read written data and assert - dataframe = dd.read_parquet(fn / "index") - assert len(dataframe) == NUMBER_OF_TEST_ROWS - assert dataframe.index.name == "id" - assert dataframe.npartitions in list( - range(number_workers - 1, number_workers + 2), - ) - - -def test_write_subsets( +def test_write_dataset( tmp_path_factory, dataframe, manifest, @@ -125,11 +83,7 @@ def test_write_subsets( ): """Test writing out subsets.""" # Dictionary specifying the expected subsets to write and their column names - subset_columns_dict = { - "index": [], - "properties": ["Name", "HP"], - "types": ["Type 1", "Type 2"], - } + columns = ["Name", "HP", "Type 1", "Type 2"] with tmp_path_factory.mktemp("temp") as fn: # override the base path of the manifest with the temp dir manifest.update_metadata("base_path", str(fn)) @@ -137,13 +91,13 @@ def test_write_subsets( # write dataframe to temp dir data_writer.write_dataframe(dataframe, dask_client) # read written data and assert - for subset, subset_columns in subset_columns_dict.items(): - dataframe = dd.read_parquet(fn / subset) - assert len(dataframe) == NUMBER_OF_TEST_ROWS - assert list(dataframe.columns) == subset_columns - assert dataframe.index.name == "id" + dataframe = dd.read_parquet(fn) + assert len(dataframe) == NUMBER_OF_TEST_ROWS + assert list(dataframe.columns) == columns + assert dataframe.index.name == "id" +# TODO: check if this is still needed? def test_write_reset_index( tmp_path_factory, dataframe, @@ -151,7 +105,7 @@ def test_write_reset_index( component_spec, dask_client, ): - """Test writing out the index and subsets that have no dask index and checking + """Test writing out the index and fields that have no dask index and checking if the id index was created. """ dataframe = dataframe.reset_index(drop=True) @@ -160,10 +114,8 @@ def test_write_reset_index( data_writer = DaskDataWriter(manifest=manifest, component_spec=component_spec) data_writer.write_dataframe(dataframe, dask_client) - - for subset in ["properties", "types", "index"]: - dataframe = dd.read_parquet(fn / subset) - assert dataframe.index.name == "id" + dataframe = dd.read_parquet(fn) + assert dataframe.index.name == "id" @pytest.mark.parametrize("partitions", list(range(1, 5))) @@ -189,29 +141,51 @@ def test_write_divisions( # noqa: PLR0913 data_writer.write_dataframe(dataframe, dask_client) - for target in ["properties", "types", "index"]: - dataframe = dd.read_parquet(fn / target) - assert dataframe.index.name == "id" - assert dataframe.npartitions == partitions + dataframe = dd.read_parquet(fn) + assert dataframe.index.name == "id" + assert dataframe.npartitions == partitions + + +def test_write_fields_invalid( + tmp_path_factory, + dataframe, + manifest, + component_spec, + dask_client, +): + """Test writing out fields but the dataframe columns are incomplete.""" + with tmp_path_factory.mktemp("temp") as fn: + # override the base path of the manifest with the temp dir + manifest.update_metadata("base_path", str(fn)) + # Drop one of the columns required in the output + dataframe = dataframe.drop(["Type 2"], axis=1) + data_writer = DaskDataWriter(manifest=manifest, component_spec=component_spec) + expected_error_msg = ( + r"Fields \['Type 2'\] defined in output dataset " + r"but not found in dataframe" + ) + with pytest.raises(ValueError, match=expected_error_msg): + data_writer.write_dataframe(dataframe, dask_client) -def test_write_subsets_invalid( +def test_write_fields_invalid_several_fields_missing( tmp_path_factory, dataframe, manifest, component_spec, dask_client, ): - """Test writing out subsets but the dataframe columns are incomplete.""" + """Test writing out fields but the dataframe columns are incomplete.""" with tmp_path_factory.mktemp("temp") as fn: # override the base path of the manifest with the temp dir manifest.update_metadata("base_path", str(fn)) # Drop one of the columns required in the output - dataframe = dataframe.drop(["types_Type 2"], axis=1) + dataframe = dataframe.drop(["Type 1"], axis=1) + dataframe = dataframe.drop(["Type 2"], axis=1) data_writer = DaskDataWriter(manifest=manifest, component_spec=component_spec) expected_error_msg = ( - r"Field \['types_Type 2'\] not in index defined in output subset " - r"types but not found in dataframe" + r"Fields \['Type 1', 'Type 2'\] defined in output dataset " + r"but not found in dataframe" ) with pytest.raises(ValueError, match=expected_error_msg): data_writer.write_dataframe(dataframe, dask_client) diff --git a/tests/core/test_manifest.py b/tests/core/test_manifest.py index 0b255b9df..c24d27c9c 100644 --- a/tests/core/test_manifest.py +++ b/tests/core/test_manifest.py @@ -1,6 +1,5 @@ import json import pkgutil -from collections import OrderedDict from pathlib import Path import pytest @@ -226,21 +225,3 @@ def test_fields(): # delete a field manifest.remove_field(name="field_1") assert "field_1" not in manifest.fields - - -def test_field_mapping(valid_manifest): - """Test field mapping generation.""" - manifest = Manifest(valid_manifest) - manifest.add_or_update_field(Field(name="index", location="component2")) - field_mapping = manifest.field_mapping - assert field_mapping == OrderedDict( - { - "gs://bucket/test_pipeline/test_pipeline_12345/component2": [ - "id", - "height", - "width", - ], - "gs://bucket/test_pipeline/test_pipeline_12345/component1": ["images"], - "gs://bucket/test_pipeline/test_pipeline_12345/component3": ["caption"], - }, - ) diff --git a/tests/examples/example_data/raw/split.py b/tests/examples/example_data/raw/split.py index 6800ee323..ade466125 100644 --- a/tests/examples/example_data/raw/split.py +++ b/tests/examples/example_data/raw/split.py @@ -13,7 +13,7 @@ import dask.dataframe as dd data_path = Path(__file__).parent -output_path = Path(__file__).parent.parent / "subsets_input/" +output_path = Path(__file__).parent.parent def split_into_subsets(): @@ -22,17 +22,13 @@ def split_into_subsets(): master_df = master_df.set_index("id", sorted=True) master_df = master_df.repartition(divisions=[0, 50, 100, 151], force=True) - # create index subset - index_df = master_df.index.to_frame().drop(columns=["id"]) - index_df.to_parquet(output_path / "index") - # create properties subset properties_df = master_df[["Name", "HP"]] - properties_df.to_parquet(output_path / "properties") + properties_df.to_parquet(output_path / "component_1") # create types subset types_df = master_df[["Type 1", "Type 2"]] - types_df.to_parquet(output_path / "types") + types_df.to_parquet(output_path / "component_2") if __name__ == "__main__":