Skip to content

Commit

Permalink
[Relay] Remove DeviceMap from LowerTE (#8788)
Browse files Browse the repository at this point in the history
* [Relay] Switch the graph, VM and AOT executors to use the merged
device_planner.cc from #9038, and finally remove DeviceMap from the
LowerTE Pass.

- We retire analysis/context_analysis.cc and
  transforms/device_annotation.cc (and their tests). That
  includes the CollectDeviceInfo, CollectDeviceAnnotationOps and
  ContextAnalysis entry points. These are all subsumed by the
  PlanDevices pass and the device aware visitors.
- The following passes now use the new 'Device Aware' visitors to
  recover the device for every Relay sub-expression:
     - backend/aot_executor_codegen.cc (AOTOnDemandAllocator)
     - backend/graph_plan_memory.cc (StorageAllocaBaseVisitor etc)
     - backend/te_compiler.cc (LowerTensorExprMutator)
     - transforms/memory_alloc.cc (DialectRewriter)
     - backend/vm/compiler.cc (VMFunctionCompiler)
- The following passes/utils must maintain the device information
  encoded by the device planner within "on_device" annotations and
  "param_device_types"/"result_device_type" function attributes:
     - backend/vm/lambda_lift.cc (LambdaLifter)
     - transforms/to_a_normal_form.cc (Fill)
     - ir/expr_functior.cc (Bind)
- Remove a lot ad-hoc 'homogeneous' vs 'hetrogeneous' conditionals
  in favor of just asking for the device. Also removed a lot of ad-doc
  encodings of the 'default' device.
- We no longer need to run device-planning twice (before and after
  lowering). Device planning is also decoupled from memory planning.
- The LowerTE Pass no longer needs an expression-to-device side table
  (which was the problem which kicked this series of PRs off in the first place).

* [checkpoint] Revert unnecessary changes

- Started down multi-target handling in interpreter but didn't finish
- Some one-off debug stuff

* [checkpoint] TODO's for default device logic
  • Loading branch information
mbs-octoml authored Oct 4, 2021
1 parent 2f02b1e commit 779a506
Show file tree
Hide file tree
Showing 46 changed files with 1,157 additions and 3,276 deletions.
29 changes: 0 additions & 29 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*!
* \brief Collect the device anntation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);

/*!
* \brief Finds cases that the given match expression does not catch, if any.
*
Expand Down Expand Up @@ -268,17 +250,6 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod);
*/
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);

/*!
* \brief Analyze the device context of each IR node in a given relay module.
*
* \param mod The module for analysis.
* \param default_device The default device used by unassigned IR nodes.
*
* \return The mapping between an IR node and its associated device.
*/
TVM_DLL std::unordered_map<Expr, Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
ContextAnalysis(const IRModule& mod, const Device& default_device);

} // namespace relay
} // namespace tvm

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,12 @@ TVM_DLL Pass ToANormalForm();
/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param maybe_mod optional module holding definitions for global vars in \p expr
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);
TVM_DLL Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ struct VMFunction {
/*! \brief The size of the frame for this function */
Index register_file_size;
/*! \brief The device type of each parameter for this function. */
std::vector<Index> params_device_type;
std::vector<DLDeviceType> params_device_type;

VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions, Index register_file_size,
const std::vector<Index> params_device_type = {})
const std::vector<DLDeviceType> params_device_type = {})
: name(name),
params(params),
instructions(instructions),
Expand Down
49 changes: 0 additions & 49 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@
from .feature import Feature


def context_analysis(mod, default_device):
"""Analyze the device context information of each IR node in a Relay
program.
Parameters
----------
mod : tvm.IRModule
The input module.
default_device : tvm.runtime.Device
The default context allocated to an IR node.
"""
return _ffi_api.ContextAnalysis(mod, default_device)


def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
Expand Down Expand Up @@ -268,40 +253,6 @@ def all_dtypes(expr):
return set(_ffi_api.all_dtypes(expr))


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.ir.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _ffi_api.CollectDeviceInfo(expr)


def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.Expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _ffi_api.CollectDeviceAnnotationOps(expr)


def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
register_broadcast_schedule("fast_erf")
# a fake on_device schedule.
# this will not be used in actual computation
# as on_device will be removed during DeviceAnnotation pass
register_injective_schedule("on_device")


Expand Down
21 changes: 0 additions & 21 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,27 +544,6 @@ def MergeCompilerRegions():
return _ffi_api.MergeCompilerRegions()


