From cc93069063b83f5e1149bd07952c85f9232a4e17 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Wed, 18 Oct 2017 00:02:45 +0900 Subject: [PATCH] v0.12 regression: Fix registration of children for Block (#8277) * Fix Block not registering children If the attribute was already set to something different than Block (e.g. None), it was not being registered. * fix if / elif for block children registration * trigger test * Add fix from #8152 * Add tests from #8152 --- python/mxnet/gluon/block.py | 7 +++++-- tests/python/unittest/test_gluon.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index fb4ac8525299..73dbfc10fed7 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -191,9 +191,10 @@ def __setattr__(self, name, value): for i, c in enumerate(self._children): if c is existing: self._children[i] = value - else: - if isinstance(value, Block): + elif isinstance(value, Block): self.register_child(value) + elif isinstance(value, Block): + self.register_child(value) super(Block, self).__setattr__(name, value) @@ -332,6 +333,8 @@ def __init__(self, prefix=None, params=None): def __setattr__(self, name, value): """Registers parameters.""" super(HybridBlock, self).__setattr__(name, value) + if isinstance(value, HybridBlock): + self._clear_cached_op() if isinstance(value, Parameter): assert name not in self._reg_params or \ not isinstance(self._reg_params[name], Parameter), \ diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 60a0630c1665..c9bde39375d6 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -516,6 +516,20 @@ def test_hybrid_stale_cache(): net.add(mx.gluon.nn.Flatten()) assert net(mx.nd.ones((2,3,5))).shape == (2, 30) + net = mx.gluon.nn.HybridSequential() + with net.name_scope(): + net.fc1 = mx.gluon.nn.Dense(10, weight_initializer='zeros', + bias_initializer='ones', flatten=False) + net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', + bias_initializer='ones', flatten=False) + net.hybridize() + net.initialize() + net(mx.nd.ones((2,3,5))) + + net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', + bias_initializer='ones', flatten=True) + net.initialize() + assert net(mx.nd.ones((2,3,5))).shape == (2, 10) if __name__ == '__main__': import nose