From 6c1979fd9a00535347c453a80390591fcf83f762 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 19 Jul 2019 15:06:34 -0700 Subject: [PATCH] [TOPI][RELAY] Add op Size (#3094) --- docs/api/python/topi.rst | 2 + docs/langref/relay_op.rst | 2 + include/tvm/operation.h | 2 +- include/tvm/relay/attrs/transform.h | 11 +++++ python/tvm/api.py | 2 +- python/tvm/relay/frontend/tensorflow.py | 1 + python/tvm/relay/op/contrib/_contrib.py | 5 +- python/tvm/relay/op/contrib/contrib.py | 18 +++++++ src/relay/op/tensor/unary.cc | 49 +++++++++++++++++++ .../frontend/tensorflow/test_forward.py | 17 +++++++ tests/python/relay/test_op_level10.py | 19 +++++++ topi/include/topi/transform.h | 23 +++++++++ topi/python/topi/transform.py | 21 +++++++- topi/src/topi.cc | 5 ++ topi/tests/python/test_topi_transform.py | 28 +++++++++++ 15 files changed, 201 insertions(+), 4 deletions(-) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 367ad1a4c0d3..9ac8bb1fd084 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -97,6 +97,7 @@ List of operators topi.repeat topi.tile topi.shape + topi.ndarray_size topi.layout_transform topi.image.resize topi.argsort @@ -165,6 +166,7 @@ topi .. autofunction:: topi.repeat .. autofunction:: topi.tile .. autofunction:: topi.shape +.. autofunction:: topi.ndarray_size .. autofunction:: topi.layout_transform .. autofunction:: topi.argsort .. autofunction:: topi.topk diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index ccdb3e8af8fa..dad5eb89a053 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -186,6 +186,7 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.collapse_sum_like tvm.relay.slice_like tvm.relay.shape_of + tvm.relay.contrib.ndarray_size tvm.relay.layout_transform tvm.relay.device_copy tvm.relay.annotation.on_device @@ -320,6 +321,7 @@ Level 10 Definitions .. autofunction:: tvm.relay.collapse_sum_like .. autofunction:: tvm.relay.slice_like .. autofunction:: tvm.relay.shape_of +.. autofunction:: tvm.relay.contrib.ndarray_size .. autofunction:: tvm.relay.layout_transform .. autofunction:: tvm.relay.device_copy .. autofunction:: tvm.relay.annotation.on_device diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 99d218ea1dd3..b950aa952f04 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -59,7 +59,7 @@ class OperationNode : public ir::FunctionBaseNode { std::string name; /*! \brief optional tag of the operation */ std::string tag; - /*! \brief addtitional attributes of the operation*/ + /*! \brief additional attributes of the operation*/ Map attrs; /*! \return name of the operation */ const std::string& func_name() const final { diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index d09441d73eff..e43fd5f7a2e7 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -287,6 +287,17 @@ struct SequenceMaskAttrs : public tvm::AttrsNode { } }; // struct SequenceMaskAttrs. +/*! \brief Attributes for ndarray_size operator */ +struct NdarraySizeAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(NdarraySizeAttrs, "relay.attrs.NdarraySizeAttrs") { + TVM_ATTR_FIELD(dtype) + .describe("Target data type") + .set_default(NullValue()); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/python/tvm/api.py b/python/tvm/api.py index 2fac4f5b44d1..cbc3459f3338 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -275,7 +275,7 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): The name hint of the tensor tag: str, optional - Additonal tag information about the compute. + Additional tag information about the compute. attrs: dict, optional The additional auxiliary attributes about the compute. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index a90264488ae7..8605edf1d4a6 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1383,6 +1383,7 @@ def _impl(inputs, attr, params): 'Shape' : _shape(), 'Sigmoid' : AttrCvt('sigmoid'), 'Sign' : AttrCvt('sign'), + 'Size' : AttrCvt('ndarray_size'), 'Slice' : _slice(), 'Softmax' : _softmax(), 'Softplus' : _softplus(), diff --git a/python/tvm/relay/op/contrib/_contrib.py b/python/tvm/relay/op/contrib/_contrib.py index f0df75648467..4b5588024411 100644 --- a/python/tvm/relay/op/contrib/_contrib.py +++ b/python/tvm/relay/op/contrib/_contrib.py @@ -20,7 +20,7 @@ import topi from .. import op as reg -from ..op import OpPattern +from ..op import schedule_injective, OpPattern # adaptive_max_pool2d @@ -41,3 +41,6 @@ def schedule_adaptive_avg_pool2d(_, outs, target): return topi.generic.schedule_adaptive_pool(outs) reg.register_pattern("contrib.adaptive_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE) + +# relay.contrib.ndarray_size +reg.register_schedule("contrib.ndarray_size", schedule_injective) diff --git a/python/tvm/relay/op/contrib/contrib.py b/python/tvm/relay/op/contrib/contrib.py index 1f073d4aae45..7114b7e712db 100644 --- a/python/tvm/relay/op/contrib/contrib.py +++ b/python/tvm/relay/op/contrib/contrib.py @@ -111,3 +111,21 @@ def adaptive_avg_pool2d(data, """ output_size = [] or output_size return _make.adaptive_avg_pool2d(data, output_size, layout) + +def ndarray_size(data, dtype="int32"): + """Get number of elements of input tensor. + + Parameters + ---------- + data : tvm.relay.Expr + The input tensor. + + dtype : str, optional + The target data type. + + Returns + ------- + result : tvm.relay.Expr + The number of elements of input tensor. + """ + return _make.ndarray_size(data, dtype) diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index b723137a3a8e..60e53784649b 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -279,5 +279,54 @@ RELAY_REGISTER_OP("shape_of") .set_support_level(10) .set_attr("FTVMCompute", ShapeOfCompute); + +TVM_REGISTER_NODE_TYPE(NdarraySizeAttrs); + +bool NdarraySizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(num_inputs, 1); + auto tt = types[0].as(); + CHECK(tt != nullptr); + const auto* param = attrs.as(); + CHECK(param != nullptr); + reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype)); + return true; +} + +Array NdarraySizeCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + CHECK_EQ(inputs.size(), 1); + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array{topi::ndarray_size(inputs[0], param->dtype)}; +} + +TVM_REGISTER_API("relay.op.contrib._make.ndarray_size") +.set_body_typed([](Expr data, DataType dtype) { + auto attrs = make_node(); + attrs->dtype = dtype; + static const Op& op = Op::Get("contrib.ndarray_size"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +}); + +RELAY_REGISTER_OP("contrib.ndarray_size") +.describe(R"code(Returns a tensor representing the number of elements of input tensor. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.NdarraySizeAttrs") +.add_argument("data", "Tensor", "The input tensor.") +.add_type_rel("NdarraySize", NdarraySizeRel) +.set_attr("TOpIsStateful", false) +.set_attr("TOpPattern", kInjective) +.set_attr("FInferCorrectLayout", +ElemwiseArbitraryLayout) +.set_support_level(10) +.set_attr("FTVMCompute", NdarraySizeCompute); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index cf0f8f0123e5..6c9824e4ed13 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1933,6 +1933,22 @@ def check_mean(ishape, **kwargs): check_mean((10, 8, 16, 32), axis=(2, 3)) check_mean((10, 8, 16, 32), axis=(1, 2), keepdims=True) +####################################################################### +# Size +# ---- +def test_forward_size(): + def check_size(ishape): + np_input = np.random.uniform(size=ishape).astype(np.float32) + with tf.Graph().as_default(): + input = tf.placeholder(shape=np_input.shape, dtype=np_input.dtype, name='input') + tf.size(input, name='size') + compare_tf_with_tvm([np_input], ['input:0'], 'size:0') + + if tf.__version__ < LooseVersion('1.1'): + check_size((10, 8, 16, 32)) + check_size((10,)) + check_size(()) + ####################################################################### # All, Max, Min # ------------- @@ -2087,6 +2103,7 @@ def test_placeholder(): test_forward_depthtospace() test_forward_squeeze() test_forward_pack() + test_forward_size() test_forward_broadcast_to() test_forward_fill() test_forward_crop() diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 046da8de5fe8..f3520f3650a3 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -215,6 +215,23 @@ def test_shape_of(): tvm.testing.assert_allclose(op_res.asnumpy(), np.array(shape).astype('int32')) +def test_ndarray_size(): + def verify_ndarray_size(shape): + x = relay.var("x", shape=shape) + func = relay.Function([x], relay.op.contrib.ndarray_size(x)) + func = run_infer_type(func) + + x_data = np.random.uniform(size=shape).astype("float32") + ref_res = np.size(x_data) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), + ref_res) + verify_ndarray_size((2, 3, 5)) + verify_ndarray_size((2, 3, 5, 7)) + def verify_adaptive_pool2d(dshape, out_size, pool_type, layout="NCHW", dtype="float32"): def start_index(index, odim, idim): return int(np.floor(index * idim / odim)) @@ -288,3 +305,5 @@ def _verify(data_shape, mask_value, axis, dtype, itype): test_batch_matmul() test_shape_of() test_sequence_mask() + test_ndarray_size() + diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 43711dadc273..c9a05098fc89 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -1223,5 +1223,28 @@ inline Tensor shape(const Tensor& src, }, name, tag); } +/*! + * \brief Get the size of input tensor. + * \param src the input tensor. + * \param dtype the type of the elements in the tensor. + * \param name output tensor name. + * \param tag output tensor tag. + * \return Tensor of input shape. + */ +inline Tensor ndarray_size(const Tensor& src, + const Type& dtype, + const std::string& name = "ndarray_size", + const std::string& tag = kInjective) { + int ndim = static_cast(src->shape.size()); + Array out_ndarray_size = {1}; + return compute(out_ndarray_size, [&](const Array& indices) { + Expr ret = 1; + for (int i = 0; i < ndim; ++i) { + ret *= src->shape[i]; + } + return tvm::cast(dtype, ret); + }, name, tag); +} + } // namespace topi #endif // TOPI_TRANSFORM_H_ diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py index 738754e91cbc..fc32403065d9 100644 --- a/topi/python/topi/transform.py +++ b/topi/python/topi/transform.py @@ -425,7 +425,7 @@ def shape(array, dtype="int32"): Parameters ---------- array : tvm.Tensor - The source tenosr. + The source tensor. dtype : str, optional The target data type. @@ -477,3 +477,22 @@ def sequence_mask(data, valid_length, mask_value=0, axis=0): "only support data.ndim >= 2, received data.shape = {}".format(data.shape) assert axis == 0 or axis == 1, "only support axis = 0, 1, received axis = {}".format(axis) return cpp.sequence_mask(data, valid_length, mask_value, axis) + + +def ndarray_size(array, dtype="int32"): + """Get the number of elements of input array + + Parameters + ---------- + array : tvm.Tensor + The source tensor. + + dtype : str, optional + The target data type. + + Returns + ------- + result : tvm.Tensor + The resulting tensor. + """ + return cpp.ndarray_size(array, dtype) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 688cc9fc8354..44134d7c2d67 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -311,6 +311,11 @@ TVM_REGISTER_GLOBAL("topi.shape") *rv = shape(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.ndarray_size") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = ndarray_size(args[0], args[1]); +}); + TVM_REGISTER_GLOBAL("topi.split") .set_body([](TVMArgs args, TVMRetValue *rv) { if (args[1].type_code() == kDLInt || args[1].type_code() == kDLUInt) { diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py index 9d69734139a6..7f2c73e00390 100644 --- a/topi/tests/python/test_topi_transform.py +++ b/topi/tests/python/test_topi_transform.py @@ -649,6 +649,33 @@ def check_device(device): for backend in get_all_backend(): check_device(backend) +def test_ndarray_size(): + in_shape = (5, 11, 7) + dtype = "int32" + A = tvm.placeholder(shape=in_shape, dtype="float32", name="A") + B = topi.ndarray_size(A, dtype) + + input = np.random.uniform(size=in_shape).astype(A.dtype) + output = np.asarray(np.size(input)).astype(dtype) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + tvm_input = tvm.nd.array(input, ctx=ctx) + tvm_output = tvm.nd.empty((1,), ctx=ctx, dtype=B.dtype) + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(B) + f = tvm.build(s, [A, B], device, name="ndarray_size") + f(tvm_input, tvm_output) + tvm.testing.assert_allclose(tvm_output.asnumpy(), output) + + for backend in get_all_backend(): + check_device(backend) + + if __name__ == "__main__": test_strided_slice() test_concatenate() @@ -668,3 +695,4 @@ def check_device(device): test_tile() test_shape() test_sequence_mask() + test_ndarray_size()