diff --git a/scenic/dataset_lib/dataset_utils.py b/scenic/dataset_lib/dataset_utils.py index 14d94b2c..943fca26 100644 --- a/scenic/dataset_lib/dataset_utils.py +++ b/scenic/dataset_lib/dataset_utils.py @@ -22,7 +22,7 @@ import dataclasses import functools import itertools -from typing import Any, Callable, Dict, Iterator, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Iterator, Optional, Sequence, Union from absl import logging from flax.training import common_utils @@ -35,6 +35,8 @@ PyTree = Any DatasetIterator = Union[Iterator[Any], Dict[str, Iterator[Any]]] DatasetIteratorProvider = Callable[[], DatasetIterator] +DatasetIteratorType = DatasetIterator | DatasetIteratorProvider +DatasetType = Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] @dataclasses.dataclass(frozen=True) @@ -68,14 +70,23 @@ class Dataset: classification tasks, `num_classes` is used for the configuring head of the model. """ - train_iter: DatasetIterator | DatasetIteratorProvider | None = None - valid_iter: DatasetIterator | DatasetIteratorProvider | None = None - test_iter: DatasetIterator | DatasetIteratorProvider | None = None + train_iter: DatasetIteratorType | None = None + valid_iter: DatasetIteratorType | None = None + test_iter: DatasetIteratorType | None = None meta_data: Dict[str, Any] = dataclasses.field(default_factory=dict) - train_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None - valid_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None - test_ds: Union[tf.data.Dataset, Dict[str, tf.data.Dataset]] | None = None + train_ds: DatasetType | None = None + valid_ds: DatasetType | None = None + test_ds: DatasetType | None = None + + # Multiple dataset support. + train_multi_iter: List[DatasetIteratorType] | None = None + valid_multi_iter: List[DatasetIteratorType] | None = None + test_multi_iter: List[DatasetIteratorType] | None = None + + train_multi_ds: List[DatasetType] | None = None + valid_multi_ds: List[DatasetType] | None = None + test_multi_ds: List[DatasetType] | None = None def maybe_pad_batch(batch: Dict[str, PyTree],