Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid unnecessary workers with sequential CombinedLoader #17639

Merged
merged 12 commits into from
May 30, 2023
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- `CombinedLoader` only starts DataLoader workers when necessary when operating in sequential mode ([#17639](https://github.com/Lightning-AI/lightning/pull/17639))


- Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308))


Expand Down
15 changes: 12 additions & 3 deletions src/lightning/pytorch/utilities/combined_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def limits(self, limits: Optional[List[Union[int, float]]]) -> None:
self._limits = limits

def __next__(self) -> Tuple[Any, int, int]:
n = len(self.iterators)
n = len(self.iterables)
if n == 0 or self._iterator_idx >= n:
raise StopIteration

Expand All @@ -120,7 +120,7 @@ def __next__(self) -> Tuple[Any, int, int]:
raise StopIteration

try:
out = next(self.iterators[self._iterator_idx])
out = next(self.iterators[0])
index = self._idx
self._idx += 1
# batch, batch_idx, dataloader_idx
Expand All @@ -131,19 +131,28 @@ def __next__(self) -> Tuple[Any, int, int]:
return self.__next__()

def __iter__(self) -> Self:
super().__iter__()
self._iterator_idx = 0
self._idx = 0
self._load_current_iterator()
return self

def reset(self) -> None:
super().reset()
self._iterator_idx = 0
self._idx = 0

def _load_current_iterator(self) -> None:
# Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily
if self._iterator_idx < len(self.iterables):
self.iterators = [iter(self.iterables[self._iterator_idx])]
else:
# No more iterables to step through, return an empty list
self.iterators = []

def _use_next_iterator(self) -> None:
self._iterator_idx += 1
self._idx = 0
self._load_current_iterator()


class _MaxSize(_ModeIterator[List]):
Expand Down
6 changes: 2 additions & 4 deletions tests/tests_pytorch/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,7 @@ def _get_iterator(self):
# iterable check
0,
# epoch ends
1,
# teardown
0,
1,
mukhery marked this conversation as resolved.
Show resolved Hide resolved
]
else:
Expand All @@ -855,9 +854,8 @@ def _get_iterator(self):
# iterable check
0,
# epoch ends
0,
1,
2,
# teardown
3,
]
assert val_dataloader.shutdown_workers_epochs == expected
34 changes: 34 additions & 0 deletions tests/tests_pytorch/utilities/test_combined_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,40 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
assert idx == expected - 1


@pytest.mark.parametrize("mode", ["min_size", "max_size_cycle", "max_size", "sequential"])
def test_combined_loader_simultaneous_workers(mode):
"""Test `CombinedLoader` to check how it initializes dataloader workers."""

class TestDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.workers_active = False

def _get_iterator(self):
self.workers_active = True
return super()._get_iterator()

def _shutdown_workers(self):
self.workers_active = False
super()._shutdown_workers()

loaders = [
TestDataLoader(range(10), batch_size=2, num_workers=0),
TestDataLoader(range(20), batch_size=2, num_workers=0),
]
combined_loader = CombinedLoader(loaders, mode)
# Start the dataloader
_ = iter(combined_loader)

workers_active = []
for loader in loaders:
workers_active.append(loader.workers_active)

# Sequential only starts the first dataloader, other modes start both
expected = [True, False] if mode == "sequential" else [True, True]
assert workers_active == expected


@pytest.mark.parametrize(
("limits", "expected"),
[
Expand Down