Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][OP] ROI Align #2618

Merged
merged 1 commit into from
Feb 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,30 @@ struct NMSAttrs : public tvm::AttrsNode<NMSAttrs>{
}
};

/*! \brief Attributes used in roi_align operators */
struct ROIAlignAttrs : public tvm::AttrsNode<ROIAlignAttrs> {
Array<IndexExpr> pooled_size;
double spatial_scale;
int sample_ratio;
std::string layout;
TVM_DECLARE_ATTRS(ROIAlignAttrs, "relay.attrs.ROIAlignAttrs") {
TVM_ATTR_FIELD(pooled_size).describe("Output size of roi align.");
TVM_ATTR_FIELD(spatial_scale)
.describe(
"Ratio of input feature map height (or w) to raw image height (or w). "
"Equals the reciprocal of total stride in convolutional layers, which should be "
"in range (0.0, 1.0]");
TVM_ATTR_FIELD(sample_ratio)
.set_default(-1)
.describe("Optional sampling ratio of ROI align, using adaptive size by default.");
TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
"Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VISION_H_
10 changes: 10 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1)


def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
new_attrs["spatial_scale"] = attrs.get_float("spatial_scale")
new_attrs["sample_ratio"] = attrs.get_int("sample_ratio", -1)
new_attrs["layout"] = "NCHW"
return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs)


# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
Expand Down Expand Up @@ -357,6 +366,7 @@ def _mx_multibox_detection(inputs, attrs):
# vision
"_contrib_MultiBoxPrior" : _mx_multibox_prior,
"_contrib_MultiBoxDetection" : _mx_multibox_detection,
"_contrib_ROIAlign" : _mx_roi_align,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@

from .multibox import *
from .nms import *
from .rcnn import *
from . import _multibox
from . import _rcnn
23 changes: 23 additions & 0 deletions python/tvm/relay/op/vision/_rcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# pylint: disable=invalid-name, unused-argument
"""Faster R-CNN and Mask R-CNN operations."""
import topi
from topi.util import get_const_tuple
from .. import op as reg
from ..op import OpPattern


@reg.register_compute("vision.roi_align")
def compute_roi_align(attrs, inputs, _, target):
"""Compute definition of roi_align"""
assert attrs.layout == "NCHW"
return [topi.vision.rcnn.roi_align_nchw(
inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size),
spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio)]

@reg.register_schedule("vision.roi_align")
def schedule_roi_align(_, outs, target):
"""Schedule definition of roi_align"""
with target:
return topi.generic.vision.schedule_roi_align(outs)

reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE)
32 changes: 32 additions & 0 deletions python/tvm/relay/op/vision/rcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Faster R-CNN and Mask R-CNN operations."""
from . import _make


def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='NCHW'):
"""ROI align operator.

Parameters
----------
data : relay.Expr
4-D tensor with shape [batch, channel, height, width]

rois : relay.Expr
2-D tensor with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]

pooled_size : list/tuple of two ints
output size

spatial_scale : float
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]

sample_ratio : int
Optional sampling ratio of ROI align, using adaptive size by default.

Returns
-------
output : relay.Expr
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size]
"""
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout)
67 changes: 67 additions & 0 deletions src/relay/op/vision/rcnn_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*!
* Copyright (c) 2019 by Contributors
* \file rcnn_op.cc
* \brief Faster RCNN and Mask RCNN operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/vision.h>

namespace tvm {
namespace relay {

TVM_REGISTER_NODE_TYPE(ROIAlignAttrs);

bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
auto roi_align_attrs = attrs.as<ROIAlignAttrs>();
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* rois = types[1].as<TensorTypeNode>();
const auto& dshape = data->shape;
const auto& rshape = rois->shape;
CHECK(roi_align_attrs);
CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D.";
CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D.";
CHECK_EQ(roi_align_attrs->layout, "NCHW") << "ROI Align only supports NCHW layout";
// assign output type
std::vector<IndexExpr> oshape(
{rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]});
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype));
return true;
}

Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale,
int sample_ratio, std::string layout) {
auto attrs = make_node<ROIAlignAttrs>();
attrs->pooled_size = pooled_size;
attrs->spatial_scale = spatial_scale;
attrs->sample_ratio = sample_ratio;
attrs->layout = layout;
static const Op& op = Op::Get("vision.roi_align");
return CallNode::make(op, {data, rois}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.vision._make.roi_align")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeROIAlign, args, rv);
});

RELAY_REGISTER_OP("vision.roi_align")
.describe(R"doc(ROI Align operator.

- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, channels, height, width) if `layout` is `NCHW`.
- **rois**: 2D array of shape (num_roi, 5). The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end].
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("rois", "Tensor", "The input rois")
.set_support_level(5)
masahi marked this conversation as resolved.
Show resolved Hide resolved
.add_type_rel("ROIAlign", ROIAlignRel);

} // namespace relay
} // namespace tvm
35 changes: 35 additions & 0 deletions tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,44 @@ def test_threshold():
test_threshold()


def test_roi_align():
def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ratio):
data = relay.var("data", relay.ty.TensorType(data_shape, "float32"))
rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32"))
z = relay.vision.roi_align(data, rois, pooled_size=(pooled_size, pooled_size),
spatial_scale=spatial_scale, sample_ratio=sample_ratio,
layout="NCHW")
zz = relay.ir_pass.infer_type(z)

batch, channel, in_size, _ = data_shape
num_roi = rois_shape[0]
assert zz.checked_type == relay.ty.TensorType(
(num_roi, channel, pooled_size, pooled_size), "float32")

func = relay.Function([data, rois], z)
func = relay.ir_pass.infer_type(func)
np_data = np.random.uniform(size=data_shape).astype("float32")
np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size
np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi)
ref_res = topi.testing.roi_align_nchw_python(np_data, np_rois, pooled_size=pooled_size,
spatial_scale=spatial_scale,
sample_ratio=sample_ratio)
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(np_data, np_rois)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4)
intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
op_res2 = intrp2.evaluate(func)(np_data, np_rois)
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-4)

verify_roi_align((1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0, sample_ratio=-1)
verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2)


if __name__ == "__main__":
test_resize_infer_type()
test_resize()
test_multibox_prior()
test_multibox_transform_loc()
test_nms()
test_roi_align()
4 changes: 2 additions & 2 deletions topi/python/topi/vision/rcnn/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def _sample(i, c, ph, pw):
if sample_ratio > 0:
roi_bin_grid_h = roi_bin_grid_w = tvm.const(sample_ratio, 'int32')
else:
roi_bin_grid_h = tvm.ceil(roi_h / pooled_size).astype('int32')
roi_bin_grid_w = tvm.ceil(roi_w / pooled_size).astype('int32')
roi_bin_grid_h = tvm.ceil(roi_h / pooled_size_h).astype('int32')
roi_bin_grid_w = tvm.ceil(roi_w / pooled_size_w).astype('int32')

count = roi_bin_grid_h * roi_bin_grid_w
rh = tvm.reduce_axis((0, roi_bin_grid_h))
Expand Down