Skip to content

Commit

Permalink
Distributed sampling parity between Lite and PyTorch (#16101)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored and carmocca committed Dec 20, 2022
1 parent a2f0640 commit 222562a
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Merged the implementation of `DDPSpawnStrategy` into `DDPStrategy` and removed `DDPSpawnStrategy` ([#14952](https://github.com/Lightning-AI/lightning/issues/14952))


- The dataloader wrapper returned from `.setup_dataloaders()` now calls `.set_epoch()` on the distributed sampler if one is used ([#16101](https://github.com/Lightning-AI/lightning/issues/16101))


### Deprecated

-
Expand Down
4 changes: 3 additions & 1 deletion src/lightning_fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from lightning_fabric.plugins import Precision # avoid circular imports: # isort: split
from lightning_fabric.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -639,6 +639,8 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler:
kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler))
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
if isinstance(dataloader.sampler, (RandomSampler, SequentialSampler)):
return DistributedSampler(dataloader.dataset, **kwargs)
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)

def _prepare_run_method(self) -> None:
Expand Down
5 changes: 5 additions & 0 deletions src/lightning_fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,11 @@ class DistributedSamplerWrapper(DistributedSampler):
Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if
sampler replacement is enabled.
Note:
The purpose of this wrapper is to take care of sharding the sampler indices. It is up to the underlying
sampler to handle randomness and shuffling. The ``shuffle`` and ``seed`` arguments on this wrapper won't
have any effect.
"""

def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None:
Expand Down
8 changes: 8 additions & 0 deletions src/lightning_fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None
self.__dict__.update(dataloader.__dict__)
self._dataloader = dataloader
self._device = device
self._num_iter_calls = 0

@property
def device(self) -> Optional[torch.device]:
Expand All @@ -160,6 +161,13 @@ def __len__(self) -> int:
return len(self._dataloader)

def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
if hasattr(self._dataloader.sampler, "set_epoch"):
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
# In Lite, we take care of this boilerplate code.
self._dataloader.sampler.set_epoch(self._num_iter_calls)
self._num_iter_calls += 1

iterator = iter(self._dataloader)
if self._device is None:
yield from iterator
Expand Down
41 changes: 40 additions & 1 deletion tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,46 @@ def test_setup_dataloaders_distributed_sampler_shuffle():
for dataloader in shuffle_dataloaders:
seed_everything(1)
dataloader = lite.setup_dataloaders(dataloader)
assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1]
assert list(t[0].item() for t in iter(dataloader)) == [5, 2, 7, 1]


@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("batch_size", [1, 2, 3])
def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size):
"""Test that the distributed sampler setup in Lite leads to the same sequence of data as in raw PyTorch."""
torch.manual_seed(1)
lite = Fabric(accelerator="cpu", strategy="ddp", devices=2)
# no lite.launch(): pretend we are on rank 0 now

dataset = torch.arange(10)
torch_dataloader = DataLoader(
dataset,
sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle),
batch_size=batch_size,
)
lite_dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size)
lite_dataloader = lite.setup_dataloaders(lite_dataloader)

def fetch_epoch(loader):
iterator = iter(loader)
# we fetch 2 batches per epoch
return torch.cat((next(iterator), next(iterator)))

# 1st epoch
# PyTorch users needs to set the epoch, while in Lite it gets handled automatically
torch_dataloader.sampler.set_epoch(0)
torch_data = fetch_epoch(torch_dataloader)
lite_data = fetch_epoch(lite_dataloader)
assert torch.equal(torch_data, lite_data)

# 2nd epoch
# PyTorch users needs to set the epoch, while in Lite it gets handled automatically
torch_dataloader.sampler.set_epoch(1)
torch_data = fetch_epoch(torch_dataloader)
lite_data = fetch_epoch(lite_dataloader)
assert torch.equal(torch_data, lite_data)
assert torch_dataloader.sampler.epoch == 1
assert lite_dataloader._dataloader.sampler.epoch == 1


@mock.patch.dict(os.environ, {}, clear=True)
Expand Down
23 changes: 22 additions & 1 deletion tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from unittest.mock import call, Mock

import pytest
import torch
from tests_fabric.helpers.runif import RunIf
from torch.utils.data import DistributedSampler
from torch.utils.data.dataloader import DataLoader

from lightning_fabric.fabric import Fabric
Expand Down Expand Up @@ -230,6 +231,26 @@ def test_lite_dataloader_device_placement(src_device_str, dest_device_str):
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))


def test_lite_dataloader_distributed_sampler_set_epoch():
"""Test that the LiteDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
sampler = DistributedSampler(range(3), num_replicas=2, rank=0)
sampler.set_epoch = Mock()
dataloader = DataLoader(range(3), sampler=sampler)
lite_dataloader = _FabricDataLoader(dataloader)
iterator_epoch_0 = iter(lite_dataloader)
dataloader.sampler.set_epoch.assert_not_called()
next(iterator_epoch_0)
# .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
next(iterator_epoch_0)
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
iterator_epoch_1 = iter(lite_dataloader)
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
next(iterator_epoch_1)
# with every new iterator call, the epoch increases
assert dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)]


def test_lite_optimizer_wraps():
"""Test that the FabricOptimizer fully wraps the optimizer."""
optimizer_cls = torch.optim.SGD
Expand Down

0 comments on commit 222562a

Please sign in to comment.