Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

gluon improvement #8152

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 27 additions & 24 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __init__(self, prefix=None, params=None):
self._reg_params = {}
self._cached_graph = ()
self._cached_op = None
self._cached_params = None
self._cached_call_params = ()
self._out_format = None
self._in_format = None
self._active = False
Expand Down Expand Up @@ -363,34 +363,34 @@ def _get_graph(self, *args):

def _build_cache(self, *args):
inputs, out = self._get_graph(*args)
name2pos = {var.name: i for i, var in enumerate(inputs)}
self._cached_op = ndarray.CachedOp(out)

params = dict(self.collect_params().items())
self._cached_params = [params.get(name, None) for name in out.list_inputs()]
assert len(params) + len(self._cached_graph[0]) == len(out.list_inputs()), \
if args and not isinstance(args[0], Symbol):
self._finish_deferred_init(*args)

params = dict(self.collect_params()._params)
assert len(params) + len(inputs) == len(out.list_inputs()), \
"Wrong number of inputs."

name2pos = {var.name: i for i, var in enumerate(inputs)}
self._in_idx = [(i, name2pos[name]) for i, name in enumerate(out.list_inputs())
if name not in params]
self._cached_call_params = tuple((False, params[name].data) if name in params
else (True, name2pos[name])
for name in out.list_inputs())

def _finish_deferred_init(self, *args):
self.infer_shape(*args)
for i in self.collect_params().values():
i._finish_deferred_init()

def _call_cached_op(self, *args):
if self._cached_op is None:
self._build_cache(*args)

try:
cargs = [i.data() if i else None for i in self._cached_params]
except DeferredInitializationError:
self.infer_shape(*args)
for i in self._cached_params:
if i is not None:
i._finish_deferred_init()
cargs = [i.data() if i else None for i in self._cached_params]

args, fmt = _flatten(args)
assert fmt == self._in_format, "Invalid input format"
for i, j in self._in_idx:
cargs[i] = args[j]

cargs = (args[p] if is_arg else p()
for is_arg, p in self._cached_call_params)
out = self._cached_op(*cargs)
if isinstance(out, NDArray):
out = [out]
Expand All @@ -399,6 +399,7 @@ def _call_cached_op(self, *args):
def _clear_cached_op(self):
self._cached_graph = ()
self._cached_op = None
self._cached_call_params = ()

def register_child(self, block):
if not isinstance(block, HybridBlock):
Expand Down Expand Up @@ -462,14 +463,12 @@ def forward(self, x, *args):
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
if self._active:
return self._call_cached_op(x, *args)
try:
if self._active:
return self._call_cached_op(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
except DeferredInitializationError:
self.infer_shape(x, *args)
for i in self.collect_params().values():
i._finish_deferred_init()
self._finish_deferred_init(x, *args)
params = {i: j.data(ctx) for i, j in self._reg_params.items()}
return self.hybrid_forward(ndarray, x, *args, **params)

Expand Down Expand Up @@ -559,7 +558,11 @@ def __init__(self, outputs, inputs, params=None):
def forward(self, x, *args):
if isinstance(x, NDArray):
with x.context:
return self._call_cached_op(x, *args)
try:
return self._call_cached_op(x, *args)
except DeferredInitializationError:
self._finish_deferred_init(x, *args)
return self._call_cached_op(x, *args)

assert isinstance(x, Symbol), \
"HybridBlock requires the first argument to forward be either " \
Expand Down
6 changes: 4 additions & 2 deletions python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ def __repr__(self):
s += ', {_conv_layout}'
s += ')'
attrs = self.__dict__
mapping = ('{_in_channels} -> {_hidden_channels}'.format(**attrs) if self._in_channels
else self._hidden_channels)
shape = self.i2h_weight.shape
in_channels = shape[1 if self._channel_axis == 1 else -1]
mapping = ('{0} -> {1}'.format(in_channels, shape[0]) if in_channels
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**attrs)
Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,11 @@ def hybrid_forward(self, F, x, weight, bias=None):

def __repr__(self):
s = '{name}({layout}, {act})'
shape = self.weight.shape
return s.format(name=self.__class__.__name__,
act=self.act if self.act else 'linear',
layout='{0} -> {1}'.format(self._in_units, self._units) if self._in_units
else self._units)
layout='{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])


class Activation(HybridBlock):
Expand Down Expand Up @@ -360,8 +361,9 @@ def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):

