diff --git a/python/tvm/relay/op/vision/_rcnn.py b/python/tvm/relay/op/vision/_rcnn.py index 6f5097df49d2..24d7de517494 100644 --- a/python/tvm/relay/op/vision/_rcnn.py +++ b/python/tvm/relay/op/vision/_rcnn.py @@ -26,6 +26,44 @@ reg.register_strategy("vision.roi_align", strategy.roi_align_strategy) reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE) +@reg.register_convert_op_layout("vision.roi_align") +def convert_roi_align(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for roi_align op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current roi_align + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data and rois inputs respectively. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + # pylint: disable=import-outside-toplevel + from tvm import relay + data, rois = inputs + new_attrs = dict(attrs) + assert len(desired_layouts) == 2, \ + "A desired layout is expected for both of vision.roi_align's inputs" + desired_data_layout, desired_rois_layout = map(str, desired_layouts) + assert desired_data_layout != "default", "Data layout cannot be default" + assert desired_rois_layout == "default", "Rois layout must be default" + + new_attrs['layout'] = desired_data_layout + # rois layout not change + if desired_data_layout in ['NCHW', 'NHWC']: + return relay.vision.roi_align(data, rois, **new_attrs) + + raise ValueError("Layout %s is not yet supported." % desired_data_layout) + # roi_pool @reg.register_compute("vision.roi_pool") def compute_roi_pool(attrs, inputs, _): diff --git a/python/tvm/relay/op/vision/rcnn.py b/python/tvm/relay/op/vision/rcnn.py index 1798ae946dc0..f4edf91c2ace 100644 --- a/python/tvm/relay/op/vision/rcnn.py +++ b/python/tvm/relay/op/vision/rcnn.py @@ -24,7 +24,7 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N Parameters ---------- data : relay.Expr - 4-D tensor with shape [batch, channel, height, width] + 4-D tensor with shape [batch, channel, height, width] or [batch, height, width, channel] rois : relay.Expr 2-D tensor with shape [num_roi, 5]. The last dimension should be in format of @@ -43,7 +43,8 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N Returns ------- output : relay.Expr - 4-D tensor with shape [num_roi, channel, pooled_size, pooled_size] + 4-D tensor with shape [num_roi, channel, pooled_size, pooled_size] or + [num_roi, pooled_size, pooled_size, channel] """ return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout) diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index f7e1ecb82dcb..382e38d2189f 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -25,6 +25,8 @@ #include #include +#include "../../transforms/infer_layout_util.h" + namespace tvm { namespace relay { @@ -43,14 +45,43 @@ bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK(roi_align_attrs); CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D."; CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D."; - CHECK_EQ(roi_align_attrs->layout, "NCHW") << "ROI Align only supports NCHW layout"; // assign output type - std::vector oshape( - {rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]}); + std::vector oshape; + if (roi_align_attrs->layout == "NCHW") { + oshape = {rshape[0], dshape[1], roi_align_attrs->pooled_size[0], + roi_align_attrs->pooled_size[1]}; + } else { + CHECK_EQ(roi_align_attrs->layout, "NHWC") << "Unexpected ROI Align layout"; + oshape = {rshape[0], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1], + dshape[3]}; + } + reporter->Assign(types[2], TensorType(oshape, data->dtype)); return true; } +template +Array > ROIAlignInferCorrectLayout(const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast(attrs.as()); + + if (new_in_layouts.defined()) { + // Set the roi_align with the new layout. + CHECK_EQ(new_in_layouts.size(), 2); + } + + // We always make other operators to fit the layouts of roi_align layers and the roi_align + // transpose followed by layout_transpose, So this inference ignores all inputs. + + // Layout inference needs to define the layout for all inputs and output data layouts. + // For roi_align, the second inputs is 2-D tensor with shape [num_roi, 5]. + // So, we set the layouts as "N5". + return Array >{{params->layout, Layout("N5")}, {params->layout}}; +} + Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spatial_scale, int sample_ratio, String layout) { auto attrs = make_object(); @@ -78,7 +109,9 @@ RELAY_REGISTER_OP("vision.roi_align") .add_argument("data", "Tensor", "The input tensor.") .add_argument("rois", "Tensor", "The input rois") .set_support_level(5) - .add_type_rel("ROIAlign", ROIAlignRel); + .add_type_rel("ROIAlign", ROIAlignRel) + .set_attr("FInferCorrectLayout", + ROIAlignInferCorrectLayout); TVM_REGISTER_NODE_TYPE(ROIPoolAttrs); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index e71cfdcf4ecc..09836d4581a5 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -713,6 +713,36 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_roi_align_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + rois = relay.var('rois', shape=(32, 5)) + y = relay.vision.roi_align(x, rois, + pooled_size=(14, 14), + spatial_scale=0.0625, + sample_ratio=2, + layout='NCHW') + y = relay.Function([x, rois], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + rois = relay.var('rois', shape=(32, 5)) + x = relay.layout_transform(x, 'NCHW', 'NHWC') + y = relay.vision.roi_align(x, rois, + pooled_size=(14, 14), + spatial_scale=0.0625, + sample_ratio=2, + layout='NHWC') + y = relay.layout_transform(y, 'NHWC', 'NCHW') + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({'vision.roi_align': ['NHWC', 'default']})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_default_keyword(): """ Check that the default keyword selects correct TVM default layout. """ @@ -845,5 +875,6 @@ def expected(): test_qnn_conv_add_convert_layout() test_conv_convert_kernel_layout() test_conv_transpose_convert_layout() + test_roi_align_convert_layout() test_default_keyword() test_different_ops_convert_layout()