diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 73dbfc10fed7..9e6da5b26510 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -325,7 +325,7 @@ def __init__(self, prefix=None, params=None): self._reg_params = {} self._cached_graph = () self._cached_op = None - self._cached_params = None + self._cached_call_params = () self._out_format = None self._in_format = None self._active = False @@ -363,34 +363,34 @@ def _get_graph(self, *args): def _build_cache(self, *args): inputs, out = self._get_graph(*args) + name2pos = {var.name: i for i, var in enumerate(inputs)} self._cached_op = ndarray.CachedOp(out) - params = dict(self.collect_params().items()) - self._cached_params = [params.get(name, None) for name in out.list_inputs()] - assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \ + if args and not isinstance(args[0], Symbol): + self._finish_deferred_init(*args) + + params = dict(self.collect_params()._params) + assert len(params) + len(inputs) == len(out.list_inputs()), \ "Wrong number of inputs." - name2pos = {var.name: i for i, var in enumerate(inputs)} - self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs()) - if name not in params] + self._cached_call_params = tuple((False, params[name].data) if name in params + else (True, name2pos[name]) + for name in out.list_inputs()) + + def _finish_deferred_init(self, *args): + self.infer_shape(*args) + for i in self.collect_params().values(): + i._finish_deferred_init() def _call_cached_op(self, *args): if self._cached_op is None: self._build_cache(*args) - try: - cargs = [i.data() if i else None for i in self._cached_params] - except DeferredInitializationError: - self.infer_shape(*args) - for i in self._cached_params: - if i is not None: - i._finish_deferred_init() - cargs = [i.data() if i else None for i in self._cached_params] - args, fmt = _flatten(args) assert fmt == self._in_format, "Invalid input format" - for i, j in self._in_idx: - cargs[i] = args[j] + + cargs = (args[p] if is_arg else p() + for is_arg, p in self._cached_call_params) out = self._cached_op(*cargs) if isinstance(out, NDArray): out = [out] @@ -399,6 +399,7 @@ def _call_cached_op(self, *args): def _clear_cached_op(self): self._cached_graph = () self._cached_op = None + self._cached_call_params = () def register_child(self, block): if not isinstance(block, HybridBlock): @@ -462,14 +463,12 @@ def forward(self, x, *args): :py:class:`NDArray` or :py:class:`Symbol`.""" if isinstance(x, NDArray): with x.context as ctx: - if self._active: - return self._call_cached_op(x, *args) try: + if self._active: + return self._call_cached_op(x, *args) params = {i: j.data(ctx) for i, j in self._reg_params.items()} except DeferredInitializationError: - self.infer_shape(x, *args) - for i in self.collect_params().values(): - i._finish_deferred_init() + self._finish_deferred_init(x, *args) params = {i: j.data(ctx) for i, j in self._reg_params.items()} return self.hybrid_forward(ndarray, x, *args, **params) @@ -559,7 +558,11 @@ def __init__(self, outputs, inputs, params=None): def forward(self, x, *args): if isinstance(x, NDArray): with x.context: - return self._call_cached_op(x, *args) + try: + return self._call_cached_op(x, *args) + except DeferredInitializationError: + self._finish_deferred_init(x, *args) + return self._call_cached_op(x, *args) assert isinstance(x, Symbol), \ "HybridBlock requires the first argument to forward be either " \ diff --git a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py index cbb3f1a43b37..748314b1be22 100644 --- a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py @@ -131,8 +131,10 @@ def __repr__(self): s += ', {_conv_layout}' s += ')' attrs = self.__dict__ - mapping = ('{_in_channels} -> {_hidden_channels}'.format(**attrs) if self._in_channels - else self._hidden_channels) + shape = self.i2h_weight.shape + in_channels = shape[1 if self._channel_axis == 1 else -1] + mapping = ('{0} -> {1}'.format(in_channels, shape[0]) if in_channels + else shape[0]) return s.format(name=self.__class__.__name__, mapping=mapping, **attrs) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index e9fb2ffc62ef..fadfc9f8f3cc 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -207,10 +207,11 @@ def hybrid_forward(self, F, x, weight, bias=None): def __repr__(self): s = '{name}({layout}, {act})' + shape = self.weight.shape return s.format(name=self.__class__.__name__, act=self.act if self.act else 'linear', - layout='{0} -> {1}'.format(self._in_units, self._units) if self._in_units - else self._units) + layout='{0} -> {1}'.format(shape[1], shape[0]) if shape[1] + else shape[0]) class Activation(HybridBlock): @@ -360,8 +361,9 @@ def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): def __repr__(self): s = '{name}({content}' - if hasattr(self, 'in_channels'): - s += ', in_channels={0}'.format(self.in_channels) + in_channels = self.gamma.shape[0] + if in_channels: + s += ', in_channels={0}'.format(in_channels) s += ')' return s.format(name=self.__class__.__name__, content=', '.join(['='.join([k, v.__repr__()]) diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 8dcdbc3b040f..f7d1b97ce9dd 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -153,10 +153,10 @@ def __repr__(self): if self.bias is None: s += ', bias=False' s += ')' + shape = self.weight.shape return s.format(name=self.__class__.__name__, - mapping=self._channels if not self._in_channels - else '{0} -> {1}'.format(self._in_channels, - self._channels), + mapping=shape[0] if not shape[1] + else '{0} -> {1}'.format(shape[1], shape[0]), **self._kwargs) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index c73aee28a1aa..c42fbaa1fa44 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -171,11 +171,12 @@ def _check_and_get(self, arr_list, ctx): def _load_init(self, data, ctx): """(Re)initializes by loading from data.""" if self.shape: - for i, j in zip(self.shape, data.shape): - assert i == 0 or i == j, \ + for self_dim, data_dim in zip(self.shape, data.shape): + assert self_dim == 0 or self_dim == data_dim, \ "Failed loading Parameter %s from saved params: " \ "shape incompatible expacted %s vs saved %s"%( self.name, str(self.shape), str(data.shape)) + self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape)) if self.dtype: assert np.dtype(self.dtype).type == data.dtype, \ "Failed loading Parameter %s from saved params: " \ @@ -344,6 +345,8 @@ def set_data(self, data): "Parameter %s has not been initialized"%self.name for arr in self.list_data(): arr[:] = data + if not self.shape or np.prod(self.shape) <= 0: + self.shape = data.shape def data(self, ctx=None): """Returns a copy of this parameter on one context. Must have been diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index 9d318eb092ea..57c94f5fa2e9 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -111,17 +111,6 @@ def __init__(self, prefix=None, params=None): self._modified = False self.reset() - def __repr__(self): - s = '{name}({mapping}' - if hasattr(self, '_activation'): - s += ', {_activation}' - s += ')' - mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size - else self._hidden_size) - return s.format(name=self.__class__.__name__, - mapping=mapping, - **self.__dict__) - def reset(self): """Reset before re-using the cell for another graph.""" self._init_counter = -1 @@ -355,6 +344,18 @@ def state_info(self, batch_size=0): def _alias(self): return 'rnn' + def __repr__(self): + s = '{name}({mapping}' + if hasattr(self, '_activation'): + s += ', {_activation}' + s += ')' + shape = self.i2h_weight.shape + mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1] + else shape[0]) + return s.format(name=self.__class__.__name__, + mapping=mapping, + **self.__dict__) + def hybrid_forward(self, F, inputs, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): prefix = 't%d_'%self._counter @@ -453,6 +454,18 @@ def state_info(self, batch_size=0): def _alias(self): return 'lstm' + def __repr__(self): + s = '{name}({mapping}' + if hasattr(self, '_activation'): + s += ', {_activation}' + s += ')' + shape = self.i2h_weight.shape + mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1] + else shape[0]) + return s.format(name=self.__class__.__name__, + mapping=mapping, + **self.__dict__) + def hybrid_forward(self, F, inputs, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): prefix = 't%d_'%self._counter @@ -551,6 +564,18 @@ def state_info(self, batch_size=0): def _alias(self): return 'gru' + def __repr__(self): + s = '{name}({mapping}' + if hasattr(self, '_activation'): + s += ', {_activation}' + s += ')' + shape = self.i2h_weight.shape + mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1] + else shape[0]) + return s.format(name=self.__class__.__name__, + mapping=mapping, + **self.__dict__) + def hybrid_forward(self, F, inputs, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): # pylint: disable=too-many-locals diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 2d7c0087891a..f8986491d4ef 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -89,8 +89,9 @@ def __repr__(self): if self._dir == 2: s += ', bidirectional' s += ')' - mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size - else self._hidden_size) + shape = self.i2h_weight[0].shape + mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1] + else shape[0]) return s.format(name=self.__class__.__name__, mapping=mapping, **self.__dict__) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 6f9966b7211c..df0af34dfef5 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -553,8 +553,44 @@ def test_lambda(): assert_almost_equal(out1.asnumpy(), out3.asnumpy()) +def test_fill_shape_deferred(): + net = nn.HybridSequential() + with net.name_scope(): + net.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) + net.hybridize() + net.initialize() + net(mx.nd.ones((2,3,5,7))) + assert net[0].weight.shape[1] == 3, net[0].weight.shape[1] + assert net[1].gamma.shape[0] == 64, net[1].gamma.shape[0] + assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1] +def test_fill_shape_load(): + ctx = mx.context.current_context() + net1 = nn.HybridSequential() + with net1.name_scope(): + net1.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) + net1.hybridize() + net1.initialize(ctx=ctx) + net1(mx.nd.ones((2,3,5,7), ctx)) + net1.save_params('net_fill.params') + + net2 = nn.HybridSequential() + with net2.name_scope(): + net2.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) + net2.hybridize() + net2.initialize() + net2.load_params('net_fill.params', ctx) + assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1] + assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0] + assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1] + if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index c99836c5aa22..07b8956988fd 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -94,6 +94,13 @@ def test_convgru(): check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) +def test_conv_fill_shape(): + cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,)) + cell.hybridize() + check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7))) + assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1] + + def test_vardrop(): def check_vardrop(drop_inputs, drop_states, drop_outputs): cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'), diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index f71ac1809f64..228884219258 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -274,6 +274,19 @@ def test_rnn_layers(): with mx.autograd.record(): net(mx.nd.ones((2, 3, 10))).backward() +def test_cell_fill_shape(): + cell = gluon.rnn.LSTMCell(10) + cell.hybridize() + check_rnn_forward(cell, mx.nd.ones((2, 3, 7))) + assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1] + +def test_layer_fill_shape(): + layer = gluon.rnn.LSTM(10) + layer.hybridize() + check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7))) + print(layer) + assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1] + if __name__ == '__main__': import nose