diff --git a/src/lightning_lite/CHANGELOG.md b/src/lightning_lite/CHANGELOG.md index 62a0448315aaa..1cde4d79f4e6f 100644 --- a/src/lightning_lite/CHANGELOG.md +++ b/src/lightning_lite/CHANGELOG.md @@ -42,7 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed `shuffle=False` having no effect when using DDP/DistributedSampler ([#15931](https://github.com/Lightning-AI/lightning/issues/15931)) + ## [1.8.3] - 2022-11-22 diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 5cd33f7c1baea..917cc48e2ff93 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -25,7 +25,7 @@ from lightning_utilities.core.rank_zero import rank_zero_warn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler from lightning_lite.plugins import Precision # avoid circular imports: # isort: split from lightning_lite.accelerators.accelerator import Accelerator @@ -582,6 +582,7 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: @staticmethod def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler: + kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler)) kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) return DistributedSamplerWrapper(dataloader.sampler, **kwargs) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index ba28837ff515a..a14e1214adcfd 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -23,7 +23,7 @@ from tests_lite.helpers.runif import RunIf from tests_lite.helpers.utils import no_warning_call from torch import nn -from torch.utils.data import DataLoader, DistributedSampler, Sampler +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset from lightning_lite.lite import LightningLite from lightning_lite.plugins import Precision @@ -40,7 +40,7 @@ from lightning_lite.strategies.strategy import _Sharded from lightning_lite.utilities import _StrategyType from lightning_lite.utilities.exceptions import MisconfigurationException -from lightning_lite.utilities.seed import pl_worker_init_function +from lightning_lite.utilities.seed import pl_worker_init_function, seed_everything from lightning_lite.utilities.warnings import PossibleUserWarning from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -384,6 +384,32 @@ def test_setup_dataloaders_distributed_sampler_not_needed(): assert lite_dataloader.sampler is custom_sampler +def test_setup_dataloaders_distributed_sampler_shuffle(): + """Test that the DataLoader(shuffle=True|False) setting gets carried over correctly into the distributed + sampler.""" + lite = LightningLite(accelerator="cpu", strategy="ddp_spawn", devices=2) + # no lite.launch(): pretend we are on rank 0 now + + dataset = TensorDataset(torch.arange(8)) + + # shuffling turned off + no_shuffle_dataloaders = [ + DataLoader(dataset), + DataLoader(dataset, shuffle=False), + DataLoader(dataset, sampler=SequentialSampler(dataset)), + ] + for dataloader in no_shuffle_dataloaders: + dataloader = lite.setup_dataloaders(dataloader) + assert list(t[0].item() for t in iter(dataloader)) == [0, 2, 4, 6] + + # shuffling turned on + shuffle_dataloaders = [DataLoader(dataset, shuffle=True), DataLoader(dataset, sampler=RandomSampler(dataset))] + for dataloader in shuffle_dataloaders: + seed_everything(1) + dataloader = lite.setup_dataloaders(dataloader) + assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1] + + @mock.patch.dict(os.environ, {}, clear=True) def test_seed_everything(): """Test that seed everything is static and sets the worker init function on the dataloader.""" diff --git a/tests/tests_lite/test_parity.py b/tests/tests_lite/test_parity.py index b74a23438d0d6..52d602ca366d7 100644 --- a/tests/tests_lite/test_parity.py +++ b/tests/tests_lite/test_parity.py @@ -203,7 +203,7 @@ def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator, ) def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir): LightningLite.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 4)) + train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True) model = BoringModel() num_epochs = 1 state_dict = deepcopy(model.state_dict()) @@ -214,13 +214,13 @@ def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir lite_model_state_dict = model.state_dict() for w_pure, w_lite in zip(state_dict.values(), lite_model_state_dict.values()): - assert not torch.equal(w_pure.cpu(), w_lite.cpu()) + assert not torch.allclose(w_pure.cpu(), w_lite.cpu()) LightningLite.seed_everything(42) - train_dataloader = DataLoader(RandomDataset(32, 4)) + train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True) model = BoringModel() run(lite.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir) pure_model_state_dict = model.state_dict() for w_pure, w_lite in zip(pure_model_state_dict.values(), lite_model_state_dict.values()): - assert torch.equal(w_pure.cpu(), w_lite.cpu()) + torch.testing.assert_close(w_pure.cpu(), w_lite.cpu())