From 524f82fa8f0e5a12ff05349b58c77377857e8b8d Mon Sep 17 00:00:00 2001 From: tianwei Date: Thu, 10 Aug 2023 14:00:43 +0800 Subject: [PATCH] enhance(client): enhance dataset build from huggingface for subsets (#2608) --- client/starwhale/api/_impl/dataset/model.py | 15 ++-- client/starwhale/core/dataset/cli.py | 25 +++++-- client/starwhale/core/dataset/model.py | 9 ++- client/starwhale/core/dataset/view.py | 6 +- .../integrations/huggingface/dataset.py | 73 +++++++++++++------ client/tests/core/test_dataset.py | 4 +- client/tests/sdk/test_dataset_sdk.py | 73 +++++++++++++------ 7 files changed, 143 insertions(+), 62 deletions(-) diff --git a/client/starwhale/api/_impl/dataset/model.py b/client/starwhale/api/_impl/dataset/model.py index d36a79fb46..6b9d4d5aaf 100644 --- a/client/starwhale/api/_impl/dataset/model.py +++ b/client/starwhale/api/_impl/dataset/model.py @@ -1188,7 +1188,7 @@ def from_huggingface( cls, name: str, repo: str, - subset: str | None = None, + subsets: t.List[str] | None = None, split: str | None = None, revision: str = "main", alignment_size: int | str = D_ALIGNMENT_SIZE, @@ -1196,13 +1196,14 @@ def from_huggingface( mode: DatasetChangeMode | str = DatasetChangeMode.PATCH, cache: bool = True, tags: t.List[str] | None = None, + add_info: bool = True, ) -> Dataset: """Create a new dataset from huggingface datasets. Arguments: name: (str, required) The dataset name you would like to use. repo: (str, required) The huggingface datasets repo name. - subset: (str, optional) The subset name. If the huggingface dataset has multiple subsets, you must specify the subset name. + subsets: (list(str), optional) The list of subset names. If the subset names are not specified, the all subsets dataset will be built. split: (str, optional) The split name. If the split name is not specified, the all splits dataset will be built. revision: (str, optional) The huggingface datasets revision. The default value is `main`. The option value accepts tag name, or branch name, or commit hash. alignment_size: (int|str, optional) The blob alignment size. The default value is 128. @@ -1211,6 +1212,10 @@ def from_huggingface( mode: (str|DatasetChangeMode, optional) The dataset change mode. The default value is `patch`. Mode choices are `patch` and `overwrite`. cache: (bool, optional) Whether to use huggingface dataset cache(download + local hf dataset). The default value is True. tags: (list(str), optional) The tags for the dataset version. + add_info: (bool, optional) Whether to add huggingface dataset info to the dataset rows, + currently support to add subset and split into the dataset rows. + subset uses _hf_subset field name, split uses _hf_split field name. + The default value is True. Returns: A Dataset Object @@ -1224,21 +1229,21 @@ def from_huggingface( ```python from starwhale import Dataset - myds = Dataset.from_huggingface("mmlu", "cais/mmlu", subset="anatomy", split="auxiliary_train", revision="7456cfb") + myds = Dataset.from_huggingface("mmlu", "cais/mmlu", subsets=["anatomy"], split="auxiliary_train", revision="7456cfb") ``` """ from starwhale.integrations.huggingface import iter_dataset StandaloneTag.check_tags_validation(tags) - # TODO: support auto build all subset datasets # TODO: support huggingface dataset info data_items = iter_dataset( repo=repo, - subset=subset, + subsets=subsets, split=split, revision=revision, cache=cache, + add_info=add_info, ) with cls.dataset(name) as ds: diff --git a/client/starwhale/core/dataset/cli.py b/client/starwhale/core/dataset/cli.py index 297bf6b1d0..ebf9316346 100644 --- a/client/starwhale/core/dataset/cli.py +++ b/client/starwhale/core/dataset/cli.py @@ -154,10 +154,12 @@ def dataset_cmd(ctx: click.Context) -> None: ) @optgroup.group("\n ** Huggingface Build Source Configurations") @optgroup.option( # type: ignore[no-untyped-call] - "hf_subset", + "hf_subsets", "--subset", + multiple=True, help=( - "Huggingface dataset subset name. If the huggingface dataset has multiple subsets, you must specify the subset name." + "Huggingface dataset subset name. If the subset name is not specified, the all subsets will be built." + "The option can be used multiple times." ), ) @optgroup.option( # type: ignore[no-untyped-call] @@ -176,13 +178,22 @@ def dataset_cmd(ctx: click.Context) -> None: "Version of the dataset script to load. Defaults to 'main'. The option value accepts tag name, or branch name, or commit hash." ), ) +@optgroup.option( # type: ignore[no-untyped-call] + "hf_info", + "--add-hf-info/--no-add-hf-info", + is_flag=True, + default=True, + show_default=True, + help="Whether to add huggingface dataset info to the dataset rows, currently support to add subset and split into the dataset rows." + "subset uses _hf_subset field name, split uses _hf_split field name.", +) @optgroup.option( # type: ignore[no-untyped-call] "hf_cache", "--cache/--no-cache", is_flag=True, default=True, show_default=True, - help=("Whether to use huggingface dataset cache(download + local hf dataset)."), + help="Whether to use huggingface dataset cache(download + local hf dataset).", ) @click.pass_obj def _build( @@ -204,10 +215,11 @@ def _build( field_selector: str, mode: str, hf_repo: str, - hf_subset: str, + hf_subsets: t.List[str], hf_split: str, hf_revision: str, hf_cache: bool, + hf_info: bool, tags: t.List[str], ) -> None: """Build Starwhale Dataset. @@ -320,19 +332,20 @@ def _build( config.do_validate() view.build(_workdir, config, mode=mode_type, tags=tags) elif hf_repo: - _candidate_name = (f"{hf_repo}/{hf_subset or ''}").strip("/").replace("/", "-") + _candidate_name = (f"{hf_repo}").strip("/").replace("/", "-") view.build_from_huggingface( hf_repo, name=name or _candidate_name, project_uri=project, volume_size=volume_size, alignment_size=alignment_size, - subset=hf_subset, + subsets=hf_subsets, split=hf_split, revision=hf_revision, mode=mode_type, cache=hf_cache, tags=tags, + add_info=hf_info, ) else: yaml_path = Path(dataset_yaml) diff --git a/client/starwhale/core/dataset/model.py b/client/starwhale/core/dataset/model.py index 5ef6980777..da4ca2cd95 100644 --- a/client/starwhale/core/dataset/model.py +++ b/client/starwhale/core/dataset/model.py @@ -101,7 +101,7 @@ def build_from_json_file( def build_from_huggingface( self, repo: str, - subset: str | None = None, + subsets: t.List[str] | None = None, split: str | None = None, revision: str = "main", alignment_size: int | str = D_ALIGNMENT_SIZE, @@ -109,6 +109,7 @@ def build_from_huggingface( mode: DatasetChangeMode = DatasetChangeMode.PATCH, cache: bool = True, tags: t.List[str] | None = None, + add_info: bool = True, ) -> None: raise NotImplementedError @@ -271,7 +272,7 @@ def list( def build_from_huggingface( self, repo: str, - subset: str | None = None, + subsets: t.List[str] | None = None, split: str | None = None, revision: str = "main", alignment_size: int | str = D_ALIGNMENT_SIZE, @@ -279,13 +280,14 @@ def build_from_huggingface( mode: DatasetChangeMode = DatasetChangeMode.PATCH, cache: bool = True, tags: t.List[str] | None = None, + add_info: bool = True, ) -> None: from starwhale.api._impl.dataset.model import Dataset as SDKDataset ds = SDKDataset.from_huggingface( name=self.name, repo=repo, - subset=subset, + subsets=subsets, split=split, revision=revision, alignment_size=alignment_size, @@ -293,6 +295,7 @@ def build_from_huggingface( mode=mode, cache=cache, tags=tags, + add_info=add_info, ) console.print( f":hibiscus: congratulation! dataset build from https://huggingface.co/datasets/{repo} has been built. You can run " diff --git a/client/starwhale/core/dataset/view.py b/client/starwhale/core/dataset/view.py index 484f5e33e5..9743c1246f 100644 --- a/client/starwhale/core/dataset/view.py +++ b/client/starwhale/core/dataset/view.py @@ -167,12 +167,13 @@ def build_from_huggingface( project_uri: str, alignment_size: int | str, volume_size: int | str, - subset: str | None = None, + subsets: t.List[str] | None = None, split: str | None = None, revision: str = "main", mode: DatasetChangeMode = DatasetChangeMode.PATCH, cache: bool = True, tags: t.List[str] | None = None, + add_info: bool = True, ) -> None: dataset_uri = cls.prepare_build_bundle( project=project_uri, @@ -183,7 +184,7 @@ def build_from_huggingface( ds = Dataset.get_dataset(dataset_uri) ds.build_from_huggingface( repo=repo, - subset=subset, + subsets=subsets, split=split, revision=revision, alignment_size=alignment_size, @@ -191,6 +192,7 @@ def build_from_huggingface( mode=mode, cache=cache, tags=tags, + add_info=add_info, ) @classmethod diff --git a/client/starwhale/integrations/huggingface/dataset.py b/client/starwhale/integrations/huggingface/dataset.py index a13aa9b957..bcf2c132e6 100644 --- a/client/starwhale/integrations/huggingface/dataset.py +++ b/client/starwhale/integrations/huggingface/dataset.py @@ -77,7 +77,9 @@ def _transform_to_starwhale(data: t.Any, feature: t.Any) -> t.Any: return data -def _iter_dataset(ds: hf_datasets.Dataset) -> t.Iterator[t.Tuple[int, t.Dict]]: +def _iter_dataset( + ds: hf_datasets.Dataset, subset: str, split: str | None, add_info: bool = True +) -> t.Iterator[t.Tuple[int, t.Dict]]: for i in range(len(ds)): item = {} for k, v in ds[i].items(): @@ -86,41 +88,68 @@ def _iter_dataset(ds: hf_datasets.Dataset) -> t.Iterator[t.Tuple[int, t.Dict]]: # TODO: support inner ClassLabel if isinstance(feature, hf_datasets.ClassLabel): item[f"{k}__classlabel__"] = feature.names[v] + if add_info: + if "_hf_subset" in item or "_hf_split" in item: + raise RuntimeError( + f"Dataset {subset} has already contains _hf_subset or _hf_split field, {item.keys()}" + ) + item["_hf_subset"] = subset + item["_hf_split"] = split + yield i, item def iter_dataset( repo: str, - subset: str | None = None, + subsets: t.List[str] | None = None, split: str | None = None, revision: str = "main", cache: bool = True, + add_info: bool = True, ) -> t.Iterator[t.Tuple[str, t.Dict]]: download_mode = ( hf_datasets.DownloadMode.REUSE_DATASET_IF_EXISTS if cache else hf_datasets.DownloadMode.FORCE_REDOWNLOAD ) - - ds = hf_datasets.load_dataset( - repo, - subset, - split=split, - revision=revision, + download_config = hf_datasets.DownloadConfig( + max_retries=10, num_proc=min(8, os.cpu_count() or 8), - download_mode=download_mode, ) - if isinstance(ds, hf_datasets.DatasetDict): - for _split, _ds in ds.items(): - for _key, _data in _iter_dataset(_ds): - yield f"{_split}/{_key}", _data - elif isinstance(ds, hf_datasets.Dataset): - for _key, _data in _iter_dataset(ds): - if split: - _s_key = f"{split}/{_key}" - else: - _s_key = str(_key) - yield _s_key, _data - else: - raise RuntimeError(f"Unknown dataset type: {type(ds)}") + if not subsets: + subsets = hf_datasets.get_dataset_config_names( + repo, + revision=revision, + download_mode=download_mode, + download_config=download_config, + ) + + if not subsets: + raise RuntimeError(f"Dataset {repo} has no any valid config names") + + for subset in subsets: + ds = hf_datasets.load_dataset( + repo, + subset, + split=split, + revision=revision, + download_mode=download_mode, + download_config=download_config, + ) + + if isinstance(ds, hf_datasets.DatasetDict): + for _ds_split, _ds in ds.items(): + for _key, _data in _iter_dataset( + _ds, subset, _ds_split, add_info=add_info + ): + yield f"{subset}/{_ds_split}/{_key}", _data + elif isinstance(ds, hf_datasets.Dataset): + for _key, _data in _iter_dataset(ds, subset, split, add_info=add_info): + if split: + _s_key = f"{subset}/{split}/{_key}" + else: + _s_key = f"{subset}/{_key}" + yield _s_key, _data + else: + raise RuntimeError(f"Unknown dataset type: {type(ds)}") diff --git a/client/tests/core/test_dataset.py b/client/tests/core/test_dataset.py index 30cebf8882..99ed49efa5 100644 --- a/client/tests/core/test_dataset.py +++ b/client/tests/core/test_dataset.py @@ -189,7 +189,7 @@ def test_build_from_huggingface(self, m_hf: MagicMock) -> None: assert call_args assert call_args[0][0] == "mnist" assert call_args[1]["name"] == "huggingface-test" - assert call_args[1]["subset"] is None + assert len(call_args[1]["subsets"]) == 0 assert not call_args[1]["cache"] DatasetTermView.build_from_huggingface( @@ -198,7 +198,7 @@ def test_build_from_huggingface(self, m_hf: MagicMock) -> None: project_uri="self", alignment_size="128", volume_size="128M", - subset="sub1", + subsets=["sub1"], split="train", revision="main", ) diff --git a/client/tests/sdk/test_dataset_sdk.py b/client/tests/sdk/test_dataset_sdk.py index 5ff13397f5..4ca120bc2e 100644 --- a/client/tests/sdk/test_dataset_sdk.py +++ b/client/tests/sdk/test_dataset_sdk.py @@ -2015,8 +2015,13 @@ def test_compound_data(self) -> None: assert transform_data["sequence_dict"]["int"] == 1 assert transform_data["sequence_dict"]["list_int"] == [1, 1, 1] + @patch( + "starwhale.integrations.huggingface.dataset.hf_datasets.get_dataset_config_names" + ) @patch("starwhale.integrations.huggingface.dataset.hf_datasets.load_dataset") - def test_build_dataset(self, m_load_dataset: MagicMock) -> None: + def test_build_dataset( + self, m_load_dataset: MagicMock, m_get_config_names: MagicMock + ) -> None: import datasets as hf_datasets complex_data = { @@ -2091,46 +2096,70 @@ def test_build_dataset(self, m_load_dataset: MagicMock) -> None: } ) + m_get_config_names.return_value = ["simple"] m_load_dataset.return_value = hf_simple_ds Dataset.from_huggingface( - name="simple", repo="simple", split="train", tags=["hf-0", "hf-1"] + name="simple", + repo="simple", + split="train", + tags=["hf-0", "hf-1"], + add_info=True, ) simple_ds = dataset("simple") assert len(simple_ds) == 2 - assert simple_ds["train/0"].features.int == 1 - assert simple_ds["train/0"].features.float == 1.0 - assert simple_ds["train/0"].features.str == "test1" - assert simple_ds["train/0"].features.bin == b"test1" - large_str = simple_ds["train/0"].features["large_str"] - large_bin = simple_ds["train/0"].features["large_bin"] + assert simple_ds["simple/train/0"].features.int == 1 + assert simple_ds["simple/train/0"].features.float == 1.0 + assert simple_ds["simple/train/0"].features.str == "test1" + assert simple_ds["simple/train/0"].features.bin == b"test1" + assert simple_ds["simple/train/0"].features["_hf_subset"] == "simple" + assert simple_ds["simple/train/0"].features["_hf_split"] == "train" + large_str = simple_ds["simple/train/0"].features["large_str"] + large_bin = simple_ds["simple/train/0"].features["large_bin"] assert isinstance(large_str, str) assert isinstance(large_bin, bytes) assert large_str == "test1" * 20 assert large_bin == b"test1" * 20 - assert simple_ds["train/1"].features.int == 2 - assert simple_ds["train/1"].features["large_bin"] == b"test2" * 20 + assert simple_ds["simple/train/1"].features.int == 2 + assert simple_ds["simple/train/1"].features["large_bin"] == b"test2" * 20 + m_get_config_names.return_value = ["complex"] m_load_dataset.return_value = hf_complex_ds - complex_ds = Dataset.from_huggingface(name="complex", repo="complex") + complex_ds = Dataset.from_huggingface( + name="complex", repo="complex", add_info=False + ) assert len(complex_ds) == 1 - assert complex_ds["0"].features.list_int == [1, 2, 3] - assert complex_ds["0"].features.seq_img[0].shape == [1, 1, 3] - assert complex_ds["0"].features.seq_img[1].shape == [1, 1, 1] - assert complex_ds["0"].features.seq_dict["int"] == [1] - assert complex_ds["0"].features.seq_dict["str"] == ["test"] - assert complex_ds["0"].features.img.shape == [1, 1, 3] - assert complex_ds["0"].features.class_label == 1 - _audio = complex_ds["0"].features.audio + assert complex_ds["complex/0"].features.list_int == [1, 2, 3] + assert complex_ds["complex/0"].features.seq_img[0].shape == [1, 1, 3] + assert complex_ds["complex/0"].features.seq_img[1].shape == [1, 1, 1] + assert complex_ds["complex/0"].features.seq_dict["int"] == [1] + assert complex_ds["complex/0"].features.seq_dict["str"] == ["test"] + assert complex_ds["complex/0"].features.img.shape == [1, 1, 3] + assert complex_ds["complex/0"].features.class_label == 1 + assert "_hf_subset" not in complex_ds["complex/0"].features + assert "_hf_split" not in complex_ds["complex/0"].features + _audio = complex_ds["complex/0"].features.audio assert isinstance(_audio, Audio) assert _audio.display_name == "simple.wav" assert _audio.mime_type == MIMEType.WAV + m_get_config_names.return_value = ["mixed"] m_load_dataset.return_value = hf_mixed_ds mixed_ds = Dataset.from_huggingface(name="mixed", repo="mixed") assert len(mixed_ds) == 3 - assert mixed_ds["complex/0"].features.list_int == [1, 2, 3] - assert mixed_ds["simple/0"].features.int == 1 - assert mixed_ds["simple/1"].features.int == 2 + assert mixed_ds["mixed/complex/0"].features.list_int == [1, 2, 3] + assert mixed_ds["mixed/simple/0"].features.int == 1 + assert mixed_ds["mixed/simple/1"].features.int == 2 + + m_get_config_names.return_value = ["simple", "mixed"] + m_load_dataset.side_effect = [hf_simple_ds, hf_mixed_ds] + multi_subsets_ds = Dataset.from_huggingface(name="multi", repo="multi") + assert len(multi_subsets_ds) == 5 + assert multi_subsets_ds["simple/0"].features["_hf_subset"] == "simple" + assert "_hf_split" not in multi_subsets_ds["simple/0"].features + assert multi_subsets_ds["mixed/simple/0"].features["_hf_subset"] == "mixed" + assert multi_subsets_ds["mixed/simple/0"].features["_hf_split"] == "simple" + assert multi_subsets_ds["mixed/complex/0"].features["_hf_subset"] == "mixed" + assert multi_subsets_ds["mixed/complex/0"].features["_hf_split"] == "complex"