This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 82
/
trainer.py
719 lines (611 loc) · 27.5 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import os
import time
import logging
import torch
from torch.nn.parallel import DistributedDataParallel
from fvcore.nn.precise_bn import get_bn_modules
import numpy as np
from collections import OrderedDict
import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.engine import DefaultTrainer, SimpleTrainer, TrainerBase
from detectron2.engine.train_loop import AMPTrainer
from detectron2.utils.events import EventStorage
from detectron2.evaluation import COCOEvaluator, verify_results, PascalVOCDetectionEvaluator, DatasetEvaluators
from detectron2.data.dataset_mapper import DatasetMapper
from detectron2.engine import hooks
from detectron2.structures.boxes import Boxes
from detectron2.structures.instances import Instances
from detectron2.utils.env import TORCH_VERSION
from detectron2.data import MetadataCatalog
from ubteacher.data.build import (
build_detection_semisup_train_loader,
build_detection_test_loader,
build_detection_semisup_train_loader_two_crops,
)
from ubteacher.data.dataset_mapper import DatasetMapperTwoCropSeparate
from ubteacher.engine.hooks import LossEvalHook
from ubteacher.modeling.meta_arch.ts_ensemble import EnsembleTSModel
from ubteacher.checkpoint.detection_checkpoint import DetectionTSCheckpointer
from ubteacher.solver.build import build_lr_scheduler
# Supervised-only Trainer
class BaselineTrainer(DefaultTrainer):
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
Use the custom checkpointer, which loads other backbone models
with matching heuristics.
"""
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
data_loader = self.build_train_loader(cfg)
if comm.get_world_size() > 1:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
TrainerBase.__init__(self)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
self.checkpointer = DetectionCheckpointer(
model,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
def resume_or_load(self, resume=True):
"""
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
a `last_checkpoint` file), resume from the file. Resuming means loading all
available states (eg. optimizer and scheduler) and update iteration counter
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
Otherwise, this is considered as an independent training. The method will load model
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
from iteration 0.
Args:
resume (bool): whether to do resume or not
"""
checkpoint = self.checkpointer.resume_or_load(
self.cfg.MODEL.WEIGHTS, resume=resume
)
if resume and self.checkpointer.has_checkpoint():
self.start_iter = checkpoint.get("iteration", -1) + 1
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
if isinstance(self.model, DistributedDataParallel):
# broadcast loaded data/model from the first rank, because other
# machines may not have access to the checkpoint file
if TORCH_VERSION >= (1, 7):
self.model._sync_params_and_buffers()
self.start_iter = comm.all_gather(self.start_iter)[0]
def train_loop(self, start_iter: int, max_iter: int):
"""
Args:
start_iter, max_iter (int): See docs above
"""
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
def run_step(self):
self._trainer.iter = self.iter
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
start = time.perf_counter()
data = next(self._trainer._data_loader_iter)
data_time = time.perf_counter() - start
record_dict, _, _, _ = self.model(data, branch="supervised")
num_gt_bbox = 0.0
for element in data:
num_gt_bbox += len(element["instances"])
num_gt_bbox = num_gt_bbox / len(data)
record_dict["bbox_num/gt_bboxes"] = num_gt_bbox
loss_dict = {}
for key in record_dict.keys():
if key[:4] == "loss" and key[-3:] != "val":
loss_dict[key] = record_dict[key]
losses = sum(loss_dict.values())
metrics_dict = record_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
self.optimizer.zero_grad()
losses.backward()
self.optimizer.step()
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type == "coco":
evaluator_list.append(COCOEvaluator(
dataset_name, output_dir=output_folder))
elif evaluator_type == "pascal_voc":
return PascalVOCDetectionEvaluator(dataset_name)
if len(evaluator_list) == 0:
raise NotImplementedError(
"no Evaluator for the dataset {} with the type {}".format(
dataset_name, evaluator_type
)
)
elif len(evaluator_list) == 1:
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)
@classmethod
def build_train_loader(cls, cfg):
return build_detection_semisup_train_loader(cfg, mapper=None)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
"""
Returns:
iterable
"""
return build_detection_test_loader(cfg, dataset_name)
def build_hooks(self):
"""
Build a list of default hooks, including timing, evaluation,
checkpointing, lr scheduling, precise BN, writing events.
Returns:
list[HookBase]:
"""
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
hooks.PreciseBN(
cfg.TEST.EVAL_PERIOD,
self.model,
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
if comm.is_main_process():
ret.append(
hooks.PeriodicCheckpointer(
self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD
)
)
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
if comm.is_main_process():
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret
def _write_metrics(self, metrics_dict: dict):
"""
Args:
metrics_dict (dict): dict of scalar metrics
"""
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
# gather metrics among all workers for logging
# This assumes we do DDP-style training, which is currently the only
# supported method in detectron2.
all_metrics_dict = comm.gather(metrics_dict)
if comm.is_main_process():
if "data_time" in all_metrics_dict[0]:
data_time = np.max([x.pop("data_time")
for x in all_metrics_dict])
self.storage.put_scalar("data_time", data_time)
metrics_dict = {
k: np.mean([x[k] for x in all_metrics_dict])
for k in all_metrics_dict[0].keys()
}
loss_dict = {}
for key in metrics_dict.keys():
if key[:4] == "loss":
loss_dict[key] = metrics_dict[key]
total_losses_reduced = sum(loss for loss in loss_dict.values())
self.storage.put_scalar("total_loss", total_losses_reduced)
if len(metrics_dict) > 1:
self.storage.put_scalars(**metrics_dict)
# Unbiased Teacher Trainer
class UBTeacherTrainer(DefaultTrainer):
def __init__(self, cfg):
"""
Args:
cfg (CfgNode):
Use the custom checkpointer, which loads other backbone models
with matching heuristics.
"""
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
data_loader = self.build_train_loader(cfg)
# create an student model
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
# create an teacher model
model_teacher = self.build_model(cfg)
self.model_teacher = model_teacher
# For training, wrap with DDP. But don't need this for inference.
if comm.get_world_size() > 1:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
TrainerBase.__init__(self)
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
# Ensemble teacher and student model is for model saving and loading
ensem_ts_model = EnsembleTSModel(model_teacher, model)
self.checkpointer = DetectionTSCheckpointer(
ensem_ts_model,
cfg.OUTPUT_DIR,
optimizer=optimizer,
scheduler=self.scheduler,
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
def resume_or_load(self, resume=True):
"""
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
a `last_checkpoint` file), resume from the file. Resuming means loading all
available states (eg. optimizer and scheduler) and update iteration counter
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
Otherwise, this is considered as an independent training. The method will load model
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
from iteration 0.
Args:
resume (bool): whether to do resume or not
"""
checkpoint = self.checkpointer.resume_or_load(
self.cfg.MODEL.WEIGHTS, resume=resume
)
if resume and self.checkpointer.has_checkpoint():
self.start_iter = checkpoint.get("iteration", -1) + 1
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration (or iter zero if there's no checkpoint).
if isinstance(self.model, DistributedDataParallel):
# broadcast loaded data/model from the first rank, because other
# machines may not have access to the checkpoint file
if TORCH_VERSION >= (1, 7):
self.model._sync_params_and_buffers()
self.start_iter = comm.all_gather(self.start_iter)[0]
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
if output_folder is None:
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
evaluator_list = []
evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
if evaluator_type == "coco":
evaluator_list.append(COCOEvaluator(
dataset_name, output_dir=output_folder))
elif evaluator_type == "pascal_voc":
return PascalVOCDetectionEvaluator(dataset_name)
if len(evaluator_list) == 0:
raise NotImplementedError(
"no Evaluator for the dataset {} with the type {}".format(
dataset_name, evaluator_type
)
)
elif len(evaluator_list) == 1:
return evaluator_list[0]
return DatasetEvaluators(evaluator_list)
@classmethod
def build_train_loader(cls, cfg):
mapper = DatasetMapperTwoCropSeparate(cfg, True)
return build_detection_semisup_train_loader_two_crops(cfg, mapper)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
return build_lr_scheduler(cfg, optimizer)
def train(self):
self.train_loop(self.start_iter, self.max_iter)
if hasattr(self, "_last_eval_results") and comm.is_main_process():
verify_results(self.cfg, self._last_eval_results)
return self._last_eval_results
def train_loop(self, start_iter: int, max_iter: int):
logger = logging.getLogger(__name__)
logger.info("Starting training from iteration {}".format(start_iter))
self.iter = self.start_iter = start_iter
self.max_iter = max_iter
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step_full_semisup()
self.after_step()
except Exception:
logger.exception("Exception during training:")
raise
finally:
self.after_train()
# =====================================================
# ================== Pseduo-labeling ==================
# =====================================================
def threshold_bbox(self, proposal_bbox_inst, thres=0.7, proposal_type="roih"):
if proposal_type == "rpn":
valid_map = proposal_bbox_inst.objectness_logits > thres
# create instances containing boxes and gt_classes
image_shape = proposal_bbox_inst.image_size
new_proposal_inst = Instances(image_shape)
# create box
new_bbox_loc = proposal_bbox_inst.proposal_boxes.tensor[valid_map, :]
new_boxes = Boxes(new_bbox_loc)
# add boxes to instances
new_proposal_inst.gt_boxes = new_boxes
new_proposal_inst.objectness_logits = proposal_bbox_inst.objectness_logits[
valid_map
]
elif proposal_type == "roih":
valid_map = proposal_bbox_inst.scores > thres
# create instances containing boxes and gt_classes
image_shape = proposal_bbox_inst.image_size
new_proposal_inst = Instances(image_shape)
# create box
new_bbox_loc = proposal_bbox_inst.pred_boxes.tensor[valid_map, :]
new_boxes = Boxes(new_bbox_loc)
# add boxes to instances
new_proposal_inst.gt_boxes = new_boxes
new_proposal_inst.gt_classes = proposal_bbox_inst.pred_classes[valid_map]
new_proposal_inst.scores = proposal_bbox_inst.scores[valid_map]
return new_proposal_inst
def process_pseudo_label(
self, proposals_rpn_unsup_k, cur_threshold, proposal_type, psedo_label_method=""
):
list_instances = []
num_proposal_output = 0.0
for proposal_bbox_inst in proposals_rpn_unsup_k:
# thresholding
if psedo_label_method == "thresholding":
proposal_bbox_inst = self.threshold_bbox(
proposal_bbox_inst, thres=cur_threshold, proposal_type=proposal_type
)
else:
raise ValueError("Unkown pseudo label boxes methods")
num_proposal_output += len(proposal_bbox_inst)
list_instances.append(proposal_bbox_inst)
num_proposal_output = num_proposal_output / len(proposals_rpn_unsup_k)
return list_instances, num_proposal_output
def remove_label(self, label_data):
for label_datum in label_data:
if "instances" in label_datum.keys():
del label_datum["instances"]
return label_data
def add_label(self, unlabled_data, label):
for unlabel_datum, lab_inst in zip(unlabled_data, label):
unlabel_datum["instances"] = lab_inst
return unlabled_data
# =====================================================
# =================== Training Flow ===================
# =====================================================
def run_step_full_semisup(self):
self._trainer.iter = self.iter
assert self.model.training, "[UBTeacherTrainer] model was changed to eval mode!"
start = time.perf_counter()
data = next(self._trainer._data_loader_iter)
# data_q and data_k from different augmentations (q:strong, k:weak)
# label_strong, label_weak, unlabed_strong, unlabled_weak
label_data_q, label_data_k, unlabel_data_q, unlabel_data_k = data
data_time = time.perf_counter() - start
# remove unlabeled data labels
unlabel_data_q = self.remove_label(unlabel_data_q)
unlabel_data_k = self.remove_label(unlabel_data_k)
# burn-in stage (supervised training with labeled data)
if self.iter < self.cfg.SEMISUPNET.BURN_UP_STEP:
# input both strong and weak supervised data into model
label_data_q.extend(label_data_k)
record_dict, _, _, _ = self.model(
label_data_q, branch="supervised")
# weight losses
loss_dict = {}
for key in record_dict.keys():
if key[:4] == "loss":
loss_dict[key] = record_dict[key] * 1
losses = sum(loss_dict.values())
else:
if self.iter == self.cfg.SEMISUPNET.BURN_UP_STEP:
# update copy the the whole model
self._update_teacher_model(keep_rate=0.00)
elif (
self.iter - self.cfg.SEMISUPNET.BURN_UP_STEP
) % self.cfg.SEMISUPNET.TEACHER_UPDATE_ITER == 0:
self._update_teacher_model(
keep_rate=self.cfg.SEMISUPNET.EMA_KEEP_RATE)
record_dict = {}
# generate the pseudo-label using teacher model
# note that we do not convert to eval mode, as 1) there is no gradient computed in
# teacher model and 2) batch norm layers are not updated as well
with torch.no_grad():
(
_,
proposals_rpn_unsup_k,
proposals_roih_unsup_k,
_,
) = self.model_teacher(unlabel_data_k, branch="unsup_data_weak")
# Pseudo-labeling
cur_threshold = self.cfg.SEMISUPNET.BBOX_THRESHOLD
joint_proposal_dict = {}
joint_proposal_dict["proposals_rpn"] = proposals_rpn_unsup_k
(
pesudo_proposals_rpn_unsup_k,
nun_pseudo_bbox_rpn,
) = self.process_pseudo_label(
proposals_rpn_unsup_k, cur_threshold, "rpn", "thresholding"
)
joint_proposal_dict["proposals_pseudo_rpn"] = pesudo_proposals_rpn_unsup_k
# Pseudo_labeling for ROI head (bbox location/objectness)
pesudo_proposals_roih_unsup_k, _ = self.process_pseudo_label(
proposals_roih_unsup_k, cur_threshold, "roih", "thresholding"
)
joint_proposal_dict["proposals_pseudo_roih"] = pesudo_proposals_roih_unsup_k
# add pseudo-label to unlabeled data
unlabel_data_q = self.add_label(
unlabel_data_q, joint_proposal_dict["proposals_pseudo_roih"]
)
unlabel_data_k = self.add_label(
unlabel_data_k, joint_proposal_dict["proposals_pseudo_roih"]
)
all_label_data = label_data_q + label_data_k
all_unlabel_data = unlabel_data_q
record_all_label_data, _, _, _ = self.model(
all_label_data, branch="supervised"
)
record_dict.update(record_all_label_data)
record_all_unlabel_data, _, _, _ = self.model(
all_unlabel_data, branch="supervised"
)
new_record_all_unlabel_data = {}
for key in record_all_unlabel_data.keys():
new_record_all_unlabel_data[key + "_pseudo"] = record_all_unlabel_data[
key
]
record_dict.update(new_record_all_unlabel_data)
# weight losses
loss_dict = {}
for key in record_dict.keys():
if key[:4] == "loss":
if key == "loss_rpn_loc_pseudo" or key == "loss_box_reg_pseudo":
# pseudo bbox regression <- 0
loss_dict[key] = record_dict[key] * 0
elif key[-6:] == "pseudo": # unsupervised loss
loss_dict[key] = (
record_dict[key] *
self.cfg.SEMISUPNET.UNSUP_LOSS_WEIGHT
)
else: # supervised loss
loss_dict[key] = record_dict[key] * 1
losses = sum(loss_dict.values())
metrics_dict = record_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
self.optimizer.zero_grad()
losses.backward()
self.optimizer.step()
def _write_metrics(self, metrics_dict: dict):
metrics_dict = {
k: v.detach().cpu().item() if isinstance(v, torch.Tensor) else float(v)
for k, v in metrics_dict.items()
}
# gather metrics among all workers for logging
# This assumes we do DDP-style training, which is currently the only
# supported method in detectron2.
all_metrics_dict = comm.gather(metrics_dict)
# all_hg_dict = comm.gather(hg_dict)
if comm.is_main_process():
if "data_time" in all_metrics_dict[0]:
# data_time among workers can have high variance. The actual latency
# caused by data_time is the maximum among workers.
data_time = np.max([x.pop("data_time")
for x in all_metrics_dict])
self.storage.put_scalar("data_time", data_time)
# average the rest metrics
metrics_dict = {
k: np.mean([x[k] for x in all_metrics_dict])
for k in all_metrics_dict[0].keys()
}
# append the list
loss_dict = {}
for key in metrics_dict.keys():
if key[:4] == "loss":
loss_dict[key] = metrics_dict[key]
total_losses_reduced = sum(loss for loss in loss_dict.values())
self.storage.put_scalar("total_loss", total_losses_reduced)
if len(metrics_dict) > 1:
self.storage.put_scalars(**metrics_dict)
@torch.no_grad()
def _update_teacher_model(self, keep_rate=0.996):
if comm.get_world_size() > 1:
student_model_dict = {
key[7:]: value for key, value in self.model.state_dict().items()
}
else:
student_model_dict = self.model.state_dict()
new_teacher_dict = OrderedDict()
for key, value in self.model_teacher.state_dict().items():
if key in student_model_dict.keys():
new_teacher_dict[key] = (
student_model_dict[key] *
(1 - keep_rate) + value * keep_rate
)
else:
raise Exception("{} is not found in student model".format(key))
self.model_teacher.load_state_dict(new_teacher_dict)
@torch.no_grad()
def _copy_main_model(self):
# initialize all parameters
if comm.get_world_size() > 1:
rename_model_dict = {
key[7:]: value for key, value in self.model.state_dict().items()
}
self.model_teacher.load_state_dict(rename_model_dict)
else:
self.model_teacher.load_state_dict(self.model.state_dict())
@classmethod
def build_test_loader(cls, cfg, dataset_name):
return build_detection_test_loader(cfg, dataset_name)
def build_hooks(self):
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
hooks.IterationTimer(),
hooks.LRScheduler(self.optimizer, self.scheduler),
hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None,
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process():
ret.append(
hooks.PeriodicCheckpointer(
self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD
)
)
def test_and_save_results_student():
self._last_eval_results_student = self.test(self.cfg, self.model)
_last_eval_results_student = {
k + "_student": self._last_eval_results_student[k]
for k in self._last_eval_results_student.keys()
}
return _last_eval_results_student
def test_and_save_results_teacher():
self._last_eval_results_teacher = self.test(
self.cfg, self.model_teacher)
return self._last_eval_results_teacher
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD,
test_and_save_results_student))
ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD,
test_and_save_results_teacher))
if comm.is_main_process():
# run writers in the end, so that evaluation metrics are written
ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
return ret