-
Notifications
You must be signed in to change notification settings - Fork 75
/
main_pytorch.py
216 lines (180 loc) · 7.81 KB
/
main_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
## Custom Imports
from src.p2_dataload import KaggleAmazonDataset
from src.p_neuro import Net, ResNet50, ResNet101, ResNet152, DenseNet121
from src.p3_neuroRNN import GRU_ResNet50, LSTM_ResNet50, Skip_LSTM_RN50
from src.p_training import train, snapshot
#from src.p2_validation import validate
from src.p_validation import validate
from src.p_model_selection import train_valid_split
from src.p_logger import setup_logs
#from src.p2_prediction import predict, output
from src.p_prediction import predict, output
from src.p_data_augmentation import ColorJitter, PowerPIL
# from src.p_metrics import SmoothF2Loss
from src.p2_loss import ConvolutedLoss
from src.p_sampler import SubsetSampler, balance_weights
## Utilities
import random
import logging
import time
from timeit import default_timer as timer
import os
## Libraries
import numpy as np
import math
## Torch
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from torchsample.transforms import Affine
from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler
############################################################################
####### CONTROL CENTER ############# STAR COMMAND #########################
## Variables setup
model = ResNet50(17).cuda()
# model = ResNet152(17).cuda()
# model = GRU_ResNet50(17, 128, 2).cuda()
# model = LSTM_ResNet50(17, 128, 2).cuda()
# model = Skip_LSTM_RN50(17, 128, 2).cuda()
epochs = 16
batch_size = 64
# Run name
run_name = time.strftime("%Y-%m-%d_%H%M-") + "resnet50-L2reg-new-data"
## Normalization on ImageNet mean/std for finetuning
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
# Note, p_training has lr_decay automated
optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.0005) # Finetuning whole model
# criterion = ConvolutedLoss()
criterion = torch.nn.MultiLabelSoftMarginLoss(
weight = torch.Tensor([1, 4, 2, 1,
1, 3, 3, 3,
4, 4, 1, 2,
1, 1, 3, 4, 1])
).cuda()
#classes = [
# 'clear', 'cloudy', 'haze','partly_cloudy',
# 'agriculture','artisinal_mine','bare_ground','blooming',
# 'blow_down','conventional_mine','cultivation','habitation',
# 'primary','road','selective_logging','slash_burn','water'
# ]
## Frequency
# [28203, 2330, 2695, 7251,
# 12338, 339, 859, 332,
# 98, 100, 4477, 3662,
# 37840, 8076, 340, 209, 7262]
save_dir = './snapshots'
####### CONTROL CENTER ############# STAR COMMAND #########################
############################################################################
if __name__ == "__main__":
# Initiate timer
global_timer = timer()
# Setup logs
logger = setup_logs(save_dir, run_name)
# Setting random seeds for reproducibility. (Caveat, some CuDNN algorithms are non-deterministic)
torch.manual_seed(1337)
torch.cuda.manual_seed(1337)
np.random.seed(1337)
random.seed(1337)
##############################################################
## Loading the dataset
## Augmentation + Normalization for full training
ds_transform_augmented = transforms.Compose([
transforms.RandomSizedCrop(224),
PowerPIL(),
transforms.ToTensor(),
# ColorJitter(), # Use PowerPIL instead, with PillowSIMD it's much more efficient
normalize,
# Affine(
# rotation_range = 15,
# translation_range = (0.2,0.2),
# shear_range = math.pi/6,
# zoom_range=(0.7,1.4)
#)
])
## Normalization only for validation and test
ds_transform_raw = transforms.Compose([
transforms.Scale(224),
transforms.ToTensor(),
normalize
])
#### ######### ######## ########### #####
X_train = KaggleAmazonDataset('./data/train_v2.csv','./data/train-jpg/','.jpg',
ds_transform_augmented
)
X_val = KaggleAmazonDataset('./data/train_v2.csv','./data/train-jpg/','.jpg',
ds_transform_raw
)
# Resample the dataset
# weights = balance_weights(X_train.getDF(), 'tags', X_train.getLabelEncoder())
# weights = np.clip(weights,0.02,0.2) # We need to let the net view the most common classes or learning is too slow
# Creating a validation split
train_idx, valid_idx = train_valid_split(X_train, 0.2)
# weights[valid_idx] = 0
# train_sampler = WeightedRandomSampler(weights, len(train_idx))
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetSampler(valid_idx)
###### ########## ########## ######## #########
# Both dataloader loads from the same dataset but with different indices
train_loader = DataLoader(X_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=4,
pin_memory=True)
valid_loader = DataLoader(X_val,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=4,
pin_memory=True)
###########################################################
## Start training
best_score = 0.
for epoch in range(epochs):
epoch_timer = timer()
# Train and validate
train(epoch, train_loader, model, criterion, optimizer)
score, loss, threshold = validate(epoch, valid_loader, model, criterion, X_train.getLabelEncoder())
# Save
is_best = score > best_score
best_score = max(score, best_score)
snapshot(save_dir, run_name, is_best,{
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'best_score': best_score,
'optimizer': optimizer.state_dict(),
'threshold': threshold,
'val_loss': loss
})
end_epoch_timer = timer()
logger.info("#### End epoch {}, elapsed time: {}".format(epoch, end_epoch_timer - epoch_timer))
###########################################################
## Prediction
X_test = KaggleAmazonDataset('./data/sample_submission_v2.csv','./data/test-jpg/','.jpg',
ds_transform_raw
)
test_loader = DataLoader(X_test,
batch_size=batch_size,
num_workers=4,
pin_memory=True)
# Load model from best iteration
logger.info('===> loading best model for prediction')
checkpoint = torch.load(os.path.join(save_dir,
run_name + '-model_best.pth'
)
)
model.load_state_dict(checkpoint['state_dict'])
# Predict
predictions = predict(test_loader, model) # TODO load model from the best on disk
output(predictions,
checkpoint['threshold'],
X_test,
X_train.getLabelEncoder(),
'./out',
run_name,
checkpoint['best_score']) # TODO early_stopping and use best_score
##########################################################
end_global_timer = timer()
logger.info("################## Success #########################")
logger.info("Total elapsed time: %s" % (end_global_timer - global_timer))