Skip to content

Commit

Permalink
fix trainer when the model involves share_parameters (apache#18880)
Browse files Browse the repository at this point in the history
* fix trainer when using shared_param

* add unittest
  • Loading branch information
ZiyueHuang authored Aug 8, 2020
1 parent cf908fd commit 706c369
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
10 changes: 5 additions & 5 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,25 +398,25 @@ def _allreduce_grads(self):
return
for i, param in enumerate(self._params):
if param.grad_req != 'null':

idx = self._param2idx[param._uuid]
grad_list = param.list_grad()
# sparse gradients, call push and pull separately
if grad_list[0].stype != 'default':
self._kvstore.push(i, grad_list, priority=-i)
self._kvstore.push(idx, grad_list, priority=-i)
if param._stype == 'default':
if self._update_on_kvstore:
pull_list = param.list_data()
else:
pull_list = param.list_grad()
self._kvstore.pull(i, pull_list, priority=-i,
self._kvstore.pull(idx, pull_list, priority=-i,
ignore_sparse=self._distributed)
else:
# allreduce dense gradients if not update_on_kvstore,
# otherwise push dense gradients, pull dense weights
if self._update_on_kvstore:
self._kvstore.pushpull(i, grad_list, out=param.list_data(), priority=-i)
self._kvstore.pushpull(idx, grad_list, out=param.list_data(), priority=-i)
else:
self._kvstore.pushpull(i, grad_list, priority=-i)
self._kvstore.pushpull(idx, grad_list, priority=-i)

def update(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update.
Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,43 @@ def test_trainer_allreduce_hybridsequential():
out = net(mx.nd.ones((1, 1), ctx=ctx))
out.backward()
trainer.allreduce_grads()


def test_trainer_share_parameters():
class Net(gluon.Block):
def __init__(self, **kwargs):
super(Net, self).__init__(**kwargs)
self.dense1 = gluon.nn.Dense(5, in_units=2, use_bias=False)
params = self.dense1.collect_params()
self.dense2 = gluon.nn.Dense(5, in_units=2,
use_bias=False).share_parameters(params)
self.dense3 = gluon.nn.Dense(5, in_units=5, use_bias=False)

def forward(self, x):
hidden = self.dense1(x) + self.dense2(x)
out = self.dense3(hidden)
return out

net = Net()
ctxes = [mx.cpu(0), mx.cpu(1)]
net.initialize(mx.init.One(), ctx=ctxes)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 1})
data = mx.nd.array([[1, 1], [1, 1]])
xs = gluon.utils.split_and_load(data, ctxes)
ys = []
with mx.autograd.record():
for x in xs:
y = net(x)
ys.append(y)
for y in ys:
y.backward()
trainer.step(1)
params = net.collect_params()
shared_params = []
for param in params.values():
p = param.data(mx.cpu(0)).asnumpy()
if p.shape[1] == 2:
shared_params.append(p)

assert((shared_params[0] == shared_params[1]).all())

0 comments on commit 706c369

Please sign in to comment.