From 27e22420cb2373f1bbb20eeb96458c06e1687f70 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 31 Oct 2021 19:42:38 +0000 Subject: [PATCH 01/19] update --- pytorch_lightning/lite/lite.py | 11 +++-------- pytorch_lightning/lite/wrappers.py | 28 ++++++++++++++++------------ tests/lite/test_wrappers.py | 6 +++--- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 7d0ff6a436b61..448ed5859e075 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -234,16 +234,11 @@ def _setup_dataloader( sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) - device = self.device if move_to_device else None - if isinstance(self._strategy, TPUSpawnPlugin): - dataloader = DataLoader(**kwargs) - else: - dataloader = _LiteDataLoader(device=device, **kwargs) - + dataloader = type(dataloader)(**kwargs) # add worker_init_fn for correct seeding in worker processes TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) - - return self._strategy.process_dataloader(dataloader) + dataloader = self._strategy.process_dataloader(dataloader) + return _LiteDataLoader(iterator=dataloader, device=self.device if move_to_device else None) def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 3dd387319ae68..38997de627c25 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -100,17 +100,19 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return output -class _LiteDataLoader(DataLoader): - def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None: - """The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds - additional features such as moving the data to the device automatically. +class _LiteDataLoader(Iterator): + def __init__(self, iterator: Union[Iterator, DataLoader], device: Optional[torch.device] = None) -> None: + """The LiteDataLoader is an extension of Iterator. + It would move move the data to the device automatically if the device is specified Args: + iterator: The current iterator. device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). - **dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts. """ - super().__init__(**dl_kwargs) + super().__init__() + self.__dict__.update(getattr(iterator, "__dict__", {})) + self._iterator = iterator self._device = device @property @@ -118,9 +120,11 @@ def device(self) -> Optional[torch.device]: return self._device def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: - iterator = super().__iter__() - if self._device is None: - return iterator - - for item in iterator: - yield move_data_to_device(item, self._device) + self._iterator_iter = iter(self._iterator) + return self + + def __next__(self) -> Any: + item = next(self._iterator_iter) + if self._device: + item = move_data_to_device(item, self._device) + return item diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 3e2e9ac7a9f9a..7acbee500d5bf 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -15,7 +15,7 @@ import pytest import torch - +from torch.utils.data.dataloader import DataLoader from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from tests.helpers.runif import RunIf @@ -73,8 +73,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device): sample1 = torch.tensor(1, device=src_device) sample2 = {"data": torch.tensor(2, device=src_device)} sample3 = {"data": torch.tensor(3, device=src_device)} - data = [sample0, sample1, sample2, sample3] - lite_dataloader = _LiteDataLoader(device=dest_device, dataset=data, batch_size=2) + data = DataLoader([sample0, sample1, sample2, sample3], batch_size=2) + lite_dataloader = _LiteDataLoader(iterator=data, device=dest_device) iterator = iter(lite_dataloader) batch0 = next(iterator) From dbc4615dea21c81c9d19db8a0b80d36f1fee288d Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 31 Oct 2021 19:43:44 +0000 Subject: [PATCH 02/19] update --- pytorch_lightning/lite/lite.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 448ed5859e075..992ac9313b913 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -233,12 +233,12 @@ def _setup_dataloader( ) sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) - kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) - dataloader = type(dataloader)(**kwargs) + dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) + dataloader = type(dataloader)(**dataloader_kwargs) # add worker_init_fn for correct seeding in worker processes TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) - dataloader = self._strategy.process_dataloader(dataloader) - return _LiteDataLoader(iterator=dataloader, device=self.device if move_to_device else None) + return _LiteDataLoader( + iterator=self._strategy.process_dataloader(dataloader), device=self.device if move_to_device else None) def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. From b83af51b4689c973656ff3b7209390ccd91ff3bc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 31 Oct 2021 19:45:30 +0000 Subject: [PATCH 03/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/lite/lite.py | 3 ++- pytorch_lightning/lite/wrappers.py | 4 ++-- tests/lite/test_wrappers.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 992ac9313b913..c0f5c5e84f9e7 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -238,7 +238,8 @@ def _setup_dataloader( # add worker_init_fn for correct seeding in worker processes TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) return _LiteDataLoader( - iterator=self._strategy.process_dataloader(dataloader), device=self.device if move_to_device else None) + iterator=self._strategy.process_dataloader(dataloader), device=self.device if move_to_device else None + ) def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 38997de627c25..bcd857096a1b1 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -102,8 +102,8 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LiteDataLoader(Iterator): def __init__(self, iterator: Union[Iterator, DataLoader], device: Optional[torch.device] = None) -> None: - """The LiteDataLoader is an extension of Iterator. - It would move move the data to the device automatically if the device is specified + """The LiteDataLoader is an extension of Iterator. It would move move the data to the device automatically + if the device is specified. Args: iterator: The current iterator. diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 7acbee500d5bf..2388f4d983b04 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -16,6 +16,7 @@ import pytest import torch from torch.utils.data.dataloader import DataLoader + from pytorch_lightning.lite import LightningLite from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from tests.helpers.runif import RunIf From e9af786c93ba3063e315bfecc7cf691bdcf65295 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 31 Oct 2021 19:45:50 +0000 Subject: [PATCH 04/19] update --- pytorch_lightning/lite/wrappers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 38997de627c25..06560a40a7d13 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -124,7 +124,11 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: return self def __next__(self) -> Any: - item = next(self._iterator_iter) - if self._device: - item = move_data_to_device(item, self._device) - return item + try: + item = next(self._iterator_iter) + if self._device: + item = move_data_to_device(item, self._device) + return item + except StopIteration as e: + self._iterator_iter = None + raise e From 3daa1b03fa27510eb6e3f7b696c24bfc0f7d4b77 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 31 Oct 2021 19:47:34 +0000 Subject: [PATCH 05/19] drop reference to iterator --- tests/lite/test_wrappers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 7acbee500d5bf..9f5ed22e7c55f 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from unittest.mock import ANY, Mock - +from contextlib import suppress import pytest import torch from torch.utils.data.dataloader import DataLoader @@ -83,6 +83,11 @@ def test_lite_dataloader_device_placement(src_device, dest_device): batch1 = next(iterator) assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) + assert lite_dataloader._iterator_iter + with suppress(StopIteration): + batch1 = next(iterator) + assert lite_dataloader._iterator_iter is None + def test_lite_optimizer_wraps(): """Test that the LiteOptimizer fully wraps the optimizer.""" From aaaea5963a976682fd3c219928ced7cececbbb50 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 31 Oct 2021 19:48:52 +0000 Subject: [PATCH 06/19] update --- pytorch_lightning/lite/wrappers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index c3f9009f15f9c..5dbdf24a9ce3e 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -102,11 +102,11 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LiteDataLoader(Iterator): def __init__(self, iterator: Union[Iterator, DataLoader], device: Optional[torch.device] = None) -> None: - """The LiteDataLoader is an extension of Iterator. It would move move the data to the device automatically - if the device is specified. + """The LiteDataLoader is an extension of an Iterator. It would move move the data to the device + automatically if the device is specified. Args: - iterator: The current iterator. + iterator: The current iterator to be used. device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). """ From 5927281b043b3420658919f21454456609049087 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 31 Oct 2021 19:49:13 +0000 Subject: [PATCH 07/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/lite/test_wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index cbb54fa607df9..5727cc4a6810c 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import ANY, Mock from contextlib import suppress +from unittest.mock import ANY, Mock + import pytest import torch from torch.utils.data.dataloader import DataLoader From d78d9e06e55114d81efcaae7829a99b04dae5a9c Mon Sep 17 00:00:00 2001 From: tchaton Date: Sun, 31 Oct 2021 20:04:22 +0000 Subject: [PATCH 08/19] update --- pytorch_lightning/lite/wrappers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 5dbdf24a9ce3e..7916b27d5393f 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Generator, Iterator, Optional, Union +from typing import Any, Callable, Iterable, Iterator, Optional, Union import torch from torch import nn as nn @@ -101,7 +101,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LiteDataLoader(Iterator): - def __init__(self, iterator: Union[Iterator, DataLoader], device: Optional[torch.device] = None) -> None: + def __init__(self, iterator: Union[Iterable[Any], DataLoader], device: Optional[torch.device] = None) -> None: """The LiteDataLoader is an extension of an Iterator. It would move move the data to the device automatically if the device is specified. @@ -114,17 +114,19 @@ def __init__(self, iterator: Union[Iterator, DataLoader], device: Optional[torch self.__dict__.update(getattr(iterator, "__dict__", {})) self._iterator = iterator self._device = device + self._iterator_iter: Optional[Iterator] = None @property def device(self) -> Optional[torch.device]: return self._device - def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: + def __iter__(self) -> "_LiteDataLoader": self._iterator_iter = iter(self._iterator) return self def __next__(self) -> Any: try: + assert self._iterator_iter item = next(self._iterator_iter) if self._device: item = move_data_to_device(item, self._device) From 2c8d0e2598bf9be32f6a542e3585592cf5ae4290 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 12:20:47 +0000 Subject: [PATCH 09/19] update --- pytorch_lightning/lite/lite.py | 9 +++++---- pytorch_lightning/lite/wrappers.py | 21 +++++++++++---------- tests/lite/test_lite.py | 21 +++++++++++++++++++++ tests/lite/test_wrappers.py | 14 ++++++++++---- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index c0f5c5e84f9e7..02374a0f351c6 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -183,7 +183,7 @@ def setup( def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[DataLoader, List[DataLoader], Iterable]: + ) -> Union[_LiteDataLoader, List[_LiteDataLoader]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -208,7 +208,7 @@ def setup_dataloaders( def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[Iterable, DataLoader]: + ) -> _LiteDataLoader: """Setup a single dataloader for accelerated training. Args: @@ -238,7 +238,8 @@ def _setup_dataloader( # add worker_init_fn for correct seeding in worker processes TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) return _LiteDataLoader( - iterator=self._strategy.process_dataloader(dataloader), device=self.device if move_to_device else None + dataloader=self._strategy.process_dataloader(dataloader), + device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None, ) def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 7916b27d5393f..5b5b394c4bb0a 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, Iterator, Optional, Union +from typing import Any, Callable, Iterator, Optional import torch from torch import nn as nn @@ -101,36 +101,37 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LiteDataLoader(Iterator): - def __init__(self, iterator: Union[Iterable[Any], DataLoader], device: Optional[torch.device] = None) -> None: + def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: """The LiteDataLoader is an extension of an Iterator. It would move move the data to the device automatically if the device is specified. Args: - iterator: The current iterator to be used. + dataloader: The current dataloader to be used. device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). """ super().__init__() - self.__dict__.update(getattr(iterator, "__dict__", {})) - self._iterator = iterator + self.__dict__.update(getattr(dataloader, "__dict__", {})) + self._dataloader = dataloader self._device = device - self._iterator_iter: Optional[Iterator] = None + self._dataloader_iter: Optional[Iterator] = None @property def device(self) -> Optional[torch.device]: return self._device def __iter__(self) -> "_LiteDataLoader": - self._iterator_iter = iter(self._iterator) + self._dataloader_iter = iter(self._dataloader) return self def __next__(self) -> Any: try: - assert self._iterator_iter - item = next(self._iterator_iter) + assert self._dataloader_iter + item = next(self._dataloader_iter) if self._device: item = move_data_to_device(item, self._device) return item except StopIteration as e: - self._iterator_iter = None + # drop the reference to the dataloader iterator. + self._dataloader_iter = None raise e diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 916e0aa542b32..bf1288db27b37 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -164,6 +164,27 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 +def test_setup_custom_dataloaders(): + """Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader.""" + lite = EmptyLite() + + class CustomDataLoader(DataLoader): + def __init__(self, value: int = 2, *args, **kwargs): + self.value = value + kwargs["dataset"] = range(value) + super().__init__(*args, **kwargs) + + dataloader = CustomDataLoader(2, batch_size=2) + + # single dataloader + lite_dataloader = lite.setup_dataloaders(dataloader) + assert lite_dataloader._dataloader + assert lite_dataloader._dataloader_iter is None + assert lite_dataloader.value == 2 + batch0 = next(iter(lite_dataloader)) + assert torch.equal(batch0, torch.tensor([0, 1])) + + def test_setup_dataloaders_twice_fails(): """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" lite = EmptyLite() diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 5727cc4a6810c..fd621abc274a0 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -75,8 +75,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device): sample1 = torch.tensor(1, device=src_device) sample2 = {"data": torch.tensor(2, device=src_device)} sample3 = {"data": torch.tensor(3, device=src_device)} - data = DataLoader([sample0, sample1, sample2, sample3], batch_size=2) - lite_dataloader = _LiteDataLoader(iterator=data, device=dest_device) + dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2) + lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=dest_device) iterator = iter(lite_dataloader) batch0 = next(iterator) @@ -85,10 +85,16 @@ def test_lite_dataloader_device_placement(src_device, dest_device): batch1 = next(iterator) assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) - assert lite_dataloader._iterator_iter + assert lite_dataloader._dataloader_iter with suppress(StopIteration): batch1 = next(iterator) - assert lite_dataloader._iterator_iter is None + assert lite_dataloader._dataloader_iter is None + + lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device) + iterator = iter(lite_dataloader) + + batch0 = next(iterator) + assert batch0 == 0 def test_lite_optimizer_wraps(): From 7c94c8e799f4f0dc444abc20d7c251f8643827c1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 12:23:12 +0000 Subject: [PATCH 10/19] update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 38f465bb5d584..34bc01840139c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,6 +119,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164)) * Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162)) * Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175)) + * Make the `_LiteDataLoader` an iterator and add supports for custom dataloader ([#10279](https://github.com/PyTorchLightning/pytorch-lightning/pull/10279)) - Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170)) - Added `ckpt_path` argument for `Trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061)) - Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222)) From 91a7795d3e3f6309e117371708ea2f5905619bab Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 12:35:35 +0000 Subject: [PATCH 11/19] update --- pytorch_lightning/lite/lite.py | 6 +++++- tests/lite/test_lite.py | 17 +++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 02374a0f351c6..252307e5fab95 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -234,7 +234,11 @@ def _setup_dataloader( sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs) dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler) - dataloader = type(dataloader)(**dataloader_kwargs) + try: + dataloader = type(dataloader)(**dataloader_kwargs) + except TypeError: + dataloader_kwargs.pop("dataset") + dataloader = type(dataloader)(**dataloader_kwargs) # add worker_init_fn for correct seeding in worker processes TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank) return _LiteDataLoader( diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index bf1288db27b37..303fc731f2020 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -171,8 +171,7 @@ def test_setup_custom_dataloaders(): class CustomDataLoader(DataLoader): def __init__(self, value: int = 2, *args, **kwargs): self.value = value - kwargs["dataset"] = range(value) - super().__init__(*args, **kwargs) + super().__init__(range(value), *args, **kwargs) dataloader = CustomDataLoader(2, batch_size=2) @@ -184,6 +183,20 @@ def __init__(self, value: int = 2, *args, **kwargs): batch0 = next(iter(lite_dataloader)) assert torch.equal(batch0, torch.tensor([0, 1])) + class CustomDataLoader(DataLoader): + def __init__(self, range, *args, **kwargs): + self.range = range + super().__init__(range, *args, **kwargs) + + dataloader = CustomDataLoader(range(2), batch_size=2) + + # single dataloader + lite_dataloader = lite.setup_dataloaders(dataloader) + assert lite_dataloader._dataloader + assert lite_dataloader._dataloader_iter is None + batch0 = next(iter(lite_dataloader)) + assert torch.equal(batch0, torch.tensor([0, 1])) + def test_setup_dataloaders_twice_fails(): """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" From 68857f91ddb71b01041ace02c3f73e4eb3e2a113 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 13:26:10 +0000 Subject: [PATCH 12/19] update --- pytorch_lightning/lite/lite.py | 9 ++++-- pytorch_lightning/lite/wrappers.py | 50 +++++++++++++++++++++++++++++- tests/lite/test_lite.py | 21 +++++++++++-- 3 files changed, 75 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 252307e5fab95..5f5676824a48d 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -25,7 +25,12 @@ from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler from pytorch_lightning.accelerators.accelerator import Accelerator -from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer +from pytorch_lightning.lite.wrappers import ( + _LiteDataLoader, + _LiteModule, + _LiteOptimizer, + _replace_dataloader_init_function, +) from pytorch_lightning.plugins import ( DDPShardedPlugin, DDPSpawnPlugin, @@ -401,7 +406,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: return run_method(*args, **kwargs) def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - with self._strategy.model_sharded_context(): + with self._strategy.model_sharded_context(), _replace_dataloader_init_function(): return run_method(*args, **kwargs) def _set_plugin_specific_precision_variables(self) -> None: diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 5b5b394c4bb0a..bf3a42ca65afc 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterator, Optional +import functools +import inspect +from contextlib import contextmanager +from typing import Any, Callable, Generator, Iterator, Optional import torch from torch import nn as nn @@ -100,6 +103,51 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return output +def _wrap_init(f): + @functools.wraps(f) + def wrapper(module, *args, **kwargs): + params = dict(inspect.signature(module._old_init).parameters) + params.pop("args") + params.pop("kwargs") + for init_name, init_arg in zip(params, args): + setattr(module, init_name, init_arg) + f(module, *args, **kwargs) + + return wrapper + + +# https://stackoverflow.com/a/63851681/9201239 +def _get_all_subclasses(cls): + subclass_list = [] + + def recurse(cl): + for subclass in cl.__subclasses__(): + subclass_list.append(subclass) + recurse(subclass) + + recurse(cls) + return set(subclass_list) + + +def _enable_class(cls): + cls._old_init = cls.__init__ + cls.__init__ = _wrap_init(cls.__init__) + + +def _disable_class(cls): + cls.__init__ = cls._old_init + + +@contextmanager +def _replace_dataloader_init_function() -> Generator: + """This context manager is used to support custom :class:`~torch.utils.data.DataLoader.""" + for subclass in _get_all_subclasses(DataLoader): + _enable_class(subclass) + yield + for subclass in _get_all_subclasses(DataLoader): + _disable_class(subclass) + + class _LiteDataLoader(Iterator): def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: """The LiteDataLoader is an extension of an Iterator. It would move move the data to the device diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 303fc731f2020..d3f1fd9e80db9 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -183,12 +183,12 @@ def __init__(self, value: int = 2, *args, **kwargs): batch0 = next(iter(lite_dataloader)) assert torch.equal(batch0, torch.tensor([0, 1])) - class CustomDataLoader(DataLoader): + class CustomDataLoader2(DataLoader): def __init__(self, range, *args, **kwargs): self.range = range super().__init__(range, *args, **kwargs) - dataloader = CustomDataLoader(range(2), batch_size=2) + dataloader = CustomDataLoader2(range(2), batch_size=2) # single dataloader lite_dataloader = lite.setup_dataloaders(dataloader) @@ -197,6 +197,23 @@ def __init__(self, range, *args, **kwargs): batch0 = next(iter(lite_dataloader)) assert torch.equal(batch0, torch.tensor([0, 1])) + class CustomDataLoader(DataLoader): + def __init__(self, value: int, *args, **kwargs): + super().__init__(range(value), *args, **kwargs) + + class LiteWithCustomDataLoader(LightningLite): + def run(self): + dataloader = CustomDataLoader(2, batch_size=2) + self.setup_dataloaders(dataloader) + + LiteWithCustomDataLoader().run() + + with pytest.raises( + MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance" + ): + dataloader = CustomDataLoader(2, batch_size=2) + lite_dataloader = lite.setup_dataloaders(dataloader) + def test_setup_dataloaders_twice_fails(): """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" From 0f10e2733024b313e2ecf30bd9fd68b179415b90 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 13:27:55 +0000 Subject: [PATCH 13/19] update --- tests/lite/test_lite.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index d3f1fd9e80db9..e2f33ab6feb87 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -203,6 +203,8 @@ def __init__(self, value: int, *args, **kwargs): class LiteWithCustomDataLoader(LightningLite): def run(self): + # This doesn't fail as the context manager would save all the arguments provided + # to the dataloaders. dataloader = CustomDataLoader(2, batch_size=2) self.setup_dataloaders(dataloader) From ef60ccb9d7f8f7dd9ff8a517771e8e34e93c057b Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 13:35:01 +0000 Subject: [PATCH 14/19] update --- pytorch_lightning/lite/wrappers.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index bf3a42ca65afc..101e4fca7f728 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -14,7 +14,7 @@ import functools import inspect from contextlib import contextmanager -from typing import Any, Callable, Generator, Iterator, Optional +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type import torch from torch import nn as nn @@ -103,9 +103,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: return output -def _wrap_init(f): +def _wrap_init(f: Callable) -> Callable: @functools.wraps(f) - def wrapper(module, *args, **kwargs): + def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None: params = dict(inspect.signature(module._old_init).parameters) params.pop("args") params.pop("kwargs") @@ -117,10 +117,10 @@ def wrapper(module, *args, **kwargs): # https://stackoverflow.com/a/63851681/9201239 -def _get_all_subclasses(cls): +def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: subclass_list = [] - def recurse(cl): + def recurse(cl: Type[Any]) -> None: for subclass in cl.__subclasses__(): subclass_list.append(subclass) recurse(subclass) @@ -129,12 +129,12 @@ def recurse(cl): return set(subclass_list) -def _enable_class(cls): +def _enable_class(cls: Type[Any]) -> None: cls._old_init = cls.__init__ cls.__init__ = _wrap_init(cls.__init__) -def _disable_class(cls): +def _disable_class(cls: Type[Any]) -> None: cls.__init__ = cls._old_init @@ -149,7 +149,7 @@ def _replace_dataloader_init_function() -> Generator: class _LiteDataLoader(Iterator): - def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: + def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None: """The LiteDataLoader is an extension of an Iterator. It would move move the data to the device automatically if the device is specified. From 6dffc538660f8926ea675afe5a0b5150162cf8eb Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 13:40:07 +0000 Subject: [PATCH 15/19] update --- pytorch_lightning/lite/lite.py | 6 +++--- pytorch_lightning/lite/wrappers.py | 26 +++++++++----------------- tests/lite/test_wrappers.py | 5 +---- 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 5f5676824a48d..f93083eff89c3 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -188,7 +188,7 @@ def setup( def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[_LiteDataLoader, List[_LiteDataLoader]]: + ) -> Union[Iterable, List[Iterable]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -213,7 +213,7 @@ def setup_dataloaders( def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> _LiteDataLoader: + ) -> Iterable: """Setup a single dataloader for accelerated training. Args: diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 101e4fca7f728..a4d42652ac2ee 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -14,7 +14,7 @@ import functools import inspect from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type, Union import torch from torch import nn as nn @@ -148,7 +148,7 @@ def _replace_dataloader_init_function() -> Generator: _disable_class(subclass) -class _LiteDataLoader(Iterator): +class _LiteDataLoader: def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None: """The LiteDataLoader is an extension of an Iterator. It would move move the data to the device automatically if the device is specified. @@ -168,18 +168,10 @@ def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) def device(self) -> Optional[torch.device]: return self._device - def __iter__(self) -> "_LiteDataLoader": - self._dataloader_iter = iter(self._dataloader) - return self - - def __next__(self) -> Any: - try: - assert self._dataloader_iter - item = next(self._dataloader_iter) - if self._device: - item = move_data_to_device(item, self._device) - return item - except StopIteration as e: - # drop the reference to the dataloader iterator. - self._dataloader_iter = None - raise e + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: + dataloader_iter = iter(self._dataloader) + if self._device is None: + return dataloader_iter + + for item in dataloader_iter: + yield move_data_to_device(item, self._device) diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index fd621abc274a0..4dd7b4a890648 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import suppress from unittest.mock import ANY, Mock import pytest @@ -85,10 +84,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device): batch1 = next(iterator) assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device)) - assert lite_dataloader._dataloader_iter - with suppress(StopIteration): + with pytest.raises(StopIteration): batch1 = next(iterator) - assert lite_dataloader._dataloader_iter is None lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device) iterator = iter(lite_dataloader) From 02f2b20fd371e51eca99635639238d16548d811f Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 17:31:58 +0000 Subject: [PATCH 16/19] update on comments --- pytorch_lightning/lite/lite.py | 4 ++-- pytorch_lightning/lite/wrappers.py | 3 +-- tests/lite/test_lite.py | 2 -- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index f93083eff89c3..2e6f10d356fe0 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -29,7 +29,7 @@ _LiteDataLoader, _LiteModule, _LiteOptimizer, - _replace_dataloader_init_function, + _replace_dataloader_init_method, ) from pytorch_lightning.plugins import ( DDPShardedPlugin, @@ -406,7 +406,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: return run_method(*args, **kwargs) def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: - with self._strategy.model_sharded_context(), _replace_dataloader_init_function(): + with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): return run_method(*args, **kwargs) def _set_plugin_specific_precision_variables(self) -> None: diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index a4d42652ac2ee..bed7c23bad94e 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -139,7 +139,7 @@ def _disable_class(cls: Type[Any]) -> None: @contextmanager -def _replace_dataloader_init_function() -> Generator: +def _replace_dataloader_init_method() -> Generator: """This context manager is used to support custom :class:`~torch.utils.data.DataLoader.""" for subclass in _get_all_subclasses(DataLoader): _enable_class(subclass) @@ -162,7 +162,6 @@ def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) self.__dict__.update(getattr(dataloader, "__dict__", {})) self._dataloader = dataloader self._device = device - self._dataloader_iter: Optional[Iterator] = None @property def device(self) -> Optional[torch.device]: diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index e2f33ab6feb87..8eac30f9cf823 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -178,7 +178,6 @@ def __init__(self, value: int = 2, *args, **kwargs): # single dataloader lite_dataloader = lite.setup_dataloaders(dataloader) assert lite_dataloader._dataloader - assert lite_dataloader._dataloader_iter is None assert lite_dataloader.value == 2 batch0 = next(iter(lite_dataloader)) assert torch.equal(batch0, torch.tensor([0, 1])) @@ -193,7 +192,6 @@ def __init__(self, range, *args, **kwargs): # single dataloader lite_dataloader = lite.setup_dataloaders(dataloader) assert lite_dataloader._dataloader - assert lite_dataloader._dataloader_iter is None batch0 = next(iter(lite_dataloader)) assert torch.equal(batch0, torch.tensor([0, 1])) From 45c9b8ec6379a8442ace553e8f28f08a9ac8a608 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Nov 2021 17:43:32 +0000 Subject: [PATCH 17/19] update --- pytorch_lightning/lite/wrappers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index bed7c23bad94e..b7a03deadab0b 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -136,6 +136,7 @@ def _enable_class(cls: Type[Any]) -> None: def _disable_class(cls: Type[Any]) -> None: cls.__init__ = cls._old_init + del cls._old_init @contextmanager From 009a02e3a9ac69b466d18fc76d16e347a730ca77 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 1 Nov 2021 23:18:16 +0530 Subject: [PATCH 18/19] Update pytorch_lightning/lite/wrappers.py --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index b7a03deadab0b..d0eab91e2c7a7 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -151,7 +151,7 @@ def _replace_dataloader_init_method() -> Generator: class _LiteDataLoader: def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None: - """The LiteDataLoader is an extension of an Iterator. It would move move the data to the device + """The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if the device is specified. Args: From 18ee42f4059be3d2cd5b2e153d9ccc3320ff4764 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Nov 2021 17:49:31 +0000 Subject: [PATCH 19/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/lite/wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index d0eab91e2c7a7..d9acba70bcba1 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -151,8 +151,8 @@ def _replace_dataloader_init_method() -> Generator: class _LiteDataLoader: def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None: - """The LiteDataLoader is an extension of an Iterator. It would move the data to the device - automatically if the device is specified. + """The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if + the device is specified. Args: dataloader: The current dataloader to be used.