Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][VM] Add shape_of instruction #5855

Merged
merged 7 commits into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ enum class Opcode {
LoadConsti = 14U,
Fatal = 15U,
AllocStorage = 16U,
ShapeOf = 17U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -245,6 +246,9 @@ struct Instruction {
/*! \brief The hint of the dtype. */
DLDataType dtype_hint;
} alloc_storage;
struct /* ShapeOf Operands */ {
RegName tensor;
} shape_of;
};

/*!
Expand Down Expand Up @@ -389,6 +393,14 @@ struct Instruction {
static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint,
RegName dst);

/*!
* \brief Get the shape of an input tensor.
* \param tensor The input tensor.
* \param dst The destination to store the shape of the given tensor.
* \return The shape of instruction.
*/
static Instruction ShapeOf(RegName tensor, RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .tensor import *
from .transform import *
from .algorithm import *
from .vm import *
from . import nn
from . import annotation
from . import memory
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/relay/op/vm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=wildcard-import
"""Dialect operators for Relay VM."""
from __future__ import absolute_import as _abs
from . import vm
20 changes: 20 additions & 0 deletions python/tvm/relay/op/vm/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for relay.op.vm"""
import tvm._ffi

tvm._ffi._init_api("relay.op.vm", __name__)
35 changes: 35 additions & 0 deletions python/tvm/relay/op/vm/vm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
"""Dialect operators for Relay VM."""
from . import _ffi_api


def shape_of(expr):
"""Invoke a function to get the shape of a tensor.
Parameters
----------
expr : tvm.relay.Expr
The expr used to evaluate its tensor shape.
Returns
-------
result : tvm.relay.Expr
The expression with the evaluated tensor shape.
"""
return _ffi_api.shape_of(expr)
4 changes: 1 addition & 3 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ManifestAllocPass(ExprMutator):
def __init__(self, target_host):
self.invoke_tvm = op.memory.invoke_tvm_op
self.shape_func = op.memory.shape_func
self.shape_of = op.vm.shape_of
self.scopes = [ScopeBuilder()]
self.target_host = target_host
self.default_context = cpu(0)
Expand All @@ -53,9 +54,6 @@ def __init__(self, target_host):
def current_scope(self):
return self.scopes[-1]

def shape_of(self, e):
return op.shape_of(e, self.compute_dtype)

def visit_tuple(self, tup):
scope = self.current_scope()
new_fields = []
Expand Down
13 changes: 13 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::Invoke:
case Opcode::AllocClosure:
case Opcode::AllocStorage:
case Opcode::ShapeOf:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register_ = instr.dst;
Expand Down Expand Up @@ -584,6 +585,18 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto outputs = Downcast<Tuple>(args[2]);
EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
})
.Match("vm.shape_of",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 1U);
// Get the attributes.
const auto* shape_of_attrs = attrs.as<ShapeOfAttrs>();
CHECK(shape_of_attrs) << "Must be the shape_of attrs";
CHECK_EQ(shape_of_attrs->dtype.bits(), 64)
<< "The dtype of shape of must be int64, but got"
<< DLDataType2String(shape_of_attrs->dtype);
this->VisitExpr(args[0]);
Emit(Instruction::ShapeOf(last_register_, NewRegister()));
})
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
Expand Down
14 changes: 0 additions & 14 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,20 +396,6 @@ RELAY_REGISTER_UNARY_OP("bitwise_not")
// shape_of
TVM_REGISTER_NODE_TYPE(ShapeOfAttrs);

bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
if (tt == nullptr) {
return false;
}
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
auto rank_shape = RankShape(tt->shape);
reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
return true;
}

Array<te::Tensor> ShapeOfCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
CHECK_EQ(inputs.size(), 1);
Expand Down
15 changes: 15 additions & 0 deletions src/relay/op/type_relations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "./type_relations.h"

#include <tvm/arith/analyzer.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/tir/op.h>
Expand Down Expand Up @@ -146,5 +147,19 @@ Array<IndexExpr> RankShape(const Array<IndexExpr>& shape) {
}
}

bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(num_inputs, 1);
auto tt = types[0].as<TensorTypeNode>();
if (tt == nullptr) {
return false;
}
const auto* param = attrs.as<ShapeOfAttrs>();
CHECK(param != nullptr);
auto rank_shape = RankShape(tt->shape);
reporter->Assign(types[1], TensorType(rank_shape, param->dtype));
return true;
}

} // namespace relay
} // namespace tvm
12 changes: 12 additions & 0 deletions src/relay/op/type_relations.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ bool IdentityCompRel(const Array<Type>& types, int num_inputs, const Attrs& attr

Array<IndexExpr> RankShape(const Array<IndexExpr>& shape);

/*!
* \brief The shape of type relation.
*
* \param types The input and output types to the relation.
* \param num_inputs The number of input arguments.
* \param attrs The attributes
* \param reporter The reporter.
* \return true whether relation has been resolved.
*/
bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);

} // namespace relay
} // namespace tvm

Expand Down
58 changes: 58 additions & 0 deletions src/relay/op/vm/vm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/op/vm/vm.cc
* \brief Dialect operators for Relay VM.
*/

#include <topi/elemwise.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/runtime/data_type.h>

#include "../../transforms/infer_layout_util.h"
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
namespace relay {

RELAY_REGISTER_OP("vm.shape_of")
.describe(R"code(Get the shape of an input tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("tensor", "Tensor", "The input tensor")
.add_type_rel("ShapeOf", ShapeOfRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed([](Expr expr) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = DataType::Int(64);
static const Op& op = Op::Get("vm.shape_of");
return Call(op, {expr}, Attrs(attrs), {});
});

} // namespace relay
} // namespace tvm
4 changes: 3 additions & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class ConstantFolder : public ExprMutator {
: executor_(executor),
module_(module),
shape_of_op_(Op::Get("shape_of")),
vm_shape_of_op_(Op::Get("vm.shape_of")),
invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
shape_func_op_(Op::Get("memory.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
Expand Down Expand Up @@ -123,7 +124,7 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
if (call->op == shape_of_op_) {
if (call->op == shape_of_op_ || call->op == vm_shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs);
}

Expand Down Expand Up @@ -166,6 +167,7 @@ class ConstantFolder : public ExprMutator {

// Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_;
const Op& vm_shape_of_op_;
const Op& invoke_tvm_op_;
const Op& shape_func_op_;
const Op& alloc_tensor_op_;
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.pc_offset);
break;
}
case Opcode::ShapeOf: {
// Number of fields = 2
fields.assign({instr.shape_of.tensor, instr.dst});
break;
}
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -683,6 +688,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Goto(instr.fields[0]);
}
case Opcode::ShapeOf: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
Expand Down
Loading