Skip to content

Commit

Permalink
Fix false DeprecationWarning in Module.state_dict
Browse files Browse the repository at this point in the history
Fixes pytorch#75404

TODO:
- [x] add tests
Pull Request resolved: pytorch#75507
Approved by: https://github.com/jbschlosser
  • Loading branch information
lkct authored and pytorchmergebot committed May 4, 2022
1 parent 429a80d commit b8776e1
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 60 deletions.
3 changes: 3 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6170,6 +6170,9 @@ def test_state_dict(self):
self.assertEqual(state_dict['weight'].data_ptr(), l.weight.data_ptr())
self.assertEqual(state_dict['bias'].data_ptr(), l.bias.data_ptr())

# Reference https://github.com/pytorch/pytorch/pull/75507#issuecomment-1110291545
self.assertNotWarn(lambda: l.state_dict(destination=dict()), "Should not warn kwarg destination w/o _metadata")

def test_load_state_dict(self):
l = nn.Linear(5, 5)
block = nn.Module()
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/nn/api/remote_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def register_forward_pre_hook(self, hook: Callable[..., None]) -> RemovableHandl
def register_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: # type: ignore[return]
_raise_not_supported(self.register_forward_hook.__name__)

def state_dict(self, destination=None, prefix="", keep_vars=False):
def state_dict(self, *args, **kwargs):
_raise_not_supported(self.state_dict.__name__)

def load_state_dict(
Expand Down
1 change: 0 additions & 1 deletion torch/jit/_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,6 @@ def _get_methods(cls):
"double",
"half",
"state_dict",
"_state_dict_impl",
"_save_to_state_dict",
"load_state_dict",
"_load_from_state_dict",
Expand Down
96 changes: 38 additions & 58 deletions torch/nn/modules/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,75 +1296,48 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()

def _state_dict_impl(self, destination, prefix, keep_vars):
r"""Holds the actual implementation of
:meth:`~torch.nn.Module.state_dict`, with recursive calls for
descendants of this module.
In rare cases, users can call this directly to provide a custom
:attr:`destination`.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
keep_vars (bool): whether NOT to return buffers detached from
autograd
"""
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination

# TODO: Deprecated, destination is becoming private. Remove this signature when BC allows
# See https://github.com/pytorch/pytorch/issues/72778#issuecomment-1039263869
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrederedDict` is created and returned.
T_destination = TypeVar('T_destination', bound=Dict[str, Any])

@overload
def state_dict(self, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
...

@overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
...

# TODO: Change `*args` to `*` and remove the copprespinding warning in docs when BC allows.
# Also remove the logic for arg parsing together.
def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
This can be called as
.. function:: state_dict(*, prefix='', keep_vars=False)
:noindex:
.. function:: state_dict(destination, prefix='', keep_vars=False)
:noindex:
.. warning::
Currently ``state_dict()`` also accepts positional arguments for
``destination``, ``prefix`` and ``keep_vars`` in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
.. warning::
The second signature is deprecated and should not be used. It's only
temporarily kept for backward compatibility and will be removed in
a future release. Use the first signature instead.
Please avoid the use of argument ``destination`` as it is not
designed for end-users.
Args:
destination (dict, optional): Deprecated. This dict is returned
with the module state saved in it. It should also have an
attribute ``_metadata: dict`` to save metadata of the module
state. If it's not provided, an ``OrderedDict`` is created and
returned. Default: ``None``
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
Default: ``None``.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in dict. Default: ``''``
names to compose the keys in state_dict. Default: ``''``.
keep_vars (bool, optional): by default the :class:`~torch.Tensor` s
returned in the state dict are detached from autograd. If it's
set to ``True``, detaching is not performed. Default: ``False``
set to ``True``, detaching will not be performed.
Default: ``False``.
Returns:
dict:
Expand All @@ -1377,30 +1350,37 @@ def state_dict(self, *args, destination=None, prefix='', keep_vars=False):
"""

# TODO: positional args parsing is just for BC. Remove on transition to kwargs-only
warn_msg = []
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
warn_msg.append('positional arguments')
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == '':
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
# DeprecationWarning is ignored by default
warnings.warn(
"Positional args are being deprecated, use kwargs instead. Refer to "
"https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict"
" for details.")

if destination is not None:
warn_msg.append('argument "destination"')
else:
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()

if warn_msg:
# DeprecationWarning is ignored by default
warnings.warn(
" and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. "
"Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.")
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata

return self._state_dict_impl(destination, prefix, keep_vars)
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination

def _register_load_state_dict_pre_hook(self, hook, with_module=False):
r"""These hooks will be called with arguments: `state_dict`, `prefix`,
Expand Down

0 comments on commit b8776e1

Please sign in to comment.