Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor component package #654

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
35338b6
Update component spec schema validation
mrchtr Nov 16, 2023
a269e3c
Update component spec tests to validate new component spec
mrchtr Nov 16, 2023
ad0dab6
Add additional fields to json schema
mrchtr Nov 16, 2023
7b91535
Update manifest json schema for validation
mrchtr Nov 16, 2023
5d1bf5e
Update manifest creation
mrchtr Nov 17, 2023
d8ecd01
Reduce PR to core module
mrchtr Nov 21, 2023
12c78ca
Addresses comments
mrchtr Nov 21, 2023
c1cad60
Restructure test directory
mrchtr Nov 21, 2023
fd0699c
Remove additional fields in common.json
mrchtr Nov 21, 2023
0f8117f
Test structure
mrchtr Nov 21, 2023
7e8a1d6
Refactor component package
mrchtr Nov 21, 2023
9f67c61
Update src/fondant/core/component_spec.py
mrchtr Nov 21, 2023
40955bf
Update src/fondant/core/manifest.py
mrchtr Nov 21, 2023
6b246a4
Update src/fondant/core/component_spec.py
mrchtr Nov 21, 2023
8ef38d9
Update src/fondant/core/manifest.py
mrchtr Nov 21, 2023
e8c8135
Update src/fondant/core/schema.py
mrchtr Nov 21, 2023
df9a60e
Addresses comments
mrchtr Nov 21, 2023
2256118
Addresses comments
mrchtr Nov 21, 2023
3042fb5
Addresses comments
mrchtr Nov 21, 2023
8fa8be7
Update src/fondant/core/manifest.py
mrchtr Nov 21, 2023
25eb492
Addresses comments
mrchtr Nov 22, 2023
c0fb47a
Merge branch 'feature/implement-new-dataset-format' into feautre/refa…
mrchtr Nov 22, 2023
0701662
Addresses comments
mrchtr Nov 22, 2023
365ca6d
Update test examples
mrchtr Nov 22, 2023
4dc7dc7
Update src/fondant/core/manifest.py
mrchtr Nov 22, 2023
a60ca3e
addresses comments
mrchtr Nov 22, 2023
d2182a0
Merge feature/implement-new-dataset-format into feature/refactore-com…
mrchtr Nov 22, 2023
43a7b68
Addressing comments regarding data_io
mrchtr Nov 23, 2023
83a5de6
Merge feature/redesign-dataset-format-and-interface into feature/refa…
mrchtr Nov 23, 2023
5ac5e42
Update tests
mrchtr Nov 23, 2023
6616bf2
Remove set_index on during merging
mrchtr Nov 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 55 additions & 101 deletions src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask.diagnostics import ProgressBar
from dask.distributed import Client

from fondant.core.component_spec import ComponentSpec, ComponentSubset
from fondant.core.component_spec import ComponentSpec
from fondant.core.manifest import Manifest

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -82,35 +82,7 @@ def partition_loaded_dataframe(self, dataframe: dd.DataFrame) -> dd.DataFrame:

return dataframe

def _load_subset(self, subset_name: str, fields: t.List[str]) -> dd.DataFrame:
"""
Function that loads a subset from the manifest as a Dask dataframe.

Args:
subset_name: the name of the subset to load
fields: the fields to load from the subset

Returns:
The subset as a dask dataframe
"""
subset = self.manifest.subsets[subset_name]
remote_path = subset.location

logger.info(f"Loading subset {subset_name} with fields {fields}...")

subset_df = dd.read_parquet(
remote_path,
columns=fields,
calculate_divisions=True,
)

# add subset prefix to columns
subset_df = subset_df.rename(
columns={col: subset_name + "_" + col for col in subset_df.columns},
)

return subset_df

# TODO: probably not needed anymore!
def _load_index(self) -> dd.DataFrame:
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
"""
Function that loads the index from the manifest as a Dask dataframe.
Expand All @@ -121,9 +93,10 @@ def _load_index(self) -> dd.DataFrame:
# get index subset from the manifest
index = self.manifest.index
# get remote path
remote_path = index.location
remote_path = index["location"]

# load index from parquet, expecting id and source columns
# TODO: reduce dataframe to index loading? .loc[:, []]?
return dd.read_parquet(remote_path, calculate_divisions=True)

def load_dataframe(self) -> dd.DataFrame:
Expand All @@ -135,20 +108,34 @@ def load_dataframe(self) -> dd.DataFrame:
The Dask dataframe with the field columns in the format (<subset>_<column_name>)
as well as the index columns.
"""
# load index into dataframe
dataframe = self._load_index()
for name, subset in self.component_spec.consumes.items():
fields = list(subset.fields.keys())
subset_df = self._load_subset(name, fields)
# left joins -> filter on index
dataframe = dd.merge(
dataframe,
subset_df,
left_index=True,
right_index=True,
how="left",
dataframe = None
field_mapping = self.manifest.field_mapping
for location, fields in field_mapping.items():
RobbeSneyders marked this conversation as resolved.
Show resolved Hide resolved
partial_df = dd.read_parquet(
location,
columns=fields,
index="id",
calculate_divisions=True,
)

if dataframe is None:
# ensure that the index is set correctly and divisions are known.
dataframe = partial_df
else:
dask_divisions = partial_df.set_index("id").divisions
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
unique_divisions = list(dict.fromkeys(list(dask_divisions)))
mrchtr marked this conversation as resolved.
Show resolved Hide resolved

# apply set index to both dataframes
partial_df = partial_df.set_index("id", divisions=unique_divisions)
dataframe = dataframe.set_index("id", divisions=unique_divisions)

dataframe = dataframe.merge(
partial_df,
how="left",
left_index=True,
right_index=True,
)

dataframe = self.partition_loaded_dataframe(dataframe)

logging.info(f"Columns of dataframe: {list(dataframe.columns)}")
Expand All @@ -170,79 +157,46 @@ def write_dataframe(
dataframe: dd.DataFrame,
dask_client: t.Optional[Client] = None,
) -> None:
write_tasks = []
columns_to_produce = [
column_name for column_name, field in self.component_spec.produces.items()
]

dataframe.index = dataframe.index.rename("id")
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
# validation that all columns are in the dataframe
self.validate_dataframe_columns(dataframe, columns_to_produce)

# Turn index into an empty dataframe so we can write it
index_df = dataframe.index.to_frame().drop(columns=["id"])
write_index_task = self._write_subset(
index_df,
subset_name="index",
subset_spec=self.component_spec.index,
)
write_tasks.append(write_index_task)

for subset_name, subset_spec in self.component_spec.produces.items():
subset_df = self._extract_subset_dataframe(
dataframe,
subset_name=subset_name,
subset_spec=subset_spec,
)
write_subset_task = self._write_subset(
subset_df,
subset_name=subset_name,
subset_spec=subset_spec,
)
write_tasks.append(write_subset_task)
dataframe = dataframe[columns_to_produce]
write_task = self._write_dataframe(dataframe)

with ProgressBar():
logging.info("Writing data...")
# alternative implementation possible: futures = client.compute(...)
dd.compute(*write_tasks, scheduler=dask_client)
dd.compute(write_task, scheduler=dask_client)

@staticmethod
def _extract_subset_dataframe(
dataframe: dd.DataFrame,
*,
subset_name: str,
subset_spec: ComponentSubset,
) -> dd.DataFrame:
"""Create subset dataframe to save with the original field name as the column name."""
# Create a new dataframe with only the columns needed for the output subset
subset_columns = [f"{subset_name}_{field}" for field in subset_spec.fields]
try:
subset_df = dataframe[subset_columns]
except KeyError as e:
def validate_dataframe_columns(dataframe: dd.DataFrame, columns: t.List[str]):
"""Validates that all columns are available in the dataset."""
missing_fields = []
for col in columns:
if col not in dataframe.columns:
missing_fields.append(col)

if missing_fields:
msg = (
f"Field {e.args[0]} defined in output subset {subset_name} "
f"Fields {missing_fields} defined in output dataset "
f"but not found in dataframe"
)
raise ValueError(
msg,
)

# Remove the subset prefix from the column names
subset_df = subset_df.rename(
columns={col: col[(len(f"{subset_name}_")) :] for col in subset_columns},
def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar:
"""Create dataframe writing task."""
location = (
self.manifest.base_path + "/" + self.component_spec.component_folder_name
)

return subset_df

def _write_subset(
self,
dataframe: dd.DataFrame,
*,
subset_name: str,
subset_spec: ComponentSubset,
) -> dd.core.Scalar:
if subset_name == "index":
location = self.manifest.index.location
else:
location = self.manifest.subsets[subset_name].location

schema = {field.name: field.type.value for field in subset_spec.fields.values()}

schema = {
field.name: field.type.value
for field in self.component_spec.produces.values()
}
return self._create_write_task(dataframe, location=location, schema=schema)

@staticmethod
Expand Down
36 changes: 11 additions & 25 deletions src/fondant/component/executor.py
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -491,42 +491,25 @@ def optional_fondant_arguments() -> t.List[str]:
@staticmethod
def wrap_transform(transform: t.Callable, *, spec: ComponentSpec) -> t.Callable:
"""Factory that creates a function to wrap the component transform function. The wrapper:
- Converts the columns to hierarchical format before passing the dataframe to the
transform function
- Removes extra columns from the returned dataframe which are not defined in the component
spec `produces` section
- Sorts the columns from the returned dataframe according to the order in the component
spec `produces` section to match the order in the `meta` argument passed to Dask's
`map_partitions`.
- Flattens the returned dataframe columns.

Args:
transform: Transform method to wrap
spec: Component specification to base behavior on
"""

def wrapped_transform(dataframe: pd.DataFrame) -> pd.DataFrame:
# Switch to hierarchical columns
dataframe.columns = pd.MultiIndex.from_tuples(
tuple(column.split("_")) for column in dataframe.columns
)

# Call transform method
dataframe = transform(dataframe)

# Drop columns not in specification
columns = [
(subset_name, field)
for subset_name, subset in spec.produces.items()
for field in subset.fields
]
dataframe = dataframe[columns]

# Switch to flattened columns
dataframe.columns = [
"_".join(column) for column in dataframe.columns.to_flat_index()
]
return dataframe
columns = [name for name, field in spec.produces.items()]

return dataframe[columns]

return wrapped_transform

Expand All @@ -552,11 +535,8 @@ def _execute_component(

# Create meta dataframe with expected format
meta_dict = {"id": pd.Series(dtype="object")}
for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
meta_dict[f"{subset_name}_{field_name}"] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
for field_name, field in self.spec.produces.items():
meta_dict[field_name] = pd.Series(dtype=pd.ArrowDtype(field.type.value))
meta_df = pd.DataFrame(meta_dict).set_index("id")

wrapped_transform = self.wrap_transform(component.transform, spec=self.spec)
Expand All @@ -569,12 +549,16 @@ def _execute_component(

# Clear divisions if component spec indicates that the index is changed
if self._infer_index_change():
# TODO: might causing issues for merging components
mrchtr marked this conversation as resolved.
Show resolved Hide resolved
# to guarantee fast merging of large dataframes we need to keep the division information
dataframe.clear_divisions()

return dataframe

# TODO: fix in #244
def _infer_index_change(self) -> bool:
"""Infer if this component changes the index based on its component spec."""
"""
if not self.spec.accepts_additional_subsets:
return True
if not self.spec.outputs_additional_subsets:
Expand All @@ -585,6 +569,8 @@ def _infer_index_change(self) -> bool:
return any(
not subset.additional_fields for subset in self.spec.produces.values()
)
"""
return False


class DaskWriteExecutor(Executor[DaskWriteComponent]):
Expand Down
52 changes: 7 additions & 45 deletions src/fondant/core/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,34 +66,6 @@ def kubeflow_type(self) -> str:
return lookup[self.type]


class ComponentSubset:
"""
Class representing a Fondant Component subset.

Args:
specification: the part of the component json representing the subset
"""

def __init__(self, specification: t.Dict[str, t.Any]) -> None:
self._specification = specification

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._specification!r})"

@property
def fields(self) -> t.Mapping[str, Field]:
return types.MappingProxyType(
{
name: Field(name=name, type=Type.from_json(field))
for name, field in self._specification["fields"].items()
},
)

@property
def additional_fields(self) -> bool:
return self._specification.get("additionalFields", True)


class ComponentSpec:
"""
Class representing a Fondant component specification.
Expand Down Expand Up @@ -191,38 +163,28 @@ def tags(self) -> t.List[str]:

@property
def index(self):
return ComponentSubset({"fields": {}})
return Field(name="index", location=self._specification["index"].location)

@property
def consumes(self) -> t.Mapping[str, ComponentSubset]:
def consumes(self) -> t.Mapping[str, Field]:
"""The subsets consumed by the component as an immutable mapping."""
return types.MappingProxyType(
{
name: ComponentSubset(subset)
for name, subset in self._specification.get("consumes", {}).items()
if name != "additionalSubsets"
name: Field(name=name, type=Type.from_json(field))
for name, field in self._specification["consumes"].items()
},
)

@property
def produces(self) -> t.Mapping[str, ComponentSubset]:
def produces(self) -> t.Mapping[str, Field]:
"""The subsets produced by the component as an immutable mapping."""
return types.MappingProxyType(
{
name: ComponentSubset(subset)
for name, subset in self._specification.get("produces", {}).items()
if name != "additionalSubsets"
name: Field(name=name, type=Type.from_json(field))
for name, field in self._specification["produces"].items()
},
)

@property
def accepts_additional_subsets(self) -> bool:
return self._specification.get("consumes", {}).get("additionalSubsets", True)

@property
def outputs_additional_subsets(self) -> bool:
return self._specification.get("produces", {}).get("additionalSubsets", True)

@property
def args(self) -> t.Mapping[str, Argument]:
args = self.default_arguments
Expand Down
Loading
Loading