diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 0ebcd201d56b..26a8a30dcdc5 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -168,7 +168,6 @@ def __init__(self, prefix=None, params=None): self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix self._scope = _BlockScope(self) self._children = [] - self._collected_params = None def __repr__(self): s = '{name}(\n{modstr}\n)' @@ -189,13 +188,11 @@ def __setattr__(self, name, value): type1=type(existing), type2=type(value))) if isinstance(existing, Block): - self._unfreeze_children() for i, c in enumerate(self._children): if c is existing: self._children[i] = value else: if isinstance(value, Block): - self._unfreeze_children() self.register_child(value) super(Block, self).__setattr__(name, value) @@ -203,17 +200,6 @@ def __setattr__(self, name, value): def _alias(self): return self.__class__.__name__.lower() - def _is_children_frozen(self): - return isinstance(self._children, tuple) - - def _freeze_children(self): - if not self._is_children_frozen(): - self._children = tuple(self._children) - - def _unfreeze_children(self): - if self._is_children_frozen(): - self._children = list(self._children) - @property def prefix(self): """Prefix of this :py:class:`Block`.""" @@ -242,15 +228,10 @@ def params(self): def collect_params(self): """Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its children's Parameters.""" - if self._is_children_frozen() and self._collected_params is not None: - return self._collected_params - - self._freeze_children() ret = ParameterDict(self._params.prefix) ret.update(self.params) for cld in self._children: ret.update(cld.collect_params()) - self._collected_params = ret return ret def save_params(self, filename): @@ -282,7 +263,6 @@ def load_params(self, filename, ctx, allow_missing=False, def register_child(self, block): """Registers block as a child of self. :py:class:`Block` s assigned to self as attributes will be registered automatically.""" - self._unfreeze_children() self._children.append(block) def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False): @@ -352,7 +332,9 @@ def __init__(self, prefix=None, params=None): def __setattr__(self, name, value): """Registers parameters.""" super(HybridBlock, self).__setattr__(name, value) - if isinstance(value, Parameter): + if isinstance(value, HybridBlock): + self._clear_cached_op() + elif isinstance(value, Parameter): assert name not in self._reg_params or \ not isinstance(self._reg_params[name], Parameter), \ "Overriding Parameter attribute %s is not allowed. " \ diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 60a0630c1665..8b157372f818 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -516,6 +516,19 @@ def test_hybrid_stale_cache(): net.add(mx.gluon.nn.Flatten()) assert net(mx.nd.ones((2,3,5))).shape == (2, 30) + net = mx.gluon.nn.HybridSequential() + with net.name_scope(): + net.fc1 = mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones', flatten=False) + net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones', flatten=False) + net.hybridize() + net.initialize() + net(mx.nd.ones((2,3,5))) + + net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones', + flatten=True) + net.initialize() + assert net(mx.nd.ones((2,3,5))).shape == (2, 10) + if __name__ == '__main__': import nose