Skip to content

Commit

Permalink
Fix retrieval of batch indices when dataloader num_workers > 0 (#10870)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
awaelchli and rohitgr7 committed Dec 2, 2021
1 parent f26f637 commit 84bdcd4
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 63 deletions.
11 changes: 1 addition & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611))
- Fixed `SignalConnector._has_already_handler` check for callable type ([#10483](https://github.com/PyTorchLightning/pytorch-lightning/pull/10483))
- Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839))


- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))


- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837))


-


-
- Fixed a bug that caused incorrect batch indices to be passed to the `BasePredictionWriter` hooks when using a dataloader with `num_workers > 0` ([#10870](https://github.com/PyTorchLightning/pytorch-lightning/pull/10870))


## [1.5.4] - 2021-11-30
Expand Down
38 changes: 20 additions & 18 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self) -> None:
self._dl_max_batches: Optional[int] = None
self._num_dataloaders: Optional[int] = None
self._warning_cache = WarningCache()
self._all_batch_indices: List[int] = []
self._seen_batch_indices: List[List[int]] = []

@property
def done(self) -> bool:
Expand All @@ -44,8 +44,8 @@ def connect(self, **kwargs: "Loop") -> None:

def reset(self) -> None:
"""Resets the loops internal state."""
self._all_batch_indices: List[int] = []
self.predictions: List[Any] = []
self._seen_batch_indices = []
self.predictions = []
self.batch_progress.reset_on_run()

def on_run_start(
Expand All @@ -68,6 +68,7 @@ def on_run_start(
void(dataloader_iter, dataloader_idx)
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
self.return_predictions = return_predictions

def advance(
Expand All @@ -88,6 +89,10 @@ def advance(
return_predictions: whether to return the obtained predictions
"""
batch_idx, batch = next(dataloader_iter)
self._seen_batch_indices = self._get_batch_indices(dataloader_idx)
# we need to truncate the list of batch indicies due to prefetching in the dataloader and Lightning
self._seen_batch_indices = self._seen_batch_indices[: (self.batch_progress.current.completed + 1)]

if batch is None:
raise StopIteration

Expand All @@ -99,13 +104,10 @@ def advance(
with self.trainer.profiler.profile("predict_step"):
self._predict_step(batch, batch_idx, dataloader_idx)

def on_run_end(self) -> Tuple[List[Any], List[int]]:
def on_run_end(self) -> Tuple[List[Any], List[List[int]]]:
"""Returns the predictions and the corresponding batch indices."""
predictions = self.predictions
all_batch_indices = self._all_batch_indices
# free memory
self.predictions = []
self._all_batch_indices = []
predictions, all_batch_indices = self.predictions, self._seen_batch_indices
self.predictions, self._seen_batch_indices = [], [] # free memory
return predictions, all_batch_indices

def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
Expand All @@ -121,7 +123,7 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

# extract batch_indices and store them
self._store_batch_indices(dataloader_idx)
self.current_batch_indices = self._seen_batch_indices[batch_idx] if self._seen_batch_indices else []

model_ref = self.trainer.lightning_module

Expand Down Expand Up @@ -160,12 +162,12 @@ def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict
step_kwargs["dataloader_idx"] = dataloader_idx
return step_kwargs

def _store_batch_indices(self, dataloader_idx: int) -> None:
"""Stores the batch indices if the predictions should be stored."""
def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
:class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
self.current_batch_indices = batch_sampler.batch_indices
if self.should_store_predictions:
self._all_batch_indices.append(batch_sampler.batch_indices)
else:
warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
if isinstance(batch_sampler, IndexBatchSamplerWrapper) and self.should_store_predictions:
return batch_sampler.seen_batch_indices

warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
return []
26 changes: 23 additions & 3 deletions pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Any, Iterator, List, Optional
from typing import Any, Iterator, List

import torch
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, DistributedSampler, Sampler

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.utilities import rank_zero_deprecation


class LightningDistributedModule(_LightningModuleWrapperBase):
Expand Down Expand Up @@ -119,12 +120,31 @@ class IndexBatchSamplerWrapper:
"""This class is used to wrap a :class:`torch.utils.data.BatchSampler` and capture its indices."""

def __init__(self, sampler: BatchSampler) -> None:
self.seen_batch_indices: List[List[int]] = []
self._sampler = sampler
self.batch_indices: Optional[List[int]] = None
self._batch_indices: List[int] = []

@property
def batch_indices(self) -> List[int]:
rank_zero_deprecation(
"The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in"
" v1.7. Access the full list `seen_batch_indices` instead."
)
return self._batch_indices

@batch_indices.setter
def batch_indices(self, indices: List[int]) -> None:
rank_zero_deprecation(
"The attribute `IndexBatchSamplerWrapper.batch_indices` was deprecated in v1.5 and will be removed in"
" v1.7. Access the full list `seen_batch_indices` instead."
)
self._batch_indices = indices

def __iter__(self) -> Iterator[List[int]]:
self.seen_batch_indices = []
for batch in self._sampler:
self.batch_indices = batch
self._batch_indices = batch
self.seen_batch_indices.append(batch)
yield batch

def __len__(self) -> int:
Expand Down
102 changes: 73 additions & 29 deletions tests/callbacks/test_prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,54 +11,98 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import ANY, call, Mock

import pytest
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import BasePredictionWriter
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf


def test_prediction_writer(tmpdir):
class CustomPredictionWriter(BasePredictionWriter):
def __init__(self, writer_interval: str):
super().__init__(writer_interval)
class DummyPredictionWriter(BasePredictionWriter):
def write_on_batch_end(self, *args, **kwargs):
pass

self.write_on_batch_end_called = False
self.write_on_epoch_end_called = False
def write_on_epoch_end(self, *args, **kwargs):
pass

def write_on_batch_end(self, *args, **kwargs):
self.write_on_batch_end_called = True

def write_on_epoch_end(self, *args, **kwargs):
self.write_on_epoch_end_called = True

def test_prediction_writer_invalid_write_interval():
"""Test that configuring an unknown interval name raises an error."""
with pytest.raises(MisconfigurationException, match=r"`write_interval` should be one of \['batch"):
CustomPredictionWriter("something")
DummyPredictionWriter("something")


def test_prediction_writer_hook_call_intervals(tmpdir):
"""Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined
interval."""
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()

dataloader = DataLoader(RandomDataset(32, 64))

model = BoringModel()
cb = CustomPredictionWriter("batch_and_epoch")
cb = DummyPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
results = trainer.predict(model, dataloaders=model.train_dataloader())
results = trainer.predict(model, dataloaders=dataloader)
assert len(results) == 4
assert cb.write_on_batch_end_called
assert cb.write_on_epoch_end_called
assert cb.write_on_batch_end.call_count == 4
assert cb.write_on_epoch_end.call_count == 1

cb = CustomPredictionWriter("batch_and_epoch")
DummyPredictionWriter.write_on_batch_end.reset_mock()
DummyPredictionWriter.write_on_epoch_end.reset_mock()

cb = DummyPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert cb.write_on_batch_end_called
assert cb.write_on_epoch_end_called
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 4
assert cb.write_on_epoch_end.call_count == 1

DummyPredictionWriter.write_on_batch_end.reset_mock()
DummyPredictionWriter.write_on_epoch_end.reset_mock()

cb = CustomPredictionWriter("batch")
cb = DummyPredictionWriter("batch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert cb.write_on_batch_end_called
assert not cb.write_on_epoch_end_called
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 4
assert cb.write_on_epoch_end.call_count == 0

DummyPredictionWriter.write_on_batch_end.reset_mock()
DummyPredictionWriter.write_on_epoch_end.reset_mock()

cb = CustomPredictionWriter("epoch")
cb = DummyPredictionWriter("epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=cb)
trainer.predict(model, dataloaders=model.train_dataloader(), return_predictions=False)
assert not cb.write_on_batch_end_called
assert cb.write_on_epoch_end_called
trainer.predict(model, dataloaders=dataloader, return_predictions=False)
assert cb.write_on_batch_end.call_count == 0
assert cb.write_on_epoch_end.call_count == 1


@pytest.mark.parametrize("num_workers", [0, pytest.param(2, marks=RunIf(slow=True))])
def test_prediction_writer_batch_indices(tmpdir, num_workers):
DummyPredictionWriter.write_on_batch_end = Mock()
DummyPredictionWriter.write_on_epoch_end = Mock()

dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers)
model = BoringModel()
writer = DummyPredictionWriter("batch_and_epoch")
trainer = Trainer(limit_predict_batches=4, callbacks=writer)
trainer.predict(model, dataloaders=dataloader)

writer.write_on_batch_end.assert_has_calls(
[
call(trainer, model, ANY, [0, 1, 2, 3], ANY, 0, 0),
call(trainer, model, ANY, [4, 5, 6, 7], ANY, 1, 0),
call(trainer, model, ANY, [8, 9, 10, 11], ANY, 2, 0),
call(trainer, model, ANY, [12, 13, 14, 15], ANY, 3, 0),
]
)

writer.write_on_epoch_end.assert_has_calls(
[
call(trainer, model, ANY, [[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]]),
]
)
11 changes: 11 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Test deprecated functionality which will be removed in v1.7.0."""
from unittest import mock
from unittest.mock import Mock

import pytest

Expand All @@ -22,6 +23,7 @@
from pytorch_lightning.callbacks.progress import ProgressBar
from pytorch_lightning.callbacks.xla_stats_monitor import XLAStatsMonitor
from pytorch_lightning.loggers import LoggerCollection, TestTubeLogger
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from tests.callbacks.test_callbacks import OldStatefulCallback
from tests.deprecated_api import _soft_unimport_module
from tests.helpers import BoringModel
Expand Down Expand Up @@ -448,3 +450,12 @@ def test_v1_7_0_deprecate_lr_sch_names(tmpdir):

with pytest.deprecated_call(match="`LearningRateMonitor.lr_sch_names` has been deprecated in v1.5"):
assert lr_monitor.lr_sch_names == ["lr-SGD"]


def test_v1_7_0_index_batch_sampler_wrapper_batch_indices():
sampler = IndexBatchSamplerWrapper(Mock())
with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
_ = sampler.batch_indices

with pytest.deprecated_call(match="was deprecated in v1.5 and will be removed in v1.7"):
sampler.batch_indices = []
4 changes: 1 addition & 3 deletions tests/overrides/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def test_index_batch_sampler(tmpdir):
assert batch_sampler.batch_size == index_batch_sampler.batch_size
assert batch_sampler.drop_last == index_batch_sampler.drop_last
assert batch_sampler.sampler is sampler

for batch in index_batch_sampler:
assert index_batch_sampler.batch_indices == batch
assert list(index_batch_sampler) == index_batch_sampler.seen_batch_indices


def test_index_batch_sampler_methods():
Expand Down

0 comments on commit 84bdcd4

Please sign in to comment.