Skip to content

Commit

Permalink
implement PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippeMoussalli committed Jan 30, 2024
1 parent 85f0994 commit 5b69298
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
2 changes: 1 addition & 1 deletion components/load_from_hf_hub/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ RUN pip3 install --no-cache-dir -r requirements.txt
# Install Fondant
# This is split from other requirements to leverage caching
ARG FONDANT_VERSION=main
RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION}
RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@d87efb9d37fbec8e86b5fc20a6ab480ff67895af

# Set the working directory to the component folder
WORKDIR /component/src
Expand Down
27 changes: 21 additions & 6 deletions src/fondant/pipeline/lightweight_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,35 +73,50 @@ def consumes(cls) -> t.Optional[t.Dict[str, t.Any]]:
pass

@classmethod
def modify_consumes_spec(cls, apply_consumes, consumes_spec):
def modify_spec_consumes(
cls,
spec_consumes: t.Dict[str, t.Any],
apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]],
):
"""Modify fields based on the consumes argument in the 'apply' method."""
if apply_consumes:
for k, v in apply_consumes.items():
if isinstance(v, str):
consumes_spec[k] = consumes_spec.pop(v)
spec_consumes[k] = spec_consumes.pop(v)
else:
msg = (
f"Invalid data type for field `{k}` in the `apply_consumes` "
f"argument. Only string and pa.DataType are allowed."
f"argument. Only string types are allowed."
)
raise ValueError(
msg,
)
return consumes_spec
return spec_consumes

@classmethod
def get_consumes_spec(
def get_spec_consumes(
cls,
dataset_fields: t.Mapping[str, Field],
apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None,
):
"""
Function that get the consumes spec for the component based on the dataset fields and
the apply_consumes argument.
Args:
dataset_fields: The fields of the dataset.
apply_consumes: The consumes argument in the apply method.
Returns:
The consumes spec for the component.
"""
consumes = cls.consumes()

if consumes is None:
# Get consumes spec from the dataset
spec_consumes = {k: v.type.to_dict() for k, v in dataset_fields.items()}

spec_consumes = cls.modify_consumes_spec(apply_consumes, spec_consumes)
spec_consumes = cls.modify_spec_consumes(spec_consumes, apply_consumes)

logger.warning(
"No consumes defined. Consumes will be inferred from the dataset."
Expand Down
2 changes: 1 addition & 1 deletion src/fondant/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def from_ref(
description = ref.__doc__ or "lightweight component"

consumes_spec = (
ref.get_consumes_spec(fields, kwargs["consumes"])
ref.get_spec_consumes(fields, kwargs["consumes"])
if fields
else {"additionalProperties": True}
)
Expand Down

0 comments on commit 5b69298

Please sign in to comment.