-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_1gpu.py
110 lines (101 loc) · 4.46 KB
/
main_1gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import time
import torch
import random
import torch.nn as nn
import torch.optim as optim
from dataloader import DataLoader
from arguments import ArgumentParser
from modelfactory import ModelFactory
from utils import save_signature, test, fill_net, reset_lin_comb
max_acc = 0
args = ArgumentParser()
CrossEntropy = nn.CrossEntropyLoss()
train_net, test_net = ModelFactory(args)
trainloader, testloader = DataLoader(args)
alpha = torch.zeros(args.num_alpha, requires_grad=True, device="cuda:0")
with torch.no_grad():
theta = torch.cat([p.flatten() for p in train_net.parameters()])
net_optimizer = optim.SGD(train_net.parameters(), lr=1.)
lin_comb_net = torch.zeros(theta.shape).cuda()
layer_cnt = len([p for p in train_net.parameters()])
shapes = [list(p.shape) for p in train_net.parameters()]
lengths = [p.flatten().shape[0] for p in train_net.parameters()]
perm = [i for i in range(args.num_alpha)]
basis_net = torch.zeros(args.window, theta.shape[0]).cuda()
dummy_net = [torch.zeros(p.shape).cuda() for p in train_net.parameters()]
grads = torch.zeros(theta.shape, device='cuda:0')
saving_path = args.save_path + '_' + args.task + '_' + args.model + '_' + str(args.num_alpha)
if args.resume == 'True':
with torch.no_grad():
alpha = torch.load(saving_path + '/lr.pt').cuda()
if 'resnet' in args.model:
means = torch.load(saving_path + '/means.pt')
vars = torch.load(saving_path + '/vars.pt')
ind = 0
for p1 in train_net.modules():
if isinstance(p1, nn.BatchNorm2d):
leng = p1.running_var.shape[0]
p1.running_mean.copy_(means[ind:ind + leng])
p1.running_var.copy_(vars[ind:ind + leng])
ind += leng
else:
with torch.no_grad():
alpha[0] = 1.
lin_comb_net = reset_lin_comb(args, alpha, lin_comb_net, theta, layer_cnt, shapes, dummy_net, basis_net, lengths)
max_acc = test(train_net, test_net, lin_comb_net, testloader, lengths, shapes)
if args.evaluate:
epochs = 0
for e in range(args.epoch):
random.shuffle(perm)
idx = perm[:args.window]
fill_net(args, idx, layer_cnt, shapes, dummy_net, basis_net, lengths)
with torch.no_grad():
rest_of_net = lin_comb_net - torch.matmul(basis_net.T, alpha[idx]).T
optimizer = torch.optim.SGD([alpha], lr=args.lr, momentum=.9, weight_decay=1e-4)
for i, data in enumerate(trainloader):
t1 = time.time()
optimizer.zero_grad()
net_optimizer.zero_grad()
imgs, labels = data
imgs = imgs.cuda()
labels = labels.cuda()
select_subnet = torch.matmul(basis_net.T, alpha[idx]).T
with torch.no_grad():
start_ind = 0
for j, p in enumerate(train_net.parameters()):
p.copy_((select_subnet + rest_of_net)[start_ind:start_ind + lengths[j]].view(shapes[j]))
start_ind += lengths[j]
loss = CrossEntropy(train_net(imgs), labels)
if i % args.log_rate == 0:
print("Epoch:", e, "\tIteration:", i, "\tLoss:", round(loss.item(), 4), "\tTime:", round((time.time() - t1) * 1000, 2), 'ms')
loss.backward()
with torch.no_grad():
start_ind = 0
for j, p in enumerate(train_net.parameters()):
grads[start_ind:start_ind + lengths[j]].copy_(p.grad.flatten())
start_ind += lengths[j]
if alpha.grad is None:
alpha.grad = torch.zeros(alpha.shape, device=alpha.get_device())
alpha.grad[idx] = torch.matmul(grads, basis_net.T)
optimizer.step()
with torch.no_grad():
lin_comb_net.copy_(rest_of_net + torch.matmul(basis_net.T, alpha[idx]).T)
lin_comb_net = reset_lin_comb(args, alpha, lin_comb_net, theta, layer_cnt, shapes, dummy_net, basis_net, lengths)
t1 = time.time()
acc = test(train_net, test_net, lin_comb_net, testloader, lengths, shapes)
if max_acc <= acc:
max_acc = acc
means = []
vars = []
for p in train_net.modules():
if isinstance(p, nn.BatchNorm2d):
means.append(p.running_mean)
vars.append(p.running_var)
if 'resnet' in args.model:
save_signature(args, alpha, saving_path, torch.cat(means), torch.cat(vars))
else:
save_signature(args, alpha, saving_path)
print("Acc:", round(acc, 4), "\tMax Acc:", round(max_acc, 4), "\tTime:", round(time.time() - t1, 3), 's')
if args.save_model:
torch.save(train_net.state_dict(), "final_model.pt")
print(max_acc)