Skip to content

Commit

Permalink
inference huge image
Browse files Browse the repository at this point in the history
  • Loading branch information
jbwang1997 committed Feb 25, 2022
1 parent 4722bd6 commit 4e39a45
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 2 deletions.
58 changes: 58 additions & 0 deletions demo/huge_image_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) OpenMMLab. All rights reserved.
from argparse import ArgumentParser

from mmdet.apis import init_detector, show_result_pyplot

from mmrotate.apis import inference_detector_by_patches


def parse_args():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--patch_sizes',
type=int,
nargs='+',
default=[1024],
help='The sizes of patches')
parser.add_argument(
'--patch_steps',
type=int,
nargs='+',
default=[824],
help='The steps between two patches')
parser.add_argument(
'--img_ratios',
type=float,
nargs='+',
default=[1.0],
help='Image resizing ratios for multi-scale detecting')
parser.add_argument(
'--merge_iou_thr',
type=float,
default=0.1,
help='IoU threshould for merging results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
return args


def main(args):
# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# test a single image
result = inference_detector_by_patches(model, args.img, args.patch_sizes,
args.patch_steps, args.img_ratios,
args.merge_iou_thr)
# show the results
show_result_pyplot(model, args.img, result, score_thr=args.score_thr)


if __name__ == '__main__':
args = parse_args()
main(args)
3 changes: 2 additions & 1 deletion mmrotate/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_detector_by_patches
from .train import train_detector

__all__ = ['train_detector']
__all__ = ['inference_detector_by_patches', 'train_detector']
90 changes: 90 additions & 0 deletions mmrotate/apis/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
import torch
from mmcv.ops import RoIPool
from mmcv.parallel import collate, scatter
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose

from mmrotate.core import get_multiscale_patch, merge_results, slide_window


def inference_detector_by_patches(model,
img,
sizes,
steps,
ratios,
merge_iou_thr,
bs=1):
"""inference patches with the detector.
Split huge image(s) into patches and inference them with the detector.
Finally, merge patch results on one huge image by nms.
Args:
model (nn.Module): The loaded detector.
img (str | ndarray or): Either an image file or loaded image.
sizes (list): The sizes of patches.
steps (list): The steps between two patches.
ratios (list): Image resizing ratios for multi-scale detecting.
merge_iou_thr (float): IoU threshold for merging results.
bs (int): Batch size, must greater than or equal to 1.
Returns:
If imgs is a list or tuple, the same length list type results
will be returned, otherwise return the detection results directly.
"""
assert bs >= 1, 'The batch size must greater than or equal to 1'
cfg = model.cfg
device = next(model.parameters()).device # model device
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadPatchFromImage'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)

if not isinstance(img, np.ndarray):
img = mmcv.imread(img)
height, width = img.shape[:2]
sizes, steps = get_multiscale_patch(sizes, steps, ratios)
windows = slide_window(width, height, sizes, steps)

# prepare patch data
patch_datas = []
for window in windows:
data = dict(img=img, win=window.tolist())
# build the data pipeline
data = test_pipeline(data)
patch_datas.append(data)

results = []
start = 0
while True:
data = patch_datas[start:start + bs]
data = collate(data, samples_per_gpu=len(data))
# just get the actual data from DataContainer
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
data['img'] = [img.data[0] for img in data['img']]
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
for m in model.modules():
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'

# forward the model
with torch.no_grad():
results.extend(model(return_loss=False, rescale=True, **data))

if start + bs >= len(patch_datas):
break
start += bs

results = merge_results(
results, windows[:, :2], iou_thr=merge_iou_thr, device=device)
return results
1 change: 1 addition & 0 deletions mmrotate/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor import * # noqa: F401, F403
from .bbox import * # noqa: F401, F403
from .patch import * # noqa: F401, F403
from .post_processing import * # noqa: F401, F403
from .visualization import * # noqa: F401, F403
5 changes: 5 additions & 0 deletions mmrotate/core/patch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .merge_results import merge_results
from .split import get_multiscale_patch, slide_window

__all__ = ['merge_results', 'get_multiscale_patch', 'slide_window']
37 changes: 37 additions & 0 deletions mmrotate/core/patch/merge_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcv.ops import nms_rotated


def merge_results(results, offsets, iou_thr=0.1, device='cpu'):
"""Merge patch results via nms.
Args:
results (list[np.ndarray]): A list of patches results.
offsets (np.ndarray): Positions of the left top points of patches.
iou_thr (float): The IoU threshold of NMS.
device (str): The device to call nms.
Retunrns:
list[np.ndarray]: Detection results after merging.
"""
assert len(results) == offsets.shape[0], 'The `results` should has the ' \
'same length with `offsets`.'
merged_results = []
for results_pre_cls in zip(*results):
tran_dets = []
for dets, offset in zip(results_pre_cls, offsets):
dets[:, :2] += offset
tran_dets.append(dets)
tran_dets = np.concatenate(tran_dets, axis=0)

if tran_dets.size == 0:
merged_results.append(tran_dets)
else:
tran_dets = torch.from_numpy(tran_dets)
tran_dets = tran_dets.to(device)
nms_dets, _ = nms_rotated(tran_dets[:, :5], tran_dets[:, -1],
iou_thr)
merged_results.append(nms_dets.cpu().numpy())
return merged_results
75 changes: 75 additions & 0 deletions mmrotate/core/patch/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) OpenMMLab. All rights reserved.
from itertools import product
from math import ceil

