From e1c41805fd50e2b7792d163ac8418709b5d00479 Mon Sep 17 00:00:00 2001 From: liuyanyi Date: Sat, 26 Feb 2022 14:06:27 +0800 Subject: [PATCH] Add hrsc dataset. Move eval_map to mmrotate.core.evaluation --- configs/_base_/datasets/hrsc.py | 53 +++ ...ated_retinanet_obb_r50_fpn_3x_hrsc_le90.py | 93 ++++++ mmrotate/core/__init__.py | 1 + mmrotate/core/evaluation/__init__.py | 5 + mmrotate/core/evaluation/eval_map.py | 311 ++++++++++++++++++ mmrotate/core/evaluation/recall.py | 187 +++++++++++ mmrotate/datasets/__init__.py | 3 +- mmrotate/datasets/hrsc.py | 264 +++++++++++++++ 8 files changed, 916 insertions(+), 1 deletion(-) create mode 100644 configs/_base_/datasets/hrsc.py create mode 100644 configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_3x_hrsc_le90.py create mode 100644 mmrotate/core/evaluation/__init__.py create mode 100644 mmrotate/core/evaluation/eval_map.py create mode 100644 mmrotate/core/evaluation/recall.py create mode 100644 mmrotate/datasets/hrsc.py diff --git a/configs/_base_/datasets/hrsc.py b/configs/_base_/datasets/hrsc.py new file mode 100644 index 000000000..10e5ba406 --- /dev/null +++ b/configs/_base_/datasets/hrsc.py @@ -0,0 +1,53 @@ +# dataset settings +dataset_type = 'HRSCDataset' +data_root = '/datasets/hrsc/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RResize', img_scale=(800, 800)), + dict(type='RRandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(800, 800), + flip=False, + transforms=[ + dict(type='RResize'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + ann_file=data_root + 'ImageSets/trainval.txt', + classwise=False, + ann_subdir=data_root + 'FullDataSet/Annotations/', + img_subdir=data_root + 'FullDataSet/AllImages/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=data_root + 'ImageSets/trainval.txt', + classwise=False, + ann_subdir=data_root + 'FullDataSet/Annotations/', + img_subdir=data_root + 'FullDataSet/AllImages/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=data_root + 'ImageSets/trainval.txt', + classwise=False, + ann_subdir=data_root + 'FullDataSet/Annotations/', + img_subdir=data_root + 'FullDataSet/AllImages/', + pipeline=test_pipeline)) diff --git a/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_3x_hrsc_le90.py b/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_3x_hrsc_le90.py new file mode 100644 index 000000000..394160765 --- /dev/null +++ b/configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_3x_hrsc_le90.py @@ -0,0 +1,93 @@ +_base_ = [ + '../_base_/datasets/hrsc.py', '../_base_/schedules/schedule_3x.py', + '../_base_/default_runtime.py' +] +fp16 = dict(loss_scale=dict(init_scale=512)) + +angle_version = 'le90' +model = dict( + type='RotatedRetinaNet', + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + zero_init_residual=False, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + start_level=1, + add_extra_convs='on_input', + num_outs=5), + bbox_head=dict( + type='RotatedRetinaHead', + num_classes=31, + in_channels=256, + stacked_convs=4, + feat_channels=256, + assign_by_circumhbbox=None, + anchor_generator=dict( + type='RotatedAnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[1.0, 0.5, 2.0], + strides=[8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHAOBBoxCoder', + angle_range=angle_version, + norm_factor=None, + edge_swap=True, + proj_xy=True, + target_means=(.0, .0, .0, .0, .0), + target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + train_cfg=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.4, + min_pos_iou=0, + ignore_iof_thr=-1, + iou_calculator=dict(type='RBboxOverlaps2D')), + allowed_border=-1, + pos_weight=-1, + debug=False), + test_cfg=dict( + nms_pre=2000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(iou_thr=0.1), + max_per_img=2000)) + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RResize', img_scale=(800, 800)), + dict( + type='RRandomFlip', + flip_ratio=[0.25, 0.25, 0.25], + direction=['horizontal', 'vertical', 'diagonal'], + version=angle_version), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +data = dict( + train=dict(pipeline=train_pipeline, version=angle_version), + val=dict(version=angle_version), + test=dict(version=angle_version)) diff --git a/mmrotate/core/__init__.py b/mmrotate/core/__init__.py index 126eb69ea..c8164cf94 100644 --- a/mmrotate/core/__init__.py +++ b/mmrotate/core/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .anchor import * # noqa: F401, F403 from .bbox import * # noqa: F401, F403 +from .evaluation import * # noqa: F401, F403 from .patch import * # noqa: F401, F403 from .post_processing import * # noqa: F401, F403 from .visualization import * # noqa: F401, F403 diff --git a/mmrotate/core/evaluation/__init__.py b/mmrotate/core/evaluation/__init__.py new file mode 100644 index 000000000..702cc570c --- /dev/null +++ b/mmrotate/core/evaluation/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .eval_map import eval_map +from .recall import eval_recalls + +__all__ = ['eval_map', 'eval_recalls'] diff --git a/mmrotate/core/evaluation/eval_map.py b/mmrotate/core/evaluation/eval_map.py new file mode 100644 index 000000000..1fbe4a460 --- /dev/null +++ b/mmrotate/core/evaluation/eval_map.py @@ -0,0 +1,311 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from multiprocessing import Pool + +import numpy as np +import torch +from mmcv.ops import box_iou_rotated +from mmcv.utils import print_log +from mmdet.core import average_precision +from terminaltables import AsciiTable + + +def tpfp_default(det_bboxes, + gt_bboxes, + gt_bboxes_ignore=None, + iou_thr=0.5, + area_ranges=None): + """Check if detected bboxes are true positive or false positive. + + Args: + det_bboxes (ndarray): Detected bboxes of this image, of shape (m, 9). + gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 8). + gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image, + of shape (k, 8). Default: None + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + area_ranges (list[tuple] | None): Range of bbox areas to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. Default: None. + + Returns: + tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of + each array is (num_scales, m). + """ + # an indicator of ignored gts + det_bboxes = np.array(det_bboxes) + gt_ignore_inds = np.concatenate( + (np.zeros(gt_bboxes.shape[0], dtype=np.bool), + np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool))) + # stack gt_bboxes and gt_bboxes_ignore for convenience + gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore)) + + num_dets = det_bboxes.shape[0] + num_gts = gt_bboxes.shape[0] + if area_ranges is None: + area_ranges = [(None, None)] + num_scales = len(area_ranges) + # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of + # a certain scale + tp = np.zeros((num_scales, num_dets), dtype=np.float32) + fp = np.zeros((num_scales, num_dets), dtype=np.float32) + + # if there is no gt bboxes in this image, then all det bboxes + # within area range are false positives + if gt_bboxes.shape[0] == 0: + if area_ranges == [(None, None)]: + fp[...] = 1 + else: + raise NotImplementedError + return tp, fp + + ious = box_iou_rotated( + torch.from_numpy(det_bboxes).float(), + torch.from_numpy(gt_bboxes).float()).numpy() + # for each det, the max iou with all gts + ious_max = ious.max(axis=1) + # for each det, which gt overlaps most with it + ious_argmax = ious.argmax(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-det_bboxes[:, -1]) + for k, (min_area, max_area) in enumerate(area_ranges): + gt_covered = np.zeros(num_gts, dtype=bool) + # if no area range is specified, gt_area_ignore is all False + if min_area is None: + gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool) + else: + raise NotImplementedError + for i in sort_inds: + if ious_max[i] >= iou_thr: + matched_gt = ious_argmax[i] + if not (gt_ignore_inds[matched_gt] + or gt_area_ignore[matched_gt]): + if not gt_covered[matched_gt]: + gt_covered[matched_gt] = True + tp[k, i] = 1 + else: + fp[k, i] = 1 + # otherwise ignore this detected bbox, tp = 0, fp = 0 + elif min_area is None: + fp[k, i] = 1 + else: + bbox = det_bboxes[i, :5] + area = bbox[2] * bbox[3] + if area >= min_area and area < max_area: + fp[k, i] = 1 + return tp, fp + + +def get_cls_results(det_results, annotations, class_id): + """Get det results and gt information of a certain class. + + Args: + det_results (list[list]): Same as `eval_map()`. + annotations (list[dict]): Same as `eval_map()`. + class_id (int): ID of a specific class. + + Returns: + tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes + """ + cls_dets = [img_res[class_id] for img_res in det_results] + + cls_gts = [] + cls_gts_ignore = [] + for ann in annotations: + gt_inds = ann['labels'] == class_id + cls_gts.append(ann['bboxes'][gt_inds, :]) + + if ann.get('labels_ignore', None) is not None: + ignore_inds = ann['labels_ignore'] == class_id + cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :]) + + else: + cls_gts_ignore.append(torch.zeros((0, 6), dtype=torch.float64)) + + return cls_dets, cls_gts, cls_gts_ignore + + +def eval_map(det_results, + annotations, + scale_ranges=None, + iou_thr=0.5, + use_07_metric=True, + dataset=None, + logger=None, + nproc=4): + """Evaluate mAP of a dataset. + + Args: + det_results (list[list]): [[cls1_det, cls2_det, ...], ...]. + The outer list indicates images, and the inner list indicates + per-class detected bboxes. + annotations (list[dict]): Ground truth annotations where each item of + the list indicates an image. Keys of annotations are: + + - `bboxes`: numpy array of shape (n, 4) + - `labels`: numpy array of shape (n, ) + - `bboxes_ignore` (optional): numpy array of shape (k, 4) + - `labels_ignore` (optional): numpy array of shape (k, ) + scale_ranges (list[tuple] | None): Range of scales to be evaluated, + in the format [(min1, max1), (min2, max2), ...]. A range of + (32, 64) means the area range between (32**2, 64**2). + Default: None. + iou_thr (float): IoU threshold to be considered as matched. + Default: 0.5. + use_07_metric (bool): Whether to use the voc07 metric. + dataset (list[str] | str | None): Dataset name or dataset classes, + there are minor differences in metrics for different datasets, e.g. + "voc07", "imagenet_det", etc. Default: None. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmcv.utils.print_log()` for details. Default: None. + nproc (int): Processes used for computing TP and FP. + Default: 4. + + Returns: + tuple: (mAP, [dict, dict, ...]) + """ + assert len(det_results) == len(annotations) + + num_imgs = len(det_results) + num_scales = len(scale_ranges) if scale_ranges is not None else 1 + num_classes = len(det_results[0]) # positive class num + area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges] + if scale_ranges is not None else None) + + pool = Pool(nproc) + eval_results = [] + for i in range(num_classes): + # get gt and det bboxes of this class + cls_dets, cls_gts, cls_gts_ignore = get_cls_results( + det_results, annotations, i) + + # compute tp and fp for each image with multiple processes + tpfp = pool.starmap( + tpfp_default, + zip(cls_dets, cls_gts, cls_gts_ignore, + [iou_thr for _ in range(num_imgs)], + [area_ranges for _ in range(num_imgs)])) + tp, fp = tuple(zip(*tpfp)) + # calculate gt number of each scale + # ignored gts or gts beyond the specific scale are not counted + num_gts = np.zeros(num_scales, dtype=int) + for _, bbox in enumerate(cls_gts): + if area_ranges is None: + num_gts[0] += bbox.shape[0] + else: + gt_areas = bbox[:, 2] * bbox[:, 3] + for k, (min_area, max_area) in enumerate(area_ranges): + num_gts[k] += np.sum((gt_areas >= min_area) + & (gt_areas < max_area)) + # sort all det bboxes by score, also sort tp and fp + cls_dets = np.vstack(cls_dets) + num_dets = cls_dets.shape[0] + sort_inds = np.argsort(-cls_dets[:, -1]) + tp = np.hstack(tp)[:, sort_inds] + fp = np.hstack(fp)[:, sort_inds] + # calculate recall and precision with tp and fp + tp = np.cumsum(tp, axis=1) + fp = np.cumsum(fp, axis=1) + eps = np.finfo(np.float32).eps + recalls = tp / np.maximum(num_gts[:, np.newaxis], eps) + precisions = tp / np.maximum((tp + fp), eps) + # calculate AP + if scale_ranges is None: + recalls = recalls[0, :] + precisions = precisions[0, :] + num_gts = num_gts.item() + mode = 'area' if not use_07_metric else '11points' + ap = average_precision(recalls, precisions, mode) + eval_results.append({ + 'num_gts': num_gts, + 'num_dets': num_dets, + 'recall': recalls, + 'precision': precisions, + 'ap': ap + }) + pool.close() + if scale_ranges is not None: + # shape (num_classes, num_scales) + all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results]) + all_num_gts = np.vstack( + [cls_result['num_gts'] for cls_result in eval_results]) + mean_ap = [] + for i in range(num_scales): + if np.any(all_num_gts[:, i] > 0): + mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean()) + else: + mean_ap.append(0.0) + else: + aps = [] + for cls_result in eval_results: + if cls_result['num_gts'] > 0: + aps.append(cls_result['ap']) + mean_ap = np.array(aps).mean().item() if aps else 0.0 + + print_map_summary( + mean_ap, eval_results, dataset, area_ranges, logger=logger) + + return mean_ap, eval_results + + +def print_map_summary(mean_ap, + results, + dataset=None, + scale_ranges=None, + logger=None): + """Print mAP and results of each class. + + A table will be printed to show the gts/dets/recall/AP of each class and + the mAP. + Args: + mean_ap (float): Calculated from `eval_map()`. + results (list[dict]): Calculated from `eval_map()`. + dataset (list[str] | str | None): Dataset name or dataset classes. + scale_ranges (list[tuple] | None): Range of scales to be evaluated. + logger (logging.Logger | str | None): The way to print the mAP + summary. See `mmcv.utils.print_log()` for details. Default: None. + """ + + if logger == 'silent': + return + + if isinstance(results[0]['ap'], np.ndarray): + num_scales = len(results[0]['ap']) + else: + num_scales = 1 + + if scale_ranges is not None: + assert len(scale_ranges) == num_scales + + num_classes = len(results) + + recalls = np.zeros((num_scales, num_classes), dtype=np.float32) + aps = np.zeros((num_scales, num_classes), dtype=np.float32) + num_gts = np.zeros((num_scales, num_classes), dtype=int) + for i, cls_result in enumerate(results): + if cls_result['recall'].size > 0: + recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1] + aps[:, i] = cls_result['ap'] + num_gts[:, i] = cls_result['num_gts'] + + if dataset is None: + label_names = [str(i) for i in range(num_classes)] + else: + label_names = dataset + + if not isinstance(mean_ap, list): + mean_ap = [mean_ap] + + header = ['class', 'gts', 'dets', 'recall', 'ap'] + for i in range(num_scales): + if scale_ranges is not None: + print_log(f'Scale range {scale_ranges[i]}', logger=logger) + table_data = [header] + for j in range(num_classes): + row_data = [ + label_names[j], num_gts[i, j], results[j]['num_dets'], + f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}' + ] + table_data.append(row_data) + table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}']) + table = AsciiTable(table_data) + table.inner_footing_row_border = True + print_log('\n' + table.table, logger=logger) diff --git a/mmrotate/core/evaluation/recall.py b/mmrotate/core/evaluation/recall.py new file mode 100644 index 000000000..f58c2ed4a --- /dev/null +++ b/mmrotate/core/evaluation/recall.py @@ -0,0 +1,187 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections.abc import Sequence + +import numpy as np +from mmcv.utils import print_log +from terminaltables import AsciiTable + +from ..bbox import rbbox_overlaps + + +def _recalls(all_ious, proposal_nums, thrs): + img_num = all_ious.shape[0] + total_gt_num = sum([ious.shape[0] for ious in all_ious]) + + _ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32) + for k, proposal_num in enumerate(proposal_nums): + tmp_ious = np.zeros(0) + for i in range(img_num): + ious = all_ious[i][:, :proposal_num].copy() + gt_ious = np.zeros((ious.shape[0])) + if ious.size == 0: + tmp_ious = np.hstack((tmp_ious, gt_ious)) + continue + for j in range(ious.shape[0]): + gt_max_overlaps = ious.argmax(axis=1) + max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps] + gt_idx = max_ious.argmax() + gt_ious[j] = max_ious[gt_idx] + box_idx = gt_max_overlaps[gt_idx] + ious[gt_idx, :] = -1 + ious[:, box_idx] = -1 + tmp_ious = np.hstack((tmp_ious, gt_ious)) + _ious[k, :] = tmp_ious + + _ious = np.fliplr(np.sort(_ious, axis=1)) + recalls = np.zeros((proposal_nums.size, thrs.size)) + for i, thr in enumerate(thrs): + recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num) + + return recalls + + +def set_recall_param(proposal_nums, iou_thrs): + """Check proposal_nums and iou_thrs and set correct format.""" + if isinstance(proposal_nums, Sequence): + _proposal_nums = np.array(proposal_nums) + elif isinstance(proposal_nums, int): + _proposal_nums = np.array([proposal_nums]) + else: + _proposal_nums = proposal_nums + + if iou_thrs is None: + _iou_thrs = np.array([0.5]) + elif isinstance(iou_thrs, Sequence): + _iou_thrs = np.array(iou_thrs) + elif isinstance(iou_thrs, float): + _iou_thrs = np.array([iou_thrs]) + else: + _iou_thrs = iou_thrs + + return _proposal_nums, _iou_thrs + + +def eval_recalls(gts, + proposals, + proposal_nums=None, + iou_thrs=0.5, + logger=None): + """Calculate recalls. + + Args: + gts (list[ndarray]): a list of arrays of shape (n, 4) + proposals (list[ndarray]): a list of arrays of shape (k, 4) or (k, 5) + proposal_nums (int | Sequence[int]): Top N proposals to be evaluated. + iou_thrs (float | Sequence[float]): IoU thresholds. Default: 0.5. + logger (logging.Logger | str | None): The way to print the recall + summary. See `mmcv.utils.print_log()` for details. Default: None. + + Returns: + ndarray: recalls of different ious and proposal nums + """ + + img_num = len(gts) + assert img_num == len(proposals) + proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs) + all_ious = [] + for i in range(img_num): + if proposals[i].ndim == 2 and proposals[i].shape[1] == 6: + scores = proposals[i][:, 5] + sort_idx = np.argsort(scores)[::-1] + img_proposal = proposals[i][sort_idx, :] + else: + img_proposal = proposals[i] + prop_num = min(img_proposal.shape[0], proposal_nums[-1]) + if gts[i] is None or gts[i].shape[0] == 0: + ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32) + else: + ious = rbbox_overlaps(gts[i], img_proposal[:prop_num, :5]) + all_ious.append(ious) + all_ious = np.array(all_ious) + recalls = _recalls(all_ious, proposal_nums, iou_thrs) + + print_recall_summary(recalls, proposal_nums, iou_thrs, logger=logger) + return recalls + + +def print_recall_summary(recalls, + proposal_nums, + iou_thrs, + row_idxs=None, + col_idxs=None, + logger=None): + """Print recalls in a table. + + Args: + recalls (ndarray): calculated from `bbox_recalls` + proposal_nums (ndarray or list): top N proposals + iou_thrs (ndarray or list): iou thresholds + row_idxs (ndarray): which rows(proposal nums) to print + col_idxs (ndarray): which cols(iou thresholds) to print + logger (logging.Logger | str | None): The way to print the recall + summary. See `mmcv.utils.print_log()` for details. Default: None. + """ + proposal_nums = np.array(proposal_nums, dtype=np.int32) + iou_thrs = np.array(iou_thrs) + if row_idxs is None: + row_idxs = np.arange(proposal_nums.size) + if col_idxs is None: + col_idxs = np.arange(iou_thrs.size) + row_header = [''] + iou_thrs[col_idxs].tolist() + table_data = [row_header] + for i, num in enumerate(proposal_nums[row_idxs]): + row = [f'{val:.3f}' for val in recalls[row_idxs[i], col_idxs].tolist()] + row.insert(0, num) + table_data.append(row) + table = AsciiTable(table_data) + print_log('\n' + table.table, logger=logger) + + +def plot_num_recall(recalls, proposal_nums): + """Plot Proposal_num-Recalls curve. + + Args: + recalls(ndarray or list): shape (k,) + proposal_nums(ndarray or list): same shape as `recalls` + """ + if isinstance(proposal_nums, np.ndarray): + _proposal_nums = proposal_nums.tolist() + else: + _proposal_nums = proposal_nums + if isinstance(recalls, np.ndarray): + _recalls = recalls.tolist() + else: + _recalls = recalls + + import matplotlib.pyplot as plt + f = plt.figure() + plt.plot([0] + _proposal_nums, [0] + _recalls) + plt.xlabel('Proposal num') + plt.ylabel('Recall') + plt.axis([0, proposal_nums.max(), 0, 1]) + f.show() + + +def plot_iou_recall(recalls, iou_thrs): + """Plot IoU-Recalls curve. + + Args: + recalls(ndarray or list): shape (k,) + iou_thrs(ndarray or list): same shape as `recalls` + """ + if isinstance(iou_thrs, np.ndarray): + _iou_thrs = iou_thrs.tolist() + else: + _iou_thrs = iou_thrs + if isinstance(recalls, np.ndarray): + _recalls = recalls.tolist() + else: + _recalls = recalls + + import matplotlib.pyplot as plt + f = plt.figure() + plt.plot(_iou_thrs + [1.0], _recalls + [0.]) + plt.xlabel('IoU') + plt.ylabel('Recall') + plt.axis([iou_thrs.min(), 1, 0, 1]) + f.show() diff --git a/mmrotate/datasets/__init__.py b/mmrotate/datasets/__init__.py index 559c42c34..05dd2849c 100644 --- a/mmrotate/datasets/__init__.py +++ b/mmrotate/datasets/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .builder import build_dataset # noqa: F401, F403 from .dota import DOTADataset # noqa: F401, F403 +from .hrsc import HRSCDataset # noqa: F401, F403 from .pipelines import * # noqa: F401, F403 from .sar import SARDataset # noqa: F401, F403 -__all__ = ['SARDataset', 'DOTADataset', 'build_dataset'] +__all__ = ['SARDataset', 'DOTADataset', 'build_dataset', 'HRSCDataset'] diff --git a/mmrotate/datasets/hrsc.py b/mmrotate/datasets/hrsc.py new file mode 100644 index 000000000..035605f56 --- /dev/null +++ b/mmrotate/datasets/hrsc.py @@ -0,0 +1,264 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import xml.etree.ElementTree as ET +from collections import OrderedDict + +import mmcv +import numpy as np +from mmcv import print_log +from mmdet.datasets import CustomDataset +from PIL import Image + +from mmrotate.core.bbox import obb2poly_np, poly2obb_np +from mmrotate.core.evaluation import eval_map, eval_recalls +from .builder import ROTATED_DATASETS + + +@ROTATED_DATASETS.register_module() +class HRSCDataset(CustomDataset): + """HRSC dataset for detection. + + Args: + ann_file (str): Annotation file path. + pipeline (list[dict]): Processing pipeline. + img_subdir (str): Subdir where images are stored. Default: JPEGImages. + ann_subdir (str): Subdir where annotations are. Default: Annotations. + classwise (bool): Whether to use all classes or only ship. + version (str, optional): Angle representations. Defaults to 'oc'. + """ + + CLASSES = None + HRSC_CLASS = ('ship', ) + HRSC_CLASSES = ('ship', 'aircraft carrier', 'warcraft', 'merchant ship', + 'Nimitz', 'Enterprise', 'Arleigh Burke', 'WhidbeyIsland', + 'Perry', 'Sanantonio', 'Ticonderoga', 'Kitty Hawk', + 'Kuznetsov', 'Abukuma', 'Austen', 'Tarawa', 'Blue Ridge', + 'Container', 'OXo|--)', 'Car carrier([]==[])', + 'Hovercraft', 'yacht', 'CntShip(_|.--.--|_]=', 'Cruise', + 'submarine', 'lute', 'Medical', 'Car carrier(======|', + 'Ford-class', 'Midway-class', 'Invincible-class') + HRSC_CLASSES_ID = ('01', '02', '03', '04', '05', '06', '07', '08', '09', + '10', '11', '12', '13', '14', '15', '16', '17', '18', + '19', '20', '22', '24', '25', '26', '27', '28', '29', + '30', '31', '32', '33') + + def __init__(self, + ann_file, + pipeline, + img_subdir='JPEGImages', + ann_subdir='Annotations', + classwise=False, + version='oc', + **kwargs): + self.img_subdir = img_subdir + self.ann_subdir = ann_subdir + self.classwise = classwise + self.version = version + if self.classwise: + HRSCDataset.CLASSES = self.HRSC_CLASSES + self.catid2label = { + ('1' + '0' * 6 + cls_id): i + for i, cls_id in enumerate(self.HRSC_CLASSES_ID) + } + else: + HRSCDataset.CLASSES = self.HRSC_CLASS + # self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} + super(HRSCDataset, self).__init__(ann_file, pipeline, **kwargs) + + def load_annotations(self, ann_file): + """Load annotation from XML style ann_file. + + Args: + ann_file (str): Path of Imageset file. + + Returns: + list[dict]: Annotation info from XML file. + """ + + data_infos = [] + img_ids = mmcv.list_from_file(ann_file) + for img_id in img_ids: + data_info = {} + + filename = osp.join(self.img_subdir, f'{img_id}.bmp') + data_info['filename'] = filename + xml_path = osp.join(self.img_prefix, self.ann_subdir, + f'{img_id}.xml') + tree = ET.parse(xml_path) + root = tree.getroot() + + width = int(root.find('Img_SizeWidth').text) + height = int(root.find('Img_SizeHeight').text) + + if width is None or height is None: + img_path = osp.join(self.img_prefix, filename) + img = Image.open(img_path) + width, height = img.size + data_info['width'] = width + data_info['height'] = height + data_info['ann'] = {} + gt_bboxes = [] + gt_labels = [] + gt_polygons = [] + gt_headers = [] + gt_bboxes_ignore = [] + gt_labels_ignore = [] + gt_polygons_ignore = [] + gt_headers_ignore = [] + + for obj in root.findall('HRSC_Objects/HRSC_Object'): + if self.classwise: + class_id = obj.find('Class_ID').text + if class_id not in self.CLASSES_ID: + continue + label = self.catid2label[class_id] + else: + label = 0 + + try: + # Add an extra score to use obb2poly_np + bbox = np.array([[ + float(obj.find('mbox_cx').text), + float(obj.find('mbox_cy').text), + float(obj.find('mbox_w').text), + float(obj.find('mbox_h').text), + float(obj.find('mbox_ang').text), 0 + ]], + dtype=np.float32) + + polygon = obb2poly_np(bbox, + 'le90')[0, :-1].astype(np.float32) + if self.version != 'le90': + bbox = np.array( + poly2obb_np(polygon, self.version), + dtype=np.float32) + else: + bbox = bbox[0, :-1] + head = np.array([ + int(obj.find('header_x').text), + int(obj.find('header_y').text) + ], + dtype=np.int64) + except: # noqa: E722 + continue + + gt_bboxes.append(bbox) + gt_labels.append(label) + gt_polygons.append(polygon) + gt_headers.append(head) + + if gt_bboxes: + data_info['ann']['bboxes'] = np.array( + gt_bboxes, dtype=np.float32) + data_info['ann']['labels'] = np.array( + gt_labels, dtype=np.int64) + data_info['ann']['polygons'] = np.array( + gt_polygons, dtype=np.float32) + data_info['ann']['headers'] = np.array( + gt_headers, dtype=np.int64) + else: + data_info['ann']['bboxes'] = np.zeros((0, 5), dtype=np.float32) + data_info['ann']['labels'] = np.array([], dtype=np.int64) + data_info['ann']['polygons'] = np.zeros((0, 8), + dtype=np.float32) + data_info['ann']['headers'] = np.zeros((0, 2), + dtype=np.float32) + + if gt_polygons_ignore: + data_info['ann']['bboxes_ignore'] = np.array( + gt_bboxes_ignore, dtype=np.float32) + data_info['ann']['labels_ignore'] = np.array( + gt_labels_ignore, dtype=np.int64) + data_info['ann']['polygons_ignore'] = np.array( + gt_polygons_ignore, dtype=np.float32) + data_info['ann']['headers_ignore'] = np.array( + gt_headers_ignore, dtype=np.float32) + else: + data_info['ann']['bboxes_ignore'] = np.zeros((0, 5), + dtype=np.float32) + data_info['ann']['labels_ignore'] = np.array([], + dtype=np.int64) + data_info['ann']['polygons_ignore'] = np.zeros( + (0, 8), dtype=np.float32) + data_info['ann']['headers_ignore'] = np.zeros((0, 2), + dtype=np.float32) + + data_infos.append(data_info) + return data_infos + + def _filter_imgs(self): + """Filter images without ground truths.""" + valid_inds = [] + for i, data_info in enumerate(self.data_infos): + if data_info['ann']['labels'].size > 0: + valid_inds.append(i) + return valid_inds + + def evaluate(self, + results, + metric='mAP', + logger=None, + proposal_nums=(100, 300, 1000), + iou_thr=0.5, + scale_ranges=None, + use_07_metric=True, + nproc=4): + """Evaluate the dataset. + + Args: + results (list): Testing results of the dataset. + metric (str | list[str]): Metrics to be evaluated. + logger (logging.Logger | None | str): Logger used for printing + related information during evaluation. Default: None. + proposal_nums (Sequence[int]): Proposal number used for evaluating + recalls, such as recall@100, recall@1000. + Default: (100, 300, 1000). + iou_thr (float | list[float]): IoU threshold. It must be a float + when evaluating mAP, and can be a list when evaluating recall. + Default: 0.5. + scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP. + Default: None. + use_07_metric (bool): Whether to use the voc07 metric. + nproc (int): Processes used for computing TP and FP. + Default: 4. + """ + if not isinstance(metric, str): + assert len(metric) == 1 + metric = metric[0] + allowed_metrics = ['mAP', 'recall'] + if metric not in allowed_metrics: + raise KeyError(f'metric {metric} is not supported') + + annotations = [self.get_ann_info(i) for i in range(len(self))] + eval_results = OrderedDict() + iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr + if metric == 'mAP': + assert isinstance(iou_thrs, list) + mean_aps = [] + for iou_thr in iou_thrs: + print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') + mean_ap, _ = eval_map( + results, + annotations, + scale_ranges=scale_ranges, + iou_thr=iou_thr, + use_07_metric=use_07_metric, + dataset=self.CLASSES, + logger=logger, + nproc=nproc) + mean_aps.append(mean_ap) + eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) + eval_results['mAP'] = sum(mean_aps) / len(mean_aps) + eval_results.move_to_end('mAP', last=False) + elif metric == 'recall': + gt_bboxes = [ann['bboxes'] for ann in annotations] + recalls = eval_recalls( + gt_bboxes, results, proposal_nums, iou_thrs, logger=logger) + for i, num in enumerate(proposal_nums): + for j, iou_thr in enumerate(iou_thrs): + eval_results[f'recall@{num}@{iou_thr}'] = recalls[i, j] + if recalls.shape[1] > 1: + ar = recalls.mean(axis=1) + for i, num in enumerate(proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + return eval_results