Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI, Relay] ROI Pool operator #2811

Merged
merged 1 commit into from
Mar 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,26 @@ struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
}
};

/*! \brief Attributes used in roi_pool operators */
struct ROIPoolAttrs : public tvm::AttrsNode<ROIPoolAttrs> {
Array<IndexExpr> pooled_size;
double spatial_scale;
std::string layout;
TVM_DECLARE_ATTRS(ROIPoolAttrs, "relay.attrs.ROIPoolAttrs") {
TVM_ATTR_FIELD(pooled_size).describe("Output size of roi pool.");
TVM_ATTR_FIELD(spatial_scale)
.describe(
"Ratio of input feature map height (or w) to raw image height (or w). "
"Equals the reciprocal of total stride in convolutional layers, which should be "
"in range (0.0, 1.0]");
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
}
};

/*! \brief Attributes used in yolo reorg operators */
struct YoloReorgAttrs : public tvm::AttrsNode<YoloReorgAttrs> {
Integer stride;
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,14 @@ def _mx_roi_align(inputs, attrs):
return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs)


def _mx_roi_pooling(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
new_attrs["spatial_scale"] = attrs.get_float("spatial_scale")
new_attrs["layout"] = "NCHW"
return _op.vision.roi_pool(inputs[0], inputs[1], **new_attrs)


def _mx_proposal(inputs, attrs):
new_attrs = {}
new_attrs["scales"] = attrs.get_float_tuple("scales", (4.0, 8.0, 16.0, 32.0))
Expand Down Expand Up @@ -496,6 +504,7 @@ def _mx_proposal(inputs, attrs):
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
"_contrib_MultiBoxDetection" : _mx_multibox_detection,
"_contrib_ROIAlign" : _mx_roi_align,
"ROIPooling" : _mx_roi_pooling,
"_contrib_Proposal" : _mx_proposal,
"_contrib_MultiProposal" : _mx_proposal,
# List of missing operators that are present in NNVMv1
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/op/vision/_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ def schedule_roi_align(_, outs, target):

reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("vision.roi_pool")
def compute_roi_pool(attrs, inputs, _, target):
"""Compute definition of roi_pool"""
assert attrs.layout == "NCHW"
return [topi.vision.rcnn.roi_pool_nchw(
inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
spatial_scale=attrs.spatial_scale)]

@reg.register_schedule("vision.roi_pool")
def schedule_roi_pool(_, outs, target):
"""Schedule definition of roi_pool"""
with target:
return topi.generic.vision.schedule_roi_pool(outs)

reg.register_pattern("vision.roi_pool", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("vision.proposal")
def compute_proposal(attrs, inputs, _, target):
"""Compute definition of proposal"""
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/op/vision/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,33 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)


def roi_pool(data, rois, pooled_size, spatial_scale, layout='NCHW'):
"""ROI pool operator.

Parameters
----------
data : relay.Expr
4-D tensor with shape [batch, channel, height, width]

rois : relay.Expr
2-D tensor with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]

pooled_size : list/tuple of two ints
output size

spatial_scale : float
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]

Returns
-------
output : relay.Expr
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
"""
return _make.roi_pool(data, rois, pooled_size, spatial_scale, layout)


def proposal(cls_prob,
bbox_pred,
im_info,
Expand Down
52 changes: 52 additions & 0 deletions src/relay/op/vision/rcnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,58 @@ RELAY_REGISTER_OP("vision.roi_align")
.set_support_level(5)
.add_type_rel("ROIAlign", ROIAlignRel);

TVM_REGISTER_NODE_TYPE(ROIPoolAttrs);

bool ROIPoolRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
auto roi_pool_attrs = attrs.as<ROIPoolAttrs>();
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* rois = types[1].as<TensorTypeNode>();
const auto& dshape = data->shape;
const auto& rshape = rois->shape;
CHECK(roi_pool_attrs);
CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D.";
CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D.";
CHECK_EQ(roi_pool_attrs->layout, "NCHW") << "ROI Pool only supports NCHW layout";
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_pool_attrs->pooled_size[0], roi_pool_attrs->pooled_size[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}

Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale,
std::string layout) {
auto attrs = make_node<ROIPoolAttrs>();
attrs->pooled_size = pooled_size;
attrs->spatial_scale = spatial_scale;
attrs->layout = layout;
static const Op& op = Op::Get("vision.roi_pool");
return CallNode::make(op, {data, rois}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.vision._make.roi_pool")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeROIPool, args, rv);
});

RELAY_REGISTER_OP("vision.roi_pool")
.describe(R"doc(ROI Pool operator.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **rois**: 2D array of shape (num_roi, 5). The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end].
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("rois", "Tensor", "The input rois")
.set_support_level(5)
.add_type_rel("ROIPool", ROIPoolRel);

TVM_REGISTER_NODE_TYPE(ProposalAttrs);

bool ProposalRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
33 changes: 33 additions & 0 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,38 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_
verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2)


def test_roi_pool():
def verify_roi_pool(data_shape, rois_shape, pooled_size, spatial_scale):
data = relay.var("data", relay.ty.TensorType(data_shape, "float32"))
rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32"))
z = relay.vision.roi_pool(data, rois, pooled_size=(pooled_size, pooled_size),
spatial_scale=spatial_scale, layout="NCHW")
zz = relay.ir_pass.infer_type(z)

batch, channel, in_size, _ = data_shape
num_roi = rois_shape[0]
assert zz.checked_type == relay.ty.TensorType(
(num_roi, channel, pooled_size, pooled_size), "float32")

func = relay.Function([data, rois], z)
func = relay.ir_pass.infer_type(func)
np_data = np.random.uniform(size=data_shape).astype("float32")
np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size
np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi).astype('float32')
ref_res = topi.testing.roi_pool_nchw_python(np_data, np_rois, pooled_size=pooled_size,
spatial_scale=spatial_scale)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(np_data, np_rois)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(np_data, np_rois)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-4)

