Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Update huge image inference #55

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added demo/dota_demo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
69 changes: 69 additions & 0 deletions demo/huge_image_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Inference on huge images.

Example:
```
python demo/huge_image_demo.py \
demo/dota_demo.jpg \
configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_v3.py \
work_dirs/oriented_rcnn_r50_fpn_1x_dota_v3/epoch_12.pth \
```
""" # nowq

from argparse import ArgumentParser

from mmdet.apis import init_detector, show_result_pyplot

from mmrotate.apis import inference_detector_by_patches


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--patch_sizes',
type=int,
nargs='+',
default=[1024],
help='The sizes of patches')
parser.add_argument(
'--patch_steps',
type=int,
nargs='+',
default=[824],
help='The steps between two patches')
parser.add_argument(
'--img_ratios',
type=float,
nargs='+',
default=[1.0],
help='Image resizing ratios for multi-scale detecting')
parser.add_argument(
'--merge_iou_thr',
type=float,
default=0.1,
help='IoU threshould for merging results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
return args


def main(args):
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# test a huge image by patches
result = inference_detector_by_patches(model, args.img, args.patch_sizes,
args.patch_steps, args.img_ratios,
args.merge_iou_thr)
# show the results
show_result_pyplot(model, args.img, result, score_thr=args.score_thr)


if __name__ == '__main__':
args = parse_args()
main(args)
8 changes: 5 additions & 3 deletions demo/image_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
"""Inference on single image.

Example:
```
python demo/image_demo.py \
demo/demo.jpg \
configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_v3.py \
work_dirs/oriented_rcnn_r50_fpn_1x_dota_v3/epoch_12.pth \
demo/vis.jpg
configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_dota_v3.py \
work_dirs/oriented_rcnn_r50_fpn_1x_dota_v3/epoch_12.pth \
demo/vis.jpg
```
""" # nowq

from argparse import ArgumentParser
Expand Down
3 changes: 2 additions & 1 deletion mmrotate/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_detector_by_patches
from .train import train_detector

__all__ = ['train_detector']
__all__ = ['inference_detector_by_patches', 'train_detector']
89 changes: 89 additions & 0 deletions mmrotate/apis/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose

from mmrotate.core import get_multiscale_patch, merge_results, slide_window


def inference_detector_by_patches(model,
img,
sizes,
steps,
ratios,
merge_iou_thr,
bs=1):
"""inference patches with the detector.

Split huge image(s) into patches and inference them with the detector.
Finally, merge patch results on one huge image by nms.

Args:
model (nn.Module): The loaded detector.
img (str | ndarray or): Either an image file or loaded image.
sizes (list): The sizes of patches.
steps (list): The steps between two patches.
ratios (list): Image resizing ratios for multi-scale detecting.
merge_iou_thr (float): IoU threshold for merging results.
bs (int): Batch size, must greater than or equal to 1.

Returns:
list[np.ndarray]: Detection results.
"""
assert bs >= 1, 'The batch size must greater than or equal to 1'
cfg = model.cfg
device = next(model.parameters()).device # model device
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadPatchFromImage'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)

if not isinstance(img, np.ndarray):
img = mmcv.imread(img)
height, width = img.shape[:2]
sizes, steps = get_multiscale_patch(sizes, steps, ratios)
windows = slide_window(width, height, sizes, steps)

# prepare patch data
patch_datas = []
for window in windows:
data = dict(img=img, win=window.tolist())
# build the data pipeline
data = test_pipeline(data)
patch_datas.append(data)

results = []
start = 0
while True:
data = patch_datas[start:start + bs]
data = collate(data, samples_per_gpu=len(data))
# just get the actual data from DataContainer
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
data['img'] = [img.data[0] for img in data['img']]
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'

# forward the model
with torch.no_grad():
results.extend(model(return_loss=False, rescale=True, **data))

if start + bs >= len(patch_datas):
break
start += bs

results = merge_results(
results, windows[:, :2], iou_thr=merge_iou_thr, device=device)
return results
1 change: 1 addition & 0 deletions mmrotate/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor import * # noqa: F401, F403
from .bbox import * # noqa: F401, F403
from .patch import * # noqa: F401, F403
from .post_processing import * # noqa: F401, F403
from .visualization import * # noqa: F401, F403
5 changes: 5 additions & 0 deletions mmrotate/core/patch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .merge_results import merge_results
from .split import get_multiscale_patch, slide_window

__all__ = ['merge_results', 'get_multiscale_patch', 'slide_window']
37 changes: 37 additions & 0 deletions mmrotate/core/patch/merge_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.ops import nms_rotated


def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
"""Merge patch results via nms.

Args:
results (list[np.ndarray]): A list of patches results.
offsets (np.ndarray): Positions of the left top points of patches.
iou_thr (float): The IoU threshold of NMS.
device (str): The device to call nms.

