Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][VM] Relay VM memory liveness/lifetime analysis #10026

Merged
merged 21 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ TVM_DLL Pass RelayToTIRTargetHook();
*/
TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device);

/*!
* \brief A pass for manifesting variable lifetimes by inserting kill operations when variables
* become dead. This pass should be run after ManifestAlloc, and should not be run more than once.
*
* \return The pass.
*/
TVM_DLL Pass ManifestLifetimes();

/*!
* \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p VirtualDevice on
* which every Relay sub-expression should run and the result stored. Captures the result of that
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/runtime/vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ enum class Opcode {
ShapeOf = 17U,
ReshapeTensor = 18U,
DeviceCopy = 19U,
KillRegister = 20U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -386,6 +387,8 @@ struct Instruction {
static Instruction DeviceCopy(RegName src, Index src_device_index, Index dst_device_index,
RegName dst);

static Instruction KillRegister(RegName dst);

Instruction();
Instruction(const Instruction& instr);
Instruction& operator=(const Instruction& instr);
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,14 @@ def PlanDevices(config):
return _ffi_api.PlanDevices(config)


def ManifestLifetimes():
"""
Manifest the lifetimes of variables after allocations have been manifested, by inserting kill
operations once variables become dead.
"""
return _ffi_api.ManifestLifetimes()


def FoldExplicitPadding():
"""
FoldExplicitPadding finds explict padding before an op that can support
Expand Down
10 changes: 8 additions & 2 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
case Opcode::Ret:
case Opcode::Goto:
case Opcode::Fatal:
case Opcode::KillRegister:
altanh marked this conversation as resolved.
Show resolved Hide resolved
break;
}
instructions_.push_back(instr);
Expand Down Expand Up @@ -647,8 +648,10 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister()));
})
.Match("memory.kill",
[](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
LOG(FATAL) << "memory.kill is not yet supported";
[this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
ICHECK_EQ(args.size(), 1u);
this->VisitExpr(args[0]);
Emit(Instruction::KillRegister(this->last_register_));
});
matcher(GetRef<Call>(call_node));
return;
Expand Down Expand Up @@ -993,6 +996,9 @@ transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_de
// Compute away possibly introduced constant computation.
pass_seqs.push_back(transform::FoldConstant());

// Insert kills to free memory.
pass_seqs.push_back(transform::ManifestLifetimes());

// Lift constants to the top-level of the block to simplify VM code generation.
// TODO(@icemelon9, @jroesch): Remove this pass for now because some
// instructions need to access to constant
Expand Down
Loading