-
Notifications
You must be signed in to change notification settings - Fork 16
/
show.py
47 lines (40 loc) · 1.44 KB
/
show.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
#!/usr/bin/env python -u
# -*- coding: utf-8 -*-
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def show_params(nnet, fid):
print("=" * 40, "Model Parameters", "=" * 40)
if fid is not None:
fid.write("=" * 40+ "Model Parameters"+ "=" * 40 +"\n")
num_params = 0
for module_name, m in nnet.named_modules():
if module_name == '':
for name, params in m.named_parameters():
print(name, params.size())
if fid is not None:
fid.write(str(name)+ str(params.size())+'\n')
i = 1
for j in params.size():
i = i * j
num_params += i
print('[*] Parameter Size: {}'.format(num_params))
print("=" * 98)
if fid is not None:
fid.write('[*] Parameter Size: {}'.format(num_params)+'\n')
fid.write("=" * 98+'\n')
fid.flush()
def show_model(nnet, fid):
print("=" * 40, "Model Structures", "=" * 40)
if fid is not None:
fid.write("=" * 40+ "Model Structures"+"=" * 40+'\n')
for module_name, m in nnet.named_modules():
if module_name == '':
print(m)
if fid is not None:
fid.write(str(m))
print("=" * 98)
if fid is not None:
fid.write("=" * 98+'\n')
fid.flush()