-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlosses.py
461 lines (396 loc) · 25.6 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
import torch
import torch.nn as nn
import torch.nn.functional as torch_f
from utils.perceptual_loss import PerceptualLoss
import utils.pytorch_ssim as pytorch_ssim
from utils.losses_util import bone_direction_loss, tsa_pose_loss, calc_laplacian_loss, edge_length_loss, iou #image_l1_loss, iou_loss, ChamferLoss,
# !depreciated
def loss_func(examples, outputs, loss_used, dat_name, args) -> dict:
loss_dic = {}
device = examples['imgs'].device
# heatmap integral loss: estimated 2d joints -> openpose 2d joints
if 'hm_integral' in loss_used and ('open_2dj' in examples) and ('open_2dj_con' in examples) and ('hm_j2d_list' in outputs):
hm_j2d_list = outputs['hm_j2d_list']
hm_integral_loss = torch.zeros(1).to(device)
for hm_j2d in hm_j2d_list:
hm_2dj_distance = torch.sqrt(torch.sum((examples['open_2dj']-hm_j2d)**2,2))#[b,21]
open_2dj_con_hm = examples['open_2dj_con'].squeeze(2)
hm_integral_loss += (torch.sum(hm_2dj_distance.mul(open_2dj_con_hm**2))/torch.sum((open_2dj_con_hm**2)))
loss_dic['hm_integral'] = args.lambda_hm * hm_integral_loss
else:
loss_dic['hm_integral'] = torch.zeros(1)
# (used in full supervision) estimated 2d joints -> gt 2d joints
if 'hm_integral_gt' in loss_used and ('j2d_gt' in examples) and ('hm_j2d_list' in outputs):
hm_j2d_list = outputs['hm_j2d_list']
hm_integral_loss = torch.zeros(1).to(device)
for hm_j2d in hm_j2d_list:
hm_2dj_distance0 = torch.sqrt(torch.sum((examples['j2d_gt']-hm_j2d)**2,2))#[b,21]
open_2dj_con_hm0 = torch.ones_like(hm_2dj_distance0) # set confidence as 1
hm_integral_loss += torch.sum(hm_2dj_distance0.mul(open_2dj_con_hm0**2))/torch.sum((open_2dj_con_hm0**2))
loss_dic['hm_integral_gt'] = args.lambda_hm * hm_integral_loss
else:
loss_dic['hm_integral_gt'] = torch.zeros(1)
# (used in full supervision) 2d joint loss: projected 2d joints -> gt 2d joints
if 'j2d_gt' in examples and ('j2d' in outputs):
joint_2d_loss = torch_f.mse_loss(examples['j2d_gt'], outputs['j2d'])
joint_2d_loss = args.lambda_j2d_gt * joint_2d_loss
loss_dic['joint_2d'] = joint_2d_loss
else:
loss_dic['joint_2d'] = torch.zeros(1)
# open pose 2d joint loss: projected 2dj -> openpose 2dj
if 'open_2dj' in loss_used and ('open_2dj' in examples) and ('open_2dj_con' in examples) and ('j2d' in outputs):
open_2dj_distance = torch.sqrt(torch.sum((examples['open_2dj']-outputs['j2d'])**2,2))
open_2dj_distance = torch.where(open_2dj_distance<5, open_2dj_distance**2/10,open_2dj_distance-2.5)
keypoint_weights = torch.tensor([[2,1,1,1,1.5,1,1,1,1.5,1,1,1,1.5,1,1,1,1.5,1,1,1,1.5]]).to(device).float()
open_2dj_con0 = examples['open_2dj_con'].squeeze(2)
open_2dj_con0 = open_2dj_con0.mul(keypoint_weights)
open_2dj_loss = (torch.sum(open_2dj_distance.mul(open_2dj_con0**2))/torch.sum((open_2dj_con0**2)))
open_2dj_loss = args.lambda_j2d * open_2dj_loss
loss_dic['open_2dj'] = open_2dj_loss
else:
loss_dic['open_2dj'] = torch.zeros(1)
# open pose 2d joint loss --- Downgrade Version
if "open_2dj_de" in loss_used and ('open_2dj' in examples) and ('j2d' in outputs):
open_2dj_loss = torch_f.mse_loss(examples['open_2dj'],outputs['j2d'])
open_2dj_loss = args.lambda_j2d_de * open_2dj_loss
loss_dic["open_2dj_de"] = open_2dj_loss
else:
loss_dic["open_2dj_de"] = torch.zeros(1)
# (used in full supervision) 3D joint loss & Bone scale loss: 3dj -> gt 3dj
if 'joints' in outputs and 'joints' in examples:
joint_3d_loss = torch_f.mse_loss(outputs['joints'], examples['joints'])
joint_3d_loss = args.lambda_j3d * joint_3d_loss
loss_dic["joint_3d"] = joint_3d_loss
# relative
joint_3d_loss_norm = torch_f.mse_loss((outputs['joints']-outputs['joints'][:,9].unsqueeze(1)),(examples['joints']-examples['joints'][:,9].unsqueeze(1)))
joint_3d_loss_norm = args.lambda_j3d_norm * joint_3d_loss_norm
loss_dic["joint_3d_norm"] = joint_3d_loss_norm
else:
loss_dic["joint_3d"] = torch.zeros(1)
loss_dic["joint_3d_norm"] = torch.zeros(1)
# bone direction loss: 2d bones -> openpose bones
if 'open_bone_direc' in loss_used and ('open_2dj' in examples) and ('open_2dj_con' in examples) and ('j2d' in outputs):
open_bone_direc_loss = bone_direction_loss(outputs['j2d'], examples['open_2dj'], examples['open_2dj_con'])
open_bone_direc_loss = args.lambda_bone_direc * open_bone_direc_loss
loss_dic['open_bone_direc'] = open_bone_direc_loss
else:
loss_dic['open_bone_direc'] = torch.zeros(1)
# (used in full supervision) projected 2d bones -> gt 2d bones
if 'bone_direc' in loss_used and ('j2d_gt' in examples) and ('j2d' in outputs):
j2d_con = torch.ones_like(examples['j2d_gt'][:,:,0]).unsqueeze(-1)
bone_direc_loss = bone_direction_loss(outputs['j2d'], examples['j2d_gt'], j2d_con)
bone_direc_loss = args.lambda_bone_direc * bone_direc_loss
loss_dic['bone_direc'] = bone_direc_loss
else:
loss_dic['bone_direc'] = torch.zeros(1)
# 2d-3d keypoints consistency loss: projected 2dj -> estimated 2dj
if ('hm_j2d_list' in outputs) and ('j2d' in outputs):
hm_j2d_list = outputs['hm_j2d_list']
kp_cons_distance = torch.sqrt(torch.sum((hm_j2d_list[-1]-outputs['j2d'])**2,2))
kp_cons_distance = torch.where(kp_cons_distance<5, kp_cons_distance**2/10,kp_cons_distance-2.5)
kp_cons_loss = torch.mean(kp_cons_distance)
kp_cons_loss = args.lambda_kp_cons * kp_cons_loss
loss_dic['kp_cons'] = kp_cons_loss
else:
loss_dic['kp_cons'] = torch.zeros(1)
# mean scale regularization term
if 'mscale' in loss_used and ('joints' in outputs):# and "joints" not in examples:
out_bone_length = torch.sqrt(torch.sum((outputs['joints'][:,9, :] - outputs['joints'][:,10, :])**2,1))#check
crit = nn.L1Loss()
mscale_loss = crit(out_bone_length,torch.ones_like(out_bone_length)*0.0282)#check
mscale_loss = args.lambda_mscale * mscale_loss
loss_dic['mscale'] = mscale_loss
else:
loss_dic['mscale'] = torch.zeros(1)
# (used in full supervision) GT scale loss
if 'scale' in loss_used and ('joints' in outputs) and 'scales' in examples:
if dat_name == 'FreiHand':
cal_scale = torch.sqrt(torch.sum((outputs['joints'][:,9]-outputs['joints'][:,10])**2,1))
scale_loss = torch_f.mse_loss(cal_scale, examples['scales'].to(device))
scale_loss = args.lambda_scale * scale_loss
loss_dic['scale'] = scale_loss
else:
loss_dic['scale'] = torch.zeros(1)
# MANO pose regularization terms
if 'tsa_poses' in outputs:
pose_loss = tsa_pose_loss(outputs['tsa_poses'])
pose_loss = args.lambda_pose * pose_loss
loss_dic['tsa_poses'] = pose_loss
else:
loss_dic['tsa_poses'] = torch.zeros(1)
# mesh texture regularization terms
if 'mtex' in loss_used and ('textures' in outputs) and ('texture_con' in examples):
textures = outputs['textures']
std = torch.std(textures.view(textures.shape[0],-1,3),dim=1)#[b,3]
mean = torch.mean(textures.view(textures.shape[0],-1,3),dim=1)
textures_reg = (torch.where(textures>(mean.view(-1,1,1,1,1,3)+2*std.view(-1,1,1,1,1,3)),textures-mean.view(-1,1,1,1,1,3),torch.zeros_like(textures))+torch.where(textures<(mean.view(-1,1,1,1,1,3)-2*std.view(-1,1,1,1,1,3)),-textures+mean.view(-1,1,1,1,1,3),torch.zeros_like(textures))).squeeze()
textures_reg = torch.sum(torch.mean(torch.mean(torch.mean(textures_reg,1),1),1).mul(examples['texture_con']*2))/torch.sum(examples['texture_con']**2)
textures_reg = args.lambda_tex_reg * textures_reg
loss_dic['mtex'] = textures_reg
else:
loss_dic['mtex'] = torch.zeros(1)
# photometric loss
if 're_img' in outputs and ('re_sil' in outputs) and ('texture_con' in examples):
maskRGBs = outputs['maskRGBs']#examples['imgs'].mul((outputs['re_sil']>0).float().unsqueeze(1).repeat(1,3,1,1))
re_img = outputs['re_img']
crit = nn.L1Loss()
# texture loss: rendered img -> masked original img
#texture_loss = crit(re_img, maskRGBs).cpu()
texture_con_this = examples['texture_con'].view(-1,1,1,1).repeat(1,re_img.shape[1],re_img.shape[2],re_img.shape[3])
texture_loss = (torch.sum(torch.abs(re_img-maskRGBs).mul(texture_con_this**2))/torch.sum((texture_con_this**2)))
texture_loss = args.lambda_texture * texture_loss
loss_dic['texture'] = texture_loss
# mean rgb loss
#loss_mean_rgb = torch_f.mse_loss(torch.mean(maskRGBs),torch.mean(re_img)).cpu()
loss_mean_rgb = (torch.sum(torch.abs(torch.mean(re_img.view(re_img.shape[0],-1),1)-torch.mean(maskRGBs.view(maskRGBs.shape[0],-1),1)).mul(examples['texture_con']**2))/torch.sum((examples['texture_con']**2)))
loss_mean_rgb = args.lambda_mrgb * loss_mean_rgb
loss_dic['mrgb'] = loss_mean_rgb
# ssim texture loss
ssim_tex = pytorch_ssim.ssim(re_img, maskRGBs)
loss_ssim_tex = 1 - ssim_tex
loss_ssim_tex = args.lambda_ssim_tex * loss_ssim_tex
loss_dic['ssim_tex'] = loss_ssim_tex
# ssim texture depth loss: ssim between rendered img -- rendered depth. ??? is it reasonable?
ssim_tex_depth = pytorch_ssim.ssim(re_img, outputs['re_depth'].unsqueeze(1).repeat(1,3,1,1))
loss_ssim_tex_depth = 1 - ssim_tex_depth
loss_ssim_tex_depth = args.lambda_ssim_tex * loss_ssim_tex_depth
loss_dic['ssim_tex_depth'] = loss_ssim_tex_depth
# ssim depth loss: ssim between masked original img -- rendered depth. ???
ssim_inrgb_depth = pytorch_ssim.ssim(maskRGBs, outputs['re_depth'].unsqueeze(1).repeat(1,3,1,1))
loss_ssim_inrgb_depth = 1 - ssim_inrgb_depth
loss_ssim_inrgb_depth = args.lambda_ssim_tex * loss_ssim_inrgb_depth
loss_dic['ssim_inrgb_depth'] = loss_ssim_inrgb_depth
else:
loss_dic['texture'] = torch.zeros(1)
loss_dic['mrgb'] = torch.zeros(1)
# (fully supervision) silhouette loss: rendered sil -> gt sil
if 're_sil' in outputs and 'segms_gt' in examples:
crit = nn.L1Loss()
sil_loss = crit(outputs['re_sil'], examples['segms_gt'].float())
loss_dic['sil'] = args.lambda_silhouette * sil_loss
else:
loss_dic['sil'] = torch.zeros(1)
# perceptual loss: rendered img -> gt img. not used at all.
# if 'perc_features' in outputs and ('texture_con' in examples):
# perc_features = outputs['perc_features']
# batch_size = perc_features[0].shape[0]
# loss_percep_batch = torch.mean(torch.abs(perc_features[0]-perc_features[2]),1)+torch.mean(torch.abs(perc_features[1]-perc_features[3]).reshape(batch_size,-1),1)
# loss_percep = torch.sum(loss_percep_batch.mul( examples['texture_con']**2))/torch.sum(( examples['texture_con']**2))
# loss_percep = args.lambda_percep * loss_percep
# loss_dic['loss_percep'] = loss_percep
# else:
# loss_dic['loss_percep'] = torch.zeros(1)
# mesh laplacian regularization term
if 'faces' in outputs and 'vertices' in outputs:
# triangle_loss_fn = LaplacianLoss(torch.autograd.Variable(outputs['faces'][0]).cpu(),outputs['vertices'][0])
# why [0]???
# triangle_loss = triangle_loss_fn(outputs['vertices'])
triangle_loss = calc_laplacian_loss(outputs['faces'], outputs['vertices'])
triangle_loss = args.lambda_laplacian * triangle_loss
loss_dic['triangle'] = triangle_loss
else:
loss_dic['triangle'] = torch.zeros(1)
# mean shape loss: make shape towards 0???
if 'shape' in outputs:
shape_loss = torch_f.mse_loss(outputs['shape'], torch.zeros_like(outputs['shape']).to(device))
shape_loss = args.lambda_shape * shape_loss
loss_dic['mshape'] = shape_loss
else:
loss_dic['mshape'] = torch.zeros(1)
return loss_dic
class LossFunction():
def __init__(self):
super(LossFunction, self).__init__()
self.perceptual_loss = PerceptualLoss()
def __call__(self, examples, outputs, loss_used, dat_name, args) -> dict:
loss_dic = {}
device = examples['imgs'].device
if args.base_loss_fn == 'L1':
base_loss_fn = nn.L1Loss()
elif args.base_loss_fn == 'L2':
base_loss_fn = torch_f.mse_loss
# (used in full supervision) 2d joint loss: projected 2d joints -> gt 2d joints
if 'joint_2d' in loss_used:
assert 'j2d_gt' in examples and ('j2d' in outputs), "Using joint_2d in losses, but j2d_gt or j2d are not provided."
joint_2d_loss = base_loss_fn(examples['j2d_gt'], outputs['j2d'])
joint_2d_loss = args.lambda_j2d_gt * joint_2d_loss
loss_dic['joint_2d'] = joint_2d_loss
# (used in full supervision) 3D joint loss & Bone scale loss: 3dj -> gt 3dj
if 'joint_3d' in loss_used:
assert 'joints' in outputs and 'joints' in examples, "Using joint_3d in losses, but joints or joints_gt are not provided."
joint_3d_loss = base_loss_fn(outputs['joints'], examples['joints'])
joint_3d_loss = args.lambda_j3d * joint_3d_loss
loss_dic["joint_3d"] = joint_3d_loss
# joint_3d_loss_norm = base_loss_fn((outputs['joints']-outputs['joints'][:,9].unsqueeze(1)),(examples['joints']-examples['joints'][:,9].unsqueeze(1)))
# joint_3d_loss_norm = args.lambda_j3d_norm * joint_3d_loss_norm
# loss_dic["joint_3d_norm"] = joint_3d_loss_norm
# (used in full supervision) 3D verts loss: 3dj -> gt 3dj
if 'vert_3d' in loss_used:
assert 'mano_verts' in outputs and 'verts' in examples, "Using vert_3d in losses, but verts or verts_gt are not provided."
vert_3d_loss = base_loss_fn(outputs['mano_verts'], examples['verts'])
vert_3d_loss = args.lambda_vert_3d * vert_3d_loss
loss_dic["vert_3d"] = vert_3d_loss
# (used in full supervision) projected 2d bones -> gt 2d bones
if 'bone_direc' in loss_used:
assert ('j2d_gt' in examples) and ('j2d' in outputs), "Using bone_direc but j2d_gt not inputted or j2d not outputted"
j2d_con = torch.ones_like(examples['j2d_gt'][:,:,0]).unsqueeze(-1)
bone_direc_loss = bone_direction_loss(outputs['j2d'], examples['j2d_gt'], j2d_con)
bone_direc_loss = args.lambda_bone_direc * bone_direc_loss
loss_dic['bone_direc'] = bone_direc_loss
# (used in full supervision) 3d bones -> gt 3d bones
if 'bone_direc_3d' in loss_used:
assert ('joints' in examples) and ('joints' in outputs), "Using bone_direc_3d but joints not inputted or outputted"
j3d_con = torch.ones_like(examples['joints'][:,:,0]).unsqueeze(-1)
bone_direc_loss_3d = bone_direction_loss(outputs['joints'], examples['joints'], j3d_con)
bone_direc_loss_3d = args.lambda_bone_direc_3d * bone_direc_loss_3d
loss_dic['bone_direc_3d'] = bone_direc_loss_3d
# (used in full supervision) 3d verts length loss of a given face. 3dv -> gt 3dv
if 'edge_length' in loss_used:
assert ('mano_verts' in outputs) and ('verts' in examples) and ('mano_faces' in outputs), "Using edge_length but verts or faces not outputted."
edge_len_loss = edge_length_loss(outputs['mano_verts'], examples['verts'], outputs['mano_faces'])
edge_len_loss = args.lambda_edge_len * edge_len_loss
loss_dic['edge_length'] = edge_len_loss
# mean scale regularization term
if 'mscale' in loss_used:
assert ('joints' in outputs), "Using mscale but joints not outputted."
out_bone_length = torch.sqrt(torch.sum((outputs['joints'][:,9, :] - outputs['joints'][:,10, :])**2,1))#check
crit = nn.L1Loss()
mscale_loss = crit(out_bone_length,torch.ones_like(out_bone_length)*0.0282)#check
mscale_loss = args.lambda_mscale * mscale_loss
loss_dic['mscale'] = mscale_loss
# (used in full supervision) GT scale loss
if 'scale' in loss_used:
assert ('joints' in outputs) and 'scales' in examples, "Using scale as loss but joints not outputted or scales not inputted."
if dat_name == 'FreiHand':
cal_scale = torch.sqrt(torch.sum((outputs['joints'][:,9]-outputs['joints'][:,10])**2,1))
scale_loss = torch_f.mse_loss(cal_scale, examples['scales'].to(device))
scale_loss = args.lambda_scale * scale_loss
loss_dic['scale'] = scale_loss
elif dat_name == 'RHD':
cal_scale = torch.sqrt(torch.sum((outputs['joints'][:,9]-outputs['joints'][:,10])**2,1))
scale_loss = torch_f.mse_loss(cal_scale, examples['scales'].to(device))
scale_loss = args.lambda_scale * scale_loss
loss_dic['scale'] = scale_loss
# self-supervised photometric loss
if 're_img' in outputs and ('re_sil' in outputs) and ('texture_con' in examples):
maskRGBs = outputs['maskRGBs']#examples['imgs'].mul((outputs['re_sil']>0).float().unsqueeze(1).repeat(1,3,1,1))
re_img = outputs['re_img']
crit = nn.L1Loss()
# texture loss: rendered img -> masked original img
#texture_loss = crit(re_img, maskRGBs).cpu()
texture_con_this = examples['texture_con'].view(-1,1,1,1).repeat(1,re_img.shape[1],re_img.shape[2],re_img.shape[3])
texture_loss = (torch.sum(torch.abs(re_img-maskRGBs).mul(texture_con_this**2))/torch.sum((texture_con_this**2)))
texture_loss = args.lambda_texture * texture_loss
loss_dic['texture_self'] = texture_loss
# mean rgb loss
#loss_mean_rgb = torch_f.mse_loss(torch.mean(maskRGBs),torch.mean(re_img)).cpu()
loss_mean_rgb = (torch.sum(torch.abs(torch.mean(re_img.reshape(re_img.shape[0],-1),1)-torch.mean(maskRGBs.reshape(maskRGBs.shape[0],-1),1)).mul(examples['texture_con']**2))/torch.sum((examples['texture_con']**2)))
loss_mean_rgb = args.lambda_mrgb * loss_mean_rgb
loss_dic['mrgb_self'] = loss_mean_rgb
# ssim texture loss
ssim_tex = pytorch_ssim.ssim(re_img, maskRGBs)
loss_ssim_tex = 1 - ssim_tex
loss_ssim_tex = args.lambda_ssim_tex * loss_ssim_tex
loss_dic['ssim_tex_self'] = loss_ssim_tex
# ssim texture depth loss: ssim between rendered img -- rendered depth. ??? is it reasonable?
# ssim_tex_depth = pytorch_ssim.ssim(re_img, outputs['re_depth'].unsqueeze(1).repeat(1,3,1,1))
# loss_ssim_tex_depth = 1 - ssim_tex_depth
# loss_ssim_tex_depth = args.lambda_ssim_tex * loss_ssim_tex_depth
# loss_dic['ssim_tex_depth_self'] = loss_ssim_tex_depth
# ssim depth loss: ssim between masked original img -- rendered depth. ???
# ssim_inrgb_depth = pytorch_ssim.ssim(maskRGBs, outputs['re_depth'].unsqueeze(1).repeat(1,3,1,1))
# loss_ssim_inrgb_depth = 1 - ssim_inrgb_depth
# loss_ssim_inrgb_depth = args.lambda_ssim_tex * loss_ssim_inrgb_depth
# loss_dic['ssim_inrgb_depth'] = loss_ssim_inrgb_depth
# photometric loss
if 're_img' in outputs and ('re_sil' in outputs):
# maskRGBs = outputs['maskRGBs']#examples['imgs'].mul((outputs['re_sil']>0).float().unsqueeze(1).repeat(1,3,1,1))
maskRGBs = examples['segms_gt'].unsqueeze(1) * examples['imgs'] #examples['imgs'].mul((outputs['re_sil']>0).float().unsqueeze(1).repeat(1,3,1,1))
# re_img needs to be masked by re_sil!!!
re_img = outputs['re_img'] * (outputs['re_sil']/255.0).repeat(1,3,1,1)
crit = nn.L1Loss()
# texture loss: rendered img -> masked original img
texture_loss = crit(re_img, maskRGBs)
# texture_loss = torch.sum(torch.abs(re_img-maskRGBs))
texture_loss = args.lambda_texture * texture_loss
loss_dic['texture'] = texture_loss
# mean rgb loss
loss_mean_rgb = torch_f.mse_loss(torch.mean(maskRGBs),torch.mean(re_img))
# loss_mean_rgb = (torch.sum(torch.abs(torch.mean(re_img.view(re_img.shape[0],-1),1)-torch.mean(maskRGBs.view(maskRGBs.shape[0],-1),1)).mul(examples['texture_con']**2))/torch.sum((examples['texture_con']**2)))
loss_mean_rgb = args.lambda_mrgb * loss_mean_rgb
loss_dic['mrgb'] = loss_mean_rgb
# ssim texture loss
ssim_tex = pytorch_ssim.ssim(re_img, maskRGBs)
loss_ssim_tex = 1 - ssim_tex
loss_ssim_tex = args.lambda_ssim_tex * loss_ssim_tex
loss_dic['ssim_tex'] = loss_ssim_tex
# ssim texture depth loss: ssim between rendered img -- rendered depth. ??? is it reasonable?
# ssim_tex_depth = pytorch_ssim.ssim(re_img, outputs['re_depth'].unsqueeze(1).repeat(1,3,1,1))
# loss_ssim_tex_depth = 1 - ssim_tex_depth
# loss_ssim_tex_depth = args.lambda_ssim_tex * loss_ssim_tex_depth
# loss_dic['ssim_tex_depth'] = loss_ssim_tex_depth
# ssim depth loss: ssim between masked original img -- rendered depth. ???
# ssim_inrgb_depth = pytorch_ssim.ssim(maskRGBs, outputs['re_depth'].unsqueeze(1).repeat(1,3,1,1))
# loss_ssim_inrgb_depth = 1 - ssim_inrgb_depth
# loss_ssim_inrgb_depth = args.lambda_ssim_tex * loss_ssim_inrgb_depth
# loss_dic['ssim_inrgb_depth'] = loss_ssim_inrgb_depth
# perceptual loss
if 'perceptual' in loss_used:
loss_percep = self.perceptual_loss(outputs['re_img'] * examples['segms_gt'].unsqueeze(1) + examples['imgs'] * (1 - examples['segms_gt'].unsqueeze(1)), examples['imgs'])
loss_percep = args.lambda_percep * loss_percep
loss_dic['perceptual'] = loss_percep
# (fully supervision) silhouette loss: rendered sil -> gt sil
if 'sil' in loss_used:
assert 're_sil' in outputs and 'segms_gt' in examples, 'silhouette loss needs rendered sil and gt sil'
crit = nn.L1Loss()
sil_loss = crit(outputs['re_sil'], examples['segms_gt'].unsqueeze(1).float())
loss_dic['sil'] = args.lambda_silhouette * sil_loss
if 'iou' in loss_used:
assert 're_sil' in outputs and 'segms_gt' in examples, 'iou loss needs rendered sil and gt sil'
iou_loss = iou(outputs['re_sil'], examples['segms_gt'].unsqueeze(1).float())
loss_dic['iou'] = args.lambda_iou * iou_loss
# perceptual loss: rendered img -> gt img. not used at all.
# if 'perc_features' in outputs and ('texture_con' in examples):
# perc_features = outputs['perc_features']
# batch_size = perc_features[0].shape[0]
# loss_percep_batch = torch.mean(torch.abs(perc_features[0]-perc_features[2]),1)+torch.mean(torch.abs(perc_features[1]-perc_features[3]).reshape(batch_size,-1),1)
# loss_percep = torch.sum(loss_percep_batch.mul( examples['texture_con']**2))/torch.sum(( examples['texture_con']**2))
# loss_percep = args.lambda_percep * loss_percep
# loss_dic['loss_percep'] = loss_percep
# else:
# loss_dic['loss_percep'] = torch.zeros(1)
# mesh laplacian regularization term
if 'triangle' in loss_used:
assert 'faces' in outputs and 'verts' in outputs, "Using triangle as loss but faces or verts are not outputted."
# triangle_loss_fn = LaplacianLoss(torch.autograd.Variable(outputs['faces'][0]).cpu(),outputs['vertices'][0])
# why [0]???
# triangle_loss = triangle_loss_fn(outputs['vertices'])
triangle_loss = calc_laplacian_loss(outputs['faces'], outputs['verts'])
triangle_loss = args.lambda_laplacian * triangle_loss
loss_dic['triangle'] = triangle_loss
# min shape loss: make shape towards 0
if 'mshape' in loss_used:
assert 'shape_params' in outputs, "Using mshape as loss but shape_params not outputted."
shape_loss = torch_f.mse_loss(outputs['shape_params'], torch.zeros_like(outputs['shape_params']).to(device))
shape_loss = args.lambda_shape * shape_loss
loss_dic['mshape'] = shape_loss
# min pose loss: make shape towards 0
if 'mpose' in loss_used:
assert 'pose_params' in outputs, "Using mpose as loss but pose_params not outputted."
pose_loss = torch_f.mse_loss(outputs['pose_params'], torch.zeros_like(outputs['pose_params']).to(device))
# pose_loss = outputs['pose_params'].pow(2).sum(dim=-1).sqrt().mean()*10
pose_loss = args.lambda_pose * pose_loss
loss_dic['mpose'] = pose_loss
# mesh texture regularization terms
if 'mtex' in loss_used and ('texture_params' in outputs):
assert 'texture_params' in outputs, "Using mtex as loss but texture_params not outputted."
texture_loss = torch_f.mse_loss(outputs['texture_params'], torch.zeros_like(outputs['texture_params']).to(device))
textures_reg = args.lambda_tex_reg * texture_loss
loss_dic['mtex'] = textures_reg
return loss_dic
def MSE_loss(self, pred, label=0) -> torch.Tensor:
loss = (pred.contiguous() - label) ** 2
return loss.mean()
def L1_loss(self, pred, label=0) -> torch.Tensor:
loss = torch.abs(pred.contiguous() - label)
return loss.mean()