Skip to content

Commit

Permalink
[Relay][VM] Add ReshapeTensor instruction in the VM to replace the re…
Browse files Browse the repository at this point in the history
…shape op (#6089)

* [VM] Add reshape tensor instruction

* update

* lint

* fix

* fix
  • Loading branch information
icemelon authored Jul 21, 2020
1 parent d8c9bb1 commit 526b5a5
Show file tree
Hide file tree
Showing 15 changed files with 270 additions and 32 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/attrs/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ struct ShapeFuncAttrs : public tvm::AttrsNode<ShapeFuncAttrs> {
}
};

/*!
* \brief Attributes for VM reshape_tensor operator.
*/
struct ReshapeTensorAttrs : public tvm::AttrsNode<ReshapeTensorAttrs> {
Array<PrimExpr> newshape;

TVM_DECLARE_ATTRS(ReshapeTensorAttrs, "relay.attrs.ReshapeTensorAttrs") {
TVM_ATTR_FIELD(newshape).describe("The new shape of output tensor");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_VM_H_
14 changes: 14 additions & 0 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ enum class Opcode {
Fatal = 15U,
AllocStorage = 16U,
ShapeOf = 17U,
ReshapeTensor = 18U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -249,6 +250,10 @@ struct Instruction {
struct /* ShapeOf Operands */ {
RegName tensor;
} shape_of;
struct /* ReshapeTensor Operands */ {
RegName tensor;
RegName newshape;
} reshape_tensor;
};

/*!
Expand Down Expand Up @@ -401,6 +406,15 @@ struct Instruction {
*/
static Instruction ShapeOf(RegName tensor, RegName dst);

/*!
* \brief Reshape the tensor given the new shape.
* \param tensor The input tensor.
* \param newshape The shape tensor.
* \param dst The destination to store the output tensor with new shape.
* \return The reshape tensor instruction.
*/
static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,9 @@ def lower_call(call, inputs, target):
new_fields.append(field)
ret_type = _ty.TupleType(new_fields)

is_dyn = _ty.type_has_any(call.checked_type)
is_dyn = _ty.is_dynamic(call.checked_type)
for arg in call.args:
is_dyn = is_dyn or _ty.type_has_any(arg.checked_type)
is_dyn = is_dyn or _ty.is_dynamic(arg.checked_type)

# check if in the AutoTVM tracing mode, and disable if op is not in wanted list
env = autotvm.task.TaskExtractEnv.current
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import tvm.runtime.vm as vm_rt
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm.relay.ty import type_has_any
from tvm.relay.ty import is_dynamic
from tvm.relay.backend.interpreter import Executor
from . import _vm

Expand Down Expand Up @@ -257,7 +257,7 @@ def _make_executor(self, expr=None):
def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs)
ret_type = self.mod["main"].checked_type.ret_type
if type_has_any(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
self.target):
raise ValueError(
"Virtual Machine only supports dynamic graphs on CPU, got output type",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _make_executor(self, expr=None):
if expr:
self.mod["main"] = expr
ret_type = self.mod["main"].checked_type.ret_type
if _ty.type_has_any(ret_type):
if _ty.is_dynamic(ret_type):
raise ValueError("Graph Runtime only supports static graphs, got output type",
ret_type)
num_outputs = len(ret_type.fields) if isinstance(ret_type, _ty.TupleType) else 1
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/vm/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,20 @@ def shape_func(func, inputs, outputs, is_inputs):
The shape function expression.
"""
return _ffi_api.shape_func(func, inputs, outputs, is_inputs)


def reshape_tensor(data, shape, newshape):
"""Invoke the VM ReshapeTensor instruction.
Parameters
----------
data : tvm.relay.Expr
The input data.
shape : tvm.relay.Expr
The newshape tensor.
newshape : List[tvm.ir.PrimExpr]
The new shape.
"""
return _ffi_api.reshape_tensor(data, shape, newshape)
80 changes: 62 additions & 18 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
A pass for manifesting explicit memory allocations.
"""
import numpy as np
from ..expr_functor import ExprMutator
from ..expr_functor import ExprVisitor, ExprMutator
from ..scope_builder import ScopeBuilder
from . import transform
from .. import op
Expand All @@ -38,13 +38,39 @@ def is_primitive(call):
return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \
hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1


class CheckReshapeOnly(ExprVisitor):
"""A pass to check if the fused op contains only reshape ops."""
def __init__(self):
super().__init__()
self._reshape_ops = [op.get("reshape"), op.get("contrib_reverse_reshape"),
op.get("dyn.reshape")]
self.reshape_only = True

def visit_call(self, call):
if not self.reshape_only:
return
if call.op not in self._reshape_ops:
self.reshape_only = False
for arg in call.args:
self.visit(arg)


def is_reshape_only(func):
"""Check if the primitive function contains only reshape ops."""
check = CheckReshapeOnly()
check.visit(func)
return check.reshape_only


class ManifestAllocPass(ExprMutator):
"""A pass for explicitly manifesting all memory allocations in Relay."""

def __init__(self, target_host):
self.invoke_tvm = op.vm.invoke_tvm_op
self.shape_func = op.vm.shape_func
self.shape_of = op.vm.shape_of
self.reshape_tensor = op.vm.reshape_tensor
self.scopes = [ScopeBuilder()]
self.target_host = target_host
self.default_context = cpu(0)
Expand Down Expand Up @@ -121,8 +147,8 @@ def visit_let(self, let):

return scope.get()

def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
"""Generate the code for invoking a TVM op with a dynamic shape."""
def emit_shape_func(self, scope, func, new_args):
"""Insert the shape function given a primitive function."""
shape_func_ins = []
engine = compile_engine.get()
cfunc = engine.lower_shape_func(func, self.target_host)
Expand Down Expand Up @@ -165,9 +191,14 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
expr.Tuple(out_shapes), is_inputs)

scope.let("shape_func", shape_call)
return out_shapes

def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
"""Generate the code for invoking a TVM op with a dynamic shape."""
out_shapes = self.emit_shape_func(scope, func, new_args)

storages = []
for out_shape, out_type in zip(out_shapes, out_types):
for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)):
size = self.compute_storage_in_relay(
out_shape, out_type.dtype)
alignment = self.compute_alignment(out_type.dtype)
Expand All @@ -191,8 +222,18 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type):
scope.let("", invoke)
return to_tuple_type(ret_type, tuple_outs.fields)

