Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support re-instantiation for custom DataLoader in Lightning #10680

Merged
merged 55 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
95d6889
move helpers to the bottom
awaelchli Nov 1, 2021
642b9c0
update docs for wrappers
awaelchli Nov 1, 2021
2c1bcfd
rename iterator variable
awaelchli Nov 1, 2021
29ae286
mention iterable in the docstring
awaelchli Nov 1, 2021
a10df79
update type
awaelchli Nov 1, 2021
134122e
add comment, improve readability
awaelchli Nov 1, 2021
16fc44d
add typing for generator
awaelchli Nov 1, 2021
8b352e2
update docs for LiteDataLoader
awaelchli Nov 1, 2021
cc2673a
every Python object has a dict
awaelchli Nov 1, 2021
61bfb09
wrap_init code improvement
awaelchli Nov 1, 2021
f047032
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2021
5e8d88e
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 2, 2021
f5b19b7
add changes from master
awaelchli Nov 2, 2021
53b15af
change order for review
awaelchli Nov 2, 2021
98834e4
fix iterator
awaelchli Nov 2, 2021
e4dd939
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2021
4b29301
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 2, 2021
c3456a7
simplify reference to old_init
awaelchli Nov 2, 2021
d60cf92
inline code
awaelchli Nov 2, 2021
0563c8c
add docs
awaelchli Nov 2, 2021
3e3d7c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2021
7411a62
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 3, 2021
23fab9e
wip
awaelchli Nov 3, 2021
ccfbf56
wip
awaelchli Nov 3, 2021
cf5923b
wip
awaelchli Nov 3, 2021
66961f3
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 5, 2021
a11c358
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2021
cf6131e
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 5, 2021
5f8e67c
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 18, 2021
6c05ab0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2021
265b8b9
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 22, 2021
c9a05e0
update docs
awaelchli Nov 22, 2021
6cc836b
update signature
awaelchli Nov 22, 2021
9c53ddc
use init reference directly
awaelchli Nov 22, 2021
bb40f54
remove
awaelchli Nov 22, 2021
5b4c1d7
remove
awaelchli Nov 22, 2021
0c69d3a
update message
awaelchli Nov 22, 2021
da5f1a0
unused imprts
awaelchli Nov 22, 2021
527fe52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2021
30eb696
move utilities and patch Lightning dataloader methods
awaelchli Nov 22, 2021
83d1884
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2021
ccfcf4f
update test
awaelchli Nov 23, 2021
9df013c
Merge remote-tracking branch 'origin/feature/patch-dataloader-init' i…
awaelchli Nov 23, 2021
fa1145a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2021
79e91c2
Merge branch 'master' into feature/patch-dataloader-init
awaelchli Nov 23, 2021
9c29200
unused import
awaelchli Nov 23, 2021
4ddb496
add tests
awaelchli Nov 23, 2021
2821b24
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2021
a1a4e98
update changelog
awaelchli Nov 23, 2021
a3c5a8c
Merge remote-tracking branch 'origin/feature/patch-dataloader-init' i…
awaelchli Nov 23, 2021
6d704e7
save subclasses
awaelchli Nov 24, 2021
7aa0efa
add a comment
awaelchli Nov 24, 2021
bb513c0
simplify test
awaelchli Nov 24, 2021
6a324cb
Merge branch 'master' into feature/patch-dataloader-init
justusschock Nov 24, 2021
dec535f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))


-
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))


-
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_method,
)
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, _update_dataloader, has_iterable_dataset
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
_update_dataloader,
has_iterable_dataset,
)
from pytorch_lightning.utilities.device_parser import _parse_devices
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand Down
49 changes: 1 addition & 48 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from contextlib import contextmanager
from itertools import chain
from typing import Any, Callable, Generator, Iterator, Optional, Set, Type, Union
from typing import Any, Callable, Generator, Iterator, Optional, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -110,49 +106,6 @@ def _convert_float_tensor(t: Tensor) -> Tensor:
return output


