diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 65f2375341c1..53f2f3c73afa 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -46,6 +46,7 @@ List of operators topi.reinterpret topi.transpose topi.flip + topi.reverse_sequence topi.strided_slice topi.expand_dims topi.reshape @@ -152,6 +153,7 @@ topi .. autofunction:: topi.reinterpret .. autofunction:: topi.transpose .. autofunction:: topi.flip +.. autofunction:: topi.reverse_sequence .. autofunction:: topi.strided_slice .. autofunction:: topi.expand_dims .. autofunction:: topi.reshape diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index cef96ef65931..86e0c0de4f95 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -132,6 +132,7 @@ This level enables additional math and transform operators. tvm.relay.repeat tvm.relay.tile tvm.relay.reverse + tvm.relay.reverse_sequence tvm.relay.unravel_index tvm.relay.sparse_to_dense diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cbc60340d924..750a8a43163c 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -194,6 +194,20 @@ struct ReverseAttrs : public tvm::AttrsNode { } }; // struct ReverseAttrs +/*! \brief Attributes used in reverse_sequence operators */ +struct ReverseSequenceAttrs : public tvm::AttrsNode { + Integer seq_axis; + Integer batch_axis; + + TVM_DECLARE_ATTRS(ReverseSequenceAttrs, "relay.attrs.ReverseSequenceAttrs") { + TVM_ATTR_FIELD(seq_axis).set_default(1).describe( + "The seq axis along which to reverse elements."); + TVM_ATTR_FIELD(batch_axis) + .set_default(0) + .describe("The batch axis along which to slice the tensor."); + } +}; // struct ReverseSequenceAttrs + /*! \brief Attributes used in squeeze operators */ struct SqueezeAttrs : public tvm::AttrsNode { // use axis to make the name numpy compatible. diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1d8842d69d12..f77c3b5f4caa 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -819,6 +819,21 @@ def _mx_reverse(inputs, attrs): return _op.reverse(inputs[0], **new_attrs) +def _mx_sequence_reverse(inputs, attrs): + new_attrs = {} + use_seq_lengths = attrs.get_bool("use_sequence_length") + if not use_seq_lengths: + assert len(inputs) == 1 + new_attrs["axis"] = attrs.get_int("axis") + return _op.reverse(inputs[0], **new_attrs) + + assert len(inputs) == 2 + new_attrs["seq_axis"] = attrs.get_int("axis") + # MXNet assumes batch_axis as 1. + new_attrs["batch_axis"] = 1 + return _op.reverse_sequence(inputs[0], inputs[1], **new_attrs) + + def _mx_roi_align(inputs, attrs): new_attrs = {} new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size") @@ -2078,6 +2093,7 @@ def impl(inputs, input_types): "take" : _mx_take, "gather_nd" : _mx_gather_nd, "reverse" : _mx_reverse, + "SequenceReverse" : _mx_sequence_reverse, "squeeze" : _mx_squeeze, "broadcast_axis": _mx_broadcast_axis, "broadcast_axes": _mx_broadcast_axis, diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index b79d8e1289e5..2fc82d74a08d 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -130,6 +130,7 @@ def __init__(self, model, subgraph, exp_tab): 'RESIZE_NEAREST_NEIGHBOR': self.convert_resize_nearest_neighbor, 'ROUND': self.convert_round, 'RSQRT': self.convert_rsqrt, + 'REVERSE_SEQUENCE': self.convert_reverse_sequence, 'SELECT': self.convert_select, 'SHAPE': self.convert_shape, 'SIN': self.convert_sin, @@ -2002,6 +2003,33 @@ def convert_transpose(self, op): return out + def convert_reverse_sequence(self, op): + """Convert TFLite REVERSE_SEQUENCE""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.ReverseSequenceOptions import ReverseSequenceOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFLite does not support quantized REVERSE_SEQUENCE operator yet.') + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + in_expr = self.get_tensor_expr(input_tensors[0]) + length_expr = self.get_tensor_expr(input_tensors[1]) + + assert op.BuiltinOptionsType() == BuiltinOptions.ReverseSequenceOptions + op_options = op.BuiltinOptions() + options = ReverseSequenceOptions() + options.Init(op_options.Bytes, op_options.Pos) + batch_axis = options.BatchDim() + seq_axis = options.SeqDim() + + return _op.reverse_sequence(in_expr, length_expr, seq_axis, batch_axis) + def convert_cast(self, op): """Convert TFLite CAST""" try: @@ -2700,14 +2728,10 @@ def has_expr(self, input_tensor_idx): return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx)) def get_tensor_expr(self, tensor): - """ Returns constant expr for constant else a tensor expr""" + """ Return the Relay expr for tensor. """ if self.has_expr(tensor.tensor_idx): - # In most cases, we can assume that TOCO fuses elemwise operators - # with constants - it means both will be tensors. expr = self.get_expr(tensor.tensor_idx) else: - # However, in some corner cases, the elemwise operator is not fused, - # we can receive as constant. type_str = self.get_tensor_type_str(tensor.tensor.Type()) expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index f134b8251afa..d104c1b1c2f8 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -41,6 +41,7 @@ _reg.register_injective_schedule("full_like") _reg.register_injective_schedule("arange") _reg.register_injective_schedule("reverse") +_reg.register_injective_schedule("reverse_sequence") _reg.register_injective_schedule("cast") _reg.register_injective_schedule("cast_like") _reg.register_injective_schedule("reinterpret") diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 429c4f1b9940..6c3dfaf2cd0f 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -227,6 +227,10 @@ class TileAttrs(Attrs): class ReverseAttrs(Attrs): """Attributes used in reverse operators""" +@tvm._ffi.register_object("relay.attrs.ReverseSequenceAttrs") +class ReverseSequenceAttrs(Attrs): + """Attributes used in reverse sequence operators""" + @tvm._ffi.register_object("relay.attrs.SqueezeAttrs") class SqueezeAttrs(Attrs): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 05958fc39196..a37226ea4f58 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -515,6 +515,53 @@ def reverse(data, axis): return _make.reverse(data, axis) +def reverse_sequence(data, seq_lengths, seq_axis=1, batch_axis=0): + """Reverse the tensor for variable length slices. + Input is first sliced along batch axis and then elements are reversed along seq axis. + + Parameters + ---------- + data : relay.Expr + The tensor to be reversed. + + seq_lengths : relay.Expr + A 1D Tensor with length a.dims[batch_axis] + Must be one of the following types: int32, int64 + if seq_lengths[i] > a.dims[seq_axis], it is rounded to a.dims[seq_axis] + if seq_lengths[i] < 1, it is rounded to 1 + + seq_axis : int, optional + The axis along which the elements will be reversed. Default is 1. + + batch_axis : int, optional + The axis along which the tensor will be sliced. Default is 0. + + Returns + ------- + ret : relay.Expr + The computed result of same shape and type as of input. + + Examples + -------- + .. code-block:: python + + x = [[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15]] + relay.reverse(x, [1, 2, 3, 4], 0, 1) = [[0, 5, 10, 15], + [4, 1, 6, 11], + [8, 9, 2, 7], + [12, 13, 14, 3]] + + relay.reverse(x, [1, 2, 3, 4], 1, 0) = [[0, 1, 2, 3], + [5, 4, 6, 7], + [10, 9, 8, 11], + [15, 14, 13, 12]] + """ + return _make.reverse_sequence(data, seq_lengths, seq_axis, batch_axis) + + def where(condition, x, y): """Selecting elements from either x or y depending on the value of the condition. diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 2a7e4e21e68b..ee5e291e3d53 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1397,7 +1397,8 @@ Array ReverseCompute(const Attrs& attrs, const Array& in const Type& out_type) { const ReverseAttrs* param = attrs.as(); CHECK(param != nullptr); - return {topi::flip(inputs[0], param->axis)}; + // pass empty seq_length tensor to reverse_sequence + return {topi::reverse_sequence(inputs[0], te::Tensor(), param->axis)}; } Expr MakeReverse(Expr data, int axis) { @@ -1423,6 +1424,96 @@ RELAY_REGISTER_OP("reverse") .set_attr("FTVMCompute", ReverseCompute) .set_attr("TOpPattern", kInjective); +// reverse sequence operator +TVM_REGISTER_NODE_TYPE(ReverseSequenceAttrs); + +bool ReverseSequenceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, seq_lengths, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + + if (data == nullptr) { + CHECK(types[0].as()) + << "reverse_sequence: expect input type to be TensorType but get " << types[0]; + return false; + } + + const auto* seq_lengths = types[1].as(); + if (seq_lengths == nullptr) { + CHECK(types[1].as()) + << "reverse_sequence: expect input type to be TensorType but get " << types[1]; + return false; + } + + const int seq_lengths_dim = static_cast(seq_lengths->shape.size()); + CHECK(seq_lengths_dim == 1) << "For reverse_sequnece, seq_lengths must be a 1D vector"; + CHECK(seq_lengths->dtype.is_int()) + << "For reverse_sequnece, seq_lengths must be tensor of integer"; + + const auto* param = attrs.as(); + const int ndim = static_cast(data->shape.size()); + int batch_axis = param->batch_axis; + CHECK(-ndim <= batch_axis && batch_axis < ndim) + << "reverse_sequence only accepts `batch_axis` in [-data.ndim, data.ndim - 1]" + << ", but got batch_axis = " << batch_axis << ", and data.ndim = " << ndim; + + if (batch_axis < 0) { + batch_axis = static_cast(data->shape.size()) + batch_axis; + } + CHECK(reporter->Assert(seq_lengths->shape[0] == data->shape[batch_axis])) + << "For reverse_sequnece seq_lengths size should match with dimension of batch axis" + << ", but got dimension of batch_axis = " << data->shape[batch_axis] + << ", and seq_length size = " << seq_lengths->shape[0]; + + const int seq_axis = param->seq_axis; + CHECK(-ndim <= seq_axis && seq_axis < ndim) + << "reverse_sequnece only accepts `seq_axis` in [-data.ndim, data.ndim - 1]" + << ", but got seq_axis = " << seq_axis << ", and data.ndim = " << ndim; + + reporter->Assign(types[2], types[0]); + return true; +} + +Array ReverseSequenceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + const ReverseSequenceAttrs* param = attrs.as(); + CHECK(param != nullptr); + return {topi::reverse_sequence(inputs[0], inputs[1], param->seq_axis, param->batch_axis)}; +} + +Expr MakeReverseSequence(Expr data, Expr seq_lengths, int seq_axis, int batch_axis) { + auto attrs = make_object(); + attrs->seq_axis = seq_axis; + attrs->batch_axis = batch_axis; + static const Op& op = Op::Get("reverse_sequence"); + return Call(op, {data, seq_lengths}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.reverse_sequence").set_body_typed(MakeReverseSequence); + +RELAY_REGISTER_OP("reverse_sequence") + .describe(R"code(Reverses the tensor for variable length slices. +Input is first sliced along batch axis and then elements are reversed along seq axis. + +- **data**: The input data to the operator. + +- **seq_lengths**: A 1D Tensor with length data.dims[batch_axis]. + +- **seq_axis**: The axis along which the elements will be reversed. Default is 1. + +- **batch_axis**: The axis along which the tensor will be sliced. Default is 0. + +)code" TVM_ADD_FILELINE) + .set_num_inputs(2) + .set_attrs_type() + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("seq_lengths", "Tensor", "A 1D Tensor with length data.dims[batch_axis]") + .set_support_level(3) + .add_type_rel("ReverseSequence", ReverseSequenceRel) + .set_attr("FTVMCompute", ReverseSequenceCompute) + .set_attr("TOpPattern", kInjective); + // where operator bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 00c077f0d2e0..8b3e04be0379 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -472,6 +472,39 @@ def verify(x_shape, y_shape, axes): verify((3, 4), (2, 3), (0)) verify((3, 4), (2, 3), (-1)) +def test_forward_sequence_reverse(): + def verify(shape, seq_lengths, use_seq_lengths, seq_axis): + data_np = np.random.uniform(size=shape).astype("float32") + + ref_res_args = [mx.nd.array(data_np), None, use_seq_lengths, seq_axis] + mx_sym_args = [mx.sym.var("data"), None, use_seq_lengths, seq_axis] + from_mxnet_args = [{"data": shape}, {"data": "float32"}] + in_data= [data_np] + + if use_seq_lengths and seq_lengths: + seq_lengths_np = np.array(seq_lengths).astype("int32") + ref_res_args[1] = mx.nd.array(seq_lengths_np) + mx_sym_args[1] = mx.sym.var("seq_lengths") + from_mxnet_args[0].update({"seq_lengths": seq_lengths_np.shape}) + from_mxnet_args[1].update({"seq_lengths": "int32"}) + in_data.append(seq_lengths_np) + + ref_res = mx.nd.SequenceReverse(*ref_res_args) + mx_sym = mx.sym.SequenceReverse(*mx_sym_args) + mod, _ = relay.frontend.from_mxnet(mx_sym, *from_mxnet_args) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(*in_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + + verify((3, 4), [1, 2, 3, 1], True, 0) + verify((3, 4), None, False, 0) + verify((3, 5, 5, 6), [1, 2, 3, 1, 3], True, 0) + # MXNet accepts axis value as 0 only + # verify((3, 4, 5, 6), None, False, 2) + def test_forward_l2_normalize(): data = mx.sym.var('data') mx_sym = mx.sym.L2Normalization(data, mode="channel") @@ -1232,6 +1265,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size test_forward_scalar_ops() test_forward_slice_like() test_forward_slice_axis() + test_forward_sequence_reverse() test_forward_l2_normalize() test_forward_shape_array() test_forward_squeeze() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 4adc3e074832..166eb2740edb 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -2082,6 +2082,32 @@ def test_forward_spacetodepth(): _test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]).astype("float32"), 2) _test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]).astype("float32"), 4) + +####################################################################### +# ReverseSequence +# --------------- + +def _test_reverse_sequence(shape, dtype, seq_lengths, batch_axis, seq_axis): + """ One iteration of reverse_sequence operation with given data and attributes """ + + data = np.random.uniform(0, 100, size=shape).astype(dtype) + with tf.Graph().as_default(): + in_data = array_ops.placeholder(dtype=dtype, name="input", shape=shape) + out = tf.reverse_sequence(in_data, seq_lengths=seq_lengths, batch_axis=batch_axis, + seq_axis=seq_axis) + + compare_tflite_with_tvm(data, 'input', [in_data], [out]) + + +def test_forward_reverse_sequence(): + if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): + _test_reverse_sequence([4, 3], "float32", [3, 2, 1], 1, 0) + _test_reverse_sequence([4, 3], "float32", [3, 2, 1, 3], 0, 1) + _test_reverse_sequence([2, 3, 3, 3], "float32", [2, 3, 2], 2, 1) + _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3], 0, 2) + _test_reverse_sequence([2, 4, 6, 4, 5], "float32", [5, 3, 1, 4], 3, 2) + + ####################################################################### # Sparse To Dense # --------------- @@ -2602,6 +2628,7 @@ def test_forward_mediapipe_hand_landmark(): test_forward_stridedslice() test_forward_depthtospace() test_forward_spacetodepth() + test_forward_reverse_sequence() test_forward_sparse_to_dense() test_forward_select() test_forward_quantize_dequantize() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index f50a69278402..f3e28dbfeb58 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -663,6 +663,73 @@ def verify_reverse(dshape, axis): verify_reverse((2, 3, 4), -1) +def test_reverse_sequence(): + def verify_reverse_sequence(x_data, seq_lengths, batch_axis, seq_axis, ref_res): + seq_lengths_data = np.array(seq_lengths).astype("int32") + x = relay.var("x", relay.TensorType(x_data.shape, str(x_data.dtype))) + z = relay.reverse_sequence(x, relay.const(seq_lengths_data), seq_axis, batch_axis) + zz = run_infer_type(z) + assert zz.checked_type == x.type_annotation + func = relay.Function([x], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") + result = [[0, 5, 10, 15], + [4, 1, 6, 11], + [8, 9, 2, 7], + [12, 13, 14, 3]] + verify_reverse_sequence(indata, [1, 2, 3, 4], 1, 0, np.array(result)) + verify_reverse_sequence(indata, [1, 2, 3, 4], -1, 0, np.array(result)) + verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32")) + + indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") + result = [[0, 1, 2, 3], + [5, 4, 6, 7], + [10, 9, 8, 11], + [15, 14, 13, 12]] + verify_reverse_sequence(indata, [1, 2, 3, 4], 0, 1, np.array(result)) + verify_reverse_sequence(indata, [1, 2, 3, 4], 0, -1, np.array(result)) + verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32")) + + indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") + result = [[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [15, 14, 13, 12]] + verify_reverse_sequence(indata, [-1, 0, 1, 5], 0, 1, np.array(result)) + + indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32") + result = [[[[18, 19, 20], [21, 22, 23], [24, 25, 26]], + [[9, 10, 11], [12, 13, 14], [15, 16, 17]], + [[0, 1, 2], [3, 4, 5], [6, 7, 8]]], + [[[45, 46, 47], [48, 49, 50], [51, 52, 53]], + [[36, 37, 38], [39, 40, 41], [42, 43, 44]], + [[27, 28, 29], [30, 31, 32], [33, 34, 35]]]] + verify_reverse_sequence(indata, [3, 3], 0, 1, np.array(result)) + + indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32") + result = [[[[9, 10, 11], [21, 22, 23], [15, 16, 17]], + [[0, 1, 2], [12, 13, 14], [6, 7, 8]], + [[18, 19, 20], [3, 4, 5], [24, 25, 26]]], + [[[36, 37, 38], [48, 49, 50], [42, 43, 44]], + [[27, 28, 29], [39, 40, 41], [33, 34, 35]], + [[45, 46, 47], [30, 31, 32], [51, 52, 53]]]] + verify_reverse_sequence(indata, [2, 3, 2], 2, 1, np.array(result)) + + indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") + result = [] + with pytest.raises(Exception) as execinfo: + verify_reverse_sequence(indata, [2, 3, 2, 4, 5], 1, 0, np.array(result)) + + assert "For reverse_sequnece seq_lengths size should match with dimension of batch axis," \ + " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0] + + def test_scatter(): def ref_scatter(data, indices, updates, axis=0): diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 794796702d00..e0e455667889 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -150,44 +150,72 @@ inline Tensor transpose(const Tensor& x, Array axes, std::string name = } /*! - * \brief flip/reverse elements of an array in a particular axis + * \brief Reverse the tensor for variable length slices. + * Input is first sliced along batch axis and then elements are reversed along seq axis. * * \param x The input tensor - * \param axis The axis along which the tensors will be reveresed - * (allows negative indices) + * \param seq_lengths A 1D Tensor with length x.dims[batch_axis]. Optional Tensor() can be passed. + * If not defined batch axis is ignored and tensor is reversed along seq_axis. + * \param seq_axis The axis along which the elements will be reveresed + * \param batch_axis The axis along which the tensor will be sliced * \param name The name of the operation * \param tag The tag to mark the operation * - * \return A Tensor whose op member is the reverse operation + * \return A Tensor whose op member is the reverse_sequence operation */ -inline Tensor flip(const Tensor& x, int axis = 0, std::string name = "T_flip", - std::string tag = kInjective) { +inline Tensor reverse_sequence(const Tensor& x, const Tensor& seq_lengths, int seq_axis = 1, + int batch_axis = 0, std::string name = "T_reverse_sequence", + std::string tag = kInjective) { size_t src_tensor_dim = x->shape.size(); - int axis_inp = axis; + int seq_axis_inp = seq_axis; - if (axis < 0) { - axis = static_cast(x->shape.size()) + axis; + if (seq_lengths.defined()) { + size_t seq_lengths_dim = seq_lengths->shape.size(); + int batch_axis_inp = batch_axis; + if (batch_axis < 0) { + batch_axis = static_cast(x->shape.size()) + batch_axis; + } + + CHECK(seq_lengths_dim == 1) << "seq_lengths should be 1D vector"; + + CHECK(GetConstInt(seq_lengths->shape[0]) == GetConstInt(x->shape[batch_axis])) + << "For reverse_sequnece seq_lengths size should match with dimension of batch axis" + << ", but got dimension of batch_axis = " << GetConstInt(x->shape[batch_axis]) + << ", and seq_length size = " << GetConstInt(seq_lengths->shape[0]); + + CHECK((0 <= batch_axis) && (batch_axis < static_cast(x->shape.size()))) + << "batch_axis=" << batch_axis_inp << " is invalid for the " + << static_cast(x->shape.size()) << "-dimensional input tensor"; } - CHECK((0 <= axis) && (axis < static_cast(x->shape.size()))) - << "axis=" << axis_inp << " is invalid for the " << static_cast(x->shape.size()) + if (seq_axis < 0) { + seq_axis = static_cast(x->shape.size()) + seq_axis; + } + CHECK((0 <= seq_axis) && (seq_axis < static_cast(x->shape.size()))) + << "seq_axis=" << seq_axis_inp << " is invalid for the " << static_cast(x->shape.size()) << "-dimensional input tensor"; - // Reverse the Input Tensor in the axis specified - return compute( - x->shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - if (i == static_cast(axis)) { - real_indices.push_back(x->shape[i] - indices[i] - 1); - } else { - real_indices.push_back(indices[i]); - } + auto func = [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + if (i == static_cast(seq_axis)) { + if (seq_lengths.defined()) { + auto len = seq_lengths(indices[batch_axis]); + auto idx = if_then_else( + len <= 1 || len <= indices[i], indices[i], + if_then_else(len > x->shape[i], x->shape[i] - 1 - indices[i], len - 1 - indices[i])); + real_indices.push_back(idx); + } else { + real_indices.push_back(x->shape[i] - 1 - indices[i]); } - return x(real_indices); - }, - name, tag); + } else { + real_indices.push_back(indices[i]); + } + } + return x(real_indices); + }; + + return compute(x->shape, func, name, tag); } /*! diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index f1bcccd9fde8..a8c8b1400208 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -131,6 +131,37 @@ def flip(a, axis=0): """ return cpp.flip(a, axis) + +def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0): + """Reverse the tensor for variable length slices. + Input is first sliced along batch axis and then elements are reversed along seq axis. + + Parameters + ---------- + a : tvm.te.Tensor + The tensor to be reversed. + + seq_lengths : tvm.te.Tensor + A 1D Tensor with length a.dims[batch_axis] + Must be one of the following types: int32, int64 + if seq_lengths[i] > a.dims[seq_axis], it is rounded to a.dims[seq_axis] + if seq_lengths[i] < 1, it is rounded to 1 + + seq_axis : int, optional + The axis along which the elements will be reversed. Default is 1. + + batch_axis : int, optional + The axis along which the tensor will be sliced. Default is 0. + + Returns + ------- + ret : tvm.te.Tensor + The computed result of same shape and type as of input. + + """ + return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis) + + def strided_slice(a, begin, end, strides=None, slice_mode="end"): """Slice of an array. diff --git a/topi/src/transform.cc b/topi/src/transform.cc index 2791ff7dab1d..4308784f80c4 100644 --- a/topi/src/transform.cc +++ b/topi/src/transform.cc @@ -40,7 +40,12 @@ TVM_REGISTER_GLOBAL("topi.transpose").set_body([](TVMArgs args, TVMRetValue* rv) }); TVM_REGISTER_GLOBAL("topi.flip").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = flip(args[0], args[1]); + // pass empty seq_lengths tensor to reverse_sequence + *rv = reverse_sequence(args[0], Tensor(), args[1]); +}); + +TVM_REGISTER_GLOBAL("topi.reverse_sequence").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = reverse_sequence(args[0], args[1], args[2], args[3]); }); TVM_REGISTER_GLOBAL("topi.reshape").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 96df101b092e..b0aee6a3d899 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -16,6 +16,7 @@ # under the License. """Test code for broadcasting operators.""" import numpy as np +import pytest import tvm from tvm import te import topi @@ -289,6 +290,85 @@ def check_device(device): for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]: check_device(device) + +def test_reverse_sequence(): + def verify_reverse_sequence(in_data, seq_lengths, batch_axis, seq_axis, ref_res): + seq_lengths = np.array(seq_lengths).astype("int32") + A = te.placeholder(shape=in_data.shape, name="A", dtype=str(in_data.dtype)) + B = te.placeholder(shape=seq_lengths.shape, name="B", dtype=str(seq_lengths.dtype)) + C = topi.reverse_sequence(A, B, seq_axis, batch_axis) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.testing.get_injective_schedule(device)(C) + + foo = tvm.build(s, [A, B, C], device, name="reverse_sequence") + + data_nd = tvm.nd.array(in_data, ctx) + seq_lengths_nd = tvm.nd.array(seq_lengths, ctx) + out_nd = tvm.nd.empty(in_data.shape, ctx=ctx, dtype=A.dtype) + foo(data_nd, seq_lengths_nd, out_nd) + tvm.testing.assert_allclose(out_nd.asnumpy(), ref_res) + + for device in get_all_backend(): + check_device(device) + + indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") + result = [[0, 5, 10, 15], + [4, 1, 6, 11], + [8, 9, 2, 7], + [12, 13, 14, 3]] + verify_reverse_sequence(indata, [1, 2, 3, 4], 1, 0, np.array(result)) + verify_reverse_sequence(indata, [1, 2, 3, 4], -1, 0, np.array(result)) + verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32")) + + indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") + result = [[0, 1, 2, 3], + [5, 4, 6, 7], + [10, 9, 8, 11], + [15, 14, 13, 12]] + verify_reverse_sequence(indata, [1, 2, 3, 4], 0, 1, np.array(result)) + verify_reverse_sequence(indata, [1, 2, 3, 4], 0, -1, np.array(result)) + verify_reverse_sequence(indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32")) + + indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") + result = [[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [15, 14, 13, 12]] + verify_reverse_sequence(indata, [-1, 0, 1, 5], 0, 1, np.array(result)) + + indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32") + result = [[[[18, 19, 20], [21, 22, 23], [24, 25, 26]], + [[9, 10, 11], [12, 13, 14], [15, 16, 17]], + [[0, 1, 2], [3, 4, 5], [6, 7, 8]]], + [[[45, 46, 47], [48, 49, 50], [51, 52, 53]], + [[36, 37, 38], [39, 40, 41], [42, 43, 44]], + [[27, 28, 29], [30, 31, 32], [33, 34, 35]]]] + verify_reverse_sequence(indata, [3, 3], 0, 1, np.array(result)) + + indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32") + result = [[[[9, 10, 11], [21, 22, 23], [15, 16, 17]], + [[0, 1, 2], [12, 13, 14], [6, 7, 8]], + [[18, 19, 20], [3, 4, 5], [24, 25, 26]]], + [[[36, 37, 38], [48, 49, 50], [42, 43, 44]], + [[27, 28, 29], [39, 40, 41], [33, 34, 35]], + [[45, 46, 47], [30, 31, 32], [51, 52, 53]]]] + verify_reverse_sequence(indata, [2, 3, 2], 2, 1, np.array(result)) + + indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32") + result = [] + with pytest.raises(Exception) as execinfo: + verify_reverse_sequence(indata, [2, 3, 2, 4, 5], 1, 0, np.array(result)) + + assert "For reverse_sequnece seq_lengths size should match with dimension of batch axis," \ + " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0] + def verify_take(src_shape, indices_src, axis=None, mode="clip"): src_dtype = "float32" indices_dtype = "int32"