-
Notifications
You must be signed in to change notification settings - Fork 25
/
main.py
56 lines (44 loc) · 1.63 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Learning from Between-class Examples for Deep Sound Recognition.
Yuji Tokozume, Yoshitaka Ushiku, and Tatsuya Harada
"""
import sys
import os
import chainer
import opts
import models
import dataset
from train import Trainer
def main():
opt = opts.parse()
chainer.cuda.get_device_from_id(opt.gpu).use()
for split in opt.splits:
print('+-- Split {} --+'.format(split))
train(opt, split)
def train(opt, split):
model = getattr(models, opt.netType)(opt.nClasses)
model.to_gpu()
optimizer = chainer.optimizers.NesterovAG(lr=opt.LR, momentum=opt.momentum)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(opt.weightDecay))
train_iter, val_iter = dataset.setup(opt, split)
trainer = Trainer(model, optimizer, train_iter, val_iter, opt)
if opt.testOnly:
chainer.serializers.load_npz(
os.path.join(opt.save, 'model_split{}.npz'.format(split)), trainer.model)
val_top1 = trainer.val()
print('| Val: top1 {:.2f}'.format(val_top1))
return
for epoch in range(1, opt.nEpochs + 1):
train_loss, train_top1 = trainer.train(epoch)
val_top1 = trainer.val()
sys.stderr.write('\r\033[K')
sys.stdout.write(
'| Epoch: {}/{} | Train: LR {} Loss {:.3f} top1 {:.2f} | Val: top1 {:.2f}\n'.format(
epoch, opt.nEpochs, trainer.optimizer.lr, train_loss, train_top1, val_top1))
sys.stdout.flush()
if opt.save != 'None':
chainer.serializers.save_npz(
os.path.join(opt.save, 'model_split{}.npz'.format(split)), model)
if __name__ == '__main__':
main()