From 11c01d39a8bf5bdd7086ea772c07bea7f8b44dda Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Fri, 28 Jan 2022 20:12:46 -0800 Subject: [PATCH] rewrite CFG analysis in stages, support ADTs --- src/relay/backend/vm/plan_memory.cc | 636 +++++++++++++--------------- 1 file changed, 302 insertions(+), 334 deletions(-) diff --git a/src/relay/backend/vm/plan_memory.cc b/src/relay/backend/vm/plan_memory.cc index e25d626424b4..c6e04687751d 100644 --- a/src/relay/backend/vm/plan_memory.cc +++ b/src/relay/backend/vm/plan_memory.cc @@ -25,6 +25,7 @@ #include #include "../../../support/arena.h" +#include "../../op/memory/device_copy.h" #include "../../transforms/device_aware_visitors.h" #include "../../transforms/let_list.h" @@ -32,24 +33,70 @@ namespace tvm { namespace relay { namespace transform { -using support::LinkedList; -using support::LinkNode; - using VarSet = std::unordered_set; class ControlFlowGraph { public: + struct Node; + struct BasicBlock; + + using NodePtr = std::shared_ptr; + using BasicBlockPtr = std::shared_ptr; + + struct BasicBlock { + // The nodes of the basic block. + std::vector nodes; + // The predecessor basic blocks. + std::vector pred; + // The successor basic blocks. + std::vector succ; + + static BasicBlockPtr Make() { return std::make_shared(); } + }; + struct Node { - LinkedList pred; - LinkedList succ; + // The basic block this node belongs to. + BasicBlockPtr parent; + // The index into the parent basic block where this node is. + size_t index; + // The expr corresponding to this node. Expr expr; + + // Returns whether or not this node is the last one in the parent basic block. + bool IsLast() const { return index == parent->nodes.size() - 1; } + + // Returns the successor nodes of this node. + std::vector GetSucc() const { + std::vector succ; + if (IsLast()) { + for (const BasicBlockPtr& succ_block : parent->succ) { + succ.push_back(succ_block->nodes[0]); + } + } else { + succ.push_back(parent->nodes[index + 1]); + } + return succ; + } + + // Creates a node with the given expr and pushes it to the end of the parent basic block. + static NodePtr Make(BasicBlockPtr parent, Expr expr) { + NodePtr n = std::make_shared(); + n->parent = parent; + n->expr = expr; + n->index = parent->nodes.size(); + parent->nodes.push_back(n); + return n; + } }; - std::unordered_map let_map; - std::vector reverse_post_order; - // Node* entry; + BasicBlockPtr entry; - static ControlFlowGraph Create(support::Arena* arena, const Expr& body); + // Let expressions are never shared in ANF (unlike vars), so this is an injection. + std::unordered_map let_map; + + std::vector reverse_post_order; + + static ControlFlowGraph Create(const Expr& body); private: class Creator; @@ -57,341 +104,280 @@ class ControlFlowGraph { using NodeList = std::vector; -class ControlFlowGraph::Creator : private ExprFunctor { +class ControlFlowGraph::Creator : private ExprFunctor { public: - Creator(support::Arena* arena) : arena_(arena) {} + Creator() {} ControlFlowGraph Create(const Expr& body) { - VisitExpr(body, {}); + cfg_.entry = BasicBlock::Make(); + VisitExpr(body, cfg_.entry); return std::move(cfg_); } private: - support::Arena* arena_; ControlFlowGraph cfg_; - std::unordered_set visited_; bool in_func_ = false; - void Succ(Node* from, Node* to) { - auto succ_link = arena_->make>(); - succ_link->value = to; - from->succ.Push(succ_link); - - auto pred_link = arena_->make>(); - pred_link->value = from; - to->pred.Push(pred_link); + void Succ(BasicBlockPtr from, BasicBlockPtr to) { + from->succ.push_back(to); + to->pred.push_back(from); } -#define DEFAULT_CFG(OP) \ - Node* VisitExpr_(const OP* op, const NodeList& preds) final { \ - Node* n = arena_->make(); \ - n->expr = GetRef(op); \ - for (Node * pred : preds) { \ - Succ(pred, n); \ - } \ - cfg_.reverse_post_order.push_back(n); \ - return n; \ +#define DEFAULT_CFG(OP) \ + void VisitExpr_(const OP* op, BasicBlockPtr parent) final { \ + NodePtr n = Node::Make(parent, GetRef(op)); \ + cfg_.reverse_post_order.push_back(n); \ } - Node* VisitExpr_(const FunctionNode* f, const NodeList& preds) final { + void VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) final { ICHECK(!in_func_) << "nested functions not supported by CFG analysis"; in_func_ = true; if (f->HasNonzeroAttr(attr::kClosure)) { ICHECK(f->body.as()); - return VisitExpr(Downcast(f->body)->body, {}); + return VisitExpr(Downcast(f->body)->body, parent); } - // cfg_.entry = arena_->make(); - // Succ(cfg_.entry, VisitExpr(f->body)); - - return VisitExpr(f->body, {}); + return VisitExpr(f->body, parent); } - Node* VisitExpr_(const LetNode* let_node, const NodeList& let_preds) final { + void VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) final { Expr expr = GetRef(let_node); - NodeList preds = let_preds; while (const LetNode* inner_let_node = expr.as()) { - Node* curr_node = arena_->make(); - curr_node->expr = expr; + NodePtr curr_node = Node::Make(parent, expr); + ICHECK(!cfg_.let_map.count(expr)); cfg_.let_map[expr] = curr_node; - - // 2 predecessors if last let bound value was an If, else 1 - for (Node* pred : preds) { - Succ(pred, curr_node); - } - cfg_.reverse_post_order.push_back(curr_node); + if (const IfNode* ite = AsIgnoringOnDevice(inner_let_node->value)) { - Node* t_node = VisitExpr(ite->true_branch, {curr_node}); - Node* f_node = VisitExpr(ite->false_branch, {curr_node}); - preds = {t_node, f_node}; - } else { - preds = {curr_node}; + // Create the basic blocks for each branch and mark them as successors to the current block. + BasicBlockPtr t_block = BasicBlock::Make(); + BasicBlockPtr f_block = BasicBlock::Make(); + Succ(parent, t_block); + Succ(parent, f_block); + + VisitExpr(ite->true_branch, t_block); + VisitExpr(ite->false_branch, f_block); + + // All subsequent bindings (and/or the body expr) will be in a new basic block. + BasicBlockPtr next = BasicBlock::Make(); + Succ(t_block, next); + Succ(f_block, next); + parent = next; + } else if (const MatchNode* match = AsIgnoringOnDevice(inner_let_node->value)) { + // Same as above but one for each pattern. + std::vector clause_blocks; + BasicBlockPtr next = BasicBlock::Make(); + for (const Clause& clause : match->clauses) { + BasicBlockPtr clause_block = BasicBlock::Make(); + Succ(parent, clause_block); + Succ(clause_block, next); + VisitExpr(clause->rhs, clause_block); + } + parent = next; } + expr = inner_let_node->body; } - Node* body_node = VisitExpr(expr, preds); + VisitExpr(expr, parent); + } - return body_node; + void VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) { + // TODO(@altanh): is there a way of making this work? + LOG(FATAL) << "If expressions should be bound to variables."; + } + + void VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) { + // TODO(@altanh): same as If + LOG(FATAL) << "Match expressions should be bound to variables."; } DEFAULT_CFG(VarNode); DEFAULT_CFG(GlobalVarNode); DEFAULT_CFG(ConstantNode); - DEFAULT_CFG(IfNode); DEFAULT_CFG(CallNode); DEFAULT_CFG(OpNode); DEFAULT_CFG(TupleNode); DEFAULT_CFG(TupleGetItemNode); }; -ControlFlowGraph ControlFlowGraph::Create(support::Arena* arena, const Expr& body) { - return Creator(arena).Create(body); -} - -// -class LivenessAnalyzer : private ExprFunctor { - private: - using CFG = ControlFlowGraph; +ControlFlowGraph ControlFlowGraph::Create(const Expr& body) { return Creator().Create(body); } +// NOTE: for If exprs, only the condition is included (not the branches). Similarly, for Match +// exprs only the value being deconstructed is included. +class VarUseCollector : public ExprFunctor { public: - LivenessAnalyzer() {} + VarSet VisitExpr_(const VarNode* var_node) { return {GetRef(var_node)}; } - // see https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm - void ComputeLiveness(const Expr& expr) { - cfg_ = CFG::Create(&arena_, expr); - VisitExpr(expr, nullptr); - - bool did_work = true; + VarSet VisitExpr_(const CallNode* call_node) { + VarSet use = VisitExpr(call_node->op); + for (const Expr& arg : call_node->args) { + VarSet arg_use = VisitExpr(arg); + use.insert(arg_use.begin(), arg_use.end()); + } + return use; + } - auto visitor = [&](CFG::Node* n) { - VarSet old_in_n = this->live_in_[n]; - VarSet old_out_n = this->live_out_[n]; + VarSet VisitExpr_(const TupleNode* tuple_node) { + VarSet use; + for (const Expr& field : tuple_node->fields) { + VarSet field_use = VisitExpr(field); + use.insert(field_use.begin(), field_use.end()); + } + return use; + } - this->live_in_[n] = this->use_[n]; - for (const Var& v : this->live_out_[n]) { - if (!v.same_as(this->def_[n])) { - this->live_in_[n].insert(v); - } - } + VarSet VisitExpr_(const TupleGetItemNode* get_node) { return VisitExpr(get_node->tuple); } - this->live_out_[n] = VarSet(); - auto s = n->succ.head; - while (s) { - CFG::Node* s_node = s->value; - this->live_out_[n].insert(this->live_in_[s_node].begin(), this->live_in_[s_node].end()); - s = s->next; - } + VarSet VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); } - if (!SetEqual(old_in_n, this->live_in_[n])) { - did_work = true; - } else if (!SetEqual(old_out_n, this->live_out_[n])) { - did_work = true; - } - }; + VarSet VisitExpr_(const MatchNode* match_node) { return VisitExpr(match_node->data); } - while (did_work) { - did_work = false; - for (auto it = cfg_.reverse_post_order.rbegin(); it != cfg_.reverse_post_order.rend(); ++it) { - visitor(*it); - } - } - } + VarSet VisitExpr_(const ConstructorNode* cons_node) { return {}; } - private: - CFG::Node* VisitExpr_(const LetNode* let_node, CFG::Node* cfg_node) override { - Expr expr = GetRef(let_node); - ICHECK(!cfg_node || cfg_node == cfg_.let_map[expr]) << cfg_node->expr << std::endl << - std::endl << cfg_.let_map[expr]->expr; + VarSet VisitExpr_(const GlobalVarNode* gvar_node) { return {}; } - while (const auto* inner_let_node = expr.as()) { - const Var& var = inner_let_node->var; - const Expr& value = inner_let_node->value; - - ICHECK(!cfg_node || cfg_node == cfg_.let_map[expr]) << cfg_node->expr << std::endl << - std::endl << cfg_.let_map[expr]->expr; cfg_node = cfg_.let_map[expr]; - - // ICHECK(!alias_.count(var)); - // if (value.as()) { - // Var rhs = Downcast(value); - // ICHECK(alias_.count(rhs)); - // alias_[var] = alias_[rhs]; - // alias_[rhs]->insert(var); - // } else { - // alias_[var] = std::make_shared(); - // alias_[var]->insert(var); - // } - - cfg_node = cfg_.let_map[expr]; - def_[cfg_node] = var; - - if (const IfNode* ite = AsIgnoringOnDevice(value)) { - VisitExpr_(ite, cfg_node); - - // there should be exactly two successors: the true branch then the false branch - ICHECK(cfg_node->succ.head); - ICHECK(cfg_node->succ.head->next); - ICHECK(!cfg_node->succ.head->next->next); - CFG::Node* t_entry = cfg_node->succ.head->value; - CFG::Node* f_entry = cfg_node->succ.head->next->value; - - CFG::Node* t_exit = VisitExpr(ite->true_branch, t_entry); - CFG::Node* f_exit = VisitExpr(ite->false_branch, f_entry); - - // each branch should have exactly one succcessor, and it should be the same - ICHECK(t_exit->succ.head && !t_exit->succ.head->next); - ICHECK(f_exit->succ.head && !f_exit->succ.head->next); - ICHECK(t_exit->succ.head->value == f_exit->succ.head->value); - cfg_node = t_exit->succ.head->value; - if (inner_let_node->body.as()) { - ICHECK(cfg_node->expr.same_as(inner_let_node->body)); - ICHECK(cfg_.let_map[inner_let_node->body] == cfg_node); - } - } else { - VisitExpr(value, cfg_node); - - // normal bindings should have just one successor - ICHECK(cfg_node->succ.head && !cfg_node->succ.head->next); - cfg_node = cfg_node->succ.head->value; - if (inner_let_node->body.as()) { - ICHECK(cfg_node->expr.same_as(inner_let_node->body)); - ICHECK(cfg_.let_map[inner_let_node->body] == cfg_node); - } - } + VarSet VisitExpr_(const ConstantNode* const_node) { return {}; } - expr = inner_let_node->body; - } + VarSet VisitExpr_(const OpNode* op_node) { return {}; } +}; - return VisitExpr(expr, cfg_node); - } +struct UseDefAnalysis { + using CFG = ControlFlowGraph; - CFG::Node* VisitExpr_(const IfNode* if_node, CFG::Node* cfg_node) override { - VisitExpr(if_node->cond, cfg_node); - return cfg_node; - } + std::unordered_map use; + std::unordered_map def; - CFG::Node* VisitExpr_(const TupleNode* tuple_node, CFG::Node* cfg_node) override { - for (const Expr& field : tuple_node->fields) { - VisitExpr(field, cfg_node); - } - return cfg_node; - } + VarUseCollector use_collector; - CFG::Node* VisitExpr_(const TupleGetItemNode* get_node, CFG::Node* cfg_node) override { - VisitExpr(get_node->tuple, cfg_node); - return cfg_node; - } + static UseDefAnalysis Analyze(const CFG& cfg) { + UseDefAnalysis a; - CFG::Node* VisitExpr_(const GlobalVarNode* global_var_node, CFG::Node* cfg_node) override {return cfg_node;} + std::vector worklist = {cfg.entry}; + while (!worklist.empty()) { + CFG::BasicBlockPtr block = worklist.back(); + worklist.pop_back(); - CFG::Node* VisitExpr_(const VarNode* var_node, CFG::Node* cfg_node) override { - Var var = GetRef(var_node); - // ICHECK(alias_.count(var)); - use_[cfg_node].insert(var); - return cfg_node; - } + for (const CFG::NodePtr& node : block->nodes) { + if (const LetNode* let_node = AsIgnoringOnDevice(node->expr)) { + a.use[node] = a.use_collector.VisitExpr(let_node->value); + a.def[node] = let_node->var; + } else { + a.use[node] = a.use_collector.VisitExpr(node->expr); + a.def[node] = Var(); + } + } - CFG::Node* VisitExpr_(const ConstantNode* const_node, CFG::Node* cfg_node) override {return cfg_node;} + for (const CFG::BasicBlockPtr& s : block->succ) { + worklist.push_back(s); + } + } - CFG::Node* VisitExpr_(const OpNode* op_node, CFG::Node* cfg_node) override {return cfg_node;} + return a; + } +}; - CFG::Node* VisitExpr_(const CallNode* call_node, CFG::Node* cfg_node) override { - VisitExpr(call_node->op, cfg_node); - for (const Expr& arg : call_node->args) { - VisitExpr(arg, cfg_node); +bool SetEqual(const VarSet& a, const VarSet& b) { + if (a.size() != b.size()) { + return false; + } + for (auto& xa : a) { + if (!b.count(xa)) { + return false; } - return cfg_node; } + return true; +} - CFG::Node* VisitExpr_(const FunctionNode* func_node, CFG::Node* cfg_node) override { - ICHECK(!used_); - used_ = true; +struct LivenessAnalysis { + using CFG = ControlFlowGraph; - if (func_node->HasNonzeroAttr(attr::kPrimitive)) { - return nullptr; - } + std::unordered_map live_in; + std::unordered_map live_out; - // TODO(@altanh): figure out the closure nesting thing - // ICHECK(!function_node->HasNonzeroAttr(attr::kClosure)) << "closures not supported yet"; - ICHECK(func_depth_ == 0) << "nested functions should have been transformed away"; + static LivenessAnalysis Analyze(const ControlFlowGraph& cfg, const UseDefAnalysis& use_def) { + LivenessAnalysis a; + bool did_work = true; - Expr body = func_node->body; - if (func_node->HasNonzeroAttr(attr::kClosure)) { - ICHECK(body.as()); - body = Downcast(func_node->body)->body; - } + auto visitor = [&](const CFG::NodePtr n) { + VarSet old_in_n = a.live_in[n]; + VarSet old_out_n = a.live_out[n]; - ++func_depth_; - VisitExpr(body, nullptr); - --func_depth_; + a.live_in[n] = use_def.use.at(n); + for (const Var& v : a.live_out[n]) { + if (!v.same_as(use_def.def.at(n))) { + a.live_in[n].insert(v); + } + } - return nullptr; - } + a.live_out[n] = VarSet(); + for (const CFG::NodePtr& s : n->GetSucc()) { + a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end()); + } - bool SetEqual(const VarSet& a, const VarSet& b) { - if (a.size() != b.size()) { - return false; - } - for (auto& xa : a) { - if (!b.count(xa)) { - return false; + if (!SetEqual(old_in_n, a.live_in[n])) { + did_work = true; + } else if (!SetEqual(old_out_n, a.live_out[n])) { + did_work = true; + } + }; + + while (did_work) { + did_work = false; + for (auto it = cfg.reverse_post_order.rbegin(); it != cfg.reverse_post_order.rend(); ++it) { + visitor(*it); } } - return true; + + return a; } +}; - private: - friend class MemoryPlanner; +class KillInserter : public ExprMutator { + public: + KillInserter(const ControlFlowGraph* cfg, const LivenessAnalysis* lva) : cfg_(cfg), lva_(lva) {} - bool used_ = false; + Expr VisitExpr_(const LetNode* let_node) override { + Expr expr = GetRef(let_node); + LetList ll; - support::Arena arena_; - CFG cfg_; + while (const LetNode* inner_let_node = expr.as()) { + ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value)); - CFG::Node* cfg_node_; + ICHECK(!inner_let_node->value.as()) << "aliasing should have been eliminated."; + ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG"; - // v in use_[n] means v is read in n. - // NOTE: the use set of an If expression does not include the branches, just the condition. - // This lets us pretend the IR is composed of basic blocks (seq of bindings) + unstructured - // control flow, which is what most data-flow algorithms assume as input. - // std::unordered_map use_; - std::unordered_map use_; + const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr); - // def_[n] = v means n is a node "let v = ...;" - // TODO(@altanh): pretty sure this can be removed since we don't allow binding the same var twice - // (unless I'm misremembering). - // std::unordered_map def_; - std::unordered_map def_; + const VarSet& li = lva_->live_in.at(n); + const VarSet& lo = lva_->live_out.at(n); - // y in alias_[x] means y is an alias of x, created by let binding y to an alias of x. - // NOTE: x in alias_[x] for all x. - // std::unordered_map, ObjectPtrHash, ObjectPtrEqual> alias_; + // Killed vars = live in - live out. + VarSet kills; + for (const Var& v : li) { + if (!lo.count(v)) { + kills.insert(v); + } + } - // Maps node -> {successor expr/basic block}. - // NOTE: a pair of bindings without control flow, e.g. e = "b0; b1; body", results in a linear - // successor e -> b1. If expressions on the other hand, e.g. - // e = "let x = if (cond) { true_b } else { false_b }; body" have branching - // e -> {true_b, false_b} -> body. - // std::unordered_map, ObjectPtrHash, ObjectPtrEqual> succ_; + for (const Var& v : kills) { + ll.Push(Call(Op::Get("memory.kill"), {v})); + } - // Maps node -> {v: Var | v is live before node} - std::unordered_map live_in_; - // Maps node -> {v: Var | v is live after node} - std::unordered_map live_out_; + expr = inner_let_node->body; + } - size_t func_depth_ = 0; + return ll.Get(VisitExpr(expr)); + } - const VarSet empty_set_; + private: + const ControlFlowGraph* cfg_; + const LivenessAnalysis* lva_; }; -// TODO(@altanh): figure out if letrec is a problem -// FIXME(@altanh): device_copy can be aliasing when src == dst - class AliasEliminator : public MixedModeMutator { public: Expr VisitExpr_(const LetNode* let_node) override { @@ -399,20 +385,47 @@ class AliasEliminator : public MixedModeMutator { LetList ll; std::vector bound_vars; + auto set_alias = [&](const Var& alias, const VarNode* alias_of_n) { + Var alias_of = GetRef(alias_of_n); + if (alias_.count(alias_of)) { + alias_[alias] = alias_[alias_of]; + } else { + alias_[alias] = alias_of; + } + bound_vars.push_back(alias); + }; + while (const LetNode* inner_let_node = expr.as()) { const Var& var = inner_let_node->var; const Expr& val = inner_let_node->value; + bool aliased = false; ICHECK(!alias_.count(var)); - if (val.as()) { - ICHECK(alias_.count(Downcast(val))); - alias_[var] = alias_[Downcast(val)]; - } else { - alias_[var] = var; + + if (const VarNode* alias_of_n = AsIgnoringOnDevice(val)) { + set_alias(var, alias_of_n); + aliased = true; + } else if (AsIgnoringOnDevice(val)) { + // Copying to the same device is aliasing. + Expr unwrapped = IgnoreOnDevice(val); + DeviceCopyProps copy_props = GetDeviceCopyProps(unwrapped); + if (copy_props.body.defined()) { + if (copy_props.src_virtual_device->device_type() == + copy_props.dst_virtual_device->device_type() && + copy_props.src_virtual_device->virtual_device_id == + copy_props.dst_virtual_device->virtual_device_id) { + Expr to_copy = Downcast(unwrapped)->args[0]; + if (const VarNode* alias_of_n = to_copy.as()) { + set_alias(var, alias_of_n); + aliased = true; + } + } + } + } + + if (!aliased) { ll.Push(var, VisitExpr(val)); } - bound_vars.push_back(var); - expr = inner_let_node->body; } @@ -428,99 +441,56 @@ class AliasEliminator : public MixedModeMutator { Expr VisitExpr_(const VarNode* var_node) override { Var var = GetRef(var_node); - ICHECK(alias_.count(var)); - return alias_[var]; + if (alias_.count(var)) { + return alias_[var]; + } + return var; } Expr VisitExpr_(const FunctionNode* func_node) override { - for (const Var& param : func_node->params) { - alias_[param] = param; - } - Expr new_body = VisitExpr(func_node->body); - Expr result = GetRef(func_node); + Function result = GetRef(func_node); if (!new_body.same_as(func_node->body)) { result = Function(func_node->params, new_body, func_node->ret_type, func_node->type_params, func_node->attrs, func_node->span); } - - for (const Var& param : func_node->params) { - size_t erased = alias_.erase(param); - ICHECK(erased); - } - return result; } - private: - std::unordered_map alias_; -}; - -class MemoryPlanner : public ExprMutator { - public: - MemoryPlanner() {} - - Expr PlanMemory(const Expr& e) { - lva_.ComputeLiveness(e); - return VisitExpr(e); - } - - Expr VisitExpr_(const LetNode* let_node) override { - Expr expr = GetRef(let_node); - LetList ll; - - while (const LetNode* inner_let_node = expr.as()) { - ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value)); - - ICHECK(!inner_let_node->value.as()); - - ICHECK(lva_.cfg_.let_map.count(expr)); - - ControlFlowGraph::Node* n = lva_.cfg_.let_map[expr]; - - auto& li = lva_.live_in_[n]; - auto& lo = lva_.live_out_[n]; - - // std::cout << "let " << inner_let_node->var->name_hint() << " = ...;" << std::endl; - // std::cout << " live in:"; - // for (auto& v : li) { - // std::cout << " " << v->name_hint(); - // } - // std::cout << std::endl << " live out:"; - // for (auto& v : lo) { - // std::cout << " " << v->name_hint(); - // } - // std::cout << std::endl << std::endl; - - // killed vars = live in - live out - VarSet kills; - for (auto& v : li) { - if (!lo.count(v)) { - kills.insert(v); + // The only register-level aliasing that occurs in Match expressions is when + // the deconstructed expression is a Var, and the matched pattern is also a Var. + Expr VisitExpr_(const MatchNode* match_node) override { + if (const VarNode* data_var = AsIgnoringOnDevice(match_node->data)) { + std::vector new_clauses; + for (const Clause& clause : match_node->clauses) { + const PatternVarNode* pv_node = nullptr; + if ((pv_node = clause->lhs.as())) { + alias_[pv_node->var] = GetRef(data_var); + } + new_clauses.push_back(Clause(clause->lhs, VisitExpr(clause->rhs))); + if (pv_node) { + alias_.erase(pv_node->var); } } - - for (auto& v : kills) { - ll.Push(Call(Op::Get("memory.kill"), {v})); - } - - expr = inner_let_node->body; + return Match(match_node->data, new_clauses, match_node->complete, match_node->span); + } else { + return ExprMutator::VisitExpr_(match_node); } - - return ll.Get(VisitExpr(expr)); } private: - LivenessAnalyzer lva_; - ControlFlowGraph::Node* curr_node_; + std::unordered_map alias_; }; Pass VMPlanMemory() { auto pass_func = [](Function f, IRModule m, PassContext pc) -> Function { - AliasEliminator el; - MemoryPlanner mp; - Expr nf = mp.PlanMemory(el.Mutate(f)); - return Downcast(nf); + f = Downcast(AliasEliminator().Mutate(f)); + ControlFlowGraph cfg = ControlFlowGraph::Create(f); + UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); + LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); + KillInserter ki(&cfg, &lva); + Function nf = Downcast(ki.Mutate(f)); + return nf; }; return CreateFunctionPass(pass_func, 0, "VMPlanMemory", {}); } @@ -530,5 +500,3 @@ TVM_REGISTER_GLOBAL("relay._transform.VMPlanMemory").set_body_typed(VMPlanMemory } // namespace transform } // namespace relay } // namespace tvm - -// class LivenessAnalyzer : public DeviceA