Skip to content

Commit

Permalink
[TOPI][CUDA] Add faster-rcnn proposal op
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jan 18, 2019
1 parent 237dbf2 commit 38f22a4
Show file tree
Hide file tree
Showing 8 changed files with 570 additions and 0 deletions.
1 change: 1 addition & 0 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from . import ssd
from .ssd import *
from .nms import *
from .rcnn import *
3 changes: 3 additions & 0 deletions topi/python/topi/cuda/rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pylint: disable=wildcard-import
"""Faster R-CNN and Mask R-CNN operators"""
from .proposal import *
355 changes: 355 additions & 0 deletions topi/python/topi/cuda/rcnn/proposal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
# pylint: disable=invalid-name, singleton-comparison
"""Proposal operator"""
import math
import tvm
from ...vision.rcnn import proposal, generate_anchor, reg_bbox, reg_iou
from ...util import get_const_tuple, get_const_int


def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, ratios,
feature_stride, rpn_min_size, iou_loss):
"""Predict bounding boxes based on anchors, scores and deltas.
Parameters
----------
cls_prob_buf : tvm.schedule.Buffer
4-D with shape [batch, 2 * num_anchors, height, width]
bbox_pred_buf : tvm.schedule.Buffer
4-D with shape [batch, 4 * num_anchors, height, width]
im_info_buf : tvm.schedule.Buffer
2-D with shape [batch, 3]
out_buf : tvm.schedule.Buffer
3-D with shape [batch, num_bbox, 5]
The last dimension is in format of [w_start, h_start, w_end, h_end, score]
scales : list/tuple of float
Scales of anchor windoes.
ratios : list/tuple of float
Ratios of anchor windoes.
feature_stride : int
The size of the receptive field each unit in the convolution layer of the rpn, for example
the product of all stride's prior to this layer.
rpn_min_size : int
Minimum height or width in proposal.
iou_loss : bool
Usage of IoU loss.
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch, num_anchors, height, width = get_const_tuple(cls_prob_buf.shape)
num_anchors //= 2
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = (batch * height * width) // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
tid = bx * max_threads + tx
ib = tvm.ir_builder.create()
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)

p_score = ib.buffer_ptr(cls_prob_buf)
p_delta = ib.buffer_ptr(bbox_pred_buf)
p_im_info = ib.buffer_ptr(im_info_buf)
p_out = ib.buffer_ptr(out_buf)

with ib.if_scope(tid < batch * height * width):
w = tid % width
h = (tid // width) % height
b = tid // width // height

for k in range(num_anchors):
out_index = tid * num_anchors + k
ratio = ratios[k // len(scales)]
scale = scales[k % len(scales)]
anchor = generate_anchor(ratio, scale, feature_stride)
im_height = p_im_info[b * 3]
im_width = p_im_info[b * 3 + 1]
x1 = anchor[0] + w * feature_stride
y1 = anchor[1] + h * feature_stride
x2 = anchor[2] + w * feature_stride
y2 = anchor[3] + h * feature_stride

delta = [p_delta[((((b * num_anchors + k) * 4 + i) * height + h) * width + w)]
for i in range(4)]
regression_func = reg_iou if iou_loss else reg_bbox
pred_x1, pred_y1, pred_x2, pred_y2 = regression_func(x1, y1, x2, y2, *delta)

pred_x1 = tvm.max(tvm.min(pred_x1, im_width - 1.0), 0.0)
pred_y1 = tvm.max(tvm.min(pred_y1, im_height - 1.0), 0.0)
pred_x2 = tvm.max(tvm.min(pred_x2, im_width - 1.0), 0.0)
pred_y2 = tvm.max(tvm.min(pred_y2, im_height - 1.0), 0.0)

real_height = (im_height / feature_stride).astype('int32')
real_width = (im_width / feature_stride).astype('int32')

bbox_w = pred_x2 - pred_x1 + 1.0
bbox_h = pred_y2 - pred_y1 + 1.0
min_size = p_im_info[b * 3 + 2] * rpn_min_size

pred_score = p_score[((b * num_anchors * 2 + num_anchors + k) * height + h) * width + w]
pred_score = tvm.select(tvm.any(h >= real_height, w >= real_width), -1.0, pred_score)

p_out[out_index * 5 + 0] = pred_x1
p_out[out_index * 5 + 1] = pred_y1
p_out[out_index * 5 + 2] = pred_x2
p_out[out_index * 5 + 3] = pred_y2
p_out[out_index * 5 + 4] = pred_score

with ib.if_scope(tvm.any(bbox_w < min_size, bbox_h < min_size)):
p_out[out_index * 5 + 0] -= min_size / 2.0
p_out[out_index * 5 + 1] -= min_size / 2.0
p_out[out_index * 5 + 2] += min_size / 2.0
p_out[out_index * 5 + 3] += min_size / 2.0
p_out[out_index * 5 + 4] = -1.0

return ib.get()


def argsort_ir(data_buf, out_index_buf):
"""Batched odd-even transposition sort.
Parameters
----------
data_buf : tvm.schedule.Buffer
2-D with shape [batch, num_bbox]
out_index_buf : tvm.schedule.Buffer
2-D with shape [batch, num_bbox]. Indices of data in sorted order.
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch, num_bbox = get_const_tuple(data_buf.shape)
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
p_data = ib.buffer_ptr(data_buf)
index_out = ib.buffer_ptr(out_index_buf)
nthread_tx = max_threads
nthread_bx = num_bbox // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * nthread_tx + tx

