From 804fc9d7e9a264e82120cd53aeb4354dafae80a8 Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 15 Aug 2022 16:35:40 +0200 Subject: [PATCH 1/8] Replace __setattr__ and __delattr__ --- src/pytorch_lightning/lite/lite.py | 8 +- src/pytorch_lightning/strategies/ipu.py | 4 +- .../trainer/connectors/data_connector.py | 8 +- src/pytorch_lightning/utilities/data.py | 139 +++++++++++++----- tests/tests_pytorch/utilities/test_data.py | 93 +++++++++--- 5 files changed, 183 insertions(+), 69 deletions(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 981eed30635f6..ca45a4011fcdd 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -35,7 +35,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, - _replace_init_method, + _replace_dunder_methods, _update_dataloader, has_iterable_dataset, ) @@ -403,9 +403,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() - with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method( - BatchSampler - ): + with self._strategy.model_sharded_context(), _replace_dunder_methods( + DataLoader, "dataset" + ), _replace_dunder_methods(BatchSampler): return run_method(*args, **kwargs) def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index f56c095dc12c1..bf051d33b7aae 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs +from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -248,7 +248,7 @@ def _convert_to_poptorch_loader( dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type] ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts - dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) + dataloader = _reinstantiate_wrapped_cls(dataloader, opts, *dl_args, poptorch.DataLoader, **dl_kwargs) return dataloader def _handle_gradient_accumulation_steps(self) -> None: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 6e592b9f6d310..e20eac2ffae57 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _is_dataloader_shuffled, - _replace_init_method, + _replace_dunder_methods, _update_dataloader, has_iterable_dataset, has_len_all_ranks, @@ -428,9 +428,11 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat """ source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source") - with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler): + with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as - # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning + # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning. + # Also, it records all attribute setting and deletion using patched `__setattr__` and `__delattr__` + # methods so that the re-instantiated object is as close to the original as possible. dataloader = source.dataloader() if isinstance(dataloader, tuple): dataloader = list(dataloader) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index b625a046f6122..b67369adb487d 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -37,7 +37,7 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.apply_func import _is_dataclass_instance from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler -from pytorch_lightning.utilities.enums import _FaultTolerantMode +from pytorch_lightning.utilities.enums import _FaultTolerantMode, LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.meta import _get_all_subclasses from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -49,6 +49,18 @@ warning_cache = WarningCache() +class _WrapAttrTag(LightningEnum): + SET = "set" + DEL = "del" + + def __call__(self, *args): + if self == self.SET: + fn = setattr + else: + fn = delattr + return fn(*args) + + def _extract_batch_size(batch: BType) -> Generator[int, None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: @@ -189,27 +201,7 @@ def _update_dataloader( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None ) -> DataLoader: dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) - dl_cls = type(dataloader) - try: - dataloader = dl_cls(*dl_args, **dl_kwargs) - except TypeError as e: - # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass - # `__init__` arguments map to one `DataLoader.__init__` argument - import re - - match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) - if not match: - # an unexpected `TypeError`, continue failure - raise - argument = match.groups()[0] - message = ( - f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument" - f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" - f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." - f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." - " This argument was automatically passed to your DataLoader by PyTorch Lightning." - ) - raise MisconfigurationException(message) from e + dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) return dataloader @@ -375,7 +367,7 @@ def _dataloader_init_kwargs_resolve_sampler( "this, expose an argument `sampler` in the `__init__` method of your custom class." ) - batch_sampler = batch_sampler_cls(*args, **kwargs) + batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs) else: try: batch_sampler = batch_sampler_cls( @@ -450,6 +442,40 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) +def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any: + if explicit_cls is None: + constructor = type(orig_object) + else: + constructor = explicit_cls + + try: + result = constructor(*args, **kwargs) + except TypeError as e: + # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass + # `__init__` arguments map to one `DataLoader.__init__` argument + import re + + match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + argument = match.groups()[0] + message = ( + f"The {constructor.__name__} implementation has an error where more than one `__init__` argument" + f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" + f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." + f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." + " This argument was automatically passed to your object by PyTorch Lightning." + ) + raise MisconfigurationException(message) from e + + attrs_record = getattr(orig_object, "__pl_attrs_record", list()) + for args, fn in attrs_record: + fn(result, *args) + + return result + + def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable: """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @@ -458,6 +484,8 @@ def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: # We need to inspect `init`, as inspecting `obj.__init__` # can lead to inspecting the wrong function with multiple inheritance + old_inside_init = getattr(obj, "__pl_inside_init", False) + object.__setattr__(obj, "__pl_inside_init", True) params = inspect.signature(init).parameters parameters_defaults = OrderedDict( @@ -475,45 +503,82 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: } if not hasattr(obj, "__pl_saved_args"): - obj.__pl_saved_args = args - obj.__pl_saved_kwargs = kwargs - obj.__pl_saved_arg_names = param_names - obj.__pl_saved_default_kwargs = default_kwargs + object.__setattr__(obj, "__pl_saved_args", args) + object.__setattr__(obj, "__pl_saved_kwargs", kwargs) + object.__setattr__(obj, "__pl_saved_arg_names", param_names) + object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs) # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class) # so that we can be sure, that it will not get changed anymore. # That is why we are setting this in every `__init__` if store_explicit_arg is not None: if store_explicit_arg in param_names: - setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) + object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) elif store_explicit_arg in kwargs: - setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) + object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) init(obj, *args, **kwargs) + object.__setattr__(obj, "__pl_inside_init", old_inside_init) + + return wrapper + + +def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: + """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and + :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" + + @functools.wraps(method) + def wrapper(obj: Any, *args: Any): + # First, let's find out if we're the first in inheritance chain calling the patched method. + name, *_ = args + prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method")) + first_call = not (prev_call_name == name and prev_call_method == tag) + + # Then mark the current called method + object.__setattr__(obj, "__pl_current_call", (name, tag)) + + # call original method + method(obj, *args) + if first_call and not getattr(obj, "__pl_inside_init", True): + # and save the value it was called with to the internal list, + # if we're outside of __init__ and the original call did not fail and we're the first call + attrs_record = getattr(obj, "__pl_attrs_record", list()) + attrs_record.append((args, tag)) + object.__setattr__(obj, "__pl_attrs_record", attrs_record) + object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method)) return wrapper @contextmanager -def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: +def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. - It patches the ``__init__`` method. + It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. """ classes = _get_all_subclasses(base_cls) | {base_cls} for cls in classes: # Check that __init__ belongs to the class # https://stackoverflow.com/a/5253424 if "__init__" in cls.__dict__: - cls._old_init = cls.__init__ + cls.__old__init__ = cls.__init__ cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) + + # we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses + # implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls` + for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)): + if patch_fn_name in cls.__dict__ or cls is base_cls: + saved_name = f"__old{patch_fn_name}" + setattr(cls, saved_name, getattr(cls, patch_fn_name)) + setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag)) yield for cls in classes: - # Check that _old_init belongs to the class - # https://stackoverflow.com/a/5253424 - if "_old_init" in cls.__dict__: - cls.__init__ = cls._old_init - del cls._old_init + for patched_name in ("__setattr__", "__delattr__", "__init__"): + # Check that __old__{init,setattr,delattr} belongs to the class + # https://stackoverflow.com/a/5253424 + if f"__old{patched_name}" in cls.__dict__: + setattr(cls, patched_name, getattr(cls, f"__old{patched_name}")) + delattr(cls, f"__old{patched_name}") def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset: diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index cc70417988616..5ad3706ea5def 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -13,9 +13,10 @@ from pytorch_lightning.utilities.data import ( _dataloader_init_kwargs_resolve_sampler, _get_dataloader_init_args_and_kwargs, - _replace_init_method, + _replace_dunder_methods, _replace_value_in_saved_args, _update_dataloader, + _WrapAttrTag, extract_batch_size, get_len, has_iterable_dataset, @@ -144,10 +145,10 @@ def __init__(self, foo, *args, **kwargs): super().__init__(foo, *args, **kwargs) dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) - with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): + with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"): _update_dataloader(dataloader, dataloader.sampler) - with _replace_init_method(DataLoader, "dataset"): + with _replace_dunder_methods(DataLoader, "dataset"): dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) new_dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(new_dataloader, BadStandaloneGoodHookImpl) @@ -159,7 +160,7 @@ def __init__(self, randomize, *args, **kwargs): super().__init__(*args, shuffle=randomize, **kwargs) dataloader = BadImpl(False, []) - with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"): + with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"): _update_dataloader(dataloader, dataloader.sampler) class GoodImpl(DataLoader): @@ -173,28 +174,28 @@ def __init__(self, randomize, *args, **kwargs): assert isinstance(new_dataloader, GoodImpl) -def test_replace_init_method_multiple_loaders_without_init(): +def test_replace_dunder_methods_multiple_loaders_without_init(): """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__` - method (the one we are wrapping), it can happen, that `hasattr(cls, "_old_init")` is True because of parent + method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error occured only sometimes because it depends on the order in which we are iterating over a set of classes we are patching. This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__` - and are children of `DataLoader`. We are testing that a) context manager `_replace_init_method` exits cleanly, and - b) the mechanism checking for presence of `_old_init` works as expected. + and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and + b) the mechanism checking for presence of `__old__init__` works as expected. """ classes = [DataLoader] for i in range(100): classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {})) - with _replace_init_method(DataLoader, "dataset"): + with _replace_dunder_methods(DataLoader, "dataset"): for cls in classes[1:]: # First one is `DataLoader` - assert "_old_init" not in cls.__dict__ - assert hasattr(cls, "_old_init") + assert "__old__init__" not in cls.__dict__ + assert hasattr(cls, "__old__init__") - assert "_old_init" in DataLoader.__dict__ - assert hasattr(DataLoader, "_old_init") + assert "__old__init__" in DataLoader.__dict__ + assert hasattr(DataLoader, "__old__init__") class DataLoaderSubclass1(DataLoader): @@ -322,8 +323,8 @@ def __init__(self, dataset, **kwargs): pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"), ], ) -def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): - with _replace_init_method(DataLoader, "dataset"): +def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): + with _replace_dunder_methods(DataLoader, "dataset"): dataloader = cls(*args, **kwargs) assert dataloader.__pl_saved_args == args @@ -360,12 +361,12 @@ def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, c assert dataloader_value == value -def test_replace_init_method_extra_kwargs(): +def test_replace_dunder_methods_extra_kwargs(): class LoaderSubclass(DataLoader): def __init__(self, dataset, *args, batch_size=10, **kwargs): super().__init__(dataset, *args, batch_size=batch_size, **kwargs) - with _replace_init_method(DataLoader, "dataset"): + with _replace_dunder_methods(DataLoader, "dataset"): dataloader = LoaderSubclass(range(10)) assert dataloader.__pl_saved_args == (range(10),) @@ -375,6 +376,52 @@ def __init__(self, dataset, *args, batch_size=10, **kwargs): assert dataloader.__dataset == range(10) +def test_replace_dunder_methods_attrs(): + """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` + are correctly preserved even after reinstantiation. + + It also includes a custom `__setattr__` + """ + + class Loader(DataLoader): + def __setattr__(self, attr, val): + if attr == "custom_arg": + val = val + 2 + super().__setattr__(attr, val) + + with _replace_dunder_methods(DataLoader, "dataset"): + dataloader = Loader(range(10)) + dataloader.custom_arg = 5 + dataloader.my_arg = 10 + dataloader.another_arg = 100 + del dataloader.dataset + try: + del dataloader.abc_arg + except AttributeError: + pass + + assert dataloader.__pl_saved_args == (range(10),) + assert dataloader.__pl_saved_kwargs == {} + assert dataloader.__pl_saved_arg_names == ("dataset",) + assert dataloader.__dataset == range(10) + assert dataloader.custom_arg == 7 + assert dataloader.my_arg == 10 + assert dataloader.another_arg == 100 + assert not hasattr(dataloader, "dataset") + assert dataloader.__pl_attrs_record == [ + (("custom_arg", 5), _WrapAttrTag.SET), + (("my_arg", 10), _WrapAttrTag.SET), + (("another_arg", 100), _WrapAttrTag.SET), + (("dataset",), _WrapAttrTag.DEL), + ] + + dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert dataloader.custom_arg == 7 + assert dataloader.my_arg == 10 + assert dataloader.another_arg == 100 + assert not hasattr(dataloader, "dataset") + + @pytest.mark.parametrize("predicting", [True, False]) def test_custom_batch_sampler(predicting): """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to @@ -391,8 +438,8 @@ def __init__(self, sampler, extra_arg, drop_last=True): super().__init__(sampler, 10, drop_last) sampler = RandomSampler(range(10)) - with _replace_init_method(BatchSampler): - # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + with _replace_dunder_methods(BatchSampler): + # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler(sampler, "random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) @@ -437,8 +484,8 @@ def __init__(self, sampler, extra_arg): super().__init__(sampler, 10, False) sampler = RandomSampler(range(10)) - with _replace_init_method(BatchSampler): - # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + with _replace_dunder_methods(BatchSampler): + # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler(sampler, "random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) @@ -464,8 +511,8 @@ def __init__(self, extra_arg): self.extra_arg = extra_arg super().__init__(RandomSampler(range(10)), 10, False) - with _replace_init_method(BatchSampler): - # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + with _replace_dunder_methods(BatchSampler): + # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler("random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) From 8cce95a605a7f99596a7e9bbdfeb96993966a216 Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 15 Aug 2022 16:50:11 +0200 Subject: [PATCH 2/8] changelog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 6aa6a9c7d8037..5342faf06f77e 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Avoided requiring the FairScale package to use precision with the fsdp native strategy ([#14092](https://github.com/Lightning-AI/lightning/pull/14092)) +- Fixed not preserving set attributes on `DataLoader` and `BatchSampler` when instantiated inside `*_dataloader` hooks ([#14212](https://github.com/Lightning-AI/lightning/pull/14212)) + + ## [1.7.1] - 2022-08-09 ### Fixed From 1f7b98ffaeb66e65e2a096c9a17d0c763c49ee6e Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 15 Aug 2022 17:16:51 +0200 Subject: [PATCH 3/8] fix ipu tests --- src/pytorch_lightning/strategies/ipu.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index bf051d33b7aae..b254c5df16ca5 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -248,7 +248,9 @@ def _convert_to_poptorch_loader( dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type] ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts - dataloader = _reinstantiate_wrapped_cls(dataloader, opts, *dl_args, poptorch.DataLoader, **dl_kwargs) + dataloader = _reinstantiate_wrapped_cls( + dataloader, opts, *dl_args, explicit_cls=poptorch.DataLoader, **dl_kwargs + ) return dataloader def _handle_gradient_accumulation_steps(self) -> None: From 9c8f4c8af7bf1289f2ff0056edbcaefc3e8f0c17 Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 15 Aug 2022 17:34:51 +0200 Subject: [PATCH 4/8] fix lite test --- tests/tests_pytorch/lite/test_lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index 86a0a5a82195a..d45046f249d54 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -177,7 +177,7 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 -@mock.patch("pytorch_lightning.lite.lite._replace_init_method") +@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods") def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): """Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method.""" From 700f7b5e6374c69b867147a3ee183f73963a12b5 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 16 Aug 2022 08:52:52 +0200 Subject: [PATCH 5/8] added one extra test --- tests/tests_pytorch/utilities/test_data.py | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 5ad3706ea5def..9e3d04ae65560 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -189,6 +189,8 @@ def test_replace_dunder_methods_multiple_loaders_without_init(): for i in range(100): classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {})) + before = {cls: cls.__init__ for cls in classes} + with _replace_dunder_methods(DataLoader, "dataset"): for cls in classes[1:]: # First one is `DataLoader` assert "__old__init__" not in cls.__dict__ @@ -197,6 +199,9 @@ def test_replace_dunder_methods_multiple_loaders_without_init(): assert "__old__init__" in DataLoader.__dict__ assert hasattr(DataLoader, "__old__init__") + for cls in classes: + assert before[cls] == cls.__init__ + class DataLoaderSubclass1(DataLoader): def __init__(self, attribute1, *args, **kwargs): @@ -422,6 +427,44 @@ def __setattr__(self, attr, val): assert not hasattr(dataloader, "dataset") +def test_replace_dunder_methods_restore_methods(): + """This tests checks whether are all dunder methods restored to their original versions.""" + + class Init(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class SetAttr(DataLoader): + def __setattr__(self, *args): + return super().__setattr__(*args) + + class DelAttr(DataLoader): + def __delattr__(self, *args): + return super().__delattr__(*args) + + class InitAndSetAttr(Init, SetAttr): + pass + + class InitAndDelAttr(Init, DelAttr): + pass + + class SetAttrAndDelAttr(SetAttr, DelAttr): + pass + + class AllDunder(Init, SetAttr, DelAttr): + pass + + before = dict() + for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): + before[cls] = {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} + + with _replace_dunder_methods(DataLoader, "dataset"): + pass + + for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): + assert before[cls] == {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} + + @pytest.mark.parametrize("predicting", [True, False]) def test_custom_batch_sampler(predicting): """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to From 14ab4d75de65dc1755205502fe5408db217baef2 Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Tue, 16 Aug 2022 15:38:43 +0200 Subject: [PATCH 6/8] Update src/pytorch_lightning/utilities/data.py Co-authored-by: Jirka Borovec --- src/pytorch_lightning/utilities/data.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index b67369adb487d..b0c9307cec8e1 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -443,10 +443,7 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any: - if explicit_cls is None: - constructor = type(orig_object) - else: - constructor = explicit_cls + constructor = type(orig_object) if explicit_cls is None else explicit_cls try: result = constructor(*args, **kwargs) From 9b2c783fa644edb0fe7c7daeb377388ffde1635f Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 16 Aug 2022 17:15:04 +0200 Subject: [PATCH 7/8] . From cd12665e8c2c190eca15f549e6536e6f6dd8ae1e Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 17 Aug 2022 13:15:42 +0200 Subject: [PATCH 8/8] .