-
Notifications
You must be signed in to change notification settings - Fork 18
/
nerfdet.py
executable file
·447 lines (406 loc) · 18.9 KB
/
nerfdet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from mmdet.models.detectors import BaseDetector
from ..model_utils.render_ray import render_rays
from ..model_utils.nerf_mlp import VanillaNeRFRadianceField
from ..model_utils.projection import Projector
from ..model_utils.save_rendered_img import save_rendered_img
from mmdet3d.core import bbox3d2result
import os
@DETECTORS.register_module()
class nerfdet(BaseDetector):
def __init__(self,
backbone,
neck,
neck_3d,
bbox_head,
n_voxels,
voxel_size,
head_2d=None,
train_cfg=None,
test_cfg=None,
pretrained=None,
aabb=None,
near_far_range=None,
N_samples=40,
N_rand=4096,
depth_supervise=False,
use_nerf_mask=True,
nerf_sample_view=3,
nerf_mode="volume",
squeeze_scale=4,
rgb_supervision=True,
nerf_density = False,
render_testing=False):
super().__init__()
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
self.neck_3d = build_neck(neck_3d)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.bbox_head.voxel_size = voxel_size
self.head_2d = build_head(head_2d) if head_2d is not None else None
self.n_voxels = n_voxels
self.voxel_size = voxel_size
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
self.aabb=aabb
self.near_far_range=near_far_range
self.N_samples=N_samples
self.N_rand=N_rand
self.depth_supervise=depth_supervise
self.projector = Projector()
self.squeeze_scale = squeeze_scale
self.use_nerf_mask = use_nerf_mask
self.rgb_supervision = rgb_supervision
nerf_feature_dim = neck["out_channels"] // squeeze_scale
self.nerf_mlp = VanillaNeRFRadianceField(
net_depth=4, # The depth of the MLP.
net_width=256, # The width of the MLP.
skip_layer=3, # The layer to add skip layers to.
feature_dim=nerf_feature_dim + 6, # + RGB original img
net_depth_condition=1, # The depth of the second part of MLP.
net_width_condition=128
)
self.nerf_mode = nerf_mode
self.nerf_density = nerf_density
self.nerf_sample_view = nerf_sample_view
self.render_testing = render_testing
# Chenfeng hard code here. will deal with batch issue later.
self.cov = nn.Sequential(
nn.Conv3d(
neck["out_channels"], neck["out_channels"], kernel_size=3, padding=1
),
nn.ReLU(inplace=True),
nn.Conv3d(
neck["out_channels"], neck["out_channels"], kernel_size=3, padding=1
),
nn.ReLU(inplace=True),
nn.Conv3d(
neck["out_channels"], 1, kernel_size=1
)
)
self.mean_mapping = nn.Sequential(
nn.Conv3d(
neck["out_channels"], nerf_feature_dim//2, kernel_size=1
)
)
self.cov_mapping = nn.Sequential(
nn.Conv3d(
neck["out_channels"], nerf_feature_dim//2, kernel_size=1
)
)
self.mapping = nn.Sequential(
nn.Linear(neck["out_channels"], nerf_feature_dim//2)
)
self.mapping_2d = nn.Sequential(
nn.Conv2d(
neck["out_channels"], nerf_feature_dim//2, kernel_size=1
)
)
# self.overfit_nerfmlp = overfit_nerfmlp
# if self.overfit_nerfmlp:
# self. _finetuning_NeRF_MLP()
self.render_testing = render_testing
def init_weights(self, pretrained=None):
super().init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.neck.init_weights()
self.neck_3d.init_weights()
self.bbox_head.init_weights()
if self.head_2d is not None:
self.head_2d.init_weights()
# def _finetuning_NeRF_MLP(self):
# for p in self.parameters():
# p.requires_grad = False
# for p in self.nerf_mlp.parameters():
# p.requires_grad = True
def extract_feat(self, img, img_metas, mode, depth=None, ray_batch=None):
batch_size = img.shape[0]
img = img.reshape([-1] + list(img.shape)[2:])
if depth is not None:
depth_bs = depth.shape[0]
assert depth_bs == batch_size
depth = depth.reshape([-1] + list(depth.shape)[2:])
x = self.backbone(img)
features_2d = self.head_2d.forward(x[-1], img_metas) if self.head_2d is not None else None
x = self.neck(x)[0]
x = x.reshape([batch_size, -1] + list(x.shape[1:]))
stride = img.shape[-1] / x.shape[-1]
assert stride == 4 # may be removed in the future
stride = int(stride)
volumes, valids = [], []
rgb_preds = []
densitys = []
for feature, img_meta in zip(x, img_metas):
# use predicted pitch and roll for SUNRGBDTotal test
angles = features_2d[0] if features_2d is not None and mode == 'test' else None
projection = self._compute_projection(img_meta, stride, angles).to(x.device)
points = get_points(
n_voxels=torch.tensor(self.n_voxels),
voxel_size=torch.tensor(self.voxel_size),
origin=torch.tensor(img_meta['lidar2img']['origin'])
).to(x.device)
height = img_meta['img_shape'][0] // stride
width = img_meta['img_shape'][1] // stride
volume, valid = backproject(
feature[:, :, :height, :width],
points,
projection,
depth,
self.voxel_size)
density = None
volume_sum = volume.sum(dim=0)
# cov_valid = valid.clone().detach()
valid = valid.sum(dim=0)
# TODO: Maintain a mask and use a learnable token to fill in the unobserved place.
volume_mean = volume_sum / (valid + 1e-8)
volume_mean[:, valid[0]==0] = .0
# volume_cov = (volume - volume_mean.unsqueeze(0)) ** 2 * cov_valid
# volume_cov = torch.sum(volume_cov, dim=0) / (valid + 1e-8)
volume_cov = torch.sum((volume - volume_mean.unsqueeze(0)) ** 2, dim=0) / (valid + 1e-8)
volume_cov[:, valid[0]==0] = 1e6
volume_cov = torch.exp(-volume_cov) # default setting
# be careful here, the smaller the cov, the larger the weight.
n_channels, n_x_voxels, n_y_voxels, n_z_voxels = volume_mean.shape
if ray_batch is not None:
if self.nerf_mode == 'volume':
mean_volume = self.mean_mapping(volume_mean.unsqueeze(0))
cov_volume = self.cov_mapping(volume_cov.unsqueeze(0))
feature_2d = feature[:, :, :height, :width]
elif self.nerf_mode == 'image':
mean_volume = None
cov_volume = None
feature_2d = feature[:, :, :height, :width]
n_v, C, height, width = feature_2d.shape
feature_2d = feature_2d.view(n_v, C, -1).permute(0, 2, 1).contiguous()
feature_2d = self.mapping(feature_2d
).permute(0, 2, 1).contiguous().view(n_v, -1, height, width)
# feature_2d = self.mapping_2d(feature[:, :, :height, :width])
denorm_images = ray_batch['denorm_images']
denorm_images = denorm_images.reshape([-1] + list(denorm_images.shape)[2:])
rgb_projection = self._compute_projection(
img_meta, stride=1, angles=None).to(x.device)
rgb_volume, _ = backproject(
denorm_images[
:, :, :img_meta['img_shape'][0], :img_meta['img_shape'][1]],
points,
rgb_projection,
depth,
self.voxel_size)
ret = render_rays(
ray_batch,
mean_volume,
cov_volume,
feature_2d,
denorm_images,
self.aabb,
self.near_far_range,
self.N_samples,
self.N_rand,
self.nerf_mlp,
img_meta,
self.projector,
self.nerf_mode,
self.nerf_sample_view,
is_train = True if mode == "train" else False,
render_testing=self.render_testing
)
rgb_preds.append(ret)
if self.nerf_density:
# would have 0 bias issue for mean_mapping.
n_v, C, n_x_voxels, n_y_voxels, n_z_voxels = volume.shape
volume = volume.view(n_v, C, -1).permute(0, 2, 1).contiguous()
mapping_volume = self.mapping(volume).permute(0, 2, 1
).contiguous().view(n_v, -1, n_x_voxels, n_y_voxels, n_z_voxels)
mapping_volume = torch.cat([rgb_volume, mapping_volume], dim=1)
mapping_volume_sum = mapping_volume.sum(dim=0)
mapping_volume_mean = mapping_volume_sum / (valid + 1e-8)
# mapping_volume_cov = (
# mapping_volume - mapping_volume_mean.unsqueeze(0)
# ) ** 2 * cov_valid
mapping_volume_cov = (
mapping_volume - mapping_volume_mean.unsqueeze(0)
) ** 2
mapping_volume_cov = torch.sum(mapping_volume_cov, dim=0) / (valid + 1e-8)
mapping_volume_cov[:, valid[0]==0] = 1e6
mapping_volume_cov = torch.exp(-mapping_volume_cov) # default setting
global_volume = torch.cat([mapping_volume_mean, mapping_volume_cov], dim=1)
global_volume = global_volume.view(-1, n_x_voxels*n_y_voxels*n_z_voxels
).permute(1, 0).contiguous()
points = points.view(3, -1).permute(1, 0).contiguous()
density = self.nerf_mlp.query_density(points, global_volume
)
alpha = 1 - torch.exp(-density)
# density -> alpha # (1, n_x_voxels, n_y_voxels, n_z_voxels)
volume = alpha.view(
1, n_x_voxels, n_y_voxels, n_z_voxels) * volume_mean
volume[:, valid[0]==0] = .0
volumes.append(volume)
valids.append(valid)
x = torch.stack(volumes)
valids = torch.stack(valids)
x = self.neck_3d(x)
return x, valids, features_2d, rgb_preds, densitys
def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d, **kwargs):
ray_batchs = {}
if "raydirs" in kwargs.keys():
ray_batchs['ray_o'] = kwargs['lightpos']
ray_batchs['ray_d'] = kwargs['raydirs']
ray_batchs['gt_rgb'] = kwargs['gt_images']
ray_batchs['gt_depth'] = kwargs['gt_depths']
ray_batchs['nerf_sizes'] = kwargs['nerf_sizes']
ray_batchs['denorm_images'] = kwargs['denorm_images']
x, valids, features_2d, rgb_preds, densitys = self.extract_feat(
img, img_metas, 'train', ray_batch=ray_batchs)
else:
x, valids, features_2d, rgb_preds, densitys = self.extract_feat(
img, img_metas, 'train')
losses = self.bbox_head.forward_train(x, valids.float(), img_metas, gt_bboxes_3d, gt_labels_3d)
if self.head_2d is not None:
losses.update(self.head_2d.loss(*features_2d, img_metas))
if len(ray_batchs) != 0 and self.rgb_supervision:
losses.update(self.nvs_loss_func(rgb_preds))
if self.depth_supervise:
losses.update(self.depth_loss_func(rgb_preds))
return losses
def nvs_loss_func(self, rgb_pred):
loss = 0
for ret in rgb_pred:
rgb = ret['outputs_coarse']['rgb']
gt = ret['gt_rgb']
masks = ret['outputs_coarse']['mask']
if self.use_nerf_mask:
loss += torch.sum(
masks.unsqueeze(-1)*(rgb - gt)**2)/(masks.sum() + 1e-6)
else:
loss += torch.mean((rgb - gt)**2)
return dict(loss_nvs=loss)
def depth_loss_func(self, rgb_pred):
loss = 0
for ret in rgb_pred:
depth = ret['outputs_coarse']['depth']
gt = ret['gt_depth'].squeeze(-1)
masks = ret['outputs_coarse']['mask']
if self.use_nerf_mask:
loss += torch.sum(
masks*torch.abs(depth - gt))/(masks.sum() + 1e-6)
else:
loss += torch.mean(torch.abs(depth - gt))
return dict(loss_depth=loss)
def forward_test(self, img, img_metas, **kwargs):
ray_batchs = {}
if "raydirs" in kwargs.keys():
ray_batchs['ray_o'] = kwargs['lightpos']
ray_batchs['ray_d'] = kwargs['raydirs']
ray_batchs['gt_rgb'] = kwargs['gt_images']
ray_batchs['gt_depth'] = kwargs['gt_depths']
ray_batchs['nerf_sizes'] = kwargs['nerf_sizes']
ray_batchs['denorm_images'] = kwargs['denorm_images']
return self.simple_test(
img, img_metas, ray_batch=ray_batchs
)
else:
return self.simple_test(img, img_metas)
def simple_test(self, img, img_metas, depth=None, ray_batch=None, evaluate_nerf=False):
x, valids, features_2d, rgb_preds, densitys = self.extract_feat(
img, img_metas, 'test', depth, ray_batch)
if evaluate_nerf:
psnr, ssim, rmse = save_rendered_img(img_metas, rgb_preds)
x = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(*x, valids.float(), img_metas)
bbox_results = [
bbox3d2result(det_bboxes, det_scores, det_labels)
for det_bboxes, det_scores, det_labels in bbox_list
]
if self.head_2d is not None:
angles, layouts = self.head_2d.get_bboxes(*features_2d, img_metas)
for i in range(len(img)):
bbox_results[i]['angles'] = angles[i]
bbox_results[i]['layout'] = layouts[i]
return bbox_results
def aug_test(self, imgs, img_metas):
pass
def show_results(self, *args, **kwargs):
pass
@staticmethod
def _compute_projection(img_meta, stride, angles):
projection = []
intrinsic = torch.tensor(img_meta['lidar2img']['intrinsic'][:3, :3])
ratio = img_meta['ori_shape'][0] / (img_meta['img_shape'][0] / stride)
intrinsic[:2] /= ratio
# use predicted pitch and roll for SUNRGBDTotal test
if angles is not None:
extrinsics = []
for angle in angles:
extrinsics.append(get_extrinsics(angle).to(intrinsic.device))
else:
extrinsics = map(torch.tensor, img_meta['lidar2img']['extrinsic'])
for extrinsic in extrinsics:
projection.append(intrinsic @ extrinsic[:3])
return torch.stack(projection)
@torch.no_grad()
def get_points(n_voxels, voxel_size, origin):
# origin: point-cloud center.
points = torch.stack(torch.meshgrid([
torch.arange(n_voxels[0]), # 40 W width, x
torch.arange(n_voxels[1]), # 40 D depth, y
torch.arange(n_voxels[2]) # 16 H Heigh, z
]))
new_origin = origin - n_voxels / 2. * voxel_size
points = points * voxel_size.view(3, 1, 1, 1) + new_origin.view(3, 1, 1, 1)
return points
# modify from https://github.com/magicleap/Atlas/blob/master/atlas/model.py
def backproject(features, points, projection, depth, voxel_size):
n_images, n_channels, height, width = features.shape
n_x_voxels, n_y_voxels, n_z_voxels = points.shape[-3:]
points = points.view(1, 3, -1).expand(n_images, 3, -1)
points = torch.cat((points, torch.ones_like(points[:, :1])), dim=1)
points_2d_3 = torch.bmm(projection, points)
x = (points_2d_3[:, 0] / points_2d_3[:, 2]).round().long()
y = (points_2d_3[:, 1] / points_2d_3[:, 2]).round().long()
z = points_2d_3[:, 2]
valid = (x >= 0) & (y >= 0) & (x < width) & (y < height) & (z > 0)
##### below is using depth to sample feature ########
if depth is not None:
depth = F.interpolate(depth.unsqueeze(1), size=(height, width), mode="bilinear").squeeze(1)
for i in range(n_images):
z_mask = z.clone() > 0
z_mask[i, valid[i]] = (z[i, valid[i]] > depth[i, y[i, valid[i]], x[i, valid[i]]] - voxel_size[-1]) & \
(z[i, valid[i]] < depth[i, y[i, valid[i]], x[i, valid[i]]] + voxel_size[-1])
valid = valid & z_mask
######################################################
volume = torch.zeros((n_images, n_channels, points.shape[-1]), device=features.device)
for i in range(n_images):
volume[i, :, valid[i]] = features[i, :, y[i, valid[i]], x[i, valid[i]]]
volume = volume.view(n_images, n_channels, n_x_voxels, n_y_voxels, n_z_voxels)
valid = valid.view(n_images, 1, n_x_voxels, n_y_voxels, n_z_voxels)
return volume, valid
# for SUNRGBDTotal test
def get_extrinsics(angles):
yaw = angles.new_zeros(())
pitch, roll = angles
r = angles.new_zeros((3, 3))
r[0, 0] = torch.cos(yaw) * torch.cos(pitch)
r[0, 1] = torch.sin(yaw) * torch.sin(roll) - torch.cos(yaw) * torch.cos(roll) * torch.sin(pitch)
r[0, 2] = torch.cos(roll) * torch.sin(yaw) + torch.cos(yaw) * torch.sin(pitch) * torch.sin(roll)
r[1, 0] = torch.sin(pitch)
r[1, 1] = torch.cos(pitch) * torch.cos(roll)
r[1, 2] = -torch.cos(pitch) * torch.sin(roll)
r[2, 0] = -torch.cos(pitch) * torch.sin(yaw)
r[2, 1] = torch.cos(yaw) * torch.sin(roll) + torch.cos(roll) * torch.sin(yaw) * torch.sin(pitch)
r[2, 2] = torch.cos(yaw) * torch.cos(roll) - torch.sin(yaw) * torch.sin(pitch) * torch.sin(roll)
# follow Total3DUnderstanding
t = angles.new_tensor([[0., 0., 1.], [0., -1., 0.], [-1., 0., 0.]])
r = t @ r.T
# follow DepthInstance3DBoxes
r = r[:, [2, 0, 1]]
r[2] *= -1
extrinsic = angles.new_zeros((4, 4))
extrinsic[:3, :3] = r
extrinsic[3, 3] = 1.
return extrinsic