Skip to content

Commit

Permalink
Call set_epoch for distributed batch samplers (#13396)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka <[email protected]>
Co-authored-by: Rohit Gupta <[email protected]>
  • Loading branch information
3 people authored Jun 29, 2022
1 parent 43635a9 commit 2dd332f
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 32 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `pytorch_lightning.utilities.distributed.gather_all_tensors` to handle tensors of different dimensions ([#12630](https://github.com/PyTorchLightning/pytorch-lightning/pull/12630))


- The loops now call `.set_epoch()` also on batch samplers if the dataloader has one wrapped in a distributed sampler ([#13396](https://github.com/PyTorchLightning/pytorch-lightning/pull/13396))


-


Expand Down
11 changes: 3 additions & 8 deletions src/pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -161,14 +162,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
self._has_run = True

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
dataloader = self.current_dataloader
if (
dataloader is not None
and getattr(dataloader, "sampler", None)
and callable(getattr(dataloader.sampler, "set_epoch", None))
):
# set seed for distributed sampler (enables shuffling for each epoch)
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
if self.current_dataloader is not None:
_set_sampler_epoch(self.current_dataloader, self.trainer.fit_loop.epoch_progress.current.processed)

super().on_advance_start(*args, **kwargs)

Expand Down
10 changes: 3 additions & 7 deletions src/pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.loops.utilities import _set_sampler_epoch
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT
Expand Down Expand Up @@ -90,13 +91,8 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
"""Predicts one entire dataloader."""
void(*args, **kwargs)
dataloader = self.current_dataloader
if (
dataloader is not None
and getattr(dataloader, "sampler", None)
and callable(getattr(dataloader.sampler, "set_epoch", None))
):
# set seed for distributed sampler (enables shuffling for each epoch)
dataloader.sampler.set_epoch(self.trainer.fit_loop.epoch_progress.current.processed)
if dataloader is not None:
_set_sampler_epoch(dataloader, self.trainer.fit_loop.epoch_progress.current.processed)
dataloader = self.trainer.strategy.process_dataloader(dataloader)
dataloader_iter = enumerate(dataloader)
dl_max_batches = self.max_batches[self.current_dataloader_idx]
Expand Down
9 changes: 3 additions & 6 deletions src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import _OUTPUTS_TYPE as _EPOCH_OUTPUTS_TYPE
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
Expand Down Expand Up @@ -232,11 +232,8 @@ def on_advance_start(self) -> None: # type: ignore[override]
# reset outputs here instead of in `reset` as they are not accumulated between epochs
self._outputs = []

if self.trainer.train_dataloader is not None and callable(
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
):
# set seed for distributed sampler (enables shuffling for each epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed)
if self.trainer.train_dataloader is not None:
_set_sampler_epoch(self.trainer.train_dataloader, self.epoch_progress.current.processed)

# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
Expand Down
14 changes: 14 additions & 0 deletions src/pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loops import Loop
Expand Down Expand Up @@ -220,3 +221,16 @@ def _reset_progress(loop: Loop) -> None:
def _v1_8_output_format(fx: Callable) -> bool:
parameters = inspect.signature(fx).parameters
return "new_format" in parameters and parameters["new_format"].default is True


def _set_sampler_epoch(dataloader: DataLoader, epoch: int) -> None:
"""Calls the ``set_epoch`` method on either the sampler or the batch sampler of the given dataloader.
Every PyTorch dataloader has either a sampler or a batch sampler, and if it is wrapped by a
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
"""
for sampler_name in ("sampler", "batch_sampler"):
sampler = getattr(dataloader, sampler_name, None)
if sampler is not None and callable(getattr(sampler, "set_epoch", None)):
sampler.set_epoch(epoch)
7 changes: 6 additions & 1 deletion src/pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,14 @@ class DataLoaderDict(dict):

@property
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of samplers extracting from loaders."""
"""Return a collections of samplers extracted from loaders."""
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "sampler", None)

@property
def batch_sampler(self) -> Union[Iterable, Sequence, Mapping]:
"""Return a collections of batch samplers extracted from loaders."""
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "batch_sampler", None)

def _wrap_loaders_max_size_cycle(self) -> Any:
"""Wraps all loaders to make sure they are cycled until the longest loader is exhausted.
Expand Down
57 changes: 48 additions & 9 deletions tests/tests_pytorch/loops/test_evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import Mock
from unittest.mock import call, Mock

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import RandomSampler
from torch.utils.data.sampler import BatchSampler, RandomSampler

from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
Expand Down Expand Up @@ -44,9 +44,8 @@ def test_on_evaluation_epoch_end(eval_epoch_end_mock, tmpdir):
assert eval_epoch_end_mock.call_count == 4


def test_set_epoch_called_eval_predict(tmpdir):
"""Tests that set_epoch (if the sampler has one) is called on the DataLoader during evaluation and
prediction."""
def test_evaluation_loop_sampler_set_epoch_called(tmpdir):
"""Tests that set_epoch is called on the dataloader's sampler (if any) during training and validation."""

def _get_dataloader():
dataset = RandomDataset(32, 64)
Expand All @@ -56,20 +55,60 @@ def _get_dataloader():

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=2, enable_model_summary=False
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
)

train_dataloader = _get_dataloader()
val_dataloader = _get_dataloader()
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
# One for each epoch
assert train_dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)]
# One for each epoch + sanity check
assert val_dataloader.sampler.set_epoch.call_args_list == [call(0), call(0), call(1)]

val_dataloader = _get_dataloader()
trainer.validate(model, val_dataloader)
assert val_dataloader.sampler.set_epoch.call_args_list == [call(2)]


def test_evaluation_loop_batch_sampler_set_epoch_called(tmpdir):
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during training and validation."""

def _get_dataloader():
dataset = RandomDataset(32, 64)
sampler = RandomSampler(dataset)
batch_sampler = BatchSampler(sampler, 2, True)
batch_sampler.set_epoch = Mock()
return DataLoader(dataset, batch_sampler=batch_sampler)

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=2,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
)

train_dataloader = _get_dataloader()
val_dataloader = _get_dataloader()
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
# One for each epoch
assert train_dataloader.sampler.set_epoch.call_count == 2
assert train_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(1)]
# One for each epoch + sanity check
assert val_dataloader.sampler.set_epoch.call_count == 3
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(0), call(0), call(1)]

val_dataloader = _get_dataloader()
trainer.validate(model, val_dataloader)
assert val_dataloader.sampler.set_epoch.call_count == 1
assert val_dataloader.batch_sampler.set_epoch.call_args_list == [call(2)]


@mock.patch(
Expand Down
24 changes: 23 additions & 1 deletion tests/tests_pytorch/loops/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock

import pytest
import torch

from pytorch_lightning.loops.utilities import _extract_hiddens, _v1_8_output_format
from pytorch_lightning.loops.utilities import _extract_hiddens, _set_sampler_epoch, _v1_8_output_format
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -61,3 +63,23 @@ def training_epoch_end(outputs, new_format=True):
...

assert _v1_8_output_format(training_epoch_end)


def test_set_sampler_epoch():
# No samplers
dataloader = Mock()
dataloader.sampler = None
dataloader.batch_sampler = None
_set_sampler_epoch(dataloader, 55)

# set_epoch not callable
dataloader = Mock()
dataloader.sampler.set_epoch = None
dataloader.batch_sampler.set_epoch = None
_set_sampler_epoch(dataloader, 55)

# set_epoch callable
dataloader = Mock()
_set_sampler_epoch(dataloader, 55)
dataloader.sampler.set_epoch.assert_called_once_with(55)
dataloader.batch_sampler.set_epoch.assert_called_once_with(55)

0 comments on commit 2dd332f

Please sign in to comment.