Skip to content

Commit

Permalink
Add a beam writer that doesn't shuffle
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 683197996
  • Loading branch information
tomvdw authored and The TensorFlow Datasets Authors committed Oct 15, 2024
1 parent bc48d05 commit 0389eb5
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 35 deletions.
7 changes: 6 additions & 1 deletion tensorflow_datasets/core/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,9 @@ def download_and_prepare(
self.info.download_size = dl_manager.downloaded_size
# Write DatasetInfo to disk, even if we haven't computed statistics.
self.info.write_to_directory(self.data_dir)
print(f'XXXXXXXXXX: at end of temporary assignment')
print(f'XXXXXXXXXX: at end of with utils.incomplete_dir')
print(f'XXXXXXXXXX: utils.incomplete_dir ended')
# The generated DatasetInfo contains references to `tmp_data_dir`
self.info.update_data_dir(self.data_dir)

Expand Down Expand Up @@ -1411,11 +1414,13 @@ def _get_filename_template(
self, split_name: str
) -> naming.ShardedFileTemplate:
"""Returns a filename template for the given split."""
if self.info.file_format is None:
raise ValueError("File format is not set!")
return naming.ShardedFileTemplate(
split=split_name,
dataset_name=self.name,
data_dir=self.data_path,
filetype_suffix=self.info.file_format.file_suffix, # pytype: disable=attribute-error
filetype_suffix=self.info.file_format.file_suffix,
)


Expand Down
48 changes: 31 additions & 17 deletions tensorflow_datasets/core/dataset_builder_beam_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,16 @@ class DummyBeamDataset(dataset_builder.GeneratorBasedBuilder):
'valid_725': 725,
}

FEATURE_DICT = features.FeaturesDict({
'image': features.Image(shape=(16, 16, 1)),
'label': features.ClassLabel(names=['dog', 'cat']),
'id': tf.int32,
})

def _info(self):
return dataset_info.DatasetInfo(
builder=self,
features=features.FeaturesDict({
'image': features.Image(shape=(16, 16, 1)),
'label': features.ClassLabel(names=['dog', 'cat']),
'id': tf.int32,
}),
features=self.FEATURE_DICT,
supervised_keys=('x', 'x'),
metadata=dataset_info.BeamMetadataDict(),
)
Expand All @@ -71,6 +73,18 @@ def _generate_examples(self, num_examples):
return examples


class UnshuffledDummyBeamDataset(DummyBeamDataset):

def _info(self) -> dataset_info.DatasetInfo:
return dataset_info.DatasetInfo(
builder=self,
features=self.FEATURE_DICT,
supervised_keys=('x', 'x'),
metadata=dataset_info.BeamMetadataDict(),
disable_shuffling=True,
)


