-
Notifications
You must be signed in to change notification settings - Fork 6
/
engine_mt.py
322 lines (264 loc) · 12.6 KB
/
engine_mt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import math
import sys
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy
import util.misc as misc
import util.lr_sched as lr_sched
import torch.nn.functional as F
from util.metric import *
def get_loss(outputs, targets, task):
if 'class' in task:
task_loss = F.mse_loss(outputs, targets.squeeze(1))
elif 'segment_semantic' in task:
task_loss = criterion(outputs, targets)
elif 'normal' in task:
T = targets.permute(0,2,3,1)
task_loss = (1 - (outputs*T).sum(-1) / (torch.norm(outputs, p=2, dim=-1) + 0.000001) / (torch.norm(T, p=2, dim=-1)+ 0.000001) ).mean()
elif 'depth' in task or 'keypoint' in task or 'reshading' in task or 'edge' in task or 'segment' in task:
if outputs.shape[-1] == 1:
Out = outputs.view(outputs.shape[:-1])
elif outputs.shape[-1] == 3:
Out = outputs.permute(0,3,1,2)
task_loss = F.l1_loss(Out, targets)
else: # L2 curvature
if outputs.shape[-1] == 1:
Out = outputs.view(outputs.shape[:-1])
elif outputs.shape[-1] > 1:
Out = outputs.permute(0,3,1,2)
task_loss = F.mse_loss(Out, targets)
return task_loss
def get_metric(outputs, targets, task):
# get the metric
if 'class' in task:
# correct_prediction = tf.equal(tf.argmax(final_output,1), tf.argmax(target, 1))
metric = (outputs.argmax(dim=-1) == targets.argmax(dim=-1)).float().mean().item()
elif 'depth' in task:
if outputs.shape[-1] == 1:
outputs = outputs.view(outputs.shape[:-1]) # B, H, W
if task == 'depth_euclidean':
metric = compute_depth_errors(outputs, targets).item()
else:
metric = 0.0
elif 'curvature' in task:
if outputs.shape[-1] == 1:
outputs = outputs.view(outputs.shape[:-1])
elif outputs.shape[-1] > 1:
outputs = outputs.permute(0,3,1,2)
metric = F.mse_loss(outputs, targets).item()
else:
if outputs.shape[-1] == 1:
outputs = outputs.view(outputs.shape[:-1])
elif outputs.shape[-1] == 3:
outputs = outputs.permute(0,3,1,2)
metric = F.l1_loss(outputs, targets).item()
return metric
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
mixup_fn: Optional[Mixup] = None, log_writer=None, AWL=None,
args=None):
model.train(True)
AWL.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
for data_iter_step, (data) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
if args.cycle:
lr_sched.adjust_cycle_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
else:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
samples = data['rgb'].to(device, non_blocking=True)
z_loss = 0
loss = 0
the_loss = {}
loss_list = []
tot_loss = 0
the_metric = {}
if args.visualize and misc.is_main_process():
if data_iter_step > 0 and data_iter_step%20 == 0:
model.module.visualize(vis_head=True, vis_mlp=False, model_name=args.exp_name)
with torch.cuda.amp.autocast():
predict = {}
if model.module.taskGating:
outputs, aux_loss = model(samples, None)
z_loss = z_loss + aux_loss
for task in args.img_types:
if 'rgb' in task:
continue
predict[task] = outputs[task].detach().cpu()
targets = data[task].to(device, non_blocking=True)
task_loss = get_loss(outputs[task], targets, task)
if not math.isfinite(task_loss.item()):
print("Loss is {}, stopping training".format(task_loss.item()))
sys.exit(1)
task_loss = torch.clamp(task_loss, min=-1000, max=1000)
tot_loss = tot_loss + task_loss.item()
the_loss[task] = task_loss
loss_list.append(the_loss[task])
task_metric = get_metric(outputs[task], targets, task)
the_metric[task] = task_metric
else:
for task in args.img_types:
if 'rgb' in task:
continue
outputs, aux_loss = model(samples, task)
z_loss = z_loss + aux_loss
predict[task] = outputs.detach().cpu()
targets = data[task].to(device, non_blocking=True)
task_loss = get_loss(outputs, targets, task)
task_loss = torch.clamp(task_loss, min=-1000, max=1000)
task_loss_value = task_loss.item()
if not math.isfinite(task_loss_value):
print("Task is {} Loss is {}, stopping training".format(task, task_loss_value))
task_loss = torch.clamp(task_loss, min=-1000, max=1000)
sys.exit(1)
tot_loss = tot_loss + task_loss.item()
the_loss[task] = task_loss
loss_list.append(the_loss[task])
task_metric = get_metric(outputs, targets, task)
the_metric[task] = task_metric
if args.visualizeimg:
image_visualize(args, data, predict)
loss = AWL(loss_list)
loss_value = loss.item()
if args.tasks > 2:
if torch.is_tensor(z_loss):
if not math.isfinite(z_loss.item()): #
print("ZLoss is {}, stopping training".format(z_loss.item()))
# z_loss =s
sys.exit(1)
loss = loss + z_loss
else:
z_loss = z_loss * 0.00001
if torch.is_tensor(z_loss):
z_loss_value = z_loss.item()
else:
z_loss_value = z_loss
the_loss_value = {}
for _key, value in the_loss.items():
the_loss_value[_key] = value.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
loss = torch.clamp(loss, min=-1000, max=1000)
sys.exit(1)
loss = torch.clamp(loss, min=-1000, max=1000)
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
# model.module.init_aux_statistics()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
if model.module.ismoe:
metric_logger.update(zloss=z_loss_value)
for _key, value in the_loss_value.items():
metric_logger.meters[_key].update(value)
for _key, value in the_metric.items():
metric_logger.meters['met_'+_key].update(value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
tot_loss_reduce = misc.all_reduce_mean(tot_loss)
if torch.is_tensor(z_loss):
z_loss_value_reduce = misc.all_reduce_mean(z_loss_value)
else:
z_loss_value_reduce = 0
the_loss_value_reduce = {}
for _key, value in the_loss_value.items():
the_loss_value_reduce[_key] = misc.all_reduce_mean(value)
the_metric_reduce = {}
for _key, value in the_metric.items():
the_metric_reduce[_key] = misc.all_reduce_mean(value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('z_loss', z_loss_value_reduce, epoch_1000x)
log_writer.add_scalar('tot_loss', tot_loss_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)
for _key, value in the_loss_value_reduce.items():
log_writer.add_scalar('multitask/' + _key, value, epoch_1000x)
for _key, value in the_metric_reduce.items():
log_writer.add_scalar('multitask_metric/' + _key, value, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
# if misc.is_main_process():
print("Averaged stats:", metric_logger)
print('params: ', AWL.params)
# model.module.visualize(vis_head=False, vis_mlp=True)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device, AWL, args):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = misc.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
AWL.eval()
for data in metric_logger.log_every(data_loader, 10, header):
samples = data['rgb']
samples = samples.to(device, non_blocking=True)
the_loss = {}
loss_list = []
the_metric = {}
tot_loss = 0
with torch.cuda.amp.autocast():
if model.module.taskGating:
outputs, _ = model(samples, None)
# z_loss = z_loss + aux_loss
for task in args.img_types:
if 'rgb' in task:
continue
targets = data[task].to(device, non_blocking=True)
task_loss = get_loss(outputs[task], targets, task)
tot_loss = tot_loss + task_loss.item()
the_loss[task] = task_loss
loss_list.append(the_loss[task])
task_metric = get_metric(outputs[task], targets, task)
the_metric[task] = task_metric
else:
for task in args.img_types:
if 'rgb' in task:
continue
outputs, _ = model(samples, task)
targets = data[task].to(device, non_blocking=True)
task_loss = get_loss(outputs, targets, task)
tot_loss = tot_loss + task_loss.item()
the_loss[task] = task_loss
loss_list.append(the_loss[task])
task_metric = get_metric(outputs, targets, task)
the_metric[task] = task_metric
loss = AWL(loss_list)
# acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = samples.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.update(tot_loss=tot_loss)
# metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
# metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
for _key, value in the_loss.items():
metric_logger.meters[_key].update(value.item(), n=batch_size)
for _key, value in the_metric.items():
metric_logger.meters['met_'+_key].update(value, n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
# print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
# .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
print('test Result: ', ' '.join(str(a) + ':' + str(b.global_avg) for (a,b) in metric_logger.meters.items()))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}