def __repr__(self):
s = '{name}({content}'
if hasattr(self, 'in_channels'):
s += ', in_channels={0}'.format(self.in_channels)
in_channels = self.gamma.shape[0]
if in_channels:
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/gluon/nn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def __repr__(self):
if self.bias is None:
s += ', bias=False'
s += ')'
shape = self.weight.shape
return s.format(name=self.__class__.__name__,
mapping=self._channels if not self._in_channels
else '{0} -> {1}'.format(self._in_channels,
self._channels),
mapping=shape[0] if not shape[1]
else '{0} -> {1}'.format(shape[1], shape[0]),
**self._kwargs)


Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,12 @@ def _check_and_get(self, arr_list, ctx):
def _load_init(self, data, ctx):
"""(Re)initializes by loading from data."""
if self.shape:
for i, j in zip(self.shape, data.shape):
assert i == 0 or i == j, \
for self_dim, data_dim in zip(self.shape, data.shape):
assert self_dim == 0 or self_dim == data_dim, \
"Failed loading Parameter %s from saved params: " \
"shape incompatible expacted %s vs saved %s"%(
self.name, str(self.shape), str(data.shape))
self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape))
if self.dtype:
assert np.dtype(self.dtype).type == data.dtype, \
"Failed loading Parameter %s from saved params: " \
Expand Down Expand Up @@ -344,6 +345,8 @@ def set_data(self, data):
"Parameter %s has not been initialized"%self.name
for arr in self.list_data():
arr[:] = data
if not self.shape or np.prod(self.shape) <= 0:
self.shape = data.shape

def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
Expand Down
47 changes: 36 additions & 11 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,6 @@ def __init__(self, prefix=None, params=None):
self._modified = False
self.reset()

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
else self._hidden_size)
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def reset(self):
"""Reset before re-using the cell for another graph."""
self._init_counter = -1
Expand Down Expand Up @@ -355,6 +344,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'rnn'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
Expand Down Expand Up @@ -453,6 +454,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'lstm'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
Expand Down Expand Up @@ -551,6 +564,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'gru'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
# pylint: disable=too-many-locals
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def __repr__(self):
if self._dir == 2:
s += ', bidirectional'
s += ')'
mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
else self._hidden_size)
shape = self.i2h_weight[0].shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should show None -> xx when shape is not defined yet

else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,8 +553,44 @@ def test_lambda():
assert_almost_equal(out1.asnumpy(), out3.asnumpy())


def test_fill_shape_deferred():
net = nn.HybridSequential()
with net.name_scope():
net.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net.hybridize()
net.initialize()
net(mx.nd.ones((2,3,5,7)))
assert net[0].weight.shape[1] == 3, net[0].weight.shape[1]
assert net[1].gamma.shape[0] == 64, net[1].gamma.shape[0]
assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]


def test_fill_shape_load():
ctx = mx.context.current_context()
net1 = nn.HybridSequential()
with net1.name_scope():
net1.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net1.hybridize()
net1.initialize(ctx=ctx)
net1(mx.nd.ones((2,3,5,7), ctx))
net1.save_params('net_fill.params')

net2 = nn.HybridSequential()
with net2.name_scope():
net2.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net2.hybridize()
net2.initialize()
net2.load_params('net_fill.params', ctx)
assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1]
assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0]
assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]


if __name__ == '__main__':
import nose
Expand Down
7 changes: 7 additions & 0 deletions tests/python/unittest/test_gluon_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def test_convgru():
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))


def test_conv_fill_shape():
cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,))
cell.hybridize()
check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]


def test_vardrop():
def check_vardrop(drop_inputs, drop_states, drop_outputs):
cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'),
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,19 @@ def test_rnn_layers():
with mx.autograd.record():
net(mx.nd.ones((2, 3, 10))).backward()

def test_cell_fill_shape():
cell = gluon.rnn.LSTMCell(10)
cell.hybridize()
check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]

def test_layer_fill_shape():
layer = gluon.rnn.LSTM(10)
layer.hybridize()
check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
print(layer)
assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1]


if __name__ == '__main__':
import nose
Expand Down