Skip to content

Commit

Permalink
Lite: Fix DataLoader shuffling when using DistributedSampler (#15931)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
awaelchli and pre-commit-ci[bot] authored Dec 8, 2022
1 parent 904323b commit 3004f13
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/lightning_lite/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed `shuffle=False` having no effect when using DDP/DistributedSampler ([#15931](https://github.com/Lightning-AI/lightning/issues/15931))



## [1.8.3] - 2022-11-22
Expand Down
3 changes: 2 additions & 1 deletion src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler

from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
from lightning_lite.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -582,6 +582,7 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:

@staticmethod
def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler:
kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler))
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)

Expand Down
30 changes: 28 additions & 2 deletions tests/tests_lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tests_lite.helpers.runif import RunIf
from tests_lite.helpers.utils import no_warning_call
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler, Sampler
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset

from lightning_lite.lite import LightningLite
from lightning_lite.plugins import Precision
Expand All @@ -40,7 +40,7 @@
from lightning_lite.strategies.strategy import _Sharded
from lightning_lite.utilities import _StrategyType
from lightning_lite.utilities.exceptions import MisconfigurationException
from lightning_lite.utilities.seed import pl_worker_init_function
from lightning_lite.utilities.seed import pl_worker_init_function, seed_everything
from lightning_lite.utilities.warnings import PossibleUserWarning
from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer

Expand Down Expand Up @@ -384,6 +384,32 @@ def test_setup_dataloaders_distributed_sampler_not_needed():
assert lite_dataloader.sampler is custom_sampler


def test_setup_dataloaders_distributed_sampler_shuffle():
"""Test that the DataLoader(shuffle=True|False) setting gets carried over correctly into the distributed
sampler."""
lite = LightningLite(accelerator="cpu", strategy="ddp_spawn", devices=2)
# no lite.launch(): pretend we are on rank 0 now

dataset = TensorDataset(torch.arange(8))

# shuffling turned off
no_shuffle_dataloaders = [
DataLoader(dataset),
DataLoader(dataset, shuffle=False),
DataLoader(dataset, sampler=SequentialSampler(dataset)),
]
for dataloader in no_shuffle_dataloaders:
dataloader = lite.setup_dataloaders(dataloader)
assert list(t[0].item() for t in iter(dataloader)) == [0, 2, 4, 6]

# shuffling turned on
shuffle_dataloaders = [DataLoader(dataset, shuffle=True), DataLoader(dataset, sampler=RandomSampler(dataset))]
for dataloader in shuffle_dataloaders:
seed_everything(1)
dataloader = lite.setup_dataloaders(dataloader)
assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1]


@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_everything():
"""Test that seed everything is static and sets the worker init function on the dataloader."""
Expand Down
8 changes: 4 additions & 4 deletions tests/tests_lite/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_boring_lite_model_ddp_spawn(precision, strategy, devices, accelerator,
)
def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir):
LightningLite.seed_everything(42)
train_dataloader = DataLoader(RandomDataset(32, 4))
train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True)
model = BoringModel()
num_epochs = 1
state_dict = deepcopy(model.state_dict())
Expand All @@ -214,13 +214,13 @@ def test_boring_lite_model_ddp(precision, strategy, devices, accelerator, tmpdir
lite_model_state_dict = model.state_dict()

for w_pure, w_lite in zip(state_dict.values(), lite_model_state_dict.values()):
assert not torch.equal(w_pure.cpu(), w_lite.cpu())
assert not torch.allclose(w_pure.cpu(), w_lite.cpu())

LightningLite.seed_everything(42)
train_dataloader = DataLoader(RandomDataset(32, 4))
train_dataloader = DataLoader(RandomDataset(32, 4), shuffle=True)
model = BoringModel()
run(lite.global_rank, model, train_dataloader, num_epochs, precision, accelerator, tmpdir)
pure_model_state_dict = model.state_dict()

for w_pure, w_lite in zip(pure_model_state_dict.values(), lite_model_state_dict.values()):
assert torch.equal(w_pure.cpu(), w_lite.cpu())
torch.testing.assert_close(w_pure.cpu(), w_lite.cpu())

0 comments on commit 3004f13

Please sign in to comment.