Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid raising the sampler warning if num_replicas=1 #14097

Merged
merged 12 commits into from
Aug 12, 2022
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061))


- Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097))


- Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151))


Expand Down
10 changes: 7 additions & 3 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,14 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional

# update docs too once this is resolved
trainer_fn = self.trainer.state.fn
if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING):
if (
isinstance(sampler, DistributedSampler)
and sampler.num_replicas > 1
and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING)
):
rank_zero_warn(
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`,"
" it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated"
f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`, it is"
" recommended to use `Trainer(devices=1, num_nodes=1)` to ensure each sample/batch gets evaluated"
" exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates"
" some samples to make sure all devices have same batch size in case of uneven inputs.",
category=PossibleUserWarning,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -526,19 +526,20 @@ def test_invalid_hook_passed_in_datahook_selector():
dh_selector.get_instance("setup")


def test_eval_distributed_sampler_warning(tmpdir):
@pytest.mark.parametrize("devices, warn_context", [(1, no_warning_call), (2, pytest.warns)])
def test_eval_distributed_sampler_warning(devices, warn_context):
"""Test that a warning is raised when `DistributedSampler` is used with evaluation."""

model = BoringModel()
trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True)
trainer = Trainer(strategy="ddp", devices=devices, accelerator="cpu")
trainer._data_connector.attach_data(model)

trainer.state.fn = TrainerFn.VALIDATING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_val_dataloader(model)

trainer.state.fn = TrainerFn.TESTING
with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"):
trainer.reset_test_dataloader(model)


Expand Down