diff --git a/CHANGELOG.md b/CHANGELOG.md index c64be1885d322..57b3ed7e97a96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,18 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611)) - Fixed `SignalConnector._has_already_handler` check for callable type ([#10483](https://github.com/PyTorchLightning/pytorch-lightning/pull/10483)) - Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839)) - - - Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369)) - - - Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837)) - - -- - - -- +- Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870)) ## [1.5.4] - 2021-11-30 diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index 58e65233dfe81..e5fa46fe05836 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -26,7 +26,7 @@ def __init__(self) -> None: self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self._warning_cache = WarningCache() - self._all_batch_indices: List[int] = [] + self._seen_batch_indices: List[List[int]] = [] @property def done(self) -> bool: @@ -44,8 +44,8 @@ def connect(self, **kwargs: "Loop") -> None: def reset(self) -> None: """Resets the loops internal state.""" - self._all_batch_indices: List[int] = [] - self.predictions: List[Any] = [] + self._seen_batch_indices = [] + self.predictions = [] self.batch_progress.reset_on_run() def on_run_start( @@ -68,6 +68,7 @@ def on_run_start( void(dataloader_iter, dataloader_idx) self._dl_max_batches = dl_max_batches self._num_dataloaders = num_dataloaders + self._seen_batch_indices = self._get_batch_indices(dataloader_idx) self.return_predictions = return_predictions def advance( @@ -88,6 +89,10 @@ def advance( return_predictions: whether to return the obtained predictions """ batch_idx, batch = next(dataloader_iter) + self._seen_batch_indices = self._get_batch_indices(dataloader_idx) + # we need to truncate the list of batch indicies due to prefetching in the dataloader and Lightning + self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)] + if batch is None: raise StopIteration @@ -99,13 +104,10 @@ def advance( with self.trainer.profiler.profile("predict_step"): self._predict_step(batch, batch_idx, dataloader_idx) - def on_run_end(self) -> Tuple[List[Any], List[int]]: + def on_run_end(self) -> Tuple[List[Any], List[List[int]]]: """Returns the predictions and the corresponding batch indices.""" - predictions = self.predictions - all_batch_indices = self._all_batch_indices - # free memory - self.predictions = [] - self._all_batch_indices = [] + predictions, all_batch_indices = self.predictions, self._seen_batch_indices + self.predictions, self._seen_batch_indices = [], [] # free memory return predictions, all_batch_indices def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: @@ -121,7 +123,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) # extract batch_indices and store them - self._store_batch_indices(dataloader_idx) + self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else [] model_ref = self.trainer.lightning_module @@ -160,12 +162,12 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict step_kwargs["dataloader_idx"] = dataloader_idx return step_kwargs - def _store_batch_indices(self, dataloader_idx: int) -> None: - """Stores the batch indices if the predictions should be stored.""" + def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: + """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our + :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler - if isinstance(batch_sampler, IndexBatchSamplerWrapper): - self.current_batch_indices = batch_sampler.batch_indices - if self.should_store_predictions: - self._all_batch_indices.append(batch_sampler.batch_indices) - else: - warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") + if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions: + return batch_sampler.seen_batch_indices + + warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.") + return [] diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 0cf392dd44775..835d7f87040c1 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import itertools -from typing import Any, Iterator, List, Optional +from typing import Any, Iterator, List import torch from torch.nn.parallel import DistributedDataParallel @@ -20,6 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.utilities import rank_zero_deprecation class LightningDistributedModule(_LightningModuleWrapperBase): @@ -119,12 +120,31 @@ class IndexBatchSamplerWrapper: """This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices.""" def __init__(self, sampler: BatchSampler) -> None: + self.seen_batch_indices: List[List[int]] = [] self._sampler = sampler - self.batch_indices: Optional[List[int]] = None + self._batch_indices: List[int] = [] + + @property + def batch_indices(self) -> List[int]: + rank_zero_deprecation( + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in" + " v1.7. Access the full list `seen_batch_indices` instead." + ) + return self._batch_indices + + @batch_indices.setter + def batch_indices(self, indices: List[int]) -> None: + rank_zero_deprecation( + "The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in" + " v1.7. Access the full list `seen_batch_indices` instead." + ) + self._batch_indices = indices def __iter__(self) -> Iterator[List[int]]: + self.seen_batch_indices = [] for batch in self._sampler: - self.batch_indices = batch + self._batch_indices = batch + self.seen_batch_indices.append(batch) yield batch def __len__(self) -> int: diff --git a/tests/callbacks/test_prediction_writer.py b/tests/callbacks/test_prediction_writer.py index 75e0dbd31ec79..2cd3738ca875f 100644 --- a/tests/callbacks/test_prediction_writer.py +++ b/tests/callbacks/test_prediction_writer.py @@ -11,54 +11,98 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import ANY, call, Mock import pytest +from torch.utils.data import DataLoader from pytorch_lightning import Trainer from pytorch_lightning.callbacks import BasePredictionWriter from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers import BoringModel +from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf -def test_prediction_writer(tmpdir): - class CustomPredictionWriter(BasePredictionWriter): - def __init__(self, writer_interval: str): - super().__init__(writer_interval) +class DummyPredictionWriter(BasePredictionWriter): + def write_on_batch_end(self, *args, **kwargs): + pass - self.write_on_batch_end_called = False - self.write_on_epoch_end_called = False + def write_on_epoch_end(self, *args, **kwargs): + pass - def write_on_batch_end(self, *args, **kwargs): - self.write_on_batch_end_called = True - - def write_on_epoch_end(self, *args, **kwargs): - self.write_on_epoch_end_called = True +def test_prediction_writer_invalid_write_interval(): + """Test that configuring an unknown interval name raises an error.""" with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"): - CustomPredictionWriter("something") + DummyPredictionWriter("something") + + +def test_prediction_writer_hook_call_intervals(tmpdir): + """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined + interval.""" + DummyPredictionWriter.write_on_batch_end = Mock() + DummyPredictionWriter.write_on_epoch_end = Mock() + + dataloader = DataLoader(RandomDataset(32, 64)) model = BoringModel() - cb = CustomPredictionWriter("batch_and_epoch") + cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - results = trainer.predict(model, dataloaders=model.train_dataloader()) + results = trainer.predict(model, dataloaders=dataloader) assert len(results) == 4 - assert cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 1 - cb = CustomPredictionWriter("batch_and_epoch") + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() + + cb = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 1 + + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() - cb = CustomPredictionWriter("batch") + cb = DummyPredictionWriter("batch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert cb.write_on_batch_end_called - assert not cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 4 + assert cb.write_on_epoch_end.call_count == 0 + + DummyPredictionWriter.write_on_batch_end.reset_mock() + DummyPredictionWriter.write_on_epoch_end.reset_mock() - cb = CustomPredictionWriter("epoch") + cb = DummyPredictionWriter("epoch") trainer = Trainer(limit_predict_batches=4, callbacks=cb) - trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False) - assert not cb.write_on_batch_end_called - assert cb.write_on_epoch_end_called + trainer.predict(model, dataloaders=dataloader, return_predictions=False) + assert cb.write_on_batch_end.call_count == 0 + assert cb.write_on_epoch_end.call_count == 1 + + +@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))]) +def test_prediction_writer_batch_indices(tmpdir, num_workers): + DummyPredictionWriter.write_on_batch_end = Mock() + DummyPredictionWriter.write_on_epoch_end = Mock() + + dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers) + model = BoringModel() + writer = DummyPredictionWriter("batch_and_epoch") + trainer = Trainer(limit_predict_batches=4, callbacks=writer) + trainer.predict(model, dataloaders=dataloader) + + writer.write_on_batch_end.assert_has_calls( + [ + call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0), + call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0), + call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0), + call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0), + ] + ) + + writer.write_on_epoch_end.assert_has_calls( + [ + call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]), + ] + ) diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index e0e51575f5f22..62ec4d8d5490a 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -13,6 +13,7 @@ # limitations under the License. """Test deprecated functionality which will be removed in v1.7.0.""" from unittest import mock +from unittest.mock import Mock import pytest @@ -22,6 +23,7 @@ from pytorch_lightning.callbacks.progress import ProgressBar from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger +from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from tests.callbacks.test_callbacks import OldStatefulCallback from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel @@ -448,3 +450,12 @@ def test_v1_7_0_deprecate_lr_sch_names(tmpdir): with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"): assert lr_monitor.lr_sch_names == ["lr-SGD"] + + +def test_v1_7_0_index_batch_sampler_wrapper_batch_indices(): + sampler = IndexBatchSamplerWrapper(Mock()) + with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): + _ = sampler.batch_indices + + with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"): + sampler.batch_indices = [] diff --git a/tests/overrides/test_distributed.py b/tests/overrides/test_distributed.py index c8d982bd733fe..e425859fe34df 100644 --- a/tests/overrides/test_distributed.py +++ b/tests/overrides/test_distributed.py @@ -54,9 +54,7 @@ def test_index_batch_sampler(tmpdir): assert batch_sampler.batch_size == index_batch_sampler.batch_size assert batch_sampler.drop_last == index_batch_sampler.drop_last assert batch_sampler.sampler is sampler - - for batch in index_batch_sampler: - assert index_batch_sampler.batch_indices == batch + assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices def test_index_batch_sampler_methods():