Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to simple_parsing #5572

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 86 additions & 85 deletions tensorflow_datasets/scripts/cli/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

"""`tfds build` command."""

import argparse
import dataclasses
import functools
import importlib
import itertools
Expand All @@ -25,84 +25,84 @@
from typing import Any, Dict, Iterator, Optional, Tuple, Type, Union

from absl import logging
import simple_parsing
import tensorflow_datasets as tfds
from tensorflow_datasets.scripts.cli import cli_utils

# pylint: disable=logging-fstring-interpolation


def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access
"""Add subparser for `build` command."""
build_parser = parsers.add_parser(
'build', help='Commands for downloading and preparing datasets.'
)
build_parser.add_argument(
'datasets', # Positional arguments
type=str,
nargs='*',
help=(
'Name(s) of the dataset(s) to build. Default to current dir. '
'See https://www.tensorflow.org/datasets/cli for accepted values.'
),
)
build_parser.add_argument( # Also accept keyword arguments
'--datasets',
type=str,
nargs='+',
dest='datasets_keyword',
help='Datasets can also be provided as keyword argument.',
)
@dataclasses.dataclass(frozen=True, kw_only=True)
class _AutomationGroup:
"""Used by automated scripts.

cli_utils.add_debug_argument_group(build_parser)
cli_utils.add_path_argument_group(build_parser)
cli_utils.add_generation_argument_group(build_parser)
cli_utils.add_publish_argument_group(build_parser)
Attributes:
exclude_datasets: If set, generate all datasets except the one defined here.
Comma separated list of datasets to exclude.
experimental_latest_version: Build the latest Version(experiments=...)
available rather than default version.
"""

# **** Automation options ****
automation_group = build_parser.add_argument_group(
'Automation', description='Used by automated scripts.'
)
automation_group.add_argument(
'--exclude_datasets',
type=str,
help=(
'If set, generate all datasets except the one defined here. '
'Comma separated list of datasets to exclude. '
),
exclude_datasets: list[str] = cli_utils.comma_separated_list_field()
experimental_latest_version: bool = False


@dataclasses.dataclass(frozen=True, kw_only=True)
class CmdArgs:
"""Commands for downloading and preparing datasets.

Attributes:
datasets: Name(s) of the dataset(s) to build. Default to current dir. See
https://www.tensorflow.org/datasets/cli for accepted values.
datasets_keyword: Datasets can also be provided as keyword argument.
debug: Debug & tests options.
path: Paths options.
generation: Generation options.
publish: Publishing options.
automation: Automation options.
"""

datasets: list[str] = simple_parsing.field(
positional=True, default_factory=list, nargs='*'
)
automation_group.add_argument(
'--experimental_latest_version',
action='store_true',
help=(
'Build the latest Version(experiments=...) available rather than '
'default version.'
),
datasets_keyword: list[str] = simple_parsing.field(
alias='datasets', default_factory=list, nargs='*'
)
debug: cli_utils.DebugGroup = simple_parsing.field(prefix='')
path: cli_utils.PathGroup = simple_parsing.field(prefix='')
generation: cli_utils.GenerationGroup = simple_parsing.field(prefix='')
publish: cli_utils.PublishGroup = simple_parsing.field(prefix='')
automation: _AutomationGroup = simple_parsing.field(prefix='')

build_parser.set_defaults(subparser_fn=_build_datasets)
def execute(self):
_build_datasets(self)


def _build_datasets(args: argparse.Namespace) -> None:
def _build_datasets(args: CmdArgs) -> None:
"""Build the given datasets."""
# Eventually register additional datasets imports
if args.imports:
list(importlib.import_module(m) for m in args.imports.split(','))
if args.generation.imports:
list(importlib.import_module(m) for m in args.generation.imports)

# Select datasets to generate
datasets = (args.datasets or []) + (args.datasets_keyword or [])
if args.exclude_datasets: # Generate all datasets if `--exclude_datasets` set
datasets = args.datasets + args.datasets_keyword
if (
args.automation.exclude_datasets
): # Generate all datasets if `--exclude_datasets` set
if datasets:
raise ValueError("--exclude_datasets can't be used with `datasets`")
datasets = set(tfds.list_builders(with_community_datasets=False)) - set(
args.exclude_datasets.split(',')
args.automation.exclude_datasets
)
datasets = sorted(datasets) # `set` is not deterministic
else:
datasets = datasets or [''] # Empty string for default

# Import builder classes
builders_cls_and_kwargs = [
_get_builder_cls_and_kwargs(dataset, has_imports=bool(args.imports))
_get_builder_cls_and_kwargs(
dataset, has_imports=bool(args.generation.imports)
)
for dataset in datasets
]

Expand All @@ -112,19 +112,20 @@ def _build_datasets(args: argparse.Namespace) -> None:
for (builder_cls, builder_kwargs) in builders_cls_and_kwargs
))
process_builder_fn = functools.partial(
_download if args.download_only else _download_and_prepare, args
_download if args.generation.download_only else _download_and_prepare,
args,
)

if args.num_processes == 1:
if args.generation.num_processes == 1:
for builder in builders:
process_builder_fn(builder)
else:
with multiprocessing.Pool(args.num_processes) as pool:
with multiprocessing.Pool(args.generation.num_processes) as pool:
pool.map(process_builder_fn, builders)


