Skip to content

Commit

Permalink
Allowed custom BatchSamplers when instantiated in *_dataloader ho…
Browse files Browse the repository at this point in the history
…ok (#13640)

Co-authored-by: Rohit Gupta <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
3 people authored Jul 27, 2022
1 parent c58d351 commit 95f5f17
Show file tree
Hide file tree
Showing 8 changed files with 317 additions and 81 deletions.
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))

- Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook [#13640](https://github.com/PyTorchLightning/pytorch-lightning/pull/13640))


- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))

Expand Down
8 changes: 5 additions & 3 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
Expand All @@ -35,7 +35,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
_replace_init_method,
_update_dataloader,
has_iterable_dataset,
)
Expand Down Expand Up @@ -403,7 +403,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:

def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method(
BatchSampler
):
return run_method(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
from weakref import proxy

from torch.utils.data import DataLoader, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

import pytorch_lightning as pl
Expand All @@ -31,7 +31,7 @@
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_is_dataloader_shuffled,
_replace_dataloader_init_method,
_replace_init_method,
_update_dataloader,
has_iterable_dataset,
has_len_all_ranks,
Expand Down Expand Up @@ -424,7 +424,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
"""
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")

with _replace_dataloader_init_method():
with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
dataloader = source.dataloader()
Expand Down
14 changes: 1 addition & 13 deletions src/pytorch_lightning/utilities/auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@
from functools import partial, wraps
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union

from torch.utils.data import (
BatchSampler,
Dataset,
DistributedSampler,
get_worker_info,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
Expand Down Expand Up @@ -757,10 +749,6 @@ def _validate_map_dataset(dataloader: DataLoader) -> None:
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")

batch_sampler = getattr(dataloader, "batch_sampler", None)
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")

if type(sampler) is DistributedSampler and sampler.shuffle:
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
elif type(sampler) is RandomSampler:
Expand Down
160 changes: 124 additions & 36 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import functools
import inspect
import os
from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import fields
from functools import partial
Expand Down Expand Up @@ -220,11 +221,11 @@ def _get_dataloader_init_args_and_kwargs(
if not isinstance(dataloader, DataLoader):
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")

was_wrapped = hasattr(dataloader, "__pl_dl_args")
was_wrapped = hasattr(dataloader, "__pl_saved_args")
if was_wrapped:
dl_args = dataloader.__pl_dl_args
dl_kwargs = dataloader.__pl_dl_kwargs
arg_names = dataloader.__pl_dl_arg_names
dl_args = dataloader.__pl_saved_args
dl_kwargs = dataloader.__pl_saved_kwargs
arg_names = dataloader.__pl_saved_arg_names
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
else:
# get the dataloader instance attributes
Expand Down Expand Up @@ -323,6 +324,9 @@ def _dataloader_init_kwargs_resolve_sampler(
If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
`FastForwardSampler`.
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
automatically, since `poptorch.DataLoader` will try to increase the batch_size
"""
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
batch_sampler = getattr(dataloader, "batch_sampler")
Expand All @@ -341,11 +345,59 @@ def _dataloader_init_kwargs_resolve_sampler(
"when running on multiple IPU devices."
)
elif type(batch_sampler) is not BatchSampler or is_predicting:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
batch_sampler_cls = type(batch_sampler)
if hasattr(batch_sampler, "__pl_saved_args"):
args = batch_sampler.__pl_saved_args
kwargs = batch_sampler.__pl_saved_kwargs
default_kwargs = batch_sampler.__pl_saved_default_kwargs
arg_names = batch_sampler.__pl_saved_arg_names

if is_predicting:
success, args, kwargs = _replace_value_in_saved_args(
"drop_last", False, args, kwargs, default_kwargs, arg_names
)
if not success:
rank_zero_warn(
f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however "
f"it seems the class `{batch_sampler_cls.__qualname__}` does not support it. "
"Your predictions might be incomplete. To mitigate this, expose `drop_last` in "
"the `__init__` method of your custom class."
)

success, args, kwargs = _replace_value_in_saved_args(
"sampler", sampler, args, kwargs, default_kwargs, arg_names
)
if not success:
raise TypeError(
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
"this, expose an argument `sampler` in the `__init__` method of your custom class."
)

batch_sampler = batch_sampler_cls(*args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=(False if is_predicting else batch_sampler.drop_last),
)
except TypeError as e:
import re

match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise

# There could either be too few or too many arguments. Customizing the message based on this doesn't
# make much sense since our MisconfigurationException is going to be raised from the original one.
raise MisconfigurationException(
"We tried to re-instantiate your custom batch sampler and failed. "
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
"your custom batch sampler inside `*_dataloader` hooks of your module."
) from e

if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

Expand All @@ -368,39 +420,73 @@ def _dataloader_init_kwargs_resolve_sampler(
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}


def _replace_value_in_saved_args(
replace_key: str,
replace_value: Any,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
default_kwargs: Dict[str, Any],
arg_names: Tuple[str, ...],
) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]:
"""Tries to replace an argument value in a saved list of args and kwargs.
Returns a tuple indicating success of the operation and modified saved args and kwargs
"""

