diff --git a/demo/huge_image_demo.py b/demo/huge_image_demo.py new file mode 100644 index 000000000..f73ab0661 --- /dev/null +++ b/demo/huge_image_demo.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser + +from mmdet.apis import init_detector, show_result_pyplot + +from mmrotate.apis import inference_detector_by_patches + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument('img', help='Image file') + parser.add_argument('config', help='Config file') + parser.add_argument('checkpoint', help='Checkpoint file') + parser.add_argument( + '--patch_sizes', + type=int, + nargs='+', + default=[1024], + help='The sizes of patches') + parser.add_argument( + '--patch_steps', + type=int, + nargs='+', + default=[824], + help='The steps between two patches') + parser.add_argument( + '--img_ratios', + type=float, + nargs='+', + default=[1.0], + help='Image resizing ratios for multi-scale detecting') + parser.add_argument( + '--merge_iou_thr', + type=float, + default=0.1, + help='IoU threshould for merging results') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--score-thr', type=float, default=0.3, help='bbox score threshold') + args = parser.parse_args() + return args + + +def main(args): + # build the model from a config file and a checkpoint file + model = init_detector(args.config, args.checkpoint, device=args.device) + # test a single image + result = inference_detector_by_patches(model, args.img, args.patch_sizes, + args.patch_steps, args.img_ratios, + args.merge_iou_thr) + # show the results + show_result_pyplot(model, args.img, result, score_thr=args.score_thr) + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/mmrotate/apis/__init__.py b/mmrotate/apis/__init__.py index 9a53f1db2..de13731d9 100644 --- a/mmrotate/apis/__init__.py +++ b/mmrotate/apis/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .inference import inference_detector_by_patches from .train import train_detector -__all__ = ['train_detector'] +__all__ = ['inference_detector_by_patches', 'train_detector'] diff --git a/mmrotate/apis/inference.py b/mmrotate/apis/inference.py new file mode 100644 index 000000000..06fc00c23 --- /dev/null +++ b/mmrotate/apis/inference.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +import torch +from mmcv.ops import RoIPool +from mmcv.parallel import collate, scatter +from mmdet.datasets import replace_ImageToTensor +from mmdet.datasets.pipelines import Compose + +from mmrotate.core import get_multiscale_patch, merge_results, slide_window + + +def inference_detector_by_patches(model, + img, + sizes, + steps, + ratios, + merge_iou_thr, + bs=1): + """inference patches with the detector. + + Split huge image(s) into patches and inference them with the detector. + Finally, merge patch results on one huge image by nms. + + Args: + model (nn.Module): The loaded detector. + img (str | ndarray or): Either an image file or loaded image. + sizes (list): The sizes of patches. + steps (list): The steps between two patches. + ratios (list): Image resizing ratios for multi-scale detecting. + merge_iou_thr (float): IoU threshold for merging results. + bs (int): Batch size, must greater than or equal to 1. + + Returns: + If imgs is a list or tuple, the same length list type results + will be returned, otherwise return the detection results directly. + """ + assert bs >= 1, 'The batch size must greater than or equal to 1' + cfg = model.cfg + device = next(model.parameters()).device # model device + cfg = cfg.copy() + # set loading pipeline type + cfg.data.test.pipeline[0].type = 'LoadPatchFromImage' + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) + test_pipeline = Compose(cfg.data.test.pipeline) + + if not isinstance(img, np.ndarray): + img = mmcv.imread(img) + height, width = img.shape[:2] + sizes, steps = get_multiscale_patch(sizes, steps, ratios) + windows = slide_window(width, height, sizes, steps) + + # prepare patch data + patch_datas = [] + for window in windows: + data = dict(img=img, win=window.tolist()) + # build the data pipeline + data = test_pipeline(data) + patch_datas.append(data) + + results = [] + start = 0 + while True: + data = patch_datas[start:start + bs] + data = collate(data, samples_per_gpu=len(data)) + # just get the actual data from DataContainer + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + data['img'] = [img.data[0] for img in data['img']] + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + for m in model.modules(): + assert not isinstance( + m, RoIPool + ), 'CPU inference with RoIPool is not supported currently.' + + # forward the model + with torch.no_grad(): + results.extend(model(return_loss=False, rescale=True, **data)) + + if start + bs >= len(patch_datas): + break + start += bs + + results = merge_results( + results, windows[:, :2], iou_thr=merge_iou_thr, device=device) + return results diff --git a/mmrotate/core/__init__.py b/mmrotate/core/__init__.py index 1a55336e8..126eb69ea 100644 --- a/mmrotate/core/__init__.py +++ b/mmrotate/core/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .anchor import * # noqa: F401, F403 from .bbox import * # noqa: F401, F403 +from .patch import * # noqa: F401, F403 from .post_processing import * # noqa: F401, F403 from .visualization import * # noqa: F401, F403 diff --git a/mmrotate/core/patch/__init__.py b/mmrotate/core/patch/__init__.py new file mode 100644 index 000000000..5f112059b --- /dev/null +++ b/mmrotate/core/patch/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .merge_results import merge_results +from .split import get_multiscale_patch, slide_window + +__all__ = ['merge_results', 'get_multiscale_patch', 'slide_window'] diff --git a/mmrotate/core/patch/merge_results.py b/mmrotate/core/patch/merge_results.py new file mode 100644 index 000000000..5ef34c125 --- /dev/null +++ b/mmrotate/core/patch/merge_results.py @@ -0,0 +1,37 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from mmcv.ops import nms_rotated + + +def merge_results(results, offsets, iou_thr=0.1, device='cpu'): + """Merge patch results via nms. + + Args: + results (list[np.ndarray]): A list of patches results. + offsets (np.ndarray): Positions of the left top points of patches. + iou_thr (float): The IoU threshold of NMS. + device (str): The device to call nms. + + Retunrns: + list[np.ndarray]: Detection results after merging. + """ + assert len(results) == offsets.shape[0], 'The `results` should has the ' \ + 'same length with `offsets`.' + merged_results = [] + for results_pre_cls in zip(*results): + tran_dets = [] + for dets, offset in zip(results_pre_cls, offsets): + dets[:, :2] += offset + tran_dets.append(dets) + tran_dets = np.concatenate(tran_dets, axis=0) + + if tran_dets.size == 0: + merged_results.append(tran_dets) + else: + tran_dets = torch.from_numpy(tran_dets) + tran_dets = tran_dets.to(device) + nms_dets, _ = nms_rotated(tran_dets[:, :5], tran_dets[:, -1], + iou_thr) + merged_results.append(nms_dets.cpu().numpy()) + return merged_results diff --git a/mmrotate/core/patch/split.py b/mmrotate/core/patch/split.py new file mode 100644 index 000000000..41b3d6457 --- /dev/null +++ b/mmrotate/core/patch/split.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from itertools import product +from math import ceil + +import numpy as np + + +def get_multiscale_patch(sizes, steps, ratios): + """Get multiscale patch sizes and steps. + + Args: + sizes (list): A list of patch sizes. + steps (list): A list of steps to slide patches. + ratios (list): Multiscale ratios. devidie to each size and step and + generate patches in new scales. + + Returns: + new_sizes (list): A list of multiscale patch sizes. + new_steps (list): A list of steps corresponding to new_sizes. + """ + assert len(sizes) == len(steps), 'The length of `sizes` and `steps`' \ + 'should be the same.' + new_sizes, new_steps = [], [] + size_steps = list(zip(sizes, steps)) + for (size, step), ratio in product(size_steps, ratios): + new_sizes.append(int(size / ratio)) + new_steps.append(int(step / ratio)) + return new_sizes, new_steps + + +def slide_window(width, height, sizes, steps, img_rate_thr=0.6): + """Slide windows in images and get window position. + + Args: + width (int): The width of the image. + height (int): The height of the image. + sizes (list): List of window's sizes. + steps (list): List of window's steps. + img_rate_thr (float): Threshold of window area divided by image area. + + Returns: + np.ndarray: Information of valid windows. + """ + assert 1 >= img_rate_thr >= 0, 'The `in_rate_thr` should lie in 0~1' + windows = [] + # Sliding windows. + for size, step in zip(sizes, steps): + assert size > step, 'Size should large than step' + + x_num = 1 if width <= size else ceil((width - size) / step + 1) + x_start = [step * i for i in range(x_num)] + if len(x_start) > 1 and x_start[-1] + size > width: + x_start[-1] = width - size + + y_num = 1 if height <= size else ceil((height - size) / step + 1) + y_start = [step * i for i in range(y_num)] + if len(y_start) > 1 and y_start[-1] + size > height: + y_start[-1] = height - size + + start = np.array(list(product(x_start, y_start)), dtype=np.int64) + windows.append(np.concatenate([start, start + size], axis=1)) + windows = np.concatenate(windows, axis=0) + + # Calculate the rate of image part in each window. + img_in_wins = windows.copy() + img_in_wins[:, 0::2] = np.clip(img_in_wins[:, 0::2], 0, width) + img_in_wins[:, 1::2] = np.clip(img_in_wins[:, 1::2], 0, height) + img_areas = (img_in_wins[:, 2] - img_in_wins[:, 0]) * \ + (img_in_wins[:, 3] - img_in_wins[:, 1]) + win_areas = (windows[:, 2] - windows[:, 0]) * \ + (windows[:, 3] - windows[:, 1]) + img_rates = img_areas / win_areas + if not (img_rates >= img_rate_thr).any(): + img_rates[img_rates == img_rates.max()] = 1 + return windows[img_rates >= img_rate_thr] diff --git a/mmrotate/datasets/pipelines/__init__.py b/mmrotate/datasets/pipelines/__init__.py index b7f88e52d..1a90302ce 100644 --- a/mmrotate/datasets/pipelines/__init__.py +++ b/mmrotate/datasets/pipelines/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .loading import LoadPatchFromImage from .transforms import PolyRandomRotate, RRandomFlip, RResize -__all__ = ['RResize', 'RRandomFlip', 'PolyRandomRotate'] +__all__ = ['LoadPatchFromImage', 'RResize', 'RRandomFlip', 'PolyRandomRotate'] diff --git a/mmrotate/datasets/pipelines/loading.py b/mmrotate/datasets/pipelines/loading.py new file mode 100644 index 000000000..dbe9f65c5 --- /dev/null +++ b/mmrotate/datasets/pipelines/loading.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +from mmdet.datasets.pipelines import LoadImageFromFile + +from ..builder import ROTATED_PIPELINES + + +@ROTATED_PIPELINES.register_module() +class LoadPatchFromImage(LoadImageFromFile): + """Load an patch from the huge image. + + Similar with :obj:`LoadImageFromFile`, but only reserve a patch of + ``results['img']`` according to ``results['win']``. + """ + + def __call__(self, results): + """Call functions to add image meta information. + + Args: + results (dict): Result dict with Webcam read image in + ``results['img']``. + + Returns: + dict: The dict contains loaded image and meta information. + """ + + img = results['img'] + x_start, y_start, x_stop, y_stop = results['win'] + width = x_stop - x_start + height = y_stop - y_start + + patch = img[y_start:y_stop, x_start:x_stop] + if height > patch.shape[0] or width > patch.shape[1]: + patch = mmcv.impad(patch, shape=(height, width)) + + if self.to_float32: + patch = patch.astype(np.float32) + + results['filename'] = None + results['ori_filename'] = None + results['img'] = patch + results['img_shape'] = patch.shape + results['ori_shape'] = patch.shape + results['img_fields'] = ['img'] + return results