Skip to content

Commit

Permalink
[RELAY][VM] Add shape_of instruction (#5855)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored Jun 28, 2020
1 parent 17bd06a commit 7312934
Show file tree
Hide file tree
Showing 15 changed files with 268 additions and 88 deletions.
12 changes: 12 additions & 0 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ enum class Opcode {
LoadConsti = 14U,
Fatal = 15U,
AllocStorage = 16U,
ShapeOf = 17U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -245,6 +246,9 @@ struct Instruction {
/*! \brief The hint of the dtype. */
DLDataType dtype_hint;
} alloc_storage;
struct /* ShapeOf Operands */ {
RegName tensor;
} shape_of;
};

/*!
Expand Down Expand Up @@ -389,6 +393,14 @@ struct Instruction {
static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
RegName dst);

/*!
* \brief Get the shape of an input tensor.
* \param tensor The input tensor.
* \param dst The destination to store the shape of the given tensor.
* \return The shape of instruction.
*/
static Instruction ShapeOf(RegName tensor, RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .tensor import *
from .transform import *
from .algorithm import *
from .vm import *
from . import nn
from . import annotation
from . import memory
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/vm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Dialect operators for Relay VM."""
from __future__ import absolute_import as _abs
from . import vm
20 changes: 20 additions & 0 deletions python/tvm/relay/op/vm/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for relay.op.vm"""
import tvm._ffi

tvm._ffi._init_api("relay.op.vm", __name__)
35 changes: 35 additions & 0 deletions python/tvm/relay/op/vm/vm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
"""Dialect operators for Relay VM."""
from . import _ffi_api


def shape_of(expr):
"""Invoke a function to get the shape of a tensor.
Parameters
----------
expr : tvm.relay.Expr
The expr used to evaluate its tensor shape.
Returns
-------
result : tvm.relay.Expr
The expression with the evaluated tensor shape.
"""
return _ffi_api.shape_of(expr)
4 changes: 1 addition & 3 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ManifestAllocPass(ExprMutator):
def __init__(self, target_host):
self.invoke_tvm = op.memory.invoke_tvm_op
self.shape_func = op.memory.shape_func
self.shape_of = op.vm.shape_of
self.scopes = [ScopeBuilder()]
self.target_host = target_host
self.default_context = cpu(0)
Expand All @@ -53,9 +54,6 @@ def __init__(self, target_host):
def current_scope(self):
return self.scopes[-1]

def shape_of(self, e):
return op.shape_of(e, self.compute_dtype)

def visit_tuple(self, tup):
scope = self.current_scope()
new_fields = []
Expand Down
13 changes: 13 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::Invoke:
case Opcode::AllocClosure:
case Opcode::AllocStorage:
case Opcode::ShapeOf:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register_ = instr.dst;
Expand Down Expand Up @@ -588,6 +589,18 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto outputs = Downcast<Tuple>(args[2]);
EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
})
.Match("vm.shape_of",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 1U);
// Get the attributes.
const auto* shape_of_attrs = attrs.as<ShapeOfAttrs>();
CHECK(shape_of_attrs) << "Must be the shape_of attrs";
CHECK_EQ(shape_of_attrs->dtype.bits(), 64)
<< "The dtype of shape of must be int64, but got"
<< DLDataType2String(shape_of_attrs->dtype);
this->VisitExpr(args[0]);
Emit(Instruction::ShapeOf(last_register_, NewRegister()));
})
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
Expand Down
14 changes: 0 additions & 14 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,20 +396,6 @@ RELAY_REGISTER_UNARY_OP("bitwise_not")
// shape_of
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);

bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
if (tt == nullptr) {
return false;
}
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
auto rank_shape = RankShape(tt->shape);
reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
return true;
}

Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 1);
Expand Down
15 changes: 15 additions & 0 deletions src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "./type_relations.h"

#include <tvm/arith/analyzer.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/tir/op.h>
Expand Down Expand Up @@ -146,5 +147,19 @@ Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
}
}

bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
if (tt == nullptr) {
return false;
}
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
auto rank_shape = RankShape(tt->shape);
reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
return true;
}

} // namespace relay
} // namespace tvm
12 changes: 12 additions & 0 deletions src/relay/op/type_relations.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ bool IdentityCompRel(const Array<Type>& types, int num_inputs, const Attrs& attr

Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);

/*!
* \brief The shape of type relation.
*
* \param types The input and output types to the relation.
* \param num_inputs The number of input arguments.
* \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/
bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);

} // namespace relay
} // namespace tvm

Expand Down
58 changes: 58 additions & 0 deletions src/relay/op/vm/vm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/op/vm/vm.cc
* \brief Dialect operators for Relay VM.
*/

#include <topi/elemwise.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/data_type.h>

#include "../../transforms/infer_layout_util.h"
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
namespace relay {

RELAY_REGISTER_OP("vm.shape_of")
.describe(R"code(Get the shape of an input tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("tensor", "Tensor", "The input tensor")
.add_type_rel("ShapeOf", ShapeOfRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = DataType::Int(64);
static const Op& op = Op::Get("vm.shape_of");
return Call(op, {expr}, Attrs(attrs), {});
});

} // namespace relay
} // namespace tvm
4 changes: 3 additions & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class ConstantFolder : public ExprMutator {
: executor_(executor),
module_(module),
shape_of_op_(Op::Get("shape_of")),
vm_shape_of_op_(Op::Get("vm.shape_of")),
invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
shape_func_op_(Op::Get("memory.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
Expand Down Expand Up @@ -123,7 +124,7 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
if (call->op == shape_of_op_) {
if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs);
}

Expand Down Expand Up @@ -166,6 +167,7 @@ class ConstantFolder : public ExprMutator {

// Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_;
const Op& vm_shape_of_op_;
const Op& invoke_tvm_op_;
const Op& shape_func_op_;
const Op& alloc_tensor_op_;
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.pc_offset);
break;
}
case Opcode::ShapeOf: {
// Number of fields = 2
fields.assign({instr.shape_of.tensor, instr.dst});
break;
}
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -683,6 +688,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Goto(instr.fields[0]);
}
case Opcode::ShapeOf: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
Expand Down
Loading

0 comments on commit 7312934

Please sign in to comment.