-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathtrain_net.py
510 lines (389 loc) · 18 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
import os
import sys
import numpy as np
import copy
from collections import OrderedDict
import torch
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.engine import (
default_argument_parser,
default_setup,
default_writers,
launch
)
from detectron2.solver import build_lr_scheduler
from detectron2.utils.events import EventStorage
from detectron2.utils.logger import setup_logger
logger = logging.getLogger("cubercnn")
sys.dont_write_bytecode = True
sys.path.append(os.getcwd())
np.set_printoptions(suppress=True)
from cubercnn.solver import build_optimizer, freeze_bn, PeriodicCheckpointerOnlyOne
from cubercnn.config import get_cfg_defaults
from cubercnn.data import (
load_omni3d_json,
DatasetMapper3D,
build_detection_train_loader,
build_detection_test_loader,
get_omni3d_categories,
simple_register
)
from cubercnn.evaluation import (
Omni3DEvaluator, Omni3Deval,
Omni3DEvaluationHelper,
inference_on_dataset
)
from cubercnn.modeling.proposal_generator import RPNWithIgnore
from cubercnn.modeling.roi_heads import ROIHeads3D
from cubercnn.modeling.meta_arch import RCNN3D, build_model
from cubercnn.modeling.backbone import build_dla_from_vision_fpn_backbone
from cubercnn import util, vis, data
import cubercnn.vis.logperf as utils_logperf
MAX_TRAINING_ATTEMPTS = 10
def do_test(cfg, model, iteration='final', storage=None):
filter_settings = data.get_filter_settings_from_cfg(cfg)
filter_settings['visibility_thres'] = cfg.TEST.VISIBILITY_THRES
filter_settings['truncation_thres'] = cfg.TEST.TRUNCATION_THRES
filter_settings['min_height_thres'] = 0.0625
filter_settings['max_depth'] = 1e8
dataset_names_test = cfg.DATASETS.TEST
only_2d = cfg.MODEL.ROI_CUBE_HEAD.LOSS_W_3D == 0.0
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", 'iter_{}'.format(iteration))
eval_helper = Omni3DEvaluationHelper(
dataset_names_test,
filter_settings,
output_folder,
iter_label=iteration,
only_2d=only_2d,
)
for dataset_name in dataset_names_test:
"""
Cycle through each dataset and test them individually.
This loop keeps track of each per-image evaluation result,
so that it doesn't need to be re-computed for the collective.
"""
'''
Distributed Cube R-CNN inference
'''
data_loader = build_detection_test_loader(cfg, dataset_name)
results_json = inference_on_dataset(model, data_loader)
if comm.is_main_process():
'''
Individual dataset evaluation
'''
eval_helper.add_predictions(dataset_name, results_json)
eval_helper.save_predictions(dataset_name)
eval_helper.evaluate(dataset_name)
'''
Optionally, visualize some instances
'''
instances = torch.load(os.path.join(output_folder, dataset_name, 'instances_predictions.pth'))
log_str = vis.visualize_from_instances(
instances, data_loader.dataset, dataset_name,
cfg.INPUT.MIN_SIZE_TEST, os.path.join(output_folder, dataset_name),
MetadataCatalog.get('omni3d_model').thing_classes, iteration
)
logger.info(log_str)
if comm.is_main_process():
'''
Summarize each Omni3D Evaluation metric
'''
eval_helper.summarize_all()
def do_train(cfg, model, dataset_id_to_unknown_cats, dataset_id_to_src, resume=False):
max_iter = cfg.SOLVER.MAX_ITER
do_eval = cfg.TEST.EVAL_PERIOD > 0
model.train()
optimizer = build_optimizer(cfg, model)
scheduler = build_lr_scheduler(cfg, optimizer)
# bookkeeping
checkpointer = DetectionCheckpointer(model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler)
periodic_checkpointer = PeriodicCheckpointerOnlyOne(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter)
writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else []
# create the dataloader
data_mapper = DatasetMapper3D(cfg, is_train=True)
data_loader = build_detection_train_loader(cfg, mapper=data_mapper, dataset_id_to_src=dataset_id_to_src)
# give the mapper access to dataset_ids
data_mapper.dataset_id_to_unknown_cats = dataset_id_to_unknown_cats
if cfg.MODEL.WEIGHTS_PRETRAIN != '':
# load ONLY the model, no checkpointables.
checkpointer.load(cfg.MODEL.WEIGHTS_PRETRAIN, checkpointables=[])
# determine the starting iteration, if resuming
start_iter = (checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1)
iteration = start_iter
logger.info("Starting training from iteration {}".format(start_iter))
if not cfg.MODEL.USE_BN:
freeze_bn(model)
world_size = comm.get_world_size()
# if the loss diverges for more than the below TOLERANCE
# as a percent of the iterations, the training will stop.
# This is only enabled if "STABILIZE" is on, which
# prevents a single example from exploding the training.
iterations_success = 0
iterations_explode = 0
# when loss > recent_loss * TOLERANCE, then it could be a
# diverging/failing model, which we should skip all updates for.
TOLERANCE = 4.0
GAMMA = 0.02 # rolling average weight gain
recent_loss = None # stores the most recent loss magnitude
data_iter = iter(data_loader)
# model.parameters() is surprisingly expensive at 150ms, so cache it
named_params = list(model.named_parameters())
with EventStorage(start_iter) as storage:
while True:
data = next(data_iter)
storage.iter = iteration
# forward
loss_dict = model(data)
losses = sum(loss_dict.values())
# reduce
loss_dict_reduced = {k: v.item() for k, v in allreduce_dict(loss_dict).items()}
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
# sync up
comm.synchronize()
if recent_loss is None:
# init recent loss fairly high
recent_loss = losses_reduced*2.0
# Is stabilization enabled, and loss high or NaN?
diverging_model = cfg.MODEL.STABILIZE > 0 and \
(losses_reduced > recent_loss*TOLERANCE or \
not (np.isfinite(losses_reduced)) or np.isnan(losses_reduced))
if diverging_model:
# clip and warn the user.
losses = losses.clip(0, 1)
logger.warning('Skipping gradient update due to higher than normal loss {:.2f} vs. rolling mean {:.2f}, Dict-> {}'.format(
losses_reduced, recent_loss, loss_dict_reduced
))
else:
# compute rolling average of loss
recent_loss = recent_loss * (1-GAMMA) + losses_reduced*GAMMA
if comm.is_main_process():
# send loss scalars to tensorboard.
storage.put_scalars(total_loss=losses_reduced, **loss_dict_reduced)
# backward and step
optimizer.zero_grad()
losses.backward()
# if the loss is not too high,
# we still want to check gradients.
if not diverging_model:
if cfg.MODEL.STABILIZE > 0:
for name, param in named_params:
if param.grad is not None:
diverging_model = torch.isnan(param.grad).any() or torch.isinf(param.grad).any()
if diverging_model:
logger.warning('Skipping gradient update due to inf/nan detection, loss is {}'.format(loss_dict_reduced))
break
# convert exploded to a float, then allreduce it,
# if any process gradients have exploded then we skip together.
diverging_model = torch.tensor(float(diverging_model)).cuda()
if world_size > 1:
dist.all_reduce(diverging_model)
# sync up
comm.synchronize()
if diverging_model > 0:
optimizer.zero_grad()
iterations_explode += 1
else:
optimizer.step()
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
iterations_success += 1
total_iterations = iterations_success + iterations_explode
# Only retry if we have trained sufficiently long relative
# to the latest checkpoint, which we would otherwise revert back to.
retry = (iterations_explode / total_iterations) >= cfg.MODEL.STABILIZE \
and (total_iterations > cfg.SOLVER.CHECKPOINT_PERIOD*1/2)
# Important for dist training. Convert to a float, then allreduce it,
# if any process gradients have exploded then we must skip together.
retry = torch.tensor(float(retry)).cuda()
if world_size > 1:
dist.all_reduce(retry)
# sync up
comm.synchronize()
# any processes need to retry
if retry > 0:
# instead of failing, try to resume the iteration instead.
logger.warning('!! Restarting training at {} iters. Exploding loss {:d}% of iters !!'.format(
iteration, int(100*(iterations_explode / (iterations_success + iterations_explode)))
))
# send these to garbage, for ideally a cleaner restart.
del data_mapper
del data_loader
del optimizer
del checkpointer
del periodic_checkpointer
return False
scheduler.step()
# Evaluate only when the loss is not diverging.
if not (diverging_model > 0) and \
(do_eval and ((iteration + 1) % cfg.TEST.EVAL_PERIOD) == 0 and iteration != (max_iter - 1)):
logger.info('Starting test for iteration {}'.format(iteration+1))
do_test(cfg, model, iteration=iteration+1, storage=storage)
comm.synchronize()
if not cfg.MODEL.USE_BN:
freeze_bn(model)
# Flush events
if iteration - start_iter > 5 and ((iteration + 1) % 20 == 0 or iteration == max_iter - 1):
for writer in writers:
writer.write()
# Do not bother checkpointing if there is potential for a diverging model.
if not (diverging_model > 0) and \
(iterations_explode / total_iterations) < 0.5*cfg.MODEL.STABILIZE:
periodic_checkpointer.step(iteration)
iteration += 1
if iteration >= max_iter:
break
# success
return True
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
get_cfg_defaults(cfg)
config_file = args.config_file
# store locally if needed
if config_file.startswith(util.CubeRCNNHandler.PREFIX):
config_file = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, config_file)
cfg.merge_from_file(config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="cubercnn")
filter_settings = data.get_filter_settings_from_cfg(cfg)
for dataset_name in cfg.DATASETS.TRAIN:
simple_register(dataset_name, filter_settings, filter_empty=True)
dataset_names_test = cfg.DATASETS.TEST
for dataset_name in dataset_names_test:
if not(dataset_name in cfg.DATASETS.TRAIN):
simple_register(dataset_name, filter_settings, filter_empty=False)
return cfg
def main(args):
cfg = setup(args)
logger.info('Preprocessing Training Datasets')
filter_settings = data.get_filter_settings_from_cfg(cfg)
priors = None
if args.eval_only:
category_path = os.path.join(util.file_parts(args.config_file)[0], 'category_meta.json')
# store locally if needed
if category_path.startswith(util.CubeRCNNHandler.PREFIX):
category_path = util.CubeRCNNHandler._get_local_path(util.CubeRCNNHandler, category_path)
metadata = util.load_json(category_path)
# register the categories
thing_classes = metadata['thing_classes']
id_map = {int(key):val for key, val in metadata['thing_dataset_id_to_contiguous_id'].items()}
MetadataCatalog.get('omni3d_model').thing_classes = thing_classes
MetadataCatalog.get('omni3d_model').thing_dataset_id_to_contiguous_id = id_map
else:
# setup and join the data.
dataset_paths = [os.path.join('datasets', 'Omni3D', name + '.json') for name in cfg.DATASETS.TRAIN]
datasets = data.Omni3D(dataset_paths, filter_settings=filter_settings)
# determine the meta data given the datasets used.
data.register_and_store_model_metadata(datasets, cfg.OUTPUT_DIR, filter_settings)
thing_classes = MetadataCatalog.get('omni3d_model').thing_classes
dataset_id_to_contiguous_id = MetadataCatalog.get('omni3d_model').thing_dataset_id_to_contiguous_id
'''
It may be useful to keep track of which categories are annotated/known
for each dataset in use, in case a method wants to use this information.
'''
infos = datasets.dataset['info']
if type(infos) == dict:
infos = [datasets.dataset['info']]
dataset_id_to_unknown_cats = {}
possible_categories = set(i for i in range(cfg.MODEL.ROI_HEADS.NUM_CLASSES + 1))
dataset_id_to_src = {}
for info in infos:
dataset_id = info['id']
known_category_training_ids = set()
if not dataset_id in dataset_id_to_src:
dataset_id_to_src[dataset_id] = info['source']
for id in info['known_category_ids']:
if id in dataset_id_to_contiguous_id:
known_category_training_ids.add(dataset_id_to_contiguous_id[id])
# determine and store the unknown categories.
unknown_categories = possible_categories - known_category_training_ids
dataset_id_to_unknown_cats[dataset_id] = unknown_categories
# log the per-dataset categories
logger.info('Available categories for {}'.format(info['name']))
logger.info([thing_classes[i] for i in (possible_categories & known_category_training_ids)])
# compute priors given the training data.
priors = util.compute_priors(cfg, datasets)
'''
The training loops can attempt to train for N times.
This catches a divergence or other failure modes.
'''
remaining_attempts = MAX_TRAINING_ATTEMPTS
while remaining_attempts > 0:
# build the training model.
model = build_model(cfg, priors=priors)
if remaining_attempts == MAX_TRAINING_ATTEMPTS:
# log the first attempt's settings.
logger.info("Model:\n{}".format(model))
if args.eval_only:
# skip straight to eval mode
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
return do_test(cfg, model)
# setup distributed training.
distributed = comm.get_world_size() > 1
if distributed:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()],
broadcast_buffers=False, find_unused_parameters=True
)
# train full model, potentially with resume.
if do_train(cfg, model, dataset_id_to_unknown_cats, dataset_id_to_src, resume=args.resume):
break
else:
# allow restart when a model fails to train.
remaining_attempts -= 1
del model
if remaining_attempts == 0:
# Exit if the model could not finish without diverging.
raise ValueError('Training failed')
return do_test(cfg, model)
def allreduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = comm.get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
if __name__ == "__main__":
args = default_argument_parser().parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)