forked from aosokin/os2d
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_flickr27.py
470 lines (381 loc) · 19 KB
/
train_flickr27.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
import os
import argparse
import pandas as pd
from PIL import Image
import numpy as np
import time
from collections import OrderedDict
from os2d.structures.transforms import TransformList, crop
import copy
import math
import torch
from torchvision.utils import save_image
import torchvision.transforms as transforms
import pickle
import random
import matplotlib.patches as patches
from os2d.modeling.model import build_os2d_from_config
from os2d.data.dataloader import build_eval_dataloaders_from_cfg, build_train_dataloader_from_config
from os2d.engine.train import trainval_loop
from os2d.engine.evaluate import evaluate
from os2d.utils import log_meters, checkpoint_model, set_random_seed, add_to_meters_in_dict, print_meters, get_trainable_parameters, mkdir, save_config, setup_logger, get_data_path, read_image, get_image_size_after_resize_preserving_aspect_ratio
from os2d.engine.optimization import create_optimizer
from os2d.config import cfg
from os2d.structures.feature_map import FeatureMapSize
from os2d.structures.bounding_box import BoxList
from os2d.engine.augmentation import DataAugmentation
import matplotlib.pyplot as plt
import os2d.utils.visualization as visualizer
import os2d.structures.transforms as transforms_boxes
from os2d.engine.optimization import setup_lr, get_learning_rate, set_learning_rate
imgspath = '../../data/flickr_logos_27_dataset/flickr_logos_27_dataset_images'
querypath = '../../data/flickr_logos_27_dataset/flickr_logos_27_class_images'
anns_path = '../../data/flickr_logos_27_dataset/annotations.csv'
cpt_path = "alternating-models/checkpoint_iter_12500.pth"
output_path = 'flickr27-models'
anns_df = pd.read_csv(anns_path)
train_df = anns_df
cfg.is_cuda = torch.cuda.is_available()
cfg.train.batch_size = 1
# set this to use faster convolutions
if cfg.is_cuda:
assert torch.cuda.is_available(), "Do not have available GPU, but cfg.is_cuda == 1"
torch.backends.cudnn.benchmark = True
# random seed
set_random_seed(cfg.random_seed, cfg.is_cuda)
# Model
cfg.init.model = cpt_path
net, box_coder, criterion, img_normalization, optimizer_state = build_os2d_from_config(cfg)
transform_image = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(img_normalization["mean"], img_normalization["std"])
])
parameters = get_trainable_parameters(net)
optimizer = create_optimizer(parameters, cfg.train.optim, optimizer_state)
data_augmentation = DataAugmentation(random_flip_batches=False,
random_crop_size=FeatureMapSize(w=600, h=600),
random_crop_scale=0.39215686274509803,
jitter_aspect_ratio=0.9,
scale_jitter=0.7,
random_color_distortion=True,
random_crop_label_images=False,
min_box_coverage=0.7)
cfg.output.save_iter = 809
cfg.output.path = output_path
cfg.eval.iter = 5000
cfg.train.optim.max_iter = 100000#len(train_df['imageid'])
def get_litw_batch(i_batch):
idx_0 = i_batch * cfg.train.batch_size
idxs = range(idx_0, idx_0 + cfg.train.batch_size)
ids = [np.unique(train_df['imageid'])[idx] for idx in idxs]
batch_data = _prepare_batch(ids)
return batch_data
def trainval_loop2():
# setup the learning rate schedule
_, anneal_lr_func = setup_lr(optimizer, None, cfg.train.optim.anneal_lr, cfg.eval.iter)
# save initial model
if cfg.output.path:
checkpoint_model(net, optimizer, cfg.output.path, cfg.is_cuda, i_iter=0)
# start training
i_epoch = 0
i_batch = 0 # to start a new epoch at the first iteration
ts = []
t0 = time.time()
for i_iter in range(0, cfg.train.optim.max_iter):
t1 = time.time()
ts.append(t1-t0)
t0 = t1
if i_iter % 25 == 0:
print(i_iter, np.mean(ts))
# restart dataloader if needed
if i_batch >= len(np.unique(train_df['imageid'])) // cfg.train.batch_size:
i_epoch += 1
i_batch = 0
# get data for training
t_start_loading = time.time()
try:
batch_data = get_litw_batch(i_batch)
except:
i_batch += 1
print(f"skipping {i_batch}")
continue
i_batch += 1
t_data_loading = time.time() - t_start_loading
# train on one batch
meters = train_one_batch(batch_data, net, cfg, criterion, optimizer)
meters["loading_time"] = t_data_loading
# save intermediate model
if cfg.output.path and cfg.output.save_iter and i_iter % cfg.output.save_iter == 0:
print("Saving...")
checkpoint_model(net, optimizer, cfg.output.path, cfg.is_cuda, i_iter=i_iter)
# save the final model
if cfg.output.path:
checkpoint_model(net, optimizer, cfg.output.path, cfg.is_cuda, i_iter=cfg.train.optim.max_iter)
def prepare_batch_data(batch_data, is_cuda):
"""Helper function to parse batch_data and put tensors on a GPU.
Used in train_one_batch
"""
images, class_images, loc_targets, class_targets, class_ids, class_image_sizes, \
batch_box_inverse_transform, batch_boxes, batch_img_size = \
batch_data
if is_cuda:
images = images.cuda()
class_images = [im.cuda() for im in class_images]
loc_targets = loc_targets.cuda()
class_targets = class_targets.cuda()
return images, class_images, loc_targets, class_targets, class_ids, class_image_sizes, \
batch_box_inverse_transform, batch_boxes, batch_img_size
def train_one_batch(batch_data, net, cfg, criterion, optimizer):
t_start_batch = time.time()
net.train(freeze_bn_in_extractor=cfg.train.model.freeze_bn,
freeze_transform_params=cfg.train.model.freeze_transform,
freeze_bn_transform=cfg.train.model.freeze_bn_transform)
optimizer.zero_grad()
images, class_images, loc_targets, class_targets, class_ids, class_image_sizes, \
batch_box_inverse_transform, batch_boxes, batch_img_size = \
prepare_batch_data(batch_data, cfg.is_cuda)
#print('training batch...')
#fig = plt.figure()
#plt.imshow(images[0].permute(1,2,0))
#plt.show()
#fig = plt.figure()
#plt.imshow(class_images[0].permute(1,2,0))
#plt.show()
loc_scores, class_scores, class_scores_transform_detached, fm_sizes, corners = \
net(images, class_images,
train_mode=True,
fine_tune_features=cfg.train.model.train_features)
cls_targets_remapped, ious_anchor, ious_anchor_corrected = \
box_coder.remap_anchor_targets(loc_scores, batch_img_size, class_image_sizes, batch_boxes)
losses = criterion(loc_scores, loc_targets,
class_scores, class_targets,
cls_targets_remapped=cls_targets_remapped,
cls_preds_for_neg=class_scores_transform_detached if not cfg.train.model.train_transform_on_negs else None)
main_loss = losses["loss"]
main_loss.backward()
# save full grad
grad = OrderedDict()
for name, param in net.named_parameters():
if param.requires_grad and param.grad is not None:
grad[name] = param.grad.clone().cpu()
grad_norm = torch.nn.utils.clip_grad_norm_(get_trainable_parameters(net), cfg.train.optim.max_grad_norm, norm_type=2)
# save error state if grad appears to be nan
if math.isnan(grad_norm):
# remove some unsavable objects
batch_data = [b for b in batch_data]
batch_data[6] = None
data_nan = {"batch_data":batch_data, "state_dict":net.state_dict(), "optimizer": optimizer.state_dict(),
"cfg":cfg, "grad": grad}
time_stamp = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
dump_file = "error_nan_appeared-"+time_stamp+".pth"
if cfg.output.path:
dump_file = os.path.join(cfg.output.path, dump_file)
print("gradient is NaN. Saving dump to {}".format(dump_file))
torch.save(data_nan, dump_file)
else:
optimizer.step()
# convert everything to numbers
meters = OrderedDict()
for l in losses:
meters[l] = losses[l].mean().item()
meters["grad_norm"] = grad_norm
meters["batch_time"] = time.time() - t_start_batch
return meters
def get_class_images_and_sizes(class_ids):
class_images = []
class_image_sizes = []
all_class_ids = []
for class_id in class_ids:
images = []
#for img in os.listdir(f'{querypath}/{class_id}'):
# if image[-4:] == '.jpg':
# images.append(Image.open(f'{querypath}/{class_id}/{image}'))
#print(class_id)
images = [Image.open(f'{querypath}/{class_id}/{image}').convert("RGB") for image in os.listdir(f'{querypath}/{class_id}') if image[-4:] == '.jpg']
choice = random.choice(images)
#fig = plt.figure()
#plt.imshow(choice)
#plt.show()
class_images.append(choice)
class_image_sizes = [FeatureMapSize(img=img) for img in class_images]
return class_images, class_image_sizes
def _transform_image_gt(img, do_augmentation=True, hflip=False, vflip=False, do_resize=True):
# batch level data augmentation
img, _ = transforms_boxes.transpose(img, hflip=hflip, vflip=vflip, boxes=None, transform_list=None)
if do_augmentation:
# color distortion
img = data_augmentation.random_distort(img)
# random crop
img = data_augmentation.random_crop_label_image(img)
# resize image
if do_resize:
random_interpolation = data_augmentation.random_interpolation if do_augmentation else False
# get the new size - while preserving aspect ratio
size_old = FeatureMapSize(img=img)
h, w = get_image_size_after_resize_preserving_aspect_ratio(h=size_old.h, w=size_old.w,
target_size=240)
size_new = FeatureMapSize(w=w, h=h)
img, _ = transforms_boxes.resize(img, target_size=size_new, random_interpolation=random_interpolation)
transforms_th = [transforms.ToTensor()]
if img_normalization is not None:
transforms_th += [transforms.Normalize(img_normalization["mean"], img_normalization["std"])]
img = transforms.Compose(transforms_th)(img)
return img
def get_boxes_from_image_dataframe(image_data, image_size):
if not image_data.empty:
# get the labels
label_ids_global = torch.tensor(list(image_data["classid"]), dtype=torch.long)
# get the boxes
boxes = image_data[["lx", "ty", "rx", "by"]].to_numpy()
# renorm boxes using the image size
boxes[:, 0] *= image_size.w
boxes[:, 2] *= image_size.w
boxes[:, 1] *= image_size.h
boxes[:, 3] *= image_size.h
boxes = torch.FloatTensor(boxes)
boxes = BoxList(boxes, image_size=image_size, mode="xyxy")
else:
boxes = BoxList.create_empty(image_size)
label_ids_global = torch.tensor([], dtype=torch.long)
difficult_flag = torch.tensor([], dtype=torch.bool)
boxes.add_field("labels", label_ids_global)
boxes.add_field("labels_original", label_ids_global)
return boxes
def convert_label_ids_global_to_local(label_ids_global, class_ids):
label_ids_local = [] # local indices w.r.t. batch_class_images
if label_ids_global is not None:
for label_id in label_ids_global:
label_id = label_id.item()
label_ids_local.append( class_ids.index(label_id) if label_id in class_ids else -1 )
label_ids_local = torch.tensor(label_ids_local, dtype=torch.long)
return label_ids_local
def update_box_labels_to_local(boxes, class_ids):
label_ids_global = boxes.get_field("labels")
label_ids_local = convert_label_ids_global_to_local(label_ids_global, class_ids)
boxes.add_field("labels", label_ids_local)
def _transform_image_to_pyramid(image_id, boxes=None,
do_augmentation=True, hflip=False, vflip=False,
pyramid_scales=(1,),
mined_data=None ):
#print(image_id)
#fig = plt.figure()
img = Image.open(f'{imgspath}/{image_id}').convert("RGB")
#plt.show()
img_size = FeatureMapSize(img=img)
num_pyramid_levels = len(pyramid_scales)
if boxes is None:
boxes = BoxList.create_empty(img_size)
mask_cutoff_boxes = torch.zeros(len(boxes), dtype=torch.bool)
mask_difficult_boxes = torch.zeros(len(boxes), dtype=torch.bool)
box_inverse_transform = TransformList()
# batch level data augmentation
img, boxes = transforms_boxes.transpose(img, hflip=hflip, vflip=vflip,
boxes=boxes,
transform_list=box_inverse_transform)
if do_augmentation:
if data_augmentation.do_random_crop:
img, boxes, mask_cutoff_boxes, mask_difficult_boxes = \
data_augmentation.random_crop(img,
boxes=boxes,
transform_list=box_inverse_transform)
img, boxes = transforms_boxes.resize(img, target_size=data_augmentation.random_crop_size,
random_interpolation=data_augmentation.random_interpolation,
boxes=boxes,
transform_list=box_inverse_transform)
# color distortion
img = data_augmentation.random_distort(img)
random_interpolation = data_augmentation.random_interpolation
img_size = FeatureMapSize(img=img)
pyramid_sizes = [ FeatureMapSize(w=int(img_size.w * s), h=int(img_size.h * s)) for s in pyramid_scales ]
img_pyramid = []
boxes_pyramid = []
pyramid_box_inverse_transform = []
for p_size in pyramid_sizes:
box_inverse_transform_this_scale = copy.deepcopy(box_inverse_transform)
p_img, p_boxes = transforms_boxes.resize(img, target_size=p_size, random_interpolation=random_interpolation,
boxes=boxes,
transform_list=box_inverse_transform_this_scale)
pyramid_box_inverse_transform.append(box_inverse_transform_this_scale)
img_pyramid.append( p_img )
boxes_pyramid.append( p_boxes )
transforms_th = [transforms.ToTensor()]
if img_normalization is not None:
transforms_th += [transforms.Normalize(img_normalization["mean"], img_normalization["std"])]
for i_p in range(num_pyramid_levels):
img_pyramid[i_p] = transforms.Compose(transforms_th)( img_pyramid[i_p] )
return img_pyramid, boxes_pyramid, mask_cutoff_boxes, mask_difficult_boxes, pyramid_box_inverse_transform
def _transform_image(image_id, boxes=None, do_augmentation=True, hflip=False, vflip=False, mined_data=None):
img_pyramid, boxes_pyramid, mask_cutoff_boxes, mask_difficult_boxes, pyramid_box_inverse_transform = \
_transform_image_to_pyramid(image_id, boxes=boxes,
do_augmentation=do_augmentation, hflip=hflip, vflip=vflip,
pyramid_scales=(1,), mined_data=mined_data)
return img_pyramid[0], boxes_pyramid[0], mask_cutoff_boxes, mask_difficult_boxes, pyramid_box_inverse_transform[0]
def _prepare_batch(image_ids, use_all_labels=False):
batch_images = []
batch_class_images = []
batch_loc_targets = []
batch_class_targets = []
# flag to use hard neg mining
use_mined_data = False
# collect labels for this batch
batch_data = train_df[train_df['imageid'].isin(image_ids)]
class_ids = batch_data["classid"].unique()
brands = batch_data["name"].unique()
# select labels for mined hardnegs
mined_labels = []
# randomly prune label images if too many
max_batch_labels = class_ids.size + len(mined_labels) + 1
class_ids = np.unique(class_ids)
np.random.shuffle(class_ids)
class_ids = class_ids[:max_batch_labels - len(mined_labels)]
class_ids = np.concatenate((class_ids, np.array(mined_labels).astype(class_ids.dtype)), axis=0)
class_ids = list(class_ids)#sorted(list(class_ids))
# decide on batch level data augmentation
batch_vflip = random.random() < 0.5 if data_augmentation.batch_random_vflip else False
batch_hflip = random.random() < 0.5 if data_augmentation.batch_random_hflip else False
# prepare class images
num_classes = len(class_ids)
class_images, class_image_sizes = get_class_images_and_sizes(brands)
batch_class_images = [_transform_image_gt(img, hflip=batch_hflip, vflip=batch_vflip) for img in class_images]
# get the image sizes after resize in self._transform_image_gt, format - width, height
class_image_sizes = [FeatureMapSize(img=img) for img in batch_class_images]
# prepare images and boxes
img_size = None
batch_box_inverse_transform = []
batch_boxes = []
batch_img_size = []
for image_id in image_ids:
# get annotation
fm_size = FeatureMapSize(Image.open(f'{imgspath}/{image_id}').convert("RGB"))
boxes = get_boxes_from_image_dataframe(batch_data[batch_data['imageid'] == image_id], fm_size)
# convert global indices to local
# if use_global_labels==False then local indices will be w.r.t. labels in this batch
# if use_global_labels==True then local indices will be w.r.t. labels in the whole dataset (not class_ids)
update_box_labels_to_local(boxes, class_ids)
# prepare image and boxes: convert image to tensor, data augmentation: some boxes might be cut off the image
image_mined_data = None
img, boxes, mask_cutoff_boxes, mask_difficult_boxes, box_inverse_transform = \
_transform_image(image_id, boxes, hflip=batch_hflip, vflip=batch_vflip, mined_data=image_mined_data)
if boxes.has_field("difficult"):
old_difficult = boxes.get_field("difficult")
boxes.add_field("difficult", old_difficult | mask_difficult_boxes)
boxes.get_field("labels")[mask_cutoff_boxes] = -2
# check image size in this batch
if img_size is None:
img_size = FeatureMapSize(img=img)
else:
assert img_size == FeatureMapSize(img=img), "Images in a batch should be of the same size"
loc_targets, class_targets = box_coder.encode(boxes, img_size, num_classes)
batch_loc_targets.append(loc_targets)
batch_class_targets.append(class_targets)
batch_images.append(img)
batch_box_inverse_transform.append( [box_inverse_transform] )
batch_boxes.append(boxes)
batch_img_size.append(img_size)
# stack data
batch_images = torch.stack(batch_images, 0)
batch_loc_targets = torch.stack(batch_loc_targets, 0)
batch_class_targets = torch.stack(batch_class_targets, 0)
return batch_images, batch_class_images, batch_loc_targets, batch_class_targets, class_ids, class_image_sizes, \
batch_box_inverse_transform, batch_boxes, batch_img_size
trainval_loop2()