diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index f20c443e8404..e16da29fdf8c 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -99,6 +99,7 @@ This level enables additional math and transform operators. tvm.relay.stack tvm.relay.repeat tvm.relay.tile + tvm.relay.reverse **Level 4: Broadcast and Reductions** @@ -229,6 +230,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.stack .. autofunction:: tvm.relay.repeat .. autofunction:: tvm.relay.tile +.. autofunction:: tvm.relay.reverse Level 4 Definitions diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 5382017d8c1c..326c9f06eb79 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -146,6 +146,15 @@ struct TileAttrs : public tvm::AttrsNode { } }; // struct TileAttrs +/*! \brief Attributes used in reverse operators */ +struct ReverseAttrs : public tvm::AttrsNode { + Integer axis; + TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") { + TVM_ATTR_FIELD(axis).set_default(NullValue()) + .describe("The axis along which to reverse elements."); + } +}; // struct ReverseAttrs + /*! \brief Attributes used in squeeze operators */ struct SqueezeAttrs : public tvm::AttrsNode { // use axis to make the name numpy compatible. diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 93bd8efc6752..42f3e8a7c148 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -422,6 +422,13 @@ def _mx_tile(inputs, attrs): return _op.tile(inputs[0], **new_attrs) +def _mx_reverse(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["axis"] = attrs.get_int("axis") + return _op.reverse(inputs[0], **new_attrs) + + def _mx_roi_align(inputs, attrs): new_attrs = {} new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size") @@ -612,6 +619,7 @@ def _mx_l2_normalize(inputs, attrs): "_arange" : _mx_arange, "repeat" : _mx_repeat, "tile" : _mx_tile, + "reverse" : _mx_reverse, "BlockGrad" : _mx_BlockGrad, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 2b43c21f8e10..72fbca967555 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -19,6 +19,7 @@ _reg.register_schedule("full", schedule_injective) _reg.register_schedule("full_like", schedule_injective) _reg.register_schedule("arange", schedule_injective) +_reg.register_schedule("reverse", schedule_injective) _reg.register_schedule("repeat", schedule_broadcast) _reg.register_schedule("tile", schedule_broadcast) _reg.register_schedule("cast", schedule_injective) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index b77269843c91..37aace5afe4a 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -385,6 +385,35 @@ def tile(data, reps): return _make.tile(data, reps) +def reverse(data, axis): + """Reverses the order of elements along given axis while preserving array shape. + By default, repeat flattens the input array into 1-D and then repeats the elements. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + axis: int + The axis along which to reverse elements. + + Returns + ------- + ret : relay.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + x = [[1., 2.], [3., 4.]] + relay.reverse(x, axis=0) = [[3., 4.], [1., 2.]] + + relay.reverse(x, axis=1) = [[2., 1.], [4., 3.]] + """ + return _make.reverse(data, axis) + + def where(condition, x, y): """Selecting elements from either x or y depending on the value of the condition. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 7aa98e3fd87a..36b93ee5d39f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1086,8 +1086,8 @@ Array RepeatCompute(const Attrs& attrs, } Expr MakeRepeat(Expr data, - int repeats, - int axis) { + int repeats, + int axis) { auto attrs = make_node(); attrs->repeats = repeats; attrs->axis = axis; @@ -1204,6 +1204,69 @@ RELAY_REGISTER_OP("tile") .set_attr("FTVMCompute", TileCompute) .set_attr("TOpPattern", kBroadcast); +// reverse operator +TVM_REGISTER_NODE_TYPE(ReverseAttrs); + +bool ReverseRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "reverse: expect input type to be TensorType but get " + << types[0]; + return false; + } + const auto* param = attrs.as(); + const int ndim = static_cast(data->shape.size()); + const int axis = param->axis; + CHECK(-ndim <= axis && axis < ndim) + << "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]" + << ", but got axis = " << axis + << ", and data.ndim = " << ndim; + reporter->Assign(types[1], types[0]); + return true; +} + +Array ReverseCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const ReverseAttrs *param = attrs.as(); + CHECK(param != nullptr); + return { topi::flip(inputs[0], param->axis) }; +} + +Expr MakeReverse(Expr data, + int axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("reverse"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.reverse") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeReverse, args, rv); +}); + +RELAY_REGISTER_OP("reverse") +.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. + +- **data**: The input data to the operator. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.Reverse") +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(3) +.add_type_rel("Reverse", ReverseRel) +.set_attr("FTVMCompute", ReverseCompute) +.set_attr("TOpPattern", kInjective); + // where operator bool WhereRel(const Array& types, int num_inputs, diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index e762c7d3a1a0..eee0bcfab008 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -491,6 +491,25 @@ def verify_arange(start, stop, step): verify_arange(20, 1, -1.5) +def test_reverse(): + def verify_reverse(dshape, axis): + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.reverse(x, axis=axis) + zz = relay.ir_pass.infer_type(z) + + func = relay.Function([x], z) + x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32") + ref_res = np.flip(x_data, axis) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_reverse((2, 3, 4), 1) + verify_reverse((4, 7), 0) + verify_reverse((2, 3, 4), -1) + + if __name__ == "__main__": test_cast() test_zeros_ones() @@ -515,3 +534,4 @@ def verify_arange(start, stop, step): test_squeeze_bad_axes_infer_type() test_split_infer_type() test_arange() + test_reverse()