diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 3209c263..ad11aff2 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -12,8 +12,10 @@ # limitations under the License. import random +from copy import deepcopy from typing import Any, Dict, Iterator, List, Optional, Sequence +import numpy as np from torch.utils.data import IterableDataset from litdata.streaming.dataset import StreamingDataset @@ -36,15 +38,38 @@ class CombinedStreamingDataset(IterableDataset): """ def __init__( - self, datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None + self, + datasets: List[StreamingDataset], + seed: int = 42, + weights: Optional[Sequence[float]] = None, + iterate_over_all: bool = True, ) -> None: + """ " + Arguments: + datasets: The list of the StreamingDataset to use. + seed: The random seed to initialize the sampler + weights: The sampling ratio for the datasets + iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. + Otherwise, it stops as soon as one raises a StopIteration. + """ + self._check_datasets(datasets) self._seed = seed self._datasets = datasets self._weights = weights + self._iterate_over_all = iterate_over_all + num_datasets = len(datasets) + if iterate_over_all and weights: + raise ValueError( + "When `iterate_over_all` is set to True, the weights argument shouldn't be provided.", + " Instead, it will be computed from the inverse of the dataset length.", + ) + + self._iterate_over_all = iterate_over_all + if weights is None: # Inversely weighted based on length self._weights = [1 / float(num_datasets)] * num_datasets @@ -56,6 +81,15 @@ def __init__( self._num_samples_yielded: Optional[List[int]] = None self._current_epoch = 0 + def __len__(self) -> Optional[int]: + if self._iterate_over_all: + return self._get_total_length() + return None + + # total length of the datasets + def _get_total_length(self) -> int: + return sum(len(d) for d in self._datasets) + def set_epoch(self, current_epoch: int) -> None: """Set the current epoch to the datasets on epoch starts. @@ -95,6 +129,7 @@ def __iter__(self) -> Iterator[Any]: self._weights, self._use_streaming_dataloader, num_samples_yielded, + self._iterate_over_all, ) return self._iterator @@ -132,14 +167,18 @@ def __init__( seed: int, weights: Sequence[float], use_streaming_dataloader: bool, - num_samples_yielded: Optional[Any] = None, + num_samples_yielded: Any, + iterate_over_all: bool = False, ) -> None: self._datasets = datasets self._dataset_iters = [iter(dataset) for dataset in datasets] self._dataset_indexes = list(range(len(datasets))) - self._num_samples_yielded = [0 for _ in range(len(datasets))] - self._weights = weights + self._num_samples_yielded = num_samples_yielded or [0 for _ in range(len(datasets))] + self._original_weights = deepcopy(weights) + self._weights = deepcopy(weights) self._rng = random.Random(seed) + self._iterate_over_all = iterate_over_all + self._is_done = False if num_samples_yielded is not None: self._num_samples_yielded = num_samples_yielded @@ -147,16 +186,42 @@ def __init__( self._rng.choices(self._dataset_indexes, weights=self._weights, k=1) self._use_streaming_dataloader = use_streaming_dataloader + self._is_done = False def __next__(self) -> Any: + if self._iterate_over_all: + while True: + try: + if len(self._dataset_indexes) > 1: + dataset_index = self._get_dataset_index() + elif len(self._dataset_indexes) == 1: + dataset_index = self._dataset_indexes[0] + return self._get_sample(dataset_index) + except StopIteration as e: + if len(self._dataset_indexes) == 1: + self._dataset_indexes = list(range(len(self._datasets))) + self._weights = deepcopy(self._original_weights) + raise e + + self._dataset_indexes.pop(dataset_index) + self._weights.pop(dataset_index) + self._weights /= np.sum(self._weights) + + # stop on the first iteration + return self._get_sample(self._get_dataset_index()) + + def _get_dataset_index(self) -> int: # randomly select a dataset index (dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1) + return dataset_index + + def _get_sample(self, dataset_index: int) -> Any: + # get the sample + sample = next(self._dataset_iters[dataset_index]) # keep track the sample was fetched self._num_samples_yielded[dataset_index] += 1 - sample = next(self._dataset_iters[dataset_index]) - # return a new sample if self._use_streaming_dataloader: return { diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 279a6fae..e8ad7c4b 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -18,19 +18,25 @@ def _check_datasets(self, datasets) -> None: def test_combined_dataset_num_samples_yield(): - dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset( + [range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5), iterate_over_all=False + ) dataset_iter = iter(dataset) data = list(dataset_iter) assert data == [0, 0, 1, 2, -1, -2, -3, 3, 4, 5, 6, -4, 7, 8, -5, -6, 9, -7, -8] - dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset( + [range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5), iterate_over_all=False + ) dataset_iter = iter(dataset) data = list(dataset_iter) assert data == [0, 0, -1, -2, -3, -4, -5, 1, -6, 2, -7, -8, 3, 4, -9, 5] - dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset( + [range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5), iterate_over_all=False + ) dataset_iter = iter(dataset) data = [next(dataset_iter) for _ in range(5)] @@ -40,6 +46,13 @@ def test_combined_dataset_num_samples_yield(): assert dataset._iterator._num_samples_yielded == [2, 4] +def test_combined_dataset_num_samples_yield_iterate_over_all(): + dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, iterate_over_all=True) + assert len(dataset) == 20 + samples = list(dataset) + assert len(samples) == 20 + + class TestStatefulDataset: def __init__(self, size, step): self.size = size @@ -69,14 +82,20 @@ def load_state_dict(self, state_dict): def test_combined_dataset_state_dict(): dataset = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) assert dataset.state_dict(0, 1) == {} dataset_iter = iter(dataset) assert dataset.state_dict(0, 1) == {"0": {"counter": 0}, "1": {"counter": 0}} dataset2 = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) assert dataset2.state_dict(0, 1) == {} @@ -111,7 +130,10 @@ def test_combined_dataset_state_dict(): ] dataset2 = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) assert dataset2.state_dict(0, 1) == {} dataset2_iter = iter(dataset2) @@ -136,7 +158,7 @@ def test_combined_dataset_state_dict(): ], ) def test_combined_dataset_normalizes_weights(weights, expected): - combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, seed=1) + combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, iterate_over_all=False, seed=1) assert combined_dataset._weights == expected @@ -159,21 +181,27 @@ def set_epoch(self, current_epoch): def test_combined_dataset(): dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[1.0, 0.0], iterate_over_all=False, seed=12345 + ) res = list(dataset) assert res == list(range(0, 10)) dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.0, 1.0], iterate_over_all=False, seed=12345 + ) res = list(dataset) assert res == list(range(10, 20)) dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345 + ) res = list(dataset) assert 9 in res or 19 in res @@ -183,7 +211,9 @@ def test_combined_dataset(): dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345 + ) dataloader = DataLoader(dataset, batch_size=2, num_workers=1) dataloader_iter = iter(dataloader) assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1])) @@ -193,7 +223,9 @@ def test_combined_dataset(): def test_combined_dataset_with_dataloader_and_one_worker(batch_size): dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345 + ) dataloader = StreamingDataLoader(dataset, num_workers=1, batch_size=batch_size, prefetch_factor=1) dataloader_iter = iter(dataloader) @@ -260,7 +292,9 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True) dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True) - dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = CombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345 + ) dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2) assert dataset1.current_epoch == 1 @@ -454,7 +488,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): { "dataset": { "0": { - "num_samples_yielded": 9, + "num_samples_yielded": 8, "num_workers": 3, "batch_size": 2, "current_epoch": 1, @@ -482,12 +516,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 0, "latest_worker_idx": 2, - "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]}, }, { "dataset": { "0": { - "num_samples_yielded": 11, + "num_samples_yielded": 9, "num_workers": 3, "batch_size": 2, "current_epoch": 1, @@ -515,12 +549,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 0, "latest_worker_idx": 0, - "num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]}, }, { "dataset": { "0": { - "num_samples_yielded": 13, + "num_samples_yielded": 10, "num_workers": 3, "batch_size": 2, "current_epoch": 1, @@ -548,7 +582,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 0, "latest_worker_idx": 1, - "num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]}, }, ] @@ -721,7 +755,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): { "dataset": { "0": { - "num_samples_yielded": 9, + "num_samples_yielded": 8, "num_workers": 3, "batch_size": 2, "current_epoch": 2, @@ -749,12 +783,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 1, "latest_worker_idx": 2, - "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]}, }, { "dataset": { "0": { - "num_samples_yielded": 11, + "num_samples_yielded": 9, "num_workers": 3, "batch_size": 2, "current_epoch": 2, @@ -782,12 +816,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 1, "latest_worker_idx": 0, - "num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]}, }, { "dataset": { "0": { - "num_samples_yielded": 13, + "num_samples_yielded": 10, "num_workers": 3, "batch_size": 2, "current_epoch": 2, @@ -815,7 +849,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 1, "latest_worker_idx": 1, - "num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]}, }, ] diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index d719bce6..d75097a6 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -49,7 +49,10 @@ def _check_datasets(self, datasets) -> None: def test_streaming_dataloader(): dataset = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) dataloader = StreamingDataLoader(dataset, batch_size=2) dataloader_iter = iter(dataloader) @@ -77,7 +80,7 @@ def test_streaming_dataloader(): "dataset": {"0": {"counter": 10}, "1": {"counter": 9}}, "current_epoch": 0, "latest_worker_idx": 0, - "num_samples_yielded": {0: [11, 9]}, + "num_samples_yielded": {0: [10, 9]}, } @@ -87,7 +90,10 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch): monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True) dataset = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) dataloader = StreamingDataLoader( dataset, batch_size=2, profile_batches=profile, profile_dir=str(tmpdir), num_workers=1 @@ -102,7 +108,7 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch): def test_dataloader_shuffle(): dataset = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5), iterate_over_all=False ) assert dataset._datasets[0].shuffle is None assert dataset._datasets[1].shuffle is None