From 5b69298944d6ce1e1cd444b36da328b4eb4f52df Mon Sep 17 00:00:00 2001 From: Philippe Moussalli Date: Tue, 30 Jan 2024 15:56:40 +0100 Subject: [PATCH] implement PR feedback --- components/load_from_hf_hub/Dockerfile | 2 +- src/fondant/pipeline/lightweight_component.py | 27 ++++++++++++++----- src/fondant/pipeline/pipeline.py | 2 +- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/components/load_from_hf_hub/Dockerfile b/components/load_from_hf_hub/Dockerfile index 72525d884..919611e7f 100644 --- a/components/load_from_hf_hub/Dockerfile +++ b/components/load_from_hf_hub/Dockerfile @@ -12,7 +12,7 @@ RUN pip3 install --no-cache-dir -r requirements.txt # Install Fondant # This is split from other requirements to leverage caching ARG FONDANT_VERSION=main -RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@${FONDANT_VERSION} +RUN pip3 install fondant[component,aws,azure,gcp]@git+https://github.com/ml6team/fondant@d87efb9d37fbec8e86b5fc20a6ab480ff67895af # Set the working directory to the component folder WORKDIR /component/src diff --git a/src/fondant/pipeline/lightweight_component.py b/src/fondant/pipeline/lightweight_component.py index 474d96db3..bce83619b 100644 --- a/src/fondant/pipeline/lightweight_component.py +++ b/src/fondant/pipeline/lightweight_component.py @@ -73,35 +73,50 @@ def consumes(cls) -> t.Optional[t.Dict[str, t.Any]]: pass @classmethod - def modify_consumes_spec(cls, apply_consumes, consumes_spec): + def modify_spec_consumes( + cls, + spec_consumes: t.Dict[str, t.Any], + apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]], + ): """Modify fields based on the consumes argument in the 'apply' method.""" if apply_consumes: for k, v in apply_consumes.items(): if isinstance(v, str): - consumes_spec[k] = consumes_spec.pop(v) + spec_consumes[k] = spec_consumes.pop(v) else: msg = ( f"Invalid data type for field `{k}` in the `apply_consumes` " - f"argument. Only string and pa.DataType are allowed." + f"argument. Only string types are allowed." ) raise ValueError( msg, ) - return consumes_spec + return spec_consumes @classmethod - def get_consumes_spec( + def get_spec_consumes( cls, dataset_fields: t.Mapping[str, Field], apply_consumes: t.Optional[t.Dict[str, t.Union[str, pa.DataType]]] = None, ): + """ + Function that get the consumes spec for the component based on the dataset fields and + the apply_consumes argument. + + Args: + dataset_fields: The fields of the dataset. + apply_consumes: The consumes argument in the apply method. + + Returns: + The consumes spec for the component. + """ consumes = cls.consumes() if consumes is None: # Get consumes spec from the dataset spec_consumes = {k: v.type.to_dict() for k, v in dataset_fields.items()} - spec_consumes = cls.modify_consumes_spec(apply_consumes, spec_consumes) + spec_consumes = cls.modify_spec_consumes(spec_consumes, apply_consumes) logger.warning( "No consumes defined. Consumes will be inferred from the dataset." diff --git a/src/fondant/pipeline/pipeline.py b/src/fondant/pipeline/pipeline.py index 5e8152947..46062c6d2 100644 --- a/src/fondant/pipeline/pipeline.py +++ b/src/fondant/pipeline/pipeline.py @@ -228,7 +228,7 @@ def from_ref( description = ref.__doc__ or "lightweight component" consumes_spec = ( - ref.get_consumes_spec(fields, kwargs["consumes"]) + ref.get_spec_consumes(fields, kwargs["consumes"]) if fields else {"additionalProperties": True} )