Skip to content

Commit

Permalink
[RELAY][VM] Add shape_of instruction
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Jun 19, 2020
1 parent 52bf113 commit 0e3c9ae
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 74 deletions.
12 changes: 12 additions & 0 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ enum class Opcode {
LoadConsti = 14U,
Fatal = 15U,
AllocStorage = 16U,
ShapeOf = 17U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -245,6 +246,9 @@ struct Instruction {
/*! \brief The hint of the dtype. */
DLDataType dtype_hint;
} alloc_storage;
struct /* ShapeOf Operands */ {
RegName tensor;
} shape_of;
};

/*!
Expand Down Expand Up @@ -389,6 +393,14 @@ struct Instruction {
static Instruction AllocStorage(RegName size, RegName alignment, DLDataType dtype_hint,
RegName dst);

/*!
* \brief Get the shape of an input tensor.
* \param tensor The input tensor.
* \param dst The destination to store the shape of the given tensor.
* \return The shape of instruction.
*/
static Instruction ShapeOf(RegName tensor, RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/op/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ def shape_func(func, inputs, outputs, dependent=False):
"""
return _make.shape_func(func, inputs, outputs, dependent)

def shape_of(expr):
"""Invoke a function to get the shape of a tensor.
Parameters
----------
expr : tvm.relay.Expr
The expr used to evaluate its tensor shape.
Returns
-------
result : tvm.relay.Expr
The expression with the evaluated tensor shape.
"""
return _make.shape_of(expr)

def flatten_tuple_type(ty):
"""Return a sequence of the types contained in the tuple type in order.
Expand Down
4 changes: 1 addition & 3 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ManifestAllocPass(ExprMutator):
def __init__(self, target_host):
self.invoke_tvm = op.memory.invoke_tvm_op
self.shape_func = op.memory.shape_func
self.shape_of = op.memory.shape_of
self.scopes = [ScopeBuilder()]
self.target_host = target_host
self.default_context = cpu(0)
Expand All @@ -53,9 +54,6 @@ def __init__(self, target_host):
def current_scope(self):
return self.scopes[-1]

def shape_of(self, e):
return op.shape_of(e, self.compute_dtype)

def visit_tuple(self, tup):
scope = self.current_scope()
new_fields = []
Expand Down
14 changes: 14 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
case Opcode::Invoke:
case Opcode::AllocClosure:
case Opcode::AllocStorage:
case Opcode::ShapeOf:
case Opcode::Move:
case Opcode::InvokeClosure:
last_register_ = instr.dst;
Expand Down Expand Up @@ -584,6 +585,19 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
auto outputs = Downcast<Tuple>(args[2]);
EmitShapeFunc(shape_func, inputs->fields, outputs->fields);
})
.Match("memory.shape_of",
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
CHECK_EQ(args.size(), 1U);
// Get the attributes.
auto shape_of_attrs = attrs.as<ShapeOfAttrs>();
CHECK(shape_of_attrs != nullptr) << "must be the shape_of attrs";
CHECK(shape_of_attrs->dtype.bits() == 64)
<< "The dtype of shape of must be int64, but got"
<< DLDataType2String(shape_of_attrs->dtype);
this->VisitExpr(args[0]);
auto src_reg = last_register_;
Emit(Instruction::ShapeOf(src_reg, NewRegister()));
})
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
Expand Down
23 changes: 23 additions & 0 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
namespace tvm {
namespace relay {

// Forward declare the shape_of type relation function.
bool ShapeOfRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter);

TVM_REGISTER_NODE_TYPE(AllocStorageAttrs);
TVM_REGISTER_NODE_TYPE(AllocTensorAttrs);
TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
Expand Down Expand Up @@ -423,5 +427,24 @@ RELAY_REGISTER_OP("memory.shape_func")
return {topi::identity(inputs[0])};
});