def _make_builders(
args: argparse.Namespace,
args: CmdArgs,
builder_cls: Type[tfds.core.DatasetBuilder],
builder_kwargs: Dict[str, Any],
) -> Iterator[tfds.core.DatasetBuilder]:
Expand All @@ -139,7 +140,7 @@ def _make_builders(
Initialized dataset builders.
"""
# Eventually overwrite version
if args.experimental_latest_version:
if args.automation.experimental_latest_version:
if 'version' in builder_kwargs:
raise ValueError(
"Can't have both `--experimental_latest` and version set (`:1.0.0`)"
Expand All @@ -150,19 +151,19 @@ def _make_builders(
builder_kwargs['config'] = _get_config_name(
builder_cls=builder_cls,
config_kwarg=builder_kwargs.get('config'),
config_name=args.config,
config_idx=args.config_idx,
config_name=args.generation.config,
config_idx=args.generation.config_idx,
)

if args.file_format:
builder_kwargs['file_format'] = args.file_format
if args.generation.file_format:
builder_kwargs['file_format'] = args.generation.file_format

make_builder = functools.partial(
_make_builder,
builder_cls,
overwrite=args.overwrite,
fail_if_exists=args.fail_if_exists,
data_dir=args.data_dir,
overwrite=args.debug.overwrite,
fail_if_exists=args.debug.fail_if_exists,
data_dir=args.path.data_dir,
**builder_kwargs,
)

Expand Down Expand Up @@ -301,7 +302,7 @@ def _make_builder(


def _download(
args: argparse.Namespace,
args: CmdArgs,
builder: tfds.core.DatasetBuilder,
) -> None:
"""Downloads all files of the given builder."""
Expand All @@ -323,7 +324,7 @@ def _download(
if builder.MAX_SIMULTANEOUS_DOWNLOADS is not None:
max_simultaneous_downloads = builder.MAX_SIMULTANEOUS_DOWNLOADS

download_dir = args.download_dir or os.path.join(
download_dir = args.path.download_dir or os.path.join(
builder._data_dir_root, 'downloads' # pylint: disable=protected-access
)
dl_manager = tfds.download.DownloadManager(
Expand All @@ -345,51 +346,51 @@ def _download(


def _download_and_prepare(
args: argparse.Namespace,
args: CmdArgs,
builder: tfds.core.DatasetBuilder,
) -> None:
"""Generate a single builder."""
cli_utils.download_and_prepare(
builder=builder,
download_config=_make_download_config(args, dataset_name=builder.name),
download_dir=args.download_dir,
publish_dir=args.publish_dir,
skip_if_published=args.skip_if_published,
overwrite=args.overwrite,
download_dir=args.path.download_dir,
publish_dir=args.publish.publish_dir,
skip_if_published=args.publish.skip_if_published,
overwrite=args.debug.overwrite,
)


def _make_download_config(
args: argparse.Namespace,
args: CmdArgs,
dataset_name: str,
) -> tfds.download.DownloadConfig:
"""Generate the download and prepare configuration."""
# Load the download config
manual_dir = args.manual_dir
if args.add_name_to_manual_dir:
manual_dir = args.path.manual_dir
if args.path.add_name_to_manual_dir:
manual_dir = manual_dir / dataset_name

kwargs = {}
if args.max_shard_size_mb:
kwargs['max_shard_size'] = args.max_shard_size_mb << 20
if args.download_config:
kwargs.update(json.loads(args.download_config))
if args.generation.max_shard_size_mb:
kwargs['max_shard_size'] = args.generation.max_shard_size_mb << 20
if args.generation.download_config:
kwargs.update(json.loads(args.generation.download_config))

if 'download_mode' in kwargs:
kwargs['download_mode'] = tfds.download.GenerateMode(
kwargs['download_mode']
)
else:
kwargs['download_mode'] = tfds.download.GenerateMode.REUSE_DATASET_IF_EXISTS
if args.update_metadata_only:
if args.generation.update_metadata_only:
kwargs['download_mode'] = tfds.download.GenerateMode.UPDATE_DATASET_INFO

dl_config = tfds.download.DownloadConfig(
extract_dir=args.extract_dir,
extract_dir=args.path.extract_dir,
manual_dir=manual_dir,
max_examples_per_split=args.max_examples_per_split,
register_checksums=args.register_checksums,
force_checksums_validation=args.force_checksums_validation,
max_examples_per_split=args.debug.max_examples_per_split,
register_checksums=args.generation.register_checksums,
force_checksums_validation=args.generation.force_checksums_validation,
**kwargs,
)

Expand All @@ -400,9 +401,9 @@ def _make_download_config(
beam = None

if beam is not None:
if args.beam_pipeline_options:
if args.generation.beam_pipeline_options:
dl_config.beam_options = beam.options.pipeline_options.PipelineOptions(
flags=[f'--{opt}' for opt in args.beam_pipeline_options.split(',')]
flags=[f'--{opt}' for opt in args.generation.beam_pipeline_options]
)

return dl_config
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_datasets/scripts/cli/build_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_download_only():
)
def test_make_download_config(args: str, download_config_kwargs):
args = main._parse_flags(f'tfds build x {args}'.split())
actual = build_lib._make_download_config(args, dataset_name='x')
actual = build_lib._make_download_config(args.command, dataset_name='x')
# Ignore the beam runner
actual = actual.replace(beam_runner=None)
expected = tfds.download.DownloadConfig(**download_config_kwargs)
Expand Down
Loading
Loading