From 8ea39d2c8f68cc33273c3431a310a262e2240cf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 21 Nov 2021 02:33:13 +0100 Subject: [PATCH] LiteDataLoader code improvements and docs (#10625) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/lite/lite.py | 10 ++++++---- pytorch_lightning/lite/wrappers.py | 21 +++++++++------------ 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index ca88095dfc673..f5fdd0221cbe3 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -187,7 +187,7 @@ def setup( def setup_dataloaders( self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Union[Iterable, List[Iterable]]: + ) -> Union[DataLoader, List[DataLoader]]: """Setup one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one. @@ -212,7 +212,7 @@ def setup_dataloaders( def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True - ) -> Iterable: + ) -> DataLoader: """Setup a single dataloader for accelerated training. Args: @@ -245,7 +245,9 @@ def _setup_dataloader( dataloader = self._strategy.process_dataloader(dataloader) device = self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None - return _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=device) + lite_dataloader = cast(DataLoader, lite_dataloader) + return lite_dataloader def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None: """Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you. diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 6b8e44b610352..3cd2f5eb69712 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -15,7 +15,7 @@ import inspect from contextlib import contextmanager from itertools import chain -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, Optional, Set, Type, Union import torch from torch import nn as nn @@ -157,29 +157,26 @@ def _replace_dataloader_init_method() -> Generator: class _LiteDataLoader: - def __init__(self, dataloader: Union[Iterable, DataLoader], device: Optional[torch.device] = None) -> None: - """The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if - the device is specified. + def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None: + """The LiteDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the + device automatically if the device is specified. Args: - dataloader: The current dataloader to be used. + dataloader: The dataloader to wrap device: The device to which the data should be moved. By default the device is `None` and no data transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`). """ - super().__init__() - self.__dict__.update(getattr(dataloader, "__dict__", {})) + self.__dict__.update(dataloader.__dict__) self._dataloader = dataloader self._device = device - def __len__(self) -> Union[int, float]: - if isinstance(self._dataloader, Sized): - return len(self._dataloader) - return float("inf") - @property def device(self) -> Optional[torch.device]: return self._device + def __len__(self) -> int: + return len(self._dataloader) + def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = iter(self._dataloader) if self._device is None: