Skip to content

Commit

Permalink
update based on feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippeMoussalli committed Jan 23, 2024
1 parent 8c9d154 commit 0619c41
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
78 changes: 45 additions & 33 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pyarrow as pa

from fondant.component import BaseComponent, Component
from fondant.core.schema import Field, Type
from fondant.core.schema import Field


@dataclass
Expand All @@ -32,7 +32,7 @@ def image(cls) -> Image:
raise NotImplementedError

@classmethod
def consumes(cls) -> t.Optional[t.Union[list, str]]:
def consumes(cls) -> t.Optional[list]:
pass

@classmethod
Expand All @@ -48,7 +48,7 @@ def lightweight_component(
*args,
extra_requires: t.Optional[t.List[str]] = None,
base_image: t.Optional[str] = None,
consumes: t.Optional[t.Union[list, str]] = None,
consumes: t.Optional[list] = None,
):
"""Decorator to enable a python component."""

Expand Down Expand Up @@ -126,6 +126,41 @@ def validate_abstract_methods_are_implemented(cls):
msg,
)

def modify_consumes_spec(apply_consumes, consumes_spec):
"""Modify fields based on the consumes argument in the 'apply' method."""
if apply_consumes:
for k, v in apply_consumes.items():
if isinstance(v, str):
consumes_spec[k] = consumes_spec.pop(v)
elif isinstance(v, pa.DataType):
pass
else:
msg = (
f"Invalid data type for field `{k}` in the `apply_consumes` "
f"argument. Only string and pa.DataType are allowed."
)
raise ValueError(
msg,
)
return consumes_spec

def filter_consumes_spec(python_component_consumes, consumes_spec):
"""Filter for values that are not in the user defined consumes list."""
if python_component_consumes:
for field_to_consume in python_component_consumes:
if field_to_consume not in consumes_spec.keys():
msg = f"Field `{field_to_consume}` is not available in the dataset."
raise ValueError(
msg,
)

consumes_spec = {
k: v
for k, v in consumes_spec.items()
if k in python_component_consumes
}
return consumes_spec

validate_abstract_methods_are_implemented(cls)
base_component_cls = get_base_cls(cls)
validate_signatures(base_component_cls, cls)
Expand All @@ -138,7 +173,7 @@ def image(cls) -> Image:
return image

@classmethod
def consumes(cls) -> t.Optional[t.Union[list, str]]:
def consumes(cls) -> t.Optional[list]:
return consumes

@classmethod
Expand All @@ -147,42 +182,19 @@ def get_consumes_spec(
dataset_fields: t.Mapping[str, Field],
apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]],
):
consumes = cls.consumes()

if consumes == "generic":
return {"additionalProperties": True}
python_component_consumes = cls.consumes()

# Get consumes spec from the dataset
consumes_spec = {k: v.type.to_dict() for k, v in dataset_fields.items()}

# Modify naming based on the consumes argument in the 'apply' method
if apply_consumes:
for k, v in apply_consumes.items():
if isinstance(v, str):
consumes_spec[k] = consumes_spec.pop(v)
elif isinstance(v, pa.DataType):
consumes_spec[k] = Type(v).to_dict()
else:
msg = (
f"Invalid data type for field `{k}` in the `apply_consumes` "
f"argument. Only string and pa.DataType are allowed."
)
raise ValueError(
msg,
)
consumes_spec = modify_consumes_spec(apply_consumes, consumes_spec)

# Filter for values that are not in the user defined consumes list
if consumes:
for field_to_consume in consumes:
if field_to_consume not in consumes_spec.keys():
msg = f"Field `{field_to_consume}` is not available in the dataset."
raise ValueError(
msg,
)

consumes_spec = {
k: v for k, v in consumes_spec.items() if k in consumes
}
consumes_spec = filter_consumes_spec(
python_component_consumes,
consumes_spec,
)

return consumes_spec

Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_component_op(


def test_component_op_python_component(default_pipeline_args):
@lightweight_component(consumes="generic")
@lightweight_component()
class Foo(DaskLoadComponent):
def load(self) -> dd.DataFrame:
df = pd.DataFrame(
Expand Down
17 changes: 8 additions & 9 deletions tests/pipeline/test_python_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def load_pipeline():
@lightweight_component(
base_image="python:3.8-slim-buster",
extra_requires=["pandas", "dask"],
consumes="generic",
)
class CreateData(DaskLoadComponent):
def load(self) -> dd.DataFrame:
Expand Down Expand Up @@ -100,7 +99,7 @@ def test_lightweight_component_sdk(load_pipeline):
},
}

@lightweight_component(consumes="generic")
@lightweight_component
class AddN(PandasTransformComponent):
def __init__(self, n: int, **kwargs):
self.n = n
Expand All @@ -112,7 +111,7 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
_ = dataset.apply(
ref=AddN,
produces={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()},
consumes={"x": pa.int32(), "y": pa.int32(), "z": pa.int32()},
consumes=None,
arguments={"n": 1},
)
assert len(pipeline._graph.keys()) == 1 + 1
Expand All @@ -123,15 +122,15 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
"name": "AddN",
"image": "fondant:latest",
"description": "python component",
"consumes": {"additionalProperties": True},
"consumes": {
"x": {"type": "int32"},
"y": {"type": "int32"},
"z": {"type": "int32"},
},
"produces": {"additionalProperties": True},
"args": {"n": {"type": "int"}},
},
"consumes": {
"x": {"type": "int32"},
"y": {"type": "int32"},
"z": {"type": "int32"},
},
"consumes": {},
"produces": {
"x": {"type": "int32"},
"y": {"type": "int32"},
Expand Down

0 comments on commit 0619c41

Please sign in to comment.