Skip to content

Commit

Permalink
Stream from Hugging Face instead of downloading and preparing everyth…
Browse files Browse the repository at this point in the history
…ing.

PiperOrigin-RevId: 657212303
  • Loading branch information
marcenacp authored and The TensorFlow Datasets Authors committed Jul 30, 2024
1 parent 2123db7 commit 290ee7e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import itertools
import multiprocessing
import os
import time
from typing import Any, Dict, Optional, Union

from absl import logging
Expand Down Expand Up @@ -108,9 +109,24 @@ class _ShardInfo:
num_exceptions: int


def _load_dataset(
hf_builder: hf_datasets.DatasetBuilder,
split: str,
) -> hf_datasets.Dataset:
"""Efficiently loads a HuggingFace iterable dataset from its builder."""
if hf_builder.repo_id is None:
return hf_builder.as_dataset(split=split)
return hf_datasets.load_dataset(
hf_builder.repo_id or hf_builder.cache_dir,
hf_builder.config_id,
split=split,
streaming=True,
)


def _write_shard(
shard_spec: _ShardSpec,
hf_builder,
hf_builder: hf_datasets.DatasetBuilder,
example_writer,
features: feature_lib.FeaturesDict,
ignore_hf_errors: bool,
Expand All @@ -136,12 +152,19 @@ def _write_shard(
def get_serialized_examples_iter():
nonlocal num_bytes
nonlocal num_exceptions
dataset = hf_builder.as_dataset(
split=shard_spec.shard_split, run_post_process=False
dataset = _load_dataset(
hf_builder,
shard_spec.hf_split,
)
for i in range(shard_spec.num_examples):
dataset = iter(dataset)
# Skipping the first `start_index` examples. `streaming=True` returns an
# iterable dataset, so we cannot jump to a specific index. This is not too
# costly because it takes <0.5 ms/element in the wikipedia dataset.
for _ in range(shard_spec.start_index):
next(dataset)
for _ in range(shard_spec.num_examples):
try:
hf_value = dataset[i]
hf_value = next(dataset)
except Exception: # pylint: disable=broad-exception-caught
num_exceptions += 1
if ignore_hf_errors:
Expand All @@ -155,6 +178,7 @@ def get_serialized_examples_iter():
num_bytes += len(serialized_example)
yield serialized_example

start = time.time()
example_writer.write(
os.fspath(shard_spec.path),
tqdm_utils.tqdm(
Expand All @@ -166,6 +190,11 @@ def get_serialized_examples_iter():
mininterval=1.0,
),
)
logging.info(
'Generated %s examples in %s seconds',
shard_spec.num_examples,
time.time() - start,
)

return _ShardInfo(
num_bytes=num_bytes,
Expand Down Expand Up @@ -247,6 +276,7 @@ def __init__(
self._builder_config = self._converted_builder_config
self.generation_errors = []
self._ignore_hf_errors = ignore_hf_errors
login_to_hf(self._hf_hub_token)

@property
def builder_config(self) -> Optional[Any]:
Expand All @@ -257,14 +287,6 @@ def _create_builder_config(
) -> Optional[dataset_builder.BuilderConfig]:
return self._converted_builder_config

@functools.lru_cache(maxsize=1)
def _hf_download_and_prepare(self):
login_to_hf(self._hf_hub_token)
self._hf_builder.download_and_prepare(
num_proc=self._hf_num_proc,
verification_mode=self._verification_mode,
)

@property
def _hf_info(self) -> hf_datasets.DatasetInfo:
"""Retrieves the dataset info from the HuggingFace Datasets."""
Expand All @@ -278,11 +300,18 @@ def _hf_hub_info(self) -> huggingface_hub.hf_api.DatasetInfo:
)

def _hf_features(self) -> hf_datasets.Features:
if not self._hf_info.features:
# We need to download and prepare the data to know its features.
self._hf_download_and_prepare()

return self._hf_info.features
# Return the features from the builder info.
if self._hf_info.features:
return self._hf_info.features
# Return the features from the first split.
for split in self._hf_info.splits:
ds = _load_dataset(
self._hf_builder,
split,
)
if hasattr(ds, 'info') and ds.info.features:
return ds.info.features
raise ValueError('No features found in the dataset.')

def _info(self) -> dataset_info_lib.DatasetInfo:
return dataset_info_lib.DatasetInfo(
Expand All @@ -309,7 +338,6 @@ def _generate_splits(
) -> Sequence[splits_lib.SplitInfo]:
"""Prepares the dataset by writing to shards directly."""
del dl_manager, download_config # Unused.
self._hf_download_and_prepare()

shard_specs_by_split: dict[str, Sequence[_ShardSpec]] = {}
for hf_split, hf_split_info in self._hf_info.splits.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def mock_load_dataset_builder(tmp_path):
with mock.patch.object(
hf_datasets, 'load_dataset_builder', return_value=hf_builder
) as load_dataset_builder:
hf_builder.download_and_prepare()
yield load_dataset_builder


Expand Down Expand Up @@ -133,12 +134,6 @@ def test_download_and_prepare(builder):
assert len(ds['train_clean']) == 2


def test_all_parameters_are_passed_down_to_hf(builder):
builder._hf_builder.download_and_prepare.assert_called_once_with(
verification_mode='no_checks', num_proc=100
)


def test_hf_features(builder):
assert builder._hf_features() == {
'number': hf_datasets.Value('int64'),
Expand Down

0 comments on commit 290ee7e

Please sign in to comment.