Skip to content

Commit

Permalink
[Relay, Topi] [Frontend][TFLite, MXNet] ReverseSequence operator (apa…
Browse files Browse the repository at this point in the history
…che#5495)

* TFLite reverse_sequence op

* TFLite add_n implementation

* reverse_sequence implementation

* reverse_sequence implementation

* reverse sequence

* TOPI,Relay,TFLite - Reverse Sequence

Signed-off-by: maheshambule <[email protected]>

* Reverse Sequence small fixes

Signed-off-by: maheshambule <[email protected]>

* lint fixes

Signed-off-by: maheshambule <[email protected]>

* TFLite reverse_sequence op

Signed-off-by: maheshambule

* MXNet SequenceReverse implementation

* clang format

* clang format

* review comment fixes
  • Loading branch information
maheshambule authored and zhiics committed Jul 2, 2020
1 parent f58d50a commit f29e15b
Show file tree
Hide file tree
Showing 16 changed files with 504 additions and 32 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ List of operators
topi.reinterpret
topi.transpose
topi.flip
topi.reverse_sequence
topi.strided_slice
topi.expand_dims
topi.reshape
Expand Down Expand Up @@ -152,6 +153,7 @@ topi
.. autofunction:: topi.reinterpret
.. autofunction:: topi.transpose
.. autofunction:: topi.flip
.. autofunction:: topi.reverse_sequence
.. autofunction:: topi.strided_slice
.. autofunction:: topi.expand_dims
.. autofunction:: topi.reshape
Expand Down
1 change: 1 addition & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ This level enables additional math and transform operators.
tvm.relay.repeat
tvm.relay.tile
tvm.relay.reverse
tvm.relay.reverse_sequence
tvm.relay.unravel_index
tvm.relay.sparse_to_dense

Expand Down
14 changes: 14 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,20 @@ struct ReverseAttrs : public tvm::AttrsNode<ReverseAttrs> {
}
}; // struct ReverseAttrs

/*! \brief Attributes used in reverse_sequence operators */
struct ReverseSequenceAttrs : public tvm::AttrsNode<ReverseSequenceAttrs> {
Integer seq_axis;
Integer batch_axis;

TVM_DECLARE_ATTRS(ReverseSequenceAttrs, "relay.attrs.ReverseSequenceAttrs") {
TVM_ATTR_FIELD(seq_axis).set_default(1).describe(
"The seq axis along which to reverse elements.");
TVM_ATTR_FIELD(batch_axis)
.set_default(0)
.describe("The batch axis along which to slice the tensor.");
}
}; // struct ReverseSequenceAttrs

/*! \brief Attributes used in squeeze operators */
struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
// use axis to make the name numpy compatible.
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,21 @@ def _mx_reverse(inputs, attrs):
return _op.reverse(inputs[0], **new_attrs)


def _mx_sequence_reverse(inputs, attrs):
new_attrs = {}
use_seq_lengths = attrs.get_bool("use_sequence_length")
if not use_seq_lengths:
assert len(inputs) == 1
new_attrs["axis"] = attrs.get_int("axis")
return _op.reverse(inputs[0], **new_attrs)

assert len(inputs) == 2
new_attrs["seq_axis"] = attrs.get_int("axis")
# MXNet assumes batch_axis as 1.
new_attrs["batch_axis"] = 1
return _op.reverse_sequence(inputs[0], inputs[1], **new_attrs)


def _mx_roi_align(inputs, attrs):
new_attrs = {}
new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
Expand Down Expand Up @@ -2078,6 +2093,7 @@ def impl(inputs, input_types):
"take" : _mx_take,
"gather_nd" : _mx_gather_nd,
"reverse" : _mx_reverse,
"SequenceReverse" : _mx_sequence_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
"broadcast_axes": _mx_broadcast_axis,
Expand Down
34 changes: 29 additions & 5 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(self, model, subgraph, exp_tab):
'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor,
'ROUND': self.convert_round,
'RSQRT': self.convert_rsqrt,
'REVERSE_SEQUENCE': self.convert_reverse_sequence,
'SELECT': self.convert_select,
'SHAPE': self.convert_shape,
'SIN': self.convert_sin,
Expand Down Expand Up @@ -2002,6 +2003,33 @@ def convert_transpose(self, op):

return out

