diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 42883f5f77da..11fb282abac5 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -94,6 +94,7 @@ This level enables additional math and transform operators. tvm.relay.full tvm.relay.full_like tvm.relay.cast + tvm.relay.split **Level 4: Broadcast and Reductions** @@ -198,6 +199,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.full .. autofunction:: tvm.relay.full_like .. autofunction:: tvm.relay.cast +.. autofunction:: tvm.relay.split Level 4 Definitions diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b0150c4ac3d9..dfad1013701f 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -106,6 +106,22 @@ struct SqueezeAttrs : public tvm::AttrsNode { } }; // struct SqueezeAttrs +struct SplitAttrs : public tvm::AttrsNode { + NodeRef indices_or_sections; + int axis; + + TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { + TVM_ATTR_FIELD(indices_or_sections) + .describe("Indices or sections to split into. Accepts an int or a tuple" + "If indices_or_sections is an integer, the input will be divided equally" + "along given axis. If such a split is not possible, an error is raised." + "If indices_or_sections is a tuple of sorted integers," + "the entries indicate where along axis the array is split."); + TVM_ATTR_FIELD(axis).set_default(0) + .describe("the axis to be splitted."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index f5d73dfbcbde..22265587b497 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -427,7 +427,7 @@ along which to split the array. return Array{ topi::split(inputs[0], indices, param.axis) }; } }) -.set_support_level(1); +.set_support_level(3); // cast DMLC_REGISTER_PARAMETER(CastParam); diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 655379066c74..0650a493d9a6 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -5,6 +5,7 @@ import numpy as _np from .base import RelayNode, register_relay_node from . import _make +from . import _expr from . import ty as _ty from .._ffi import base as _base from .. import nd as _nd @@ -284,6 +285,16 @@ def astuple(self): as an argument to an FFI function.""" return self.tuple_value + def astext(self): + """Get the text format of the tuple expression. + + Returns + ------- + text : str + The text format of the tuple expression. + """ + return _expr._text_print(self.tuple_value) + def __getitem__(self, index): if index >= len(self): raise IndexError("Tuple index out of range") diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 84e2398f0a9e..3cf139c7dd86 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1,6 +1,7 @@ """Transform operators.""" from . import _make +from ..expr import TupleWrapper def expand_dims(data, axis, num_newaxis=1): @@ -146,7 +147,7 @@ def take(data, indices, axis=None): Parameters ---------- - a : relay.Expr + data : relay.Expr The source array. indices : rely.Expr @@ -280,3 +281,35 @@ def collapse_sum_like(data, collapse_type): The resulting tensor. """ return _make.collapse_sum_like(data, collapse_type) + + +def split(data, indices_or_sections, axis=0): + """Split input tensor along axis by sections or indices. + + If indices_or_sections is an integer, the input will be divided equally + along given axis. If such a split is not possible, an error is raised. + + If indices_or_sections is a tuple of sorted integers, + the entries indicate where along axis the array is split. + + Parameters + ---------- + data : relay.Expr + The source array. + + indices_or_sections : int or tuple of int + Indices or sections to split into. Accepts an int or a tuple + + axis : int, optional + The axis over which to split. + + Returns + ------- + ret : relay.Tuple([relay.Expr, relay.Expr]) + The computed result. + """ + if isinstance(indices_or_sections, int): + ret_size = indices_or_sections + else: + ret_size = len(indices_or_sections) + 1 + return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) diff --git a/src/lang/attr_functor.h b/src/lang/attr_functor.h index ef1d061015c3..9257ad3b5490 100644 --- a/src/lang/attr_functor.h +++ b/src/lang/attr_functor.h @@ -64,6 +64,7 @@ class AttrFunctor { virtual R VisitAttr_(const ir::Add* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Sub* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mul* op, Args... args) ATTR_FUNCTOR_DEFAULT; + virtual R VisitAttr_(const ir::Div* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Mod* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Min* op, Args... args) ATTR_FUNCTOR_DEFAULT; virtual R VisitAttr_(const ir::Max* op, Args... args) ATTR_FUNCTOR_DEFAULT; @@ -96,6 +97,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(Add); ATTR_FUNCTOR_DISPATCH(Sub); ATTR_FUNCTOR_DISPATCH(Mul); + ATTR_FUNCTOR_DISPATCH(Div); ATTR_FUNCTOR_DISPATCH(Min); ATTR_FUNCTOR_DISPATCH(Max); ATTR_FUNCTOR_DISPATCH(GE); @@ -135,6 +137,7 @@ class AttrsEqualHandler : bool VisitAttr_(const ir::Add* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Sub* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mul* lhs, const NodeRef& other) final; + bool VisitAttr_(const ir::Div* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Mod* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Min* lhs, const NodeRef& other) final; bool VisitAttr_(const ir::Max* lhs, const NodeRef& other) final; @@ -174,6 +177,7 @@ class AttrsHashHandler : size_t VisitAttr_(const ir::Add* op) final; size_t VisitAttr_(const ir::Sub* op) final; size_t VisitAttr_(const ir::Mul* op) final; + size_t VisitAttr_(const ir::Div* op) final; size_t VisitAttr_(const ir::Mod* op) final; size_t VisitAttr_(const ir::Min* op) final; size_t VisitAttr_(const ir::Max* op) final; diff --git a/src/lang/attrs.cc b/src/lang/attrs.cc index 9aa067c09679..3b273f4939ef 100644 --- a/src/lang/attrs.cc +++ b/src/lang/attrs.cc @@ -132,6 +132,7 @@ bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) TVM_DEFINE_ATTRS_BINOP_EQUAL(Add); TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul); +TVM_DEFINE_ATTRS_BINOP_EQUAL(Div); TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod); TVM_DEFINE_ATTRS_BINOP_EQUAL(Max); TVM_DEFINE_ATTRS_BINOP_EQUAL(Min); @@ -243,6 +244,7 @@ size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) { TVM_DEFINE_ATTRS_BINOP_HASH(Add); TVM_DEFINE_ATTRS_BINOP_HASH(Sub); TVM_DEFINE_ATTRS_BINOP_HASH(Mul); +TVM_DEFINE_ATTRS_BINOP_HASH(Div); TVM_DEFINE_ATTRS_BINOP_HASH(Mod); TVM_DEFINE_ATTRS_BINOP_HASH(Max); TVM_DEFINE_ATTRS_BINOP_HASH(Min); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 29dff1e4ba27..d7b4980f80b2 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -6,12 +6,14 @@ #include #include #include +#include #include #include "../op_common.h" namespace tvm { namespace relay { +using ir::IntImm; // relay.cast TVM_REGISTER_NODE_TYPE(CastAttrs); @@ -834,5 +836,100 @@ RELAY_REGISTER_OP("broadcast_to_like") .set_support_level(10) .add_type_rel("BroadCastToLike", BroadCastToLikeRel); +// Split +TVM_REGISTER_NODE_TYPE(SplitAttrs); + +bool SplitRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + CHECK(data != nullptr); + CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; + const auto param = attrs.as(); + CHECK(param != nullptr); + auto axis = param->axis; + if (axis < 0) { + axis += data->shape.size(); + } + CHECK_LT(axis, data->shape.size()) + << "axis should be within the input dimension range."; + CHECK_GT(axis, 0) + << "axis should be within the input dimension range."; + + if (const IntImm* sections = param->indices_or_sections.as()) { + CHECK(reporter->Assert(data->shape[axis] % + sections->value == make_zero(Int(64)))) + << "indices_or_sections need to be able to divide input.shape[axis]"; + std::vector fields; + for (int i = 0; i < sections->value; ++i) { + std::vector&& oshape = AsVector(data->shape); + oshape[axis] /= int32_t(sections->value); + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + } + reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + } else { + auto indices = param->indices_or_sections.as()->data; + auto begin = IndexExpr(make_zero(Int(32))); + std::vector fields; + for (uint i = 0; i < indices.size(); ++i) { + CHECK(reporter->Assert(IndexExpr(indices[i]) > begin)) + << "indices_or_sections need to be a sorted ascending list"; + std::vector&& oshape = AsVector(data->shape); + oshape[axis] = IndexExpr(indices[i]) - begin; + begin = IndexExpr(indices[i]); + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + } + CHECK(reporter->Assert(begin < data->shape[axis])) + << "The sum of sections must match the input.shape[axis]"; + std::vector&& oshape = AsVector(data->shape); + oshape[axis] = data->shape[axis] - begin; + auto vec_type = TensorTypeNode::make(oshape, data->dtype); + fields.push_back(vec_type); + reporter->Assign(types[1], TupleTypeNode::make(Array(fields))); + } + return true; +} + +Expr MakeSplit(Expr data, + NodeRef indices_or_sections, + int axis) { + auto attrs = make_node(); + attrs->axis = axis; + attrs->indices_or_sections = std::move(indices_or_sections); + static const Op& op = Op::Get("split"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.split") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + if (args.type_codes[1] == kDLInt) { + *rv = MakeSplit(args[0], make_const(Int(64), int64_t(args[1])), args[2]); + } else { + *rv = MakeSplit(args[0], args[1], args[2]); + } +}); + +RELAY_REGISTER_OP("split") +.describe(R"code(Splits an array along a particular axis into multiple sub-arrays. + +Indices or sections to split into. Accepts an int or a tuple +If indices_or_sections is an integer, the input will be divided equally +along given axis. If such a split is not possible, an error is raised. + +If indices_or_sections is a tuple of sorted integers, +the entries indicate where along axis the array is split. + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.SplitAttrs") +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(3) +.add_type_rel("Split", SplitRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 8ab3c41c079d..804d3c46ca36 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -107,6 +107,38 @@ def verify_take(dshape, indices_shape, oshape, axis=None): verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1) verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) +def test_split_infer_type(): + def verify_split(dshape, indices_or_sections, ret_type, axis=None): + x = relay.var("x", relay.ty.TensorType(dshape, "float32")) + y = relay.split(x, indices_or_sections, axis=axis) + y.astext() + yy = relay.ir_pass.infer_type(y.astuple()) + assert yy.checked_type == ret_type + + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") + axis = tvm.var("axis") + verify_split((5, 5, 2, 2), 5, + relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((5, 1, 2, 2), "float32"), + relay.ty.TensorType((5, 1, 2, 2), "float32"), + relay.ty.TensorType((5, 1, 2, 2), "float32"), + relay.ty.TensorType((5, 1, 2, 2), "float32"), + relay.ty.TensorType((5, 1, 2, 2), "float32")])), + axis=1) + verify_split((d1, d2, d3, d4), 4, + relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), + relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), + relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), + relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])), + axis=2) + verify_split((d1, d2, d3, d4), (2, 4, 7), + relay.ty.TupleType(tvm.convert([ + relay.ty.TensorType((d1, 2, d3, d4), "float32"), + relay.ty.TensorType((d1, 2, d3, d4), "float32"), + relay.ty.TensorType((d1, 3, d3, d4), "float32"), + relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])), + axis=1) def test_full(): # default settings: match input dtype @@ -161,3 +193,4 @@ def test_infer_type_leaky_relu(): test_infer_type_leaky_relu() test_squeeze_infer_type() test_squeeze_bad_axes_infer_type() + test_split_infer_type()