From 333b21574198d99b2c371ba80047c37e94cc60b4 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 3 Apr 2020 16:57:42 +0000 Subject: [PATCH 01/17] add target to region --- src/relay/analysis/annotated_region_set.cc | 61 +++++++++++--------- src/relay/analysis/annotated_region_set.h | 37 +++++++----- tests/python/relay/test_annotated_regions.py | 18 ++++-- 3 files changed, 72 insertions(+), 44 deletions(-) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index ad2b9e145789..b6e1dfc5223a 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -31,7 +32,7 @@ namespace relay { AnnotatedRegion AnnotatedRegionSetNode::GetRegion(const Expr& expr) const { for (auto candidate : regions_) { - if (candidate->nodes.find(expr) != candidate->nodes.end()) { + if (candidate->nodes_.find(expr) != candidate->nodes_.end()) { return candidate; } } @@ -45,26 +46,26 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src, } // Merge src to dest and erase src. - dest->nodes.insert(src->nodes.begin(), src->nodes.end()); - for (const auto& input : src->ins) { - dest->ins.push_back(input); + dest->nodes_.insert(src->nodes_.begin(), src->nodes_.end()); + for (const auto& input : src->ins_) { + dest->ins_.push_back(input); } - for (const auto& output : src->outs) { - dest->outs.push_back(output); + for (const auto& output : src->outs_) { + dest->outs_.push_back(output); } // if any of the outputs of src are inputs of dest, they become internal nodes // so remove them from outs std::vector ins_to_remove; - for (const auto& input : dest->ins) { + for (const auto& input : dest->ins_) { auto call = Downcast(input); - auto it = src->nodes.find(call->args[0]); - if (it != src->nodes.end()) { - dest->outs.remove(*it); + auto it = src->nodes_.find(call->args[0]); + if (it != src->nodes_.end()) { + dest->outs_.remove(*it); ins_to_remove.push_back(input); } } for (const auto& input : ins_to_remove) { - dest->ins.remove(input); + dest->ins_.remove(input); } regions_.erase(src); } @@ -74,25 +75,21 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) if (src.defined()) { MergeRegions(src, dest); } else { - dest->nodes.insert(expr); + dest->nodes_.insert(expr); } } -AnnotatedRegion AnnotatedRegionSetNode::MakeRegion() { +AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(std::string target) { auto ret = regions_.emplace(AnnotatedRegion()); - (*ret.first)->id = region_id_++; + (*ret.first)->id_ = region_id_++; + (*ret.first)->target_ = target; return *ret.first; } class AnnotatedRegionSet::Creator : public ExprVisitor { public: - Creator(const Op& region_begin_op, const Op& region_end_op) : - begin_op_(region_begin_op), end_op_(region_end_op) {} - - AnnotatedRegionSet Create(const Expr& expr) { - VisitExpr(expr); - return std::move(region_set_); - } + Creator(const Op& region_begin_op, const Op& region_end_op) + : begin_op_(region_begin_op), end_op_(region_end_op) {} void VisitExpr_(const CallNode* call) { auto op_node = call->op.as(); @@ -115,24 +112,36 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { << "Cannot find the corresponding region for start annotation:\n" << AsText(GetRef(call), false)); } - region->ins.push_back(GetRef(call)); + region->ins_.push_back(GetRef(call)); } else { CHECK_EQ(call->op, end_op_); // The annotation node is inserted on edge so it must have only one argument. CHECK_EQ(call->args.size(), 1U); + std::string target = call->attrs.as()->compiler; // Check if the argument already belongs to a region auto region = region_set_->GetRegion(call->args[0]); if (!region.defined()) { - region = region_set_->MakeRegion(); - region->nodes.insert(call->args[0]); + // Create a new region if the argument is not belonged to any regions yet. + region = region_set_->MakeRegion(target); + region->nodes_.insert(call->args[0]); + } + else { + // If the argument is belonged to a region, it must have the same target. + // Otherwise we should see a region_begin op. + CHECK_EQ(region->GetTarget(), target); } - region->nodes.insert(GetRef(call)); - region->outs.push_back(GetRef(call)); + region->nodes_.insert(GetRef(call)); + region->outs_.push_back(GetRef(call)); } ExprVisitor::VisitExpr_(call); } + AnnotatedRegionSet Create(const Expr& expr) { + VisitExpr(expr); + return std::move(region_set_); + } + void VisitExpr_(const TupleNode* op) { auto region = region_set_->GetRegion(GetRef(op)); if (region.defined()) { diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index 0b9301133d1c..cfd044e79776 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -49,33 +50,39 @@ class AnnotatedRegionSet; class AnnotatedRegionNode : public Object { public: void VisitAttrs(AttrVisitor* v) { - v->Visit("id", &id); - Array nodes_array(nodes.begin(), nodes.end()); + v->Visit("id", &id_); + v->Visit("target", &target_); + Array nodes_array(nodes_.begin(), nodes_.end()); v->Visit("nodes", &nodes_array); - Array args_array(ins.begin(), ins.end()); + Array args_array(ins_.begin(), ins_.end()); v->Visit("args", &args_array); - Array rets_array(outs.begin(), outs.end()); + Array rets_array(outs_.begin(), outs_.end()); v->Visit("rets", &rets_array); } /*! \brief Get the region ID. */ int GetID() const { - return id; + return id_; + } + + /*! \brief Get the region target. */ + std::string GetTarget() const { + return target_; } /*! \brief Get the region's inputs. */ std::list GetInputs() const { - return ins; + return ins_; } /*! \brief Get the region's outputs. */ std::list GetOutputs() const { - return outs; + return outs_; } /*! \brief Get the region's nodes. */ std::unordered_set GetNodes() const { - return nodes; + return nodes_; } static constexpr const char* _type_key = "relay.AnnotatedRegion"; @@ -83,13 +90,15 @@ class AnnotatedRegionNode : public Object { protected: /*! \brief The region ID. */ - int id{-1}; + int id_{-1}; + /*! \brief The target for this region. */ + std::string target_ = "default"; /*! \brief The inputs to this region. */ - std::list ins; + std::list ins_; /*! \brief The outputs of this region */ - std::list outs; + std::list outs_; /*! \brief Nodes in this region. */ - std::unordered_set nodes; + std::unordered_set nodes_; friend class AnnotatedRegionSet; friend class AnnotatedRegionSetNode; @@ -184,11 +193,11 @@ class AnnotatedRegionSetNode : public Object { void AddToRegion(AnnotatedRegion dest, const Expr& expr); /*! - * \brief Make a new region. + * \brief Make a new region for a target. * * \return The new region. */ - AnnotatedRegion MakeRegion(); + AnnotatedRegion MakeRegion(std::string target); std::unordered_set regions_; /*! \brief The next region ID to assign. */ diff --git a/tests/python/relay/test_annotated_regions.py b/tests/python/relay/test_annotated_regions.py index a24639867091..f3c157d296db 100644 --- a/tests/python/relay/test_annotated_regions.py +++ b/tests/python/relay/test_annotated_regions.py @@ -15,13 +15,15 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name +import tvm from tvm import relay from tvm.relay.op.annotation import compiler_begin, compiler_end -def check_region(region_set, args, nodes, rets): +def check_region(region_set, target, args, nodes, rets): region = region_set.get_region(args[0]) assert region + assert target == region.target assert set(args) == set(region.args) assert set(nodes) == set(region.nodes) assert set(rets) == set(region.rets) @@ -51,24 +53,28 @@ def test_region_set_creator_diamond(): assert len(region_set) == 4 check_region( region_set, + 'test_target', [cb_1], [cb_1, O_1, ce_1, ce_2], [ce_1, ce_2], ) check_region( region_set, + 'test_target', [cb_2], [cb_2, O_2, ce_3], [ce_3], ) check_region( region_set, + 'default', [cb_d], [cb_d, X, ce_d], [ce_d], ) check_region( region_set, + 'test_target', [cb_3, cb_4], [cb_3, cb_4, O_3, ce_4], [ce_4], @@ -88,7 +94,9 @@ def test_region_set_creator_merged(): cb_3 = compiler_begin(ce_3, 'test_target') cb_4 = compiler_begin(ce_d, 'test_target') O_3 = relay.add(cb_3, cb_4) - ce_4 = compiler_end(O_3, 'test_target') + O_4 = relay.add(cb_3, cb_4) + O_5 = relay.Tuple([O_3, O_4]) + ce_4 = compiler_end(O_5, 'test_target') merged = relay.Function([data], ce_4) region_set = relay.analysis.AnnotatedRegionSet(merged, @@ -97,20 +105,23 @@ def test_region_set_creator_merged(): assert len(region_set) == 3 check_region( region_set, + 'test_target', [cb_1], [cb_1, O_1, O_2, ce_2, ce_3], [ce_2, ce_3], ) check_region( region_set, + 'default', [cb_d], [cb_d, X, ce_d], [ce_d], ) check_region( region_set, + 'test_target', [cb_3, cb_4], - [cb_3, cb_4, O_3, ce_4], + [cb_3, cb_4, O_3, O_4, O_5, ce_4], [ce_4], ) @@ -118,4 +129,3 @@ def test_region_set_creator_merged(): if __name__ == "__main__": test_region_set_creator_diamond() test_region_set_creator_merged() - From 36d20dad06e3a15d74086f3c03b8afb82631f47e Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 3 Apr 2020 20:45:42 +0000 Subject: [PATCH 02/17] refactor annotate_target --- python/tvm/relay/transform/transform.py | 8 +- src/relay/transforms/annotate_target.cc | 250 ++++++++++-------- src/relay/transforms/partition_graph.cc | 37 ++- ...target.py => test_pass_annotate_target.py} | 21 +- 4 files changed, 198 insertions(+), 118 deletions(-) rename tests/python/relay/{test_annotate_target.py => test_pass_annotate_target.py} (93%) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index ce4ac79a88d0..fc0e3f833658 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -587,14 +587,14 @@ def PartitionGraph(): -def AnnotateTarget(target): +def AnnotateTarget(targets): """Annotate ops in an experession with a provied compiler/target and then use it for codegen. Parameters ---------- - target : String - The target compiler used for codegen. + target : List[String] + The list of target compilers used for codegen. Returns ------- @@ -602,7 +602,7 @@ def AnnotateTarget(target): The annotated pass that wrapps ops with subgraph_start and subgraph_end. """ - return _ffi_api.AnnotateTarget(target) + return _ffi_api.AnnotateTarget(targets) def Inline(): diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 44ef35a285f5..7a2a723a3f02 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -27,121 +27,139 @@ #include #include #include +#include namespace tvm { namespace relay { namespace annotate_target { -// Cache compiler_begin op for equivalence check. static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. class AnnotateTargetWrapper : public ExprMutator { public: - explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {} + AnnotateTargetWrapper(const Array targets) { + for (auto target : targets) { + targets_.push_back(target.data()); + } + } Expr Annotate(const Expr& expr) { - return InsertEnd(Mutate(expr)); + auto new_expr = Mutate(expr); + //std::cerr << AsText(new_expr); + return new_expr; } - bool IsSupported(const Expr& expr) { - if (expr->IsInstance()) { - Call call = Downcast(expr); - auto fannotate = Op::GetAttr("target." + target_); - if (call->op->IsInstance()) { - Op op = Downcast(call->op); - CHECK(op.defined()); - if (fannotate.count(op)) { - return fannotate[op](call->attrs, call->args); - } - } else if (call->op->IsInstance()) { - // handle composite functions - Function func = Downcast(call->op); - CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); - if (comp_name.defined()) { - std::string comp_name_str = comp_name; - size_t i = comp_name_str.find('.'); - if (i != std::string::npos) { - std::string target = comp_name_str.substr(0, i); - if (target == target_) return true; - } + /*! \brief This function 1) annotates a compiler end and a compiler begin to all arguments. + * The compiler end is based on the arg target while the compiler begin is based on the given + * target. If target is not given and all arguments are going to the same target, then we will + * use that target; otherwise we use default for this op. Note that all arg exprs must be + * available in op_expr_to_target before calling this function. + * + * \param args An array of arguments of the given node. + * \param target The target of the current node. + * \return A pair of target and annotated argument expressions. + */ + std::pair> AnnotateArgs(const Array args, + const std::string target = "") { + std::string ref_target = ""; + Array compiler_ends; + for (auto arg : args) { + if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { + std::string arg_target = op_expr_to_target_[arg]; + compiler_ends.push_back(InsertAnnotation(arg, arg_target, end_op)); + if (ref_target == "") { + ref_target = arg_target; + } else if (ref_target != arg_target) { + ref_target = "__inconsist__"; } + } else { + // Input vars. + compiler_ends.push_back(arg); } } - if (expr->IsInstance()) { - TupleGetItem get = Downcast(expr); - if (get->tuple->IsInstance() && - get->tuple.as()->op == compiler_begin_op) { - return true; - } + ref_target = (ref_target == "__inconsist__") ? "default" : ref_target; + + // Determine compiler begin target. + std::string op_target = (target == "")? ref_target: target; + + Array compiler_begins; + for (auto end : compiler_ends) { + compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op)); } - return false; + + return {op_target, compiler_begins}; } - Expr InsertEnd(const Expr& arg) { - if (IsSupported(arg)) { - const auto *end_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_end"); - CHECK(end_op); - Expr end = (*end_op)(arg, target_); - return end; - } - return arg; + Expr InsertAnnotation(const Expr& expr, const std::string target, const PackedFunc* ann_op) { + Expr new_op = (*ann_op)(expr, target); + new_op->checked_type_ = expr->checked_type_; + return new_op; } Expr VisitExpr_(const CallNode* cn) { - auto new_e = ExprMutator::VisitExpr_(cn); + Op op = Downcast(cn->op); + CHECK(op.defined()); - Call call = Downcast(new_e); + // Supported targets for this node. The order implies the priority. + std::vector supported_targets; - // add end annotations if the args are supported - Array compiler_ends; - for (const auto& it : call->args) { - compiler_ends.push_back(InsertEnd(it)); - } - call = Call(call->op, compiler_ends, call->attrs); - - // add begin annotations if the call node is supported - if (IsSupported(call)) { - tvm::Array compiler_begins; - const auto* begin_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - for (const auto& it : call->args) { - CHECK(begin_op); - Expr begin = (*begin_op)(it, target_); - compiler_begins.push_back(begin); + // Check which targets this op can be offloaded. + for (auto target : this->targets_) { + auto fannotate = Op::GetAttr("target." + target); + if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) { + supported_targets.push_back(target); } - call = Call(call->op, compiler_begins, call->attrs); } + supported_targets.push_back("default"); // Make default as the last option. + + // TODO(@comaniac, @zhiics): Now we simply assign this node to the target with + // the highest priority, but we should preserve all supported targets so that + // we can make a better decision. + std::string target = supported_targets[0]; + + // Visit and mutate arguments after the target of this op has been determined. + auto new_e = ExprMutator::VisitExpr_(cn); + Call call = Downcast(new_e); + + // Add annotations to each arg. + auto target_n_args = AnnotateArgs(call->args, target); + Array compiler_begins = std::get<1>(target_n_args); + // for (auto b : compiler_begins) { + // std::cerr << AsText(b); + // std::cerr << "===============\n"; + // } + //std::cerr << "*********************************************\n"; + call = Call(call->op, compiler_begins, call->attrs); + call->checked_type_ = cn->checked_type_; + + // Update the target map. + op_expr_to_target_[call] = target; return std::move(call); } Expr VisitExpr_(const TupleNode* op) { auto new_e = ExprMutator::VisitExpr_(op); + auto expr = Downcast(new_e); - auto tup = Downcast(new_e); - Array new_fields; - for (auto field : tup->fields) { - new_fields.push_back(InsertEnd(field)); - } - return Tuple(new_fields); + auto target_n_args = AnnotateArgs(expr->fields); + auto new_expr = Tuple(std::get<1>(target_n_args)); + op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + return new_expr; } Expr VisitExpr_(const TupleGetItemNode* op) { auto new_e = ExprMutator::VisitExpr_(op); + auto expr = Downcast(new_e); - auto get = Downcast(new_e); - if (IsSupported(get->tuple)) { - const auto* begin_op = - runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - CHECK(begin_op); - return TupleGetItem((*begin_op)(InsertEnd(get->tuple), target_), get->index); - } else { - return TupleGetItem(InsertEnd(get->tuple), get->index); - } + auto target_n_args = AnnotateArgs(Array({expr->tuple})); + + std::string target = std::get<0>(target_n_args); + auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index); + op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + return new_expr; } Expr VisitExpr_(const FunctionNode* fn) { @@ -154,76 +172,96 @@ class AnnotateTargetWrapper : public ExprMutator { } else { auto new_e = ExprMutator::VisitExpr_(fn); func = Downcast(new_e); - new_body = InsertEnd(func->body); + new_body = func->body; + if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { + new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], end_op); + op_expr_to_target_[new_body] = op_expr_to_target_[func->body]; + } } - - return Function( - func->params, - new_body, - func->ret_type, - func->type_params, - func->attrs); + return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); } Expr VisitExpr_(const LetNode* op) { auto new_e = ExprMutator::VisitExpr_(op); + auto expr = Downcast(new_e); + + std::vector args = {expr->value, expr->body}; + auto target_n_args = AnnotateArgs(Array(args)); - auto let = Downcast(new_e); - return Let( - let->var, - InsertEnd(let->value), - InsertEnd(let->body)); + std::string target = std::get<0>(target_n_args); + auto new_expr = Let(expr->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); + op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + return new_expr; } Expr VisitExpr_(const IfNode* op) { auto new_e = ExprMutator::VisitExpr_(op); + auto expr = Downcast(new_e); - auto iff = Downcast(new_e); - return If( - InsertEnd(iff->cond), - InsertEnd(iff->true_branch), - InsertEnd(iff->false_branch)); + std::vector args = {expr->cond, expr->true_branch, expr->false_branch}; + auto target_n_args = AnnotateArgs(Array(args)); + + std::string target = std::get<0>(target_n_args); + auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1], + std::get<1>(target_n_args)[2]); + op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + return new_expr; } Expr VisitExpr_(const RefCreateNode* op) { auto new_e = ExprMutator::VisitExpr_(op); + auto expr = Downcast(new_e); - auto create = Downcast(new_e); - return RefCreate(InsertEnd(create->value)); + auto target_n_args = AnnotateArgs(Array({expr->value})); + auto new_expr = RefCreate(std::get<1>(target_n_args)[0]); + op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + return new_expr; } Expr VisitExpr_(const RefReadNode* op) { auto new_e = ExprMutator::VisitExpr_(op); + auto expr = Downcast(new_e); + + auto target_n_args = AnnotateArgs(Array({expr->ref})); - auto read = Downcast(new_e); - return RefRead(InsertEnd(read->ref)); + auto new_expr = RefRead(std::get<1>(target_n_args)[0]); + op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + return new_expr; } Expr VisitExpr_(const RefWriteNode* op) { auto new_e = ExprMutator::VisitExpr_(op); + auto expr = Downcast(new_e); + + auto target_n_args = AnnotateArgs(Array({expr->ref, expr->value})); - auto write = Downcast(new_e); - return RefWrite( - InsertEnd(write->ref), - InsertEnd(write->value)); + std::string target = std::get<0>(target_n_args); + auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); + op_expr_to_target_[new_expr] = std::get<0>(target_n_args); + return new_expr; } private: - std::string target_; + /*! \brief The target backends for annotation. */ + std::vector targets_; + /*! \brief Maintain the decision of the target for each op expr. */ + std::unordered_map op_expr_to_target_; + const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); + const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); }; -Expr AnnotateTarget(const Expr& expr, const std::string& target) { - return AnnotateTargetWrapper(target).Annotate(expr); +Expr AnnotateTarget(const Expr& expr, const Array targets) { + return AnnotateTargetWrapper(targets).Annotate(expr); } } // namespace annotate_target namespace transform { -Pass AnnotateTarget(const std::string& target) { +Pass AnnotateTarget(const Array targets) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(relay::annotate_target::AnnotateTarget(f, target)); + return Downcast(relay::annotate_target::AnnotateTarget(f, targets)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "AnnotateTargetFunc", {"InferType"}); diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 8eeac1748a43..fa9c8c4f40a2 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -477,13 +477,48 @@ class Partitioner : public ExprMutator { IRModule module_; }; +class DefaultRemover : public ExprMutator { + public: + explicit DefaultRemover(const IRModule& module) : module_(module) {} + + IRModule Remove() { + auto glob_funcs = module_->functions; + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, + func->attrs); + module_->Update(pair.first, func); + } + } + return module_; + } + + Expr VisitExpr_(const CallNode* call) final { + auto attrs = call->attrs.as(); + if (attrs != nullptr && attrs->compiler == "default") { + return VisitExpr(call->args[0]); + } + return ExprMutator::VisitExpr_(call); + } + + private: + IRModule module_; +}; + } // namespace partitioning namespace transform { Pass PartitionGraph() { runtime::TypedPackedFunc part_func = - [=](IRModule m, PassContext pc) { return partitioning::Partitioner(m).Partition(); }; + [=](IRModule m, PassContext pc) { + // TODO(@comaniac, @zhiics): We should also handle the annotation with "default" attribute + // by treating them as un-annotated, but we don't have it yet. This workaround pass removes + // all "default" annotations and should be deleted in the future. + auto new_m = partitioning::DefaultRemover(m).Remove(); + return partitioning::Partitioner(new_m).Partition(); + }; auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); return Sequential({partitioned, InferType()}); } diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py similarity index 93% rename from tests/python/relay/test_annotate_target.py rename to tests/python/relay/test_pass_annotate_target.py index dd00d7ece7bd..52d8d06c670f 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -136,7 +136,7 @@ def expected(dtype, ishape, w1shape): def test_annotate(): mod = annotated(dtype, ishape, w1shape) - mod = transform.AnnotateTarget("dnnl")(mod) + mod = transform.AnnotateTarget([tvm.runtime.container.String("dnnl")])(mod) ref_mod = expected(dtype, ishape, w1shape) tvm.ir.assert_structural_equal(mod, ref_mod) @@ -208,14 +208,21 @@ def after(): r = relay.nn.relu(cb_1) ce_1 = relay.annotation.compiler_end(r, "test") ce_2 = relay.annotation.compiler_end(r, "test") - a_1 = relay.abs(ce_1) - a_2 = relay.abs(ce_2) - out = relay.add(a_1, a_2) - f = relay.Function([x], out) + cb_2 = relay.annotation.compiler_begin(ce_1, "default") + cb_3 = relay.annotation.compiler_begin(ce_2, "default") + a_1 = relay.abs(cb_2) + a_2 = relay.abs(cb_3) + ce_3 = relay.annotation.compiler_end(a_1, "default") + ce_4 = relay.annotation.compiler_end(a_2, "default") + cb_4 = relay.annotation.compiler_begin(ce_3, "default") + cb_5 = relay.annotation.compiler_begin(ce_4, "default") + out = relay.add(cb_4, cb_5) + ce_6 = relay.annotation.compiler_end(out, "default") + f = relay.Function([x], ce_6) mod = tvm.IRModule.from_expr(f) return mod - result = transform.AnnotateTarget("test")(before()) + result = transform.AnnotateTarget([tvm.runtime.container.String("test")])(before()) expected = transform.InferType()(after()) assert tvm.ir.structural_equal(expected, result) @@ -266,7 +273,7 @@ def after(): if __name__ == "__main__": - test_multiple_ends() test_extern_dnnl() #test_extern_dnnl_mobilenet() test_composite_function() + test_multiple_ends() From ac040a035a27ed623de500320c914c489040aa0e Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 3 Apr 2020 22:10:35 +0000 Subject: [PATCH 03/17] Make all unit test working --- python/tvm/relay/transform/transform.py | 6 +- .../transforms/merge_compiler_regions.cc | 200 +----------------- .../python/relay/test_pass_annotate_target.py | 81 ++++++- .../relay/test_pass_merge_compiler_regions.py | 40 ++-- 4 files changed, 115 insertions(+), 212 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index fc0e3f833658..918894f69603 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -593,7 +593,7 @@ def AnnotateTarget(targets): Parameters ---------- - target : List[String] + targets : str or List[str] The list of target compilers used for codegen. Returns @@ -602,7 +602,9 @@ def AnnotateTarget(targets): The annotated pass that wrapps ops with subgraph_start and subgraph_end. """ - return _ffi_api.AnnotateTarget(targets) + if isinstance(targets, str): + targets = [targets] + return _ffi_api.AnnotateTarget([tvm.runtime.container.String(t) for t in targets]) def Inline(): diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index 5253010d4bcd..0024fc003ccc 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -46,184 +46,13 @@ namespace tvm { namespace relay { -namespace partitioning { +namespace merge_compiler_region { // Cache compiler_begin and compiler_end annotation ops for equivalence check to // reduce registry lookup overhead. static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); -/*! \brief This is a pre-requisite pass to merge-supported pass. - * The AnnotateRestDefault pass will put "default" Compiler Annotations to - * nodes that are not annotated already. This is there to ensure that the - * user will not leave un-annotated nodes MergeCompilerRegions pass is run. - * Why? Because, MergeCompilerRegions pass assumes every node to be annotated. - */ -class AnnotateRestDefault : public ExprMutator { - public: - explicit AnnotateRestDefault(const Expr& expr) { - regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); - } - - Expr Annotate(const Expr& expr) { - // Its a function that is being passed on to annotate - func_ = Downcast(expr); - - // Corner Case CC1 : If the last node does not belong - // to a region node to add a compiler_end - auto region = regions_->GetRegion(func_->body); - auto mutated_expr = this->VisitExpr(expr); - if (!region.defined()) { - func_ = Downcast(mutated_expr); - // CC1 : add that compiler end after mutation - auto body = InsertEnd(func_->body); - func_ = Function(func_->params, body, body->checked_type_, {}, DictAttrs()); - return Downcast(func_); - } - return mutated_expr; - } - - /*! \brief This function adds compiler ends to nodes that - * don't belong to a region already (default). - * \param expr The expression to add a compiler end to. - * \return expr The expression with or without a compiler end added. - */ - Expr InsertEnd(const Expr& expr) { - if (annotated_nodes_.find(expr) == annotated_nodes_.end() && !expr->IsInstance() && - !expr->IsInstance()) { - const auto* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); - CHECK(end_op); - Expr end = (*end_op)(expr, target_); - return end; - } - return expr; - } - - /*! \brief This function adds compiler begins to nodes that - * don't belong to a region already (default). - * \param expr The expression to add a compiler begin to. - * \return expr The expression with or without a compiler begin added. - */ - Expr InsertBegin(const Expr& expr) { - const auto* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - CHECK(begin_op); - Expr begin = (*begin_op)(expr, target_); - annotated_nodes_.insert(begin); - return begin; - } - - Expr VisitExpr_(const CallNode* cn) final { - auto region = regions_->GetRegion(GetRef(cn)); - auto new_e = ExprMutator::VisitExpr_(cn); - Call call = Downcast(new_e); - - // Add compiler ends if the parent isn't annotated - Array args; - for (auto arg : call->args) { - args.push_back(InsertEnd(arg)); - } - - Expr updated_call = Call(call->op, args, call->attrs); - if (!region.defined()) { - // if the current node does not belong to annotated region - // annotate the all incoming edges (args) - // with "default" compiler_begin annotations. - Array compiler_begins; - for (auto arg : args) { - compiler_begins.push_back(InsertBegin(arg)); - } - updated_call = Call(call->op, compiler_begins, call->attrs); - } else { - annotated_nodes_.insert(updated_call); - } - return updated_call; - }; - - Expr VisitExpr_(const TupleNode* op) { - auto region = regions_->GetRegion(GetRef(op)); - auto new_e = ExprMutator::VisitExpr_(op); - Tuple tup = Downcast(new_e); - - Array fields; - for (auto field : tup->fields) { - fields.push_back(InsertEnd(field)); - } - - Expr updated_tuple = Tuple(fields); - if (!region.defined()) { - Array compiler_begins; - for (const auto& field : fields) { - compiler_begins.push_back(InsertBegin(field)); - } - updated_tuple = Tuple(compiler_begins); - } else { - annotated_nodes_.insert(updated_tuple); - } - return updated_tuple; - } - - Expr VisitExpr_(const TupleGetItemNode* op) { - auto region = regions_->GetRegion(GetRef(op)); - auto new_e = ExprMutator::VisitExpr_(op); - auto get = Downcast(new_e); - - auto updated_tuple = InsertEnd(get->tuple); - Expr updated_get = TupleGetItem(updated_tuple, get->index); - if (!region.defined()) { - updated_get = TupleGetItem(InsertBegin(updated_tuple), get->index); - } else { - annotated_nodes_.insert(updated_get); - } - return updated_get; - } - - Expr VisitExpr_(const IfNode* op) { - auto region = regions_->GetRegion(GetRef(op)); - auto new_e = ExprMutator::VisitExpr_(op); - auto iff = Downcast(new_e); - - if (!region.defined()) { - return If(InsertBegin(InsertEnd(iff->cond)), InsertBegin(InsertEnd(iff->true_branch)), - InsertBegin(InsertEnd(iff->false_branch))); - } else { - Expr updated_iff = - If(InsertEnd(iff->cond), InsertEnd(iff->true_branch), InsertEnd(iff->false_branch)); - annotated_nodes_.insert(updated_iff); - return updated_iff; - } - } - - Expr VisitExpr_(const LetNode* op) { - auto new_e = ExprMutator::VisitExpr_(op); - auto let = Downcast(new_e); - return Let(let->var, InsertEnd(let->value), InsertEnd(let->body)); - } - - Expr VisitExpr_(const RefCreateNode* op) { - auto new_e = ExprMutator::VisitExpr_(op); - auto create = Downcast(new_e); - return RefCreate(InsertEnd(create->value)); - } - - Expr VisitExpr_(const RefReadNode* op) { - auto new_e = ExprMutator::VisitExpr_(op); - auto read = Downcast(new_e); - return RefRead(InsertEnd(read->ref)); - } - - Expr VisitExpr_(const RefWriteNode* op) { - auto new_e = ExprMutator::VisitExpr_(op); - auto write = Downcast(new_e); - return RefWrite(InsertEnd(write->ref), InsertEnd(write->value)); - } - - private: - AnnotatedRegionSet regions_; - const std::string target_ = "default"; - Function func_; - std::unordered_set annotated_nodes_; -}; - class MergeAnnotations : public ExprMutator { public: explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} @@ -333,37 +162,30 @@ class RegionMerger : public ExprVisitor { }; Expr MergeCompilerRegions(const Expr& expr) { - // Annotate all the nodes that aren't annotated as 'default'. - AnnotateRestDefault anno_default(expr); - auto expr_all_annotated = anno_default.Annotate(expr); - // Create regions using the annotations. - AnnotatedRegionSet regions = - AnnotatedRegionSet::Create(expr_all_annotated, compiler_begin_op, compiler_end_op); + AnnotatedRegionSet regions = AnnotatedRegionSet::Create(expr, compiler_begin_op, compiler_end_op); - // By now, all the nodes have some sort of annotation. - // Region merger is an ExprVisitor that will update the - // AnnotatedRegionSet, merging all the regions that can be merged. + // Analyze the graph to explore the opportunities of merging regions. RegionMerger merger(regions); - merger.VisitExpr(expr_all_annotated); + merger.VisitExpr(expr); - // This updates the expression to remove annotations that are now - // 'internal' to a merged region. + // Remove annotations that are not in the region boundaries. MergeAnnotations merge_anno(regions); - return merge_anno.Mutate(expr_all_annotated); + auto new_expr = merge_anno.Mutate(expr); + return new_expr; } -} // namespace partitioning +} // namespace merge_compiler_region namespace transform { Pass MergeCompilerRegions() { runtime::TypedPackedFunc part_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(partitioning::MergeCompilerRegions(f)); + return Downcast(merge_compiler_region::MergeCompilerRegions(f)); }; - auto partitioned = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {}); - return Sequential({partitioned, InferType()}); + auto merged = CreateFunctionPass(part_func, 0, "MergeCompilerRegions", {}); + return Sequential({merged, InferType()}); } TVM_REGISTER_GLOBAL("relay._transform.MergeCompilerRegions") diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 52d8d06c670f..1daada429d08 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -136,7 +136,7 @@ def expected(dtype, ishape, w1shape): def test_annotate(): mod = annotated(dtype, ishape, w1shape) - mod = transform.AnnotateTarget([tvm.runtime.container.String("dnnl")])(mod) + mod = transform.AnnotateTarget("dnnl")(mod) ref_mod = expected(dtype, ishape, w1shape) tvm.ir.assert_structural_equal(mod, ref_mod) @@ -186,12 +186,11 @@ def test_extern_dnnl_mobilenet(): (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params) -@reg.register("nn.relu", "target.test") -def relu(attrs, args): - return True - - def test_multiple_ends(): + @reg.register("nn.relu", "target.test") + def relu(attrs, args): # pylint: disable=unused-variable + return True + def before(): x = relay.var("x", shape=(10, 10)) r = relay.nn.relu(x) @@ -222,7 +221,73 @@ def after(): mod = tvm.IRModule.from_expr(f) return mod - result = transform.AnnotateTarget([tvm.runtime.container.String("test")])(before()) + result = transform.AnnotateTarget("test")(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + +def test_type_propagation(): + target = "test_type_propagation" + + @reg.register("nn.relu", "target.test_type_propagation" + target) + def relu(attrs, args): # pylint: disable=unused-variable + return args[0].checked_type.dtype == "float32" + + def before(): + x = relay.var("x", shape=(10, 10)) + r = relay.nn.relu(x) + out = relay.nn.relu(r) + f = relay.Function([x], out) + mod = tvm.IRModule.from_expr(f) + return mod + + # If the type isn't propogated, then the relu checker function will fail to get the dtype. + assert transform.AnnotateTarget(target)(before()) + + +def test_tuple(): + target = "test_tuple" + + @reg.register("nn.relu", "target." + target) + def relu(attrs, args): # pylint: disable=unused-variable + return True + + @reg.register("concatenate", "target." + target) + def concatenate(attrs, args): # pylint: disable=unused-variable + return True + + """Test that TupleNode is included in annotation when surrounded by supported nodes.""" + def before(): + x = relay.var("x", shape=(10, 5)) + y = relay.var("y", shape=(10, 5)) + a_1 = relay.nn.relu(x) + a_2 = relay.nn.relu(y) + out = relay.concatenate((a_1, a_2), axis=1) + f = relay.Function([x, y], out) + mod = tvm.IRModule.from_expr(f) + return mod + + def after(): + x = relay.var("x", shape=(10, 5)) + y = relay.var("y", shape=(10, 5)) + cb_1 = relay.annotation.compiler_begin(x, target) + cb_2 = relay.annotation.compiler_begin(y, target) + a_1 = relay.nn.relu(cb_1) + a_2 = relay.nn.relu(cb_2) + ce_1 = relay.annotation.compiler_end(a_1, target) + ce_2 = relay.annotation.compiler_end(a_2, target) + cb_3 = relay.annotation.compiler_begin(ce_1, target) + cb_4 = relay.annotation.compiler_begin(ce_2, target) + tup = relay.Tuple([cb_3, cb_4]) + ce_3 = relay.annotation.compiler_end(tup, target) + cb_3 = relay.annotation.compiler_begin(ce_3, target) + out = relay.op.concatenate(cb_3, 1) + ce_4 = relay.annotation.compiler_end(out, target) + f = relay.Function([x, y], ce_4) + mod = tvm.IRModule.from_expr(f) + return mod + + result = transform.AnnotateTarget(target)(before()) expected = transform.InferType()(after()) assert tvm.ir.structural_equal(expected, result) @@ -277,3 +342,5 @@ def after(): #test_extern_dnnl_mobilenet() test_composite_function() test_multiple_ends() + test_type_propagation() + test_tuple() diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py index f316a41a88da..efa99da2c274 100644 --- a/tests/python/relay/test_pass_merge_compiler_regions.py +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -30,9 +30,9 @@ def test_diamond_graph_fanouts(): X = not supported by target O O - / \ / \ + / \\ / \\ O X --> O + + X - \ / \ / + \\ / \\ / O O Note that we can't just merge the three supported operators together, @@ -45,17 +45,20 @@ def diamond_graph_fanouts(): ce_1 = compiler_end(O_1, "test") ce_2 = compiler_end(O_1, "test") cb_2 = compiler_begin(ce_1, "test") + cb_3 = compiler_begin(ce_2, "default") O_2 = relay.nn.relu(cb_2) ce_3 = compiler_end(O_2, "test") - X = relay.tanh(ce_2) - cb_3 = compiler_begin(ce_3, "test") - cb_4 = compiler_begin(X, "test") - O_3 = relay.add(cb_3, cb_4) - ce_4 = compiler_end(O_3, "test") + X = relay.tanh(cb_3) + ce_4 = compiler_end(X, "default") + + cb_4 = compiler_begin(ce_3, "test") + cb_5 = compiler_begin(ce_4, "test") + O_3 = relay.add(cb_4, cb_5) + ce_5 = compiler_end(O_3, "test") - diamond = relay.Function([data], ce_4) + diamond = relay.Function([data], ce_5) return diamond def expected(): @@ -85,7 +88,7 @@ def test_example_graph(): """This tests the merging algorithm on the example used in the RFC. See the RFC here: https://discuss.tvm.ai/t/relay-improved-graph-partitioning-algorithm/5830 - Blue nodes are adds, red nodes are subtracts. + Blue nodes are adds (target: test), red nodes are subtracts (target: default). """ def annotated(): in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') @@ -112,21 +115,30 @@ def annotated(): node2 = relay.add(begin4, begin5) end2 = compiler_end(node2, "test") - node3 = relay.subtract(in_5, in_6) - node4 = relay.subtract(in_7, node3) + dbegin0 = compiler_begin(in_5, "default") + dbegin1 = compiler_begin(in_6, "default") + node3 = relay.subtract(dbegin0, dbegin1) + dbegin2 = compiler_begin(in_7, "default") + dend1 = compiler_end(node3, "default") + dbegin3 = compiler_begin(dend1, "default") + node4 = relay.subtract(dbegin2, dbegin3) + dend2 = compiler_end(node4, "default") begin6 = compiler_begin(end2, "test") - begin7 = compiler_begin(node4, "test") + begin7 = compiler_begin(dend2, "test") node5 = relay.add(begin6, begin7) end3 = compiler_end(node5, "test") end4 = compiler_end(node5, "test") - node6 = relay.subtract(in_8, end3) + dbegin4 = compiler_begin(in_8, "default") + dbegin5 = compiler_begin(end3, "default") + node6 = relay.subtract(dbegin4, dbegin5) begin8 = compiler_begin(in_9, "test") begin9 = compiler_begin(end4, "test") node7 = relay.add(begin8, begin9) end5 = compiler_end(node7, "test") - begin10 = compiler_begin(node6, "test") + dend3 = compiler_end(node6, "default") + begin10 = compiler_begin(dend3, "test") begin11 = compiler_begin(end5, "test") node8 = relay.add(begin10, begin11) end6 = compiler_end(node8, "test") From 10add0d357b317b40046b4eb0e520c1b5bbe4b1e Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 3 Apr 2020 22:33:06 +0000 Subject: [PATCH 04/17] quick fix --- tests/python/relay/test_pass_annotate_target.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 1daada429d08..638b79a991b0 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -229,7 +229,7 @@ def after(): def test_type_propagation(): target = "test_type_propagation" - @reg.register("nn.relu", "target.test_type_propagation" + target) + @reg.register("nn.relu", "target." + target) def relu(attrs, args): # pylint: disable=unused-variable return args[0].checked_type.dtype == "float32" @@ -281,7 +281,7 @@ def after(): tup = relay.Tuple([cb_3, cb_4]) ce_3 = relay.annotation.compiler_end(tup, target) cb_3 = relay.annotation.compiler_begin(ce_3, target) - out = relay.op.concatenate(cb_3, 1) + out = relay.op._make.concatenate(cb_3, 1) ce_4 = relay.annotation.compiler_end(out, target) f = relay.Function([x, y], ce_4) mod = tvm.IRModule.from_expr(f) From b89d8c25dcaa2896decc99694eb27c1ea21a213b Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 3 Apr 2020 22:48:24 +0000 Subject: [PATCH 05/17] enable BN, unit test failed --- tests/python/relay/test_pass_partition_graph.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 1968f34d31b6..ee440df7d3f8 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -17,6 +17,7 @@ """Unit tests for graph partitioning.""" import os import sys + import numpy as np import pytest @@ -26,8 +27,12 @@ from tvm import runtime from tvm.relay import transform from tvm.contrib import util -from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.relay import transform +from tvm.relay.backend import compile_engine from tvm.relay.expr_functor import ExprMutator +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.runtime import container + # Leverage the pass manager to write a simple white list based annotator @transform.function_pass(opt_level=0) @@ -188,6 +193,7 @@ def update_lib(lib): return lib def check_vm_result(): + compile_engine.get().clear() with relay.build_config(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() @@ -199,6 +205,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) def check_graph_runtime_result(): + compile_engine.get().clear() with relay.build_config(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) lib = update_lib(lib) @@ -449,10 +456,11 @@ def test_extern_dnnl_mobilenet(): mod, params = relay.testing.mobilenet.get_workload( batch_size=1, dtype='float32') - op_list = ["nn.conv2d", "nn.dense", "nn.relu", "add"] - mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) - mod = WhiteListAnnotator(op_list, "dnnl")(mod) + mod = transform.AnnotateTarget(["dnnl"])(mod) + mod = transform.MergeCompilerRegions() mod = transform.PartitionGraph()(mod) + # FIXME(@comaniac): Still try to fuse global function when lowering + print(mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, From 0bd11284d0e3a5acdaea3262e9c1b484bd14e4e0 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Sat, 4 Apr 2020 00:04:32 +0000 Subject: [PATCH 06/17] quick fix fusion --- src/relay/transforms/fuse_ops.cc | 3 +++ tests/python/relay/test_pass_partition_graph.py | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f646042962f0..fd07d925123c 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -199,6 +199,9 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Post order tree void VisitExpr_(const FunctionNode* op) final { + if (op->GetAttr(attr::kCompiler).defined()) { + return; + } for (auto param : op->params) { this->Update(param, nullptr, kOpaque); } diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index ee440df7d3f8..72ba21f35075 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -457,10 +457,8 @@ def test_extern_dnnl_mobilenet(): batch_size=1, dtype='float32') mod = transform.AnnotateTarget(["dnnl"])(mod) - mod = transform.MergeCompilerRegions() + mod = transform.MergeCompilerRegions()(mod) mod = transform.PartitionGraph()(mod) - # FIXME(@comaniac): Still try to fuse global function when lowering - print(mod) i_data = np.random.uniform(0, 1, ishape).astype(dtype) ref_mod, params = relay.testing.mobilenet.get_workload(batch_size=1, From e8a41daee15bde46aace673ab3f5d70a848c56ae Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 3 Apr 2020 22:56:39 +0000 Subject: [PATCH 07/17] Fix vm test, unit test. Refactor annotate_target a bit. --- src/relay/transforms/annotate_target.cc | 62 +++++++------------ .../python/relay/test_pass_partition_graph.py | 1 + 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 7a2a723a3f02..01479f57062f 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -19,8 +19,8 @@ /*! * \file src/relay/transforms/annotate_target.cc - * \brief Wraps a call with compiler_begin and compiler_end to indicate that - * the op of this call node will use external compiler. + * \brief Wraps an expr with compiler_begin and compiler_end to indicate that + * this expr should be handled by the external compiler. */ #include @@ -39,30 +39,23 @@ static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); // be handled by a specific compiler. class AnnotateTargetWrapper : public ExprMutator { public: - AnnotateTargetWrapper(const Array targets) { - for (auto target : targets) { - targets_.push_back(target.data()); - } - } - - Expr Annotate(const Expr& expr) { - auto new_expr = Mutate(expr); - //std::cerr << AsText(new_expr); - return new_expr; - } + explicit AnnotateTargetWrapper(Array targets) + : targets_(std::move(targets)) {} - /*! \brief This function 1) annotates a compiler end and a compiler begin to all arguments. - * The compiler end is based on the arg target while the compiler begin is based on the given - * target. If target is not given and all arguments are going to the same target, then we will - * use that target; otherwise we use default for this op. Note that all arg exprs must be - * available in op_expr_to_target before calling this function. + /*! + * \brief This function annotates a compiler end and a compiler begin to all arguments. + * + * The compiler end is based on the arg target while the compiler begin is based on the given + * target. If target is not given and all arguments are going to the same target, then we will + * use that target; otherwise we use default for this op. Note that all arg exprs must be + * available in op_expr_to_target before calling this function. * * \param args An array of arguments of the given node. * \param target The target of the current node. * \return A pair of target and annotated argument expressions. */ - std::pair> AnnotateArgs(const Array args, - const std::string target = "") { + std::pair> AnnotateArgs(const Array& args, + const std::string& target = "") { std::string ref_target = ""; Array compiler_ends; for (auto arg : args) { @@ -72,27 +65,26 @@ class AnnotateTargetWrapper : public ExprMutator { if (ref_target == "") { ref_target = arg_target; } else if (ref_target != arg_target) { - ref_target = "__inconsist__"; + ref_target = "default"; } } else { // Input vars. compiler_ends.push_back(arg); } } - ref_target = (ref_target == "__inconsist__") ? "default" : ref_target; // Determine compiler begin target. - std::string op_target = (target == "")? ref_target: target; + std::string op_target = (target == "") ? ref_target : target; Array compiler_begins; - for (auto end : compiler_ends) { + for (const auto& end : compiler_ends) { compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op)); } return {op_target, compiler_begins}; } - Expr InsertAnnotation(const Expr& expr, const std::string target, const PackedFunc* ann_op) { + Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) { Expr new_op = (*ann_op)(expr, target); new_op->checked_type_ = expr->checked_type_; return new_op; @@ -106,8 +98,8 @@ class AnnotateTargetWrapper : public ExprMutator { std::vector supported_targets; // Check which targets this op can be offloaded. - for (auto target : this->targets_) { - auto fannotate = Op::GetAttr("target." + target); + for (const auto& target : this->targets_) { + auto fannotate = Op::GetAttr("target." + std::string(target)); if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) { supported_targets.push_back(target); } @@ -120,18 +112,12 @@ class AnnotateTargetWrapper : public ExprMutator { std::string target = supported_targets[0]; // Visit and mutate arguments after the target of this op has been determined. - auto new_e = ExprMutator::VisitExpr_(cn); - Call call = Downcast(new_e); + auto new_call = Downcast(ExprMutator::VisitExpr_(cn)); // Add annotations to each arg. - auto target_n_args = AnnotateArgs(call->args, target); + auto target_n_args = AnnotateArgs(new_call->args, target); Array compiler_begins = std::get<1>(target_n_args); - // for (auto b : compiler_begins) { - // std::cerr << AsText(b); - // std::cerr << "===============\n"; - // } - //std::cerr << "*********************************************\n"; - call = Call(call->op, compiler_begins, call->attrs); + Call call = Call(new_call->op, compiler_begins, new_call->attrs); call->checked_type_ = cn->checked_type_; // Update the target map. @@ -243,7 +229,7 @@ class AnnotateTargetWrapper : public ExprMutator { private: /*! \brief The target backends for annotation. */ - std::vector targets_; + Array targets_; /*! \brief Maintain the decision of the target for each op expr. */ std::unordered_map op_expr_to_target_; const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); @@ -251,7 +237,7 @@ class AnnotateTargetWrapper : public ExprMutator { }; Expr AnnotateTarget(const Expr& expr, const Array targets) { - return AnnotateTargetWrapper(targets).Annotate(expr); + return AnnotateTargetWrapper(targets).Mutate(expr); } } // namespace annotate_target diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 72ba21f35075..375bae95f92f 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -456,6 +456,7 @@ def test_extern_dnnl_mobilenet(): mod, params = relay.testing.mobilenet.get_workload( batch_size=1, dtype='float32') + mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) mod = transform.AnnotateTarget(["dnnl"])(mod) mod = transform.MergeCompilerRegions()(mod) mod = transform.PartitionGraph()(mod) From 5c9adff3f8f0ba616cd69a3c86638fe9600e0ba0 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 4 Apr 2020 01:09:25 +0000 Subject: [PATCH 08/17] revert fusion change --- src/relay/transforms/fuse_ops.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index fd07d925123c..f646042962f0 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -199,9 +199,6 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Post order tree void VisitExpr_(const FunctionNode* op) final { - if (op->GetAttr(attr::kCompiler).defined()) { - return; - } for (auto param : op->params) { this->Update(param, nullptr, kOpaque); } From 8e3e0141a4327c30354d5ac844e81074a02a20bc Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 4 Apr 2020 23:58:20 +0000 Subject: [PATCH 09/17] style fix --- src/relay/transforms/annotate_target.cc | 47 ++++++++++--------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 01479f57062f..c7fcd67b6b38 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -33,7 +33,8 @@ namespace tvm { namespace relay { namespace annotate_target { -static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); +const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); +const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. @@ -90,7 +91,7 @@ class AnnotateTargetWrapper : public ExprMutator { return new_op; } - Expr VisitExpr_(const CallNode* cn) { + Expr VisitExpr_(const CallNode* cn) final { Op op = Downcast(cn->op); CHECK(op.defined()); @@ -126,7 +127,7 @@ class AnnotateTargetWrapper : public ExprMutator { return std::move(call); } - Expr VisitExpr_(const TupleNode* op) { + Expr VisitExpr_(const TupleNode* op) final { auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast(new_e); @@ -136,19 +137,17 @@ class AnnotateTargetWrapper : public ExprMutator { return new_expr; } - Expr VisitExpr_(const TupleGetItemNode* op) { + Expr VisitExpr_(const TupleGetItemNode* op) final { auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast(new_e); auto target_n_args = AnnotateArgs(Array({expr->tuple})); - - std::string target = std::get<0>(target_n_args); auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return new_expr; } - Expr VisitExpr_(const FunctionNode* fn) { + Expr VisitExpr_(const FunctionNode* fn) final { Function func; Expr new_body; // don't step into composite functions @@ -167,34 +166,29 @@ class AnnotateTargetWrapper : public ExprMutator { return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); } - Expr VisitExpr_(const LetNode* op) { + Expr VisitExpr_(const LetNode* op) final { auto new_e = ExprMutator::VisitExpr_(op); - auto expr = Downcast(new_e); - - std::vector args = {expr->value, expr->body}; - auto target_n_args = AnnotateArgs(Array(args)); + auto let = Downcast(new_e); - std::string target = std::get<0>(target_n_args); - auto new_expr = Let(expr->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); + auto target_n_args = AnnotateArgs({let->value, let->body}); + auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return new_expr; } - Expr VisitExpr_(const IfNode* op) { + Expr VisitExpr_(const IfNode* op) final { auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast(new_e); - std::vector args = {expr->cond, expr->true_branch, expr->false_branch}; - auto target_n_args = AnnotateArgs(Array(args)); - - std::string target = std::get<0>(target_n_args); + auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch}); + CHECK_EQ(std::get<1>(target_n_args).size(), 3U); auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1], std::get<1>(target_n_args)[2]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return new_expr; } - Expr VisitExpr_(const RefCreateNode* op) { + Expr VisitExpr_(const RefCreateNode* op) final { auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast(new_e); @@ -204,24 +198,21 @@ class AnnotateTargetWrapper : public ExprMutator { return new_expr; } - Expr VisitExpr_(const RefReadNode* op) { + Expr VisitExpr_(const RefReadNode* op) final { auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast(new_e); auto target_n_args = AnnotateArgs(Array({expr->ref})); - auto new_expr = RefRead(std::get<1>(target_n_args)[0]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return new_expr; } - Expr VisitExpr_(const RefWriteNode* op) { + Expr VisitExpr_(const RefWriteNode* op) final { auto new_e = ExprMutator::VisitExpr_(op); auto expr = Downcast(new_e); auto target_n_args = AnnotateArgs(Array({expr->ref, expr->value})); - - std::string target = std::get<0>(target_n_args); auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); return new_expr; @@ -232,11 +223,9 @@ class AnnotateTargetWrapper : public ExprMutator { Array targets_; /*! \brief Maintain the decision of the target for each op expr. */ std::unordered_map op_expr_to_target_; - const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); - const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); }; -Expr AnnotateTarget(const Expr& expr, const Array targets) { +Expr AnnotateTarget(const Expr& expr, const Array& targets) { return AnnotateTargetWrapper(targets).Mutate(expr); } @@ -244,7 +233,7 @@ Expr AnnotateTarget(const Expr& expr, const Array targets) { namespace transform { -Pass AnnotateTarget(const Array targets) { +Pass AnnotateTarget(const Array& targets) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(relay::annotate_target::AnnotateTarget(f, targets)); From cfb7608f227d45fb9191be874b07b6915bbfeaa3 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 6 Apr 2020 21:31:33 +0000 Subject: [PATCH 10/17] Refactor merge region pass --- .../transforms/merge_compiler_regions.cc | 138 +++++++++--------- .../relay/test_pass_merge_compiler_regions.py | 31 ++-- .../python/relay/test_pass_partition_graph.py | 1 + 3 files changed, 92 insertions(+), 78 deletions(-) diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index 0024fc003ccc..c13caf76b311 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -53,38 +53,6 @@ namespace merge_compiler_region { static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); -class MergeAnnotations : public ExprMutator { - public: - explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} - - Expr VisitExpr_(const CallNode* call) final { - // remove 'default' annotations - auto attrs = call->attrs.as(); - if (attrs != nullptr && attrs->compiler == "default") { - return VisitExpr(call->args[0]); - } - // Merge annotations which are now internal to a region. - // This happens if we see a compiler begin next to a - // compiler end and they're both in the same region. - if (call->op == compiler_begin_op) { - if (call->args[0]->IsInstance()) { - auto arg = Downcast(call->args[0]); - if (arg->op == compiler_end_op) { - auto region1 = regions_->GetRegion(GetRef(call)); - auto region2 = regions_->GetRegion(arg); - if (region1 == region2) { - return VisitExpr(arg->args[0]); - } - } - } - } - return ExprMutator::VisitExpr_(call); - } - - private: - AnnotatedRegionSet regions_; -}; - class RegionMerger : public ExprVisitor { public: explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {} @@ -92,62 +60,75 @@ class RegionMerger : public ExprVisitor { void VisitExpr_(const CallNode* call) final { if (call->op == compiler_end_op) { auto region = regions_->GetRegion(GetRef(call)); - if (merged_regions_.find(region->GetID()) != merged_regions_.end()) return; - // set the region target + + // Skip this region if it has been merged to the other region. + if (merged_regions_.find(region->GetID()) != merged_regions_.end()) { + return; + } + + // Check the region target. auto compiler_attrs = call->attrs.as(); - region_targets_[region->GetID()] = compiler_attrs->compiler; - // first look at the region args to determine the parent regions + CHECK_EQ(region->GetTarget(), compiler_attrs->compiler); + + // Visit the unmerged parent regions. for (const auto& arg : region->GetInputs()) { - // all args should be begin annotations + // Region inputs must be begin annotation, and the region of + // the begin annotation's argument is the parent region. auto begin = Downcast(arg); CHECK_EQ(begin->op, compiler_begin_op); - // the arguments of the begin annotations will be in the parent regions auto parent_region = regions_->GetRegion(begin->args[0]); - // if there is no parent region, move on - if (!parent_region.defined()) continue; - // merge the parent region if it hasn't been done already - if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { + + // Skip this region if it has been merged. + if (!parent_region.defined()) { + continue; + } else if (merged_regions_.find(parent_region->GetID()) == merged_regions_.end()) { VisitExpr(begin->args[0]); } } - // get the mergeable regions now all the parents have been visited + + // Collect unmerged parent regions. std::unordered_set mergeable_regions; for (const auto& arg : region->GetInputs()) { auto begin = Downcast(arg); CHECK_EQ(begin->op, compiler_begin_op); auto parent_region = regions_->GetRegion(begin->args[0]); - if (!parent_region.defined()) continue; + if (!parent_region.defined()) { + continue; + } mergeable_regions.insert(parent_region); } + + // Propogate all the parent restrictions to the current region. auto& region_restrictions = region_restrictions_[region->GetID()]; for (const auto& parent_region : mergeable_regions) { - // add all the parent restrictions to the current region auto parent_restrictions = region_restrictions_[parent_region->GetID()]; region_restrictions.insert(parent_restrictions.begin(), parent_restrictions.end()); } + for (const auto& parent_region : mergeable_regions) { - bool merged = false; - // check the parent region has the same target - if (region_targets_[parent_region->GetID()] == compiler_attrs->compiler) { - // check the parent region isn't in the restrictions - if (region_restrictions.find(parent_region->GetID()) == region_restrictions.end()) { - // merge the parent region into the current region - regions_->MergeRegions(parent_region, region); - // update the restrictions of all other regions to reflect the - // change in id - for (const auto& r : regions_) { - auto& restrictions = region_restrictions_[r->GetID()]; - if (restrictions.find(parent_region->GetID()) != restrictions.end()) { - restrictions.erase(parent_region->GetID()); - restrictions.insert(region->GetID()); - } - } - merged = true; + // Skip the parent region with a different target. + if (parent_region->GetTarget() != compiler_attrs->compiler) { + region_restrictions.insert(parent_region->GetID()); + continue; + } + + // Skip the parent region if it is in the restriction set. + if (region_restrictions.find(parent_region->GetID()) != region_restrictions.end()) { + continue; + } + + // Merge the parent region to the current one. + regions_->MergeRegions(parent_region, region); + + // Replace the parent region ID with the current region for all + // other regions' restriction sets. + for (const auto& r : regions_) { + auto& restrictions = region_restrictions_[r->GetID()]; + if (restrictions.find(parent_region->GetID()) != restrictions.end()) { + restrictions.erase(parent_region->GetID()); + restrictions.insert(region->GetID()); } } - // if the parent wasn't merged, add it as a restriction to the current - // region - if (!merged) region_restrictions.insert(parent_region->GetID()); } merged_regions_.insert(region->GetID()); } @@ -157,8 +138,31 @@ class RegionMerger : public ExprVisitor { private: AnnotatedRegionSet regions_; std::unordered_set merged_regions_; - std::map> region_restrictions_; - std::map region_targets_; + std::unordered_map> region_restrictions_; +}; +class MergeAnnotations : public ExprMutator { + public: + explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} + + Expr VisitExpr_(const CallNode* call) final { + // Merge annotations which are now internal to a region. + // This happens if we see a compiler begin next to a + // compiler end and they're both in the same region. + if (call->op == compiler_begin_op && call->args[0]->IsInstance()) { + auto arg = Downcast(call->args[0]); + if (arg->op == compiler_end_op) { + auto region1 = regions_->GetRegion(GetRef(call)); + auto region2 = regions_->GetRegion(arg); + if (region1 == region2) { + return VisitExpr(arg->args[0]); + } + } + } + return ExprMutator::VisitExpr_(call); + } + + private: + AnnotatedRegionSet regions_; }; Expr MergeCompilerRegions(const Expr& expr) { diff --git a/tests/python/relay/test_pass_merge_compiler_regions.py b/tests/python/relay/test_pass_merge_compiler_regions.py index efa99da2c274..7d7db357602e 100644 --- a/tests/python/relay/test_pass_merge_compiler_regions.py +++ b/tests/python/relay/test_pass_merge_compiler_regions.py @@ -69,14 +69,16 @@ def expected(): O_2 = relay.nn.relu(O_1) ce_3 = compiler_end(O_2, "test") - X = relay.tanh(ce_2) + cb_3 = compiler_begin(ce_2, "default") + X = relay.tanh(cb_3) + ce_4 = compiler_end(X, "default") - cb_3 = compiler_begin(ce_3, "test") - cb_4 = compiler_begin(X, "test") - O_3 = relay.add(cb_3, cb_4) - ce_4 = compiler_end(O_3, "test") + cb_4 = compiler_begin(ce_3, "test") + cb_5 = compiler_begin(ce_4, "test") + O_3 = relay.add(cb_4, cb_5) + ce_5 = compiler_end(O_3, "test") - func = relay.Function([data], ce_4) + func = relay.Function([data], ce_5) return func result = run_opt_pass(diamond_graph_fanouts(), relay.transform.MergeCompilerRegions()) @@ -171,20 +173,27 @@ def expected(): node1 = relay.add(begin2, begin3) node2 = relay.add(node0, node1) - node3 = relay.subtract(in_5, in_6) - node4 = relay.subtract(in_7, node3) + dbegin0 = compiler_begin(in_5, "default") + dbegin1 = compiler_begin(in_6, "default") + dbegin2 = compiler_begin(in_7, "default") + node3 = relay.subtract(dbegin0, dbegin1) + node4 = relay.subtract(dbegin2, node3) + dend0 = compiler_end(node4, "default") - begin4 = compiler_begin(node4, "test") + begin4 = compiler_begin(dend0, "test") begin5 = compiler_begin(in_9, "test") node5 = relay.add(node2, begin4) end1 = compiler_end(node5, "test") - node6 = relay.subtract(in_8, end1) + dbegin4 = compiler_begin(end1, "default") + dbegin5 = compiler_begin(in_8, "default") + node6 = relay.subtract(dbegin5, dbegin4) + dend1 = compiler_end(node6, "default") node7 = relay.add(begin5, node5) end2 = compiler_end(node7, "test") begin6 = compiler_begin(end2, "test") - begin7 = compiler_begin(node6, "test") + begin7 = compiler_begin(dend1, "test") node8 = relay.add(begin7, begin6) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 375bae95f92f..fb216822c295 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -858,6 +858,7 @@ def expected(): test_extern_ccompiler_default_ops() test_extern_ccompiler() test_extern_dnnl() + # TODO(@comaniac, @zhiics): Fix constant node and re-open this case. #test_extern_dnnl_mobilenet() test_function_lifting() test_function_lifting_inline() From 9e869c49a7b32021934ae1236e7d129cf2e786ac Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 6 Apr 2020 21:33:26 +0000 Subject: [PATCH 11/17] format --- src/relay/transforms/annotate_target.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index c7fcd67b6b38..24e771d8244f 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -40,8 +40,7 @@ const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.com // be handled by a specific compiler. class AnnotateTargetWrapper : public ExprMutator { public: - explicit AnnotateTargetWrapper(Array targets) - : targets_(std::move(targets)) {} + explicit AnnotateTargetWrapper(Array targets) : targets_(std::move(targets)) {} /*! * \brief This function annotates a compiler end and a compiler begin to all arguments. @@ -243,8 +242,7 @@ Pass AnnotateTarget(const Array& targets) { return transform::Sequential({func_pass, InferType()}, "AnnotateTarget"); } -TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget") -.set_body_typed(AnnotateTarget); +TVM_REGISTER_GLOBAL("relay._transform.AnnotateTarget").set_body_typed(AnnotateTarget); } // namespace transform From 451540304fc10cfaa0ca85c8b5e565b4bef3f8dd Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Tue, 7 Apr 2020 00:51:19 +0000 Subject: [PATCH 12/17] minor fix --- src/relay/transforms/merge_compiler_regions.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/relay/transforms/merge_compiler_regions.cc b/src/relay/transforms/merge_compiler_regions.cc index c13caf76b311..601be0f96bc4 100644 --- a/src/relay/transforms/merge_compiler_regions.cc +++ b/src/relay/transforms/merge_compiler_regions.cc @@ -92,10 +92,9 @@ class RegionMerger : public ExprVisitor { auto begin = Downcast(arg); CHECK_EQ(begin->op, compiler_begin_op); auto parent_region = regions_->GetRegion(begin->args[0]); - if (!parent_region.defined()) { - continue; + if (parent_region.defined()) { + mergeable_regions.insert(parent_region); } - mergeable_regions.insert(parent_region); } // Propogate all the parent restrictions to the current region. @@ -140,6 +139,7 @@ class RegionMerger : public ExprVisitor { std::unordered_set merged_regions_; std::unordered_map> region_restrictions_; }; + class MergeAnnotations : public ExprMutator { public: explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {} @@ -175,8 +175,7 @@ Expr MergeCompilerRegions(const Expr& expr) { // Remove annotations that are not in the region boundaries. MergeAnnotations merge_anno(regions); - auto new_expr = merge_anno.Mutate(expr); - return new_expr; + return merge_anno.Mutate(expr); } } // namespace merge_compiler_region From 41147a9d42aaee5b3c69521b4c50dc6e310f2008 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 7 Apr 2020 18:27:32 +0000 Subject: [PATCH 13/17] Skip e2e test --- src/relay/transforms/annotate_target.cc | 35 +++++++++++++++---- .../python/relay/test_pass_annotate_target.py | 4 +++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 24e771d8244f..012d192bfcff 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -91,17 +91,38 @@ class AnnotateTargetWrapper : public ExprMutator { } Expr VisitExpr_(const CallNode* cn) final { - Op op = Downcast(cn->op); - CHECK(op.defined()); - // Supported targets for this node. The order implies the priority. std::vector supported_targets; // Check which targets this op can be offloaded. - for (const auto& target : this->targets_) { - auto fannotate = Op::GetAttr("target." + std::string(target)); - if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) { - supported_targets.push_back(target); + if (cn->op->IsInstance()) { + // TVM operators: Check target specific op checking function and add to supported_targets + // if it is supported. + Op op = Downcast(cn->op); + CHECK(op.defined()); + for (const auto& target : this->targets_) { + auto fannotate = Op::GetAttr("target." + std::string(target)); + if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) { + supported_targets.push_back(target); + } + } + } else if (cn->op->IsInstance()) { + // Composite function: Add the target of a composite function to supported_targets + // if it is in the target list. + Function func = Downcast(cn->op); + CHECK(func.defined()); + auto comp_name = func->GetAttr(attr::kComposite); + if (comp_name.defined()) { + size_t i = comp_name->value.find('.'); + if (i != std::string::npos) { + std::string comp_target = comp_name->value.substr(0, i); + for (const auto& target : this->targets_) { + if (std::string(target) == comp_target) { + supported_targets.push_back(comp_target); + break; + } + } + } } } supported_targets.push_back("default"); // Make default as the last option. diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 638b79a991b0..49c12484b793 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -339,8 +339,12 @@ def after(): if __name__ == "__main__": test_extern_dnnl() +<<<<<<< HEAD #test_extern_dnnl_mobilenet() +======= +>>>>>>> Skip e2e test test_composite_function() + #test_extern_dnnl_mobilenet() test_multiple_ends() test_type_propagation() test_tuple() From 4a7bf3acb4b61d75787883e55f1360bb0849f3d6 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 8 Apr 2020 17:12:37 +0000 Subject: [PATCH 14/17] lint --- src/relay/analysis/annotated_region_set.cc | 3 +-- src/relay/transforms/annotate_target.cc | 14 +++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index b6e1dfc5223a..c70ef8ae242e 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -125,8 +125,7 @@ class AnnotatedRegionSet::Creator : public ExprVisitor { // Create a new region if the argument is not belonged to any regions yet. region = region_set_->MakeRegion(target); region->nodes_.insert(call->args[0]); - } - else { + } else { // If the argument is belonged to a region, it must have the same target. // Otherwise we should see a region_begin op. CHECK_EQ(region->GetTarget(), target); diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 012d192bfcff..cae3b639a973 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -154,7 +154,7 @@ class AnnotateTargetWrapper : public ExprMutator { auto target_n_args = AnnotateArgs(expr->fields); auto new_expr = Tuple(std::get<1>(target_n_args)); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); - return new_expr; + return std::move(new_expr); } Expr VisitExpr_(const TupleGetItemNode* op) final { @@ -164,7 +164,7 @@ class AnnotateTargetWrapper : public ExprMutator { auto target_n_args = AnnotateArgs(Array({expr->tuple})); auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); - return new_expr; + return std::move(new_expr); } Expr VisitExpr_(const FunctionNode* fn) final { @@ -193,7 +193,7 @@ class AnnotateTargetWrapper : public ExprMutator { auto target_n_args = AnnotateArgs({let->value, let->body}); auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); - return new_expr; + return std::move(new_expr); } Expr VisitExpr_(const IfNode* op) final { @@ -205,7 +205,7 @@ class AnnotateTargetWrapper : public ExprMutator { auto new_expr = If(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1], std::get<1>(target_n_args)[2]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); - return new_expr; + return std::move(new_expr); } Expr VisitExpr_(const RefCreateNode* op) final { @@ -215,7 +215,7 @@ class AnnotateTargetWrapper : public ExprMutator { auto target_n_args = AnnotateArgs(Array({expr->value})); auto new_expr = RefCreate(std::get<1>(target_n_args)[0]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); - return new_expr; + return std::move(new_expr); } Expr VisitExpr_(const RefReadNode* op) final { @@ -225,7 +225,7 @@ class AnnotateTargetWrapper : public ExprMutator { auto target_n_args = AnnotateArgs(Array({expr->ref})); auto new_expr = RefRead(std::get<1>(target_n_args)[0]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); - return new_expr; + return std::move(new_expr); } Expr VisitExpr_(const RefWriteNode* op) final { @@ -235,7 +235,7 @@ class AnnotateTargetWrapper : public ExprMutator { auto target_n_args = AnnotateArgs(Array({expr->ref, expr->value})); auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]); op_expr_to_target_[new_expr] = std::get<0>(target_n_args); - return new_expr; + return std::move(new_expr); } private: From 3da42ed2dba38d0fa41b18bc7b96a0eb89cd0115 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 8 Apr 2020 20:25:42 +0000 Subject: [PATCH 15/17] support AnnotateTarget multiple runs --- src/relay/transforms/annotate_target.cc | 68 +++++++++++++++---- .../python/relay/test_pass_annotate_target.py | 27 ++++++++ 2 files changed, 83 insertions(+), 12 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index cae3b639a973..34d45d0e389f 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -33,8 +33,12 @@ namespace tvm { namespace relay { namespace annotate_target { -const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); -const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); +static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin"); +static const Op& compiler_end_op = Op::Get("annotation.compiler_end"); + +const PackedFunc* make_begin_op = + runtime::Registry::Get("relay.op.annotation._make.compiler_begin"); +const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end"); // A helper class to insert annotation boundaries for a program region that will // be handled by a specific compiler. @@ -59,18 +63,32 @@ class AnnotateTargetWrapper : public ExprMutator { std::string ref_target = ""; Array compiler_ends; for (auto arg : args) { - if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { - std::string arg_target = op_expr_to_target_[arg]; - compiler_ends.push_back(InsertAnnotation(arg, arg_target, end_op)); - if (ref_target == "") { - ref_target = arg_target; - } else if (ref_target != arg_target) { - ref_target = "default"; + std::string arg_target = "defualt"; + const CallNode* call = arg.as(); + + if (call && call->op == compiler_begin_op) { + // Argument is already compiler begin node meaning that this is not the first time + // running this pass, so we simply remove it and will add a new one later. + CHECK_EQ(call->args.size(), 1U); + const CallNode* end = call->args[0].as(); + if (end->op == compiler_end_op) { + arg_target = end->attrs.as()->compiler; } + compiler_ends.push_back(call->args[0]); + } else if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) { + arg_target = op_expr_to_target_[arg]; + compiler_ends.push_back(InsertAnnotation(arg, arg_target, make_end_op)); } else { // Input vars. compiler_ends.push_back(arg); } + + // Maintain reference target in case the target of the current node is unassigned. + if (ref_target == "") { + ref_target = arg_target; + } else if (ref_target != arg_target) { + ref_target = "default"; + } } // Determine compiler begin target. @@ -78,7 +96,7 @@ class AnnotateTargetWrapper : public ExprMutator { Array compiler_begins; for (const auto& end : compiler_ends) { - compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op)); + compiler_begins.push_back(InsertAnnotation(end, op_target, make_begin_op)); } return {op_target, compiler_begins}; @@ -94,8 +112,34 @@ class AnnotateTargetWrapper : public ExprMutator { // Supported targets for this node. The order implies the priority. std::vector supported_targets; + auto op_node = cn->op.as(); + + // This graph has annotations, meaning that this is not the first time running this pass. + if (op_node && cn->op == compiler_begin_op) { + // Bypass compiler begin due to lack of target information. It will be processed + // when the following op handling arguments. + CHECK_EQ(cn->args.size(), 1U); + return VisitExpr(cn->args[0]); + } else if (op_node && cn->op == compiler_end_op) { + // Override compiler end with the new target. + CHECK_EQ(cn->args.size(), 1U); + auto input_expr = VisitExpr(cn->args[0]); + CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end()); + return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op); + } + + // Peek the first argument. If it is compiler begin then this node had annotated by + // another target before, so we also consider that target as a supported target. + const CallNode* first_arg_call = cn->args[0].as(); + if (first_arg_call && first_arg_call->op == compiler_begin_op) { + std::string arg_target = first_arg_call->attrs.as()->compiler; + if (arg_target != "default") { + supported_targets.push_back(arg_target); + } + } + // Check which targets this op can be offloaded. - if (cn->op->IsInstance()) { + if (op_node) { // TVM operators: Check target specific op checking function and add to supported_targets // if it is supported. Op op = Downcast(cn->op); @@ -179,7 +223,7 @@ class AnnotateTargetWrapper : public ExprMutator { func = Downcast(new_e); new_body = func->body; if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) { - new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], end_op); + new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op); op_expr_to_target_[new_body] = op_expr_to_target_[func->body]; } } diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 49c12484b793..a147824429d6 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -337,6 +337,32 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_multiple_runs(): + @reg.register("nn.relu", "target.A") + def relu(attrs, args): # pylint: disable=unused-variable + return True + + @reg.register("add", "target.B") + def add(attrs, args): # pylint: disable=unused-variable + return True + + def before(): + x = relay.var("x", shape=(10, 5)) + a_1 = relay.nn.relu(x) + a_2 = relay.abs(a_1) + a_3 = relay.nn.relu(a_1) + out = relay.add(a_2, a_3) + + f = relay.Function([x], out) + mod = tvm.IRModule.from_expr(f) + return mod + + mod = transform.AnnotateTarget("A")(before()) + mod = transform.AnnotateTarget("B")(mod) + expected = transform.AnnotateTarget(["A", "B"])(before()) + assert tvm.ir.structural_equal(expected, mod) + + if __name__ == "__main__": test_extern_dnnl() <<<<<<< HEAD @@ -348,3 +374,4 @@ def after(): test_multiple_ends() test_type_propagation() test_tuple() + test_multiple_runs() From e2b7a7e5fedb29fcbd1b3a0873ed33eb5084ac87 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 9 Apr 2020 05:58:25 +0000 Subject: [PATCH 16/17] Add HasAttr and revert DNNL codegen --- python/tvm/relay/op/contrib/dnnl.py | 9 ++- src/relay/backend/contrib/dnnl/codegen.cc | 71 ++++++++--------------- src/relay/backend/vm/compiler.cc | 13 ++--- src/relay/transforms/annotate_target.cc | 3 + src/runtime/contrib/dnnl/dnnl.cc | 6 +- src/runtime/contrib/dnnl/dnnl_kernel.h | 4 +- 6 files changed, 45 insertions(+), 61 deletions(-) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index 45a8c8331f72..1aa71921806d 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -56,10 +56,17 @@ def _func_wrapper(attrs, args): return _func_wrapper -_register_external_op_helper("nn.batch_norm") _register_external_op_helper("nn.conv2d") _register_external_op_helper("nn.dense") _register_external_op_helper("nn.relu") _register_external_op_helper("add") _register_external_op_helper("subtract") _register_external_op_helper("multiply") + + +@reg.register("nn.batch_norm", "target.dnnl") +def batch_norm(attrs, args): + """Check if the external DNNL codegen should be used. + FIXME(@zhiics, @comaniac): Turn off due to not support of multiple outputs. + """ + return False diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index cd6412ce451a..73711749d9c4 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -53,19 +53,12 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { } void VisitExpr_(const TupleGetItemNode* op) final { - VisitExpr(op->tuple); - CHECK(out_.size() > static_cast(op->index)); - - // Only keep the item we want for the child node. - // FIXME(@comaniac): The other items should still be requried for the primary outputs. - auto item = out_[op->index]; - out_.clear(); - out_.push_back(item); + // Do nothing } void VisitExpr_(const CallNode* call) final { std::ostringstream decl_stream; - + std::ostringstream buf_stream; // Args: ID std::vector args; @@ -103,45 +96,20 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { } } - // Analyze the output buffers - std::vector out_types; - if (call->checked_type()->IsInstance()) { - auto type_node = call->checked_type().as(); - for (auto field : type_node->fields) { - CHECK(field->IsInstance()); - out_types.push_back(field); - } - } else if (call->checked_type()->IsInstance()) { - CHECK(call->checked_type()->IsInstance()); - out_types.push_back(call->checked_type()); - } else { - LOG(FATAL) << "Unrecognized type node: " << AsText(call->checked_type(), false); - } - - out_.clear(); - for (auto out_type : out_types) { - const auto& dtype = GetDtypeString(out_type.as()); - - std::string out = "buf_" + std::to_string(buf_idx_++); - auto out_shape = GetShape(out_type); - int out_size = 1; - for (size_t i = 0; i < out_shape.size(); ++i) { - out_size *= out_shape[i]; - } - this->PrintIndents(); - std::ostringstream buf_stream; - buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; - buf_decl_.push_back(buf_stream.str()); - decl_stream << ", " << out; - - // Update output buffer - Output output; - output.name = out; - output.dtype = dtype; - output.need_copy = true; - output.size = out_size; - out_.push_back(output); + // Analyze the output buffer + auto type_node = call->checked_type().as(); + CHECK(type_node); + const auto& dtype = GetDtypeString(type_node); + std::string out = "buf_" + std::to_string(buf_idx_++); + auto out_shape = GetShape(call->checked_type()); + int out_size = 1; + for (size_t i = 0; i < out_shape.size(); ++i) { + out_size *= out_shape[i]; } + this->PrintIndents(); + buf_stream << "float* " << out << " = (float*)std::malloc(4 * " << out_size << ");"; + buf_decl_.push_back(buf_stream.str()); + decl_stream << ", " << out; // Attach attribute arguments for (size_t i = 0; i < args.size(); ++i) { @@ -149,6 +117,15 @@ class CodegenDNNL : public ExprVisitor, public CodegenCBase { } decl_stream << ");"; ext_func_body.push_back(decl_stream.str()); + + // Update output buffer + out_.clear(); + Output output; + output.name = out; + output.dtype = dtype; + output.need_copy = true; + output.size = out_size; + out_.push_back(output); } std::string JIT(void) { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e2b0fffec8bd..3e020bb27954 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -924,13 +924,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::InlinePrimitives()); - // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); - // Compute away possibly introduced constant computation. - pass_seqs.push_back(transform::FoldConstant()); - // Fuse the shape functions. - pass_seqs.push_back(transform::FuseOps()); - // Inline the functions that are lifted to the module scope. We perform this // pass after all other optimization passes but before the memory allocation // pass. This is because memory allocation pass will insert `invoke_tvm_op` @@ -938,6 +931,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); + // Manifest the allocations. + pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); // Manifest the allocations needed for the shape functions. pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 34d45d0e389f..8abb41299857 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -145,6 +145,9 @@ class AnnotateTargetWrapper : public ExprMutator { Op op = Downcast(cn->op); CHECK(op.defined()); for (const auto& target : this->targets_) { + if (!Op::HasAttr("target." + std::string(target))) { + continue; + } auto fannotate = Op::GetAttr("target." + std::string(target)); if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) { supported_targets.push_back(target); diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 4dc023f5a512..cc430b2c7c76 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -169,11 +169,9 @@ extern "C" void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, read_from_dnnl_memory(out, dst_memory); } -extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, float* variance, - float* out, float* new_mean, float* new_variance, int p_N_, int p_C_, +extern "C" void dnnl_bn(float* data, float* gamma, float* beta, float* mean, + float* variance, float* out, int p_N_, int p_C_, int p_H_, int p_W_, int p_E_) { - // FIXME(@comaniac): BN has 3 outputs: out, new_mean and new_variance, but we do not update - // the rest two because no one cares about them for now. Should update it in the future. using tag = memory::format_tag; using dt = memory::data_type; diff --git a/src/runtime/contrib/dnnl/dnnl_kernel.h b/src/runtime/contrib/dnnl/dnnl_kernel.h index cf474f9e6843..4d0b100b92ec 100644 --- a/src/runtime/contrib/dnnl/dnnl_kernel.h +++ b/src/runtime/contrib/dnnl/dnnl_kernel.h @@ -44,8 +44,8 @@ extern "C" TVM_DLL void dnnl_dense(float* data, float* weight, float* out, int p extern "C" TVM_DLL void dnnl_relu(float* data, float* out, int p_N_, int p_C_, int p_H_, int p_W_); extern "C" TVM_DLL void dnnl_bn(float* data, float* gamma, float* beta, float* mean, - float* variance, float* out, float* new_mean, float* new_variance, - int p_n_, int p_c_, int p_h_, int p_w_, int p_e_); + float* variance, float* out, int p_n_, int p_c_, int p_h_, int p_w_, + int p_e_); extern "C" TVM_DLL void dnnl_add(float* data, float* weight, float* out, int p_n_, int p_c_, int p_h_, int p_w_); From 13cfa83dd8149df957f184e9d4d5f907fdf8ada8 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 9 Apr 2020 20:38:14 +0000 Subject: [PATCH 17/17] address comment --- src/relay/analysis/annotated_region_set.cc | 2 +- src/relay/analysis/annotated_region_set.h | 2 +- src/relay/transforms/annotate_target.cc | 7 ++++--- tests/python/relay/test_pass_annotate_target.py | 4 ---- 4 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/relay/analysis/annotated_region_set.cc b/src/relay/analysis/annotated_region_set.cc index c70ef8ae242e..94c7621e60af 100644 --- a/src/relay/analysis/annotated_region_set.cc +++ b/src/relay/analysis/annotated_region_set.cc @@ -79,7 +79,7 @@ void AnnotatedRegionSetNode::AddToRegion(AnnotatedRegion dest, const Expr& expr) } } -AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(std::string target) { +AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) { auto ret = regions_.emplace(AnnotatedRegion()); (*ret.first)->id_ = region_id_++; (*ret.first)->target_ = target; diff --git a/src/relay/analysis/annotated_region_set.h b/src/relay/analysis/annotated_region_set.h index cfd044e79776..3bd569387d46 100644 --- a/src/relay/analysis/annotated_region_set.h +++ b/src/relay/analysis/annotated_region_set.h @@ -197,7 +197,7 @@ class AnnotatedRegionSetNode : public Object { * * \return The new region. */ - AnnotatedRegion MakeRegion(std::string target); + AnnotatedRegion MakeRegion(const std::string& target); std::unordered_set regions_; /*! \brief The next region ID to assign. */ diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 8abb41299857..44d7b54e9637 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -158,11 +158,12 @@ class AnnotateTargetWrapper : public ExprMutator { // if it is in the target list. Function func = Downcast(cn->op); CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); + auto comp_name = func->GetAttr(attr::kComposite); if (comp_name.defined()) { - size_t i = comp_name->value.find('.'); + std::string comp_name_str = comp_name; + size_t i = comp_name_str.find('.'); if (i != std::string::npos) { - std::string comp_target = comp_name->value.substr(0, i); + std::string comp_target = comp_name_str.substr(0, i); for (const auto& target : this->targets_) { if (std::string(target) == comp_target) { supported_targets.push_back(comp_target); diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index a147824429d6..705a2614674a 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -365,10 +365,6 @@ def before(): if __name__ == "__main__": test_extern_dnnl() -<<<<<<< HEAD - #test_extern_dnnl_mobilenet() -======= ->>>>>>> Skip e2e test test_composite_function() #test_extern_dnnl_mobilenet() test_multiple_ends()