verify_roi_pool((1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0)
verify_roi_pool((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5)


def test_proposal():
def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs):
cls_prob = relay.var("cls_prob", relay.ty.TensorType(np_cls_prob.shape, "float32"))
Expand Down Expand Up @@ -413,6 +445,7 @@ def verify_yolo_reorg(shape, stride):
test_multibox_transform_loc()
test_nms()
test_roi_align()
test_roi_pool()
test_proposal()
test_yolo_reorg_infer_shape()
test_yolo_reorg()
4 changes: 4 additions & 0 deletions topi/python/topi/cuda/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def schedule_multibox_detection(outs):
def schedule_roi_align(outs):
return schedule_pool(outs, 'NCHW')

@generic.schedule_roi_pool.register(["cuda", "gpu"])
def schedule_roi_pool(outs):
return schedule_pool(outs, 'NCHW')

@generic.schedule_proposal.register(["cuda", "gpu"])
def schedule_proposal(outs):
"""Schedule for proposal operator.
Expand Down
17 changes: 17 additions & 0 deletions topi/python/topi/generic/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,23 @@ def schedule_roi_align(outs):
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_roi_pool(outs):
"""Schedule for roi_align

Parameters
----------
outs: Array of Tensor
The computation graph description of roi_pool
in the format of an array of tensors.

Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_proposal(outs):
"""Schedule for proposal operator.
Expand Down
1 change: 1 addition & 0 deletions topi/python/topi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .bilinear_resize_python import bilinear_resize_python
from .reorg_python import reorg_python
from .roi_align_python import roi_align_nchw_python
from .roi_pool_python import roi_pool_nchw_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python
Expand Down
47 changes: 47 additions & 0 deletions topi/python/topi/testing/roi_pool_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# pylint: disable=invalid-name, too-many-nested-blocks
"Roi pool in python"
import math
import numpy as np

def roi_pool_nchw_python(a_np, rois_np, pooled_size, spatial_scale):
"""Roi pool in python"""
_, channel, height, width = a_np.shape
num_roi = rois_np.shape[0]
b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype)

if isinstance(pooled_size, int):
pooled_size_h = pooled_size_w = pooled_size
else:
pooled_size_h, pooled_size_w = pooled_size

for i in range(num_roi):
roi = rois_np[i]
batch_index = int(roi[0])
roi_start_w = int(round(roi[1] * spatial_scale))
roi_start_h = int(round(roi[2] * spatial_scale))
roi_end_w = int(round(roi[3] * spatial_scale))
roi_end_h = int(round(roi[4] * spatial_scale))
roi_h = max(roi_end_h - roi_start_h + 1, 1)
roi_w = max(roi_end_w - roi_start_w + 1, 1)

bin_h = float(roi_h) / pooled_size_h
bin_w = float(roi_w) / pooled_size_w

for ph in range(pooled_size_h):
for pw in range(pooled_size_w):
hstart = int(math.floor(ph * bin_h))
wstart = int(math.floor(pw * bin_w))
hend = int(math.ceil((ph + 1) * bin_h))
wend = int(math.ceil((pw + 1) * bin_w))
hstart = min(max(hstart + roi_start_h, 0), height)
hend = min(max(hend + roi_start_h, 0), height)
wstart = min(max(wstart + roi_start_w, 0), width)
wend = min(max(wend + roi_start_w, 0), width)
is_empty = (hend <= hstart) or (wend <= wstart)

for c in range(channel):
if is_empty:
b_np[i, c, ph, pw] = 0.
else:
b_np[i, c, ph, pw] = np.max(a_np[batch_index, c, hstart:hend, wstart:wend])
return b_np
1 change: 1 addition & 0 deletions topi/python/topi/vision/rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable=wildcard-import
"""Faster R-CNN and Mask R-CNN operators"""
from .roi_align import *
from .roi_pool import *
from .proposal import *
77 changes: 77 additions & 0 deletions topi/python/topi/vision/rcnn/roi_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# pylint: disable=invalid-name
"""ROI pool operator"""
import tvm
from ...util import get_const_tuple

@tvm.target.generic_func
def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
"""ROI pool operator in NCHW layout.

Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, height, width]

