From c6aaa50f0cb9428278c9253230dae14214dad165 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 13:02:53 +0100 Subject: [PATCH 01/13] parity test --- src/lightning_lite/lite.py | 4 ++- tests/tests_lite/utilities/test_data.py | 34 ++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index b7ba6acf52d17..163791cb1959f 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -25,7 +25,7 @@ from lightning_utilities.core.rank_zero import rank_zero_warn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler from lightning_lite.plugins import Precision # avoid circular imports: # isort: split from lightning_lite.accelerators.accelerator import Accelerator @@ -583,6 +583,8 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler: kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler)) kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) + if isinstance(dataloader.sampler, (RandomSampler, SequentialSampler)): + return DistributedSampler(dataloader.dataset, **kwargs) return DistributedSamplerWrapper(dataloader.sampler, **kwargs) def _prepare_run_method(self) -> None: diff --git a/tests/tests_lite/utilities/test_data.py b/tests/tests_lite/utilities/test_data.py index 23a84901e40e3..5d40808dcc7ad 100644 --- a/tests/tests_lite/utilities/test_data.py +++ b/tests/tests_lite/utilities/test_data.py @@ -4,8 +4,9 @@ import pytest import torch from tests_lite.helpers.models import RandomDataset, RandomIterableDataset -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler, DistributedSampler +from lightning_lite import LightningLite, seed_everything from lightning_lite.utilities.data import ( _dataloader_init_kwargs_resolve_sampler, _get_dataloader_init_args_and_kwargs, @@ -524,3 +525,34 @@ def __init__(self, indices=None, **kwargs): dataloader = ArrayAttributeDataloader(dataset) dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler) assert dl_kwargs["indices"] is dataloader.indices + + +@pytest.mark.parametrize("shuffle", [True, False]) +def test_distributed_sampler_parity(shuffle): + torch.manual_seed(1) + dataset = RandomDataset(1, 5) + + torch_dataloader = DataLoader( + dataset, + sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle), + ) + lite_dataloader = DataLoader(dataset, shuffle=shuffle) + + lite = LightningLite(accelerator="cpu", strategy="ddp", devices=2) + # no `lite.launch()`, we pretend we are on rank 0 + lite_dataloader = lite.setup_dataloaders(lite_dataloader) + + def fetch_data(loader): + # epoch 0 + iterator = iter(loader) + data0 = next(iterator) + data1 = next(iterator) + # epoch 1 + iterator = iter(loader) + data2 = next(iterator) + data3 = next(iterator) + return torch.stack((data0, data1, data2, data3)) + + torch_data = fetch_data(torch_dataloader) + lite_data = fetch_data(lite_dataloader) + assert torch.equal(torch_data, lite_data) From 38db1abd35dd10467f06184b78a114df851366ca Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 13:27:51 +0100 Subject: [PATCH 02/13] update --- src/lightning_lite/utilities/distributed.py | 5 ++- src/lightning_lite/wrappers.py | 10 +++++- tests/tests_lite/test_lite.py | 35 +++++++++++++++++++++ tests/tests_lite/utilities/test_data.py | 34 +------------------- 4 files changed, 49 insertions(+), 35 deletions(-) diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index b70c56360126a..9a92cb7728dc7 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -301,4 +301,7 @@ def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) def __iter__(self) -> Iterator: self.dataset.reset() - return (self.dataset[index] for index in super().__iter__()) + + it = (self.dataset[index] for index in super().__iter__()) + self.epoch += 1 + return it diff --git a/src/lightning_lite/wrappers.py b/src/lightning_lite/wrappers.py index 03e84ca4659d6..ea951f73d0391 100644 --- a/src/lightning_lite/wrappers.py +++ b/src/lightning_lite/wrappers.py @@ -19,7 +19,7 @@ from torch import Tensor from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, DistributedSampler from lightning_lite.plugins import Precision from lightning_lite.plugins.precision.utils import _convert_fp_tensor @@ -151,6 +151,7 @@ def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader self._device = device + self._epoch = 0 @property def device(self) -> Optional[torch.device]: @@ -160,6 +161,13 @@ def __len__(self) -> int: return len(self._dataloader) def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: + if hasattr(self._dataloader.sampler, "set_epoch"): + # Without setting the epoch, the distributed sampler would return the same indices every time, even when + # shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler. + # In Lite, we take care of this boilerplate code. + self._dataloader.sampler.set_epoch(self._epoch) + self._epoch += 1 + iterator = iter(self._dataloader) if self._device is None: yield from iterator diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 28e2c926e7861..1b3f039ce82f0 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -42,6 +42,7 @@ from lightning_lite.utilities.seed import pl_worker_init_function, seed_everything from lightning_lite.utilities.warnings import PossibleUserWarning from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from tests_lite.helpers.models import RandomDataset class EmptyLite(LightningLite): @@ -408,6 +409,40 @@ def test_setup_dataloaders_distributed_sampler_shuffle(): assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1] +@pytest.mark.parametrize("shuffle", [True, False]) +def test_setup_dataloaders_distributed_sampler_parity(shuffle): + """Test that the distributed sampler setup in Lite leads to the same sequence of data as in raw PyTorch.""" + torch.manual_seed(1) + lite = LightningLite(accelerator="cpu", strategy="ddp", devices=2) + # no lite.launch(): pretend we are on rank 0 now + + dataset = RandomDataset(1, 5) + torch_dataloader = DataLoader( + dataset, + sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle), + ) + lite_dataloader = lite.setup_dataloaders(DataLoader(dataset, shuffle=shuffle)) + + def fetch_epoch(loader): + iterator = iter(loader) + return torch.stack((next(iterator), next(iterator))) + + # 1st epoch + torch_dataloader.sampler.set_epoch(0) + torch_data = fetch_epoch(torch_dataloader) + lite_data = fetch_epoch(lite_dataloader) + assert torch.equal(torch_data, lite_data) + + # 2nd epoch + torch_dataloader.sampler.set_epoch(1) + torch_data = fetch_epoch(torch_dataloader) + lite_data = fetch_epoch(lite_dataloader) + + assert lite_dataloader._dataloader.sampler.epoch == 1 + + assert torch.equal(torch_data, lite_data) + + @mock.patch.dict(os.environ, {}, clear=True) def test_seed_everything(): """Test that seed everything is static and sets the worker init function on the dataloader.""" diff --git a/tests/tests_lite/utilities/test_data.py b/tests/tests_lite/utilities/test_data.py index 5d40808dcc7ad..23a84901e40e3 100644 --- a/tests/tests_lite/utilities/test_data.py +++ b/tests/tests_lite/utilities/test_data.py @@ -4,9 +4,8 @@ import pytest import torch from tests_lite.helpers.models import RandomDataset, RandomIterableDataset -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler, DistributedSampler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler -from lightning_lite import LightningLite, seed_everything from lightning_lite.utilities.data import ( _dataloader_init_kwargs_resolve_sampler, _get_dataloader_init_args_and_kwargs, @@ -525,34 +524,3 @@ def __init__(self, indices=None, **kwargs): dataloader = ArrayAttributeDataloader(dataset) dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler) assert dl_kwargs["indices"] is dataloader.indices - - -@pytest.mark.parametrize("shuffle", [True, False]) -def test_distributed_sampler_parity(shuffle): - torch.manual_seed(1) - dataset = RandomDataset(1, 5) - - torch_dataloader = DataLoader( - dataset, - sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle), - ) - lite_dataloader = DataLoader(dataset, shuffle=shuffle) - - lite = LightningLite(accelerator="cpu", strategy="ddp", devices=2) - # no `lite.launch()`, we pretend we are on rank 0 - lite_dataloader = lite.setup_dataloaders(lite_dataloader) - - def fetch_data(loader): - # epoch 0 - iterator = iter(loader) - data0 = next(iterator) - data1 = next(iterator) - # epoch 1 - iterator = iter(loader) - data2 = next(iterator) - data3 = next(iterator) - return torch.stack((data0, data1, data2, data3)) - - torch_data = fetch_data(torch_dataloader) - lite_data = fetch_data(lite_dataloader) - assert torch.equal(torch_data, lite_data) From c2094e52b42ac5dc660178e983b98de7504ee094 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 13:33:46 +0100 Subject: [PATCH 03/13] parametrize --- tests/tests_lite/test_lite.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 1b3f039ce82f0..da8adc0422adf 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -410,34 +410,39 @@ def test_setup_dataloaders_distributed_sampler_shuffle(): @pytest.mark.parametrize("shuffle", [True, False]) -def test_setup_dataloaders_distributed_sampler_parity(shuffle): +@pytest.mark.parametrize("batch_size", [1, 2]) +def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size): """Test that the distributed sampler setup in Lite leads to the same sequence of data as in raw PyTorch.""" torch.manual_seed(1) lite = LightningLite(accelerator="cpu", strategy="ddp", devices=2) # no lite.launch(): pretend we are on rank 0 now - dataset = RandomDataset(1, 5) + dataset = RandomDataset(2, 6) torch_dataloader = DataLoader( dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle), + batch_size=batch_size, ) - lite_dataloader = lite.setup_dataloaders(DataLoader(dataset, shuffle=shuffle)) + lite_dataloader = lite.setup_dataloaders(DataLoader(dataset, shuffle=shuffle, batch_size=batch_size)) def fetch_epoch(loader): iterator = iter(loader) return torch.stack((next(iterator), next(iterator))) # 1st epoch + # PyTorch users needs to set the epoch, while in Lite it gets handled automatically torch_dataloader.sampler.set_epoch(0) torch_data = fetch_epoch(torch_dataloader) lite_data = fetch_epoch(lite_dataloader) assert torch.equal(torch_data, lite_data) # 2nd epoch + # PyTorch users needs to set the epoch, while in Lite it gets handled automatically torch_dataloader.sampler.set_epoch(1) torch_data = fetch_epoch(torch_dataloader) lite_data = fetch_epoch(lite_dataloader) + assert torch_dataloader.sampler.epoch == 1 assert lite_dataloader._dataloader.sampler.epoch == 1 assert torch.equal(torch_data, lite_data) From 28feb787cf7625bbf610f522b097dbc6834fb71a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 13:48:18 +0100 Subject: [PATCH 04/13] parity --- tests/tests_lite/test_lite.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index da8adc0422adf..3560e7cebc277 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -42,7 +42,6 @@ from lightning_lite.utilities.seed import pl_worker_init_function, seed_everything from lightning_lite.utilities.warnings import PossibleUserWarning from lightning_lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer -from tests_lite.helpers.models import RandomDataset class EmptyLite(LightningLite): @@ -410,24 +409,26 @@ def test_setup_dataloaders_distributed_sampler_shuffle(): @pytest.mark.parametrize("shuffle", [True, False]) -@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize("batch_size", [1, 2, 3]) def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size): """Test that the distributed sampler setup in Lite leads to the same sequence of data as in raw PyTorch.""" torch.manual_seed(1) lite = LightningLite(accelerator="cpu", strategy="ddp", devices=2) # no lite.launch(): pretend we are on rank 0 now - dataset = RandomDataset(2, 6) + dataset = torch.rand(10, 1) torch_dataloader = DataLoader( dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle), batch_size=batch_size, ) - lite_dataloader = lite.setup_dataloaders(DataLoader(dataset, shuffle=shuffle, batch_size=batch_size)) + lite_dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size) + lite_dataloader = lite.setup_dataloaders(lite_dataloader) def fetch_epoch(loader): iterator = iter(loader) - return torch.stack((next(iterator), next(iterator))) + # we fetch 2 batches per epoch + return torch.cat([next(iterator).flatten() for _ in range(2)]) # 1st epoch # PyTorch users needs to set the epoch, while in Lite it gets handled automatically @@ -441,12 +442,10 @@ def fetch_epoch(loader): torch_dataloader.sampler.set_epoch(1) torch_data = fetch_epoch(torch_dataloader) lite_data = fetch_epoch(lite_dataloader) - + assert torch.equal(torch_data, lite_data) assert torch_dataloader.sampler.epoch == 1 assert lite_dataloader._dataloader.sampler.epoch == 1 - assert torch.equal(torch_data, lite_data) - @mock.patch.dict(os.environ, {}, clear=True) def test_seed_everything(): From 8437a2c1992f131469b6d226c382e36f2eeb279d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 14:32:48 +0100 Subject: [PATCH 05/13] add clarification for limitation in wrapper --- src/lightning_lite/utilities/distributed.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/lightning_lite/utilities/distributed.py b/src/lightning_lite/utilities/distributed.py index 9a92cb7728dc7..b80b18f42e71e 100644 --- a/src/lightning_lite/utilities/distributed.py +++ b/src/lightning_lite/utilities/distributed.py @@ -294,6 +294,11 @@ class DistributedSamplerWrapper(DistributedSampler): Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if sampler replacement is enabled. + + Note: + The purpose of this wrapper is to take care of sharding the sampler indices. It is up to the underlying + sampler to handle randomness and shuffling. The ``shuffle`` and ``seed`` arguments on this wrapper won't + have any effect. """ def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None: @@ -301,7 +306,4 @@ def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) def __iter__(self) -> Iterator: self.dataset.reset() - - it = (self.dataset[index] for index in super().__iter__()) - self.epoch += 1 - return it + return (self.dataset[index] for index in super().__iter__()) From 0e3083f29331e2b3e37a3baedff6d21e8b992d1c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 14:35:26 +0100 Subject: [PATCH 06/13] update --- tests/tests_lite/test_lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 3560e7cebc277..9151e8341533c 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -405,7 +405,7 @@ def test_setup_dataloaders_distributed_sampler_shuffle(): for dataloader in shuffle_dataloaders: seed_everything(1) dataloader = lite.setup_dataloaders(dataloader) - assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1] + assert list(t[0].item() for t in iter(dataloader)) == [5, 2, 7, 1] @pytest.mark.parametrize("shuffle", [True, False]) From bafb91901496b20dd48bbde362562f51058efc6d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 14:38:31 +0100 Subject: [PATCH 07/13] changel0g --- src/lightning_lite/CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/lightning_lite/CHANGELOG.md b/src/lightning_lite/CHANGELOG.md index d331abc41c3ca..7353f226c3541 100644 --- a/src/lightning_lite/CHANGELOG.md +++ b/src/lightning_lite/CHANGELOG.md @@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Merged the implementation of `DDPSpawnStrategy` into `DDPStrategy` and removed `DDPSpawnStrategy` ([#14952](https://github.com/Lightning-AI/lightning/issues/14952)) +- The dataloader wrapper returned from `.setup_dataloaders()` now calls `.set_epoch()` on the distributed sampler if one is used ([#16101](https://github.com/Lightning-AI/lightning/issues/16101)) + ### Deprecated @@ -46,7 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Restored sampling parity between PyTorch and Lite dataloaders when using the `DistributedSampler` ([#16101](https://github.com/Lightning-AI/lightning/issues/16101)) + ## [1.8.4] - 2022-12-08 From 28151dab4d70696b54c9254ca220542b9f2accff Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 14:43:22 +0100 Subject: [PATCH 08/13] unused import --- src/lightning_lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/wrappers.py b/src/lightning_lite/wrappers.py index ea951f73d0391..ab71b3d602ea2 100644 --- a/src/lightning_lite/wrappers.py +++ b/src/lightning_lite/wrappers.py @@ -19,7 +19,7 @@ from torch import Tensor from torch.nn.modules.module import _IncompatibleKeys from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import DataLoader from lightning_lite.plugins import Precision from lightning_lite.plugins.precision.utils import _convert_fp_tensor From 03bd562d5144580c4a533efa9fa4a7bf0ddd95bc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 17 Dec 2022 19:58:24 +0100 Subject: [PATCH 09/13] one more test --- tests/tests_lite/test_wrappers.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/tests_lite/test_wrappers.py b/tests/tests_lite/test_wrappers.py index 3e529b63425b4..e661c3365a28a 100644 --- a/tests/tests_lite/test_wrappers.py +++ b/tests/tests_lite/test_wrappers.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import Mock, call import pytest import torch from tests_lite.helpers.runif import RunIf +from torch.utils.data import DistributedSampler from torch.utils.data.dataloader import DataLoader from lightning_lite.lite import LightningLite @@ -230,6 +231,26 @@ def test_lite_dataloader_device_placement(src_device_str, dest_device_str): assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) +def test_lite_dataloader_distributed_sampler_set_epoch(): + """Test that the LiteDataLoader calls `set_epoch()` on the wrapped sampler if applicable.""" + sampler = DistributedSampler(range(3), num_replicas=2, rank=0) + sampler.set_epoch = Mock() + dataloader = DataLoader(range(3), sampler=sampler) + lite_dataloader = _LiteDataLoader(dataloader) + iterator_epoch_0 = iter(lite_dataloader) + dataloader.sampler.set_epoch.assert_not_called() + next(iterator_epoch_0) + # .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader + assert dataloader.sampler.set_epoch.call_args_list == [call(0)] + next(iterator_epoch_0) + assert dataloader.sampler.set_epoch.call_args_list == [call(0)] + iterator_epoch_1 = iter(lite_dataloader) + assert dataloader.sampler.set_epoch.call_args_list == [call(0)] + next(iterator_epoch_1) + # with every new iterator call, the epoch increases + assert dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)] + + def test_lite_optimizer_wraps(): """Test that the LiteOptimizer fully wraps the optimizer.""" optimizer_cls = torch.optim.SGD From 2afb8674715349717149063bb83d786a6118b597 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 17 Dec 2022 18:59:42 +0000 Subject: [PATCH 10/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_lite/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/test_wrappers.py b/tests/tests_lite/test_wrappers.py index e661c3365a28a..15e97614a7176 100644 --- a/tests/tests_lite/test_wrappers.py +++ b/tests/tests_lite/test_wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock, call +from unittest.mock import call, Mock import pytest import torch From 1ddc894920eefc332af42ecb5139b640960c30f1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 19 Dec 2022 22:14:03 +0100 Subject: [PATCH 11/13] arange --- tests/tests_lite/test_lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 9151e8341533c..7ca47466d6192 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -416,7 +416,7 @@ def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size): lite = LightningLite(accelerator="cpu", strategy="ddp", devices=2) # no lite.launch(): pretend we are on rank 0 now - dataset = torch.rand(10, 1) + dataset = torch.arange(10) torch_dataloader = DataLoader( dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle), From 501b37e5a96c9e9c36eca9e44ad840083b0736d6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 19 Dec 2022 22:15:09 +0100 Subject: [PATCH 12/13] simpler --- tests/tests_lite/test_lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/test_lite.py b/tests/tests_lite/test_lite.py index 7ca47466d6192..3314860e1c9b0 100644 --- a/tests/tests_lite/test_lite.py +++ b/tests/tests_lite/test_lite.py @@ -428,7 +428,7 @@ def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size): def fetch_epoch(loader): iterator = iter(loader) # we fetch 2 batches per epoch - return torch.cat([next(iterator).flatten() for _ in range(2)]) + return torch.cat((next(iterator), next(iterator))) # 1st epoch # PyTorch users needs to set the epoch, while in Lite it gets handled automatically From 1355fe460a27970a97c0a8e0046b176e5c4f0e5b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 19 Dec 2022 22:16:58 +0100 Subject: [PATCH 13/13] rename epoch --- src/lightning_lite/wrappers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning_lite/wrappers.py b/src/lightning_lite/wrappers.py index ab71b3d602ea2..e3d3b7ae8d629 100644 --- a/src/lightning_lite/wrappers.py +++ b/src/lightning_lite/wrappers.py @@ -151,7 +151,7 @@ def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader self._device = device - self._epoch = 0 + self._num_iter_calls = 0 @property def device(self) -> Optional[torch.device]: @@ -165,8 +165,8 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: # Without setting the epoch, the distributed sampler would return the same indices every time, even when # shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler. # In Lite, we take care of this boilerplate code. - self._dataloader.sampler.set_epoch(self._epoch) - self._epoch += 1 + self._dataloader.sampler.set_epoch(self._num_iter_calls) + self._num_iter_calls += 1 iterator = iter(self._dataloader) if self._device is None: