diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 06f4f0d61f34..0b9d555ca6fa 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -75,6 +75,7 @@ List of operators topi.stack topi.repeat topi.tile + topi.shape topi.layout_transform topi.image.resize @@ -136,6 +137,7 @@ topi .. autofunction:: topi.stack .. autofunction:: topi.repeat .. autofunction:: topi.tile +.. autofunction:: topi.shape .. autofunction:: topi.layout_transform topi.nn diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index f20c443e8404..997b7e34464e 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -154,6 +154,7 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.broadcast_to_like tvm.relay.collapse_sum_like tvm.relay.slice_like + tvm.relay.shape_of tvm.relay.layout_transform tvm.relay.device_copy tvm.relay.annotation.on_device @@ -273,6 +274,7 @@ Level 10 Definitions .. autofunction:: tvm.relay.broadcast_to_like .. autofunction:: tvm.relay.collapse_sum_like .. autofunction:: tvm.relay.slice_like +.. autofunction:: tvm.relay.shape_of .. autofunction:: tvm.relay.layout_transform .. autofunction:: tvm.relay.device_copy .. autofunction:: tvm.relay.annotation.on_device diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 5382017d8c1c..f9997202ac53 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -217,6 +217,7 @@ struct ClipAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for LayoutTransform operator */ struct LayoutTransformAttrs : public tvm::AttrsNode { std::string src_layout; std::string dst_layout; @@ -229,6 +230,17 @@ struct LayoutTransformAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes for ShapeOf operator */ +struct ShapeOfAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(ShapeOfAttrs, "relay.attrs.ShapeOfAttrs") { + TVM_ATTR_FIELD(dtype) + .describe("Target data type") + .set_default(NullValue()); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 93bd8efc6752..87ae6b6cb762 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -488,6 +488,19 @@ def _mx_l2_normalize(inputs, attrs): return _op.nn.l2_normalize(inputs[0], **new_attrs) +def _mx_shape_array(inputs, attrs): + assert len(inputs) == 1 + if attrs.get_int("lhs_begin", None) is not None: + raise RuntimeError("shape_array doesn't support lhs_begin") + if attrs.get_int("lhs_end", None) is not None: + raise RuntimeError("shape_array doesn't support lhs_end") + if attrs.get_int("rhs_begin", None) is not None: + raise RuntimeError("shape_array doesn't support rhs_begin") + if attrs.get_int("rhs_end", None) is not None: + raise RuntimeError("shape_array doesn't support rhs_end") + return _op.shape_of(inputs[0], dtype='int64') + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -613,6 +626,7 @@ def _mx_l2_normalize(inputs, attrs): "repeat" : _mx_repeat, "tile" : _mx_tile, "BlockGrad" : _mx_BlockGrad, + "shape_array" : _mx_shape_array, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, # vision diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 7f8da03008d2..36dae03d1237 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -40,6 +40,7 @@ register_schedule("minimum", schedule_injective) register_schedule("right_shift", schedule_injective) register_schedule("left_shift", schedule_injective) +register_schedule("shape_of", schedule_injective) # zeros @register_compute("zeros") diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index e315f27dc593..ffbc7459648e 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -713,3 +713,22 @@ def device_copy(data, src_dev, dst_dev): raise ValueError("dst_dev is expected to be the type of TVMContext or " "str, but received %s" % (type(dst_dev))) return _make.device_copy(data, src_dev, dst_dev) + + +def shape_of(data, dtype="int32"): + """Get shape of a tensor. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor. + + dtype : str, optional + The target data type. + + Returns + ------- + result : tvm.relay.Expr + The shape tensor. + """ + return _make.shape_of(data, dtype) diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index cfcc130564c0..720344c3340d 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include "../type_relations.h" #include "../op_common.h" @@ -189,5 +190,56 @@ RELAY_REGISTER_UNARY_OP("logical_not") .set_support_level(4) .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::logical_not)); + +// shape_of +TVM_REGISTER_NODE_TYPE(ShapeOfAttrs); + +bool ShapeOfRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 1); + auto tt = types[0].as(); + CHECK(tt != nullptr); + const auto* param = attrs.as(); + CHECK(param != nullptr); + auto vector_out = tvm::Integer(tt->shape.size()); + reporter->Assign(types[1], TensorTypeNode::make({ vector_out }, param->dtype)); + return true; +} + +Array ShapeOfCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + CHECK_EQ(inputs.size(), 1); + const auto* param = attrs.as(); + CHECK(param != nullptr); + return {topi::shape(inputs[0], param->dtype)}; +} + +TVM_REGISTER_API("relay.op._make.shape_of") +.set_body_typed([](Expr data, DataType dtype) { + auto attrs = make_node(); + attrs->dtype = dtype; + static const Op& op = Op::Get("shape_of"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("shape_of") +.describe(R"code(Returns a tensor representing the shape of a tensor. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ShapeOfAttrs") +.add_argument("data", "Tensor", "The input tensor.") +.add_type_rel("ShapeOf", ShapeOfRel) +.set_attr("TOpIsStateful", false) +.set_attr("TOpPattern", kInjective) +.set_attr("FInferCorrectLayout", + ElemwiseArbitraryLayout) +.set_support_level(10) +.set_attr("FTVMCompute", ShapeOfCompute); + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_constant.cc b/src/relay/pass/fold_constant.cc index 60994cdd6ca9..f4dd067e2eb1 100644 --- a/src/relay/pass/fold_constant.cc +++ b/src/relay/pass/fold_constant.cc @@ -6,6 +6,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -71,6 +72,7 @@ class ConstantFolder : public ExprMutator { Expr VisitExpr_(const CallNode* call) final { static auto op_stateful = Op::GetAttr("TOpIsStateful"); + auto origin_args = call->args; Expr res = ExprMutator::VisitExpr_(call); call = res.as(); // We don't constant fold function with zero arguments. @@ -81,6 +83,10 @@ class ConstantFolder : public ExprMutator { if (op == nullptr) return res; // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return res; + // Try to evaluate shape_of op + if (call->op.same_as(Op::Get("shape_of"))) { + return EvaluateShapeOf(res, origin_args, call->attrs); + } bool all_const_args = true; for (Expr arg : call->args) { if (!checker_.Check(arg)) { @@ -132,6 +138,42 @@ class ConstantFolder : public ExprMutator { expr = InferType(expr, Module(nullptr)); return ValueToExpr(executor_(expr)); } + // Evaluate shape_of op + Expr EvaluateShapeOf(Expr expr, Array args, Attrs attrs) { + Expr input = args[0]; + const auto* param = attrs.as(); + CHECK(param != nullptr); + tvm::Array ishape; + if (const ConstantNode* op = input.as()) { + ishape = op->tensor_type()->shape; + } else if (input->checked_type_.defined()) { + ishape = input->checked_type().as()->shape; + } else { + return expr; + } + // Get the constant shape + DLContext ctx; + ctx.device_type = kDLCPU; + ctx.device_id = 0; + auto val = runtime::NDArray::Empty( + {(int64_t)ishape.size()}, Type2TVMType(Int(32)), ctx); + int32_t* dims = static_cast(val->data); + using ::tvm::ir::IntImm; + for (size_t i = 0; i < ishape.size(); ++i) { + if (const IntImm* dim = ishape[i].as()) { + dims[i] = dim->value; + } else { + return expr; + } + } + Expr shape = ValueToExpr(TensorValueNode::make(val)); + // Cast the constant into correct dtype + auto cast_attrs = make_node(); + cast_attrs->dtype = param->dtype; + static const Op& cast_op = Op::Get("cast"); + Expr ret = CallNode::make(cast_op, {shape}, Attrs(cast_attrs), {}); + return ConstEvaluate(ret); + } }; diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 4679876c181b..e83f1e569545 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -380,6 +380,22 @@ def test_forward_l2_normalize(): verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5)) +def test_forward_shape_array(): + def verify(shape): + x_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.shape_array(mx.nd.array(x_np)) + mx_sym = mx.sym.shape_array(mx.sym.var("x")) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((1,)) + verify((3, 4, 5)) + verify((3, 4, 5, 6)) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -409,3 +425,4 @@ def test_forward_l2_normalize(): test_forward_slice_like() test_forward_slice_axis() test_forward_l2_normalize() + test_forward_shape_array() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 7237cfbc3b87..1b1760692943 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -177,6 +177,20 @@ def test_batch_matmul(): verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20)) verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20)) +def test_shape_of(): + shape = (10, 5, 12) + x = relay.var("x", shape=shape) + func = relay.Function([x], relay.op.shape_of(x)) + func = relay.ir_pass.infer_type(func) + x_data = np.random.rand(*shape).astype('float32') + for target, ctx in ctx_list(): + # Because using graph executor, this op will be optimized after + # constant folding pass, here we only test with interpreter + for kind in ["debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), + np.array(shape).astype('int32')) if __name__ == "__main__": test_collapse_sum_like() @@ -184,3 +198,4 @@ def test_batch_matmul(): test_slice_like() test_reverse_reshape() test_batch_matmul() + test_shape_of() diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 6a63d88f052f..315a83a92a35 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -95,8 +95,34 @@ def expected(): assert relay.ir_pass.graph_equal(zz, zexpected) +def test_fold_shape_of(): + c_shape = (8, 9, 10) + def before(dtype): + x = relay.var("x", shape=c_shape, dtype="float32") + y = relay.var("y", shape=c_shape, dtype="float32") + z = relay.shape_of(x + y, dtype) + return relay.Function([x, y], z) + + def expected(dtype): + x = relay.var("x", shape=c_shape, dtype="float32") + y = relay.var("y", shape=c_shape, dtype="float32") + z = relay.const(np.array(c_shape).astype(dtype), dtype=dtype) + return relay.ir_pass.infer_type(relay.Function([x, y], z)) + + for dtype in ["int32", "float32"]: + zbefore = before(dtype) + zz = relay.ir_pass.fold_constant(zbefore) + assert relay.ir_pass.graph_equal(zz, zbefore) + + zz = relay.ir_pass.infer_type(zbefore) + zz = relay.ir_pass.fold_constant(zz) + zexpected = expected(dtype) + assert relay.ir_pass.graph_equal(zz, zexpected) + + if __name__ == "__main__": test_fold_const() test_fold_let() test_fold_tuple() test_fold_concat() + test_fold_shape_of() diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index b7e012f989b4..57d442dc9206 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1081,5 +1081,28 @@ inline Tensor layout_transform(const Tensor& src, }, name, tag); } +/*! + * \brief Get the shape of input tensor. + * \param src the input tensor. + * \param name output tensor name. + * \param tag output tensor tag. + * \return Tensor of input shape. + */ +inline Tensor shape(const Tensor& src, + Type dtype, + const std::string name = "shape", + const std::string tag = kInjective) { + int ndim = static_cast(src->shape.size()); + Array out_shape{ndim}; + return compute(out_shape, [&](const Array& indices) { + auto idx = indices[0]; + Expr ret = 0; + for (int i = 0; i < ndim; ++i) { + ret = tvm::if_then_else(idx == i, src->shape[i], ret); + } + return tvm::cast(dtype, ret); + }, name, tag); +} + } // namespace topi #endif // TOPI_TRANSFORM_H_ diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 063556852d26..2c109cd92c52 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -393,3 +393,22 @@ def layout_transform(array, src_layout, dst_layout): the destination layout. """ return cpp.layout_transform(array, src_layout, dst_layout) + + +def shape(array, dtype="int32"): + """Get the shape of input array + + Parameters + ---------- + array : tvm.Tensor + The source tenosr. + + dtype : str, optional + The target data type. + + Returns + ------- + result : tvm.Tensor + The resulting tensor. + """ + return cpp.shape(array, dtype) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 14f92460fd25..366f835d808d 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -271,6 +271,11 @@ TVM_REGISTER_GLOBAL("topi.stack") *rv = stack(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.shape") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = shape(args[0], args[1]); +}); + TVM_REGISTER_GLOBAL("topi.split") .set_body([](TVMArgs args, TVMRetValue *rv) { if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { @@ -278,7 +283,7 @@ TVM_REGISTER_GLOBAL("topi.split") } else { *rv = split(args[0], args[1], args[2]); } - }); +}); TVM_REGISTER_GLOBAL("topi.layout_transform") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 785da6fddbcf..ad557f0fcbfe 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -564,6 +564,33 @@ def check_device(device): check_device(backend) +def test_shape(): + in_shape = (8, 7, 13) + dtype = "int32" + A = tvm.placeholder(shape=in_shape, dtype="float32", name="A") + B = topi.shape(A, dtype) + + input = np.random.uniform(size=in_shape).astype(A.dtype) + output = np.asarray(in_shape).astype(dtype) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + tvm_input = tvm.nd.array(input, ctx) + tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=dtype) + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + f = tvm.build(s, [A, B], device, name="shape") + f(tvm_input, tvm_output) + tvm.testing.assert_allclose(tvm_output.asnumpy(), output) + + for backend in get_all_backend(): + check_device(backend) + + if __name__ == "__main__": test_strided_slice() test_concatenate() @@ -581,3 +608,4 @@ def check_device(device): test_layout_transform() test_repeat() test_tile() + test_shape()