class CommonPipelineDummyBeamDataset(DummyBeamDataset):
EXPECTED_METADATA = {
'label_sum_1000': 500,
Expand Down Expand Up @@ -156,7 +170,12 @@ def make_default_config():


@pytest.mark.parametrize(
'dataset_cls', [DummyBeamDataset, CommonPipelineDummyBeamDataset]
'dataset_cls',
[
DummyBeamDataset,
CommonPipelineDummyBeamDataset,
UnshuffledDummyBeamDataset,
],
)
@pytest.mark.parametrize(
'make_dl_config',
Expand All @@ -178,17 +197,12 @@ def test_beam_datasets(
assert data_path.exists() # Dataset has been generated

# Check number of shards/generated files
_test_shards(
data_path,
pattern='%s-test.tfrecord-{:05}-of-{:05}' % dataset_name,
# Liquid sharding is not guaranteed to always use the same number.
num_shards=builder.info.splits['test'].num_shards,
)
_test_shards(
data_path,
pattern='%s-train.tfrecord-{:05}-of-{:05}' % dataset_name,
num_shards=1,
)
for split in ['test', 'train']:
_test_shards(
data_path,
pattern='%s-%s.tfrecord-{:05}-of-{:05}' % (dataset_name, split),
num_shards=builder.info.splits[split].num_shards,
)

ds = dataset_utils.as_numpy(builder.as_dataset())

Expand Down
50 changes: 50 additions & 0 deletions tensorflow_datasets/core/file_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,18 @@
from typing import Any, ClassVar, Type, TypeVar

from etils import epy
from tensorflow_datasets.core.utils.lazy_imports_utils import apache_beam as beam
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_io
from tensorflow_datasets.core.utils.lazy_imports_utils import array_record_module
from tensorflow_datasets.core.utils.lazy_imports_utils import parquet as pq
from tensorflow_datasets.core.utils.lazy_imports_utils import pyarrow as pa
from tensorflow_datasets.core.utils.lazy_imports_utils import tensorflow as tf
from tensorflow_datasets.core.utils.lazy_imports_utils import tfrecordio

with epy.lazy_imports():
# pylint: disable=g-import-not-at-top
from etils import epath
from tensorflow_datasets.core import naming
from tensorflow_datasets.core.utils import file_utils
from tensorflow_datasets.core.utils import type_utils

Expand Down Expand Up @@ -167,6 +171,23 @@ def deserialize(cls, raw_example: bytes) -> Any:
"""
return tf.train.Example.FromString(raw_example)

@classmethod
def beam_sink(
cls,
filename_template: naming.ShardedFileTemplate,
num_shards: int | None = None,
) -> beam.PTransform:
"""Returns a Beam sink for writing examples in the given file format."""
raise NotImplementedError()

@classmethod
def num_examples(cls, filename: epath.PathLike) -> int:
"""Returns the number of examples in the given file."""
n = 0
for _ in cls.make_tf_data(filename):
n += 1
return n


class TfRecordFileAdapter(FileAdapter):
"""File adapter for TFRecord file format."""
Expand Down Expand Up @@ -205,6 +226,20 @@ def write_examples(
writer.write(serialized_example)
writer.flush()

@classmethod
def beam_sink(
cls,
filename_template: naming.ShardedFileTemplate,
num_shards: int | None = None,
) -> beam.PTransform:
"""Returns a Beam sink for writing examples in the given file format."""
file_path_prefix = filename_template.sharded_filepaths_pattern(
num_shards=num_shards, use_at_notation=True
).removesuffix('@*')
return tfrecordio.WriteToTFRecord(
file_path_prefix=file_path_prefix, num_shards=num_shards
)


class RiegeliFileAdapter(FileAdapter):
"""File adapter for Riegeli file format."""
Expand Down Expand Up @@ -291,6 +326,21 @@ def write_examples(
writer.write(serialized_example)
writer.close()

@classmethod
def beam_sink(
cls,
filename_template: naming.ShardedFileTemplate,
num_shards: int | None = None,
) -> beam.PTransform:
"""Returns a Beam sink for writing examples in the given file format."""
return array_record_io.WriteToArrayRecord(
filename_template.sharded_filepaths_pattern(
num_shards=num_shards, use_at_notation=True
),
num_shards=num_shards,
record_writer_options='group_size:1',
)


class ParquetFileAdapter(FileAdapter):
"""File adapter for the [Parquet](https://parquet.apache.org) file format.
Expand Down
6 changes: 5 additions & 1 deletion tensorflow_datasets/core/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ def sharded_filepaths_pattern(
self,
*,
num_shards: int | None = None,
use_at_notation: bool = False,
) -> str:
"""Returns a pattern describing all the file paths captured by this template.
Expand All @@ -641,21 +642,24 @@ def sharded_filepaths_pattern(
Args:
num_shards: optional specification of the number of shards.
use_at_notation: whether to return @* in case `num_shards` is `None`.
Returns:
the pattern describing all shards captured by this template.
"""
a_filepath = self.sharded_filepath(shard_index=0, num_shards=1)
if num_shards:
replacement = f'@{num_shards}'
elif use_at_notation:
replacement = '@*'
else:
replacement = '*'
return _replace_shard_pattern(os.fspath(a_filepath), replacement)

def sharded_filenames(self, num_shards: int) -> list[str]:
return [path.name for path in self.sharded_filepaths(num_shards=num_shards)]

def replace(self, **kwargs: Any) -> 'ShardedFileTemplate':
def replace(self, **kwargs: Any) -> ShardedFileTemplate:
"""Returns a copy of the `ShardedFileTemplate` with updated attributes."""
return dataclasses.replace(self, **kwargs)

Expand Down
35 changes: 24 additions & 11 deletions tensorflow_datasets/core/split_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
# pylint: disable=g-import-not-at-top
from tensorflow_datasets.core import example_serializer
from tensorflow_datasets.core import features as features_lib
from tensorflow_datasets.core import file_adapters
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import splits as splits_lib
from tensorflow_datasets.core import utils
Expand Down Expand Up @@ -530,17 +531,29 @@ def _build_from_pcollection(
) -> _SplitInfoFuture:
"""Split generator for `beam.PCollection`."""
# TODO(tfds): Should try to add support to `max_examples_per_split`
beam_writer = writer_lib.BeamWriter(
serializer=example_serializer.ExampleSerializer(
self._features.get_serialized_info()
),
filename_template=filename_template,
hash_salt=split_name,
disable_shuffling=disable_shuffling,
shard_config=self._shard_config,
example_writer=self._example_writer,
ignore_duplicates=self._ignore_duplicates,
)
# TODO(weide): DO NOT SUBMIT
if disable_shuffling:
beam_writer = writer_lib.NoShuffleBeamWriter(
serializer=example_serializer.ExampleSerializer(
self._features.get_serialized_info()
),
file_format=file_adapters.FileFormat.from_value(
filename_template.filetype_suffix
),
filename_template=filename_template,
)
else:
beam_writer = writer_lib.BeamWriter(
serializer=example_serializer.ExampleSerializer(
self._features.get_serialized_info()
),
filename_template=filename_template,
hash_salt=split_name,
disable_shuffling=disable_shuffling,
shard_config=self._shard_config,
example_writer=self._example_writer,
ignore_duplicates=self._ignore_duplicates,
)

def _encode_example(key_ex, encode_fn=self._features.encode_example):
# We do not access self._features in this function to avoid pickling the
Expand Down
88 changes: 88 additions & 0 deletions tensorflow_datasets/core/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,3 +717,91 @@ def finalize(self) -> tuple[list[int], int]:
split_info_path.unlink()

return self._split_info["shard_lengths"], self._split_info["total_size"]


class NoShuffleBeamWriter:
"""Shuffles / writes Examples beam collection to sharded files."""

_OUTPUT_TAG_BUCKETS_LEN_SIZE = "tag_buckets_len_size"

def __init__(
self,
serializer: example_serializer.Serializer,
filename_template: naming.ShardedFileTemplate,
file_format: file_adapters.FileFormat,
):
"""Init BeamWriter.
Note that file "{filepath_prefix}.shard_lengths.json" is also created. It
contains a list with the number of examples in each final shard. Eg:
"[10,11,10,11]".
Args:
serializer: class that can serialize examples.
filename_template: template to format sharded filenames.
file_format: the file format to use.
"""
self._original_state = dict(
serializer=serializer,
filename_template=filename_template,
file_format=file_format,
)
self._file_format = file_format
self._file_adapter = file_adapters.ADAPTER_FOR_FORMAT[self._file_format]
self._filename_template = filename_template
self._serializer = serializer

@functools.lru_cache()
def _get_counter(self, name: str, namespace: str = "BeamWriter"):
return beam.metrics.Metrics.counter(namespace, name)

def inc_counter(self, name: str, value: int = 1) -> None:
self._get_counter(name).inc(value)

def __getstate__(self):
return self._original_state

def __setstate__(self, state):
self.__init__(**state)

def _serialize_example(
self,
key_example: tuple[hashing.HashKey, Example],
) -> bytes:
"""Returns (serialized_example)."""
_, example = key_example
self.inc_counter(name="serialized_examples")
return self._serializer.serialize_example(example)

def write_from_pcollection(self, examples_pcollection):
"""Returns PTransform to write (key, example) PCollection."""
return (
examples_pcollection
| "Serialize" >> beam.Map(self._serialize_example)
| "Write"
>> self._file_adapter.beam_sink(
filename_template=self._filename_template
)
)

def finalize(self) -> tuple[list[int], int]:
"""Returns the computed shard_lengths and total_size.
Returns:
List of length <number of shards> containing the number of examples stored
in each shard, and size of the files (in bytes).
"""
# We don't know the number of shards, the length of each shard, nor the
# total size, so we compute them here.
length_per_shard = {}
total_size_bytes = 0
prefix = epath.Path(self._filename_template.filepath_prefix())
for shard in self._filename_template.data_dir.glob(f"{prefix.name}*"):
length = self._file_adapter.num_examples(shard)
length_per_shard[shard] = length
total_size_bytes += shard.stat().length
shard_lengths: list[int] = []
for _, length in sorted(length_per_shard.items()):
shard_lengths.append(length)

return shard_lengths, total_size_bytes
Loading

0 comments on commit 0389eb5

Please sign in to comment.