Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fill shape after set_data
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Nov 2, 2017
1 parent 8592e1c commit 4950b8c
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 48 deletions.
51 changes: 27 additions & 24 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __init__(self, prefix=None, params=None):
self._reg_params = {}
self._cached_graph = ()
self._cached_op = None
self._cached_params = None
self._cached_call_params = ()
self._out_format = None
self._in_format = None
self._active = False
Expand Down Expand Up @@ -363,34 +363,34 @@ def _get_graph(self, *args):

def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
name2pos = {var.name: i for i, var in enumerate(inputs)}
self._cached_op = ndarray.CachedOp(out)

params = dict(self.collect_params().items())
self._cached_params = [params.get(name, None) for name in out.list_inputs()]
assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
if args and not isinstance(args[0], Symbol):
self._finish_deferred_init(*args)

params = dict(self.collect_params()._params)
assert len(params) + len(inputs) == len(out.list_inputs()), \
"Wrong number of inputs."

name2pos = {var.name: i for i, var in enumerate(inputs)}
self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
if name not in params]
self._cached_call_params = tuple((False, params[name].data) if name in params
else (True, name2pos[name])
for name in out.list_inputs())

def _finish_deferred_init(self, *args):
self.infer_shape(*args)
for i in self.collect_params().values():
i._finish_deferred_init()

def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)

try:
cargs = [i.data() if i else None for i in self._cached_params]
except DeferredInitializationError:
self.infer_shape(*args)
for i in self._cached_params:
if i is not None:
i._finish_deferred_init()
cargs = [i.data() if i else None for i in self._cached_params]

args, fmt = _flatten(args)
assert fmt == self._in_format, "Invalid input format"
for i, j in self._in_idx:
cargs[i] = args[j]

cargs = (args[p] if is_arg else p()
for is_arg, p in self._cached_call_params)
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
Expand All @@ -399,6 +399,7 @@ def _call_cached_op(self, *args):
def _clear_cached_op(self):
self._cached_graph = ()
self._cached_op = None
self._cached_call_params = ()

def register_child(self, block):
if not isinstance(block, HybridBlock):
Expand Down Expand Up @@ -462,14 +463,12 @@ def forward(self, x, *args):
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
if self._active:
return self._call_cached_op(x, *args)
try:
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
self.infer_shape(x, *args)
for i in self.collect_params().values():
i._finish_deferred_init()
self._finish_deferred_init(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
return self.hybrid_forward(ndarray, x, *args, **params)

Expand Down Expand Up @@ -559,7 +558,11 @@ def __init__(self, outputs, inputs, params=None):
def forward(self, x, *args):
if isinstance(x, NDArray):
with x.context:
return self._call_cached_op(x, *args)
try:
return self._call_cached_op(x, *args)
except DeferredInitializationError:
self._finish_deferred_init(x, *args)
return self._call_cached_op(x, *args)

assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
Expand Down
6 changes: 4 additions & 2 deletions python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ def __repr__(self):
s += ', {_conv_layout}'
s += ')'
attrs = self.__dict__
mapping = ('{_in_channels} -> {_hidden_channels}'.format(**attrs) if self._in_channels
else self._hidden_channels)
shape = self.i2h_weight.shape
in_channels = shape[1 if self._channel_axis == 1 else -1]
mapping = ('{0} -> {1}'.format(in_channels, shape[0]) if in_channels
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**attrs)
Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,11 @@ def hybrid_forward(self, F, x, weight, bias=None):

def __repr__(self):
s = '{name}({layout}, {act})'
shape = self.weight.shape
return s.format(name=self.__class__.__name__,
act=self.act if self.act else 'linear',
layout='{0} -> {1}'.format(self._in_units, self._units) if self._in_units
else self._units)
layout='{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])


class Activation(HybridBlock):
Expand Down Expand Up @@ -360,8 +361,9 @@ def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):

