-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmain.py
executable file
·54 lines (43 loc) · 1.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import os
import argparse
import copy
import opts
import utils
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as modelzoo
import models
import datasets.data_loader as loader
import back
#from tensorboard_logger import Logger
parser = opts.myargparser()
def main():
global opt, best_studentprec1
teachers = []
opt = parser.parse_args()
opt.logdir = opt.logdir + '/' + opt.name
# logger = Logger(opt.logdir)
print opt
print 'Loading models...'
for t in opt.teacher:
print "loaded teacher", t
teacher = models.teacherLoader[t](opt.cuda)
teachers.append(teacher)
print "Done loading from other file"
student = models.student.GetStudent()
if opt.cuda:
student = student.cuda()
student = models.setup(student, opt)
dataloader = loader.loadCIFAR10(opt)
train_loader = dataloader['train_loader']
val_loader = dataloader['val_loader']
back.teacherStudent(train_loader, val_loader, teachers, student, opt)
if __name__ == '__main__':
main()