Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom dataloader support with Lite #10279

Merged
merged 22 commits into from
Nov 1, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Updated precision attributes in `DeepSpeedPlugin` ([#10164](https://github.com/PyTorchLightning/pytorch-lightning/pull/10164))
* Added the ability to return a result from rank 0 in `DDPSpawnPlugin.spawn` ([#10162](https://github.com/PyTorchLightning/pytorch-lightning/pull/10162))
* Added `pytorch_lightning.lite` package ([#10175](https://github.com/PyTorchLightning/pytorch-lightning/pull/10175))
* Make the `_LiteDataLoader` an iterator and add supports for custom dataloader ([#10279](https://github.com/PyTorchLightning/pytorch-lightning/pull/10279))
- Added `use_omegaconf` argument to `save_hparams_to_yaml` plugin ([#9170](https://github.com/PyTorchLightning/pytorch-lightning/pull/9170))
- Added `ckpt_path` argument for `Trainer.fit()` ([#10061](https://github.com/PyTorchLightning/pytorch-lightning/pull/10061))
- Added `auto_device_count` method to `Accelerators` ([#10222](https://github.com/PyTorchLightning/pytorch-lightning/pull/10222))
Expand Down
32 changes: 19 additions & 13 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_function,
)
from pytorch_lightning.plugins import (
DDPShardedPlugin,
DDPSpawnPlugin,
Expand Down Expand Up @@ -183,7 +188,7 @@ def setup(

def setup_dataloaders(
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
) -> Union[DataLoader, List[DataLoader], Iterable]:
) -> Union[Iterable, List[Iterable]]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Setup one or multiple dataloaders for accelerated training. If you need different settings for each
dataloader, call this method individually for each one.

Expand All @@ -208,7 +213,7 @@ def setup_dataloaders(

def _setup_dataloader(
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
) -> Union[Iterable, DataLoader]:
) -> Iterable:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Setup a single dataloader for accelerated training.

Args:
Expand All @@ -233,17 +238,18 @@ def _setup_dataloader(
)
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)

kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
device = self.device if move_to_device else None
if isinstance(self._strategy, TPUSpawnPlugin):
dataloader = DataLoader(**kwargs)
else:
dataloader = _LiteDataLoader(device=device, **kwargs)

dataloader_kwargs = TrainerDataLoadingMixin._get_dataloader_init_kwargs(dataloader, sampler)
try:
dataloader = type(dataloader)(**dataloader_kwargs)
except TypeError:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
dataloader_kwargs.pop("dataset")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
dataloader = type(dataloader)(**dataloader_kwargs)
# add worker_init_fn for correct seeding in worker processes
TrainerDataLoadingMixin._auto_add_worker_init_fn(dataloader, self.global_rank)

return self._strategy.process_dataloader(dataloader)
return _LiteDataLoader(
dataloader=self._strategy.process_dataloader(dataloader),
device=self.device if move_to_device and not isinstance(self._strategy, TPUSpawnPlugin) else None,
)

def backward(self, tensor: Tensor, *args: Any, model: Optional[_LiteModule] = None, **kwargs: Any) -> None:
"""Replaces ``loss.backward()`` in your training loop. Handles precision and automatically for you.
Expand Down Expand Up @@ -400,7 +406,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
return run_method(*args, **kwargs)

def _run_with_sharded_context(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
with self._strategy.model_sharded_context():
with self._strategy.model_sharded_context(), _replace_dataloader_init_function():
return run_method(*args, **kwargs)

def _set_plugin_specific_precision_variables(self) -> None:
Expand Down
71 changes: 61 additions & 10 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Generator, Iterator, Optional, Union
import functools
import inspect
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Type, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -100,27 +103,75 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
return output


class _LiteDataLoader(DataLoader):
def __init__(self, device: Optional[torch.device] = None, **dl_kwargs: Any) -> None:
"""The LiteDataLoader is an extension of the PyTorch :class:`~torch.utils.data.DataLoader` that adds
additional features such as moving the data to the device automatically.
def _wrap_init(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
params = dict(inspect.signature(module._old_init).parameters)
params.pop("args")
params.pop("kwargs")
for init_name, init_arg in zip(params, args):
setattr(module, init_name, init_arg)
f(module, *args, **kwargs)

return wrapper


# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
subclass_list = []

def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclass_list.append(subclass)
recurse(subclass)

recurse(cls)
return set(subclass_list)


def _enable_class(cls: Type[Any]) -> None:
cls._old_init = cls.__init__
cls.__init__ = _wrap_init(cls.__init__)


def _disable_class(cls: Type[Any]) -> None:
cls.__init__ = cls._old_init
tchaton marked this conversation as resolved.
Show resolved Hide resolved


@contextmanager
def _replace_dataloader_init_function() -> Generator:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""This context manager is used to support custom :class:`~torch.utils.data.DataLoader."""
for subclass in _get_all_subclasses(DataLoader):
_enable_class(subclass)
yield
for subclass in _get_all_subclasses(DataLoader):
_disable_class(subclass)


class _LiteDataLoader:
def __init__(self, dataloader: Iterable, device: Optional[torch.device] = None) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""The LiteDataLoader is an extension of an Iterator. It would move move the data to the device
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
automatically if the device is specified.

Args:
dataloader: The current dataloader to be used.
device: The device to which the data should be moved. By default the device is `None` and no data
transfers will be made (identical behavior as :class:`~torch.utils.data.DataLoader`).
**dl_kwargs: Accepts all arguments that the PyTorch :class:`~torch.utils.data.DataLoader` accepts.
"""
super().__init__(**dl_kwargs)
super().__init__()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.__dict__.update(getattr(dataloader, "__dict__", {}))
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._dataloader = dataloader
self._device = device
self._dataloader_iter: Optional[Iterator] = None
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@property
def device(self) -> Optional[torch.device]:
return self._device

def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this get possibly overridden by the __dict__.update above?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the __* methods don't get affected by direct update to __dict__

iterator = super().__iter__()
dataloader_iter = iter(self._dataloader)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if self._device is None:
return iterator
return dataloader_iter

for item in iterator:
for item in dataloader_iter:
yield move_data_to_device(item, self._device)
53 changes: 53 additions & 0 deletions tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,59 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


def test_setup_custom_dataloaders():
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
lite = EmptyLite()

class CustomDataLoader(DataLoader):
def __init__(self, value: int = 2, *args, **kwargs):
self.value = value
super().__init__(range(value), *args, **kwargs)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

dataloader = CustomDataLoader(2, batch_size=2)

# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
assert lite_dataloader._dataloader_iter is None
tchaton marked this conversation as resolved.
Show resolved Hide resolved
assert lite_dataloader.value == 2
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))

class CustomDataLoader2(DataLoader):
def __init__(self, range, *args, **kwargs):
self.range = range
super().__init__(range, *args, **kwargs)

dataloader = CustomDataLoader2(range(2), batch_size=2)

# single dataloader
lite_dataloader = lite.setup_dataloaders(dataloader)
assert lite_dataloader._dataloader
assert lite_dataloader._dataloader_iter is None
tchaton marked this conversation as resolved.
Show resolved Hide resolved
batch0 = next(iter(lite_dataloader))
assert torch.equal(batch0, torch.tensor([0, 1]))

class CustomDataLoader(DataLoader):
def __init__(self, value: int, *args, **kwargs):
super().__init__(range(value), *args, **kwargs)

class LiteWithCustomDataLoader(LightningLite):
def run(self):
# This doesn't fail as the context manager would save all the arguments provided
# to the dataloaders.
dataloader = CustomDataLoader(2, batch_size=2)
self.setup_dataloaders(dataloader)

LiteWithCustomDataLoader().run()

with pytest.raises(
MisconfigurationException, match="Trying to inject `DistributedSampler` into the `CustomDataLoader` instance"
):
dataloader = CustomDataLoader(2, batch_size=2)
lite_dataloader = lite.setup_dataloaders(dataloader)


def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
lite = EmptyLite()
Expand Down
14 changes: 12 additions & 2 deletions tests/lite/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import pytest
import torch
from torch.utils.data.dataloader import DataLoader

from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
Expand Down Expand Up @@ -73,8 +74,8 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
sample1 = torch.tensor(1, device=src_device)
sample2 = {"data": torch.tensor(2, device=src_device)}
sample3 = {"data": torch.tensor(3, device=src_device)}
data = [sample0, sample1, sample2, sample3]
lite_dataloader = _LiteDataLoader(device=dest_device, dataset=data, batch_size=2)
dataloader = DataLoader([sample0, sample1, sample2, sample3], batch_size=2)
lite_dataloader = _LiteDataLoader(dataloader=dataloader, device=dest_device)
iterator = iter(lite_dataloader)

batch0 = next(iterator)
Expand All @@ -83,6 +84,15 @@ def test_lite_dataloader_device_placement(src_device, dest_device):
batch1 = next(iterator)
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))

with pytest.raises(StopIteration):
batch1 = next(iterator)

lite_dataloader = _LiteDataLoader(dataloader=[sample0, sample1, sample2, sample3], device=dest_device)
iterator = iter(lite_dataloader)

batch0 = next(iterator)
assert batch0 == 0


def test_lite_optimizer_wraps():
"""Test that the LiteOptimizer fully wraps the optimizer."""
Expand Down