Retunrns:
list[np.ndarray]: Detection results after merging.
"""
assert len(results) == offsets.shape[0], 'The `results` should has the ' \
'same length with `offsets`.'
merged_results = []
for results_pre_cls in zip(*results):
tran_dets = []
for dets, offset in zip(results_pre_cls, offsets):
dets[:, :2] += offset
tran_dets.append(dets)
tran_dets = np.concatenate(tran_dets, axis=0)

if tran_dets.size == 0:
merged_results.append(tran_dets)
else:
tran_dets = torch.from_numpy(tran_dets)
tran_dets = tran_dets.to(device)
nms_dets, _ = nms_rotated(tran_dets[:, :5], tran_dets[:, -1],
iou_thr)
merged_results.append(nms_dets.cpu().numpy())
return merged_results
75 changes: 75 additions & 0 deletions mmrotate/core/patch/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import product
from math import ceil

import numpy as np


def get_multiscale_patch(sizes, steps, ratios):
"""Get multiscale patch sizes and steps.

Args:
sizes (list): A list of patch sizes.
steps (list): A list of steps to slide patches.
ratios (list): Multiscale ratios. devidie to each size and step and
generate patches in new scales.

Returns:
new_sizes (list): A list of multiscale patch sizes.
new_steps (list): A list of steps corresponding to new_sizes.
"""
assert len(sizes) == len(steps), 'The length of `sizes` and `steps`' \
'should be the same.'
new_sizes, new_steps = [], []
size_steps = list(zip(sizes, steps))
for (size, step), ratio in product(size_steps, ratios):
new_sizes.append(int(size / ratio))
new_steps.append(int(step / ratio))
return new_sizes, new_steps


def slide_window(width, height, sizes, steps, img_rate_thr=0.6):
"""Slide windows in images and get window position.

Args:
width (int): The width of the image.
height (int): The height of the image.
sizes (list): List of window's sizes.
steps (list): List of window's steps.
img_rate_thr (float): Threshold of window area divided by image area.

Returns:
np.ndarray: Information of valid windows.
"""
assert 1 >= img_rate_thr >= 0, 'The `in_rate_thr` should lie in 0~1'
windows = []
# Sliding windows.
for size, step in zip(sizes, steps):
assert size > step, 'Size should large than step'

x_num = 1 if width <= size else ceil((width - size) / step + 1)
x_start = [step * i for i in range(x_num)]
if len(x_start) > 1 and x_start[-1] + size > width:
x_start[-1] = width - size

y_num = 1 if height <= size else ceil((height - size) / step + 1)
y_start = [step * i for i in range(y_num)]
if len(y_start) > 1 and y_start[-1] + size > height:
y_start[-1] = height - size

start = np.array(list(product(x_start, y_start)), dtype=np.int64)
windows.append(np.concatenate([start, start + size], axis=1))
windows = np.concatenate(windows, axis=0)

# Calculate the rate of image part in each window.
img_in_wins = windows.copy()
img_in_wins[:, 0::2] = np.clip(img_in_wins[:, 0::2], 0, width)
img_in_wins[:, 1::2] = np.clip(img_in_wins[:, 1::2], 0, height)
img_areas = (img_in_wins[:, 2] - img_in_wins[:, 0]) * \
(img_in_wins[:, 3] - img_in_wins[:, 1])
win_areas = (windows[:, 2] - windows[:, 0]) * \
(windows[:, 3] - windows[:, 1])
img_rates = img_areas / win_areas
if not (img_rates >= img_rate_thr).any():
img_rates[img_rates == img_rates.max()] = 1
return windows[img_rates >= img_rate_thr]
3 changes: 2 additions & 1 deletion mmrotate/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .loading import LoadPatchFromImage
from .transforms import PolyRandomRotate, RRandomFlip, RResize

__all__ = ['RResize', 'RRandomFlip', 'PolyRandomRotate']
__all__ = ['LoadPatchFromImage', 'RResize', 'RRandomFlip', 'PolyRandomRotate']
45 changes: 45 additions & 0 deletions mmrotate/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.datasets.pipelines import LoadImageFromFile

from ..builder import ROTATED_PIPELINES


@ROTATED_PIPELINES.register_module()
class LoadPatchFromImage(LoadImageFromFile):
"""Load an patch from the huge image.

Similar with :obj:`LoadImageFromFile`, but only reserve a patch of
``results['img']`` according to ``results['win']``.
"""

def __call__(self, results):
"""Call functions to add image meta information.

Args:
results (dict): Result dict with image in ``results['img']``.

Returns:
dict: The dict contains the loaded patch and meta information.
"""

img = results['img']
x_start, y_start, x_stop, y_stop = results['win']
width = x_stop - x_start
height = y_stop - y_start

patch = img[y_start:y_stop, x_start:x_stop]
if height > patch.shape[0] or width > patch.shape[1]:
patch = mmcv.impad(patch, shape=(height, width))

if self.to_float32:
patch = patch.astype(np.float32)

results['filename'] = None
results['ori_filename'] = None
results['img'] = patch
results['img_shape'] = patch.shape
results['ori_shape'] = patch.shape
results['img_fields'] = ['img']
return results