-
Notifications
You must be signed in to change notification settings - Fork 20
/
train.py
132 lines (93 loc) · 5.51 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
'''Train script.
'''
import os
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchsummary import summary
from torchvision import transforms
from level_dict import hierarchy
from runtime_args import args
from load_dataset import LoadDataset
from model import resnet50
from model.hierarchical_loss import HierarchicalLossNetwork
from helper import calculate_accuracy
from plot import plot_loss_acc
device = torch.device("cuda:0" if torch.cuda.is_available() and args.device == 'gpu' else 'cpu')
if not os.path.exists(args.graphs_folder) : os.makedirs(args.graphs_folder)
train_dataset = LoadDataset(image_size=args.img_size, image_depth=args.img_depth, csv_path=args.train_csv,
cifar_metafile=args.metafile, transform=transforms.Compose([transforms.RandomAffine(40, scale=(.85, 1.15), shear=0, resample=0),
transforms.RandomHorizontalFlip(),
transforms.RandomPerspective(distortion_scale=0.2),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
transforms.ToTensor()]))
test_dataset = LoadDataset(image_size=args.img_size, image_depth=args.img_depth, csv_path=args.test_csv,
cifar_metafile=args.metafile, transform=transforms.ToTensor())
train_generator = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=args.no_shuffle, num_workers=args.num_workers)
test_generator = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=args.no_shuffle, num_workers=args.num_workers)
model = resnet50.ResNet50()
optimizer = Adam(model.parameters(), lr=args.learning_rate)
model = model.to(device)
HLN = HierarchicalLossNetwork(metafile_path=args.metafile, hierarchical_labels=hierarchy, device=device)
train_epoch_loss = []
train_epoch_superclass_accuracy = []
train_epoch_subclass_accuracy = []
test_epoch_loss = []
test_epoch_superclass_accuracy = []
test_epoch_subclass_accuracy = []
for epoch_idx in range(args.epoch):
i = 0
epoch_loss = []
epoch_superclass_accuracy = []
epoch_subclass_accuracy = []
model.train()
for i, sample in tqdm(enumerate(train_generator)):
batch_x, batch_y1, batch_y2 = sample['image'].to(device), sample['label_1'].to(device), sample['label_2'].to(device)
optimizer.zero_grad()
superclass_pred,subclass_pred = model(batch_x)
prediction = [superclass_pred, subclass_pred]
dloss = HLN.calculate_dloss(prediction, [batch_y1, batch_y2])
lloss = HLN.calculate_lloss(prediction, [batch_y1, batch_y2])
total_loss = lloss + dloss
total_loss.backward()
optimizer.step()
epoch_loss.append(total_loss.item())
epoch_superclass_accuracy.append(calculate_accuracy(predictions=prediction[0].detach(), labels=batch_y1))
epoch_subclass_accuracy.append(calculate_accuracy(predictions=prediction[1].detach(), labels=batch_y2))
train_epoch_loss.append(sum(epoch_loss)/(i+1))
train_epoch_superclass_accuracy.append(sum(epoch_superclass_accuracy)/(i+1))
train_epoch_subclass_accuracy.append(sum(epoch_subclass_accuracy)/(i+1))
print(f'Training Loss at epoch {epoch_idx} : {sum(epoch_loss)/(i+1)}')
print(f'Training Superclass accuracy at epoch {epoch_idx} : {sum(epoch_superclass_accuracy)/(i+1)}')
print(f'Training Subclass accuracy at epoch {epoch_idx} : {sum(epoch_subclass_accuracy)/(i+1)}')
j = 0
epoch_loss = []
epoch_superclass_accuracy = []
epoch_subclass_accuracy = []
model.eval()
with torch.set_grad_enabled(False):
for j, sample in tqdm(enumerate(test_generator)):
batch_x, batch_y1, batch_y2 = sample['image'].to(device), sample['label_1'].to(device), sample['label_2'].to(device)
superclass_pred,subclass_pred = model(batch_x)
prediction = [superclass_pred,subclass_pred]
dloss = HLN.calculate_dloss(prediction, [batch_y1, batch_y2])
lloss = HLN.calculate_lloss(prediction, [batch_y1, batch_y2])
total_loss = lloss + dloss
epoch_loss.append(total_loss.item())
epoch_superclass_accuracy.append(calculate_accuracy(predictions=prediction[0], labels=batch_y1))
epoch_subclass_accuracy.append(calculate_accuracy(predictions=prediction[1], labels=batch_y2))
test_epoch_loss.append(sum(epoch_loss)/(j+1))
test_epoch_superclass_accuracy.append(sum(epoch_superclass_accuracy)/(j+1))
test_epoch_subclass_accuracy.append(sum(epoch_subclass_accuracy)/(j+1))
#plot accuracy and loss graph
plot_loss_acc(path=args.graphs_folder, num_epoch=epoch_idx, train_accuracies_superclass=train_epoch_superclass_accuracy,
train_accuracies_subclass=train_epoch_subclass_accuracy, train_losses=train_epoch_loss,
test_accuracies_superclass=test_epoch_superclass_accuracy, test_accuracies_subclass=test_epoch_subclass_accuracy,
test_losses=test_epoch_loss)
print(f'Testing Loss at epoch {epoch_idx} : {sum(epoch_loss)/(j+1)}')
print(f'Testing Superclass accuracy at epoch {epoch_idx} : {sum(epoch_superclass_accuracy)/(j+1)}')
print(f'Testing Subclass accuracy at epoch {epoch_idx} : {sum(epoch_subclass_accuracy)/(j+1)}')
print('-------------------------------------------------------------------------------------------')
torch.save(model.state_dict(), args.model_save_path.rstrip('/')+'dhc.pth')
print("Model saved!")