diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 714d4340f1ba1..43a3e69ca160b 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -79,6 +79,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed epoch-end logging results not being reset after the end of the epoch ([#14061](https://github.com/Lightning-AI/lightning/pull/14061)) +- Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097)) + + - Fixed saving hyperparameters in a composition where the parent class is not a `LightningModule` or `LightningDataModule` ([#14151](https://github.com/Lightning-AI/lightning/pull/14151)) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 1de8bee90d18f..6e592b9f6d310 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -298,10 +298,14 @@ def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional # update docs too once this is resolved trainer_fn = self.trainer.state.fn - if isinstance(sampler, DistributedSampler) and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING): + if ( + isinstance(sampler, DistributedSampler) + and sampler.num_replicas > 1 + and trainer_fn in (TrainerFn.VALIDATING, TrainerFn.TESTING) + ): rank_zero_warn( - f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`," - " it is recommended to use `Trainer(devices=1)` to ensure each sample/batch gets evaluated" + f"Using `DistributedSampler` with the dataloaders. During `trainer.{trainer_fn.value}()`, it is" + " recommended to use `Trainer(devices=1, num_nodes=1)` to ensure each sample/batch gets evaluated" " exactly once. Otherwise, multi-device settings use `DistributedSampler` that replicates" " some samples to make sure all devices have same batch size in case of uneven inputs.", category=PossibleUserWarning, diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 7273d7719834e..379a3248a1535 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -526,19 +526,20 @@ def test_invalid_hook_passed_in_datahook_selector(): dh_selector.get_instance("setup") -def test_eval_distributed_sampler_warning(tmpdir): +@pytest.mark.parametrize("devices, warn_context", [(1, no_warning_call), (2, pytest.warns)]) +def test_eval_distributed_sampler_warning(devices, warn_context): """Test that a warning is raised when `DistributedSampler` is used with evaluation.""" model = BoringModel() - trainer = Trainer(strategy="ddp", devices=2, accelerator="cpu", fast_dev_run=True) + trainer = Trainer(strategy="ddp", devices=devices, accelerator="cpu") trainer._data_connector.attach_data(model) trainer.state.fn = TrainerFn.VALIDATING - with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): + with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): trainer.reset_val_dataloader(model) trainer.state.fn = TrainerFn.TESTING - with pytest.warns(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): + with warn_context(PossibleUserWarning, match="multi-device settings use `DistributedSampler`"): trainer.reset_test_dataloader(model)