Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowed setting attributes on DataLoader and BatchSampler when instantiated inside *_dataloader hooks #14212

Merged
merged 11 commits into from
Aug 17, 2022
Merged
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Avoided requiring the FairScale package to use precision with the fsdp native strategy ([#14092](https://github.com/Lightning-AI/lightning/pull/14092))


- Fixed not preserving set attributes on `DataLoader` and `BatchSampler` when instantiated inside `*_dataloader` hooks ([#14212](https://github.com/Lightning-AI/lightning/pull/14212))


## [1.7.1] - 2022-08-09

### Fixed
Expand Down
8 changes: 4 additions & 4 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_init_method,
_replace_dunder_methods,
_update_dataloader,
has_iterable_dataset,
)
Expand Down Expand Up @@ -403,9 +403,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:

def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method(
BatchSampler
):
with self._strategy.model_sharded_context(), _replace_dunder_methods(
DataLoader, "dataset"
), _replace_dunder_methods(BatchSampler):
return run_method(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -248,7 +248,9 @@ def _convert_to_poptorch_loader(
dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type]
)
opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts
dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs)
dataloader = _reinstantiate_wrapped_cls(
dataloader, opts, *dl_args, explicit_cls=poptorch.DataLoader, **dl_kwargs
)
return dataloader

def _handle_gradient_accumulation_steps(self) -> None:
Expand Down
8 changes: 5 additions & 3 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_is_dataloader_shuffled,
_replace_init_method,
_replace_dunder_methods,
_update_dataloader,
has_iterable_dataset,
has_len_all_ranks,
Expand Down Expand Up @@ -428,9 +428,11 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
"""
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")

with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler):
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning.
# Also, it records all attribute setting and deletion using patched `__setattr__` and `__delattr__`
# methods so that the re-instantiated object is as close to the original as possible.
dataloader = source.dataloader()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
Expand Down
139 changes: 102 additions & 37 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.apply_func import _is_dataclass_instance
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.enums import _FaultTolerantMode, LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.meta import _get_all_subclasses
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
Expand All @@ -49,6 +49,18 @@
warning_cache = WarningCache()


class _WrapAttrTag(LightningEnum):
SET = "set"
DEL = "del"

def __call__(self, *args):
if self == self.SET:
fn = setattr
else:
otaj marked this conversation as resolved.
Show resolved Hide resolved
fn = delattr
return fn(*args)


def _extract_batch_size(batch: BType) -> Generator[int, None, None]:
if isinstance(batch, Tensor):
if batch.ndim == 0:
Expand Down Expand Up @@ -189,27 +201,7 @@ def _update_dataloader(
dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None
) -> DataLoader:
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode)
dl_cls = type(dataloader)
try:
dataloader = dl_cls(*dl_args, **dl_kwargs)
except TypeError as e:
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
# `__init__` arguments map to one `DataLoader.__init__` argument
import re

match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise
argument = match.groups()[0]
message = (
f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument"
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
" This argument was automatically passed to your DataLoader by PyTorch Lightning."
)
raise MisconfigurationException(message) from e
dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs)
return dataloader


Expand Down Expand Up @@ -375,7 +367,7 @@ def _dataloader_init_kwargs_resolve_sampler(
"this, expose an argument `sampler` in the `__init__` method of your custom class."
)

batch_sampler = batch_sampler_cls(*args, **kwargs)
batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs)
else:
try:
batch_sampler = batch_sampler_cls(
Expand Down Expand Up @@ -450,6 +442,40 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)


def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any:
if explicit_cls is None:
constructor = type(orig_object)
else:
constructor = explicit_cls
otaj marked this conversation as resolved.
Show resolved Hide resolved

try:
result = constructor(*args, **kwargs)
except TypeError as e:
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
# `__init__` arguments map to one `DataLoader.__init__` argument
import re

match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise
argument = match.groups()[0]
message = (
f"The {constructor.__name__} implementation has an error where more than one `__init__` argument"
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
" This argument was automatically passed to your object by PyTorch Lightning."
)
raise MisconfigurationException(message) from e

attrs_record = getattr(orig_object, "__pl_attrs_record", list())
for args, fn in attrs_record:
fn(result, *args)

return result


def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
Expand All @@ -458,6 +484,8 @@ def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None)
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
# We need to inspect `init`, as inspecting `obj.__init__`
# can lead to inspecting the wrong function with multiple inheritance
old_inside_init = getattr(obj, "__pl_inside_init", False)
object.__setattr__(obj, "__pl_inside_init", True)
params = inspect.signature(init).parameters

parameters_defaults = OrderedDict(
Expand All @@ -475,45 +503,82 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
}

if not hasattr(obj, "__pl_saved_args"):
obj.__pl_saved_args = args
obj.__pl_saved_kwargs = kwargs
obj.__pl_saved_arg_names = param_names
obj.__pl_saved_default_kwargs = default_kwargs
object.__setattr__(obj, "__pl_saved_args", args)
object.__setattr__(obj, "__pl_saved_kwargs", kwargs)
object.__setattr__(obj, "__pl_saved_arg_names", param_names)
object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs)
justusschock marked this conversation as resolved.
Show resolved Hide resolved

# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
# so that we can be sure, that it will not get changed anymore.
# That is why we are setting this in every `__init__`
if store_explicit_arg is not None:
if store_explicit_arg in param_names:
setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
elif store_explicit_arg in kwargs:
setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])

init(obj, *args, **kwargs)
object.__setattr__(obj, "__pl_inside_init", old_inside_init)

return wrapper


def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""

@functools.wraps(method)
def wrapper(obj: Any, *args: Any):
# First, let's find out if we're the first in inheritance chain calling the patched method.
name, *_ = args
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method"))
first_call = not (prev_call_name == name and prev_call_method == tag)

# Then mark the current called method
object.__setattr__(obj, "__pl_current_call", (name, tag))

# call original method
method(obj, *args)
if first_call and not getattr(obj, "__pl_inside_init", True):
# and save the value it was called with to the internal list,
# if we're outside of __init__ and the original call did not fail and we're the first call
attrs_record = getattr(obj, "__pl_attrs_record", list())
attrs_record.append((args, tag))
object.__setattr__(obj, "__pl_attrs_record", attrs_record)
object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method))

return wrapper


@contextmanager
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.

It patches the ``__init__`` method.
It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods.
"""
classes = _get_all_subclasses(base_cls) | {base_cls}
for cls in classes:
# Check that __init__ belongs to the class
# https://stackoverflow.com/a/5253424
if "__init__" in cls.__dict__:
cls._old_init = cls.__init__
cls.__old__init__ = cls.__init__
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)

# we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses
# implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls`
for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)):
if patch_fn_name in cls.__dict__ or cls is base_cls:
saved_name = f"__old{patch_fn_name}"
setattr(cls, saved_name, getattr(cls, patch_fn_name))
setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag))
yield
for cls in classes:
# Check that _old_init belongs to the class
# https://stackoverflow.com/a/5253424
if "_old_init" in cls.__dict__:
cls.__init__ = cls._old_init
del cls._old_init
for patched_name in ("__setattr__", "__delattr__", "__init__"):
# Check that __old__{init,setattr,delattr} belongs to the class
# https://stackoverflow.com/a/5253424
if f"__old{patched_name}" in cls.__dict__:
setattr(cls, patched_name, getattr(cls, f"__old{patched_name}"))
delattr(cls, f"__old{patched_name}")


def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset:
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


@mock.patch("pytorch_lightning.lite.lite._replace_init_method")
@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""

Expand Down
Loading