rois : tvm.Tensor
2-D with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]

pooled_size : int or list/tuple of two ints
output size, or [out_height, out_width]

spatial_scale : float
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]

Returns
-------
output : tvm.Tensor
4-D with shape [num_roi, channel, pooled_size, pooled_size]
"""
dtype = rois.dtype
_, channel, height, width = get_const_tuple(data.shape)
num_roi, _ = get_const_tuple(rois.shape)

if isinstance(pooled_size, int):
pooled_size_h = pooled_size_w = pooled_size
else:
pooled_size_h, pooled_size_w = pooled_size

def _pool(i, c, ph, pw):
roi = rois[i]
batch_index = roi[0].astype('int32')
roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1], roi[2], roi[3], roi[4]

roi_start_h = tvm.round(roi_start_h * spatial_scale).astype('int32')
roi_start_w = tvm.round(roi_start_w * spatial_scale).astype('int32')
roi_end_h = tvm.round(roi_end_h * spatial_scale).astype('int32')
roi_end_w = tvm.round(roi_end_w * spatial_scale).astype('int32')

# force malformed ROIs to be 1x1
roi_h = tvm.max(roi_end_h - roi_start_h + 1, tvm.const(1, 'int32'))
roi_w = tvm.max(roi_end_w - roi_start_w + 1, tvm.const(1, 'int32'))

bin_h = roi_h.astype(dtype) / pooled_size_h
bin_w = roi_w.astype(dtype) / pooled_size_w

# use epsilon to prevent floating point precision loss in floor/ceil
epsilon = tvm.const(0.00001, dtype)
hstart = tvm.floor(ph * bin_h + epsilon).astype('int32')
wstart = tvm.floor(pw * bin_w + epsilon).astype('int32')
hend = tvm.ceil((ph + 1) * bin_h - epsilon).astype('int32')
wend = tvm.ceil((pw + 1) * bin_w - epsilon).astype('int32')
hstart = tvm.min(tvm.max(hstart + roi_start_h, 0), height)
wstart = tvm.min(tvm.max(wstart + roi_start_w, 0), width)
hend = tvm.min(tvm.max(hend + roi_start_h, 0), height)
wend = tvm.min(tvm.max(wend + roi_start_w, 0), width)

non_empty = tvm.all(hstart < hend, wstart < wend)
min_value = lambda dtype: tvm.if_then_else(non_empty, tvm.min_value(dtype),
tvm.const(0.0, dtype))
# pylint: disable=unnecessary-lambda
_max = tvm.comm_reducer(lambda x, y: tvm.make._OpMax(x, y), min_value, name='max')
rh = tvm.reduce_axis((0, hend - hstart), 'rh')
rw = tvm.reduce_axis((0, wend - wstart), 'rw')
return _max(data[batch_index, c, hstart+rh, wstart+rw], axis=[rh, rw])

return tvm.compute((num_roi, channel, pooled_size_h, pooled_size_w), _pool, tag="pool,roi_pool")
Loading