Skip to content

Commit

Permalink
enhance(client): enhance dataset build from huggingface for subsets (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Aug 10, 2023
1 parent 5a0837e commit 524f82f
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 62 deletions.
15 changes: 10 additions & 5 deletions client/starwhale/api/_impl/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,21 +1188,22 @@ def from_huggingface(
cls,
name: str,
repo: str,
subset: str | None = None,
subsets: t.List[str] | None = None,
split: str | None = None,
revision: str = "main",
alignment_size: int | str = D_ALIGNMENT_SIZE,
volume_size: int | str = D_FILE_VOLUME_SIZE,
mode: DatasetChangeMode | str = DatasetChangeMode.PATCH,
cache: bool = True,
tags: t.List[str] | None = None,
add_info: bool = True,
) -> Dataset:
"""Create a new dataset from huggingface datasets.
Arguments:
name: (str, required) The dataset name you would like to use.
repo: (str, required) The huggingface datasets repo name.
subset: (str, optional) The subset name. If the huggingface dataset has multiple subsets, you must specify the subset name.
subsets: (list(str), optional) The list of subset names. If the subset names are not specified, the all subsets dataset will be built.
split: (str, optional) The split name. If the split name is not specified, the all splits dataset will be built.
revision: (str, optional) The huggingface datasets revision. The default value is `main`. The option value accepts tag name, or branch name, or commit hash.
alignment_size: (int|str, optional) The blob alignment size. The default value is 128.
Expand All @@ -1211,6 +1212,10 @@ def from_huggingface(
mode: (str|DatasetChangeMode, optional) The dataset change mode. The default value is `patch`. Mode choices are `patch` and `overwrite`.
cache: (bool, optional) Whether to use huggingface dataset cache(download + local hf dataset). The default value is True.
tags: (list(str), optional) The tags for the dataset version.
add_info: (bool, optional) Whether to add huggingface dataset info to the dataset rows,
currently support to add subset and split into the dataset rows.
subset uses _hf_subset field name, split uses _hf_split field name.
The default value is True.
Returns:
A Dataset Object
Expand All @@ -1224,21 +1229,21 @@ def from_huggingface(
```python
from starwhale import Dataset
myds = Dataset.from_huggingface("mmlu", "cais/mmlu", subset="anatomy", split="auxiliary_train", revision="7456cfb")
myds = Dataset.from_huggingface("mmlu", "cais/mmlu", subsets=["anatomy"], split="auxiliary_train", revision="7456cfb")
```
"""
from starwhale.integrations.huggingface import iter_dataset

StandaloneTag.check_tags_validation(tags)

# TODO: support auto build all subset datasets
# TODO: support huggingface dataset info
data_items = iter_dataset(
repo=repo,
subset=subset,
subsets=subsets,
split=split,
revision=revision,
cache=cache,
add_info=add_info,
)

with cls.dataset(name) as ds:
Expand Down
25 changes: 19 additions & 6 deletions client/starwhale/core/dataset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,12 @@ def dataset_cmd(ctx: click.Context) -> None:
)
@optgroup.group("\n ** Huggingface Build Source Configurations")
@optgroup.option( # type: ignore[no-untyped-call]
"hf_subset",
"hf_subsets",
"--subset",
multiple=True,
help=(
"Huggingface dataset subset name. If the huggingface dataset has multiple subsets, you must specify the subset name."
"Huggingface dataset subset name. If the subset name is not specified, the all subsets will be built."
"The option can be used multiple times."
),
)
@optgroup.option( # type: ignore[no-untyped-call]
Expand All @@ -176,13 +178,22 @@ def dataset_cmd(ctx: click.Context) -> None:
"Version of the dataset script to load. Defaults to 'main'. The option value accepts tag name, or branch name, or commit hash."
),
)
@optgroup.option( # type: ignore[no-untyped-call]
"hf_info",
"--add-hf-info/--no-add-hf-info",
is_flag=True,
default=True,
show_default=True,
help="Whether to add huggingface dataset info to the dataset rows, currently support to add subset and split into the dataset rows."
"subset uses _hf_subset field name, split uses _hf_split field name.",
)
@optgroup.option( # type: ignore[no-untyped-call]
"hf_cache",
"--cache/--no-cache",
is_flag=True,
default=True,
show_default=True,
help=("Whether to use huggingface dataset cache(download + local hf dataset)."),
help="Whether to use huggingface dataset cache(download + local hf dataset).",
)
@click.pass_obj
def _build(
Expand All @@ -204,10 +215,11 @@ def _build(
field_selector: str,
mode: str,
hf_repo: str,
hf_subset: str,
hf_subsets: t.List[str],
hf_split: str,
hf_revision: str,
hf_cache: bool,
hf_info: bool,
tags: t.List[str],
) -> None:
"""Build Starwhale Dataset.
Expand Down Expand Up @@ -320,19 +332,20 @@ def _build(
config.do_validate()
view.build(_workdir, config, mode=mode_type, tags=tags)
elif hf_repo:
_candidate_name = (f"{hf_repo}/{hf_subset or ''}").strip("/").replace("/", "-")
_candidate_name = (f"{hf_repo}").strip("/").replace("/", "-")
view.build_from_huggingface(
hf_repo,
name=name or _candidate_name,
project_uri=project,
volume_size=volume_size,
alignment_size=alignment_size,
subset=hf_subset,
subsets=hf_subsets,
split=hf_split,
revision=hf_revision,
mode=mode_type,
cache=hf_cache,
tags=tags,
add_info=hf_info,
)
else:
yaml_path = Path(dataset_yaml)
Expand Down
9 changes: 6 additions & 3 deletions client/starwhale/core/dataset/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,15 @@ def build_from_json_file(
def build_from_huggingface(
self,
repo: str,
subset: str | None = None,
subsets: t.List[str] | None = None,
split: str | None = None,
revision: str = "main",
alignment_size: int | str = D_ALIGNMENT_SIZE,
volume_size: int | str = D_FILE_VOLUME_SIZE,
mode: DatasetChangeMode = DatasetChangeMode.PATCH,
cache: bool = True,
tags: t.List[str] | None = None,
add_info: bool = True,
) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -271,28 +272,30 @@ def list(
def build_from_huggingface(
self,
repo: str,
subset: str | None = None,
subsets: t.List[str] | None = None,
split: str | None = None,
revision: str = "main",
alignment_size: int | str = D_ALIGNMENT_SIZE,
volume_size: int | str = D_FILE_VOLUME_SIZE,
mode: DatasetChangeMode = DatasetChangeMode.PATCH,
cache: bool = True,
tags: t.List[str] | None = None,
add_info: bool = True,
) -> None:
from starwhale.api._impl.dataset.model import Dataset as SDKDataset

ds = SDKDataset.from_huggingface(
name=self.name,
repo=repo,
subset=subset,
subsets=subsets,
split=split,
revision=revision,
alignment_size=alignment_size,
volume_size=volume_size,
mode=mode,
cache=cache,
tags=tags,
add_info=add_info,
)
console.print(
f":hibiscus: congratulation! dataset build from https://huggingface.co/datasets/{repo} has been built. You can run "
Expand Down
6 changes: 4 additions & 2 deletions client/starwhale/core/dataset/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,13 @@ def build_from_huggingface(
project_uri: str,
alignment_size: int | str,
volume_size: int | str,
subset: str | None = None,
subsets: t.List[str] | None = None,
split: str | None = None,
revision: str = "main",
mode: DatasetChangeMode = DatasetChangeMode.PATCH,
cache: bool = True,
tags: t.List[str] | None = None,
add_info: bool = True,
) -> None:
dataset_uri = cls.prepare_build_bundle(
project=project_uri,
Expand All @@ -183,14 +184,15 @@ def build_from_huggingface(
ds = Dataset.get_dataset(dataset_uri)
ds.build_from_huggingface(
repo=repo,
subset=subset,
subsets=subsets,
split=split,
revision=revision,
alignment_size=alignment_size,
volume_size=volume_size,
mode=mode,
cache=cache,
tags=tags,
add_info=add_info,
)

@classmethod
Expand Down
73 changes: 51 additions & 22 deletions client/starwhale/integrations/huggingface/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def _transform_to_starwhale(data: t.Any, feature: t.Any) -> t.Any:
return data


def _iter_dataset(ds: hf_datasets.Dataset) -> t.Iterator[t.Tuple[int, t.Dict]]:
def _iter_dataset(
ds: hf_datasets.Dataset, subset: str, split: str | None, add_info: bool = True
) -> t.Iterator[t.Tuple[int, t.Dict]]:
for i in range(len(ds)):
item = {}
for k, v in ds[i].items():
Expand All @@ -86,41 +88,68 @@ def _iter_dataset(ds: hf_datasets.Dataset) -> t.Iterator[t.Tuple[int, t.Dict]]:
# TODO: support inner ClassLabel
if isinstance(feature, hf_datasets.ClassLabel):
item[f"{k}__classlabel__"] = feature.names[v]
if add_info:
if "_hf_subset" in item or "_hf_split" in item:
raise RuntimeError(
f"Dataset {subset} has already contains _hf_subset or _hf_split field, {item.keys()}"
)
item["_hf_subset"] = subset
item["_hf_split"] = split

yield i, item


def iter_dataset(
repo: str,
subset: str | None = None,
subsets: t.List[str] | None = None,
split: str | None = None,
revision: str = "main",
cache: bool = True,
add_info: bool = True,
) -> t.Iterator[t.Tuple[str, t.Dict]]:
download_mode = (
hf_datasets.DownloadMode.REUSE_DATASET_IF_EXISTS
if cache
else hf_datasets.DownloadMode.FORCE_REDOWNLOAD
)

ds = hf_datasets.load_dataset(
repo,
subset,
split=split,
revision=revision,
download_config = hf_datasets.DownloadConfig(
max_retries=10,
num_proc=min(8, os.cpu_count() or 8),
download_mode=download_mode,
)

if isinstance(ds, hf_datasets.DatasetDict):
for _split, _ds in ds.items():
for _key, _data in _iter_dataset(_ds):
yield f"{_split}/{_key}", _data
elif isinstance(ds, hf_datasets.Dataset):
for _key, _data in _iter_dataset(ds):
if split:
_s_key = f"{split}/{_key}"
else:
_s_key = str(_key)
yield _s_key, _data
else:
raise RuntimeError(f"Unknown dataset type: {type(ds)}")
if not subsets:
subsets = hf_datasets.get_dataset_config_names(
repo,
revision=revision,
download_mode=download_mode,
download_config=download_config,
)

if not subsets:
raise RuntimeError(f"Dataset {repo} has no any valid config names")

for subset in subsets:
ds = hf_datasets.load_dataset(
repo,
subset,
split=split,
revision=revision,
download_mode=download_mode,
download_config=download_config,
)

if isinstance(ds, hf_datasets.DatasetDict):
for _ds_split, _ds in ds.items():
for _key, _data in _iter_dataset(
_ds, subset, _ds_split, add_info=add_info
):
yield f"{subset}/{_ds_split}/{_key}", _data
elif isinstance(ds, hf_datasets.Dataset):
for _key, _data in _iter_dataset(ds, subset, split, add_info=add_info):
if split:
_s_key = f"{subset}/{split}/{_key}"
else:
_s_key = f"{subset}/{_key}"
yield _s_key, _data
else:
raise RuntimeError(f"Unknown dataset type: {type(ds)}")
4 changes: 2 additions & 2 deletions client/tests/core/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def test_build_from_huggingface(self, m_hf: MagicMock) -> None:
assert call_args
assert call_args[0][0] == "mnist"
assert call_args[1]["name"] == "huggingface-test"
assert call_args[1]["subset"] is None
assert len(call_args[1]["subsets"]) == 0
assert not call_args[1]["cache"]

DatasetTermView.build_from_huggingface(
Expand All @@ -198,7 +198,7 @@ def test_build_from_huggingface(self, m_hf: MagicMock) -> None:
project_uri="self",
alignment_size="128",
volume_size="128M",
subset="sub1",
subsets=["sub1"],
split="train",
revision="main",
)
Expand Down
Loading

0 comments on commit 524f82f

Please sign in to comment.