RELAY_REGISTER_OP("memory.shape_of")
.describe(R"code(Get the shape of an input tensor.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("tensor", "Tensor", "The input tensor")
.add_type_rel("ShapeOf", ShapeOfRel)
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_of").set_body_typed([](Expr expr) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = DataType::Int(64);
static const Op& op = Op::Get("memory.shape_of");
return Call(op, {expr}, Attrs(attrs), {});
});

} // namespace relay
} // namespace tvm
4 changes: 3 additions & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class ConstantFolder : public ExprMutator {
: executor_(executor),
module_(module),
shape_of_op_(Op::Get("shape_of")),
memory_shape_of_op_(Op::Get("memory.shape_of")),
invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
shape_func_op_(Op::Get("memory.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
Expand Down Expand Up @@ -123,7 +124,7 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
if (call->op == shape_of_op_) {
if (call->op == shape_of_op_ || call->op == memory_shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs);
}

Expand Down Expand Up @@ -166,6 +167,7 @@ class ConstantFolder : public ExprMutator {

// Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_;
const Op& memory_shape_of_op_;
const Op& invoke_tvm_op_;
const Op& shape_func_op_;
const Op& alloc_tensor_op_;
Expand Down
10 changes: 10 additions & 0 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) {
fields.push_back(instr.pc_offset);
break;
}
case Opcode::ShapeOf: {
// Number of fields = 2
fields.assign({instr.shape_of.tensor, instr.dst});
break;
}
default:
LOG(FATAL) << "Invalid opcode" << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -683,6 +688,11 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) {
DCHECK_EQ(instr.fields.size(), 1U);
return Instruction::Goto(instr.fields[0]);
}
case Opcode::ShapeOf: {
// Number of fields = 2
DCHECK_EQ(instr.fields.size(), 2U);
return Instruction::ShapeOf(instr.fields[0], instr.fields[1]);
}
default:
LOG(FATAL) << "Invalid opcode" << instr.opcode;
return Instruction();
Expand Down
31 changes: 31 additions & 0 deletions src/runtime/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ Instruction::Instruction(const Instruction& instr) {
case Opcode::AllocStorage:
this->alloc_storage = instr.alloc_storage;
return;
case Opcode::ShapeOf:
this->shape_of.tensor = instr.shape_of.tensor;
return;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
Expand Down Expand Up @@ -239,6 +242,9 @@ Instruction& Instruction::operator=(const Instruction& instr) {
case Opcode::AllocStorage:
this->alloc_storage = instr.alloc_storage;
return *this;
case Opcode::ShapeOf:
this->shape_of.tensor = instr.shape_of.tensor;
return *this;
default:
std::ostringstream out;
out << "Invalid instruction " << static_cast<int>(instr.op);
Expand All @@ -258,6 +264,7 @@ Instruction::~Instruction() {
case Opcode::Goto:
case Opcode::LoadConsti:
case Opcode::AllocStorage:
case Opcode::ShapeOf:
case Opcode::Fatal:
return;
case Opcode::AllocTensor:
Expand Down Expand Up @@ -351,6 +358,14 @@ Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType
return instr;
}

Instruction Instruction::ShapeOf(RegName tensor, Index dst) {
Instruction instr;
instr.op = Opcode::ShapeOf;
instr.dst = dst;
instr.shape_of.tensor = tensor;
return instr;
}

Instruction Instruction::AllocADT(Index tag, Index num_fields,
const std::vector<RegName>& datatype_fields, Index dst) {
Instruction instr;
Expand Down Expand Up @@ -586,6 +601,10 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
<< DLDataType2String(instr.alloc_storage.dtype_hint);
break;
}
case Opcode::ShapeOf: {
os << "shape_of $" << instr.dst << " $" << instr.shape_of.tensor;
break;
}
default:
LOG(FATAL) << "should never hit this case" << static_cast<int>(instr.op);
break;
Expand Down Expand Up @@ -1040,6 +1059,18 @@ void VirtualMachine::RunLoop() {
pc_++;
goto main_loop;
}
case Opcode::ShapeOf: {
auto input = ReadRegister(instr.shape_of.tensor);
NDArray input_array = Downcast<NDArray>(input);
int ndim = input_array->ndim;
auto out_tensor = NDArray::Empty({ndim}, {kDLInt, 64, 1}, {kDLCPU, 0});
for (int i = 0; i < ndim; ++i) {
reinterpret_cast<int64_t*>(out_tensor->data)[i] = input_array->shape[i];
}
WriteRegister(instr.dst, out_tensor);
pc_++;
goto main_loop;
}
case Opcode::Ret: {
// If we have hit the point from which we started
// running, we should return to the caller breaking
Expand Down
Loading

0 comments on commit 0e3c9ae

Please sign in to comment.