Skip to content

Commit

Permalink
Avoid raising the sampler warning if num_replicas=1 (#14097)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: otaj <[email protected]>
  • Loading branch information
3 people authored Aug 12, 2022
1 parent 807f9d8 commit c8e22b4
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))


- Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097))


- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))


Expand Down
10 changes: 7 additions & 3 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,14 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional

# update docs too once this is resolved
trainer_fn = self.trainer.state.fn
if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING):
if (
isinstance(sampler, DistributedSampler)
and sampler.num_replicas > 1
and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING)
):
rank_zero_warn(
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`,"
" it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`, it is"
" recommended to use `Trainer(devices=1, num_nodes=1)` to ensure each sample/batch gets evaluated"
" exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
" some samples to make sure all devices have same batch size in case of uneven inputs.",
category=PossibleUserWarning,
Expand Down
9 changes: 5 additions & 4 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,19 +526,20 @@ def test_invalid_hook_passed_in_datahook_selector():
dh_selector.get_instance("setup")


def test_eval_distributed_sampler_warning(tmpdir):
@pytest.mark.parametrize("devices, warn_context", [(1, no_warning_call), (2, pytest.warns)])
def test_eval_distributed_sampler_warning(devices, warn_context):
"""Test that a warning is raised when `DistributedSampler` is used with evaluation."""

model = BoringModel()
trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True)
trainer = Trainer(strategy="ddp", devices=devices, accelerator="cpu")
trainer._data_connector.attach_data(model)

trainer.state.fn = TrainerFn.VALIDATING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_val_dataloader(model)

trainer.state.fn = TrainerFn.TESTING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_test_dataloader(model)


Expand Down

0 comments on commit c8e22b4

Please sign in to comment.