From 2c7202f3518c87c57a88a40745be866997d98e13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 9 Nov 2021 16:37:44 +0100 Subject: [PATCH] Resolve workers being forcelly deleted with `persistent_workers=True` (#10434) --- CHANGELOG.md | 1 + pytorch_lightning/utilities/fetching.py | 6 +++--- tests/loops/test_loops.py | 18 ++++++++++++------ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c96074a6f640f9..45f5efcc562164 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed deadlocks for distributed training with `RichProgressBar` ([#10428](https://github.com/PyTorchLightning/pytorch-lightning/pull/10428)) - Fixed an issue where the model wrapper in Lite converted non-floating point tensors to float ([#10429](https://github.com/PyTorchLightning/pytorch-lightning/pull/10429)) - Fixed an issue with inferring the dataset type in fault-tolerant training ([#10432](https://github.com/PyTorchLightning/pytorch-lightning/pull/10432)) +- Fixed dataloader workers with `persistent_workers` being deleted on every iteration ([#10434](https://github.com/PyTorchLightning/pytorch-lightning/pull/10434)) ## [1.5.0] - 2021-11-02 diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index fd9baf3e9c4f15..9b80d2f9874c72 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -206,15 +206,15 @@ def reset(self) -> None: self.batches: List = [] self.fetched: int = 0 self.done: bool = False + + def teardown(self) -> None: + self.reset() if isinstance(self.dataloader, CombinedLoader): self.dataloader.reset() if isinstance(self.dataloader, DataLoader): CombinedLoader._shutdown_workers_and_reset_iterator(self.dataloader) self.dataloader_iter = None - def teardown(self) -> None: - self.reset() - class DataFetcher(AbstractDataFetcher): diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index dd390ab4939d5b..bad9a717d16294 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -912,21 +912,25 @@ def val_dataloader(self): @RunIf(min_torch="1.8.0") -@pytest.mark.parametrize("persistent_workers", (True, False)) +@pytest.mark.parametrize("persistent_workers", (False, True)) def test_workers_are_shutdown(tmpdir, persistent_workers): # `num_workers == 1` uses `_MultiProcessingDataLoaderIter` # `persistent_workers` makes sure `self._iterator` gets set on the `DataLoader` instance class _TestMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter): - def __init__(self, *args, dataloader: DataLoader, **kwargs): + def __init__(self, *args, dataloader, **kwargs): super().__init__(*args, **kwargs) self.dataloader = dataloader def _shutdown_workers(self): - setattr(self.dataloader, "has_shutdown_workers", True) + self.dataloader.count_shutdown_workers += 1 super()._shutdown_workers() class TestDataLoader(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.count_shutdown_workers = 0 + def _get_iterator(self): if self.num_workers == 0: return super()._get_iterator() @@ -937,10 +941,12 @@ def _get_iterator(self): train_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) val_dataloader = TestDataLoader(RandomDataset(32, 64), num_workers=1, persistent_workers=persistent_workers) + max_epochs = 3 model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2) + trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=max_epochs) trainer.fit(model, train_dataloader, val_dataloader) - assert train_dataloader.has_shutdown_workers - assert val_dataloader.has_shutdown_workers + assert train_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs) + # on sanity checking end, the workers are being deleted too. + assert val_dataloader.count_shutdown_workers == (2 if persistent_workers else max_epochs + 1) assert train_dataloader._iterator is None assert val_dataloader._iterator is None