Skip to content

Commit

Permalink
API Improvement for paddle.nn.layer.state_dict 易用性提升 (#64358)
Browse files Browse the repository at this point in the history
* update state_dict

* udpate

* fix test

* fix test

* fix test

* fix test

* update test

* update test
  • Loading branch information
NKNaN authored May 24, 2024
1 parent 4aa1c20 commit 12ecf2e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 3 deletions.
21 changes: 18 additions & 3 deletions python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1871,6 +1871,7 @@ def _state_dict_impl(
structured_name_prefix="",
include_non_persistable_buffer=False,
use_hook=True,
keep_vars=True,
):
"""
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Expand All @@ -1880,23 +1881,30 @@ def _state_dict_impl(
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True.
include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False.
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True.
keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True.
"""

if destination is None:
destination = collections.OrderedDict()
for name, data in self._parameters.items():
if data is not None:
destination[structured_name_prefix + name] = data
destination[structured_name_prefix + name] = (
data if keep_vars else data.detach()
)
for name, buffer in self._buffers.items():
if not include_non_persistable_buffer:
if (
buffer is not None
and name not in self._non_persistable_buffer_names_set
):
destination[structured_name_prefix + name] = buffer
destination[structured_name_prefix + name] = (
buffer if keep_vars else buffer.detach()
)
else:
if buffer is not None:
destination[structured_name_prefix + name] = buffer
destination[structured_name_prefix + name] = (
buffer if keep_vars else buffer.detach()
)

if include_sublayers:
for layer_name, layer_item in self._sub_layers.items():
Expand All @@ -1909,6 +1917,7 @@ def _state_dict_impl(
structured_name_prefix + layer_name + ".",
include_non_persistable_buffer,
use_hook,
keep_vars,
)
)
destination = destination_temp
Expand All @@ -1926,6 +1935,7 @@ def to_static_state_dict(
include_sublayers=True,
structured_name_prefix="",
use_hook=True,
keep_vars=True,
):
'''
Expand All @@ -1935,6 +1945,7 @@ def to_static_state_dict(
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None.
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True.
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True.
keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True.
Returns:
dict, a dict contains all the parameters and persistable buffers.
Expand All @@ -1956,6 +1967,7 @@ def to_static_state_dict(
structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=True,
use_hook=use_hook,
keep_vars=keep_vars,
)

def state_dict(
Expand All @@ -1964,6 +1976,7 @@ def state_dict(
include_sublayers=True,
structured_name_prefix="",
use_hook=True,
keep_vars=True,
):
'''
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict
Expand All @@ -1972,6 +1985,7 @@ def state_dict(
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None.
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True.
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True.
keep_vars(bool, optional) : If false, the returned tensors in the state dict are detached from autograd. Default: True.
Returns:
dict: a dict contains all the parameters and persistable buffers.
Expand All @@ -1993,6 +2007,7 @@ def state_dict(
structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=False,
use_hook=use_hook,
keep_vars=keep_vars,
)

@framework.deprecate_stat_dict
Expand Down
41 changes: 41 additions & 0 deletions test/legacy_test/test_state_dict_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def state_dict(
include_sublayers=True,
structured_name_prefix="",
use_hook=True,
keep_vars=True,
):
st = super().state_dict(
destination=destination,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix,
use_hook=use_hook,
keep_vars=keep_vars,
)
st["linear.new_weight"] = paddle.transpose(
st.pop("linear.weight"), [1, 0]
Expand Down Expand Up @@ -75,6 +77,17 @@ def is_state_dict_equal(model1, model2):
return True


class MyModel3(nn.Layer):
def __init__(self):
super().__init__()
self.linear = nn.Linear(100, 300)
buffer = paddle.to_tensor([0.0])
self.register_buffer("model_buffer", buffer, persistable=True)

def forward(self, x):
return self.linear(x)


class TestStateDictConvert(unittest.TestCase):
def test_main(self):
model1 = MyModel()
Expand All @@ -97,5 +110,33 @@ def test_missing_keys_and_unexpected_keys(self):
self.assertEqual(unexpected_keys[0], "unexpected_keys")


class TestStateKeepVars(unittest.TestCase):
def test_true(self):
model = MyModel3()
x = paddle.randn([5, 100])
y = model(x)
y.backward()
st = model.state_dict()
has_grad = (
(st["linear.weight"].grad == model.linear.weight.grad).all()
and (st["linear.bias"].grad == model.linear.bias.grad).all()
and st["model_buffer"].grad == model.model_buffer.grad
)
self.assertEqual(has_grad, True)

def test_false(self):
model = MyModel3()
x = paddle.randn([5, 100])
y = model(x)
y.backward()
st = model.state_dict(keep_vars=False)
has_grad = (
(st["linear.weight"].grad is not None)
and (st["linear.bias"].grad is not None)
and (st["model_buffer"].grad is not None)
)
self.assertEqual(has_grad, False)


if __name__ == "__main__":
unittest.main()

0 comments on commit 12ecf2e

Please sign in to comment.