with ib.for_range(0, batch, for_type="unroll") as b:
start = b * num_bbox
with ib.if_scope(tid < num_bbox):
index_out[start + tid] = tid

with ib.for_range(0, num_bbox) as k:
with ib.if_scope(tid < (num_bbox + 1) // 2):
offset = start + 2 * tid + (k % 2)
with ib.if_scope(
tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
temp_data[0] = p_data[offset]
p_data[offset] = p_data[offset + 1]
p_data[offset + 1] = temp_data[0]
temp_index[0] = index_out[offset]
index_out[offset] = index_out[offset + 1]
index_out[offset + 1] = temp_index[0]

return ib.get()


def nms_ir(sorted_bbox_buf, out_buf, nms_threshold):
"""Non-maximum supression.
Parameters
----------
sorted_bbox_buf : tvm.schedule.Buffer
3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
[w_start, h_start, w_end, h_end, score].
out_buf : tvm.schedule.Buffer
2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
nms_threshold : float
Non-maximum suppression threshold.
Returns
-------
stmt : Stmt
The result IR statement.
"""
def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
"""Calculate overlap of two boxes.
"""
w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
- tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]) + 1.0)
h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
- tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]) + 1.0)
i = w * h
u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx] + 1.0) * \
(out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1] + 1.0) + \
(out_tensor[box_b_idx + 2] - out_tensor[box_b_idx] + 1.0) * \
(out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1] + 1.0) - i
return i / u

batch, num_bbox = get_const_tuple(out_buf.shape)
max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads))
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib = tvm.ir_builder.create()
p_data = ib.buffer_ptr(sorted_bbox_buf)
p_out = ib.buffer_ptr(out_buf)
nthread_tx = max_threads
nthread_bx = num_bbox // max_threads + 1
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
j = bx * max_threads + tx
with ib.for_range(0, batch, for_type="unroll", name="n") as b:
start = b * num_bbox
with ib.if_scope(j < num_bbox):
p_out[start + j] = False

with ib.for_range(0, num_bbox - 1) as i:
with ib.if_scope(tvm.all(j < num_bbox, j > i, p_out[start + i] == False)):
iou = calculate_overlap(p_data, (start + i) * 5, (start + j) * 5)
with ib.if_scope(iou > nms_threshold):
p_out[start + j] = True
return ib.get()


def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf):
"""Copy output after applying nms to continuous memory.
Parameters
----------
sorted_bbox_buf : tvm.schedule.Buffer
3-D with shape [batch, num_bbox, 5]. The last dimension is in format of
[w_start, h_start, w_end, h_end, score].
remove_mask_buf : tvm.schedule.Buffer
2-D with shape [batch, num_bbox]. Boolean mask of whether a bounding box should be removed.
out_buf : tvm.schedule.Buffer
2-D with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
[batch_index, w_start, h_start, w_end, h_end].
Returns
-------
stmt : Stmt
The result IR statement.
"""
batch, num_bbox, _ = get_const_tuple(sorted_bbox_buf.shape)
rpn_post_nms_top_n = get_const_int(out_buf.shape[0]) // batch
nthread_tx = batch
tx = tvm.thread_axis("threadIdx.x")
ib = tvm.ir_builder.create()
ib.scope_attr(tx, "thread_extent", nthread_tx)
i = ib.allocate('int32', (1,), 'i', scope='local')
i[0] = 0
p_sorted_bbox = ib.buffer_ptr(sorted_bbox_buf)
p_remove = ib.buffer_ptr(remove_mask_buf)
p_out = ib.buffer_ptr(out_buf)
b = tx

