From 326243a9f387ef7075d7470c73811bd7cd8fd120 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 23 Jun 2020 15:55:02 +0000 Subject: [PATCH] vm shape_of --- src/relay/backend/vm/compiler.cc | 2 +- src/relay/op/memory/memory.cc | 4 ++-- src/relay/transforms/fold_constant.cc | 6 +++--- tests/python/relay/test_vm_serialization.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 1f37de21f582..3e9b7ef66d5d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -585,7 +585,7 @@ class VMFunctionCompiler : ExprFunctor { auto outputs = Downcast(args[2]); EmitShapeFunc(shape_func, inputs->fields, outputs->fields); }) - .Match("memory.shape_of", + .Match("vm.shape_of", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { CHECK_EQ(args.size(), 1U); // Get the attributes. diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 1e1491f0dec7..7094c6cff155 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -427,7 +427,7 @@ RELAY_REGISTER_OP("memory.shape_func") return {topi::identity(inputs[0])}; }); -RELAY_REGISTER_OP("memory.shape_of") +RELAY_REGISTER_OP("vm.shape_of") .describe(R"code(Get the shape of an input tensor. )code" TVM_ADD_FILELINE) .set_num_inputs(1) @@ -442,7 +442,7 @@ RELAY_REGISTER_OP("memory.shape_of") TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_of").set_body_typed([](Expr expr) { auto attrs = make_object(); attrs->dtype = DataType::Int(64); - static const Op& op = Op::Get("memory.shape_of"); + static const Op& op = Op::Get("vm.shape_of"); return Call(op, {expr}, Attrs(attrs), {}); }); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 52e3b1ac7e2f..50de8711a4c1 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -81,7 +81,7 @@ class ConstantFolder : public ExprMutator { : executor_(executor), module_(module), shape_of_op_(Op::Get("shape_of")), - memory_shape_of_op_(Op::Get("memory.shape_of")), + vm_shape_of_op_(Op::Get("vm.shape_of")), invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")), shape_func_op_(Op::Get("memory.shape_func")), alloc_tensor_op_(Op::Get("memory.alloc_tensor")), @@ -124,7 +124,7 @@ class ConstantFolder : public ExprMutator { // skip stateful ops. if (op_stateful.get(GetRef(op), false)) return res; // Try to evaluate shape_of op - if (call->op == shape_of_op_ || call->op == memory_shape_of_op_) { + if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) { return EvaluateShapeOf(res, origin_args, call->attrs); } @@ -167,7 +167,7 @@ class ConstantFolder : public ExprMutator { // Cache the following ops for equivalence checking in this pass. const Op& shape_of_op_; - const Op& memory_shape_of_op_; + const Op& vm_shape_of_op_; const Op& invoke_tvm_op_; const Op& shape_func_op_; const Op& alloc_tensor_op_; diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index 60946026acff..95e6c6f3c89e 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -284,7 +284,7 @@ def test_mobilenet(): run_network(mod, params) -def test_memory_shape_of(): +def test_vm_shape_of(): x = relay.var('x', shape=(relay.Any(), relay.Any(), relay.Any()), dtype="float32") relu_x = relay.nn.relu(x) data = np.random.uniform(size=(2, 3, 4)).astype('float32') @@ -310,4 +310,4 @@ def test_memory_shape_of(): test_closure() test_resnet() test_mobilenet() - test_memory_shape_of() + test_vm_shape_of()