From 3889d9d82591c677ed907ddb740a944853bf175d Mon Sep 17 00:00:00 2001 From: danish87 Date: Sun, 3 Sep 2023 09:18:54 +0200 Subject: [PATCH] roi_scores_multiclass --- pcdet/models/roi_heads/roi_head_template.py | 4 +++- tools/cfgs/kitti_models/pv_rcnn_ssl_60.yaml | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pcdet/models/roi_heads/roi_head_template.py b/pcdet/models/roi_heads/roi_head_template.py index 0b841a774..66bb31918 100644 --- a/pcdet/models/roi_heads/roi_head_template.py +++ b/pcdet/models/roi_heads/roi_head_template.py @@ -119,7 +119,7 @@ def proposal_layer(self, batch_dict, nms_config): rois = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE, batch_box_preds.shape[-1])) roi_scores = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE)) roi_labels = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE), dtype=torch.long) - + roi_scores_multiclass = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE, batch_cls_preds.shape[-1])) for index in range(batch_size): if batch_dict.get('batch_index', None) is not None: assert batch_cls_preds.shape.__len__() == 2 @@ -142,9 +142,11 @@ def proposal_layer(self, batch_dict, nms_config): rois[index, :len(selected), :] = box_preds[selected] roi_scores[index, :len(selected)] = cur_roi_scores[selected] roi_labels[index, :len(selected)] = cur_roi_labels[selected] + roi_scores_multiclass[index, :len(selected), :] = cls_preds[selected] batch_dict['rois'] = rois batch_dict['roi_scores'] = roi_scores + batch_dict['roi_scores_multiclass'] = roi_scores_multiclass batch_dict['roi_labels'] = roi_labels + 1 batch_dict['has_class_labels'] = True if batch_cls_preds.shape[-1] > 1 else False batch_dict.pop('batch_index', None) diff --git a/tools/cfgs/kitti_models/pv_rcnn_ssl_60.yaml b/tools/cfgs/kitti_models/pv_rcnn_ssl_60.yaml index 4ac96f4c4..18c0fd5bd 100644 --- a/tools/cfgs/kitti_models/pv_rcnn_ssl_60.yaml +++ b/tools/cfgs/kitti_models/pv_rcnn_ssl_60.yaml @@ -162,7 +162,7 @@ MODEL: ENABLE_VIS: False ENABLE_PROTO_CONTRASTIVE_LOSS: False PROTO_CONTRASTIVE_LOSS_WEIGHT: 1.0 - ENABLE_SOFT_TEACHER: True + ENABLE_SOFT_TEACHER: False ENABLE_ULB_CLS_DIST_LOSS: False ENABLE_EVAL: True METRICS_PRED_TYPES: [roi_pl_gt, pl_gt_metrics_before_filtering] @@ -221,7 +221,7 @@ MODEL: REG_FG_THRESH: 0.55 UNLABELED_REG_FG_THRESH: [0.55, 0.55, 0.55] - UNLABELED_SAMPLER_TYPE: subsample_unlabeled_rois_default + UNLABELED_SAMPLER_TYPE: subsample_labeled_rois #subsample_unlabeled_rois_default UNLABELED_SAMPLE_EASY_BG: False UNLABELED_SHARP_RCNN_CLS_LABELS: True UNLABELED_USE_CALIBRATED_IOUS: True