diff --git a/configs/_base_/models/fcaf3d.py b/configs/_base_/models/fcaf3d.py new file mode 100644 index 0000000000..b92bfadc77 --- /dev/null +++ b/configs/_base_/models/fcaf3d.py @@ -0,0 +1,16 @@ +model = dict( + type='MinkSingleStage3DDetector', + voxel_size=.01, + backbone=dict(type='MinkResNet', in_channels=3, depth=34), + head=dict( + type='FCAF3DHead', + in_channels=(64, 128, 256, 512), + out_channels=128, + voxel_size=.01, + pts_prune_threshold=100000, + pts_assign_threshold=27, + pts_center_threshold=18, + n_classes=18, + n_reg_outs=6), + train_cfg=dict(), + test_cfg=dict(nms_pre=1000, iou_thr=.5, score_thr=.01)) diff --git a/configs/fcaf3d/fcaf3d_8x2_s3dis-3d-5class.py b/configs/fcaf3d/fcaf3d_8x2_s3dis-3d-5class.py new file mode 100644 index 0000000000..6e17976fd9 --- /dev/null +++ b/configs/fcaf3d/fcaf3d_8x2_s3dis-3d-5class.py @@ -0,0 +1,104 @@ +_base_ = ['fcaf3d_8x2_scannet-3d-18class.py'] +n_points = 100000 + +model = dict(head=dict(n_classes=5)) + +dataset_type = 'S3DISDataset' +data_root = './data/s3dis/' +class_names = ('table', 'chair', 'sofa', 'bookcase', 'board') +train_area = [1, 2, 3, 4, 6] +test_area = 5 + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict(type='LoadAnnotations3D'), + dict(type='PointSample', num_points=n_points), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.087266, 0.087266], + scale_ratio_range=[.9, 1.1], + translation_std=[.1, .1, .1], + shift_height=False), + dict(type='NormalizePointsColor', color_mean=None), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointSample', num_points=n_points), + dict(type='NormalizePointsColor', color_mean=None), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=13, + dataset=dict( + type='ConcatDataset', + datasets=[ + dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + f's3dis_infos_Area_{i}.pkl', + pipeline=train_pipeline, + filter_empty_gt=True, + classes=class_names, + box_type_3d='Depth') for i in train_area + ], + separate_eval=False)), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + f's3dis_infos_Area_{test_area}.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + f's3dis_infos_Area_{test_area}.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) diff --git a/configs/fcaf3d/fcaf3d_8x2_scannet-3d-18class.py b/configs/fcaf3d/fcaf3d_8x2_scannet-3d-18class.py new file mode 100644 index 0000000000..d4e19b6b7b --- /dev/null +++ b/configs/fcaf3d/fcaf3d_8x2_scannet-3d-18class.py @@ -0,0 +1,105 @@ +_base_ = ['../_base_/models/fcaf3d.py', '../_base_/default_runtime.py'] +n_points = 100000 + +dataset_type = 'ScanNetDataset' +data_root = './data/scannet/' +class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', + 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', + 'garbagebin') +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict(type='LoadAnnotations3D'), + dict(type='GlobalAlignment', rotation_axis=2), + dict(type='PointSample', num_points=n_points), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.087266, 0.087266], + scale_ratio_range=[.9, 1.1], + translation_std=[.1, .1, .1], + shift_height=False), + dict(type='NormalizePointsColor', color_mean=None), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointSample', num_points=n_points), + dict(type='NormalizePointsColor', color_mean=None), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=10, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_train.pkl', + pipeline=train_pipeline, + filter_empty_gt=True, + classes=class_names, + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) + +optimizer = dict(type='AdamW', lr=0.001, weight_decay=0.0001) +optimizer_config = dict(grad_clip=dict(max_norm=10, norm_type=2)) +lr_config = dict(policy='step', warmup=None, step=[8, 11]) +runner = dict(type='EpochBasedRunner', max_epochs=12) +custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)] diff --git a/configs/fcaf3d/fcaf3d_8x2_sunrgbd-3d-10class.py b/configs/fcaf3d/fcaf3d_8x2_sunrgbd-3d-10class.py new file mode 100644 index 0000000000..909d82a77b --- /dev/null +++ b/configs/fcaf3d/fcaf3d_8x2_sunrgbd-3d-10class.py @@ -0,0 +1,94 @@ +_base_ = ['fcaf3d_8x2_scannet-3d-18class.py'] +n_points = 100000 + +model = dict( + head=dict( + n_classes=10, n_reg_outs=8, bbox_loss=dict(type='RotatedIoU3DLoss'))) + +dataset_type = 'SUNRGBDDataset' +data_root = 'data/sunrgbd/' +class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', + 'night_stand', 'bookshelf', 'bathtub') +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict(type='LoadAnnotations3D'), + dict(type='PointSample', num_points=n_points), + dict(type='RandomFlip3D', sync_2d=False, flip_ratio_bev_horizontal=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.523599, 0.523599], + scale_ratio_range=[0.85, 1.15], + translation_std=[.1, .1, .1], + shift_height=False), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='PointSample', num_points=n_points), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=3, + dataset=dict( + type=dataset_type, + modality=dict(use_camera=False, use_lidar=True), + data_root=data_root, + ann_file=data_root + 'sunrgbd_infos_train.pkl', + pipeline=train_pipeline, + filter_empty_gt=True, + classes=class_names, + box_type_3d='Depth')), + val=dict( + type=dataset_type, + modality=dict(use_camera=False, use_lidar=True), + data_root=data_root, + ann_file=data_root + 'sunrgbd_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + modality=dict(use_camera=False, use_lidar=True), + data_root=data_root, + ann_file=data_root + 'sunrgbd_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) diff --git a/mmdet3d/models/backbones/mink_resnet.py b/mmdet3d/models/backbones/mink_resnet.py index 35a79ce233..e4bae72a87 100644 --- a/mmdet3d/models/backbones/mink_resnet.py +++ b/mmdet3d/models/backbones/mink_resnet.py @@ -58,7 +58,7 @@ def __init__(self, depth, in_channels, num_stages=4, pool=True): for i, num_blocks in enumerate(stage_blocks): setattr( - self, f'layer{i}', + self, f'layer{i + 1}', self._make_layer(block, 64 * 2**i, stage_blocks[i], stride=2)) def init_weights(self): @@ -111,6 +111,6 @@ def forward(self, x): x = self.maxpool(x) outs = [] for i in range(self.num_stages): - x = getattr(self, f'layer{i}')(x) + x = getattr(self, f'layer{i + 1}')(x) outs.append(x) return outs diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index 25008c95bc..154b7c4a2a 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -4,6 +4,7 @@ from .base_conv_bbox_head import BaseConvBboxHead from .base_mono3d_dense_head import BaseMono3DDenseHead from .centerpoint_head import CenterHead +from .fcaf3d_head import FCAF3DHead from .fcos_mono3d_head import FCOSMono3DHead from .free_anchor3d_head import FreeAnchor3DHead from .groupfree3d_head import GroupFree3DHead @@ -21,5 +22,5 @@ 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead', 'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead', 'GroupFree3DHead', 'PointRPNHead', 'SMOKEMono3DHead', 'PGDHead', - 'MonoFlexHead' + 'MonoFlexHead', 'FCAF3DHead' ] diff --git a/mmdet3d/models/dense_heads/fcaf3d_head.py b/mmdet3d/models/dense_heads/fcaf3d_head.py new file mode 100644 index 0000000000..d39fc55bec --- /dev/null +++ b/mmdet3d/models/dense_heads/fcaf3d_head.py @@ -0,0 +1,679 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from https://github.com/SamsungLabs/fcaf3d/blob/master/mmdet3d/models/dense_heads/fcaf3d_neck_with_head.py # noqa +try: + import MinkowskiEngine as ME +except ImportError: + import warnings + warnings.warn( + 'Please follow `getting_started.md` to install MinkowskiEngine.`') + +import torch +from mmcv.cnn import Scale, bias_init_with_prob +from mmcv.ops import nms3d, nms3d_normal +from mmcv.runner.base_module import BaseModule +from torch import nn + +from mmdet3d.core.bbox.structures import rotation_3d_in_axis +from mmdet3d.models import HEADS, build_loss +from mmdet.core import reduce_mean + + +@HEADS.register_module() +class FCAF3DHead(BaseModule): + r"""Bbox head of `FCAF3D `_. + Actually here we store both the sparse 3D FPN and a head. The neck and + the head can not be simply separated as pruning score on the i-th level + of FPN requires classification scores from i+1-th level of the head. + + Args: + n_classes (int): Number of classes. + in_channels (tuple[int]): Number of channels in input tensors. + out_channels (int): Number of channels in the neck output tensors. + n_reg_outs (int): Number of regression layer channels. + voxel_size (float): Voxel size in meters. + pts_prune_threshold (int): Pruning threshold on each feature level. + pts_assign_threshold (int): Box to location assigner parameter. + Assigner selects the maximum feature level with more locations + inside the box than pts_assign_threshold. + pts_center_threshold (int): Box to location assigner parameter. + After feature level for the box is determined, assigner selects + pts_center_threshold locations closest to the box center. + center_loss (dict, optional): Config of centerness loss. + bbox_loss (dict, optional): Config of bbox loss. + cls_loss (dict, optional): Config of classification loss. + train_cfg (dict, optional): Config for train stage. Defaults to None. + test_cfg (dict, optional): Config for test stage. Defaults to None. + init_cfg (dict, optional): Config for weight initialization. + Defaults to None. + """ + + def __init__(self, + n_classes, + in_channels, + out_channels, + n_reg_outs, + voxel_size, + pts_prune_threshold, + pts_assign_threshold, + pts_center_threshold, + center_loss=dict(type='CrossEntropyLoss', use_sigmoid=True), + bbox_loss=dict(type='AxisAlignedIoULoss'), + cls_loss=dict(type='FocalLoss'), + train_cfg=None, + test_cfg=None, + init_cfg=None): + super(FCAF3DHead, self).__init__(init_cfg) + self.voxel_size = voxel_size + self.pts_prune_threshold = pts_prune_threshold + self.pts_assign_threshold = pts_assign_threshold + self.pts_center_threshold = pts_center_threshold + self.center_loss = build_loss(center_loss) + self.bbox_loss = build_loss(bbox_loss) + self.cls_loss = build_loss(cls_loss) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self._init_layers(in_channels, out_channels, n_reg_outs, n_classes) + + @staticmethod + def _make_block(in_channels, out_channels): + """Construct Conv-Norm-Act block. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + + Returns: + torch.nn.Module: With corresponding layers. + """ + return nn.Sequential( + ME.MinkowskiConvolution( + in_channels, out_channels, kernel_size=3, dimension=3), + ME.MinkowskiBatchNorm(out_channels), ME.MinkowskiELU()) + + @staticmethod + def _make_up_block(in_channels, out_channels): + """Construct DeConv-Norm-Act-Conv-Norm-Act block. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + + Returns: + torch.nn.Module: With corresponding layers. + """ + return nn.Sequential( + ME.MinkowskiGenerativeConvolutionTranspose( + in_channels, + out_channels, + kernel_size=2, + stride=2, + dimension=3), ME.MinkowskiBatchNorm(out_channels), + ME.MinkowskiELU(), + ME.MinkowskiConvolution( + out_channels, out_channels, kernel_size=3, dimension=3), + ME.MinkowskiBatchNorm(out_channels), ME.MinkowskiELU()) + + def _init_layers(self, in_channels, out_channels, n_reg_outs, n_classes): + """Initialize layers. + + Args: + in_channels (tuple[int]): Number of channels in input tensors. + out_channels (int): Number of channels in the neck output tensors. + n_reg_outs (int): Number of regression layer channels. + n_classes (int): Number of classes. + """ + # neck layers + self.pruning = ME.MinkowskiPruning() + for i in range(len(in_channels)): + if i > 0: + self.__setattr__( + f'up_block_{i}', + self._make_up_block(in_channels[i], in_channels[i - 1])) + self.__setattr__(f'out_block_{i}', + self._make_block(in_channels[i], out_channels)) + + # head layers + self.conv_center = ME.MinkowskiConvolution( + out_channels, 1, kernel_size=1, dimension=3) + self.conv_reg = ME.MinkowskiConvolution( + out_channels, n_reg_outs, kernel_size=1, dimension=3) + self.conv_cls = ME.MinkowskiConvolution( + out_channels, n_classes, kernel_size=1, bias=True, dimension=3) + self.scales = nn.ModuleList( + [Scale(1.) for _ in range(len(in_channels))]) + + def init_weights(self): + """Initialize weights.""" + nn.init.normal_(self.conv_center.kernel, std=.01) + nn.init.normal_(self.conv_reg.kernel, std=.01) + nn.init.normal_(self.conv_cls.kernel, std=.01) + nn.init.constant_(self.conv_cls.bias, bias_init_with_prob(.01)) + + def forward(self, x): + """Forward pass. + + Args: + x (list[Tensor]): Features from the backbone. + + Returns: + list[list[Tensor]]: Predictions of the head. + """ + center_preds, bbox_preds, cls_preds, points = [], [], [], [] + inputs = x + x = inputs[-1] + prune_score = None + for i in range(len(inputs) - 1, -1, -1): + if i < len(inputs) - 1: + x = self.__getattr__(f'up_block_{i + 1}')(x) + x = inputs[i] + x + x = self._prune(x, prune_score) + + out = self.__getattr__(f'out_block_{i}')(x) + center_pred, bbox_pred, cls_pred, point, prune_score = \ + self._forward_single(out, self.scales[i]) + center_preds.append(center_pred) + bbox_preds.append(bbox_pred) + cls_preds.append(cls_pred) + points.append(point) + return center_preds[::-1], bbox_preds[::-1], cls_preds[::-1], \ + points[::-1] + + def forward_train(self, x, gt_bboxes, gt_labels, input_metas): + """Forward pass of the train stage. + + Args: + x (list[SparseTensor]): Features from the backbone. + gt_bboxes (list[:obj:`BaseInstance3DBoxes`]): Ground truth + bboxes of each sample. + gt_labels(list[torch.Tensor]): Labels of each sample. + input_metas (list[dict]): Contains scene meta info for each sample. + + Returns: + dict: Centerness, bbox and classification loss values. + """ + center_preds, bbox_preds, cls_preds, points = self(x) + return self._loss(center_preds, bbox_preds, cls_preds, points, + gt_bboxes, gt_labels, input_metas) + + def forward_test(self, x, input_metas): + """Forward pass of the test stage. + + Args: + x (list[SparseTensor]): Features from the backbone. + input_metas (list[dict]): Contains scene meta info for each sample. + + Returns: + list[list[Tensor]]: bboxes, scores and labels for each sample. + """ + center_preds, bbox_preds, cls_preds, points = self(x) + return self._get_bboxes(center_preds, bbox_preds, cls_preds, points, + input_metas) + + def _prune(self, x, scores): + """Prunes the tensor by score thresholding. + + Args: + x (SparseTensor): Tensor to be pruned. + scores (SparseTensor): Scores for thresholding. + + Returns: + SparseTensor: Pruned tensor. + """ + with torch.no_grad(): + coordinates = x.C.float() + interpolated_scores = scores.features_at_coordinates(coordinates) + prune_mask = interpolated_scores.new_zeros( + (len(interpolated_scores)), dtype=torch.bool) + for permutation in x.decomposition_permutations: + score = interpolated_scores[permutation] + mask = score.new_zeros((len(score)), dtype=torch.bool) + topk = min(len(score), self.pts_prune_threshold) + ids = torch.topk(score.squeeze(1), topk, sorted=False).indices + mask[ids] = True + prune_mask[permutation[mask]] = True + x = self.pruning(x, prune_mask) + return x + + def _forward_single(self, x, scale): + """Forward pass per level. + + Args: + x (SparseTensor): Per level neck output tensor. + scale (mmcv.cnn.Scale): Per level multiplication weight. + + Returns: + tuple[Tensor]: Per level head predictions. + """ + center_pred = self.conv_center(x).features + scores = self.conv_cls(x) + cls_pred = scores.features + prune_scores = ME.SparseTensor( + scores.features.max(dim=1, keepdim=True).values, + coordinate_map_key=scores.coordinate_map_key, + coordinate_manager=scores.coordinate_manager) + reg_final = self.conv_reg(x).features + reg_distance = torch.exp(scale(reg_final[:, :6])) + reg_angle = reg_final[:, 6:] + bbox_pred = torch.cat((reg_distance, reg_angle), dim=1) + + center_preds, bbox_preds, cls_preds, points = [], [], [], [] + for permutation in x.decomposition_permutations: + center_preds.append(center_pred[permutation]) + bbox_preds.append(bbox_pred[permutation]) + cls_preds.append(cls_pred[permutation]) + + points = x.decomposed_coordinates + for i in range(len(points)): + points[i] = points[i] * self.voxel_size + + return center_preds, bbox_preds, cls_preds, points, prune_scores + + def _loss_single(self, center_preds, bbox_preds, cls_preds, points, + gt_bboxes, gt_labels, input_meta): + """Per scene loss function. + + Args: + center_preds (list[Tensor]): Centerness predictions for all levels. + bbox_preds (list[Tensor]): Bbox predictions for all levels. + cls_preds (list[Tensor]): Classification predictions for all + levels. + points (list[Tensor]): Final location coordinates for all levels. + gt_bboxes (BaseInstance3DBoxes): Ground truth boxes. + gt_labels (Tensor): Ground truth labels. + input_meta (dict): Scene meta info. + + Returns: + tuple[Tensor]: Centerness, bbox, and classification loss values. + """ + center_targets, bbox_targets, cls_targets = self._get_targets( + points, gt_bboxes, gt_labels) + + center_preds = torch.cat(center_preds) + bbox_preds = torch.cat(bbox_preds) + cls_preds = torch.cat(cls_preds) + points = torch.cat(points) + + # cls loss + pos_inds = torch.nonzero(cls_targets >= 0).squeeze(1) + n_pos = points.new_tensor(len(pos_inds)) + n_pos = max(reduce_mean(n_pos), 1.) + cls_loss = self.cls_loss(cls_preds, cls_targets, avg_factor=n_pos) + + # bbox and centerness losses + pos_center_preds = center_preds[pos_inds] + pos_bbox_preds = bbox_preds[pos_inds] + pos_center_targets = center_targets[pos_inds].unsqueeze(1) + pos_bbox_targets = bbox_targets[pos_inds] + # reduce_mean is outside if / else block to prevent deadlock + center_denorm = max( + reduce_mean(pos_center_targets.sum().detach()), 1e-6) + if len(pos_inds) > 0: + pos_points = points[pos_inds] + center_loss = self.center_loss( + pos_center_preds, pos_center_targets, avg_factor=n_pos) + bbox_loss = self.bbox_loss( + self._bbox_to_loss( + self._bbox_pred_to_bbox(pos_points, pos_bbox_preds)), + self._bbox_to_loss(pos_bbox_targets), + weight=pos_center_targets.squeeze(1), + avg_factor=center_denorm) + else: + center_loss = pos_center_preds.sum() + bbox_loss = pos_bbox_preds.sum() + return center_loss, bbox_loss, cls_loss + + def _loss(self, center_preds, bbox_preds, cls_preds, points, gt_bboxes, + gt_labels, input_metas): + """Per scene loss function. + + Args: + center_preds (list[list[Tensor]]): Centerness predictions for + all scenes. + bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes. + cls_preds (list[list[Tensor]]): Classification predictions for all + scenes. + points (list[list[Tensor]]): Final location coordinates for all + scenes. + gt_bboxes (list[BaseInstance3DBoxes]): Ground truth boxes for all + scenes. + gt_labels (list[Tensor]): Ground truth labels for all scenes. + input_metas (list[dict]): Meta infos for all scenes. + + Returns: + dict: Centerness, bbox, and classification loss values. + """ + center_losses, bbox_losses, cls_losses = [], [], [] + for i in range(len(input_metas)): + center_loss, bbox_loss, cls_loss = self._loss_single( + center_preds=[x[i] for x in center_preds], + bbox_preds=[x[i] for x in bbox_preds], + cls_preds=[x[i] for x in cls_preds], + points=[x[i] for x in points], + input_meta=input_metas[i], + gt_bboxes=gt_bboxes[i], + gt_labels=gt_labels[i]) + center_losses.append(center_loss) + bbox_losses.append(bbox_loss) + cls_losses.append(cls_loss) + return dict( + center_loss=torch.mean(torch.stack(center_losses)), + bbox_loss=torch.mean(torch.stack(bbox_losses)), + cls_loss=torch.mean(torch.stack(cls_losses))) + + def _get_bboxes_single(self, center_preds, bbox_preds, cls_preds, points, + input_meta): + """Generate boxes for a single scene. + + Args: + center_preds (list[Tensor]): Centerness predictions for all levels. + bbox_preds (list[Tensor]): Bbox predictions for all levels. + cls_preds (list[Tensor]): Classification predictions for all + levels. + points (list[Tensor]): Final location coordinates for all levels. + input_meta (dict): Scene meta info. + + Returns: + tuple[Tensor]: Predicted bounding boxes, scores and labels. + """ + mlvl_bboxes, mlvl_scores = [], [] + for center_pred, bbox_pred, cls_pred, point in zip( + center_preds, bbox_preds, cls_preds, points): + scores = cls_pred.sigmoid() * center_pred.sigmoid() + max_scores, _ = scores.max(dim=1) + + if len(scores) > self.test_cfg.nms_pre > 0: + _, ids = max_scores.topk(self.test_cfg.nms_pre) + bbox_pred = bbox_pred[ids] + scores = scores[ids] + point = point[ids] + + bboxes = self._bbox_pred_to_bbox(point, bbox_pred) + mlvl_bboxes.append(bboxes) + mlvl_scores.append(scores) + + bboxes = torch.cat(mlvl_bboxes) + scores = torch.cat(mlvl_scores) + bboxes, scores, labels = self._single_scene_multiclass_nms( + bboxes, scores, input_meta) + return bboxes, scores, labels + + def _get_bboxes(self, center_preds, bbox_preds, cls_preds, points, + input_metas): + """Generate boxes for all scenes. + + Args: + center_preds (list[list[Tensor]]): Centerness predictions for + all scenes. + bbox_preds (list[list[Tensor]]): Bbox predictions for all scenes. + cls_preds (list[list[Tensor]]): Classification predictions for all + scenes. + points (list[list[Tensor]]): Final location coordinates for all + scenes. + input_metas (list[dict]): Meta infos for all scenes. + + Returns: + list[tuple[Tensor]]: Predicted bboxes, scores, and labels for + all scenes. + """ + results = [] + for i in range(len(input_metas)): + result = self._get_bboxes_single( + center_preds=[x[i] for x in center_preds], + bbox_preds=[x[i] for x in bbox_preds], + cls_preds=[x[i] for x in cls_preds], + points=[x[i] for x in points], + input_meta=input_metas[i]) + results.append(result) + return results + + @staticmethod + def _bbox_to_loss(bbox): + """Transform box to the axis-aligned or rotated iou loss format. + + Args: + bbox (Tensor): 3D box of shape (N, 6) or (N, 7). + + Returns: + Tensor: Transformed 3D box of shape (N, 6) or (N, 7). + """ + # rotated iou loss accepts (x, y, z, w, h, l, heading) + if bbox.shape[-1] != 6: + return bbox + + # axis-aligned case: x, y, z, w, h, l -> x1, y1, z1, x2, y2, z2 + return torch.stack( + (bbox[..., 0] - bbox[..., 3] / 2, bbox[..., 1] - bbox[..., 4] / 2, + bbox[..., 2] - bbox[..., 5] / 2, bbox[..., 0] + bbox[..., 3] / 2, + bbox[..., 1] + bbox[..., 4] / 2, bbox[..., 2] + bbox[..., 5] / 2), + dim=-1) + + @staticmethod + def _bbox_pred_to_bbox(points, bbox_pred): + """Transform predicted bbox parameters to bbox. + + Args: + points (Tensor): Final locations of shape (N, 3) + bbox_pred (Tensor): Predicted bbox parameters of shape (N, 6) + or (N, 8). + + Returns: + Tensor: Transformed 3D box of shape (N, 6) or (N, 7). + """ + if bbox_pred.shape[0] == 0: + return bbox_pred + + x_center = points[:, 0] + (bbox_pred[:, 1] - bbox_pred[:, 0]) / 2 + y_center = points[:, 1] + (bbox_pred[:, 3] - bbox_pred[:, 2]) / 2 + z_center = points[:, 2] + (bbox_pred[:, 5] - bbox_pred[:, 4]) / 2 + + # dx_min, dx_max, dy_min, dy_max, dz_min, dz_max -> x, y, z, w, l, h + base_bbox = torch.stack([ + x_center, + y_center, + z_center, + bbox_pred[:, 0] + bbox_pred[:, 1], + bbox_pred[:, 2] + bbox_pred[:, 3], + bbox_pred[:, 4] + bbox_pred[:, 5], + ], -1) + + # axis-aligned case + if bbox_pred.shape[1] == 6: + return base_bbox + + # rotated case: ..., sin(2a)ln(q), cos(2a)ln(q) + scale = bbox_pred[:, 0] + bbox_pred[:, 1] + \ + bbox_pred[:, 2] + bbox_pred[:, 3] + q = torch.exp( + torch.sqrt( + torch.pow(bbox_pred[:, 6], 2) + torch.pow(bbox_pred[:, 7], 2))) + alpha = 0.5 * torch.atan2(bbox_pred[:, 6], bbox_pred[:, 7]) + return torch.stack( + (x_center, y_center, z_center, scale / (1 + q), scale / + (1 + q) * q, bbox_pred[:, 5] + bbox_pred[:, 4], alpha), + dim=-1) + + @staticmethod + def _get_face_distances(points, boxes): + """Calculate distances from point to box faces. + + Args: + points (Tensor): Final locations of shape (N_points, N_boxes, 3). + boxes (Tensor): 3D boxes of shape (N_points, N_boxes, 7) + + Returns: + Tensor: Face distances of shape (N_points, N_boxes, 6), + (dx_min, dx_max, dy_min, dy_max, dz_min, dz_max). + """ + shift = torch.stack( + (points[..., 0] - boxes[..., 0], points[..., 1] - boxes[..., 1], + points[..., 2] - boxes[..., 2]), + dim=-1).permute(1, 0, 2) + shift = rotation_3d_in_axis( + shift, -boxes[0, :, 6], axis=2).permute(1, 0, 2) + centers = boxes[..., :3] + shift + dx_min = centers[..., 0] - boxes[..., 0] + boxes[..., 3] / 2 + dx_max = boxes[..., 0] + boxes[..., 3] / 2 - centers[..., 0] + dy_min = centers[..., 1] - boxes[..., 1] + boxes[..., 4] / 2 + dy_max = boxes[..., 1] + boxes[..., 4] / 2 - centers[..., 1] + dz_min = centers[..., 2] - boxes[..., 2] + boxes[..., 5] / 2 + dz_max = boxes[..., 2] + boxes[..., 5] / 2 - centers[..., 2] + return torch.stack((dx_min, dx_max, dy_min, dy_max, dz_min, dz_max), + dim=-1) + + @staticmethod + def _get_centerness(face_distances): + """Compute point centerness w.r.t containing box. + + Args: + face_distances (Tensor): Face distances of shape (B, N, 6), + (dx_min, dx_max, dy_min, dy_max, dz_min, dz_max). + + Returns: + Tensor: Centerness of shape (B, N). + """ + x_dims = face_distances[..., [0, 1]] + y_dims = face_distances[..., [2, 3]] + z_dims = face_distances[..., [4, 5]] + centerness_targets = x_dims.min(dim=-1)[0] / x_dims.max(dim=-1)[0] * \ + y_dims.min(dim=-1)[0] / y_dims.max(dim=-1)[0] * \ + z_dims.min(dim=-1)[0] / z_dims.max(dim=-1)[0] + return torch.sqrt(centerness_targets) + + @torch.no_grad() + def _get_targets(self, points, gt_bboxes, gt_labels): + """Compute targets for final locations for a single scene. + + Args: + points (list[Tensor]): Final locations for all levels. + gt_bboxes (BaseInstance3DBoxes): Ground truth boxes. + gt_labels (Tensor): Ground truth labels. + + Returns: + tuple[Tensor]: Centerness, bbox and classification + targets for all locations. + """ + float_max = points[0].new_tensor(1e8) + n_levels = len(points) + levels = torch.cat([ + points[i].new_tensor(i).expand(len(points[i])) + for i in range(len(points)) + ]) + points = torch.cat(points) + gt_bboxes = gt_bboxes.to(points.device) + n_points = len(points) + n_boxes = len(gt_bboxes) + volumes = gt_bboxes.volume.unsqueeze(0).expand(n_points, n_boxes) + + # condition 1: point inside box + boxes = torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), + dim=1) + boxes = boxes.expand(n_points, n_boxes, 7) + points = points.unsqueeze(1).expand(n_points, n_boxes, 3) + face_distances = self._get_face_distances(points, boxes) + inside_box_condition = face_distances.min(dim=-1).values > 0 + + # condition 2: positive points per level >= limit + # calculate positive points per scale + n_pos_points_per_level = [] + for i in range(n_levels): + n_pos_points_per_level.append( + torch.sum(inside_box_condition[levels == i], dim=0)) + # find best level + n_pos_points_per_level = torch.stack(n_pos_points_per_level, dim=0) + lower_limit_mask = n_pos_points_per_level < self.pts_assign_threshold + lower_index = torch.argmax(lower_limit_mask.int(), dim=0) - 1 + lower_index = torch.where(lower_index < 0, 0, lower_index) + all_upper_limit_mask = torch.all( + torch.logical_not(lower_limit_mask), dim=0) + best_level = torch.where(all_upper_limit_mask, n_levels - 1, + lower_index) + # keep only points with best level + best_level = best_level.expand(n_points, n_boxes) + levels = torch.unsqueeze(levels, 1).expand(n_points, n_boxes) + level_condition = best_level == levels + + # condition 3: limit topk points per box by centerness + centerness = self._get_centerness(face_distances) + centerness = torch.where(inside_box_condition, centerness, + torch.ones_like(centerness) * -1) + centerness = torch.where(level_condition, centerness, + torch.ones_like(centerness) * -1) + top_centerness = torch.topk( + centerness, + min(self.pts_center_threshold + 1, len(centerness)), + dim=0).values[-1] + topk_condition = centerness > top_centerness.unsqueeze(0) + + # condition 4: min volume box per point + volumes = torch.where(inside_box_condition, volumes, float_max) + volumes = torch.where(level_condition, volumes, float_max) + volumes = torch.where(topk_condition, volumes, float_max) + min_volumes, min_inds = volumes.min(dim=1) + + center_targets = centerness[torch.arange(n_points), min_inds] + bbox_targets = boxes[torch.arange(n_points), min_inds] + if not gt_bboxes.with_yaw: + bbox_targets = bbox_targets[:, :-1] + cls_targets = gt_labels[min_inds] + cls_targets = torch.where(min_volumes == float_max, -1, cls_targets) + return center_targets, bbox_targets, cls_targets + + def _single_scene_multiclass_nms(self, bboxes, scores, input_meta): + """Multi-class nms for a single scene. + + Args: + bboxes (Tensor): Predicted boxes of shape (N_boxes, 6) or + (N_boxes, 7). + scores (Tensor): Predicted scores of shape (N_boxes, N_classes). + input_meta (dict): Scene meta data. + + Returns: + tuple[Tensor]: Predicted bboxes, scores and labels. + """ + n_classes = scores.shape[1] + with_yaw = bboxes.shape[1] == 7 + nms_bboxes, nms_scores, nms_labels = [], [], [] + for i in range(n_classes): + ids = scores[:, i] > self.test_cfg.score_thr + if not ids.any(): + continue + + class_scores = scores[ids, i] + class_bboxes = bboxes[ids] + if with_yaw: + nms_function = nms3d + else: + class_bboxes = torch.cat( + (class_bboxes, torch.zeros_like(class_bboxes[:, :1])), + dim=1) + nms_function = nms3d_normal + + nms_ids = nms_function(class_bboxes, class_scores, + self.test_cfg.iou_thr) + nms_bboxes.append(class_bboxes[nms_ids]) + nms_scores.append(class_scores[nms_ids]) + nms_labels.append( + bboxes.new_full( + class_scores[nms_ids].shape, i, dtype=torch.long)) + + if len(nms_bboxes): + nms_bboxes = torch.cat(nms_bboxes, dim=0) + nms_scores = torch.cat(nms_scores, dim=0) + nms_labels = torch.cat(nms_labels, dim=0) + else: + nms_bboxes = bboxes.new_zeros((0, bboxes.shape[1])) + nms_scores = bboxes.new_zeros((0, )) + nms_labels = bboxes.new_zeros((0, )) + + if with_yaw: + box_dim = 7 + else: + box_dim = 6 + nms_bboxes = nms_bboxes[:, :6] + nms_bboxes = input_meta['box_type_3d']( + nms_bboxes, + box_dim=box_dim, + with_yaw=with_yaw, + origin=(.5, .5, .5)) + + return nms_bboxes, nms_scores, nms_labels diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index 1924b1232d..a7d5aa1cf0 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -7,6 +7,7 @@ from .h3dnet import H3DNet from .imvotenet import ImVoteNet from .imvoxelnet import ImVoxelNet +from .mink_single_stage import MinkSingleStage3DDetector from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN from .mvx_two_stage import MVXTwoStageDetector from .parta2 import PartA2 @@ -23,5 +24,5 @@ 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet', 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector', 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D', - 'SASSD' + 'MinkSingleStage3DDetector', 'SASSD' ] diff --git a/mmdet3d/models/detectors/mink_single_stage.py b/mmdet3d/models/detectors/mink_single_stage.py new file mode 100644 index 0000000000..963eacd42c --- /dev/null +++ b/mmdet3d/models/detectors/mink_single_stage.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from https://github.com/SamsungLabs/fcaf3d/blob/master/mmdet3d/models/detectors/single_stage_sparse.py # noqa +try: + import MinkowskiEngine as ME +except ImportError: + import warnings + warnings.warn( + 'Please follow `getting_started.md` to install MinkowskiEngine.`') + +from mmdet3d.core import bbox3d2result +from mmdet3d.models import DETECTORS, build_backbone, build_head +from .base import Base3DDetector + + +@DETECTORS.register_module() +class MinkSingleStage3DDetector(Base3DDetector): + r"""Single stage detector based on MinkowskiEngine `GSDN + `_. + + Args: + backbone (dict): Config of the backbone. + head (dict): Config of the head. + voxel_size (float): Voxel size in meters. + train_cfg (dict, optional): Config for train stage. Defaults to None. + test_cfg (dict, optional): Config for test stage. Defaults to None. + init_cfg (dict, optional): Config for weight initialization. + Defaults to None. + pretrained (str, optional): Deprecated initialization parameter. + Defaults to None. + """ + + def __init__(self, + backbone, + head, + voxel_size, + train_cfg=None, + test_cfg=None, + init_cfg=None, + pretrained=None): + super(MinkSingleStage3DDetector, self).__init__(init_cfg) + self.backbone = build_backbone(backbone) + head.update(train_cfg=train_cfg) + head.update(test_cfg=test_cfg) + self.head = build_head(head) + self.voxel_size = voxel_size + self.init_weights() + + def extract_feat(self, points): + """Extract features from points. + + Args: + points (list[Tensor]): Raw point clouds. + + Returns: + SparseTensor: Voxelized point clouds. + """ + coordinates, features = ME.utils.batch_sparse_collate( + [(p[:, :3] / self.voxel_size, p[:, 3:]) for p in points], + device=points[0].device) + x = ME.SparseTensor(coordinates=coordinates, features=features) + x = self.backbone(x) + return x + + def forward_train(self, points, gt_bboxes_3d, gt_labels_3d, img_metas): + """Forward of training. + + Args: + points (list[Tensor]): Raw point clouds. + gt_bboxes (list[BaseInstance3DBoxes]): Ground truth + bboxes of each sample. + gt_labels(list[torch.Tensor]): Labels of each sample. + img_metas (list[dict]): Contains scene meta infos. + + Returns: + dict: Centerness, bbox and classification loss values. + """ + x = self.extract_feat(points) + losses = self.head.forward_train(x, gt_bboxes_3d, gt_labels_3d, + img_metas) + return losses + + def simple_test(self, points, img_metas, *args, **kwargs): + """Test without augmentations. + + Args: + points (list[torch.Tensor]): Points of each sample. + img_metas (list[dict]): Contains scene meta infos. + + Returns: + list[dict]: Predicted 3d boxes. + """ + x = self.extract_feat(points) + bbox_list = self.head.forward_test(x, img_metas) + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def aug_test(self, points, img_metas, **kwargs): + """Test with augmentations. + + Args: + points (list[list[torch.Tensor]]): Points of each sample. + img_metas (list[dict]): Contains scene meta infos. + + Returns: + list[dict]: Predicted 3d boxes. + """ + raise NotImplementedError diff --git a/mmdet3d/models/losses/__init__.py b/mmdet3d/models/losses/__init__.py index dcdc69ab6d..c3c0943744 100644 --- a/mmdet3d/models/losses/__init__.py +++ b/mmdet3d/models/losses/__init__.py @@ -4,11 +4,12 @@ from .chamfer_distance import ChamferDistance, chamfer_distance from .multibin_loss import MultiBinLoss from .paconv_regularization_loss import PAConvRegularizationLoss +from .rotated_iou_loss import RotatedIoU3DLoss from .uncertain_smooth_l1_loss import UncertainL1Loss, UncertainSmoothL1Loss __all__ = [ 'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance', 'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss', 'PAConvRegularizationLoss', 'UncertainL1Loss', 'UncertainSmoothL1Loss', - 'MultiBinLoss' + 'MultiBinLoss', 'RotatedIoU3DLoss' ] diff --git a/mmdet3d/models/losses/axis_aligned_iou_loss.py b/mmdet3d/models/losses/axis_aligned_iou_loss.py index 428d7bb86c..9f315dd28e 100644 --- a/mmdet3d/models/losses/axis_aligned_iou_loss.py +++ b/mmdet3d/models/losses/axis_aligned_iou_loss.py @@ -9,17 +9,18 @@ @weighted_loss def axis_aligned_iou_loss(pred, target): - """Calculate the IoU loss (1-IoU) of two set of axis aligned bounding + """Calculate the IoU loss (1-IoU) of two sets of axis aligned bounding boxes. Note that predictions and targets are one-to-one corresponded. Args: - pred (torch.Tensor): Bbox predictions with shape [..., 3]. - target (torch.Tensor): Bbox targets (gt) with shape [..., 3]. + pred (torch.Tensor): Bbox predictions with shape [..., 6] + (x1, y1, z1, x2, y2, z2). + target (torch.Tensor): Bbox targets (gt) with shape [..., 6] + (x1, y1, z1, x2, y2, z2). Returns: torch.Tensor: IoU loss between predictions and targets. """ - axis_aligned_iou = AxisAlignedBboxOverlaps3D()( pred, target, is_aligned=True) iou_loss = 1 - axis_aligned_iou @@ -52,8 +53,10 @@ def forward(self, """Forward function of loss calculation. Args: - pred (torch.Tensor): Bbox predictions with shape [..., 3]. - target (torch.Tensor): Bbox targets (gt) with shape [..., 3]. + pred (torch.Tensor): Bbox predictions with shape [..., 6] + (x1, y1, z1, x2, y2, z2). + target (torch.Tensor): Bbox targets (gt) with shape [..., 6] + (x1, y1, z1, x2, y2, z2). weight (torch.Tensor | float, optional): Weight of loss. Defaults to None. avg_factor (int, optional): Average factor that is used to average diff --git a/mmdet3d/models/losses/rotated_iou_loss.py b/mmdet3d/models/losses/rotated_iou_loss.py new file mode 100644 index 0000000000..47c9139629 --- /dev/null +++ b/mmdet3d/models/losses/rotated_iou_loss.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.ops import diff_iou_rotated_3d +from torch import nn as nn + +from mmdet.models.losses.utils import weighted_loss +from ..builder import LOSSES + + +@weighted_loss +def rotated_iou_3d_loss(pred, target): + """Calculate the IoU loss (1-IoU) of two sets of rotated bounding boxes. + Note that predictions and targets are one-to-one corresponded. + + Args: + pred (torch.Tensor): Bbox predictions with shape [N, 7] + (x, y, z, w, l, h, alpha). + target (torch.Tensor): Bbox targets (gt) with shape [N, 7] + (x, y, z, w, l, h, alpha). + + Returns: + torch.Tensor: IoU loss between predictions and targets. + """ + iou_loss = 1 - diff_iou_rotated_3d(pred.unsqueeze(0), + target.unsqueeze(0))[0] + return iou_loss + + +@LOSSES.register_module() +class RotatedIoU3DLoss(nn.Module): + """Calculate the IoU loss (1-IoU) of rotated bounding boxes. + + Args: + reduction (str): Method to reduce losses. + The valid reduction method are none, sum or mean. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + """ + + def __init__(self, reduction='mean', loss_weight=1.0): + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None, + **kwargs): + """Forward function of loss calculation. + + Args: + pred (torch.Tensor): Bbox predictions with shape [..., 7] + (x, y, z, w, l, h, alpha). + target (torch.Tensor): Bbox targets (gt) with shape [..., 7] + (x, y, z, w, l, h, alpha). + weight (torch.Tensor | float, optional): Weight of loss. + Defaults to None. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): Method to reduce losses. + The valid reduction method are 'none', 'sum' or 'mean'. + Defaults to None. + + Returns: + torch.Tensor: IoU loss between predictions and targets. + """ + if weight is not None and not torch.any(weight > 0): + return pred.sum() * weight.sum() # 0 + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + if weight is not None and weight.dim() > 1: + weight = weight.mean(-1) + loss = self.loss_weight * rotated_iou_3d_loss( + pred, + target, + weight, + reduction=reduction, + avg_factor=avg_factor, + **kwargs) + + return loss diff --git a/tests/test_metrics/test_losses.py b/tests/test_metrics/test_losses.py index bde44aa00e..c7f128c86b 100644 --- a/tests/test_metrics/test_losses.py +++ b/tests/test_metrics/test_losses.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np import pytest import torch from torch import nn as nn @@ -172,3 +173,39 @@ def test_multibin_loss(): expected_multibin_loss = torch.tensor(2.1120) assert torch.allclose( output_multibin_loss, expected_multibin_loss, atol=1e-4) + + +def test_axis_aligned_iou_loss(): + from mmdet3d.models.losses import AxisAlignedIoULoss + + boxes1 = torch.tensor([[0., 0., 0., 1., 1., 1.], [0., 0., 0., 1., 1., 1.], + [0., 0., 0., 1., 1., 1.]]) + boxes2 = torch.tensor([[0., 0., 0., 1., 1., + 1.], [.5, .5, .5, 1.5, 1.5, 1.5], + [1., 1., 1., 2., 2., 2.]]) + + expect_ious = torch.tensor([[0., 14 / 15, 1.]]) + ious = AxisAlignedIoULoss(reduction='none')(boxes1, boxes2) + assert torch.allclose(ious, expect_ious, atol=1e-4) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_rotated_iou_3d_loss(): + # adapted from mmcv.tests.test_ops.test_diff_iou_rotated + from mmdet3d.models.losses import RotatedIoU3DLoss + + boxes1 = torch.tensor([[.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0]]).cuda() + boxes2 = torch.tensor([[.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 2., np.pi / 2], + [.5, .5, .5, 1., 1., 1., np.pi / 4], + [1., 1., 1., 1., 1., 1., .0], + [-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]).cuda() + + expect_ious = 1 - torch.tensor([[1., .5, .7071, 1 / 15, .0]]).cuda() + ious = RotatedIoU3DLoss(reduction='none')(boxes1, boxes2) + assert torch.allclose(ious, expect_ious, atol=1e-4) diff --git a/tests/test_models/test_heads/test_heads.py b/tests/test_models/test_heads/test_heads.py index 185d26bb5b..9fd0cfb43a 100644 --- a/tests/test_models/test_heads/test_heads.py +++ b/tests/test_models/test_heads/test_heads.py @@ -1562,3 +1562,96 @@ def test_monoflex_head(): assert cls_score[0].shape == torch.Size([2, 3, 32, 32]) assert out_reg[0].shape == torch.Size([2, 50, 32, 32]) + + +def test_fcaf3d_head(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + + try: + import MinkowskiEngine as ME + except ImportError: + pytest.skip('test requires MinkowskiEngine installation') + + _setup_seed(0) + + coordinates, features = [], [] + # batch of 2 point clouds + for i in range(2): + c = torch.from_numpy(np.random.rand(500, 3) * 100) + coordinates.append(c.float().cuda()) + f = torch.from_numpy(np.random.rand(500, 3)) + features.append(f.float().cuda()) + tensor_coordinates, tensor_features = ME.utils.sparse_collate( + coordinates, features) + x = ME.SparseTensor( + features=tensor_features, coordinates=tensor_coordinates) + + # backbone + conv1 = ME.MinkowskiConvolution( + 3, 64, kernel_size=3, stride=2, dimension=3).cuda() + conv2 = ME.MinkowskiConvolution( + 64, 128, kernel_size=3, stride=2, dimension=3).cuda() + conv3 = ME.MinkowskiConvolution( + 128, 256, kernel_size=3, stride=2, dimension=3).cuda() + conv4 = ME.MinkowskiConvolution( + 256, 512, kernel_size=3, stride=2, dimension=3).cuda() + + # backbone outputs of 4 levels + x1 = conv1(x) + x2 = conv2(x1) + x3 = conv3(x2) + x4 = conv4(x3) + x = (x1, x2, x3, x4) + + # build head + cfg = dict( + type='FCAF3DHead', + in_channels=(64, 128, 256, 512), + out_channels=128, + voxel_size=1., + pts_prune_threshold=1000, + pts_assign_threshold=27, + pts_center_threshold=18, + n_classes=18, + n_reg_outs=6) + test_cfg = mmcv.Config(dict(nms_pre=1000, iou_thr=.5, score_thr=.01)) + cfg.update(test_cfg=test_cfg) + head = build_head(cfg).cuda() + + # test forward train + gt_bboxes = [ + DepthInstance3DBoxes( + torch.tensor([[10., 10., 10., 10., 10., 10.], + [30., 30., 30., 30., 30., 30.]]), + box_dim=6, + with_yaw=False), + DepthInstance3DBoxes( + torch.tensor([[20., 20., 20., 20., 20., 20.], + [40., 40., 40., 40., 40., 40.]]), + box_dim=6, + with_yaw=False) + ] + gt_labels = [torch.tensor([2, 4]).cuda(), torch.tensor([3, 5]).cuda()] + img_metas = [ + dict(box_type_3d=DepthInstance3DBoxes), + dict(box_type_3d=DepthInstance3DBoxes) + ] + + losses = head.forward_train(x, gt_bboxes, gt_labels, img_metas) + assert torch.allclose( + losses['center_loss'].detach().cpu(), torch.tensor(0.7079), atol=1e-4) + assert torch.allclose( + losses['bbox_loss'].detach().cpu(), torch.tensor(0.9995), atol=1e-4) + assert torch.allclose( + losses['cls_loss'].detach().cpu(), torch.tensor(592.8), atol=1e-1) + + # test forward test + bbox_list = head.forward_test(x, img_metas) + assert len(bbox_list) == 2 + for bboxes, scores, labels in bbox_list: + n, dim = bboxes.tensor.shape + assert n > 0 + assert dim == 7 + assert scores.shape == torch.Size([n]) + assert labels.shape == torch.Size([n]) diff --git a/tests/test_runtime/test_config.py b/tests/test_runtime/test_config.py index 650b46df65..6a9d8f590d 100644 --- a/tests/test_runtime/test_config.py +++ b/tests/test_runtime/test_config.py @@ -47,6 +47,11 @@ def test_config_build_model(): if 'pretrained' in config_mod.model: config_mod.model['pretrained'] = None + # We skip detectors based on MikowskiEngine as it is an external + # dependency and may be not installed by the user. + if config_fname.startswith('fcaf3d'): + continue + detector = build_model(config_mod.model) assert detector is not None