From 7bc55c38a72a3ee0a7cfa101b5a3bbb31aafc648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 19 Dec 2022 22:57:15 +0100 Subject: [PATCH] Distributed sampling parity between Lite and PyTorch (#16101) --- src/lightning_lite/CHANGELOG.md | 5 ++- src/lightning_lite/lite.py | 4 +- src/lightning_lite/utilities/distributed.py | 5 +++ src/lightning_lite/wrappers.py | 8 ++++ tests/tests_lite/test_lite.py | 41 ++++++++++++++++++++- tests/tests_lite/test_wrappers.py | 23 +++++++++++- 6 files changed, 82 insertions(+), 4 deletions(-) diff --git a/src/lightning_lite/CHANGELOG.md b/src/lightning_lite/CHANGELOG.md index d331abc41c3ca..7353f226c3541 100644 --- a/src/lightning_lite/CHANGELOG.md +++ b/src/lightning_lite/CHANGELOG.md @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Merged the implementation of `DDPSpawnStrategy` into `DDPStrategy` and removed `DDPSpawnStrategy` ([#14952](https://github.com/Lightning-AI/lightning/issues/14952)) +- The dataloader wrapper returned from `.setup_dataloaders()` now calls `.set_epoch()` on the distributed sampler if one is used ([#16101](https://github.com/Lightning-AI/lightning/issues/16101)) + ### Deprecated @@ -46,7 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Restored sampling parity between PyTorch and Lite dataloaders when using the `DistributedSampler` ([#16101](https://github.com/Lightning-AI/lightning/issues/16101)) + ## [1.8.4] - 2022-12-08 diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index b7ba6acf52d17..163791cb1959f 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -25,7 +25,7 @@ from lightning_utilities.core.rank_zero import rank_zero_warn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler from lightning_lite.plugins import Precision # avoid circular imports: # isort: split from lightning_lite.accelerators.accelerator import Accelerator @@ -583,6 +583,8 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler: kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler)) kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) + if isinstance(dataloader.sampler, (RandomSampler, SequentialSampler)): + return DistributedSampler(dataloader.dataset, **kwargs) return DistributedSamplerWrapper(dataloader.sampler, **kwargs) def _prepare_run_method(self) -> None: diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index b70c56360126a..b80b18f42e71e 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -294,6 +294,11 @@ class DistributedSamplerWrapper(DistributedSampler): Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if sampler replacement is enabled. + + Note: + The purpose of this wrapper is to take care of sharding the sampler indices. It is up to the underlying + sampler to handle randomness and shuffling. The ``shuffle`` and ``seed`` arguments on this wrapper won't + have any effect. """ def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None: diff --git a/src/lightning_lite/wrappers.py b/src/lightning_lite/wrappers.py index 03e84ca4659d6..e3d3b7ae8d629 100644 --- a/src/lightning_lite/wrappers.py +++ b/src/lightning_lite/wrappers.py @@ -151,6 +151,7 @@ def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader self._device = device + self._num_iter_calls = 0 @property def device(self) -> Optional[torch.device]: @@ -160,6 +161,13 @@ def __len__(self) -> int: return len(self._dataloader) def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: + if hasattr(self._dataloader.sampler, "set_epoch"): + # Without setting the epoch, the distributed sampler would return the same indices every time, even when + # shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler. + # In Lite, we take care of this boilerplate code. + self._dataloader.sampler.set_epoch(self._num_iter_calls) + self._num_iter_calls += 1 + iterator = iter(self._dataloader) if self._device is None: yield from iterator diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 28e2c926e7861..3314860e1c9b0 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -405,7 +405,46 @@ def test_setup_dataloaders_distributed_sampler_shuffle(): for dataloader in shuffle_dataloaders: seed_everything(1) dataloader = lite.setup_dataloaders(dataloader) - assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1] + assert list(t[0].item() for t in iter(dataloader)) == [5, 2, 7, 1] + + +@pytest.mark.parametrize("shuffle", [True, False]) +@pytest.mark.parametrize("batch_size", [1, 2, 3]) +def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size): + """Test that the distributed sampler setup in Lite leads to the same sequence of data as in raw PyTorch.""" + torch.manual_seed(1) + lite = LightningLite(accelerator="cpu", strategy="ddp", devices=2) + # no lite.launch(): pretend we are on rank 0 now + + dataset = torch.arange(10) + torch_dataloader = DataLoader( + dataset, + sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle), + batch_size=batch_size, + ) + lite_dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size) + lite_dataloader = lite.setup_dataloaders(lite_dataloader) + + def fetch_epoch(loader): + iterator = iter(loader) + # we fetch 2 batches per epoch + return torch.cat((next(iterator), next(iterator))) + + # 1st epoch + # PyTorch users needs to set the epoch, while in Lite it gets handled automatically + torch_dataloader.sampler.set_epoch(0) + torch_data = fetch_epoch(torch_dataloader) + lite_data = fetch_epoch(lite_dataloader) + assert torch.equal(torch_data, lite_data) + + # 2nd epoch + # PyTorch users needs to set the epoch, while in Lite it gets handled automatically + torch_dataloader.sampler.set_epoch(1) + torch_data = fetch_epoch(torch_dataloader) + lite_data = fetch_epoch(lite_dataloader) + assert torch.equal(torch_data, lite_data) + assert torch_dataloader.sampler.epoch == 1 + assert lite_dataloader._dataloader.sampler.epoch == 1 @mock.patch.dict(os.environ, {}, clear=True) diff --git a/tests/tests_lite/test_wrappers.py b/tests/tests_lite/test_wrappers.py index 3e529b63425b4..15e97614a7176 100644 --- a/tests/tests_lite/test_wrappers.py +++ b/tests/tests_lite/test_wrappers.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import call, Mock import pytest import torch from tests_lite.helpers.runif import RunIf +from torch.utils.data import DistributedSampler from torch.utils.data.dataloader import DataLoader from lightning_lite.lite import LightningLite @@ -230,6 +231,26 @@ def test_lite_dataloader_device_placement(src_device_str, dest_device_str): assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) +def test_lite_dataloader_distributed_sampler_set_epoch(): + """Test that the LiteDataLoader calls `set_epoch()` on the wrapped sampler if applicable.""" + sampler = DistributedSampler(range(3), num_replicas=2, rank=0) + sampler.set_epoch = Mock() + dataloader = DataLoader(range(3), sampler=sampler) + lite_dataloader = _LiteDataLoader(dataloader) + iterator_epoch_0 = iter(lite_dataloader) + dataloader.sampler.set_epoch.assert_not_called() + next(iterator_epoch_0) + # .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader + assert dataloader.sampler.set_epoch.call_args_list == [call(0)] + next(iterator_epoch_0) + assert dataloader.sampler.set_epoch.call_args_list == [call(0)] + iterator_epoch_1 = iter(lite_dataloader) + assert dataloader.sampler.set_epoch.call_args_list == [call(0)] + next(iterator_epoch_1) + # with every new iterator call, the epoch increases + assert dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)] + + def test_lite_optimizer_wraps(): """Test that the LiteOptimizer fully wraps the optimizer.""" optimizer_cls = torch.optim.SGD