forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
195 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,6 @@ | |
|
||
from .multibox import * | ||
from .nms import * | ||
from .rcnn import * | ||
from . import _multibox | ||
from . import _rcnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# pylint: disable=invalid-name, unused-argument | ||
"""Faster R-CNN and Mask R-CNN operations.""" | ||
import topi | ||
from topi.util import get_const_tuple | ||
from .. import op as reg | ||
from ..op import OpPattern | ||
|
||
|
||
@reg.register_compute("vision.roi_align") | ||
def compute_roi_align(attrs, inputs, _, target): | ||
"""Compute definition of roi_align""" | ||
assert attrs.layout == "NCHW" | ||
return [topi.vision.rcnn.roi_align_nchw( | ||
inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size), | ||
spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio)] | ||
|
||
@reg.register_schedule("vision.roi_align") | ||
def schedule_roi_align(_, outs, target): | ||
"""Schedule definition of roi_align""" | ||
with target: | ||
return topi.generic.vision.schedule_roi_align(outs) | ||
|
||
reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
"""Faster R-CNN and Mask R-CNN operations.""" | ||
from . import _make | ||
|
||
|
||
def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='NCHW'): | ||
"""ROI align operator. | ||
Parameters | ||
---------- | ||
data : relay.Expr | ||
4-D tensor with shape [batch, channel, height, width] | ||
rois : relay.Expr | ||
2-D tensor with shape [num_roi, 5]. The last dimension should be in format of | ||
[batch_index, w_start, h_start, w_end, h_end] | ||
pooled_size : list/tuple of two ints | ||
output size | ||
spatial_scale : float | ||
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal | ||
of total stride in convolutional layers, which should be in range (0.0, 1.0] | ||
sample_ratio : int | ||
Optional sampling ratio of ROI align, using adaptive size by default. | ||
Returns | ||
------- | ||
output : relay.Expr | ||
4-D tensor with shape [num_roi, channel, pooled_size, pooled_size] | ||
""" | ||
return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file rcnn_op.cc | ||
* \brief Faster RCNN and Mask RCNN operators | ||
*/ | ||
#include <tvm/relay/op.h> | ||
#include <tvm/relay/op_attr_types.h> | ||
#include <tvm/relay/attrs/vision.h> | ||
|
||
namespace tvm { | ||
namespace relay { | ||
|
||
TVM_REGISTER_NODE_TYPE(ROIAlignAttrs); | ||
|
||
bool ROIAlignRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, | ||
const TypeReporter& reporter) { | ||
auto roi_align_attrs = attrs.as<ROIAlignAttrs>(); | ||
CHECK_EQ(types.size(), 3); | ||
const auto* data = types[0].as<TensorTypeNode>(); | ||
const auto* rois = types[1].as<TensorTypeNode>(); | ||
const auto& dshape = data->shape; | ||
const auto& rshape = rois->shape; | ||
CHECK(roi_align_attrs); | ||
CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D."; | ||
CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D."; | ||
CHECK_EQ(roi_align_attrs->layout, "NCHW") << "ROI Align only supports NCHW layout"; | ||
// assign output type | ||
std::vector<IndexExpr> oshape( | ||
{rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]}); | ||
reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); | ||
return true; | ||
} | ||
|
||
Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spatial_scale, | ||
int sample_ratio, std::string layout) { | ||
auto attrs = make_node<ROIAlignAttrs>(); | ||
attrs->pooled_size = pooled_size; | ||
attrs->spatial_scale = spatial_scale; | ||
attrs->sample_ratio = sample_ratio; | ||
attrs->layout = layout; | ||
static const Op& op = Op::Get("vision.roi_align"); | ||
return CallNode::make(op, {data, rois}, Attrs(attrs), {}); | ||
} | ||
|
||
TVM_REGISTER_API("relay.op.vision._make.roi_align") | ||
.set_body([](const TVMArgs& args, TVMRetValue* rv) { | ||
runtime::detail::unpack_call<Expr, 6>(MakeROIAlign, args, rv); | ||
}); | ||
|
||
RELAY_REGISTER_OP("vision.roi_align") | ||
.describe(R"doc(ROI Align operator. | ||
- **data**: This depends on the `layout` parameter. Input is 4D array of shape | ||
(batch_size, channels, height, width) if `layout` is `NCHW`. | ||
- **rois**: 2D array of shape (num_roi, 5). The last dimension should be in format of | ||
[batch_index, w_start, h_start, w_end, h_end]. | ||
- **out**: This depends on the `layout` parameter. Output is 4D array of shape | ||
(num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`. | ||
)doc" TVM_ADD_FILELINE) | ||
.set_num_inputs(2) | ||
.add_argument("data", "Tensor", "The input tensor.") | ||
.add_argument("rois", "Tensor", "The input rois") | ||
.set_support_level(5) | ||
.add_type_rel("ROIAlign", ROIAlignRel); | ||
|
||
} // namespace relay | ||
} // namespace tvm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters