From 38f22a411b61c5faa0f933cd19f22bba47e64100 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 10 Jan 2019 16:49:31 +0800 Subject: [PATCH] [TOPI][CUDA] Add faster-rcnn proposal op --- topi/python/topi/cuda/__init__.py | 1 + topi/python/topi/cuda/rcnn/__init__.py | 3 + topi/python/topi/cuda/rcnn/proposal.py | 355 +++++++++++++++++++++++ topi/python/topi/cuda/vision.py | 29 ++ topi/python/topi/generic/vision.py | 17 ++ topi/python/topi/vision/rcnn/__init__.py | 1 + topi/python/topi/vision/rcnn/proposal.py | 96 ++++++ topi/tests/python/test_topi_vision.py | 68 +++++ 8 files changed, 570 insertions(+) create mode 100644 topi/python/topi/cuda/rcnn/__init__.py create mode 100644 topi/python/topi/cuda/rcnn/proposal.py create mode 100644 topi/python/topi/vision/rcnn/proposal.py diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 28d2eb258beae..91c2235fcf700 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -18,3 +18,4 @@ from . import ssd from .ssd import * from .nms import * +from .rcnn import * diff --git a/topi/python/topi/cuda/rcnn/__init__.py b/topi/python/topi/cuda/rcnn/__init__.py new file mode 100644 index 0000000000000..dffea6f7483e8 --- /dev/null +++ b/topi/python/topi/cuda/rcnn/__init__.py @@ -0,0 +1,3 @@ +# pylint: disable=wildcard-import +"""Faster R-CNN and Mask R-CNN operators""" +from .proposal import * diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py new file mode 100644 index 0000000000000..39bb60b4cc090 --- /dev/null +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -0,0 +1,355 @@ +# pylint: disable=invalid-name, singleton-comparison +"""Proposal operator""" +import math +import tvm +from ...vision.rcnn import proposal, generate_anchor, reg_bbox, reg_iou +from ...util import get_const_tuple, get_const_int + + +def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios, + feature_stride, rpn_min_size, iou_loss): + """Predict bounding boxes based on anchors, scores and deltas. + + Parameters + ---------- + cls_prob_buf : tvm.schedule.Buffer + 4-D with shape [batch, 2 * num_anchors, height, width] + + bbox_pred_buf : tvm.schedule.Buffer + 4-D with shape [batch, 4 * num_anchors, height, width] + + im_info_buf : tvm.schedule.Buffer + 2-D with shape [batch, 3] + + out_buf : tvm.schedule.Buffer + 3-D with shape [batch, num_bbox, 5] + The last dimension is in format of [w_start, h_start, w_end, h_end, score] + + scales : list/tuple of float + Scales of anchor windoes. + + ratios : list/tuple of float + Ratios of anchor windoes. + + feature_stride : int + The size of the receptive field each unit in the convolution layer of the rpn, for example + the product of all stride's prior to this layer. + + rpn_min_size : int + Minimum height or width in proposal. + + iou_loss : bool + Usage of IoU loss. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape) + num_anchors //= 2 + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = (batch * height * width) // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + tid = bx * max_threads + tx + ib = tvm.ir_builder.create() + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + + p_score = ib.buffer_ptr(cls_prob_buf) + p_delta = ib.buffer_ptr(bbox_pred_buf) + p_im_info = ib.buffer_ptr(im_info_buf) + p_out = ib.buffer_ptr(out_buf) + + with ib.if_scope(tid < batch * height * width): + w = tid % width + h = (tid // width) % height + b = tid // width // height + + for k in range(num_anchors): + out_index = tid * num_anchors + k + ratio = ratios[k // len(scales)] + scale = scales[k % len(scales)] + anchor = generate_anchor(ratio, scale, feature_stride) + im_height = p_im_info[b * 3] + im_width = p_im_info[b * 3 + 1] + x1 = anchor[0] + w * feature_stride + y1 = anchor[1] + h * feature_stride + x2 = anchor[2] + w * feature_stride + y2 = anchor[3] + h * feature_stride + + delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)] + for i in range(4)] + regression_func = reg_iou if iou_loss else reg_bbox + pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta) + + pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0) + pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0) + pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0) + pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0) + + real_height = (im_height / feature_stride).astype('int32') + real_width = (im_width / feature_stride).astype('int32') + + bbox_w = pred_x2 - pred_x1 + 1.0 + bbox_h = pred_y2 - pred_y1 + 1.0 + min_size = p_im_info[b * 3 + 2] * rpn_min_size + + pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w] + pred_score = tvm.select(tvm.any(h >= real_height, w >= real_width), -1.0, pred_score) + + p_out[out_index * 5 + 0] = pred_x1 + p_out[out_index * 5 + 1] = pred_y1 + p_out[out_index * 5 + 2] = pred_x2 + p_out[out_index * 5 + 3] = pred_y2 + p_out[out_index * 5 + 4] = pred_score + + with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)): + p_out[out_index * 5 + 0] -= min_size / 2.0 + p_out[out_index * 5 + 1] -= min_size / 2.0 + p_out[out_index * 5 + 2] += min_size / 2.0 + p_out[out_index * 5 + 3] += min_size / 2.0 + p_out[out_index * 5 + 4] = -1.0 + + return ib.get() + + +def argsort_ir(data_buf, out_index_buf): + """Batched odd-even transposition sort. + + Parameters + ---------- + data_buf : tvm.schedule.Buffer + 2-D with shape [batch, num_bbox] + + out_index_buf : tvm.schedule.Buffer + 2-D with shape [batch, num_bbox]. Indices of data in sorted order. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch, num_bbox = get_const_tuple(data_buf.shape) + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib = tvm.ir_builder.create() + temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") + temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + p_data = ib.buffer_ptr(data_buf) + index_out = ib.buffer_ptr(out_index_buf) + nthread_tx = max_threads + nthread_bx = num_bbox // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * nthread_tx + tx + + with ib.for_range(0, batch, for_type="unroll") as b: + start = b * num_bbox + with ib.if_scope(tid < num_bbox): + index_out[start + tid] = tid + + with ib.for_range(0, num_bbox) as k: + with ib.if_scope(tid < (num_bbox + 1) // 2): + offset = start + 2 * tid + (k % 2) + with ib.if_scope( + tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])): + temp_data[0] = p_data[offset] + p_data[offset] = p_data[offset + 1] + p_data[offset + 1] = temp_data[0] + temp_index[0] = index_out[offset] + index_out[offset] = index_out[offset + 1] + index_out[offset + 1] = temp_index[0] + + return ib.get() + + +def nms_ir(sorted_bbox_buf, out_buf, nms_threshold): + """Non-maximum supression. + + Parameters + ---------- + sorted_bbox_buf : tvm.schedule.Buffer + 3-D with shape [batch, num_bbox, 5]. The last dimension is in format of + [w_start, h_start, w_end, h_end, score]. + + out_buf : tvm.schedule.Buffer + 2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed. + + nms_threshold : float + Non-maximum suppression threshold. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + def calculate_overlap(out_tensor, box_a_idx, box_b_idx): + """Calculate overlap of two boxes. + """ + w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) + - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0) + h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) + - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0) + i = w * h + u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \ + (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \ + (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \ + (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i + return i / u + + batch, num_bbox = get_const_tuple(out_buf.shape) + max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib = tvm.ir_builder.create() + p_data = ib.buffer_ptr(sorted_bbox_buf) + p_out = ib.buffer_ptr(out_buf) + nthread_tx = max_threads + nthread_bx = num_bbox // max_threads + 1 + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + j = bx * max_threads + tx + with ib.for_range(0, batch, for_type="unroll", name="n") as b: + start = b * num_bbox + with ib.if_scope(j < num_bbox): + p_out[start + j] = False + + with ib.for_range(0, num_bbox - 1) as i: + with ib.if_scope(tvm.all(j < num_bbox, j > i, p_out[start + i] == False)): + iou = calculate_overlap(p_data, (start + i) * 5, (start + j) * 5) + with ib.if_scope(iou > nms_threshold): + p_out[start + j] = True + return ib.get() + + +def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf): + """Copy output after applying nms to continuous memory. + + Parameters + ---------- + sorted_bbox_buf : tvm.schedule.Buffer + 3-D with shape [batch, num_bbox, 5]. The last dimension is in format of + [w_start, h_start, w_end, h_end, score]. + + remove_mask_buf : tvm.schedule.Buffer + 2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed. + + out_buf : tvm.schedule.Buffer + 2-D with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of + [batch_index, w_start, h_start, w_end, h_end]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape) + rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch + nthread_tx = batch + tx = tvm.thread_axis("threadIdx.x") + ib = tvm.ir_builder.create() + ib.scope_attr(tx, "thread_extent", nthread_tx) + i = ib.allocate('int32', (1,), 'i', scope='local') + i[0] = 0 + p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf) + p_remove = ib.buffer_ptr(remove_mask_buf) + p_out = ib.buffer_ptr(out_buf) + b = tx + + nkeep = ib.allocate('int32', (1,), 'nkeep', scope='local') + nkeep[0] = 0 # number of bbox after nms + + with ib.for_range(0, num_bbox) as j: + with ib.if_scope(p_remove[b * num_bbox + j] == False): + nkeep[0] += 1 + with ib.if_scope(nkeep[0] > 0): + with ib.for_range(0, tvm.ceil( + tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[0]).astype('int32')): + with ib.for_range(0, num_bbox) as j: + offset_j = (b * num_bbox + j) * 5 + offset_i = (b * rpn_post_nms_top_n + i[0]) * 5 + with ib.if_scope(tvm.all(i[0] < rpn_post_nms_top_n, + p_remove[(b*num_bbox+j)] == False)): + p_out[offset_i] = tvm.expr.Cast('float32', b) + with ib.for_range(0, 4, for_type='unroll') as k: + p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k] + i[0] = i[0] + 1 + + body = ib.get() + return body + + +@proposal.register("cuda") +def proposal_cuda(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold, + rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss): + """Proposal operator. + + Parameters + ---------- + cls_prob : tvm.Tensor + 4-D with shape [batch, 2 * num_anchors, height, width] + + bbox_pred : tvm.Tensor + 4-D with shape [batch, 4 * num_anchors, height, width] + + im_info : tvm.Tensor + 2-D with shape [batch, 3] + + scales : list/tuple of float + Scales of anchor windoes. + + ratios : list/tuple of float + Ratios of anchor windoes. + + feature_stride : int + The size of the receptive field each unit in the convolution layer of the rpn, for example + the product of all stride's prior to this layer. + + threshold : float + Non-maximum suppression threshold. + + rpn_pre_nms_top_n : int + Number of top scoring boxes to apply NMS. -1 to use all boxes. + + rpn_post_nms_top_n : int + Number of top scoring boxes to keep after applying NMS to RPN proposals. + + rpn_min_size : int + Minimum height or width in proposal. + + iou_loss : bool + Usage of IoU loss. + + Returns + ------- + out : tvm.Tensor + 2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of + [batch_index, w_start, h_start, w_end, h_end]. + """ + + batch, _, height, width = get_const_tuple(cls_prob.shape) + num_anchors = len(scales) * len(ratios) + num_bbox = height * width * num_anchors + rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox + + bbox = tvm.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs: + predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios, + feature_stride, rpn_min_size, iou_loss), + dtype=bbox_pred.dtype) + score = tvm.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score') + sorted_index = tvm.extern([score.shape], [score], + lambda ins, outs: argsort_ir(ins[0], outs[0]), + dtype='int32') + sorted_bbox = tvm.compute((batch, rpn_pre_nms_top_n, 5), + lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox') + nms_remove_mask = tvm.extern((batch, rpn_pre_nms_top_n), [sorted_bbox], + lambda ins, outs: nms_ir(ins[0], outs[0], threshold), + dtype='bool') + nms_out = tvm.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask], + lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]), + dtype=sorted_bbox.dtype) + return nms_out diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 83744593857b9..835c346f8321f 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -151,3 +151,32 @@ def schedule_multibox_detection(outs): @generic.schedule_roi_align.register(["cuda", "gpu"]) def schedule_roi_align(outs): return schedule_pool(outs, 'NCHW') + +@generic.schedule_proposal.register(["cuda", "gpu"]) +def schedule_proposal(outs): + """Schedule for proposal operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of roi_align + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + from .injective import _schedule_injective + def traverse(op): + if op.tag in ['bbox_score', 'sorted_bbox']: + _schedule_injective(op, s) + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + traverse(outs[0].op) + return s diff --git a/topi/python/topi/generic/vision.py b/topi/python/topi/generic/vision.py index a053727770e9f..8f903c11680d8 100644 --- a/topi/python/topi/generic/vision.py +++ b/topi/python/topi/generic/vision.py @@ -157,3 +157,20 @@ def schedule_roi_align(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + +@tvm.target.generic_func +def schedule_proposal(outs): + """Schedule for proposal operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of roi_align + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/vision/rcnn/__init__.py b/topi/python/topi/vision/rcnn/__init__.py index 3704d98908d61..4f6148e67b71a 100644 --- a/topi/python/topi/vision/rcnn/__init__.py +++ b/topi/python/topi/vision/rcnn/__init__.py @@ -1,3 +1,4 @@ # pylint: disable=wildcard-import """Faster R-CNN and Mask R-CNN operators""" from .roi_align import * +from .proposal import * diff --git a/topi/python/topi/vision/rcnn/proposal.py b/topi/python/topi/vision/rcnn/proposal.py new file mode 100644 index 0000000000000..d9b1f98e22868 --- /dev/null +++ b/topi/python/topi/vision/rcnn/proposal.py @@ -0,0 +1,96 @@ +# pylint: disable=invalid-name +"""Proposal operator""" +import math +import tvm + + +def generate_anchor(ratio, scale, base_size): + """Generate anchor""" + w = h = float(base_size) + x_ctr = 0.5 * (w - 1.) + y_ctr = 0.5 * (h - 1.) + size = w * h + size_ratios = math.floor(size / ratio) + new_w = math.floor(math.sqrt(size_ratios) + 0.5) * scale + new_h = math.floor((new_w / scale * ratio) + 0.5) * scale + return (x_ctr - 0.5 * (new_w - 1.0), y_ctr - 0.5 * (new_h - 1.0), + x_ctr + 0.5 * (new_w - 1.0), y_ctr + 0.5 * (new_h - 1.0)) + + +def reg_bbox(x1, y1, x2, y2, dx, dy, dw, dh): + """Bounding box regression function""" + bbox_w = x2 - x1 + 1.0 + bbox_h = y2 - y1 + 1.0 + ctr_x = x1 + 0.5 * (bbox_w - 1.0) + ctr_y = y1 + 0.5 * (bbox_h - 1.0) + + pred_ctr_x = dx * bbox_w + ctr_x + pred_ctr_y = dy * bbox_h + ctr_y + pred_w = tvm.exp(dw) * bbox_w + pred_h = tvm.exp(dh) * bbox_h + + pred_x1 = pred_ctr_x - 0.5 * (pred_w - 1.0) + pred_y1 = pred_ctr_y - 0.5 * (pred_h - 1.0) + pred_x2 = pred_ctr_x + 0.5 * (pred_w - 1.0) + pred_y2 = pred_ctr_y + 0.5 * (pred_h - 1.0) + return pred_x1, pred_y1, pred_x2, pred_y2 + + +def reg_iou(x1, y1, x2, y2, dx1, dy1, dx2, dy2): + """Bounding box regression function""" + pred_x1 = x1 + dx1 + pred_y1 = y1 + dy1 + pred_x2 = x2 + dx2 + pred_y2 = y2 + dy2 + return pred_x1, pred_y1, pred_x2, pred_y2 + + +@tvm.target.generic_func +def proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold, + rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss): + """Proposal operator. + + Parameters + ---------- + cls_prob : tvm.Tensor + 4-D with shape [batch, 2 * num_anchors, height, width] + + bbox_pred : tvm.Tensor + 4-D with shape [batch, 4 * num_anchors, height, width] + + im_info : tvm.Tensor + 2-D with shape [batch, 3] + + scales : list/tuple of float + Scales of anchor windoes. + + ratios : list/tuple of float + Ratios of anchor windoes. + + feature_stride : int + The size of the receptive field each unit in the convolution layer of the rpn, for example + the product of all stride's prior to this layer. + + threshold : float + Non-maximum suppression threshold. + + rpn_pre_nms_top_n : int + Number of top scoring boxes to apply NMS. -1 to use all boxes. + + rpn_post_nms_top_n : int + Number of top scoring boxes to keep after applying NMS to RPN proposals. + + rpn_min_size : int + Minimum height or width in proposal. + + iou_loss : bool + Usage of IoU loss. + + Returns + ------- + out : tvm.Tensor + 2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of + [batch_index, w_start, h_start, w_end, h_end]. + """ + # pylint: disable=unused-argument + raise ValueError("missing register for topi.vision.rcnn.proposal") diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index ce0f706533df5..12557a329fd42 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -1,4 +1,5 @@ """Test code for vision package""" +from __future__ import print_function import math import numpy as np import tvm @@ -206,8 +207,75 @@ def test_roi_align(): verify_roi_align(4, 16, 32, 64, 7, 0.5, 2) +def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): + cls_prob = tvm.placeholder(np_cls_prob.shape) + bbox_pred = tvm.placeholder(np_bbox_pred.shape) + im_info = tvm.placeholder(np_im_info.shape, dtype='int32') + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + out = topi.vision.proposal(cls_prob, bbox_pred, im_info, **attrs) + s = topi.generic.schedule_proposal(out) + f = tvm.build(s, [cls_prob, bbox_pred, im_info, out], device) + tvm_cls_prob = tvm.nd.array(np_cls_prob, ctx=ctx) + tvm_bbox_pred = tvm.nd.array(np_bbox_pred, ctx=ctx) + tvm_im_info = tvm.nd.array(np_im_info, ctx=ctx) + tvm_out = tvm.nd.empty(ctx=ctx, shape=out.shape, dtype=out.dtype) + f(tvm_cls_prob, tvm_bbox_pred, tvm_im_info, tvm_out) + tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4) + + for device in ['cuda']: + check_device(device) + + +def test_proposal(): + attrs = {'scales': (0.5,),'ratios': (0.5,), + 'feature_stride': 16, + 'iou_loss': False, + 'rpn_min_size': 16, + 'threshold': 0.7, + 'rpn_pre_nms_top_n': 200, + 'rpn_post_nms_top_n': 4, + } + np_cls_prob = np.array([[ + [[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]], + [[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]] + ]], dtype='float32') + np_bbox_pred = np.array([[ + [[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]], + [[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]], + [[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]], + [[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]], + ]], dtype='float32') + np_im_info = np.array([[48, 48, 1]], dtype='int32') + np_out = np.array([ + [0., 0., 2.8451548,28.38012, 18.154846], + [0., 0., 15.354933, 41.96971, 41.245064], + [0., 18.019852, 1.0538368, 51.98015, 25.946163], + [0., 27.320923, -1.266357, 55., 24.666357] + ], dtype='float32') + + verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) + + np_out = np.array([ + [ 0., -5.25, -2.5, 21.75, 19.], + [ 0., 11.25, -2., 37.25, 18.5], + [ 0., 26.849998, -2.3000002, 53.45, 18.6], + [ 0., -4.95, 13.799999, 22.25, 35.5] + ], dtype='float32') + + attrs['iou_loss'] = True + verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) + + if __name__ == "__main__": test_nms() test_multibox_prior() test_multibox_detection() test_roi_align() + test_proposal()