import numpy as np


def get_multiscale_patch(sizes, steps, ratios):
"""Get multiscale patch sizes and steps.
Args:
sizes (list): A list of patch sizes.
steps (list): A list of steps to slide patches.
ratios (list): Multiscale ratios. devidie to each size and step and
generate patches in new scales.
Returns:
new_sizes (list): A list of multiscale patch sizes.
new_steps (list): A list of steps corresponding to new_sizes.
"""
assert len(sizes) == len(steps), 'The length of `sizes` and `steps`' \
'should be the same.'
new_sizes, new_steps = [], []
size_steps = list(zip(sizes, steps))
for (size, step), ratio in product(size_steps, ratios):
new_sizes.append(int(size / ratio))
new_steps.append(int(step / ratio))
return new_sizes, new_steps


def slide_window(width, height, sizes, steps, img_rate_thr=0.6):
"""Slide windows in images and get window position.
Args:
width (int): The width of the image.
height (int): The height of the image.
sizes (list): List of window's sizes.
steps (list): List of window's steps.
img_rate_thr (float): Threshold of window area divided by image area.
Returns:
np.ndarray: Information of valid windows.
"""
assert 1 >= img_rate_thr >= 0, 'The `in_rate_thr` should lie in 0~1'
windows = []
# Sliding windows.
for size, step in zip(sizes, steps):
assert size > step, 'Size should large than step'

x_num = 1 if width <= size else ceil((width - size) / step + 1)
x_start = [step * i for i in range(x_num)]
if len(x_start) > 1 and x_start[-1] + size > width:
x_start[-1] = width - size

y_num = 1 if height <= size else ceil((height - size) / step + 1)
y_start = [step * i for i in range(y_num)]
if len(y_start) > 1 and y_start[-1] + size > height:
y_start[-1] = height - size

start = np.array(list(product(x_start, y_start)), dtype=np.int64)
windows.append(np.concatenate([start, start + size], axis=1))
windows = np.concatenate(windows, axis=0)

# Calculate the rate of image part in each window.
img_in_wins = windows.copy()
img_in_wins[:, 0::2] = np.clip(img_in_wins[:, 0::2], 0, width)
img_in_wins[:, 1::2] = np.clip(img_in_wins[:, 1::2], 0, height)
img_areas = (img_in_wins[:, 2] - img_in_wins[:, 0]) * \
(img_in_wins[:, 3] - img_in_wins[:, 1])
win_areas = (windows[:, 2] - windows[:, 0]) * \
(windows[:, 3] - windows[:, 1])
img_rates = img_areas / win_areas
if not (img_rates >= img_rate_thr).any():
img_rates[img_rates == img_rates.max()] = 1
return windows[img_rates >= img_rate_thr]
3 changes: 2 additions & 1 deletion mmrotate/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .loading import LoadPatchFromImage
from .transforms import PolyRandomRotate, RRandomFlip, RResize

__all__ = ['RResize', 'RRandomFlip', 'PolyRandomRotate']
__all__ = ['LoadPatchFromImage', 'RResize', 'RRandomFlip', 'PolyRandomRotate']
46 changes: 46 additions & 0 deletions mmrotate/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
from mmdet.datasets.pipelines import LoadImageFromFile

from ..builder import ROTATED_PIPELINES


@ROTATED_PIPELINES.register_module()
class LoadPatchFromImage(LoadImageFromFile):
"""Load an patch from the huge image.
Similar with :obj:`LoadImageFromFile`, but only reserve a patch of
``results['img']`` according to ``results['win']``.
"""

def __call__(self, results):
"""Call functions to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""

img = results['img']
x_start, y_start, x_stop, y_stop = results['win']
width = x_stop - x_start
height = y_stop - y_start

patch = img[y_start:y_stop, x_start:x_stop]
if height > patch.shape[0] or width > patch.shape[1]:
patch = mmcv.impad(patch, shape=(height, width))

if self.to_float32:
patch = patch.astype(np.float32)

results['filename'] = None
results['ori_filename'] = None
results['img'] = patch
results['img_shape'] = patch.shape
results['ori_shape'] = patch.shape
results['img_fields'] = ['img']
return results

0 comments on commit 4e39a45

Please sign in to comment.