Skip to content

Commit

Permalink
Unify operation and component spec (#741)
Browse files Browse the repository at this point in the history
Fixes #728 

We now construct the operation spec once during the compilation process,
serialize it and pass it to the component as an argument. No need to
pass in the `produces` and `consumes` arguments anymore.

Some attributes like the arguments and the previous index that are part
of the component spec can still be accessed through the operation spec
since the component spec is it's attribute.
  • Loading branch information
PhilippeMoussalli authored Jan 2, 2024
1 parent 1be3e6b commit 9d91ccd
Show file tree
Hide file tree
Showing 12 changed files with 213 additions and 213 deletions.
1 change: 0 additions & 1 deletion components/load_from_csv/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
) -> None:
"""
Args:
spec: the component spec
produces: The schema the component should produce
dataset_uri: The remote path to the csv file/folder containing the dataset
column_separator: Separator to use when parsing csv
Expand Down
8 changes: 4 additions & 4 deletions components/load_with_llamahub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import dask.dataframe as dd
import pandas as pd
from fondant.component import DaskLoadComponent
from fondant.core.component_spec import ComponentSpec
from fondant.core.component_spec import OperationSpec
from llama_index import download_loader

logger = logging.getLogger(__name__)
Expand All @@ -16,7 +16,7 @@
class LlamaHubReader(DaskLoadComponent):
def __init__(
self,
spec: ComponentSpec,
spec: OperationSpec,
*,
loader_class: str,
loader_kwargs: dict,
Expand Down Expand Up @@ -73,7 +73,7 @@ def _set_unique_index(dataframe: pd.DataFrame, partition_info=None):

def _get_meta_df() -> pd.DataFrame:
meta_dict = {"id": pd.Series(dtype="object")}
for field_name, field in self.spec.produces.items():
for field_name, field in self.spec.inner_produces.items():
meta_dict[field_name] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
Expand All @@ -95,7 +95,7 @@ def load(self) -> dd.DataFrame:

doc_dict = defaultdict(list)
for d, document in enumerate(documents):
for column in self.spec.produces:
for column in self.spec.inner_produces:
if column == "text":
doc_dict["text"].append(document.text)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar:
"""Create dataframe writing task."""
location = (
f"{self.manifest.base_path}/{self.manifest.pipeline_name}/"
f"{self.manifest.run_id}/{self.operation_spec.specification.component_folder_name}"
f"{self.manifest.run_id}/{self.operation_spec.component_folder_name}"
)

schema = {
Expand Down
69 changes: 26 additions & 43 deletions src/fondant/component/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import dask
import dask.dataframe as dd
import pandas as pd
import pyarrow as pa
from dask.distributed import Client, LocalCluster
from fsspec import open as fs_open

Expand All @@ -28,9 +27,8 @@
PandasTransformComponent,
)
from fondant.component.data_io import DaskDataLoader, DaskDataWriter
from fondant.core.component_spec import Argument, ComponentSpec, OperationSpec
from fondant.core.component_spec import Argument, OperationSpec
from fondant.core.manifest import Manifest, Metadata
from fondant.core.schema import Type

dask.config.set({"dataframe.convert-string": False})
logger = logging.getLogger(__name__)
Expand All @@ -41,7 +39,7 @@ class Executor(t.Generic[Component]):
An executor executes a Component.
Args:
spec: The specification of the Component to be executed.
operation_spec: The operation spec of the component to be executed.
cache: Flag indicating whether to use caching for intermediate results.
input_manifest_path: The path to the input manifest file.
output_manifest_path: The path to the output manifest file.
Expand All @@ -55,11 +53,13 @@ class Executor(t.Generic[Component]):
(default is "local").
client_kwargs: Additional keyword arguments dict which will be used to
initialise the dask client, allowing for advanced configuration.
previous_index: The name of the index column of the previous component.
Used to remove all previous fields if the component changes the index
"""