def emit_reshape_tensor(self, scope, func, new_args, ret_type):
if self.is_dynamic(ret_type):
out_shapes = self.emit_shape_func(scope, func, new_args)
shape_expr = out_shapes[0]
else:
# constant output shape
shape = [int(dim) for dim in ret_type.shape]
shape_expr = expr.const(shape, dtype=self.compute_dtype)
return self.reshape_tensor(new_args[0], shape_expr, ret_type.shape)

def is_dynamic(self, ret_type):
is_dynamic = ty.type_has_any(ret_type)
is_dynamic = ty.is_dynamic(ret_type)
# TODO(@jroesch): restore this code, more complex then it seems
# for arg in call.args:
# is_dynamic = is_dynamic or arg.checked_type.is_dynamic()
Expand All @@ -208,22 +249,25 @@ def visit_call(self, call):
ret_type = call.checked_type
out_types = flatten_tuple_type(ret_type)

if is_reshape_only(call.op):
# Handle fused op that only contains reshape op
return self.emit_reshape_tensor(scope, call.op, new_args, ret_type)

if self.is_dynamic(ret_type):
# Handle dynamic case.
return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type)
else:
# Handle static case.
outs = []
for i, out_ty in enumerate(out_types):
out = self.make_static_allocation(scope, out_ty, i)
outs.append(out)

output = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, output)
scope.let("", invoke)
return to_tuple_type(ret_type, output.fields)
else:
return super().visit_call(call)

