-
Notifications
You must be signed in to change notification settings - Fork 9
/
engine.py
134 lines (109 loc) · 5.65 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# ------------------------------------------------------------------------
# QAHOI
# Copyright (c) 2021 Junwen Chen. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Modified from QPIC (https://github.com/hitachi-rd-cv/qpic)
# Copyright (c) Hitachi, Ltd. All Rights Reserved.
# ------------------------------------------------------------------------
import math
import os
import sys
from typing import Iterable
import numpy as np
import copy
import itertools
import torch
import util.misc as utils
from datasets.hico_eval import HICOEvaluator
from datasets.vcoco_eval import VCOCOEvaluator
from loguru import logger
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, max_norm: float = 0):
model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
if hasattr(criterion, 'loss_labels'):
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
else:
metric_logger.add_meter('obj_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
# torch.cuda.empty_cache()
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
outputs = model(samples)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
optimizer.zero_grad()
continue
# sys.exit(1)
optimizer.zero_grad()
losses.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
if hasattr(criterion, 'loss_labels'):
metric_logger.update(class_error=loss_dict_reduced['class_error'])
else:
metric_logger.update(obj_class_error=loss_dict_reduced['obj_class_error'])
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
if utils.get_rank() == 0:
logger.info("\nAveraged stats: {}".format(metric_logger))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate_hoi(dataset_file, model, postprocessors, data_loader, subject_category_id, device, out_dir, epoch, args):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
preds = []
gts = []
indices = []
for samples, targets in metric_logger.log_every(data_loader, 10, header):
samples = samples.to(device)
outputs = model(samples)
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['hoi'](outputs, orig_target_sizes)
preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results))))
# For avoiding a runtime error, the copy is used
gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets)))))
# gather the stats from all processes
metric_logger.synchronize_between_processes()
img_ids = [img_gts['id'] for img_gts in gts]
_, indices = np.unique(img_ids, return_index=True)
preds = [img_preds for i, img_preds in enumerate(preds) if i in indices]
gts = [img_gts for i, img_gts in enumerate(gts) if i in indices]
if dataset_file == 'hico':
evaluator = HICOEvaluator(preds, gts, args.hoi_path, out_dir, epoch, use_nms=args.use_nms, nms_thresh=args.nms_thresh)
rank = utils.get_rank()
stats = evaluator.evaluation_default()
if rank == 0:
logger.info('\n--------------------\ndefault mAP: {}\ndefault mAP rare: {}\ndefault mAP non-rare: {}\n--------------------'.format(stats['mAP_def'], stats['mAP_def_rare'], stats['mAP_def_non_rare']))
stats_ko = evaluator.evaluation_ko()
if rank == 0:
logger.info('\n--------------------\nko mAP: {}\nko mAP rare: {}\nko mAP non-rare: {}\n--------------------'.format(stats_ko['mAP_ko'], stats_ko['mAP_ko_rare'], stats_ko['mAP_ko_non_rare']))
stats.update(stats_ko)
if args.eval_extra:
evaluator.evaluation_extra()
elif dataset_file == 'vcoco':
evaluator = VCOCOEvaluator(preds, gts, subject_category_id, data_loader.dataset.correct_mat, use_nms=args.use_nms, nms_thresh=args.nms_thresh)
stats = evaluator.evaluate()
return stats