def __repr__(self):
s = '{name}({content}'
if hasattr(self, 'in_channels'):
s += ', in_channels={0}'.format(self.in_channels)
in_channels = self.gamma.shape[0]
if in_channels:
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/gluon/nn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def __repr__(self):
if self.bias is None:
s += ', bias=False'
s += ')'
shape = self.weight.shape
return s.format(name=self.__class__.__name__,
mapping=self._channels if not self._in_channels
else '{0} -> {1}'.format(self._in_channels,
self._channels),
mapping=shape[0] if not shape[1]
else '{0} -> {1}'.format(shape[1], shape[0]),
**self._kwargs)


Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,12 @@ def _check_and_get(self, arr_list, ctx):
def _load_init(self, data, ctx):
"""(Re)initializes by loading from data."""
if self.shape:
for i, j in zip(self.shape, data.shape):
assert i == 0 or i == j, \
for self_dim, data_dim in zip(self.shape, data.shape):
assert self_dim == 0 or self_dim == data_dim, \
"Failed loading Parameter %s from saved params: " \
"shape incompatible expacted %s vs saved %s"%(
self.name, str(self.shape), str(data.shape))
self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape))
if self.dtype:
assert np.dtype(self.dtype).type == data.dtype, \
"Failed loading Parameter %s from saved params: " \
Expand Down Expand Up @@ -344,6 +345,8 @@ def set_data(self, data):
"Parameter %s has not been initialized"%self.name
for arr in self.list_data():
arr[:] = data
if not self.shape or np.prod(self.shape) <= 0:
self.shape = data.shape

def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
Expand Down
47 changes: 36 additions & 11 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,6 @@ def __init__(self, prefix=None, params=None):
self._modified = False
self.reset()

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
else self._hidden_size)
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def reset(self):
"""Reset before re-using the cell for another graph."""
self._init_counter = -1
Expand Down Expand Up @@ -355,6 +344,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'rnn'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
Expand Down Expand Up @@ -453,6 +454,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'lstm'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
Expand Down Expand Up @@ -551,6 +564,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'gru'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
# pylint: disable=too-many-locals
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def __repr__(self):
if self._dir == 2:
s += ', bidirectional'
s += ')'
mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
else self._hidden_size)
shape = self.i2h_weight[0].shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,44 @@ def test_lambda():
assert_almost_equal(out1.asnumpy(), out3.asnumpy())


def test_fill_shape_deferred():
net = nn.HybridSequential()
with net.name_scope():
net.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net.hybridize()
net.initialize()
net(mx.nd.ones((2,3,5,7)))
assert net[0].weight.shape[1] == 3, net[0].weight.shape[1]
assert net[1].gamma.shape[0] == 64, net[1].gamma.shape[0]
assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]


def test_fill_shape_load():
ctx = mx.context.current_context()
net1 = nn.HybridSequential()
with net1.name_scope():
net1.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net1.hybridize()
net1.initialize(ctx=ctx)
net1(mx.nd.ones((2,3,5,7), ctx))
net1.save_params('net_fill.params')

net2 = nn.HybridSequential()
with net2.name_scope():
net2.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net2.hybridize()
net2.initialize()
net2.load_params('net_fill.params', ctx)
assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1]
assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0]
assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]


if __name__ == '__main__':
import nose
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_gluon_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def test_convgru():
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))


def test_conv_fill_shape():
cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,))
cell.hybridize()
check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]


def test_vardrop():
def check_vardrop(drop_inputs, drop_states, drop_outputs):
cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'),
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,19 @@ def test_rnn_layers():
with mx.autograd.record():
net(mx.nd.ones((2, 3, 10))).backward()

def test_cell_fill_shape():
cell = gluon.rnn.LSTMCell(10)
cell.hybridize()
check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]

def test_layer_fill_shape():
layer = gluon.rnn.LSTM(10)
layer.hybridize()
check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
print(layer)
assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1]


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 4950b8c

Please sign in to comment.