Skip to content

Commit

Permalink
v0.12 regression: Fix registration of children for Block (apache#8277)
Browse files Browse the repository at this point in the history
* Fix Block not registering children

If the attribute was already set to something different than Block (e.g. None),
it was not being registered.

* fix if / elif for block children registration

* trigger test

* Add fix from apache#8152

* Add tests from apache#8152
  • Loading branch information
leezu authored and crazy-cat committed Oct 26, 2017
1 parent eb4e78d commit 1f69b70
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ def __setattr__(self, name, value):
for i, c in enumerate(self._children):
if c is existing:
self._children[i] = value
else:
if isinstance(value, Block):
elif isinstance(value, Block):
self.register_child(value)
elif isinstance(value, Block):
self.register_child(value)

super(Block, self).__setattr__(name, value)

Expand Down Expand Up @@ -332,6 +333,8 @@ def __init__(self, prefix=None, params=None):
def __setattr__(self, name, value):
"""Registers parameters."""
super(HybridBlock, self).__setattr__(name, value)
if isinstance(value, HybridBlock):
self._clear_cached_op()
if isinstance(value, Parameter):
assert name not in self._reg_params or \
not isinstance(self._reg_params[name], Parameter), \
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,20 @@ def test_hybrid_stale_cache():
net.add(mx.gluon.nn.Flatten())
assert net(mx.nd.ones((2,3,5))).shape == (2, 30)

net = mx.gluon.nn.HybridSequential()
with net.name_scope():
net.fc1 = mx.gluon.nn.Dense(10, weight_initializer='zeros',
bias_initializer='ones', flatten=False)
net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros',
bias_initializer='ones', flatten=False)
net.hybridize()
net.initialize()
net(mx.nd.ones((2,3,5)))

net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros',
bias_initializer='ones', flatten=True)
net.initialize()
assert net(mx.nd.ones((2,3,5))).shape == (2, 10)

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 1f69b70

Please sign in to comment.