Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed sampling parity between Lite and PyTorch #16101

Merged
merged 14 commits into from
Dec 19, 2022
5 changes: 4 additions & 1 deletion src/lightning_lite/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Merged the implementation of `DDPSpawnStrategy` into `DDPStrategy` and removed `DDPSpawnStrategy` ([#14952](https://github.com/Lightning-AI/lightning/issues/14952))


- The dataloader wrapper returned from `.setup_dataloaders()` now calls `.set_epoch()` on the distributed sampler if one is used ([#16101](https://github.com/Lightning-AI/lightning/issues/16101))


### Deprecated

Expand All @@ -46,7 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Restored sampling parity between PyTorch and Lite dataloaders when using the `DistributedSampler` ([#16101](https://github.com/Lightning-AI/lightning/issues/16101))



## [1.8.4] - 2022-12-08
Expand Down
4 changes: 3 additions & 1 deletion src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from lightning_lite.plugins import Precision # avoid circular imports: # isort: split
from lightning_lite.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -583,6 +583,8 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler:
kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler))
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
if isinstance(dataloader.sampler, (RandomSampler, SequentialSampler)):
return DistributedSampler(dataloader.dataset, **kwargs)
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)

def _prepare_run_method(self) -> None:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning_lite/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ class DistributedSamplerWrapper(DistributedSampler):

Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if
sampler replacement is enabled.

Note:
The purpose of this wrapper is to take care of sharding the sampler indices. It is up to the underlying
sampler to handle randomness and shuffling. The ``shuffle`` and ``seed`` arguments on this wrapper won't
have any effect.
"""

def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None:
Expand Down
8 changes: 8 additions & 0 deletions src/lightning_lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None
self.__dict__.update(dataloader.__dict__)
self._dataloader = dataloader
self._device = device
self._epoch = 0
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

@property
def device(self) -> Optional[torch.device]:
Expand All @@ -160,6 +161,13 @@ def __len__(self) -> int:
return len(self._dataloader)

def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
if hasattr(self._dataloader.sampler, "set_epoch"):
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
# In Lite, we take care of this boilerplate code.
self._dataloader.sampler.set_epoch(self._epoch)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self._epoch += 1

iterator = iter(self._dataloader)
if self._device is None:
yield from iterator
Expand Down
41 changes: 40 additions & 1 deletion tests/tests_lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,46 @@ def test_setup_dataloaders_distributed_sampler_shuffle():
for dataloader in shuffle_dataloaders:
seed_everything(1)
dataloader = lite.setup_dataloaders(dataloader)
assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1]
assert list(t[0].item() for t in iter(dataloader)) == [5, 2, 7, 1]


@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("batch_size", [1, 2, 3])
def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size):
"""Test that the distributed sampler setup in Lite leads to the same sequence of data as in raw PyTorch."""
torch.manual_seed(1)
lite = LightningLite(accelerator="cpu", strategy="ddp", devices=2)
# no lite.launch(): pretend we are on rank 0 now

dataset = torch.rand(10, 1)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
torch_dataloader = DataLoader(
dataset,
sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle),
batch_size=batch_size,
)
lite_dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size)
lite_dataloader = lite.setup_dataloaders(lite_dataloader)

def fetch_epoch(loader):
iterator = iter(loader)
# we fetch 2 batches per epoch
return torch.cat([next(iterator).flatten() for _ in range(2)])

# 1st epoch
# PyTorch users needs to set the epoch, while in Lite it gets handled automatically
torch_dataloader.sampler.set_epoch(0)
torch_data = fetch_epoch(torch_dataloader)
lite_data = fetch_epoch(lite_dataloader)
assert torch.equal(torch_data, lite_data)

# 2nd epoch
# PyTorch users needs to set the epoch, while in Lite it gets handled automatically
torch_dataloader.sampler.set_epoch(1)
torch_data = fetch_epoch(torch_dataloader)
lite_data = fetch_epoch(lite_dataloader)
assert torch.equal(torch_data, lite_data)
assert torch_dataloader.sampler.epoch == 1
assert lite_dataloader._dataloader.sampler.epoch == 1


@mock.patch.dict(os.environ, {}, clear=True)
Expand Down
23 changes: 22 additions & 1 deletion tests/tests_lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from unittest.mock import call, Mock

import pytest
import torch
from tests_lite.helpers.runif import RunIf
from torch.utils.data import DistributedSampler
from torch.utils.data.dataloader import DataLoader

from lightning_lite.lite import LightningLite
Expand Down Expand Up @@ -230,6 +231,26 @@ def test_lite_dataloader_device_placement(src_device_str, dest_device_str):
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))


def test_lite_dataloader_distributed_sampler_set_epoch():
"""Test that the LiteDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
sampler = DistributedSampler(range(3), num_replicas=2, rank=0)
sampler.set_epoch = Mock()
dataloader = DataLoader(range(3), sampler=sampler)
lite_dataloader = _LiteDataLoader(dataloader)
iterator_epoch_0 = iter(lite_dataloader)
dataloader.sampler.set_epoch.assert_not_called()
next(iterator_epoch_0)
# .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
next(iterator_epoch_0)
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
iterator_epoch_1 = iter(lite_dataloader)
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
next(iterator_epoch_1)
# with every new iterator call, the epoch increases
assert dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)]


def test_lite_optimizer_wraps():
"""Test that the LiteOptimizer fully wraps the optimizer."""
optimizer_cls = torch.optim.SGD
Expand Down