diff --git a/mmrotate/core/__init__.py b/mmrotate/core/__init__.py index 126eb69ea..c8164cf94 100644 --- a/mmrotate/core/__init__.py +++ b/mmrotate/core/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .anchor import * # noqa: F401, F403 from .bbox import * # noqa: F401, F403 +from .evaluation import * # noqa: F401, F403 from .patch import * # noqa: F401, F403 from .post_processing import * # noqa: F401, F403 from .visualization import * # noqa: F401, F403 diff --git a/mmrotate/core/evaluation/__init__.py b/mmrotate/core/evaluation/__init__.py new file mode 100644 index 000000000..140bd14af --- /dev/null +++ b/mmrotate/core/evaluation/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .eval_map import eval_rbbox_map + +__all__ = ['eval_rbbox_map'] diff --git a/mmrotate/core/evaluation/eval_map.py b/mmrotate/core/evaluation/eval_map.py new file mode 100644 index 000000000..8c2d05c09 --- /dev/null +++ b/mmrotate/core/evaluation/eval_map.py @@ -0,0 +1,311 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing import get_context + +import numpy as np +import torch +from mmcv.ops import box_iou_rotated +from mmcv.utils import print_log +from mmdet.core import average_precision +from terminaltables import AsciiTable + + +def tpfp_default(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + iou_thr=0.5, + area_ranges=None): + """Check if detected bboxes are true positive or false positive. + + Args: + det_bboxes (ndarray): Detected bboxes of this image, of shape (m, 6). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 5). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 5). Default: None + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. Default: None. + + Returns: + tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of + each array is (num_scales, m). + """ + # an indicator of ignored gts + det_bboxes = np.array(det_bboxes) + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], dtype=np.bool), + np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + + num_dets = det_bboxes.shape[0] + num_gts = gt_bboxes.shape[0] + if area_ranges is None: + area_ranges = [(None, None)] + num_scales = len(area_ranges) + # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of + # a certain scale + tp = np.zeros((num_scales, num_dets), dtype=np.float32) + fp = np.zeros((num_scales, num_dets), dtype=np.float32) + + # if there is no gt bboxes in this image, then all det bboxes + # within area range are false positives + if gt_bboxes.shape[0] == 0: + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + raise NotImplementedError + return tp, fp + + ious = box_iou_rotated( + torch.from_numpy(det_bboxes).float(), + torch.from_numpy(gt_bboxes).float()).numpy() + # for each det, the max iou with all gts + ious_max = ious.max(axis=1) + # for each det, which gt overlaps most with it + ious_argmax = ious.argmax(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + gt_covered = np.zeros(num_gts, dtype=bool) + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + raise NotImplementedError + for i in sort_inds: + if ious_max[i] >= iou_thr: + matched_gt = ious_argmax[i] + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + if not gt_covered[matched_gt]: + gt_covered[matched_gt] = True + tp[k, i] = 1 + else: + fp[k, i] = 1 + # otherwise ignore this detected bbox, tp = 0, fp = 0 + elif min_area is None: + fp[k, i] = 1 + else: + bbox = det_bboxes[i, :5] + area = bbox[2] * bbox[3] + if area >= min_area and area < max_area: + fp[k, i] = 1 + return tp, fp + + +def get_cls_results(det_results, annotations, class_id): + """Get det results and gt information of a certain class. + + Args: + det_results (list[list]): Same as `eval_map()`. + annotations (list[dict]): Same as `eval_map()`. + class_id (int): ID of a specific class. + + Returns: + tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes + """ + cls_dets = [img_res[class_id] for img_res in det_results] + + cls_gts = [] + cls_gts_ignore = [] + for ann in annotations: + gt_inds = ann['labels'] == class_id + cls_gts.append(ann['bboxes'][gt_inds, :]) + + if ann.get('labels_ignore', None) is not None: + ignore_inds = ann['labels_ignore'] == class_id + cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :]) + + else: + cls_gts_ignore.append(torch.zeros((0, 6), dtype=torch.float64)) + + return cls_dets, cls_gts, cls_gts_ignore + + +def eval_rbbox_map(det_results, + annotations, + scale_ranges=None, + iou_thr=0.5, + use_07_metric=True, + dataset=None, + logger=None, + nproc=4): + """Evaluate mAP of a rotated dataset. + + Args: + det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. + The outer list indicates images, and the inner list indicates + per-class detected bboxes. + annotations (list[dict]): Ground truth annotations where each item of + the list indicates an image. Keys of annotations are: + + - `bboxes`: numpy array of shape (n, 5) + - `labels`: numpy array of shape (n, ) + - `bboxes_ignore` (optional): numpy array of shape (k, 5) + - `labels_ignore` (optional): numpy array of shape (k, ) + scale_ranges (list[tuple] | None): Range of scales to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. A range of + (32, 64) means the area range between (32**2, 64**2). + Default: None. + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + use_07_metric (bool): Whether to use the voc07 metric. + dataset (list[str] | str | None): Dataset name or dataset classes, + there are minor differences in metrics for different datasets, e.g. + "voc07", "imagenet_det", etc. Default: None. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmcv.utils.print_log()` for details. Default: None. + nproc (int): Processes used for computing TP and FP. + Default: 4. + + Returns: + tuple: (mAP, [dict, dict, ...]) + """ + assert len(det_results) == len(annotations) + + num_imgs = len(det_results) + num_scales = len(scale_ranges) if scale_ranges is not None else 1 + num_classes = len(det_results[0]) # positive class num + area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] + if scale_ranges is not None else None) + + pool = get_context('spawn').Pool(nproc) + eval_results = [] + for i in range(num_classes): + # get gt and det bboxes of this class + cls_dets, cls_gts, cls_gts_ignore = get_cls_results( + det_results, annotations, i) + + # compute tp and fp for each image with multiple processes + tpfp = pool.starmap( + tpfp_default, + zip(cls_dets, cls_gts, cls_gts_ignore, + [iou_thr for _ in range(num_imgs)], + [area_ranges for _ in range(num_imgs)])) + tp, fp = tuple(zip(*tpfp)) + # calculate gt number of each scale + # ignored gts or gts beyond the specific scale are not counted + num_gts = np.zeros(num_scales, dtype=int) + for _, bbox in enumerate(cls_gts): + if area_ranges is None: + num_gts[0] += bbox.shape[0] + else: + gt_areas = bbox[:, 2] * bbox[:, 3] + for k, (min_area, max_area) in enumerate(area_ranges): + num_gts[k] += np.sum((gt_areas >= min_area) + & (gt_areas < max_area)) + # sort all det bboxes by score, also sort tp and fp + cls_dets = np.vstack(cls_dets) + num_dets = cls_dets.shape[0] + sort_inds = np.argsort(-cls_dets[:, -1]) + tp = np.hstack(tp)[:, sort_inds] + fp = np.hstack(fp)[:, sort_inds] + # calculate recall and precision with tp and fp + tp = np.cumsum(tp, axis=1) + fp = np.cumsum(fp, axis=1) + eps = np.finfo(np.float32).eps + recalls = tp / np.maximum(num_gts[:, np.newaxis], eps) + precisions = tp / np.maximum((tp + fp), eps) + # calculate AP + if scale_ranges is None: + recalls = recalls[0, :] + precisions = precisions[0, :] + num_gts = num_gts.item() + mode = 'area' if not use_07_metric else '11points' + ap = average_precision(recalls, precisions, mode) + eval_results.append({ + 'num_gts': num_gts, + 'num_dets': num_dets, + 'recall': recalls, + 'precision': precisions, + 'ap': ap + }) + pool.close() + if scale_ranges is not None: + # shape (num_classes, num_scales) + all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results]) + all_num_gts = np.vstack( + [cls_result['num_gts'] for cls_result in eval_results]) + mean_ap = [] + for i in range(num_scales): + if np.any(all_num_gts[:, i] > 0): + mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean()) + else: + mean_ap.append(0.0) + else: + aps = [] + for cls_result in eval_results: + if cls_result['num_gts'] > 0: + aps.append(cls_result['ap']) + mean_ap = np.array(aps).mean().item() if aps else 0.0 + + print_map_summary( + mean_ap, eval_results, dataset, area_ranges, logger=logger) + + return mean_ap, eval_results + + +def print_map_summary(mean_ap, + results, + dataset=None, + scale_ranges=None, + logger=None): + """Print mAP and results of each class. + + A table will be printed to show the gts/dets/recall/AP of each class and + the mAP. + Args: + mean_ap (float): Calculated from `eval_map()`. + results (list[dict]): Calculated from `eval_map()`. + dataset (list[str] | str | None): Dataset name or dataset classes. + scale_ranges (list[tuple] | None): Range of scales to be evaluated. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmcv.utils.print_log()` for details. Default: None. + """ + + if logger == 'silent': + return + + if isinstance(results[0]['ap'], np.ndarray): + num_scales = len(results[0]['ap']) + else: + num_scales = 1 + + if scale_ranges is not None: + assert len(scale_ranges) == num_scales + + num_classes = len(results) + + recalls = np.zeros((num_scales, num_classes), dtype=np.float32) + aps = np.zeros((num_scales, num_classes), dtype=np.float32) + num_gts = np.zeros((num_scales, num_classes), dtype=int) + for i, cls_result in enumerate(results): + if cls_result['recall'].size > 0: + recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] + aps[:, i] = cls_result['ap'] + num_gts[:, i] = cls_result['num_gts'] + + if dataset is None: + label_names = [str(i) for i in range(num_classes)] + else: + label_names = dataset + + if not isinstance(mean_ap, list): + mean_ap = [mean_ap] + + header = ['class', 'gts', 'dets', 'recall', 'ap'] + for i in range(num_scales): + if scale_ranges is not None: + print_log(f'Scale range {scale_ranges[i]}', logger=logger) + table_data = [header] + for j in range(num_classes): + row_data = [ + label_names[j], num_gts[i, j], results[j]['num_dets'], + f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}' + ] + table_data.append(row_data) + table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}']) + table = AsciiTable(table_data) + table.inner_footing_row_border = True + print_log('\n' + table.table, logger=logger) diff --git a/mmrotate/datasets/dota.py b/mmrotate/datasets/dota.py index dac5df710..5b504051d 100644 --- a/mmrotate/datasets/dota.py +++ b/mmrotate/datasets/dota.py @@ -8,18 +8,15 @@ import zipfile from collections import defaultdict from functools import partial -from multiprocessing import get_context import mmcv import numpy as np import torch -from mmcv.ops import box_iou_rotated, nms_rotated -from mmcv.utils import print_log -from mmdet.core.evaluation import average_precision +from mmcv.ops import nms_rotated from mmdet.datasets.custom import CustomDataset -from terminaltables import AsciiTable from mmrotate.core import obb2poly_np, poly2obb_np +from mmrotate.core.evaluation import eval_rbbox_map from .builder import ROTATED_DATASETS @@ -196,7 +193,7 @@ def evaluate(self, eval_results = {} if metric == 'mAP': assert isinstance(iou_thr, float) - mean_ap, _ = eval_map( + mean_ap, _ = eval_rbbox_map( results, annotations, scale_ranges=scale_ranges, @@ -330,311 +327,6 @@ def format_results(self, results, submission_dir=None, nproc=4, **kwargs): return result_files, tmp_dir -def eval_map(det_results, - annotations, - scale_ranges=None, - iou_thr=0.5, - dataset=None, - logger=None, - nproc=4): - """Evaluate mAP of a dataset. - - Args: - det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. - The outer list indicates images, and the inner list indicates - per-class detected bboxes. - annotations (list[dict]): Ground truth annotations where each item of - the list indicates an image. Keys of annotations are: - - - `bboxes`: numpy array of shape (n, 4) - - `labels`: numpy array of shape (n, ) - - `bboxes_ignore` (optional): numpy array of shape (k, 4) - - `labels_ignore` (optional): numpy array of shape (k, ) - scale_ranges (list[tuple] | None): Range of scales to be evaluated, - in the format [(min1, max1), (min2, max2), ...]. A range of - (32, 64) means the area range between (32**2, 64**2). - Default: None. - iou_thr (float): IoU threshold to be considered as matched. - Default: 0.5. - dataset (list[str] | str | None): Dataset name or dataset classes, - there are minor differences in metrics for different datasets, e.g. - "voc07", "imagenet_det", etc. Default: None. - logger (logging.Logger | str | None): The way to print the mAP - summary. See `mmcv.utils.print_log()` for details. Default: None. - tpfp_fn (callable | None): The function used to determine true/ - false positives. If None, :func:`tpfp_default` is used as default - unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this - case). If it is given as a function, then this function is used - to evaluate tp & fp. Default None. - nproc (int): Processes used for computing TP and FP. - Default: 4. - - Returns: - tuple: (mAP, [dict, dict, ...]) - """ - assert len(det_results) == len(annotations) - - num_imgs = len(det_results) - num_scales = len(scale_ranges) if scale_ranges is not None else 1 - num_classes = len(det_results[0]) # positive class num - area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] - if scale_ranges is not None else None) - - pool = get_context('spawn').Pool(nproc) - eval_results = [] - for i in range(num_classes): - # get gt and det bboxes of this class - cls_dets, cls_gts, cls_gts_ignore = get_cls_results( - det_results, annotations, i) - - # compute tp and fp for each image with multiple processes - tpfp = pool.starmap( - tpfp_default, - zip(cls_dets, cls_gts, cls_gts_ignore, - [iou_thr for _ in range(num_imgs)], - [area_ranges for _ in range(num_imgs)])) - tp, fp = tuple(zip(*tpfp)) - # calculate gt number of each scale - # ignored gts or gts beyond the specific scale are not counted - num_gts = np.zeros(num_scales, dtype=int) - for _, bbox in enumerate(cls_gts): - if area_ranges is None: - num_gts[0] += bbox.shape[0] - else: - gt_areas = (bbox[:, 2] - bbox[:, 0]) * ( - bbox[:, 3] - bbox[:, 1]) - for k, (min_area, max_area) in enumerate(area_ranges): - num_gts[k] += np.sum((gt_areas >= min_area) - & (gt_areas < max_area)) - # sort all det bboxes by score, also sort tp and fp - cls_dets = np.vstack(cls_dets) - num_dets = cls_dets.shape[0] - sort_inds = np.argsort(-cls_dets[:, -1]) - tp = np.hstack(tp)[:, sort_inds] - fp = np.hstack(fp)[:, sort_inds] - # calculate recall and precision with tp and fp - tp = np.cumsum(tp, axis=1) - fp = np.cumsum(fp, axis=1) - eps = np.finfo(np.float32).eps - recalls = tp / np.maximum(num_gts[:, np.newaxis], eps) - precisions = tp / np.maximum((tp + fp), eps) - # calculate AP - if scale_ranges is None: - recalls = recalls[0, :] - precisions = precisions[0, :] - num_gts = num_gts.item() - mode = 'area' if dataset != 'voc07' else '11points' - ap = average_precision(recalls, precisions, mode) - eval_results.append({ - 'num_gts': num_gts, - 'num_dets': num_dets, - 'recall': recalls, - 'precision': precisions, - 'ap': ap - }) - pool.close() - if scale_ranges is not None: - all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results]) - all_num_gts = np.vstack( - [cls_result['num_gts'] for cls_result in eval_results]) - mean_ap = [] - for i in range(num_scales): - if np.any(all_num_gts[:, i] > 0): - mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean()) - else: - mean_ap.append(0.0) - else: - aps = [] - for cls_result in eval_results: - if cls_result['num_gts'] > 0: - aps.append(cls_result['ap']) - mean_ap = np.array(aps).mean().item() if aps else 0.0 - - print_map_summary( - mean_ap, eval_results, dataset, area_ranges, logger=logger) - - return mean_ap, eval_results - - -def print_map_summary(mean_ap, - results, - dataset=None, - scale_ranges=None, - logger=None): - """Print mAP and results of each class. - - A table will be printed to show the gts/dets/recall/AP of each class and - the mAP. - - Args: - mean_ap (float): Calculated from `eval_map()`. - results (list[dict]): Calculated from `eval_map()`. - dataset (list[str] | str | None): Dataset name or dataset classes. - scale_ranges (list[tuple] | None): Range of scales to be evaluated. - logger (logging.Logger | str | None): The way to print the mAP - summary. See `mmcv.utils.print_log()` for details. Default: None. - """ - - if logger == 'silent': - return - - if isinstance(results[0]['ap'], np.ndarray): - num_scales = len(results[0]['ap']) - else: - num_scales = 1 - - if scale_ranges is not None: - assert len(scale_ranges) == num_scales - - num_classes = len(results) - - recalls = np.zeros((num_scales, num_classes), dtype=np.float32) - aps = np.zeros((num_scales, num_classes), dtype=np.float32) - num_gts = np.zeros((num_scales, num_classes), dtype=int) - for i, cls_result in enumerate(results): - if cls_result['recall'].size > 0: - recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] - aps[:, i] = cls_result['ap'] - num_gts[:, i] = cls_result['num_gts'] - - if dataset is None: - label_names = [str(i) for i in range(num_classes)] - else: - label_names = dataset - - if not isinstance(mean_ap, list): - mean_ap = [mean_ap] - - header = ['class', 'gts', 'dets', 'recall', 'ap'] - for i in range(num_scales): - if scale_ranges is not None: - print_log(f'Scale range {scale_ranges[i]}', logger=logger) - table_data = [header] - for j in range(num_classes): - row_data = [ - label_names[j], num_gts[i, j], results[j]['num_dets'], - f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}' - ] - table_data.append(row_data) - table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}']) - table = AsciiTable(table_data) - table.inner_footing_row_border = True - print_log('\n' + table.table, logger=logger) - - -def tpfp_default(det_bboxes, - gt_bboxes, - gt_bboxes_ignore=None, - iou_thr=0.5, - area_ranges=None): - """Check if detected bboxes are true positive or false positive. - - Args: - det_bboxes (ndarray): Detected bboxes of this image, of shape (m, 9). - gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 8). - gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, - of shape (k, 8). Default: None - iou_thr (float): IoU threshold to be considered as matched. - Default: 0.5. - area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, - in the format [(min1, max1), (min2, max2), ...]. Default: None. - - Returns: - tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of - each array is (num_scales, m). - """ - # an indicator of ignored gts - det_bboxes = np.array(det_bboxes) - gt_ignore_inds = np.concatenate( - (np.zeros(gt_bboxes.shape[0], dtype=np.bool), - np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) - # stack gt_bboxes and gt_bboxes_ignore for convenience - gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) - num_dets = det_bboxes.shape[0] - num_gts = gt_bboxes.shape[0] - if area_ranges is None: - area_ranges = [(None, None)] - num_scales = len(area_ranges) - # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of - # a certain scale - tp = np.zeros((num_scales, num_dets), dtype=np.float32) - fp = np.zeros((num_scales, num_dets), dtype=np.float32) - - # if there is no gt bboxes in this image, then all det bboxes - # within area range are false positives - if gt_bboxes.shape[0] == 0: - if area_ranges == [(None, None)]: - fp[...] = 1 - else: - raise NotImplementedError - return tp, fp - - ious = box_iou_rotated( - torch.from_numpy(det_bboxes).float(), - torch.from_numpy(gt_bboxes).float()).numpy() - # for each det, the max iou with all gts - ious_max = ious.max(axis=1) - # for each det, which gt overlaps most with it - ious_argmax = ious.argmax(axis=1) - # sort all dets in descending order by scores - sort_inds = np.argsort(-det_bboxes[:, -1]) - for k, (min_area, max_area) in enumerate(area_ranges): - gt_covered = np.zeros(num_gts, dtype=bool) - # if no area range is specified, gt_area_ignore is all False - if min_area is None: - gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) - else: - raise NotImplementedError - for i in sort_inds: - if ious_max[i] >= iou_thr: - matched_gt = ious_argmax[i] - if not (gt_ignore_inds[matched_gt] - or gt_area_ignore[matched_gt]): - if not gt_covered[matched_gt]: - gt_covered[matched_gt] = True - tp[k, i] = 1 - else: - fp[k, i] = 1 - # otherwise ignore this detected bbox, tp = 0, fp = 0 - elif min_area is None: - fp[k, i] = 1 - else: - bbox = det_bboxes[i, :5] - area = bbox[2] * bbox[3] - if area >= min_area and area < max_area: - fp[k, i] = 1 - return tp, fp - - -def get_cls_results(det_results, annotations, class_id): - """Get det results and gt information of a certain class. - - Args: - det_results (list[list]): Same as `eval_map()`. - annotations (list[dict]): Same as `eval_map()`. - class_id (int): ID of a specific class. - - Returns: - tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes - """ - cls_dets = [img_res[class_id] for img_res in det_results] - - cls_gts = [] - cls_gts_ignore = [] - for ann in annotations: - gt_inds = ann['labels'] == class_id - cls_gts.append(ann['bboxes'][gt_inds, :]) - - if ann.get('labels_ignore', None) is not None: - ignore_inds = ann['labels_ignore'] == class_id - cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :]) - - else: - cls_gts_ignore.append(torch.zeros((0, 6), dtype=torch.float64)) - - return cls_dets, cls_gts, cls_gts_ignore - - def _merge_func(info, CLASSES, iou_thr): """Merging patch bboxes into full image.