def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_device`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Parameters
----------
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.
Returns
-------
ret: tvm.transform.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
return _ffi_api.RewriteDeviceAnnotation(fallback_device)


def ToANormalForm():
"""Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
Expand Down
48 changes: 24 additions & 24 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class Parser {
* Useful for matching optional tokens, effectively looksahead by one.
*/
bool WhenMatch(const TokenType& token_type) {
VLOG(1) << "Parser::WhenMatch: Peek() == " << Peek();
VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek();
if (Peek()->token_type == token_type) {
Consume(token_type);
return true;
Expand Down Expand Up @@ -594,7 +594,7 @@ class Parser {
template <typename R>
R WithSpan(std::function<R()> parser) {
auto start_span = Peek()->span;
VLOG(0) << "WithSpan: start_span = " << start_span;
VLOG(9) << "WithSpan: start_span = " << start_span;
R ast = parser();
if (ast.defined()) {
// The token at the head of the stream is now 1 past where we parsed. So we find its start
Expand All @@ -608,7 +608,7 @@ class Parser {
span_pos--;
}
auto end_token = tokens.at(span_pos);
VLOG(0) << "WithSpan: end_span = " << end_token->span;
VLOG(9) << "WithSpan: end_span = " << end_token->span;
ast->span = start_span.Merge(end_token->span);
}
return ast;
Expand Down Expand Up @@ -668,7 +668,7 @@ class Parser {
template <typename T>
Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse,
std::function<bool()> before_stop = nullptr) {
VLOG(0) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep)
VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep)
<< " stop=" << ToString(stop);
Match(start);

Expand All @@ -686,7 +686,7 @@ class Parser {
if (WhenMatch(stop)) {
return Array<T>();
} else {
VLOG(0) << "Parser::ParseSequence: parse first";
VLOG(9) << "Parser::ParseSequence: parse first";
auto data = parse();
Array<T> elements = {data};

Expand All @@ -695,7 +695,7 @@ class Parser {
// parse '( expr ',' * ')'
} else if (WhenMatch(sep)) {
while (true) {
VLOG(0) << "Parser::ParseSequence: parse element";
VLOG(9) << "Parser::ParseSequence: parse element";
if (WhenMatch(stop)) {
break;
} else {
Expand Down Expand Up @@ -893,12 +893,12 @@ class Parser {

/*! \brief Parse a single Relay expression. */
Expr ParseExpr() {
VLOG(0) << "Parser::ParseExpr";
VLOG(9) << "Parser::ParseExpr";
return WithSpan<Expr>([this] {
std::vector<Expr> exprs;

while (true) {
VLOG(0) << "Parser::ParseExpr: parsing a single expression";
VLOG(9) << "Parser::ParseExpr: parsing a single expression";
auto next = Peek();
switch (next->token_type) {
// For graph or let, match first rhs, then invoke ParseBindingExpr
Expand Down Expand Up @@ -1011,7 +1011,7 @@ class Parser {
// This ensures for n sequential bindings
// the call depth will be the same before
// and after parsing the n bindings.
VLOG(0) << "Parser::ParseBindingExpr";
VLOG(9) << "Parser::ParseBindingExpr";
std::vector<std::tuple<Var, Expr, Span>> bindings;
int scopes = 0;

Expand Down Expand Up @@ -1085,7 +1085,7 @@ class Parser {
* Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }.
*/
Function ParseFunctionDef() {
VLOG(0) << "Parser::ParseFunctionDef";
VLOG(9) << "Parser::ParseFunctionDef";
return WithSpan<Function>([&]() {
PushScope();
PushTypeScope();
Expand Down Expand Up @@ -1147,7 +1147,7 @@ class Parser {
/*! \brief Parse an if-expression. */
Expr ParseIf() {
return WithSpan<Expr>([&]() {
VLOG(0) << "Parser::ParseIf";
VLOG(9) << "Parser::ParseIf";
Consume(TokenType::kIf);

auto guard = WithSpan<Expr>([&] { return Parens<Expr>([&] { return ParseExpr(); }); });
Expand Down Expand Up @@ -1186,7 +1186,7 @@ class Parser {
* This function recursively parses a pattern.
*/
Pattern ParsePattern() {
VLOG(0) << "Parser::ParsePattern";
VLOG(9) << "Parser::ParsePattern";
auto next = Peek();
switch (next->token_type) {
case TokenType::kUnderscore: {
Expand Down Expand Up @@ -1249,7 +1249,7 @@ class Parser {
}

Expr ParseExprBinOp() {
VLOG(0) << "Parser::ParseExprBinOp";
VLOG(9) << "Parser::ParseExprBinOp";
return WithSpan<Expr>([this] {
// We must parse at least one expression, the default
// case is that there is no operator and we will fall
Expand Down Expand Up @@ -1333,7 +1333,7 @@ class Parser {
}

ObjectRef ParseAttributeValue() {
VLOG(0) << "Parser::ParseAttributeValue";
VLOG(9) << "Parser::ParseAttributeValue";
auto next = Peek();
switch (next->token_type) {
case TokenType::kFloat:
Expand Down Expand Up @@ -1375,7 +1375,7 @@ class Parser {
}

Map<String, ObjectRef> ParseAttrs() {
VLOG(0) << "Parser::ParseAttrs";
VLOG(9) << "Parser::ParseAttrs";
Map<String, ObjectRef> kwargs;
while (Peek()->token_type == TokenType::kIdentifier) {
auto key = GetHierarchicalName(ParseHierarchicalName().data);
Expand All @@ -1387,14 +1387,14 @@ class Parser {
kwargs.Set(key, value);
WhenMatch(TokenType::kComma);
}
VLOG(0) << "Parser::ParseAttrs: kwargs=" << kwargs;
VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs;
return kwargs;
}

Expr ParseCallArgs(Expr op) {
ICHECK(op.defined()) << "the operator must be defined";

VLOG(0) << "Parser::ParseCallArgs";
VLOG(9) << "Parser::ParseCallArgs";
Attrs attrs;
std::string op_key;
bool is_op = false;
Expand Down Expand Up @@ -1471,7 +1471,7 @@ class Parser {
}

Expr ParseCallExpr() {
VLOG(0) << "Parser::ParseCallExpr";
VLOG(9) << "Parser::ParseCallExpr";
return WithSpan<Expr>([this] {
Expr expr = ParseAtomicExpr();
// Parse as many call args as possible, building up expression
Expand Down Expand Up @@ -1500,7 +1500,7 @@ class Parser {
}

Expr GetOp(const std::string& op_name, const Span& span) {
VLOG(0) << "op_name=" << op_name << " span=" << span;
VLOG(9) << "op_name=" << op_name << " span=" << span;
try {
return Op::Get(op_name);
} catch (const Error& e) {
Expand All @@ -1513,7 +1513,7 @@ class Parser {
}

Expr ParseAtomicExpr() {
VLOG(0) << "Parser::ParseAtomicExpr";
VLOG(9) << "Parser::ParseAtomicExpr";
Expr expr = WithSpan<Expr>([this] {
auto next = Peek();
switch (next->token_type) {
Expand Down Expand Up @@ -1649,7 +1649,7 @@ class Parser {
auto token = Match(TokenType::kInteger);
auto index = token.ToNumber();
auto span = token->span.Merge(expr->span);
VLOG(0) << "Parser::ParseAtomicExpr: tuple get item";
VLOG(9) << "Parser::ParseAtomicExpr: tuple get item";
return relay::TupleGetItem(expr, index, span);
} else {
return expr;
Expand Down Expand Up @@ -1870,7 +1870,7 @@ class Parser {

Parser InitParser(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
SourceName src_name = SourceName::Get(file_name);
Source source(src_name, file_content);

Expand Down Expand Up @@ -1909,7 +1909,7 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "ParseModule";
VLOG(9) << "ParseModule";
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
Expand All @@ -1923,7 +1923,7 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte
}

Expr ParseExpr(const std::string& file_name, const std::string& file_content) {
VLOG(0) << "ParseExpr";
VLOG(9) << "ParseExpr";
auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable());
parser.ParseSemVer(false);
parser.PushScope();
Expand Down
4 changes: 2 additions & 2 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ struct Tokenizer {
int line = this->line;
int col = this->col;
auto next = Peek();
VLOG(1) << "tvm::parser::TokenizeOnce: next=" << next;
VLOG(9) << "tvm::parser::TokenizeOnce: next=" << next;
if (next == '\n') {
auto token = NewToken(TokenType::kNewline);
Next();
Expand Down Expand Up @@ -550,7 +550,7 @@ struct Tokenizer {
}

void Tokenize() {
VLOG(0) << "tvm::parser::Tokenize";
VLOG(9) << "tvm::parser::Tokenize";
while (this->More()) {
auto token = TokenizeOnce();
ICHECK(token.defined());
Expand Down
Loading

0 comments on commit 779a506

Please sign in to comment.