Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fill shape after set_data
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Oct 22, 2017
1 parent 062893f commit 0182896
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 24 deletions.
6 changes: 4 additions & 2 deletions python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,10 @@ def __repr__(self):
s += ', {_conv_layout}'
s += ')'
attrs = self.__dict__
mapping = ('{_in_channels} -> {_hidden_channels}'.format(**attrs) if self._in_channels
else self._hidden_channels)
shape = self.i2h_weight.shape
in_channels = shape[1 if self._channel_axis == 1 else -1]
mapping = ('{0} -> {1}'.format(in_channels, shape[0]) if in_channels
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**attrs)
Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,11 @@ def hybrid_forward(self, F, x, weight, bias=None):

def __repr__(self):
s = '{name}({layout}, {act})'
shape = self.weight.shape
return s.format(name=self.__class__.__name__,
act=self.act if self.act else 'linear',
layout='{0} -> {1}'.format(self._in_units, self._units) if self._in_units
else self._units)
layout='{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])


class Activation(HybridBlock):
Expand Down Expand Up @@ -358,8 +359,9 @@ def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):

def __repr__(self):
s = '{name}({content}'
if hasattr(self, 'in_channels'):
s += ', in_channels={0}'.format(self.in_channels)
in_channels = self.gamma.shape[0]
if in_channels:
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/gluon/nn/conv_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,10 @@ def __repr__(self):
if self.bias is None:
s += ', bias=False'
s += ')'
shape = self.weight.shape
return s.format(name=self.__class__.__name__,
mapping=self._channels if not self._in_channels
else '{0} -> {1}'.format(self._in_channels,
self._channels),
mapping=shape[0] if not shape[1]
else '{0} -> {1}'.format(shape[1], shape[0]),
**self._kwargs)


Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,12 @@ def _check_and_get(self, arr_list, ctx):
def _load_init(self, data, ctx):
"""(Re)initializes by loading from data."""
if self.shape:
for i, j in zip(self.shape, data.shape):
assert i == 0 or i == j, \
for i, (self_dim, data_dim) in enumerate(zip(self.shape, data.shape)):
assert self_dim == 0 or self_dim == data_dim, \
"Failed loading Parameter %s from saved params: " \
"shape incompatible expacted %s vs saved %s"%(
self.name, str(self.shape), str(data.shape))
self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape))
if self.dtype:
assert np.dtype(self.dtype).type == data.dtype, \
"Failed loading Parameter %s from saved params: " \
Expand Down Expand Up @@ -344,6 +345,8 @@ def set_data(self, data):
"Parameter %s has not been initialized"%self.name
for arr in self.list_data():
arr[:] = data
if not self.shape or np.prod(self.shape) <= 0:
self.shape = data.shape

def data(self, ctx=None):
"""Returns a copy of this parameter on one context. Must have been
Expand Down
47 changes: 36 additions & 11 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,6 @@ def __init__(self, prefix=None, params=None):
self._modified = False
self.reset()

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
else self._hidden_size)
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def reset(self):
"""Reset before re-using the cell for another graph."""
self._init_counter = -1
Expand Down Expand Up @@ -355,6 +344,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'rnn'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
Expand Down Expand Up @@ -453,6 +454,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'lstm'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
prefix = 't%d_'%self._counter
Expand Down Expand Up @@ -551,6 +564,18 @@ def state_info(self, batch_size=0):
def _alias(self):
return 'gru'

def __repr__(self):
s = '{name}({mapping}'
if hasattr(self, '_activation'):
s += ', {_activation}'
s += ')'
shape = self.i2h_weight.shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
# pylint: disable=too-many-locals
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def __repr__(self):
if self._dir == 2:
s += ', bidirectional'
s += ')'
mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size
else self._hidden_size)
shape = self.i2h_weight[0].shape
mapping = ('{0} -> {1}'.format(shape[1], shape[0]) if shape[1]
else shape[0])
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)
Expand Down
36 changes: 36 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,42 @@ def test_hybrid_stale_cache():
net.initialize()
assert net(mx.nd.ones((2,3,5))).shape == (2, 10)

def test_fill_shape_deferred():
net = nn.HybridSequential()
with net.name_scope():
net.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net.hybridize()
net.initialize()
net(mx.nd.ones((2,3,5,7)))
assert net[0].weight.shape[1] == 3, net[0].weight.shape[1]
assert net[1].gamma.shape[0] == 64, net[1].gamma.shape[0]
assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1]

def test_fill_shape_load():
net1 = nn.HybridSequential()
with net1.name_scope():
net1.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net1.hybridize()
net1.initialize()
net1(mx.nd.ones((2,3,5,7)))
net1.save_params('net_fill.params')

net2 = nn.HybridSequential()
with net2.name_scope():
net2.add(nn.Conv2D(64, kernel_size=2, padding=1),
nn.BatchNorm(),
nn.Dense(10))
net2.hybridize()
net2.initialize()
net2.load_params('net_fill.params', mx.cpu())
assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1]
assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0]
assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1]

if __name__ == '__main__':
import nose
nose.runmodule()
7 changes: 7 additions & 0 deletions tests/python/unittest/test_gluon_contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ def test_convgru():
check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48))


def test_conv_fill_shape():
cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,))
cell.hybridize()
check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7)))
assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1]


def test_vardrop():
def check_vardrop(drop_inputs, drop_states, drop_outputs):
cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'),
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,19 @@ def test_rnn_layers():
with mx.autograd.record():
net(mx.nd.ones((2, 3, 10))).backward()

def test_cell_fill_shape():
cell = gluon.rnn.LSTMCell(10)
cell.hybridize()
check_rnn_forward(cell, mx.nd.ones((2, 3, 7)))
assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1]

def test_layer_fill_shape():
layer = gluon.rnn.LSTM(10)
layer.hybridize()
check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
print(layer)
assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1]


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 0182896

Please sign in to comment.