Skip to content

Commit

Permalink
Allowed setting attributes on DataLoader and BatchSampler when in…
Browse files Browse the repository at this point in the history
…stantiated inside `*_dataloader` hooks (#14212)
  • Loading branch information
otaj authored Aug 17, 2022
1 parent 909e7e7 commit 44cdbca
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 70 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Avoided requiring the FairScale package to use precision with the fsdp native strategy ([#14092](https://github.com/Lightning-AI/lightning/pull/14092))


- Fixed not preserving set attributes on `DataLoader` and `BatchSampler` when instantiated inside `*_dataloader` hooks ([#14212](https://github.com/Lightning-AI/lightning/pull/14212))


## [1.7.1] - 2022-08-09

### Fixed
Expand Down
8 changes: 4 additions & 4 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_init_method,
_replace_dunder_methods,
_update_dataloader,
has_iterable_dataset,
)
Expand Down Expand Up @@ -403,9 +403,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:

def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method(
BatchSampler
):
with self._strategy.model_sharded_context(), _replace_dunder_methods(
DataLoader, "dataset"
), _replace_dunder_methods(BatchSampler):
return run_method(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -248,7 +248,9 @@ def _convert_to_poptorch_loader(
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
dataloader = _reinstantiate_wrapped_cls(
dataloader, opts, *dl_args, explicit_cls=poptorch.DataLoader, **dl_kwargs
)
return dataloader

def _handle_gradient_accumulation_steps(self) -> None:
Expand Down
8 changes: 5 additions & 3 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_is_dataloader_shuffled,
_replace_init_method,
_replace_dunder_methods,
_update_dataloader,
has_iterable_dataset,
has_len_all_ranks,
Expand Down Expand Up @@ -428,9 +428,11 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
"""
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")

with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler):
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning.
# Also, it records all attribute setting and deletion using patched `__setattr__` and `__delattr__`
# methods so that the re-instantiated object is as close to the original as possible.
dataloader = source.dataloader()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
Expand Down
136 changes: 99 additions & 37 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import _is_dataclass_instance
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.enums import _FaultTolerantMode, LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.meta import _get_all_subclasses
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
Expand All @@ -49,6 +49,18 @@
warning_cache = WarningCache()


class _WrapAttrTag(LightningEnum):
SET = "set"
DEL = "del"

def __call__(self, *args):
if self == self.SET:
fn = setattr
else:
fn = delattr
return fn(*args)


def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
if isinstance(batch, Tensor):
if batch.ndim == 0:
Expand Down Expand Up @@ -189,27 +201,7 @@ def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
dl_cls = type(dataloader)
try:
dataloader = dl_cls(*dl_args, **dl_kwargs)
except TypeError as e:
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
# `__init__` arguments map to one `DataLoader.__init__` argument
import re

match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise
argument = match.groups()[0]
message = (
f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument"
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
" This argument was automatically passed to your DataLoader by PyTorch Lightning."
)
raise MisconfigurationException(message) from e
dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs)
return dataloader


Expand Down Expand Up @@ -375,7 +367,7 @@ def _dataloader_init_kwargs_resolve_sampler(
"this, expose an argument `sampler` in the `__init__` method of your custom class."
)

batch_sampler = batch_sampler_cls(*args, **kwargs)
batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
Expand Down Expand Up @@ -450,6 +442,37 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)


def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any:
constructor = type(orig_object) if explicit_cls is None else explicit_cls

try:
result = constructor(*args, **kwargs)
except TypeError as e:
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
# `__init__` arguments map to one `DataLoader.__init__` argument
import re

match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise
argument = match.groups()[0]
message = (
f"The {constructor.__name__} implementation has an error where more than one `__init__` argument"
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
" This argument was automatically passed to your object by PyTorch Lightning."
)
raise MisconfigurationException(message) from e

attrs_record = getattr(orig_object, "__pl_attrs_record", list())
for args, fn in attrs_record:
fn(result, *args)

return result


def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
Expand All @@ -458,6 +481,8 @@ def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None)
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
# We need to inspect `init`, as inspecting `obj.__init__`
# can lead to inspecting the wrong function with multiple inheritance
old_inside_init = getattr(obj, "__pl_inside_init", False)
object.__setattr__(obj, "__pl_inside_init", True)
params = inspect.signature(init).parameters

parameters_defaults = OrderedDict(
Expand All @@ -475,45 +500,82 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
}

if not hasattr(obj, "__pl_saved_args"):
obj.__pl_saved_args = args
obj.__pl_saved_kwargs = kwargs
obj.__pl_saved_arg_names = param_names
obj.__pl_saved_default_kwargs = default_kwargs
object.__setattr__(obj, "__pl_saved_args", args)
object.__setattr__(obj, "__pl_saved_kwargs", kwargs)
object.__setattr__(obj, "__pl_saved_arg_names", param_names)
object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs)

# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
# so that we can be sure, that it will not get changed anymore.
# That is why we are setting this in every `__init__`
if store_explicit_arg is not None:
if store_explicit_arg in param_names:
setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
elif store_explicit_arg in kwargs:
setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])

init(obj, *args, **kwargs)
object.__setattr__(obj, "__pl_inside_init", old_inside_init)

return wrapper


def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""

@functools.wraps(method)
def wrapper(obj: Any, *args: Any):
# First, let's find out if we're the first in inheritance chain calling the patched method.
name, *_ = args
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))
first_call = not (prev_call_name == name and prev_call_method == tag)

# Then mark the current called method
object.__setattr__(obj, "__pl_current_call", (name, tag))

# call original method
method(obj, *args)
if first_call and not getattr(obj, "__pl_inside_init", True):
# and save the value it was called with to the internal list,
# if we're outside of __init__ and the original call did not fail and we're the first call
attrs_record = getattr(obj, "__pl_attrs_record", list())
attrs_record.append((args, tag))
object.__setattr__(obj, "__pl_attrs_record", attrs_record)
object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method))

return wrapper


@contextmanager
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
It patches the ``__init__`` method.
It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods.
"""
classes = _get_all_subclasses(base_cls) | {base_cls}
for cls in classes:
# Check that __init__ belongs to the class
# https://stackoverflow.com/a/5253424
if "__init__" in cls.__dict__:
cls._old_init = cls.__init__
cls.__old__init__ = cls.__init__
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)

# we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses
# implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls`
for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)):
if patch_fn_name in cls.__dict__ or cls is base_cls:
saved_name = f"__old{patch_fn_name}"
setattr(cls, saved_name, getattr(cls, patch_fn_name))
setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag))
yield
for cls in classes:
# Check that _old_init belongs to the class
# https://stackoverflow.com/a/5253424
if "_old_init" in cls.__dict__:
cls.__init__ = cls._old_init
del cls._old_init
for patched_name in ("__setattr__", "__delattr__", "__init__"):
# Check that __old__{init,setattr,delattr} belongs to the class
# https://stackoverflow.com/a/5253424
if f"__old{patched_name}" in cls.__dict__:
setattr(cls, patched_name, getattr(cls, f"__old{patched_name}"))
delattr(cls, f"__old{patched_name}")


def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


@mock.patch("pytorch_lightning.lite.lite._replace_init_method")
@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""

Expand Down
Loading

0 comments on commit 44cdbca

Please sign in to comment.