From 0e429986c651726b96158f5e9ffce27960e11b75 Mon Sep 17 00:00:00 2001 From: Ryan Mukherjee Date: Tue, 30 May 2023 06:02:50 +0200 Subject: [PATCH] avoid unnecessary workers with sequential `CombinedLoader ` (#17639) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí Co-authored-by: Adrian Wälchli (cherry picked from commit c3ad7568e114c4cd357274f8f86c563005fb850c) --- src/lightning/pytorch/CHANGELOG.md | 3 ++ .../pytorch/utilities/combined_loader.py | 15 ++++++-- tests/tests_pytorch/loops/test_loops.py | 6 ++-- .../utilities/test_combined_loader.py | 34 +++++++++++++++++++ 4 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 649b0f487f81e..9484188f60e89 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -14,6 +14,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- `CombinedLoader` only starts DataLoader workers when necessary when operating in sequential mode ([#17639](https://github.com/Lightning-AI/lightning/pull/17639)) + + - Fixed a potential bug with uploading model checkpoints to Neptune.ai by uploading files from stream ([#17430](https://github.com/Lightning-AI/lightning/pull/17430)) diff --git a/src/lightning/pytorch/utilities/combined_loader.py b/src/lightning/pytorch/utilities/combined_loader.py index f8e5ccc7577d5..98e37253fa2ef 100644 --- a/src/lightning/pytorch/utilities/combined_loader.py +++ b/src/lightning/pytorch/utilities/combined_loader.py @@ -108,7 +108,7 @@ def limits(self, limits: Optional[List[Union[int, float]]]) -> None: self._limits = limits def __next__(self) -> Tuple[Any, int, int]: - n = len(self.iterators) + n = len(self.iterables) if n == 0 or self._iterator_idx >= n: raise StopIteration @@ -120,7 +120,7 @@ def __next__(self) -> Tuple[Any, int, int]: raise StopIteration try: - out = next(self.iterators[self._iterator_idx]) + out = next(self.iterators[0]) index = self._idx self._idx += 1 # batch, batch_idx, dataloader_idx @@ -131,9 +131,9 @@ def __next__(self) -> Tuple[Any, int, int]: return self.__next__() def __iter__(self) -> Self: - super().__iter__() self._iterator_idx = 0 self._idx = 0 + self._load_current_iterator() return self def reset(self) -> None: @@ -141,9 +141,18 @@ def reset(self) -> None: self._iterator_idx = 0 self._idx = 0 + def _load_current_iterator(self) -> None: + # Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily + if self._iterator_idx < len(self.iterables): + self.iterators = [iter(self.iterables[self._iterator_idx])] + else: + # No more iterables to step through, return an empty list + self.iterators = [] + def _use_next_iterator(self) -> None: self._iterator_idx += 1 self._idx = 0 + self._load_current_iterator() class _MaxSize(_ModeIterator[List]): diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 15efb392a2ba8..e38d4459e2ed3 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -844,8 +844,7 @@ def _get_iterator(self): # iterable check 0, # epoch ends - 1, - # teardown + 0, 1, ] else: @@ -855,9 +854,8 @@ def _get_iterator(self): # iterable check 0, # epoch ends + 0, 1, 2, - # teardown - 3, ] assert val_dataloader.shutdown_workers_epochs == expected diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index ae08ada88ba84..9239d9f60c958 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -306,6 +306,40 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader assert idx == expected - 1 +@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"]) +def test_combined_loader_simultaneous_workers(mode): + """Test `CombinedLoader` to check how it initializes dataloader workers.""" + + class TestDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.workers_active = False + + def _get_iterator(self): + self.workers_active = True + return super()._get_iterator() + + def _shutdown_workers(self): + self.workers_active = False + super()._shutdown_workers() + + loaders = [ + TestDataLoader(range(10), batch_size=2, num_workers=0), + TestDataLoader(range(20), batch_size=2, num_workers=0), + ] + combined_loader = CombinedLoader(loaders, mode) + # Start the dataloader + _ = iter(combined_loader) + + workers_active = [] + for loader in loaders: + workers_active.append(loader.workers_active) + + # Sequential only starts the first dataloader, other modes start both + expected = [True, False] if mode == "sequential" else [True, True] + assert workers_active == expected + + @pytest.mark.parametrize( ("limits", "expected"), [