diff --git a/tools/cal_results.py b/tools/cal_results.py index 40898fc..de44fe8 100644 --- a/tools/cal_results.py +++ b/tools/cal_results.py @@ -10,13 +10,15 @@ def parse_args(): parser = argparse.ArgumentParser( description='MMDet3D test (and eval) a model') parser.add_argument('results_file', help='the results json file') - parser.add_argument('anno_file', help='annoations json file') + parser.add_argument('ann_file', help='annoations json file') parser.add_argument('--iou_thr', type=list, default=[0.25, 0.5], help='the IoU threshold during evaluation') + args = parser.parse_args() + return args def ground_eval(gt_annos, det_annos, iou_thr): @@ -52,10 +54,9 @@ def ground_eval(gt_annos, det_annos, iou_thr): hard = gt_anno['is_hard'] unique = gt_anno['is_unique'] - box_index = scores.argsort(dim=-1, descending=True)[:10] - top_bbox = bboxes[box_index] + top_bboxes = bboxes[:10] - iou = top_bbox.overlaps(top_bbox, gt_bboxes) # (num_query, 1) + iou = top_bboxes.overlaps(top_bboxes, gt_bboxes) # (num_query, 1) for t in iou_thr: threshold = iou > t