Skip to content

Commit

Permalink
Remove subset usage in components
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchtr committed Nov 20, 2023
1 parent b91c071 commit 3e5f45c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 40 deletions.
25 changes: 10 additions & 15 deletions components/load_from_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,12 @@ def get_columns_to_keep(self) -> t.List[str]:
else:
invert_column_name_mapping = {}

for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
column_name = f"{subset_name}_{field_name}"
if (
invert_column_name_mapping
and column_name in invert_column_name_mapping
):
columns.append(invert_column_name_mapping[column_name])
else:
columns.append(column_name)
for field_name, field in self.spec.produces.items():
column_name = field_name
if invert_column_name_mapping and column_name in invert_column_name_mapping:
columns.append(invert_column_name_mapping[column_name])
else:
columns.append(column_name)

if self.index_column is not None:
columns.append(self.index_column)
Expand Down Expand Up @@ -99,11 +95,10 @@ def _set_unique_index(dataframe: pd.DataFrame, partition_info=None):

def _get_meta_df() -> pd.DataFrame:
meta_dict = {"id": pd.Series(dtype="object")}
for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
meta_dict[f"{subset_name}_{field_name}"] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
for field_name, field in self.spec.produces.items():
meta_dict[field_name] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
return pd.DataFrame(meta_dict).set_index("id")

meta = _get_meta_df()
Expand Down
25 changes: 10 additions & 15 deletions components/load_from_parquet/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,12 @@ def get_columns_to_keep(self) -> t.List[str]:
else:
invert_column_name_mapping = {}

for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
column_name = f"{subset_name}_{field_name}"
if (
invert_column_name_mapping
and column_name in invert_column_name_mapping
):
columns.append(invert_column_name_mapping[column_name])
else:
columns.append(column_name)
for field_name, field in self.spec.produces.items():
column_name = field_name
if invert_column_name_mapping and column_name in invert_column_name_mapping:
columns.append(invert_column_name_mapping[column_name])
else:
columns.append(column_name)

if self.index_column is not None:
columns.append(self.index_column)
Expand All @@ -85,11 +81,10 @@ def _set_unique_index(dataframe: pd.DataFrame, partition_info=None):

def _get_meta_df() -> pd.DataFrame:
meta_dict = {"id": pd.Series(dtype="object")}
for subset_name, subset in self.spec.produces.items():
for field_name, field in subset.fields.items():
meta_dict[f"{subset_name}_{field_name}"] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
for field_name, field in self.spec.produces.items():
meta_dict[field_name] = pd.Series(
dtype=pd.ArrowDtype(field.type.value),
)
return pd.DataFrame(meta_dict).set_index("id")

meta = _get_meta_df()
Expand Down
19 changes: 9 additions & 10 deletions components/write_to_hf_hub/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,15 @@ def write(
# Get columns to write and schema
write_columns = []
schema_dict = {}
for subset_name, subset in self.spec.consumes.items():
for field in subset.fields.values():
column_name = f"{subset_name}_{field.name}"
write_columns.append(column_name)
if self.image_column_names and column_name in self.image_column_names:
schema_dict[column_name] = datasets.Image()
else:
schema_dict[column_name] = generate_from_arrow_type(
field.type.value,
)
for field_name, field in self.spec.consumes.items():
column_name = field.name
write_columns.append(column_name)
if self.image_column_names and column_name in self.image_column_names:
schema_dict[column_name] = datasets.Image()
else:
schema_dict[column_name] = generate_from_arrow_type(
field.type.value,
)

schema = datasets.Features(schema_dict).arrow_schema
dataframe = dataframe[write_columns]
Expand Down

0 comments on commit 3e5f45c

Please sign in to comment.