From b0875cc1140b147dfae2678a9d7686e43318d400 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Nov 2021 01:01:04 +0100 Subject: [PATCH 1/7] dataloader code improvements and docs --- pytorch_lightning/lite/wrappers.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index ff95e89d1d2cf..bb6f9c2a8c0ed 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -157,29 +157,27 @@ def _replace_dataloader_init_method() -> Generator: class _LiteDataLoader: - def __init__(self, dataloader: Union[Iterable, DataLoader], device: Optional[torch.device] = None) -> None: - """The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if - the device is specified. + def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: + """The LiteDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the + device automatically if the device is specified. Args: - dataloader: The current dataloader to be used. + dataloader: The dataloader to wrap device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). """ super().__init__() - self.__dict__.update(getattr(dataloader, "__dict__", {})) + self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader self._device = device - def __len__(self) -> Union[int, float]: - if isinstance(self._dataloader, Sized): - return len(self._dataloader) - return float("inf") - @property def device(self) -> Optional[torch.device]: return self._device + def __len__(self) -> int: + return len(self._dataloader) + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = iter(self._dataloader) if self._device is None: From bae170de7e84809724cae6af29d0ecc4061b2bc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Nov 2021 01:10:24 +0100 Subject: [PATCH 2/7] remove unused imports --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index bb6f9c2a8c0ed..b91c2e10c3d0f 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -15,7 +15,7 @@ import inspect from contextlib import contextmanager from itertools import chain -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, Optional, Set, Type, Union import torch from torch import nn as nn From f6b11ed2eda5d627e6c779f2ea2aafa80d9686aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Nov 2021 22:07:13 +0100 Subject: [PATCH 3/7] Update pytorch_lightning/lite/wrappers.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/lite/wrappers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index b91c2e10c3d0f..938eb72afe622 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -166,7 +166,6 @@ def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). """ - super().__init__() self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader self._device = device From 24a85ced9c84d62ff0118a92da441fae486c021e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Nov 2021 22:21:01 +0100 Subject: [PATCH 4/7] update type --- pytorch_lightning/lite/lite.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index bb07c763156aa..ba545c38a4917 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -188,7 +188,7 @@ def setup( def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[Iterable, List[Iterable]]: + ) -> Union[_LiteDataLoader, List[_LiteDataLoader]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -213,7 +213,7 @@ def setup_dataloaders( def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Iterable: + ) -> _LiteDataLoader: """Setup a single dataloader for accelerated training. Args: From a60037de22838d9ce135920dbd3560d59fbcfaae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Nov 2021 22:37:27 +0100 Subject: [PATCH 5/7] typing hell --- pytorch_lightning/lite/lite.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index ba545c38a4917..e554a3e61495a 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union, cast import torch import torch.nn as nn @@ -188,7 +188,7 @@ def setup( def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[_LiteDataLoader, List[_LiteDataLoader]]: + ) -> Union[DataLoader, List[DataLoader]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -213,7 +213,7 @@ def setup_dataloaders( def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> _LiteDataLoader: + ) -> DataLoader: """Setup a single dataloader for accelerated training. Args: @@ -246,7 +246,9 @@ def _setup_dataloader( dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None - return _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = cast(DataLoader, lite_dataloader) + return lite_dataloader def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. From b9bb5c320d90c115b86a3dca9e6ad834953ef3be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Nov 2021 22:43:34 +0100 Subject: [PATCH 6/7] drop unused import --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index e554a3e61495a..e0f82018d3b6e 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast import torch import torch.nn as nn From 413373c344477a0015439923f41b1e62bdaed5f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Nov 2021 21:45:34 +0000 Subject: [PATCH 7/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index e0f82018d3b6e..c26783c4d8bb9 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn