From b8776e143f11f92268b639a574dee79f50365be0 Mon Sep 17 00:00:00 2001 From: lkct Date: Wed, 4 May 2022 20:08:23 +0000 Subject: [PATCH] Fix false DeprecationWarning in `Module.state_dict` Fixes #75404 TODO: - [x] add tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/75507 Approved by: https://github.com/jbschlosser --- test/test_nn.py | 3 + torch/distributed/nn/api/remote_module.py | 2 +- torch/jit/_script.py | 1 - torch/nn/modules/module.py | 96 +++++++++-------------- 4 files changed, 42 insertions(+), 60 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index c326023ac0808..7685da8b0837b 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -6170,6 +6170,9 @@ def test_state_dict(self): self.assertEqual(state_dict['weight'].data_ptr(), l.weight.data_ptr()) self.assertEqual(state_dict['bias'].data_ptr(), l.bias.data_ptr()) + # Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545 + self.assertNotWarn(lambda: l.state_dict(destination=dict()), "Should not warn kwarg destination w/o _metadata") + def test_load_state_dict(self): l = nn.Linear(5, 5) block = nn.Module() diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index 512e532730509..147f5ac70cae8 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -362,7 +362,7 @@ def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandl def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return] _raise_not_supported(self.register_forward_hook.__name__) - def state_dict(self, destination=None, prefix="", keep_vars=False): + def state_dict(self, *args, **kwargs): _raise_not_supported(self.state_dict.__name__) def load_state_dict( diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 73aca43a11990..1f8f5f23543b7 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -893,7 +893,6 @@ def _get_methods(cls): "double", "half", "state_dict", - "_state_dict_impl", "_save_to_state_dict", "load_state_dict", "_load_from_state_dict", diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index d142a240cd5f8..57f550f040039 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -1296,44 +1296,20 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: destination[extra_state_key] = self.get_extra_state() - def _state_dict_impl(self, destination, prefix, keep_vars): - r"""Holds the actual implementation of - :meth:`~torch.nn.Module.state_dict`, with recursive calls for - descendants of this module. - - In rare cases, users can call this directly to provide a custom - :attr:`destination`. - - Args: - destination (dict): a dict where state will be stored - prefix (str): the prefix for parameters and buffers used in this - module - keep_vars (bool): whether NOT to return buffers detached from - autograd - """ - destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars) - for name, module in self._modules.items(): - if module is not None: - module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars) - for hook in self._state_dict_hooks.values(): - hook_result = hook(self, destination, prefix, local_metadata) - if hook_result is not None: - destination = hook_result - return destination - - # TODO: Deprecated, destination is becoming private. Remove this signature when BC allows - # See https://github.com/pytorch/pytorch/issues/72778#issuecomment-1039263869 + # The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns + # back that same object. But if they pass nothing, an `OrederedDict` is created and returned. T_destination = TypeVar('T_destination', bound=Dict[str, Any]) @overload - def state_dict(self, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: + def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... @overload def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... + # TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows. + # Also remove the logic for arg parsing together. def state_dict(self, *args, destination=None, prefix='', keep_vars=False): r"""Returns a dictionary containing a whole state of the module. @@ -1341,30 +1317,27 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included. - This can be called as - - .. function:: state_dict(*, prefix='', keep_vars=False) - :noindex: - - .. function:: state_dict(destination, prefix='', keep_vars=False) - :noindex: + .. warning:: + Currently ``state_dict()`` also accepts positional arguments for + ``destination``, ``prefix`` and ``keep_vars`` in order. However, + this is being deprecated and keyword arguments will be enforced in + future releases. .. warning:: - The second signature is deprecated and should not be used. It's only - temporarily kept for backward compatibility and will be removed in - a future release. Use the first signature instead. + Please avoid the use of argument ``destination`` as it is not + designed for end-users. Args: - destination (dict, optional): Deprecated. This dict is returned - with the module state saved in it. It should also have an - attribute ``_metadata: dict`` to save metadata of the module - state. If it's not provided, an ``OrderedDict`` is created and - returned. Default: ``None`` + destination (dict, optional): If provided, the state of module will + be updated into the dict and the same object is returned. + Otherwise, an ``OrderedDict`` will be created and returned. + Default: ``None``. prefix (str, optional): a prefix added to parameter and buffer - names to compose the keys in dict. Default: ``''`` + names to compose the keys in state_dict. Default: ``''``. keep_vars (bool, optional): by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd. If it's - set to ``True``, detaching is not performed. Default: ``False`` + set to ``True``, detaching will not be performed. + Default: ``False``. Returns: dict: @@ -1377,30 +1350,37 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False): """ - # TODO: positional args parsing is just for BC. Remove on transition to kwargs-only - warn_msg = [] + # TODO: Remove `args` and the parsing logic when BC allows. if len(args) > 0: - warn_msg.append('positional arguments') if destination is None: destination = args[0] if len(args) > 1 and prefix == '': prefix = args[1] if len(args) > 2 and keep_vars is False: keep_vars = args[2] + # DeprecationWarning is ignored by default + warnings.warn( + "Positional args are being deprecated, use kwargs instead. Refer to " + "https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict" + " for details.") - if destination is not None: - warn_msg.append('argument "destination"') - else: + if destination is None: destination = OrderedDict() destination._metadata = OrderedDict() - if warn_msg: - # DeprecationWarning is ignored by default - warnings.warn( - " and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. " - "Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.") + local_metadata = dict(version=self._version) + if hasattr(destination, "_metadata"): + destination._metadata[prefix[:-1]] = local_metadata - return self._state_dict_impl(destination, prefix, keep_vars) + self._save_to_state_dict(destination, prefix, keep_vars) + for name, module in self._modules.items(): + if module is not None: + module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars) + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination def _register_load_state_dict_pre_hook(self, hook, with_module=False): r"""These hooks will be called with arguments: `state_dict`, `prefix`,