def convert_reverse_sequence(self, op):
"""Convert TFLite REVERSE_SEQUENCE"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ReverseSequenceOptions import ReverseSequenceOptions
except ImportError:
raise ImportError("The tflite package must be installed")

if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFLite does not support quantized REVERSE_SEQUENCE operator yet.')

input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

in_expr = self.get_tensor_expr(input_tensors[0])
length_expr = self.get_tensor_expr(input_tensors[1])

assert op.BuiltinOptionsType() == BuiltinOptions.ReverseSequenceOptions
op_options = op.BuiltinOptions()
options = ReverseSequenceOptions()
options.Init(op_options.Bytes, op_options.Pos)
batch_axis = options.BatchDim()
seq_axis = options.SeqDim()

return _op.reverse_sequence(in_expr, length_expr, seq_axis, batch_axis)

def convert_cast(self, op):
"""Convert TFLite CAST"""
try:
Expand Down Expand Up @@ -2700,14 +2728,10 @@ def has_expr(self, input_tensor_idx):
return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))

def get_tensor_expr(self, tensor):
""" Returns constant expr for constant else a tensor expr"""
""" Return the Relay expr for tensor. """
if self.has_expr(tensor.tensor_idx):
# In most cases, we can assume that TOCO fuses elemwise operators
# with constants - it means both will be tensors.
expr = self.get_expr(tensor.tensor_idx)
else:
# However, in some corner cases, the elemwise operator is not fused,
# we can receive as constant.
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_reg.register_injective_schedule("full_like")
_reg.register_injective_schedule("arange")
_reg.register_injective_schedule("reverse")
_reg.register_injective_schedule("reverse_sequence")
_reg.register_injective_schedule("cast")
_reg.register_injective_schedule("cast_like")
_reg.register_injective_schedule("reinterpret")
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relay/op/op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ class TileAttrs(Attrs):
class ReverseAttrs(Attrs):
"""Attributes used in reverse operators"""

@tvm._ffi.register_object("relay.attrs.ReverseSequenceAttrs")
class ReverseSequenceAttrs(Attrs):
"""Attributes used in reverse sequence operators"""


@tvm._ffi.register_object("relay.attrs.SqueezeAttrs")
class SqueezeAttrs(Attrs):
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,53 @@ def reverse(data, axis):
return _make.reverse(data, axis)


def reverse_sequence(data, seq_lengths, seq_axis=1, batch_axis=0):
"""Reverse the tensor for variable length slices.
Input is first sliced along batch axis and then elements are reversed along seq axis.
Parameters
----------
data : relay.Expr
The tensor to be reversed.
seq_lengths : relay.Expr
A 1D Tensor with length a.dims[batch_axis]
Must be one of the following types: int32, int64
if seq_lengths[i] > a.dims[seq_axis], it is rounded to a.dims[seq_axis]
if seq_lengths[i] < 1, it is rounded to 1
seq_axis : int, optional
The axis along which the elements will be reversed. Default is 1.
batch_axis : int, optional
The axis along which the tensor will be sliced. Default is 0.
Returns
-------
ret : relay.Expr
The computed result of same shape and type as of input.
Examples
--------
.. code-block:: python
x = [[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15]]
relay.reverse(x, [1, 2, 3, 4], 0, 1) = [[0, 5, 10, 15],
[4, 1, 6, 11],
[8, 9, 2, 7],
[12, 13, 14, 3]]
relay.reverse(x, [1, 2, 3, 4], 1, 0) = [[0, 1, 2, 3],
[5, 4, 6, 7],
[10, 9, 8, 11],
[15, 14, 13, 12]]
"""
return _make.reverse_sequence(data, seq_lengths, seq_axis, batch_axis)


def where(condition, x, y):
"""Selecting elements from either x or y depending on the value of the
condition.
Expand Down
93 changes: 92 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,8 @@ Array<te::Tensor> ReverseCompute(const Attrs& attrs, const Array<te::Tensor>& in
const Type& out_type) {
const ReverseAttrs* param = attrs.as<ReverseAttrs>();
CHECK(param != nullptr);
return {topi::flip(inputs[0], param->axis)};
// pass empty seq_length tensor to reverse_sequence
return {topi::reverse_sequence(inputs[0], te::Tensor(), param->axis)};
}

Expr MakeReverse(Expr data, int axis) {
Expand All @@ -1423,6 +1424,96 @@ RELAY_REGISTER_OP("reverse")
.set_attr<FTVMCompute>("FTVMCompute", ReverseCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// reverse sequence operator
TVM_REGISTER_NODE_TYPE(ReverseSequenceAttrs);

bool ReverseSequenceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, seq_lengths, result]
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();

if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "reverse_sequence: expect input type to be TensorType but get " << types[0];
return false;
}

const auto* seq_lengths = types[1].as<TensorTypeNode>();
if (seq_lengths == nullptr) {
CHECK(types[1].as<IncompleteTypeNode>())
<< "reverse_sequence: expect input type to be TensorType but get " << types[1];
return false;
}

const int seq_lengths_dim = static_cast<int>(seq_lengths->shape.size());
CHECK(seq_lengths_dim == 1) << "For reverse_sequnece, seq_lengths must be a 1D vector";
CHECK(seq_lengths->dtype.is_int())
<< "For reverse_sequnece, seq_lengths must be tensor of integer";

const auto* param = attrs.as<ReverseSequenceAttrs>();
const int ndim = static_cast<int>(data->shape.size());
int batch_axis = param->batch_axis;
CHECK(-ndim <= batch_axis && batch_axis < ndim)
<< "reverse_sequence only accepts `batch_axis` in [-data.ndim, data.ndim - 1]"
<< ", but got batch_axis = " << batch_axis << ", and data.ndim = " << ndim;

if (batch_axis < 0) {
batch_axis = static_cast<int>(data->shape.size()) + batch_axis;
}
CHECK(reporter->Assert(seq_lengths->shape[0] == data->shape[batch_axis]))
<< "For reverse_sequnece seq_lengths size should match with dimension of batch axis"
<< ", but got dimension of batch_axis = " << data->shape[batch_axis]
<< ", and seq_length size = " << seq_lengths->shape[0];

const int seq_axis = param->seq_axis;
CHECK(-ndim <= seq_axis && seq_axis < ndim)
<< "reverse_sequnece only accepts `seq_axis` in [-data.ndim, data.ndim - 1]"
<< ", but got seq_axis = " << seq_axis << ", and data.ndim = " << ndim;

reporter->Assign(types[2], types[0]);
return true;
}

Array<te::Tensor> ReverseSequenceCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const ReverseSequenceAttrs* param = attrs.as<ReverseSequenceAttrs>();
CHECK(param != nullptr);
return {topi::reverse_sequence(inputs[0], inputs[1], param->seq_axis, param->batch_axis)};
}

Expr MakeReverseSequence(Expr data, Expr seq_lengths, int seq_axis, int batch_axis) {
auto attrs = make_object<ReverseSequenceAttrs>();
attrs->seq_axis = seq_axis;
attrs->batch_axis = batch_axis;
static const Op& op = Op::Get("reverse_sequence");
return Call(op, {data, seq_lengths}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.reverse_sequence").set_body_typed(MakeReverseSequence);

RELAY_REGISTER_OP("reverse_sequence")
.describe(R"code(Reverses the tensor for variable length slices.
Input is first sliced along batch axis and then elements are reversed along seq axis.
- **data**: The input data to the operator.
- **seq_lengths**: A 1D Tensor with length data.dims[batch_axis].
- **seq_axis**: The axis along which the elements will be reversed. Default is 1.
- **batch_axis**: The axis along which the tensor will be sliced. Default is 0.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.set_attrs_type<ReverseSequenceAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("seq_lengths", "Tensor", "A 1D Tensor with length data.dims[batch_axis]")
.set_support_level(3)
.add_type_rel("ReverseSequence", ReverseSequenceRel)
.set_attr<FTVMCompute>("FTVMCompute", ReverseSequenceCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

// where operator
bool WhereRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
Expand Down
34 changes: 34 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,39 @@ def verify(x_shape, y_shape, axes):
verify((3, 4), (2, 3), (0))
verify((3, 4), (2, 3), (-1))

def test_forward_sequence_reverse():
def verify(shape, seq_lengths, use_seq_lengths, seq_axis):
data_np = np.random.uniform(size=shape).astype("float32")

ref_res_args = [mx.nd.array(data_np), None, use_seq_lengths, seq_axis]
mx_sym_args = [mx.sym.var("data"), None, use_seq_lengths, seq_axis]
from_mxnet_args = [{"data": shape}, {"data": "float32"}]
in_data= [data_np]

if use_seq_lengths and seq_lengths:
seq_lengths_np = np.array(seq_lengths).astype("int32")
ref_res_args[1] = mx.nd.array(seq_lengths_np)
mx_sym_args[1] = mx.sym.var("seq_lengths")
from_mxnet_args[0].update({"seq_lengths": seq_lengths_np.shape})
from_mxnet_args[1].update({"seq_lengths": "int32"})
in_data.append(seq_lengths_np)

ref_res = mx.nd.SequenceReverse(*ref_res_args)
mx_sym = mx.sym.SequenceReverse(*mx_sym_args)
mod, _ = relay.frontend.from_mxnet(mx_sym, *from_mxnet_args)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(*in_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())

verify((3, 4), [1, 2, 3, 1], True, 0)
verify((3, 4), None, False, 0)
verify((3, 5, 5, 6), [1, 2, 3, 1, 3], True, 0)
# MXNet accepts axis value as 0 only
# verify((3, 4, 5, 6), None, False, 2)

def test_forward_l2_normalize():
data = mx.sym.var('data')
mx_sym = mx.sym.L2Normalization(data, mode="channel")
Expand Down Expand Up @@ -1232,6 +1265,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
test_forward_scalar_ops()
test_forward_slice_like()
test_forward_slice_axis()
test_forward_sequence_reverse()
test_forward_l2_normalize()
test_forward_shape_array()
test_forward_squeeze()
Expand Down
Loading

0 comments on commit f29e15b

Please sign in to comment.