# Handle static case.
outs = []
for i, out_ty in enumerate(out_types):
out = self.make_static_allocation(scope, out_ty, i)
outs.append(out)

output = expr.Tuple(outs)
invoke = self.invoke_tvm(call.op, ins, output)
scope.let("", invoke)
return to_tuple_type(ret_type, output.fields)
return super().visit_call(call)


@transform.function_pass(opt_level=0)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

Any = _ffi_api.Any

def type_has_any(tensor_type):
"""Check whether type has any as a shape.
def is_dynamic(tensor_type):
"""Check whether type has any or symbolic variables as a shape.
tensor_type : Type
The type to be inspected
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ struct IsDynamicVisitor : public TypeVisitor {
bool is_dyn{false};
void VisitType_(const TensorTypeNode* tt) {
for (auto dim : tt->shape) {
if (dim.as<AnyNode>()) {
if (dim.as<tir::IntImmNode>() == nullptr) {
is_dyn = true;
break;
}
Expand Down
10 changes: 10 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::AllocClosure:
case Opcode::AllocStorage:
case Opcode::ShapeOf:
case Opcode::ReshapeTensor:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register_ = instr.dst;
Expand Down Expand Up @@ -601,6 +602,15 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
this->VisitExpr(args[0]);
Emit(Instruction::ShapeOf(last_register_, NewRegister()));
})
.Match("vm.reshape_tensor",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 2u);
this->VisitExpr(args[0]);
auto tensor_reg = last_register_;
this->VisitExpr(args[1]);
auto shape_reg = last_register_;
Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister()));
})
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
Expand Down
2 changes: 2 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,8 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
infer_dim = indexdiv(infer_dim, oshape[i]);
}
}
arith::Analyzer ana;
infer_dim = ana.Simplify(infer_dim);
oshape.Set(infer_idx, infer_dim);
}

Expand Down
37 changes: 37 additions & 0 deletions src/relay/op/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
namespace tvm {
namespace relay {

// vm.shape_func
TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);

RELAY_REGISTER_OP("vm.shape_of")
Expand Down Expand Up @@ -133,6 +134,7 @@ RELAY_REGISTER_OP("vm.shape_func")
return {topi::identity(inputs[0])};
});

// vm.invoke_tvm_op
bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4u);
Expand Down Expand Up @@ -181,5 +183,40 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op")
return {topi::identity(inputs[0])};
});

// vm.reshape
TVM_REGISTER_NODE_TYPE(ReshapeTensorAttrs);

bool ReshapeTensorRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3u);
auto reshape_attrs = attrs.as<ReshapeTensorAttrs>();
CHECK(reshape_attrs);
auto tt = types[0].as<TensorTypeNode>();
CHECK(tt) << "input must be tensor type";
reporter->Assign(types[2], TensorType(reshape_attrs->newshape, tt->dtype));
return true;
}

RELAY_REGISTER_OP("vm.reshape_tensor")
.describe(R"code(Use VM reshape_tensor instruction to reshape the tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor")
.add_argument("shape", "Tensor", "The output shape tensor")
.add_type_rel("ReshapeTensor", ReshapeTensorRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor")
.set_body_typed([](Expr data, Expr shape, Array<PrimExpr> newshape) {
static const Op& op = Op::Get("vm.reshape_tensor");
auto attrs = make_object<ReshapeTensorAttrs>();
attrs->newshape = std::move(newshape);
return Call(op, {data, shape}, Attrs(attrs), {});
});

} // namespace relay
} // namespace tvm
10 changes: 10 additions & 0 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.assign({instr.shape_of.tensor, instr.dst});
break;
}
case Opcode::ReshapeTensor: {
// Number of fields = 3
fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst});
break;
}
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -693,6 +698,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
}
case Opcode::ReshapeTensor: {
// Number of fields = 3
DCHECK_EQ(instr.fields.size(), 3U);
return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
Expand Down
Loading

0 comments on commit 526b5a5

Please sign in to comment.