Skip to content

Commit

Permalink
Address PR (pytorch#952) comments
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenDS9 committed Jan 23, 2023
1 parent 573f948 commit 135f539
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 12 deletions.
7 changes: 6 additions & 1 deletion test/test_huggingface_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@ def test_huggingface_hubreader(self, mock_load_dataset):

datapipe = HuggingFaceHubReader("lhoestq/demo1", revision="branch", streaming=False, use_auth_token=True)

elem = next(iter(datapipe))
iterator = iter(datapipe)
elem = next(iterator)
assert type(elem) is dict
assert elem["id"] == "7bd227d9-afc9-11e6-aba1-c4b301cdf627"
assert elem["package_name"] == "com.mantz_it.rfanalyzer"
mock_load_dataset.assert_called_with(
path="lhoestq/demo1", streaming=False, split="train", revision="branch", use_auth_token=True
)
with self.assertRaises(StopIteration):
next(iterator)
next(iterator)


if __name__ == "__main__":
Expand Down
31 changes: 20 additions & 11 deletions torchdata/datapipes/iter/load/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Any, Iterator, Tuple
import warnings
from typing import Any, Iterator, Tuple, Union

from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.utils import StreamWrapper
Expand All @@ -17,20 +16,27 @@


def _get_response_from_huggingface_hub(
path: str, streaming=True, split="train", revision="main", **config_kwargs
dataset: str,
split: Union[str, datasets.Split] = "train",
revision: Union[str, datasets.Version] = "main",
streaming: bool = True,
**config_kwargs,
) -> Iterator[Any]:
hf_dataset = datasets.load_dataset(path=path, streaming=streaming, split=split, revision=revision, **config_kwargs)
hf_dataset = datasets.load_dataset(
path=dataset, streaming=streaming, split=split, revision=revision, **config_kwargs
)
return iter(hf_dataset)


class HuggingFaceHubReaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
r"""
Takes in dataset names and returns an Iterable HuggingFace dataset.
Args format and meaning are the same as https://huggingface.co/docs/datasets/loading.
Contrary to their implementation, default behavior differs in the following:
split is set to "train".
revision is set to "main".
streaming is set to True.
Contrary to their implementation, default behavior differs in the following (this will be changed in version 0.7):
split is set to "train".
revision is set to "main".
streaming is set to True.
Args:
source_datapipe: a DataPipe that contains dataset names which will be accepted by the HuggingFace datasets library
Example:
Expand Down Expand Up @@ -67,11 +73,14 @@ def __init__(
"to install the package"
)

self.path = dataset
self.datset = dataset
self.config_kwargs = config_kwargs
warnings.warn(
"default behavior of HuggingFaceHubReader will change in version 0.7", DeprecationWarning, stacklevel=2
)

def __iter__(self) -> Iterator[Any]:
return _get_response_from_huggingface_hub(path=self.path, **self.config_kwargs)
return _get_response_from_huggingface_hub(dataset=self.datset, **self.config_kwargs)

def __len__(self) -> int:
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")

0 comments on commit 135f539

Please sign in to comment.