def _wrap_init(init: Callable) -> Callable:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
:class:`~torch.utils.data.DataLoader`."""

@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
params = dict(inspect.signature(obj.__init__).parameters)
params.pop("args", None)
params.pop("kwargs", None)
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
setattr(obj, arg_name, arg_value)
init(obj, *args, **kwargs)

return wrapper


# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
subclasses = set()

def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclasses.add(subclass)
recurse(subclass)

recurse(cls)
return subclasses


@contextmanager
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
for subclass in _get_all_subclasses(DataLoader):
subclass._old_init = subclass.__init__
subclass.__init__ = _wrap_init(subclass.__init__)
yield
for subclass in _get_all_subclasses(DataLoader):
subclass.__init__ = subclass._old_init
del subclass._old_init


class _LiteDataLoader:
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
"""The LiteDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
_update_dataloader,
has_iterable_dataset,
has_len_all_ranks,
Expand Down Expand Up @@ -430,7 +431,10 @@ def request_dataloader(

hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = source.dataloader()
with _replace_dataloader_init_method():
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
# attributes on the instance in case the dataloader needs to be re-instantiated later by Ligtning
dataloader = source.dataloader()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.training_type_plugin.barrier("get_dataloaders")
Expand Down
49 changes: 48 additions & 1 deletion pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import os
from contextlib import contextmanager
from functools import partial
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Union
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
Expand Down Expand Up @@ -305,3 +308,47 @@ def _dataloader_init_kwargs_resolve_sampler(
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)


def _wrap_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
:class:`~torch.utils.data.DataLoader`."""

@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
params = dict(inspect.signature(obj.__init__).parameters)
params.pop("args", None)
params.pop("kwargs", None)
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
setattr(obj, arg_name, arg_value)
init(obj, *args, **kwargs)

return wrapper


# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
subclasses = set()

def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclasses.add(subclass)
recurse(subclass)

recurse(cls)
return subclasses


@contextmanager
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
subclasses = _get_all_subclasses(DataLoader)
for subclass in subclasses:
subclass._old_init = subclass.__init__
subclass.__init__ = _wrap_init(subclass.__init__)
yield
for subclass in subclasses:
subclass.__init__ = subclass._old_init
del subclass._old_init
34 changes: 9 additions & 25 deletions tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,32 +164,16 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


def test_setup_dataloaders_with_custom_type():
"""Test that Lite intercepts arguments passed to custom subclasses of torch.utils.DataLoader and sets them as
attributes."""

class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute1 = attribute1
super().__init__(*args, **kwargs)

class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute1, attribute2, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute2 = attribute2
super().__init__(attribute1, *args, **kwargs)

class LiteWithCustomDataLoader(LightningLite):
@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""

class Lite(LightningLite):
def run(self):
dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"
assert dataloader.attribute2 == "attribute2"
lite_dataloader = self.setup_dataloaders(dataloader)
assert lite_dataloader.attribute1 == "attribute1"
assert lite_dataloader.attribute2 == "attribute2"

LiteWithCustomDataLoader().run()
ctx_manager().__enter__.assert_called_once()

Lite().run()
ctx_manager().__exit__.assert_called_once()


def test_setup_dataloaders_raises_for_unknown_custom_args():
Expand Down
29 changes: 9 additions & 20 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,16 @@


@RunIf(skip_windows=True)
@pytest.mark.parametrize("mode", (1, 2, 3))
@pytest.mark.parametrize("mode", (1, 2))
def test_replace_distributed_sampler(tmpdir, mode):
class IndexedRandomDataset(RandomDataset):
def __getitem__(self, index):
return self.data[index]

class CustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)

class FailureCustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
# argument `num_features` unused on purpose
# it gets automatically captured by _replace_dataloader_init_method()
super().__init__(dataset, *args, **kwargs)

class CustomBatchSampler(BatchSampler):
Expand All @@ -59,11 +56,11 @@ def on_test_start(self) -> None:
dataloader = self.trainer.test_dataloaders[0]
assert isinstance(dataloader, CustomDataLoader)
batch_sampler = dataloader.batch_sampler
if self._mode == 2:
if self._mode == 1:
assert isinstance(batch_sampler, CustomBatchSampler)
# the batch_size is set on the batch sampler
assert dataloader.batch_size is None
elif self._mode == 3:
elif self._mode == 2:
assert type(batch_sampler) is BatchSampler
assert dataloader.batch_size == self._mode
assert batch_sampler.batch_size == self._mode
Expand All @@ -74,15 +71,12 @@ def on_test_start(self) -> None:
def create_dataset(self):
dataset = IndexedRandomDataset(32, 64)
if self._mode == 1:
# this case will raise an error
return FailureCustomDataLoader(32, dataset)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if self._mode == 2:
# with a custom batch sampler
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=True)
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=1, drop_last=True)
return CustomDataLoader(32, dataset, batch_sampler=batch_sampler)
elif self._mode == 3:
elif self._mode == 2:
# with no batch sampler provided
return CustomDataLoader(32, dataset, batch_size=3, drop_last=True)
return CustomDataLoader(32, dataset, batch_size=2, drop_last=True)

def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders
Expand All @@ -93,12 +87,7 @@ def test_dataloader(self):
trainer = Trainer(
default_root_dir=tmpdir, limit_test_batches=2, strategy="ddp_find_unused_parameters_false", num_processes=1
)
if mode == 1:
match = escape("missing attributes are ['num_features']")
with pytest.raises(MisconfigurationException, match=match):
trainer.test(model)
else:
trainer.test(model)
trainer.test(model)


class TestSpawnBoringModel(BoringModel):
Expand Down
27 changes: 27 additions & 0 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.data import (
_replace_dataloader_init_method,
extract_batch_size,
get_len,
has_iterable_dataset,
Expand Down Expand Up @@ -112,3 +113,29 @@ def test_has_len_all_rank():
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model)

assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)


def test_replace_dataloader_init_method():
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
sets them as attributes."""

class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute1 = attribute1
super().__init__(*args, **kwargs)

class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute1, attribute2, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute2 = attribute2
super().__init__(attribute1, *args, **kwargs)

with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"

with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"
assert dataloader.attribute2 == "attribute2"