nkeep = ib.allocate('int32', (1,), 'nkeep', scope='local')
nkeep[0] = 0 # number of bbox after nms

with ib.for_range(0, num_bbox) as j:
with ib.if_scope(p_remove[b * num_bbox + j] == False):
nkeep[0] += 1
with ib.if_scope(nkeep[0] > 0):
with ib.for_range(0, tvm.ceil(
tvm.const(rpn_post_nms_top_n, 'float32') / nkeep[0]).astype('int32')):
with ib.for_range(0, num_bbox) as j:
offset_j = (b * num_bbox + j) * 5
offset_i = (b * rpn_post_nms_top_n + i[0]) * 5
with ib.if_scope(tvm.all(i[0] < rpn_post_nms_top_n,
p_remove[(b*num_bbox+j)] == False)):
p_out[offset_i] = tvm.expr.Cast('float32', b)
with ib.for_range(0, 4, for_type='unroll') as k:
p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k]
i[0] = i[0] + 1

body = ib.get()
return body


@proposal.register("cuda")
def proposal_cuda(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss):
"""Proposal operator.
Parameters
----------
cls_prob : tvm.Tensor
4-D with shape [batch, 2 * num_anchors, height, width]
bbox_pred : tvm.Tensor
4-D with shape [batch, 4 * num_anchors, height, width]
im_info : tvm.Tensor
2-D with shape [batch, 3]
scales : list/tuple of float
Scales of anchor windoes.
ratios : list/tuple of float
Ratios of anchor windoes.
feature_stride : int
The size of the receptive field each unit in the convolution layer of the rpn, for example
the product of all stride's prior to this layer.
threshold : float
Non-maximum suppression threshold.
rpn_pre_nms_top_n : int
Number of top scoring boxes to apply NMS. -1 to use all boxes.
rpn_post_nms_top_n : int
Number of top scoring boxes to keep after applying NMS to RPN proposals.
rpn_min_size : int
Minimum height or width in proposal.
iou_loss : bool
Usage of IoU loss.
Returns
-------
out : tvm.Tensor
2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of
[batch_index, w_start, h_start, w_end, h_end].
"""

batch, _, height, width = get_const_tuple(cls_prob.shape)
num_anchors = len(scales) * len(ratios)
num_bbox = height * width * num_anchors
rpn_pre_nms_top_n = min(rpn_pre_nms_top_n, num_bbox) if rpn_pre_nms_top_n > 0 else num_bbox

bbox = tvm.extern((batch, num_bbox, 5), [cls_prob, bbox_pred, im_info], lambda ins, outs:
predict_bbox_ir(ins[0], ins[1], ins[2], outs[0], scales, ratios,
feature_stride, rpn_min_size, iou_loss),
dtype=bbox_pred.dtype)
score = tvm.compute((batch, num_bbox), lambda b, i: bbox[b, i, 4], tag='bbox_score')
sorted_index = tvm.extern([score.shape], [score],
lambda ins, outs: argsort_ir(ins[0], outs[0]),
dtype='int32')
sorted_bbox = tvm.compute((batch, rpn_pre_nms_top_n, 5),
lambda b, i, j: bbox[b, sorted_index[b, i], j], tag='sorted_bbox')
nms_remove_mask = tvm.extern((batch, rpn_pre_nms_top_n), [sorted_bbox],
lambda ins, outs: nms_ir(ins[0], outs[0], threshold),
dtype='bool')
nms_out = tvm.extern((batch * rpn_post_nms_top_n, 5), [sorted_bbox, nms_remove_mask],
lambda ins, outs: prepare_output_ir(ins[0], ins[1], outs[0]),
dtype=sorted_bbox.dtype)
return nms_out
29 changes: 29 additions & 0 deletions topi/python/topi/cuda/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,32 @@ def schedule_multibox_detection(outs):
@generic.schedule_roi_align.register(["cuda", "gpu"])
def schedule_roi_align(outs):
return schedule_pool(outs, 'NCHW')

@generic.schedule_proposal.register(["cuda", "gpu"])
def schedule_proposal(outs):
"""Schedule for proposal operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of roi_align
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if op.tag in ['bbox_score', 'sorted_bbox']:
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
Loading

0 comments on commit 38f22a4

Please sign in to comment.