Skip to content

Commit

Permalink
Add custom dataloader support with Lite (#10279)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Nov 1, 2021
1 parent 828b531 commit facaff9
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164))
* Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162))
* Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175))
* Make the `_LiteDataLoader` an iterator and add supports for custom dataloader ([#10279](https://github.com/PyTorchLightning/pytorch-lightning/pull/10279))
- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
- Added `ckpt_path` argument for `Trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
Expand Down
32 changes: 19 additions & 13 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_method,
)
from pytorch_lightning.plugins import (
DDPShardedPlugin,
DDPSpawnPlugin,
Expand Down Expand Up @@ -183,7 +188,7 @@ def setup(

def setup_dataloaders(
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
) -> Union[DataLoader, List[DataLoader], Iterable]:
) -> Union[Iterable, List[Iterable]]:
"""Setup one or multiple dataloaders for accelerated training. If you need different settings for each
dataloader, call this method individually for each one.
Expand All @@ -208,7 +213,7 @@ def setup_dataloaders(

def _setup_dataloader(
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
) -> Union[Iterable, DataLoader]:
) -> Iterable:
"""Setup a single dataloader for accelerated training.
Args:
Expand All @@ -233,17 +238,18 @@ def _setup_dataloader(
)
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)

kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
device = self.device if move_to_device else None
if isinstance(self._strategy, TPUSpawnPlugin):
dataloader = DataLoader(**kwargs)
else:
dataloader = _LiteDataLoader(device=device, **kwargs)

dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
try:
dataloader = type(dataloader)(**dataloader_kwargs)
except TypeError:
dataloader_kwargs.pop("dataset")
dataloader = type(dataloader)(**dataloader_kwargs)
# add worker_init_fn for correct seeding in worker processes
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)

return self._strategy.process_dataloader(dataloader)
return _LiteDataLoader(
dataloader=self._strategy.process_dataloader(dataloader),
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
)

def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
Expand Down Expand Up @@ -400,7 +406,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
return run_method(*args, **kwargs)

def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
with self._strategy.model_sharded_context():
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
return run_method(*args, **kwargs)

def _set_plugin_specific_precision_variables(self) -> None:
Expand Down
71 changes: 61 additions & 10 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Generator, Iterator, Optional, Union
import functools
import inspect
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -100,27 +103,75 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
return output


class _LiteDataLoader(DataLoader):
def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None:
"""The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds
additional features such as moving the data to the device automatically.
def _wrap_init(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
params = dict(inspect.signature(module._old_init).parameters)
params.pop("args")
params.pop("kwargs")
for init_name, init_arg in zip(params, args):
setattr(module, init_name, init_arg)
f(module, *args, **kwargs)

return wrapper


# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
subclass_list = []

def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclass_list.append(subclass)
recurse(subclass)

recurse(cls)
return set(subclass_list)


def _enable_class(cls: Type[Any]) -> None:
cls._old_init = cls.__init__
cls.__init__ = _wrap_init(cls.__init__)


def _disable_class(cls: Type[Any]) -> None:
cls.__init__ = cls._old_init
del cls._old_init


@contextmanager
def _replace_dataloader_init_method() -> Generator:
"""This context manager is used to support custom :class:`~torch.utils.data.DataLoader."""
for subclass in _get_all_subclasses(DataLoader):
_enable_class(subclass)
yield
for subclass in _get_all_subclasses(DataLoader):
_disable_class(subclass)


class _LiteDataLoader:
def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None:
"""The LiteDataLoader is an extension of an Iterator. It would move the data to the device automatically if
the device is specified.
Args:
dataloader: The current dataloader to be used.
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
**dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts.
"""
super().__init__(**dl_kwargs)
super().__init__()
self.__dict__.update(getattr(dataloader, "__dict__", {}))
self._dataloader = dataloader
self._device = device

@property
def device(self) -> Optional[torch.device]:
return self._device

def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
iterator = super().__iter__()
dataloader_iter = iter(self._dataloader)
if self._device is None:
return iterator
return dataloader_iter

for item in iterator:
for item in dataloader_iter:
yield move_data_to_device(item, self._device)
51 changes: 51 additions & 0 deletions tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,57 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


def test_setup_custom_dataloaders():
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
lite = EmptyLite()

class CustomDataLoader(DataLoader):
def __init__(self, value: int = 2, *args, **kwargs):
self.value = value
super().__init__(range(value), *args, **kwargs)

dataloader = CustomDataLoader(2, batch_size=2)

# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
assert lite_dataloader.value == 2
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))

class CustomDataLoader2(DataLoader):
def __init__(self, range, *args, **kwargs):
self.range = range
super().__init__(range, *args, **kwargs)

dataloader = CustomDataLoader2(range(2), batch_size=2)

# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))

class CustomDataLoader(DataLoader):
def __init__(self, value: int, *args, **kwargs):
super().__init__(range(value), *args, **kwargs)

class LiteWithCustomDataLoader(LightningLite):
def run(self):
# This doesn't fail as the context manager would save all the arguments provided
# to the dataloaders.
dataloader = CustomDataLoader(2, batch_size=2)
self.setup_dataloaders(dataloader)

LiteWithCustomDataLoader().run()

with pytest.raises(
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
):
dataloader = CustomDataLoader(2, batch_size=2)
lite_dataloader = lite.setup_dataloaders(dataloader)


def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
lite = EmptyLite()
Expand Down
14 changes: 12 additions & 2 deletions tests/lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest
import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
Expand Down Expand Up @@ -73,8 +74,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
sample1 = torch.tensor(1, device=src_device)
sample2 = {"data": torch.tensor(2, device=src_device)}
sample3 = {"data": torch.tensor(3, device=src_device)}
data = [sample0, sample1, sample2, sample3]
lite_dataloader = _LiteDataLoader(device=dest_device, dataset=data, batch_size=2)
dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=dest_device)
iterator = iter(lite_dataloader)

batch0 = next(iterator)
Expand All @@ -83,6 +84,15 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
batch1 = next(iterator)
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))

with pytest.raises(StopIteration):
batch1 = next(iterator)

lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device)
iterator = iter(lite_dataloader)

batch0 = next(iterator)
assert batch0 == 0


def test_lite_optimizer_wraps():
"""Test that the LiteOptimizer fully wraps the optimizer."""
Expand Down

0 comments on commit facaff9

Please sign in to comment.