def __init__(
self,
spec: ComponentSpec,
operation_spec: OperationSpec,
*,
cache: bool,
input_manifest_path: t.Union[str, Path],
Expand All @@ -69,18 +69,16 @@ def __init__(
input_partition_rows: int,
cluster_type: t.Optional[str] = None,
client_kwargs: t.Optional[dict] = None,
consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
previous_index: t.Optional[str] = None,
) -> None:
self.spec = spec
self.operation_spec = operation_spec
self.cache = cache
self.input_manifest_path = input_manifest_path
self.output_manifest_path = output_manifest_path
self.metadata = Metadata.from_dict(metadata)
self.user_arguments = user_arguments
self.input_partition_rows = input_partition_rows

self.operation_spec = OperationSpec(spec, consumes=consumes, produces=produces)
self.previous_index = previous_index

if cluster_type == "local":
client_kwargs = client_kwargs or {
Expand Down Expand Up @@ -113,48 +111,42 @@ def __init__(
def from_args(cls) -> "Executor":
"""Create an executor from a passed argument containing the specification as a dict."""
parser = argparse.ArgumentParser()
parser.add_argument("--component_spec", type=json.loads)
parser.add_argument("--operation_spec", type=json.loads)
parser.add_argument("--cache", type=lambda x: bool(strtobool(x)))
parser.add_argument("--input_partition_rows", type=int)
parser.add_argument("--cluster_type", type=str)
parser.add_argument("--client_kwargs", type=json.loads)
parser.add_argument("--consumes", type=cls._parse_mapping)
parser.add_argument("--produces", type=cls._parse_mapping)
args, _ = parser.parse_known_args()

if "component_spec" not in args:
msg = "Error: The --component_spec argument is required."
if "operation_spec" not in args:
msg = "Error: The --operation_spec argument is required."
raise ValueError(msg)

component_spec = ComponentSpec(args.component_spec)
operation_spec = OperationSpec.from_dict(args.operation_spec)

return cls.from_spec(
component_spec,
operation_spec,
cache=args.cache,
input_partition_rows=args.input_partition_rows,
cluster_type=args.cluster_type,
client_kwargs=args.client_kwargs,
consumes=args.consumes,
produces=args.produces,
)

@classmethod
def from_spec(
cls,
component_spec: ComponentSpec,
operation_spec: OperationSpec,
*,
cache: bool,
input_partition_rows: int,
cluster_type: t.Optional[str],
client_kwargs: t.Optional[dict],
consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
produces: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
) -> "Executor":
"""Create an executor from a component spec."""
args_dict = vars(cls._add_and_parse_args(component_spec))
args_dict = vars(cls._add_and_parse_args(operation_spec))

for argument in [
"component_spec",
"operation_spec",
"input_partition_rows",
"cache",
"cluster_type",
Expand All @@ -170,7 +162,7 @@ def from_spec(
metadata = json.loads(metadata) if metadata else {}

return cls(
component_spec,
operation_spec,
input_manifest_path=input_manifest_path,
output_manifest_path=output_manifest_path,
cache=cache,
Expand All @@ -179,24 +171,13 @@ def from_spec(
input_partition_rows=input_partition_rows,
cluster_type=cluster_type,
client_kwargs=client_kwargs,
consumes=consumes,
produces=produces,
previous_index=operation_spec.previous_index,
)

@staticmethod
def _parse_mapping(json_mapping: str) -> t.Mapping:
"""Parse a json mapping to a Python mapping with Fondant types."""
mapping = json.loads(json_mapping)

for key, value in mapping.items():
if isinstance(value, dict):
mapping[key] = Type.from_json(value).value
return mapping

@classmethod
def _add_and_parse_args(cls, spec: ComponentSpec):
def _add_and_parse_args(cls, operation_spec: OperationSpec):
parser = argparse.ArgumentParser()
component_arguments = cls._get_component_arguments(spec)
component_arguments = cls._get_component_arguments(operation_spec)

for arg in component_arguments.values():
if arg.name in cls.optional_fondant_arguments():
Expand Down Expand Up @@ -232,17 +213,19 @@ def optional_fondant_arguments() -> t.List[str]:
return []

@staticmethod
def _get_component_arguments(spec: ComponentSpec) -> t.Dict[str, Argument]:
def _get_component_arguments(
operation_spec: OperationSpec,
) -> t.Dict[str, Argument]:
"""
Get the component arguments as a dictionary representation containing both input and output
arguments of a component
Args:
spec: the component spec
operation_spec: the operation spec
Returns:
Input and output arguments of the component.
"""
component_arguments: t.Dict[str, Argument] = {}
component_arguments.update(spec.args)
component_arguments.update(operation_spec.args)
return component_arguments

@abstractmethod
Expand Down Expand Up @@ -580,7 +563,7 @@ def _execute_component(
)

# Clear divisions if component spec indicates that the index is changed
if self.spec.previous_index is not None:
if self.previous_index is not None:
dataframe.clear_divisions()

return dataframe
Expand Down
Loading

0 comments on commit 9d91ccd

Please sign in to comment.