Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Oct 15, 2017
1 parent aa0a2b3 commit 4f7980e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 21 deletions.
24 changes: 3 additions & 21 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def __init__(self, prefix=None, params=None):
self._name = self._prefix[:-1] if self._prefix.endswith('_') else self._prefix
self._scope = _BlockScope(self)
self._children = []
self._collected_params = None

def __repr__(self):
s = '{name}(\n{modstr}\n)'
Expand All @@ -189,31 +188,18 @@ def __setattr__(self, name, value):
type1=type(existing),
type2=type(value)))
if isinstance(existing, Block):
self._unfreeze_children()
for i, c in enumerate(self._children):
if c is existing:
self._children[i] = value
else:
if isinstance(value, Block):
self._unfreeze_children()
self.register_child(value)

super(Block, self).__setattr__(name, value)

def _alias(self):
return self.__class__.__name__.lower()

def _is_children_frozen(self):
return isinstance(self._children, tuple)

def _freeze_children(self):
if not self._is_children_frozen():
self._children = tuple(self._children)

def _unfreeze_children(self):
if self._is_children_frozen():
self._children = list(self._children)

@property
def prefix(self):
"""Prefix of this :py:class:`Block`."""
Expand Down Expand Up @@ -242,15 +228,10 @@ def params(self):
def collect_params(self):
"""Returns a :py:class:`ParameterDict` containing this :py:class:`Block` and all of its
children's Parameters."""
if self._is_children_frozen() and self._collected_params is not None:
return self._collected_params

self._freeze_children()
ret = ParameterDict(self._params.prefix)
ret.update(self.params)
for cld in self._children:
ret.update(cld.collect_params())
self._collected_params = ret
return ret

def save_params(self, filename):
Expand Down Expand Up @@ -282,7 +263,6 @@ def load_params(self, filename, ctx, allow_missing=False,
def register_child(self, block):
"""Registers block as a child of self. :py:class:`Block` s assigned to self as
attributes will be registered automatically."""
self._unfreeze_children()
self._children.append(block)

def initialize(self, init=initializer.Uniform(), ctx=None, verbose=False):
Expand Down Expand Up @@ -352,7 +332,9 @@ def __init__(self, prefix=None, params=None):
def __setattr__(self, name, value):
"""Registers parameters."""
super(HybridBlock, self).__setattr__(name, value)
if isinstance(value, Parameter):
if isinstance(value, HybridBlock):
self._clear_cached_op()
elif isinstance(value, Parameter):
assert name not in self._reg_params or \
not isinstance(self._reg_params[name], Parameter), \
"Overriding Parameter attribute %s is not allowed. " \
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,19 @@ def test_hybrid_stale_cache():
net.add(mx.gluon.nn.Flatten())
assert net(mx.nd.ones((2,3,5))).shape == (2, 30)

net = mx.gluon.nn.HybridSequential()
with net.name_scope():
net.fc1 = mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones', flatten=False)
net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones', flatten=False)
net.hybridize()
net.initialize()
net(mx.nd.ones((2,3,5)))

net.fc2 = mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones',
flatten=True)
net.initialize()
assert net(mx.nd.ones((2,3,5))).shape == (2, 10)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 4f7980e

Please sign in to comment.