diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index b20cd8dfaaf47..fdfaef9259164 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -280,34 +280,34 @@ RELAY_REGISTER_OP("shape_of") .set_attr("FTVMCompute", ShapeOfCompute); -TVM_REGISTER_NODE_TYPE(NumElementsAttrs); +TVM_REGISTER_NODE_TYPE(SizeAttrs); -bool NumElementsRel(const Array& types, +bool SizeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { CHECK_EQ(num_inputs, 1); auto tt = types[0].as(); CHECK(tt != nullptr); - const auto* param = attrs.as(); + const auto* param = attrs.as(); CHECK(param != nullptr); reporter->Assign(types[1], TensorTypeNode::make({1}, param->dtype)); return true; } -Array NumElementsCompute(const Attrs& attrs, +Array SizeCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, const Target& target) { CHECK_EQ(inputs.size(), 1); - const auto* param = attrs.as(); + const auto* param = attrs.as(); CHECK(param != nullptr); return Array{topi::size(inputs[0], param->dtype)}; } TVM_REGISTER_API("relay.op.contrib._make.size") .set_body_typed([](Expr data, DataType dtype) { - auto attrs = make_node(); + auto attrs = make_node(); attrs->dtype = dtype; static const Op& op = Op::Get("contrib.ndarray_size"); return CallNode::make(op, {data}, Attrs(attrs), {}); @@ -318,15 +318,15 @@ RELAY_REGISTER_OP("contrib.ndarray_size") )code" TVM_ADD_FILELINE) .set_num_inputs(1) -.set_attrs_type_key("relay.attrs.NumElementsAttrs") +.set_attrs_type_key("relay.attrs.SizeAttrs") .add_argument("data", "Tensor", "The input tensor.") -.add_type_rel("NumElements", NumElementsRel) +.add_type_rel("Size", SizeRel) .set_attr("TOpIsStateful", false) .set_attr("TOpPattern", kInjective) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_support_level(10) -.set_attr("FTVMCompute", NumElementsCompute); +.set_attr("FTVMCompute", SizeCompute); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 52b6cc8023ce9..f3520f3650a39 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -218,8 +218,8 @@ def test_shape_of(): def test_ndarray_size(): def verify_ndarray_size(shape): x = relay.var("x", shape=shape) - func = relay.Function([x], relay.op.contrib.num_elements(x)) - func = relay.ir_pass.infer_type(func) + func = relay.Function([x], relay.op.contrib.ndarray_size(x)) + func = run_infer_type(func) x_data = np.random.uniform(size=shape).astype("float32") ref_res = np.size(x_data)