Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support shared parameter in summary (#11379)
Browse files Browse the repository at this point in the history
  • Loading branch information
szha authored Jun 29, 2018
1 parent 2594fca commit 36f2aae
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
10 changes: 10 additions & 0 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def summary(self, *inputs):
:class:`mxnet.ndarray.NDArray` is supported.
"""
summary = OrderedDict()
seen = set()
hooks = []

def _get_shape_str(args):
Expand Down Expand Up @@ -611,9 +612,14 @@ def _summary_hook(block, _, outputs):

params = 0
summary[m_key]['trainable'] = 0
summary[m_key]['shared'] = 0
for p in block._reg_params.values():
params += p.data().size
summary[m_key]['trainable'] += 0 if p.grad_req == 'null' else p.data().size
if p in seen:
summary[m_key]['shared'] += p.data().size
else:
seen.add(p)
summary[m_key]['n_params'] = params

from .nn.basic_layers import Sequential, HybridSequential
Expand All @@ -624,6 +630,7 @@ def _summary_hook(block, _, outputs):
summary['Input']['output_shape'] = _get_shape_str(inputs)
summary['Input']['n_params'] = 0
summary['Input']['trainable'] = 0
summary['Input']['shared'] = 0

try:
self.apply(_register_summary_hook)
Expand All @@ -635,16 +642,19 @@ def _summary_hook(block, _, outputs):
print('='*80)
total_params = 0
trainable_params = 0
shared_params = 0
for layer in summary:
print(line_format.format(layer,
str(summary[layer]['output_shape']),
summary[layer]['n_params']))
total_params += summary[layer]['n_params']
trainable_params += summary[layer]['trainable']
shared_params += summary[layer]['shared']
print('='*80)
print('Total params: ' + str(total_params))
print('Trainable params: ' + str(trainable_params))
print('Non-trainable params: ' + str(total_params - trainable_params))
print('Shared params: ' + str(shared_params))
print('-'*80)
finally:
for h in hooks:
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,9 +1264,9 @@ def test_summary():

net2 = nn.Sequential()
with net2.name_scope():
net2.add(nn.Embedding(10, 20))
net2.add(nn.Embedding(40, 30))
net2.add(gluon.rnn.LSTM(30))
net2.add(nn.Dense(40, flatten=False))
net2.add(nn.Dense(40, flatten=False, params=net2[0].params))
net2.initialize()
net2.summary(mx.nd.ones((80, 32)))

Expand Down

0 comments on commit 36f2aae

Please sign in to comment.