if replace_key in arg_names:
replace_index = arg_names.index(replace_key)
args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :]
return True, args, kwargs
elif replace_key in kwargs or replace_key in default_kwargs:
kwargs[replace_key] = replace_value
return True, args, kwargs

return False, args, kwargs


def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)


def _wrap_dataloader_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
of custom subclasses."""
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""

@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
# We need to inspect `init`, as inspecting `obj.__init__`
# can lead to inspecting the wrong function with multiple inheritance
params = inspect.signature(init).parameters
param_names = tuple(
param.name

parameters_defaults = OrderedDict(
(param.name, param.default)
for param in params.values()
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
)
param_names = param_names[: len(args)]

if not hasattr(obj, "__pl_dl_args"):
obj.__pl_dl_args = args
obj.__pl_dl_kwargs = kwargs
obj.__pl_dl_arg_names = param_names
param_names = tuple(parameters_defaults)[: len(args)]

# We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
default_kwargs = {
name: value
for name, value in parameters_defaults.items()
if name not in kwargs and name not in param_names and value != inspect.Parameter.empty
}

if not hasattr(obj, "__pl_saved_args"):
obj.__pl_saved_args = args
obj.__pl_saved_kwargs = kwargs
obj.__pl_saved_arg_names = param_names
obj.__pl_saved_default_kwargs = default_kwargs

# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
# so that we can be sure, that it will not get changed anymore.
# That is why we are setting this in every `__init__`
if "dataset" in param_names:
setattr(obj, "__dataset", args[param_names.index("dataset")])
elif "dataset" in kwargs:
setattr(obj, "__dataset", kwargs["dataset"])
if store_explicit_arg is not None:
if store_explicit_arg in param_names:
setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
elif store_explicit_arg in kwargs:
setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])

init(obj, *args, **kwargs)

Expand All @@ -422,15 +508,17 @@ def recurse(cl: Type[Any]) -> None:


@contextmanager
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
classes = _get_all_subclasses(DataLoader) | {DataLoader}
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
It patches the ``__init__`` method.
"""
classes = _get_all_subclasses(base_cls) | {base_cls}
wrapped = set()
for cls in classes:
if cls.__init__ not in wrapped:
cls._old_init = cls.__init__
cls.__init__ = _wrap_dataloader_init(cls.__init__)
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
wrapped.add(cls.__init__)
yield
for cls in classes:
Expand Down Expand Up @@ -475,13 +563,13 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(


def _is_dataloader_shuffled(dataloader: object) -> bool:
if hasattr(dataloader, "__pl_dl_kwargs"):
if hasattr(dataloader, "__pl_saved_kwargs"):
# this attribute is not part of PyTorch's DataLoader, but could have been set by
# our `_replace_dataloader_init_method` context manager
if "shuffle" in dataloader.__pl_dl_kwargs:
return dataloader.__pl_dl_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_dl_arg_names:
return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")]
# our `_replace_init_method` context manager
if "shuffle" in dataloader.__pl_saved_kwargs:
return dataloader.__pl_saved_kwargs["shuffle"]
if "shuffle" in dataloader.__pl_saved_arg_names:
return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
if isinstance(dataloader.dataset, IterableDataset):
# shuffling is useless with iterable datasets
return False
Expand Down
7 changes: 4 additions & 3 deletions tests/tests_pytorch/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,17 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method")
@mock.patch("pytorch_lightning.lite.lite._replace_init_method")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""

class Lite(LightningLite):
def run(self):
ctx_manager().__enter__.assert_called_once()
# One for BatchSampler, another for DataLoader
assert ctx_manager().__enter__.call_count == 2

Lite().run()
ctx_manager().__exit__.assert_called_once()
assert ctx_manager().__exit__.call_count == 2


def test_setup_dataloaders_raises_for_unknown_custom_args():
Expand Down
10 changes: 0 additions & 10 deletions tests/tests_pytorch/utilities/test_auto_restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from torch.utils.data._utils.worker import _generate_state, get_worker_info
from torch.utils.data.dataloader import DataLoader, default_collate
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.sampler import Sampler

import tests_pytorch.helpers.utils as tutils
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
Expand Down Expand Up @@ -1177,15 +1176,6 @@ class CustomRandomSampler(RandomSampler):
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)

class CustomBatchSampler(BatchSampler):
pass

sampler = Sampler(data())
batch_sampler = CustomBatchSampler(sampler, 2, False)
dl = DataLoader(data(), batch_sampler=batch_sampler)
with pytest.raises(TypeError, match="BatchSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)

class CustomIterable(IterableDataset):
pass

Expand Down
Loading

0 comments on commit 95f5f17

Please sign in to comment.