From 5f7263e1838404e46e58446b172858e41c206907 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 May 2024 13:05:03 +0000 Subject: [PATCH 1/7] update --- src/litdata/streaming/combined.py | 71 ++++++++++++++++++++++++++++--- tests/streaming/test_combined.py | 31 ++++++++------ 2 files changed, 85 insertions(+), 17 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 3209c263..78fbe8b1 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -14,6 +14,7 @@ import random from typing import Any, Dict, Iterator, List, Optional, Sequence +import numpy as np from torch.utils.data import IterableDataset from litdata.streaming.dataset import StreamingDataset @@ -36,15 +37,38 @@ class CombinedStreamingDataset(IterableDataset): """ def __init__( - self, datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None + self, + datasets: List[StreamingDataset], + seed: int = 42, + weights: Optional[Sequence[float]] = None, + iterate_over_all: bool = False, ) -> None: + """ " + Arguments: + datasets: The list of the StreamingDataset to use. + seed: The random seed to initialize the sampler + weights: The sampling ratio for the datasets + iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. + Otherwise, it stops as soon as one raises a StopIteration. + """ + self._check_datasets(datasets) self._seed = seed self._datasets = datasets self._weights = weights + self._iterate_over_all = iterate_over_all + num_datasets = len(datasets) + if iterate_over_all and weights: + raise ValueError( + "When `iterate_over_all` is set to True, the weights argument shouldn't be provided.", + " Instead, it will be computed from the inverse of the dataset length.", + ) + + self._iterate_over_all = iterate_over_all + if weights is None: # Inversely weighted based on length self._weights = [1 / float(num_datasets)] * num_datasets @@ -56,6 +80,15 @@ def __init__( self._num_samples_yielded: Optional[List[int]] = None self._current_epoch = 0 + def __len__(self) -> Optional[int]: + if self._iterate_over_all: + return self._get_total_length() + return None + + # total length of the datasets + def _get_total_length(self) -> int: + return sum(len(d) for d in self._datasets) + def set_epoch(self, current_epoch: int) -> None: """Set the current epoch to the datasets on epoch starts. @@ -95,6 +128,7 @@ def __iter__(self) -> Iterator[Any]: self._weights, self._use_streaming_dataloader, num_samples_yielded, + self._iterate_over_all, ) return self._iterator @@ -132,14 +166,17 @@ def __init__( seed: int, weights: Sequence[float], use_streaming_dataloader: bool, - num_samples_yielded: Optional[Any] = None, + num_samples_yielded: any, + iterate_over_all: bool = False, ) -> None: self._datasets = datasets self._dataset_iters = [iter(dataset) for dataset in datasets] self._dataset_indexes = list(range(len(datasets))) - self._num_samples_yielded = [0 for _ in range(len(datasets))] + self._num_samples_yielded = num_samples_yielded or [0 for _ in range(len(datasets))] self._weights = weights self._rng = random.Random(seed) + self._iterate_over_all = iterate_over_all + self._is_done = False if num_samples_yielded is not None: self._num_samples_yielded = num_samples_yielded @@ -147,16 +184,40 @@ def __init__( self._rng.choices(self._dataset_indexes, weights=self._weights, k=1) self._use_streaming_dataloader = use_streaming_dataloader + self._is_done = False def __next__(self) -> Any: + if self._iterate_over_all: + while True: + try: + if len(self._dataset_indexes) > 1: + dataset_index = self._get_dataset_index() + elif len(self._dataset_indexes) == 1: + dataset_index = self._dataset_indexes[0] + return self._get_sample(dataset_index) + except StopIteration as e: + if len(self._dataset_indexes) == 1: + raise e + + self._dataset_indexes.pop(dataset_index) + self._weights.pop(dataset_index) + self._weights /= np.sum(self._weights) + + # stop on the first iteration + return self._get_sample(self._get_dataset_index()) + + def _get_dataset_index(self) -> int: # randomly select a dataset index (dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1) + return dataset_index + + def _get_sample(self, dataset_index: int) -> Any: + # get the sample + sample = next(self._dataset_iters[dataset_index]) # keep track the sample was fetched self._num_samples_yielded[dataset_index] += 1 - sample = next(self._dataset_iters[dataset_index]) - # return a new sample if self._use_streaming_dataloader: return { diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 279a6fae..1f49d60e 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -40,6 +40,13 @@ def test_combined_dataset_num_samples_yield(): assert dataset._iterator._num_samples_yielded == [2, 4] +def test_combined_dataset_num_samples_yield_iterate_over_all(): + dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, iterate_over_all=True) + assert len(dataset) == 20 + dataset_iter = iter(dataset) + assert len(e for e in dataset_iter) == 20 + + class TestStatefulDataset: def __init__(self, size, step): self.size = size @@ -454,7 +461,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): { "dataset": { "0": { - "num_samples_yielded": 9, + "num_samples_yielded": 8, "num_workers": 3, "batch_size": 2, "current_epoch": 1, @@ -482,12 +489,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 0, "latest_worker_idx": 2, - "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]}, }, { "dataset": { "0": { - "num_samples_yielded": 11, + "num_samples_yielded": 9, "num_workers": 3, "batch_size": 2, "current_epoch": 1, @@ -515,12 +522,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 0, "latest_worker_idx": 0, - "num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]}, }, { "dataset": { "0": { - "num_samples_yielded": 13, + "num_samples_yielded": 10, "num_workers": 3, "batch_size": 2, "current_epoch": 1, @@ -548,7 +555,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 0, "latest_worker_idx": 1, - "num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]}, }, ] @@ -721,7 +728,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): { "dataset": { "0": { - "num_samples_yielded": 9, + "num_samples_yielded": 8, "num_workers": 3, "batch_size": 2, "current_epoch": 2, @@ -749,12 +756,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 1, "latest_worker_idx": 2, - "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [3, 1], 1: [3, 1], 2: [2, 1]}, }, { "dataset": { "0": { - "num_samples_yielded": 11, + "num_samples_yielded": 9, "num_workers": 3, "batch_size": 2, "current_epoch": 2, @@ -782,12 +789,12 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 1, "latest_worker_idx": 0, - "num_samples_yielded": {0: [5, 1], 1: [3, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [4, 1], 1: [3, 1], 2: [2, 1]}, }, { "dataset": { "0": { - "num_samples_yielded": 13, + "num_samples_yielded": 10, "num_workers": 3, "batch_size": 2, "current_epoch": 2, @@ -815,7 +822,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): }, "current_epoch": 1, "latest_worker_idx": 1, - "num_samples_yielded": {0: [5, 1], 1: [5, 1], 2: [3, 1]}, + "num_samples_yielded": {0: [4, 1], 1: [4, 1], 2: [2, 1]}, }, ] From e1b786757998efec4d945ea6c294888b53b32659 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 May 2024 13:48:14 +0000 Subject: [PATCH 2/7] update --- src/litdata/streaming/combined.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 78fbe8b1..865cd1ac 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -12,6 +12,7 @@ # limitations under the License. import random +from copy import deepcopy from typing import Any, Dict, Iterator, List, Optional, Sequence import numpy as np @@ -41,7 +42,7 @@ def __init__( datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None, - iterate_over_all: bool = False, + iterate_over_all: bool = True, ) -> None: """ " Arguments: @@ -173,7 +174,8 @@ def __init__( self._dataset_iters = [iter(dataset) for dataset in datasets] self._dataset_indexes = list(range(len(datasets))) self._num_samples_yielded = num_samples_yielded or [0 for _ in range(len(datasets))] - self._weights = weights + self._original_weights = deepcopy(weights) + self._weights = deepcopy(weights) self._rng = random.Random(seed) self._iterate_over_all = iterate_over_all self._is_done = False @@ -197,6 +199,8 @@ def __next__(self) -> Any: return self._get_sample(dataset_index) except StopIteration as e: if len(self._dataset_indexes) == 1: + self._dataset_indexes = list(range(len(self._datasets))) + self._weights = deepcopy(self._original_weights) raise e self._dataset_indexes.pop(dataset_index) From 735091ca28548ec52784444921408b43ba806c0f Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 May 2024 14:00:56 +0000 Subject: [PATCH 3/7] update --- tests/streaming/test_combined.py | 57 +++++++++++++++++++++++--------- 1 file changed, 42 insertions(+), 15 deletions(-) diff --git a/tests/streaming/test_combined.py b/tests/streaming/test_combined.py index 1f49d60e..e8ad7c4b 100644 --- a/tests/streaming/test_combined.py +++ b/tests/streaming/test_combined.py @@ -18,19 +18,25 @@ def _check_datasets(self, datasets) -> None: def test_combined_dataset_num_samples_yield(): - dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset( + [range(10), range(0, -10, -1)], 42, weights=(0.5, 0.5), iterate_over_all=False + ) dataset_iter = iter(dataset) data = list(dataset_iter) assert data == [0, 0, 1, 2, -1, -2, -3, 3, 4, 5, 6, -4, 7, 8, -5, -6, 9, -7, -8] - dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset( + [range(10), range(0, -10, -1)], 37, weights=(0.5, 0.5), iterate_over_all=False + ) dataset_iter = iter(dataset) data = list(dataset_iter) assert data == [0, 0, -1, -2, -3, -4, -5, 1, -6, 2, -7, -8, 3, 4, -9, 5] - dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5)) + dataset = TestCombinedStreamingDataset( + [range(10), range(0, -10, -1)], 23, weights=(0.5, 0.5), iterate_over_all=False + ) dataset_iter = iter(dataset) data = [next(dataset_iter) for _ in range(5)] @@ -43,8 +49,8 @@ def test_combined_dataset_num_samples_yield(): def test_combined_dataset_num_samples_yield_iterate_over_all(): dataset = TestCombinedStreamingDataset([range(10), range(0, -10, -1)], 42, iterate_over_all=True) assert len(dataset) == 20 - dataset_iter = iter(dataset) - assert len(e for e in dataset_iter) == 20 + samples = list(dataset) + assert len(samples) == 20 class TestStatefulDataset: @@ -76,14 +82,20 @@ def load_state_dict(self, state_dict): def test_combined_dataset_state_dict(): dataset = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) assert dataset.state_dict(0, 1) == {} dataset_iter = iter(dataset) assert dataset.state_dict(0, 1) == {"0": {"counter": 0}, "1": {"counter": 0}} dataset2 = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) assert dataset2.state_dict(0, 1) == {} @@ -118,7 +130,10 @@ def test_combined_dataset_state_dict(): ] dataset2 = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) assert dataset2.state_dict(0, 1) == {} dataset2_iter = iter(dataset2) @@ -143,7 +158,7 @@ def test_combined_dataset_state_dict(): ], ) def test_combined_dataset_normalizes_weights(weights, expected): - combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, seed=1) + combined_dataset = TestCombinedStreamingDataset([[1], [2, 3]], weights=weights, iterate_over_all=False, seed=1) assert combined_dataset._weights == expected @@ -166,21 +181,27 @@ def set_epoch(self, current_epoch): def test_combined_dataset(): dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[1.0, 0.0], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[1.0, 0.0], iterate_over_all=False, seed=12345 + ) res = list(dataset) assert res == list(range(0, 10)) dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.0, 1.0], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.0, 1.0], iterate_over_all=False, seed=12345 + ) res = list(dataset) assert res == list(range(10, 20)) dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345 + ) res = list(dataset) assert 9 in res or 19 in res @@ -190,7 +211,9 @@ def test_combined_dataset(): dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345 + ) dataloader = DataLoader(dataset, batch_size=2, num_workers=1) dataloader_iter = iter(dataloader) assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1])) @@ -200,7 +223,9 @@ def test_combined_dataset(): def test_combined_dataset_with_dataloader_and_one_worker(batch_size): dataset1 = SimpleDataset(0, 10) dataset2 = SimpleDataset(10, 20) - dataset = TestCombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = TestCombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345 + ) dataloader = StreamingDataLoader(dataset, num_workers=1, batch_size=batch_size, prefetch_factor=1) dataloader_iter = iter(dataloader) @@ -267,7 +292,9 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir): dataset1 = StreamingDataset(input_dir=Dir(cache_dir_1, data_dir_1), shuffle=True) dataset2 = StreamingDataset(input_dir=Dir(cache_dir_2, data_dir_2), shuffle=True) - dataset = CombinedStreamingDataset(datasets=[dataset1, dataset2], weights=[0.5, 0.5], seed=12345) + dataset = CombinedStreamingDataset( + datasets=[dataset1, dataset2], weights=[0.5, 0.5], iterate_over_all=False, seed=12345 + ) dataloader = StreamingDataLoader(dataset, num_workers=3, batch_size=2) assert dataset1.current_epoch == 1 From ccdda1b245e300499c38a7d9c32027e5b4feabd3 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 May 2024 14:12:35 +0000 Subject: [PATCH 4/7] update --- status.json | 2 +- tests/streaming/test_dataloader.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/status.json b/status.json index 72aa5495..1464eea5 100644 --- a/status.json +++ b/status.json @@ -1 +1 @@ -{ "progress": "20.0%" } +{ "progress": "50.0%" } diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index d719bce6..d0b79e96 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -49,7 +49,10 @@ def _check_datasets(self, datasets) -> None: def test_streaming_dataloader(): dataset = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) dataloader = StreamingDataLoader(dataset, batch_size=2) dataloader_iter = iter(dataloader) @@ -87,7 +90,10 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch): monkeypatch.setattr(streaming_dataloader_module, "_VIZ_TRACKER_AVAILABLE", True) dataset = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], + 42, + weights=(0.5, 0.5), + iterate_over_all=False, ) dataloader = StreamingDataLoader( dataset, batch_size=2, profile_batches=profile, profile_dir=str(tmpdir), num_workers=1 @@ -102,7 +108,7 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch): def test_dataloader_shuffle(): dataset = TestCombinedStreamingDataset( - [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5) + [TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5), iterate_over_all=False ) assert dataset._datasets[0].shuffle is None assert dataset._datasets[1].shuffle is None From 1d432dacc9ee0de50a76a0e9c448a341731feb10 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 May 2024 14:14:34 +0000 Subject: [PATCH 5/7] update --- status.json | 1 - 1 file changed, 1 deletion(-) delete mode 100644 status.json diff --git a/status.json b/status.json deleted file mode 100644 index 1464eea5..00000000 --- a/status.json +++ /dev/null @@ -1 +0,0 @@ -{ "progress": "50.0%" } From 8fe8423ba0ce7a9976692deda9d2987b0cc0f990 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 May 2024 14:25:14 +0000 Subject: [PATCH 6/7] update --- tests/streaming/test_dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/streaming/test_dataloader.py b/tests/streaming/test_dataloader.py index d0b79e96..d75097a6 100644 --- a/tests/streaming/test_dataloader.py +++ b/tests/streaming/test_dataloader.py @@ -80,7 +80,7 @@ def test_streaming_dataloader(): "dataset": {"0": {"counter": 10}, "1": {"counter": 9}}, "current_epoch": 0, "latest_worker_idx": 0, - "num_samples_yielded": {0: [11, 9]}, + "num_samples_yielded": {0: [10, 9]}, } From e4ca2578c8c7a73133eb93e7e3ff6b610e70de6b Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 7 May 2024 14:33:00 +0000 Subject: [PATCH 7/7] update --- src/litdata/streaming/combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 865cd1ac..ad11aff2 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -167,7 +167,7 @@ def __init__( seed: int, weights: Sequence[float], use_streaming_dataloader: bool, - num_samples_yielded: any, + num_samples_yielded: Any, iterate_over_all: bool = False, ) -> None: self._datasets = datasets