Skip to content

Commit

Permalink
[Relay][Frontend] Add reverse op to relay (apache#2800)
Browse files Browse the repository at this point in the history
* start adding reverse

* reverse updated

* reverse uses topi::flip

* typo fixed

* comment addressed

* exp simplified
  • Loading branch information
Laurawly authored and wweic committed Mar 20, 2019
1 parent d07c747 commit 2d2c610
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ This level enables additional math and transform operators.
tvm.relay.stack
tvm.relay.repeat
tvm.relay.tile
tvm.relay.reverse


**Level 4: Broadcast and Reductions**
Expand Down Expand Up @@ -229,6 +230,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.stack
.. autofunction:: tvm.relay.repeat
.. autofunction:: tvm.relay.tile
.. autofunction:: tvm.relay.reverse


Level 4 Definitions
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ struct TileAttrs : public tvm::AttrsNode<TileAttrs> {
}
}; // struct TileAttrs

/*! \brief Attributes used in reverse operators */
struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
Integer axis;
TVM_DECLARE_ATTRS(ReverseAttrs, "relay.attrs.ReverseAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis along which to reverse elements.");
}
}; // struct ReverseAttrs

/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible.
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,13 @@ def _mx_tile(inputs, attrs):
return _op.tile(inputs[0], **new_attrs)


def _mx_reverse(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis")
return _op.reverse(inputs[0], **new_attrs)


def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
Expand Down Expand Up @@ -612,6 +619,7 @@ def _mx_l2_normalize(inputs, attrs):
"_arange" : _mx_arange,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"reverse" : _mx_reverse,
"BlockGrad" : _mx_BlockGrad,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
_reg.register_schedule("full", schedule_injective)
_reg.register_schedule("full_like", schedule_injective)
_reg.register_schedule("arange", schedule_injective)
_reg.register_schedule("reverse", schedule_injective)
_reg.register_schedule("repeat", schedule_broadcast)
_reg.register_schedule("tile", schedule_broadcast)
_reg.register_schedule("cast", schedule_injective)
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,35 @@ def tile(data, reps):
return _make.tile(data, reps)


def reverse(data, axis):
"""Reverses the order of elements along given axis while preserving array shape.
By default, repeat flattens the input array into 1-D and then repeats the elements.
Parameters
----------
data : relay.Expr
The input data to the operator.
axis: int
The axis along which to reverse elements.
Returns
-------
ret : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
x = [[1., 2.], [3., 4.]]
relay.reverse(x, axis=0) = [[3., 4.], [1., 2.]]
relay.reverse(x, axis=1) = [[2., 1.], [4., 3.]]
"""
return _make.reverse(data, axis)


def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the
condition.
Expand Down
67 changes: 65 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1086,8 +1086,8 @@ Array<Tensor> RepeatCompute(const Attrs& attrs,
}

Expr MakeRepeat(Expr data,
int repeats,
int axis) {
int repeats,
int axis) {
auto attrs = make_node<RepeatAttrs>();
attrs->repeats = repeats;
attrs->axis = axis;
Expand Down Expand Up @@ -1204,6 +1204,69 @@ RELAY_REGISTER_OP("tile")
.set_attr<FTVMCompute>("FTVMCompute", TileCompute)
.set_attr<TOpPattern>("TOpPattern", kBroadcast);

// reverse operator
TVM_REGISTER_NODE_TYPE(ReverseAttrs);

bool ReverseRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "reverse: expect input type to be TensorType but get "
<< types[0];
return false;
}
const auto* param = attrs.as<ReverseAttrs>();
const int ndim = static_cast<int>(data->shape.size());
const int axis = param->axis;
CHECK(-ndim <= axis && axis < ndim)
<< "reverse only accepts `axis` in [-data.ndim, data.ndim - 1]"
<< ", but got axis = " << axis
<< ", and data.ndim = " << ndim;
reporter->Assign(types[1], types[0]);
return true;
}

Array<Tensor> ReverseCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const ReverseAttrs *param = attrs.as<ReverseAttrs>();
CHECK(param != nullptr);
return { topi::flip(inputs[0], param->axis) };
}

Expr MakeReverse(Expr data,
int axis) {
auto attrs = make_node<ReverseAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("reverse");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.reverse")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReverse, args, rv);
});

RELAY_REGISTER_OP("reverse")
.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
- **data**: The input data to the operator.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.Reverse")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reverse", ReverseRel)
.set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// where operator
bool WhereRel(const Array<Type>& types,
int num_inputs,
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,25 @@ def verify_arange(start, stop, step):
verify_arange(20, 1, -1.5)


def test_reverse():
def verify_reverse(dshape, axis):
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.reverse(x, axis=axis)
zz = relay.ir_pass.infer_type(z)

func = relay.Function([x], z)
x_data = np.random.uniform(low=-1, high=1, size=dshape).astype("float32")
ref_res = np.flip(x_data, axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reverse((2, 3, 4), 1)
verify_reverse((4, 7), 0)
verify_reverse((2, 3, 4), -1)


if __name__ == "__main__":
test_cast()
test_zeros_ones()
Expand All @@ -515,3 +534,4 @@ def verify_arange(start, stop, step):
test_squeeze_bad_axes_infer_type()
test_split_infer_type()
test_arange()
test_reverse()

0 comments on commit 2d2c610

Please sign in to comment.