Skip to content

Commit

Permalink
Fix output dataframe path (#675)
Browse files Browse the repository at this point in the history
The output dataframe was not written to the correct path:


![image](https://github.com/ml6team/fondant/assets/20990866/3ed6b5d6-9f7b-4c35-85ab-8d6d269c43fa)

It might make sense to centralize this functionality somewhere, but not
sure where.
  • Loading branch information
RobbeSneyders authored Nov 27, 2023
1 parent f87217a commit 8b54505
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ def validate_dataframe_columns(dataframe: dd.DataFrame, columns: t.List[str]):
def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar:
"""Create dataframe writing task."""
location = (
self.manifest.base_path + "/" + self.component_spec.component_folder_name
f"{self.manifest.base_path}/{self.manifest.pipeline_name}/"
f"{self.manifest.run_id}/{self.component_spec.component_folder_name}"
)

schema = {
field.name: field.type.value
for field in self.component_spec.produces.values()
Expand Down
11 changes: 8 additions & 3 deletions tests/component/test_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,19 @@ def test_write_dataset(
"""Test writing out subsets."""
# Dictionary specifying the expected subsets to write and their column names
columns = ["Name", "HP", "Type 1", "Type 2"]
with tmp_path_factory.mktemp("temp") as fn:
with tmp_path_factory.mktemp("temp") as temp_dir:
# override the base path of the manifest with the temp dir
manifest.update_metadata("base_path", str(fn))
manifest.update_metadata("base_path", str(temp_dir))
data_writer = DaskDataWriter(manifest=manifest, component_spec=component_spec)
# write dataframe to temp dir
data_writer.write_dataframe(dataframe, dask_client)
# read written data and assert
dataframe = dd.read_parquet(fn)
dataframe = dd.read_parquet(
temp_dir
/ manifest.pipeline_name
/ manifest.run_id
/ component_spec.component_folder_name,
)
assert len(dataframe) == NUMBER_OF_TEST_ROWS
assert list(dataframe.columns) == columns
assert dataframe.index.name == "id"
Expand Down

0 comments on commit 8b54505

Please sign in to comment.