Skip to content

Commit

Permalink
vm shape_of
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jun 23, 2020
1 parent 9b32292 commit 326243a
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto outputs = Downcast<Tuple>(args[2]);
EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
})
.Match("memory.shape_of",
.Match("vm.shape_of",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 1U);
// Get the attributes.
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ RELAY_REGISTER_OP("memory.shape_func")
return {topi::identity(inputs[0])};
});

RELAY_REGISTER_OP("memory.shape_of")
RELAY_REGISTER_OP("vm.shape_of")
.describe(R"code(Get the shape of an input tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
Expand All @@ -442,7 +442,7 @@ RELAY_REGISTER_OP("memory.shape_of")
TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_of").set_body_typed([](Expr expr) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = DataType::Int(64);
static const Op& op = Op::Get("memory.shape_of");
static const Op& op = Op::Get("vm.shape_of");
return Call(op, {expr}, Attrs(attrs), {});
});

Expand Down
6 changes: 3 additions & 3 deletions src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ConstantFolder : public ExprMutator {
: executor_(executor),
module_(module),
shape_of_op_(Op::Get("shape_of")),
memory_shape_of_op_(Op::Get("memory.shape_of")),
vm_shape_of_op_(Op::Get("vm.shape_of")),
invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
shape_func_op_(Op::Get("memory.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
Expand Down Expand Up @@ -124,7 +124,7 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
if (call->op == shape_of_op_ || call->op == memory_shape_of_op_) {
if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs);
}

Expand Down Expand Up @@ -167,7 +167,7 @@ class ConstantFolder : public ExprMutator {

// Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_;
const Op& memory_shape_of_op_;
const Op& vm_shape_of_op_;
const Op& invoke_tvm_op_;
const Op& shape_func_op_;
const Op& alloc_tensor_op_;
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_vm_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def test_mobilenet():
run_network(mod, params)


def test_memory_shape_of():
def test_vm_shape_of():
x = relay.var('x', shape=(relay.Any(), relay.Any(), relay.Any()), dtype="float32")
relu_x = relay.nn.relu(x)
data = np.random.uniform(size=(2, 3, 4)).astype('float32')
Expand All @@ -310,4 +310,4 @@ def test_memory_shape_of():
test_closure()
test_resnet()
test_mobilenet()
test_memory_shape_of()
test_vm_shape_of()

0 comments on commit 326243a

Please sign in to comment.