-
Notifications
You must be signed in to change notification settings - Fork 566
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4722bd6
commit 4e39a45
Showing
9 changed files
with
316 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from argparse import ArgumentParser | ||
|
||
from mmdet.apis import init_detector, show_result_pyplot | ||
|
||
from mmrotate.apis import inference_detector_by_patches | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument('img', help='Image file') | ||
parser.add_argument('config', help='Config file') | ||
parser.add_argument('checkpoint', help='Checkpoint file') | ||
parser.add_argument( | ||
'--patch_sizes', | ||
type=int, | ||
nargs='+', | ||
default=[1024], | ||
help='The sizes of patches') | ||
parser.add_argument( | ||
'--patch_steps', | ||
type=int, | ||
nargs='+', | ||
default=[824], | ||
help='The steps between two patches') | ||
parser.add_argument( | ||
'--img_ratios', | ||
type=float, | ||
nargs='+', | ||
default=[1.0], | ||
help='Image resizing ratios for multi-scale detecting') | ||
parser.add_argument( | ||
'--merge_iou_thr', | ||
type=float, | ||
default=0.1, | ||
help='IoU threshould for merging results') | ||
parser.add_argument( | ||
'--device', default='cuda:0', help='Device used for inference') | ||
parser.add_argument( | ||
'--score-thr', type=float, default=0.3, help='bbox score threshold') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(args): | ||
# build the model from a config file and a checkpoint file | ||
model = init_detector(args.config, args.checkpoint, device=args.device) | ||
# test a single image | ||
result = inference_detector_by_patches(model, args.img, args.patch_sizes, | ||
args.patch_steps, args.img_ratios, | ||
args.merge_iou_thr) | ||
# show the results | ||
show_result_pyplot(model, args.img, result, score_thr=args.score_thr) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .inference import inference_detector_by_patches | ||
from .train import train_detector | ||
|
||
__all__ = ['train_detector'] | ||
__all__ = ['inference_detector_by_patches', 'train_detector'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import mmcv | ||
import numpy as np | ||
import torch | ||
from mmcv.ops import RoIPool | ||
from mmcv.parallel import collate, scatter | ||
from mmdet.datasets import replace_ImageToTensor | ||
from mmdet.datasets.pipelines import Compose | ||
|
||
from mmrotate.core import get_multiscale_patch, merge_results, slide_window | ||
|
||
|
||
def inference_detector_by_patches(model, | ||
img, | ||
sizes, | ||
steps, | ||
ratios, | ||
merge_iou_thr, | ||
bs=1): | ||
"""inference patches with the detector. | ||
Split huge image(s) into patches and inference them with the detector. | ||
Finally, merge patch results on one huge image by nms. | ||
Args: | ||
model (nn.Module): The loaded detector. | ||
img (str | ndarray or): Either an image file or loaded image. | ||
sizes (list): The sizes of patches. | ||
steps (list): The steps between two patches. | ||
ratios (list): Image resizing ratios for multi-scale detecting. | ||
merge_iou_thr (float): IoU threshold for merging results. | ||
bs (int): Batch size, must greater than or equal to 1. | ||
Returns: | ||
If imgs is a list or tuple, the same length list type results | ||
will be returned, otherwise return the detection results directly. | ||
""" | ||
assert bs >= 1, 'The batch size must greater than or equal to 1' | ||
cfg = model.cfg | ||
device = next(model.parameters()).device # model device | ||
cfg = cfg.copy() | ||
# set loading pipeline type | ||
cfg.data.test.pipeline[0].type = 'LoadPatchFromImage' | ||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | ||
test_pipeline = Compose(cfg.data.test.pipeline) | ||
|
||
if not isinstance(img, np.ndarray): | ||
img = mmcv.imread(img) | ||
height, width = img.shape[:2] | ||
sizes, steps = get_multiscale_patch(sizes, steps, ratios) | ||
windows = slide_window(width, height, sizes, steps) | ||
|
||
# prepare patch data | ||
patch_datas = [] | ||
for window in windows: | ||
data = dict(img=img, win=window.tolist()) | ||
# build the data pipeline | ||
data = test_pipeline(data) | ||
patch_datas.append(data) | ||
|
||
results = [] | ||
start = 0 | ||
while True: | ||
data = patch_datas[start:start + bs] | ||
data = collate(data, samples_per_gpu=len(data)) | ||
# just get the actual data from DataContainer | ||
data['img_metas'] = [ | ||
img_metas.data[0] for img_metas in data['img_metas'] | ||
] | ||
data['img'] = [img.data[0] for img in data['img']] | ||
if next(model.parameters()).is_cuda: | ||
# scatter to specified GPU | ||
data = scatter(data, [device])[0] | ||
else: | ||
for m in model.modules(): | ||
assert not isinstance( | ||
m, RoIPool | ||
), 'CPU inference with RoIPool is not supported currently.' | ||
|
||
# forward the model | ||
with torch.no_grad(): | ||
results.extend(model(return_loss=False, rescale=True, **data)) | ||
|
||
if start + bs >= len(patch_datas): | ||
break | ||
start += bs | ||
|
||
results = merge_results( | ||
results, windows[:, :2], iou_thr=merge_iou_thr, device=device) | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .anchor import * # noqa: F401, F403 | ||
from .bbox import * # noqa: F401, F403 | ||
from .patch import * # noqa: F401, F403 | ||
from .post_processing import * # noqa: F401, F403 | ||
from .visualization import * # noqa: F401, F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .merge_results import merge_results | ||
from .split import get_multiscale_patch, slide_window | ||
|
||
__all__ = ['merge_results', 'get_multiscale_patch', 'slide_window'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
import torch | ||
from mmcv.ops import nms_rotated | ||
|
||
|
||
def merge_results(results, offsets, iou_thr=0.1, device='cpu'): | ||
"""Merge patch results via nms. | ||
Args: | ||
results (list[np.ndarray]): A list of patches results. | ||
offsets (np.ndarray): Positions of the left top points of patches. | ||
iou_thr (float): The IoU threshold of NMS. | ||
device (str): The device to call nms. | ||
Retunrns: | ||
list[np.ndarray]: Detection results after merging. | ||
""" | ||
assert len(results) == offsets.shape[0], 'The `results` should has the ' \ | ||
'same length with `offsets`.' | ||
merged_results = [] | ||
for results_pre_cls in zip(*results): | ||
tran_dets = [] | ||
for dets, offset in zip(results_pre_cls, offsets): | ||
dets[:, :2] += offset | ||
tran_dets.append(dets) | ||
tran_dets = np.concatenate(tran_dets, axis=0) | ||
|
||
if tran_dets.size == 0: | ||
merged_results.append(tran_dets) | ||
else: | ||
tran_dets = torch.from_numpy(tran_dets) | ||
tran_dets = tran_dets.to(device) | ||
nms_dets, _ = nms_rotated(tran_dets[:, :5], tran_dets[:, -1], | ||
iou_thr) | ||
merged_results.append(nms_dets.cpu().numpy()) | ||
return merged_results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from itertools import product | ||
from math import ceil | ||
|
||
import numpy as np | ||
|
||
|
||
def get_multiscale_patch(sizes, steps, ratios): | ||
"""Get multiscale patch sizes and steps. | ||
Args: | ||
sizes (list): A list of patch sizes. | ||
steps (list): A list of steps to slide patches. | ||
ratios (list): Multiscale ratios. devidie to each size and step and | ||
generate patches in new scales. | ||
Returns: | ||
new_sizes (list): A list of multiscale patch sizes. | ||
new_steps (list): A list of steps corresponding to new_sizes. | ||
""" | ||
assert len(sizes) == len(steps), 'The length of `sizes` and `steps`' \ | ||
'should be the same.' | ||
new_sizes, new_steps = [], [] | ||
size_steps = list(zip(sizes, steps)) | ||
for (size, step), ratio in product(size_steps, ratios): | ||
new_sizes.append(int(size / ratio)) | ||
new_steps.append(int(step / ratio)) | ||
return new_sizes, new_steps | ||
|
||
|
||
def slide_window(width, height, sizes, steps, img_rate_thr=0.6): | ||
"""Slide windows in images and get window position. | ||
Args: | ||
width (int): The width of the image. | ||
height (int): The height of the image. | ||
sizes (list): List of window's sizes. | ||
steps (list): List of window's steps. | ||
img_rate_thr (float): Threshold of window area divided by image area. | ||
Returns: | ||
np.ndarray: Information of valid windows. | ||
""" | ||
assert 1 >= img_rate_thr >= 0, 'The `in_rate_thr` should lie in 0~1' | ||
windows = [] | ||
# Sliding windows. | ||
for size, step in zip(sizes, steps): | ||
assert size > step, 'Size should large than step' | ||
|
||
x_num = 1 if width <= size else ceil((width - size) / step + 1) | ||
x_start = [step * i for i in range(x_num)] | ||
if len(x_start) > 1 and x_start[-1] + size > width: | ||
x_start[-1] = width - size | ||
|
||
y_num = 1 if height <= size else ceil((height - size) / step + 1) | ||
y_start = [step * i for i in range(y_num)] | ||
if len(y_start) > 1 and y_start[-1] + size > height: | ||
y_start[-1] = height - size | ||
|
||
start = np.array(list(product(x_start, y_start)), dtype=np.int64) | ||
windows.append(np.concatenate([start, start + size], axis=1)) | ||
windows = np.concatenate(windows, axis=0) | ||
|
||
# Calculate the rate of image part in each window. | ||
img_in_wins = windows.copy() | ||
img_in_wins[:, 0::2] = np.clip(img_in_wins[:, 0::2], 0, width) | ||
img_in_wins[:, 1::2] = np.clip(img_in_wins[:, 1::2], 0, height) | ||
img_areas = (img_in_wins[:, 2] - img_in_wins[:, 0]) * \ | ||
(img_in_wins[:, 3] - img_in_wins[:, 1]) | ||
win_areas = (windows[:, 2] - windows[:, 0]) * \ | ||
(windows[:, 3] - windows[:, 1]) | ||
img_rates = img_areas / win_areas | ||
if not (img_rates >= img_rate_thr).any(): | ||
img_rates[img_rates == img_rates.max()] = 1 | ||
return windows[img_rates >= img_rate_thr] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .loading import LoadPatchFromImage | ||
from .transforms import PolyRandomRotate, RRandomFlip, RResize | ||
|
||
__all__ = ['RResize', 'RRandomFlip', 'PolyRandomRotate'] | ||
__all__ = ['LoadPatchFromImage', 'RResize', 'RRandomFlip', 'PolyRandomRotate'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import mmcv | ||
import numpy as np | ||
from mmdet.datasets.pipelines import LoadImageFromFile | ||
|
||
from ..builder import ROTATED_PIPELINES | ||
|
||
|
||
@ROTATED_PIPELINES.register_module() | ||
class LoadPatchFromImage(LoadImageFromFile): | ||
"""Load an patch from the huge image. | ||
Similar with :obj:`LoadImageFromFile`, but only reserve a patch of | ||
``results['img']`` according to ``results['win']``. | ||
""" | ||
|
||
def __call__(self, results): | ||
"""Call functions to add image meta information. | ||
Args: | ||
results (dict): Result dict with Webcam read image in | ||
``results['img']``. | ||
Returns: | ||
dict: The dict contains loaded image and meta information. | ||
""" | ||
|
||
img = results['img'] | ||
x_start, y_start, x_stop, y_stop = results['win'] | ||
width = x_stop - x_start | ||
height = y_stop - y_start | ||
|
||
patch = img[y_start:y_stop, x_start:x_stop] | ||
if height > patch.shape[0] or width > patch.shape[1]: | ||
patch = mmcv.impad(patch, shape=(height, width)) | ||
|
||
if self.to_float32: | ||
patch = patch.astype(np.float32) | ||
|
||
results['filename'] = None | ||
results['ori_filename'] = None | ||
results['img'] = patch | ||
results['img_shape'] = patch.shape | ||
results['ori_shape'] = patch.shape | ||
results['img_fields'] = ['img'] | ||
return results |