diff --git a/components/load_from_hf_hub/src/main.py b/components/load_from_hf_hub/src/main.py index d4c3c5c15..25fd6f989 100644 --- a/components/load_from_hf_hub/src/main.py +++ b/components/load_from_hf_hub/src/main.py @@ -2,24 +2,31 @@ import logging import typing as t +import dask import dask.dataframe as dd import pandas as pd from fondant.component import DaskLoadComponent +from fondant.component_spec import ComponentSpec logger = logging.getLogger(__name__) +dask.config.set({"dataframe.convert-string": False}) + class LoadFromHubComponent(DaskLoadComponent): - def __init__(self, *_, - dataset_name: str, - column_name_mapping: dict, - image_column_names: t.Optional[list], - n_rows_to_load: t.Optional[int], - index_column:t.Optional[str], + def __init__(self, + spec: ComponentSpec, + *_, + dataset_name: str, + column_name_mapping: dict, + image_column_names: t.Optional[list], + n_rows_to_load: t.Optional[int], + index_column: t.Optional[str], ) -> None: """ Args: + spec: the component spec dataset_name: name of the dataset to load. column_name_mapping: Mapping of the consumed hub dataset to fondant column names image_column_names: A list containing the original hub image column names. Used to @@ -34,6 +41,7 @@ def __init__(self, *_, self.image_column_names = image_column_names self.n_rows_to_load = n_rows_to_load self.index_column = index_column + self.spec = spec def load(self) -> dd.DataFrame: # 1) Load data, read as Dask dataframe @@ -74,14 +82,24 @@ def _set_unique_index(dataframe: pd.DataFrame, partition_info=None): """Function that sets a unique index based on the partition and row number.""" dataframe["id"] = 1 dataframe["id"] = ( - str(partition_info["number"]) - + "_" - + (dataframe.id.cumsum()).astype(str) + str(partition_info["number"]) + + "_" + + (dataframe.id.cumsum()).astype(str) ) dataframe.index = dataframe.pop("id") return dataframe - dask_df = dask_df.map_partitions(_set_unique_index, meta=dask_df.head()) + def _get_meta_df() -> pd.DataFrame: + meta_dict = {"id": pd.Series(dtype="object")} + for subset_name, subset in self.spec.produces.items(): + for field_name, field in subset.fields.items(): + meta_dict[f"{subset_name}_{field_name}"] = pd.Series( + dtype=pd.ArrowDtype(field.type.value), + ) + return pd.DataFrame(meta_dict).set_index("id") + + meta = _get_meta_df() + dask_df = dask_df.map_partitions(_set_unique_index, meta=meta) else: logger.info(f"Setting `{self.index_column}` as index") dask_df = dask_df.set_index(self.index_column, drop=True)