diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 6aa6a9c7d8037..5342faf06f77e 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -101,6 +101,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Avoided requiring the FairScale package to use precision with the fsdp native strategy ([#14092](https://github.com/Lightning-AI/lightning/pull/14092)) +- Fixed not preserving set attributes on `DataLoader` and `BatchSampler` when instantiated inside `*_dataloader` hooks ([#14212](https://github.com/Lightning-AI/lightning/pull/14212)) + + ## [1.7.1] - 2022-08-09 ### Fixed diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 981eed30635f6..ca45a4011fcdd 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -35,7 +35,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, - _replace_init_method, + _replace_dunder_methods, _update_dataloader, has_iterable_dataset, ) @@ -403,9 +403,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() - with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method( - BatchSampler - ): + with self._strategy.model_sharded_context(), _replace_dunder_methods( + DataLoader, "dataset" + ), _replace_dunder_methods(BatchSampler): return run_method(*args, **kwargs) def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index f56c095dc12c1..b254c5df16ca5 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs +from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -248,7 +248,9 @@ def _convert_to_poptorch_loader( dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type] ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts - dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) + dataloader = _reinstantiate_wrapped_cls( + dataloader, opts, *dl_args, explicit_cls=poptorch.DataLoader, **dl_kwargs + ) return dataloader def _handle_gradient_accumulation_steps(self) -> None: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 6e592b9f6d310..e20eac2ffae57 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _is_dataloader_shuffled, - _replace_init_method, + _replace_dunder_methods, _update_dataloader, has_iterable_dataset, has_len_all_ranks, @@ -428,9 +428,11 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat """ source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source") - with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler): + with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as - # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning + # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning. + # Also, it records all attribute setting and deletion using patched `__setattr__` and `__delattr__` + # methods so that the re-instantiated object is as close to the original as possible. dataloader = source.dataloader() if isinstance(dataloader, tuple): dataloader = list(dataloader) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index b625a046f6122..b0c9307cec8e1 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -37,7 +37,7 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.apply_func import _is_dataclass_instance from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler -from pytorch_lightning.utilities.enums import _FaultTolerantMode +from pytorch_lightning.utilities.enums import _FaultTolerantMode, LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.meta import _get_all_subclasses from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -49,6 +49,18 @@ warning_cache = WarningCache() +class _WrapAttrTag(LightningEnum): + SET = "set" + DEL = "del" + + def __call__(self, *args): + if self == self.SET: + fn = setattr + else: + fn = delattr + return fn(*args) + + def _extract_batch_size(batch: BType) -> Generator[int, None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: @@ -189,27 +201,7 @@ def _update_dataloader( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None ) -> DataLoader: dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) - dl_cls = type(dataloader) - try: - dataloader = dl_cls(*dl_args, **dl_kwargs) - except TypeError as e: - # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass - # `__init__` arguments map to one `DataLoader.__init__` argument - import re - - match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) - if not match: - # an unexpected `TypeError`, continue failure - raise - argument = match.groups()[0] - message = ( - f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument" - f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" - f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." - f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." - " This argument was automatically passed to your DataLoader by PyTorch Lightning." - ) - raise MisconfigurationException(message) from e + dataloader = _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) return dataloader @@ -375,7 +367,7 @@ def _dataloader_init_kwargs_resolve_sampler( "this, expose an argument `sampler` in the `__init__` method of your custom class." ) - batch_sampler = batch_sampler_cls(*args, **kwargs) + batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs) else: try: batch_sampler = batch_sampler_cls( @@ -450,6 +442,37 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) +def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[Type] = None, **kwargs: Any) -> Any: + constructor = type(orig_object) if explicit_cls is None else explicit_cls + + try: + result = constructor(*args, **kwargs) + except TypeError as e: + # improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass + # `__init__` arguments map to one `DataLoader.__init__` argument + import re + + match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + argument = match.groups()[0] + message = ( + f"The {constructor.__name__} implementation has an error where more than one `__init__` argument" + f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" + f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." + f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." + " This argument was automatically passed to your object by PyTorch Lightning." + ) + raise MisconfigurationException(message) from e + + attrs_record = getattr(orig_object, "__pl_attrs_record", list()) + for args, fn in attrs_record: + fn(result, *args) + + return result + + def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable: """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @@ -458,6 +481,8 @@ def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: # We need to inspect `init`, as inspecting `obj.__init__` # can lead to inspecting the wrong function with multiple inheritance + old_inside_init = getattr(obj, "__pl_inside_init", False) + object.__setattr__(obj, "__pl_inside_init", True) params = inspect.signature(init).parameters parameters_defaults = OrderedDict( @@ -475,45 +500,82 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: } if not hasattr(obj, "__pl_saved_args"): - obj.__pl_saved_args = args - obj.__pl_saved_kwargs = kwargs - obj.__pl_saved_arg_names = param_names - obj.__pl_saved_default_kwargs = default_kwargs + object.__setattr__(obj, "__pl_saved_args", args) + object.__setattr__(obj, "__pl_saved_kwargs", kwargs) + object.__setattr__(obj, "__pl_saved_arg_names", param_names) + object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs) # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class) # so that we can be sure, that it will not get changed anymore. # That is why we are setting this in every `__init__` if store_explicit_arg is not None: if store_explicit_arg in param_names: - setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) + object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) elif store_explicit_arg in kwargs: - setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) + object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) init(obj, *args, **kwargs) + object.__setattr__(obj, "__pl_inside_init", old_inside_init) + + return wrapper + + +def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: + """Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and + :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" + + @functools.wraps(method) + def wrapper(obj: Any, *args: Any): + # First, let's find out if we're the first in inheritance chain calling the patched method. + name, *_ = args + prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method")) + first_call = not (prev_call_name == name and prev_call_method == tag) + + # Then mark the current called method + object.__setattr__(obj, "__pl_current_call", (name, tag)) + + # call original method + method(obj, *args) + if first_call and not getattr(obj, "__pl_inside_init", True): + # and save the value it was called with to the internal list, + # if we're outside of __init__ and the original call did not fail and we're the first call + attrs_record = getattr(obj, "__pl_attrs_record", list()) + attrs_record.append((args, tag)) + object.__setattr__(obj, "__pl_attrs_record", attrs_record) + object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method)) return wrapper @contextmanager -def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: +def _replace_dunder_methods(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. - It patches the ``__init__`` method. + It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. """ classes = _get_all_subclasses(base_cls) | {base_cls} for cls in classes: # Check that __init__ belongs to the class # https://stackoverflow.com/a/5253424 if "__init__" in cls.__dict__: - cls._old_init = cls.__init__ + cls.__old__init__ = cls.__init__ cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) + + # we want at least one setattr/delattr in the chain to be patched and it can happen, that none of the subclasses + # implement `__setattr__`/`__delattr__`. Therefore, we are always patching the `base_cls` + for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)): + if patch_fn_name in cls.__dict__ or cls is base_cls: + saved_name = f"__old{patch_fn_name}" + setattr(cls, saved_name, getattr(cls, patch_fn_name)) + setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag)) yield for cls in classes: - # Check that _old_init belongs to the class - # https://stackoverflow.com/a/5253424 - if "_old_init" in cls.__dict__: - cls.__init__ = cls._old_init - del cls._old_init + for patched_name in ("__setattr__", "__delattr__", "__init__"): + # Check that __old__{init,setattr,delattr} belongs to the class + # https://stackoverflow.com/a/5253424 + if f"__old{patched_name}" in cls.__dict__: + setattr(cls, patched_name, getattr(cls, f"__old{patched_name}")) + delattr(cls, f"__old{patched_name}") def _wrap_with_capture_dataset(dataset: Dataset) -> Dataset: diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index 86a0a5a82195a..d45046f249d54 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -177,7 +177,7 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 -@mock.patch("pytorch_lightning.lite.lite._replace_init_method") +@mock.patch("pytorch_lightning.lite.lite._replace_dunder_methods") def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): """Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method.""" diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index cc70417988616..9e3d04ae65560 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -13,9 +13,10 @@ from pytorch_lightning.utilities.data import ( _dataloader_init_kwargs_resolve_sampler, _get_dataloader_init_args_and_kwargs, - _replace_init_method, + _replace_dunder_methods, _replace_value_in_saved_args, _update_dataloader, + _WrapAttrTag, extract_batch_size, get_len, has_iterable_dataset, @@ -144,10 +145,10 @@ def __init__(self, foo, *args, **kwargs): super().__init__(foo, *args, **kwargs) dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) - with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): + with pytest.raises(MisconfigurationException, match="implementation has an error.*`dataset`"): _update_dataloader(dataloader, dataloader.sampler) - with _replace_init_method(DataLoader, "dataset"): + with _replace_dunder_methods(DataLoader, "dataset"): dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) new_dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(new_dataloader, BadStandaloneGoodHookImpl) @@ -159,7 +160,7 @@ def __init__(self, randomize, *args, **kwargs): super().__init__(*args, shuffle=randomize, **kwargs) dataloader = BadImpl(False, []) - with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"): + with pytest.raises(MisconfigurationException, match="implementation has an error.*`shuffle`"): _update_dataloader(dataloader, dataloader.sampler) class GoodImpl(DataLoader): @@ -173,28 +174,33 @@ def __init__(self, randomize, *args, **kwargs): assert isinstance(new_dataloader, GoodImpl) -def test_replace_init_method_multiple_loaders_without_init(): +def test_replace_dunder_methods_multiple_loaders_without_init(): """In case of a class, that inherits from a class that we are patching, but doesn't define its own `__init__` - method (the one we are wrapping), it can happen, that `hasattr(cls, "_old_init")` is True because of parent + method (the one we are wrapping), it can happen, that `hasattr(cls, "__old__init__")` is True because of parent class, but it is impossible to delete, because that method is owned by parent class. Furthermore, the error occured only sometimes because it depends on the order in which we are iterating over a set of classes we are patching. This test simulates the behavior by generating sufficient number of dummy classes, which do not define `__init__` - and are children of `DataLoader`. We are testing that a) context manager `_replace_init_method` exits cleanly, and - b) the mechanism checking for presence of `_old_init` works as expected. + and are children of `DataLoader`. We are testing that a) context manager `_replace_dunder_method` exits cleanly, and + b) the mechanism checking for presence of `__old__init__` works as expected. """ classes = [DataLoader] for i in range(100): classes.append(type(f"DataLoader_{i}", (random.choice(classes),), {})) - with _replace_init_method(DataLoader, "dataset"): + before = {cls: cls.__init__ for cls in classes} + + with _replace_dunder_methods(DataLoader, "dataset"): for cls in classes[1:]: # First one is `DataLoader` - assert "_old_init" not in cls.__dict__ - assert hasattr(cls, "_old_init") + assert "__old__init__" not in cls.__dict__ + assert hasattr(cls, "__old__init__") + + assert "__old__init__" in DataLoader.__dict__ + assert hasattr(DataLoader, "__old__init__") - assert "_old_init" in DataLoader.__dict__ - assert hasattr(DataLoader, "_old_init") + for cls in classes: + assert before[cls] == cls.__init__ class DataLoaderSubclass1(DataLoader): @@ -322,8 +328,8 @@ def __init__(self, dataset, **kwargs): pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"), ], ) -def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): - with _replace_init_method(DataLoader, "dataset"): +def test_replace_dunder_methods_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): + with _replace_dunder_methods(DataLoader, "dataset"): dataloader = cls(*args, **kwargs) assert dataloader.__pl_saved_args == args @@ -360,12 +366,12 @@ def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, c assert dataloader_value == value -def test_replace_init_method_extra_kwargs(): +def test_replace_dunder_methods_extra_kwargs(): class LoaderSubclass(DataLoader): def __init__(self, dataset, *args, batch_size=10, **kwargs): super().__init__(dataset, *args, batch_size=batch_size, **kwargs) - with _replace_init_method(DataLoader, "dataset"): + with _replace_dunder_methods(DataLoader, "dataset"): dataloader = LoaderSubclass(range(10)) assert dataloader.__pl_saved_args == (range(10),) @@ -375,6 +381,90 @@ def __init__(self, dataset, *args, batch_size=10, **kwargs): assert dataloader.__dataset == range(10) +def test_replace_dunder_methods_attrs(): + """This test checks, that all the calls from setting and deleting attributes within `_replace_dunder_methods` + are correctly preserved even after reinstantiation. + + It also includes a custom `__setattr__` + """ + + class Loader(DataLoader): + def __setattr__(self, attr, val): + if attr == "custom_arg": + val = val + 2 + super().__setattr__(attr, val) + + with _replace_dunder_methods(DataLoader, "dataset"): + dataloader = Loader(range(10)) + dataloader.custom_arg = 5 + dataloader.my_arg = 10 + dataloader.another_arg = 100 + del dataloader.dataset + try: + del dataloader.abc_arg + except AttributeError: + pass + + assert dataloader.__pl_saved_args == (range(10),) + assert dataloader.__pl_saved_kwargs == {} + assert dataloader.__pl_saved_arg_names == ("dataset",) + assert dataloader.__dataset == range(10) + assert dataloader.custom_arg == 7 + assert dataloader.my_arg == 10 + assert dataloader.another_arg == 100 + assert not hasattr(dataloader, "dataset") + assert dataloader.__pl_attrs_record == [ + (("custom_arg", 5), _WrapAttrTag.SET), + (("my_arg", 10), _WrapAttrTag.SET), + (("another_arg", 100), _WrapAttrTag.SET), + (("dataset",), _WrapAttrTag.DEL), + ] + + dataloader = _update_dataloader(dataloader, dataloader.sampler) + assert dataloader.custom_arg == 7 + assert dataloader.my_arg == 10 + assert dataloader.another_arg == 100 + assert not hasattr(dataloader, "dataset") + + +def test_replace_dunder_methods_restore_methods(): + """This tests checks whether are all dunder methods restored to their original versions.""" + + class Init(DataLoader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + class SetAttr(DataLoader): + def __setattr__(self, *args): + return super().__setattr__(*args) + + class DelAttr(DataLoader): + def __delattr__(self, *args): + return super().__delattr__(*args) + + class InitAndSetAttr(Init, SetAttr): + pass + + class InitAndDelAttr(Init, DelAttr): + pass + + class SetAttrAndDelAttr(SetAttr, DelAttr): + pass + + class AllDunder(Init, SetAttr, DelAttr): + pass + + before = dict() + for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): + before[cls] = {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} + + with _replace_dunder_methods(DataLoader, "dataset"): + pass + + for cls in (Init, SetAttr, DelAttr, InitAndSetAttr, InitAndDelAttr, SetAttrAndDelAttr, AllDunder): + assert before[cls] == {"init": cls.__init__, "setattr": cls.__setattr__, "delattr": cls.__delattr__} + + @pytest.mark.parametrize("predicting", [True, False]) def test_custom_batch_sampler(predicting): """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to @@ -391,8 +481,8 @@ def __init__(self, sampler, extra_arg, drop_last=True): super().__init__(sampler, 10, drop_last) sampler = RandomSampler(range(10)) - with _replace_init_method(BatchSampler): - # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + with _replace_dunder_methods(BatchSampler): + # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler(sampler, "random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) @@ -437,8 +527,8 @@ def __init__(self, sampler, extra_arg): super().__init__(sampler, 10, False) sampler = RandomSampler(range(10)) - with _replace_init_method(BatchSampler): - # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + with _replace_dunder_methods(BatchSampler): + # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler(sampler, "random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) @@ -464,8 +554,8 @@ def __init__(self, extra_arg): self.extra_arg = extra_arg super().__init__(RandomSampler(range(10)), 10, False) - with _replace_init_method(BatchSampler): - # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + with _replace_dunder_methods(BatchSampler): + # instantiate within `_replace_dunder_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler("random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler)