From 6e68fd9aff9f86412f8b7150b18ae1b374927f86 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Mar 2022 14:27:34 +0900 Subject: [PATCH 01/23] Decouple TE compute and schedule lowering in ScheduleBuilder --- src/relay/backend/te_compiler_cache.cc | 255 ++++++++++++++----------- 1 file changed, 141 insertions(+), 114 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index abab8cc6e0a02..a1de51de728d0 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -45,6 +45,7 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" +#include "tvm/relay/op_strategy.h" #include "utils.h" namespace tvm { @@ -115,99 +116,24 @@ Array GetShape(const Array& shape) { } // Construct a schedule for a given Relay primitive function and target. -class ScheduleBuilder : public backend::MemoizedExprTranslator> { +class LowerToTECompute : public backend::MemoizedExprTranslator> { public: - explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), - device_copy_op_(Op::Get("device_copy")), - create_schedule_(create_schedule) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - use_meta_schedule_ = backend::IsMetaScheduleEnabled(); - } + explicit LowerToTECompute(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) {} - CachedFunc Create(const Function& relay_func, std::function renamer) { - Array fn_inputs; + Array Lower(const Function& relay_func, + std::function renamer) { for (Var param : relay_func->params) { Array inputs; for (const auto& ttype : FlattenTupleType(param->checked_type())) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - fn_inputs.push_back(tensor); inputs.push_back(tensor); + fn_inputs_.push_back(tensor); } memo_[param] = inputs; } readable_name_stream_ << "fused"; - auto outputs = this->VisitExpr(relay_func->body); - auto candidate_name = readable_name_stream_.str(); - constexpr static size_t kMaxFuncNameLength = 80; - // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME - // whenever the value of kMaxFuncNameLength changes - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - - // TODO(mbs): This should be the definitive global by which the PrimFunc is known and - // no other GlobalVar ctors should appear inside the lowering machinery. - auto prim_fn_var = GlobalVar(renamer(candidate_name)); - prim_fn_var->checked_type_ = relay_func->checked_type(); - - // Fusion over tupled results may leave identity relationships - // between inputs and outputs, and those should not be scheduled. - // Hence schedule only non PlaceholderOp outputs. - tvm::Array tensor_outs; - for (const auto& tensor : outputs) { - if (!tensor->op.as()) { - tensor_outs.push_back(tensor); - } - } - - te::Schedule schedule{nullptr}; - tir::PrimFunc prim_func{nullptr}; - // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr && create_schedule_) { - if (use_auto_scheduler_) { - const auto* fauto_schedule = - runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); - ICHECK(fauto_schedule != nullptr) - << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); - if (obj.defined()) { - schedule = Downcast(obj); - } - } - if (use_meta_schedule_) { - prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); - Optional opt_mod_or_base_func = - meta_schedule::MetaScheduleContext::QueryInsideWithScope( - prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, - Array{IRModule({{prim_fn_var, prim_func}})}); - if (const auto* result = opt_mod_or_base_func.as()) { - prim_func = GetRef(result); - } else { - prim_func = tir::PrimFunc(nullptr); - } - } - - // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. - if (!schedule.defined() && !prim_func.defined()) { - ICHECK(anchor_implementation_.defined()); - schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); - } - if (schedule.defined()) { - for (const auto& scalar : scalars_) { - if (schedule->Contain(scalar)) { - schedule[scalar].compute_inline(); - } - } - } - } - - return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, - IRModule(Map({})), constant_tensors_); + return this->VisitExpr(relay_func->body); } Array VisitExpr_(const VarNode* op) final { @@ -254,7 +180,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator } Array VisitExpr_(const CallNode* call_node) final { - static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; @@ -278,28 +203,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - Array outputs; - OpImplementation impl; // TODO(mbs): device_copy cleanup ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; + Array outputs = lowered_out->outputs; + anchor_implementation_ = lowered_out->implementation; - if (create_schedule_) { - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - anchor_implementation_ = impl; - } - } if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); ICHECK(tuple_type) << "Expected output to be a tuple type " @@ -308,8 +218,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator ICHECK_EQ(tuple_type->fields.size(), outputs.size()); } - // TODO(mbs): device_copy cleanup - ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered"; readable_name_stream_ << '_' << op->name; return outputs; } @@ -347,27 +255,146 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator return {tuple[op->index]}; } + public: + // Additional outputs + Array fn_inputs_; + Array scalars_; + std::unordered_map constant_tensors_; + std::ostringstream readable_name_stream_; + OpImplementation anchor_implementation_; + + private: + tvm::Target target_; + // Index of the global constants + static int const_index; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; +}; + +int LowerToTECompute::const_index = 0; + +// Construct a schedule for a given Relay primitive function and target. +class ScheduleBuilder : ExprVisitor { + public: + explicit ScheduleBuilder(Target target, bool create_schedule = true) + : target_(target), + + create_schedule_(create_schedule) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); + use_meta_schedule_ = backend::IsMetaScheduleEnabled(); + } + + CachedFunc Create(const Function& relay_func, std::function renamer) { + LowerToTECompute lower_te_compute(target_); + Array outputs = lower_te_compute.Lower(relay_func, renamer); + std::string candidate_name = lower_te_compute.readable_name_stream_.str(); + VisitExpr(relay_func->body); + + constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes + if (candidate_name.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; + candidate_name = truncated_name.str(); + } + + // TODO(mbs): This should be the definitive global by which the PrimFunc is known and + // no other GlobalVar ctors should appear inside the lowering machinery. + auto prim_fn_var = GlobalVar(renamer(candidate_name)); + prim_fn_var->checked_type_ = relay_func->checked_type(); + + // Fusion over tupled results may leave identity relationships + // between inputs and outputs, and those should not be scheduled. + // Hence schedule only non PlaceholderOp outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + + te::Schedule schedule{nullptr}; + tir::PrimFunc prim_func{nullptr}; + // No need to register schedule for device copy op. + if (anchor_attrs_.as() == nullptr && create_schedule_) { + if (use_auto_scheduler_) { + const auto* fauto_schedule = + runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); + ICHECK(fauto_schedule != nullptr) + << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; + ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); + if (obj.defined()) { + schedule = Downcast(obj); + } + } + if (use_meta_schedule_) { + prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs); + Optional opt_mod_or_base_func = + meta_schedule::MetaScheduleContext::QueryInsideWithScope( + prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, + Array{IRModule({{prim_fn_var, prim_func}})}); + if (const auto* result = opt_mod_or_base_func.as()) { + prim_func = GetRef(result); + } else { + prim_func = tir::PrimFunc(nullptr); + } + } + + // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. + if (!schedule.defined() && !prim_func.defined()) { + ICHECK(lower_te_compute.anchor_implementation_.defined()); + schedule = + lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + } + if (schedule.defined()) { + for (const auto& scalar : lower_te_compute.scalars_) { + if (schedule->Contain(scalar)) { + schedule[scalar].compute_inline(); + } + } + } + } + + return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule, + prim_func, {}, IRModule(Map({})), + lower_te_compute.constant_tensors_); + } + + void VisitExpr_(const CallNode* call_node) final { + static auto fpattern = Op::GetAttrMap("TOpPattern"); + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + if (create_schedule_) { + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; + } + } + } + private: tvm::Target target_; Op anchor_op_; Attrs anchor_attrs_; int anchor_op_pattern_{0}; - OpImplementation anchor_implementation_; - std::ostringstream readable_name_stream_; - Array scalars_; - std::unordered_map constant_tensors_; bool use_auto_scheduler_; bool use_meta_schedule_; - // Cache device copy op for equivalence checking to reduce registry lookup - // overhead for each invocation of call node when retrieving schedules. - const Op& device_copy_op_; bool create_schedule_; - // Index of the global constants - static int const_index; }; -int ScheduleBuilder::const_index = 0; - /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. From eb1bc7e789b66eaf3d4fe01d5154c135ab275dc2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Mar 2022 18:13:42 +0900 Subject: [PATCH 02/23] fixed merge conflict --- src/relay/backend/te_compiler_cache.cc | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index a1de51de728d0..2c2042859ddb4 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -278,9 +278,7 @@ int LowerToTECompute::const_index = 0; class ScheduleBuilder : ExprVisitor { public: explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), - - create_schedule_(create_schedule) { + : target_(target), create_schedule_(create_schedule) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); use_meta_schedule_ = backend::IsMetaScheduleEnabled(); @@ -289,6 +287,7 @@ class ScheduleBuilder : ExprVisitor { CachedFunc Create(const Function& relay_func, std::function renamer) { LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); + Array fn_inputs = lower_te_compute.fn_inputs_; std::string candidate_name = lower_te_compute.readable_name_stream_.str(); VisitExpr(relay_func->body); @@ -332,7 +331,7 @@ class ScheduleBuilder : ExprVisitor { } } if (use_meta_schedule_) { - prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs); + prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); Optional opt_mod_or_base_func = meta_schedule::MetaScheduleContext::QueryInsideWithScope( prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, @@ -359,9 +358,8 @@ class ScheduleBuilder : ExprVisitor { } } - return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule, - prim_func, {}, IRModule(Map({})), - lower_te_compute.constant_tensors_); + return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {}, + IRModule(Map({})), lower_te_compute.constant_tensors_); } void VisitExpr_(const CallNode* call_node) final { From 4cd3a1657c4e2e13abe7281b7cdef5dff73b37ee Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 10 Mar 2022 18:43:15 +0900 Subject: [PATCH 03/23] removed create_schedule stuff --- src/relay/backend/te_compiler_cache.cc | 76 +++++++++++++------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 2c2042859ddb4..fc3a3ab335f4f 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -46,6 +46,7 @@ #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" #include "tvm/relay/op_strategy.h" +#include "tvm/tir/function.h" #include "utils.h" namespace tvm { @@ -115,7 +116,7 @@ Array GetShape(const Array& shape) { return res; } -// Construct a schedule for a given Relay primitive function and target. +// Lowers Relay primitive Function to TE Compute class LowerToTECompute : public backend::MemoizedExprTranslator> { public: explicit LowerToTECompute(Target target) @@ -133,7 +134,21 @@ class LowerToTECompute : public backend::MemoizedExprTranslatorVisitExpr(relay_func->body); + + Array outputs = this->VisitExpr(relay_func->body); + + candidate_name_ = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; + // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME + // whenever the value of kMaxFuncNameLength changes + if (candidate_name_.size() > kMaxFuncNameLength) { + std::stringstream truncated_name; + truncated_name << candidate_name_.substr(0, kMaxFuncNameLength); + truncated_name << "_" << std::hex << std::hash{}(candidate_name_) << "_"; + candidate_name_ = truncated_name.str(); + } + + return outputs; } Array VisitExpr_(const VarNode* op) final { @@ -260,11 +275,12 @@ class LowerToTECompute : public backend::MemoizedExprTranslator fn_inputs_; Array scalars_; std::unordered_map constant_tensors_; - std::ostringstream readable_name_stream_; + std::string candidate_name_; OpImplementation anchor_implementation_; private: tvm::Target target_; + std::ostringstream readable_name_stream_; // Index of the global constants static int const_index; // Cache device copy op for equivalence checking to reduce registry lookup @@ -277,33 +293,20 @@ int LowerToTECompute::const_index = 0; // Construct a schedule for a given Relay primitive function and target. class ScheduleBuilder : ExprVisitor { public: - explicit ScheduleBuilder(Target target, bool create_schedule = true) - : target_(target), create_schedule_(create_schedule) { + explicit ScheduleBuilder(Target target) : target_(target) { // Whether to use auto_scheduler schedule. use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); - use_meta_schedule_ = backend::IsMetaScheduleEnabled(); } CachedFunc Create(const Function& relay_func, std::function renamer) { LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); Array fn_inputs = lower_te_compute.fn_inputs_; - std::string candidate_name = lower_te_compute.readable_name_stream_.str(); VisitExpr(relay_func->body); - constexpr static size_t kMaxFuncNameLength = 80; - // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME - // whenever the value of kMaxFuncNameLength changes - if (candidate_name.size() > kMaxFuncNameLength) { - std::stringstream truncated_name; - truncated_name << candidate_name.substr(0, kMaxFuncNameLength); - truncated_name << "_" << std::hex << std::hash{}(candidate_name) << "_"; - candidate_name = truncated_name.str(); - } - // TODO(mbs): This should be the definitive global by which the PrimFunc is known and // no other GlobalVar ctors should appear inside the lowering machinery. - auto prim_fn_var = GlobalVar(renamer(candidate_name)); + auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_)); prim_fn_var->checked_type_ = relay_func->checked_type(); // Fusion over tupled results may leave identity relationships @@ -319,7 +322,7 @@ class ScheduleBuilder : ExprVisitor { te::Schedule schedule{nullptr}; tir::PrimFunc prim_func{nullptr}; // No need to register schedule for device copy op. - if (anchor_attrs_.as() == nullptr && create_schedule_) { + if (anchor_attrs_.as() == nullptr) { if (use_auto_scheduler_) { const auto* fauto_schedule = runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); @@ -330,7 +333,7 @@ class ScheduleBuilder : ExprVisitor { schedule = Downcast(obj); } } - if (use_meta_schedule_) { + if (backend::IsMetaScheduleEnabled()) { prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); Optional opt_mod_or_base_func = meta_schedule::MetaScheduleContext::QueryInsideWithScope( @@ -368,18 +371,16 @@ class ScheduleBuilder : ExprVisitor { ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - if (create_schedule_) { - int op_pattern = fpattern[op]; - if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { - ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) - << "Cannot apply TOPI schedule to a primitive function with two complicated ops" - << " anchor=" << anchor_op_ << " current=" << op; - } - if (op_pattern >= anchor_op_pattern_) { - anchor_op_ = op; - anchor_attrs_ = call_node->attrs; - anchor_op_pattern_ = op_pattern; - } + int op_pattern = fpattern[op]; + if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { + ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) + << "Cannot apply TOPI schedule to a primitive function with two complicated ops" + << " anchor=" << anchor_op_ << " current=" << op; + } + if (op_pattern >= anchor_op_pattern_) { + anchor_op_ = op; + anchor_attrs_ = call_node->attrs; + anchor_op_pattern_ = op_pattern; } } @@ -389,8 +390,6 @@ class ScheduleBuilder : ExprVisitor { Attrs anchor_attrs_; int anchor_op_pattern_{0}; bool use_auto_scheduler_; - bool use_meta_schedule_; - bool create_schedule_; }; /*! @@ -775,9 +774,12 @@ std::string GetUniqueName(std::string name, std::unordered_map } TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { - return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) { - return name; - }); + auto tgt = tvm::Target("ext_dev"); + LowerToTECompute lower_te_compute(tgt); + auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; }); + return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_, + outputs, te::Schedule(), tir::PrimFunc(), {}, + IRModule(Map({})), lower_te_compute.constant_tensors_); }); } // namespace tec From 0c6d4a603335ae2cba2771e939eff1ddeb98fbe3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 10:45:08 +0900 Subject: [PATCH 04/23] add public, fix include path convention --- src/relay/backend/te_compiler_cache.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index fc3a3ab335f4f..276c7f9f017e4 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -28,11 +28,13 @@ #include #include #include +#include #include #include #include #include #include +#include #include #include @@ -45,8 +47,6 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" -#include "tvm/relay/op_strategy.h" -#include "tvm/tir/function.h" #include "utils.h" namespace tvm { @@ -138,6 +138,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator outputs = this->VisitExpr(relay_func->body); candidate_name_ = readable_name_stream_.str(); + constexpr static size_t kMaxFuncNameLength = 80; // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME // whenever the value of kMaxFuncNameLength changes @@ -291,7 +292,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator Date: Fri, 11 Mar 2022 10:57:02 +0900 Subject: [PATCH 05/23] Forgot visiting arg in ScheduleBuilder CallNode vsit --- src/relay/backend/te_compiler_cache.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 276c7f9f017e4..74b9013b3659c 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -300,6 +300,7 @@ class ScheduleBuilder : public ExprVisitor { } CachedFunc Create(const Function& relay_func, std::function renamer) { + LOG(INFO) << relay_func; LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); Array fn_inputs = lower_te_compute.fn_inputs_; @@ -350,6 +351,8 @@ class ScheduleBuilder : public ExprVisitor { // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { ICHECK(lower_te_compute.anchor_implementation_.defined()); + LOG(INFO) << lower_te_compute.candidate_name_; + LOG(INFO) << anchor_attrs_; schedule = lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); } @@ -372,6 +375,10 @@ class ScheduleBuilder : public ExprVisitor { ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); + for (Expr arg : call_node->args) { + VisitExpr(arg); + } + int op_pattern = fpattern[op]; if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) From 6f019014a4614f43aefcf642981bfb15d64b09f3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 11:25:44 +0900 Subject: [PATCH 06/23] fixed anchor impl selection --- src/relay/backend/te_compiler_cache.cc | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 74b9013b3659c..ffcce6e1c8dab 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -224,7 +224,7 @@ class LowerToTECompute : public backend::MemoizedExprTranslator(call_node), inputs, target_); Array outputs = lowered_out->outputs; - anchor_implementation_ = lowered_out->implementation; + op_implementations_[op.operator->()] = lowered_out->implementation; if (outputs.size() != 1) { const auto* tuple_type = call_node->checked_type().as(); @@ -276,8 +276,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator fn_inputs_; Array scalars_; std::unordered_map constant_tensors_; + std::unordered_map op_implementations_; std::string candidate_name_; - OpImplementation anchor_implementation_; private: tvm::Target target_; @@ -300,7 +300,6 @@ class ScheduleBuilder : public ExprVisitor { } CachedFunc Create(const Function& relay_func, std::function renamer) { - LOG(INFO) << relay_func; LowerToTECompute lower_te_compute(target_); Array outputs = lower_te_compute.Lower(relay_func, renamer); Array fn_inputs = lower_te_compute.fn_inputs_; @@ -350,11 +349,9 @@ class ScheduleBuilder : public ExprVisitor { // Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule. if (!schedule.defined() && !prim_func.defined()) { - ICHECK(lower_te_compute.anchor_implementation_.defined()); - LOG(INFO) << lower_te_compute.candidate_name_; - LOG(INFO) << anchor_attrs_; - schedule = - lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); + auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->()); + ICHECK(anchor_impl != lower_te_compute.op_implementations_.end()); + schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_); } if (schedule.defined()) { for (const auto& scalar : lower_te_compute.scalars_) { From 109187fc0463728cd44171389e8fc91fb0ac8cf9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 02:21:58 +0900 Subject: [PATCH 07/23] New relay backend for meta schedule task extraction --- python/tvm/meta_schedule/integration.py | 19 +++- .../backend/metaschedule_task_extraction.cc | 86 +++++++++++++++++++ src/relay/backend/te_compiler_cache.cc | 18 ++++ src/relay/backend/te_compiler_cache.h | 2 + 4 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 src/relay/backend/metaschedule_task_extraction.cc diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 26b01444e7527..d3f21a21493ae 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -18,7 +18,10 @@ from contextlib import contextmanager from typing import Callable, Dict, List, Optional, Union -from tvm._ffi import register_object +import numpy as np +import tvm.runtime.ndarray as nd + +from tvm._ffi import register_object, get_global_func from tvm.ir import IRModule, transform from tvm.relay import Any from tvm.relay import Function as RelayFunc @@ -230,6 +233,20 @@ def extract_task_from_relay( The tasks extracted from this network """ + extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask") + assert extract_task_func + + target = Target(target) if isinstance(target, str) else target + + for name, param in params.items(): + if isinstance(param, np.ndarray): + params[name] = nd.array(param) + + with transform.PassContext(opt_level=opt_level): + with target: + tasks = extract_task_func(mod, target, params) + return tasks + @contextmanager def _autotvm_silencer(): from tvm import autotvm # pylint: disable=import-outside-toplevel diff --git a/src/relay/backend/metaschedule_task_extraction.cc b/src/relay/backend/metaschedule_task_extraction.cc new file mode 100644 index 0000000000000..509ef6259e860 --- /dev/null +++ b/src/relay/backend/metaschedule_task_extraction.cc @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../te/operation/create_primfunc.h" +#include "te_compiler_cache.h" +#include "utils.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace metaschedule { + +using meta_schedule::ExtractedTask; + +Array ExtractTask(IRModule mod, Target target, Map params) { + if (params.size()) { + std::unordered_map params_; + BaseFunc base_func = mod->Lookup("main"); + ICHECK(base_func->IsInstance()); + auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } + + Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); + pass_seqs.push_back(transform::FuseOps()); + + transform::Sequential seq(pass_seqs); + auto opt_mod = seq(std::move(mod)); + + Array tasks; + LOG(INFO) << opt_mod; + LOG(INFO) << opt_mod->Lookup("main"); + PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) { + if (exp->IsInstance()) { + Function relay_func = Downcast(exp); + if (relay_func->HasNonzeroAttr(attr::kPrimitive)) { + LOG(INFO) << relay_func; + Array outputs; + std::string fused_name; + std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func); + LOG(INFO) << fused_name; + LOG(INFO) << outputs; + auto prim_func = tir::CreatePrimFunc(outputs); + auto prim_fn_var = GlobalVar(fused_name); + auto relay_mod = IRModule({{prim_fn_var, relay_func}}); + auto tir_mod = IRModule({{prim_fn_var, prim_func}}); + tasks.push_back(ExtractedTask(prim_fn_var->name_hint, relay_mod, target, {tir_mod})); + } + } + }); + + return tasks; +} + +TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask") + .set_body_typed([](IRModule mod, Target target, Map params) { + return ExtractTask(mod, target, params); + }); + +} // namespace metaschedule +} // namespace backend +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index ffcce6e1c8dab..e364cf60f98d2 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -754,6 +754,24 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, return MakeShapeFunc().Create(prim_func, target, renamer); } +std::pair, std::string> LowerTECompute(Target target, const Function& relay_func, + bool return_inputs) { + LowerToTECompute lower_te_compute(target); + auto outputs = lower_te_compute.Lower(relay_func, [&](std::string name) { return name; }); + // Following ScheduleBuilder, remove placeholder ops from outputs. + tvm::Array tensor_outs; + for (const auto& tensor : outputs) { + if (!tensor->op.as()) { + tensor_outs.push_back(tensor); + } + } + if (return_inputs) { + return std::make_pair(Concat(lower_te_compute.fn_inputs_, tensor_outs), + lower_te_compute.candidate_name_); + } + return std::make_pair(tensor_outs, lower_te_compute.candidate_name_); +} + /*! * \brief Get unique name from name. * \param name The orginal name. diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 2ffca1aa6be72..d75e3d2ebbfcb 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -204,6 +204,8 @@ class CCacheValue : public ObjectRef { Array GetShape(const Array& shape); +std::pair, std::string> LowerTECompute(Target target, const Function& relay_func, bool return_inputs=true); + /*! * \brief Create schedule for target. * \param source_func The primitive function to be lowered. From efecceaea3958e184de7ef0ff6cb5f3988640afa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 03:56:05 +0900 Subject: [PATCH 08/23] refactor param binding --- src/relay/backend/build_module.cc | 9 +-------- .../backend/metaschedule_task_extraction.cc | 6 +----- src/relay/backend/utils.h | 19 +++++++++++++++++++ src/relay/backend/vm/compiler.cc | 9 +-------- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 89ee61c83f7c7..87fe39c389f04 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -333,14 +333,7 @@ class RelayBuildModule : public runtime::ModuleNode { IRModule OptimizeImpl(IRModule relay_module) { ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; - if (!params_.empty()) { - ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; - GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); - Function main_func = Downcast(relay_module->Lookup(main_glb_var)); - auto new_main = BindParamsByName(main_func, params_); - IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite(); - relay_module_ptr->Update(main_glb_var, new_main); - } + backend::BindParamsInModule(relay_module, params_); Array pass_seqs = GetPassPrefix( /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); diff --git a/src/relay/backend/metaschedule_task_extraction.cc b/src/relay/backend/metaschedule_task_extraction.cc index 509ef6259e860..c0c7a525f3d3c 100644 --- a/src/relay/backend/metaschedule_task_extraction.cc +++ b/src/relay/backend/metaschedule_task_extraction.cc @@ -35,6 +35,7 @@ namespace metaschedule { using meta_schedule::ExtractedTask; Array ExtractTask(IRModule mod, Target target, Map params) { + // backend::BindParamsInModule(mod, params); if (params.size()) { std::unordered_map params_; BaseFunc base_func = mod->Lookup("main"); @@ -51,18 +52,13 @@ Array ExtractTask(IRModule mod, Target target, Map tasks; - LOG(INFO) << opt_mod; - LOG(INFO) << opt_mod->Lookup("main"); PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); if (relay_func->HasNonzeroAttr(attr::kPrimitive)) { - LOG(INFO) << relay_func; Array outputs; std::string fused_name; std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func); - LOG(INFO) << fused_name; - LOG(INFO) << outputs; auto prim_func = tir::CreatePrimFunc(outputs); auto prim_fn_var = GlobalVar(fused_name); auto relay_mod = IRModule({{prim_fn_var, relay_func}}); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 3b4d4c18de895..f15ae4765addd 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -417,6 +417,25 @@ inline relay::Function BindParamsByName( return ret; } +inline void BindParamsInModule(IRModule mod, + const std::unordered_map& params) { + if (!params.empty()) { + BaseFunc base_func = mod->Lookup("main"); + ICHECK(base_func->IsInstance()); + auto f = relay::backend::BindParamsByName(Downcast(base_func), params); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } +} + +inline void BindParamsInModule(IRModule mod, Map params) { + std::unordered_map params_tmp; + for (const auto& kv : params) { + params_tmp[kv.first] = kv.second->data; + } + BindParamsInModule(mod, params_tmp); +} + /*! * \brief Extract the shape from a Relay tensor type. * \param type The provided type. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e94919de7f20f..130fb09e7af17 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1034,14 +1034,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { VLOG_CONTEXT << "VM Optimize"; - if (params_.size()) { - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()) - << "VM compiler expects to compile relay::Function"; - auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } + backend::BindParamsInModule(mod, params_); Array pass_seqs = relay::backend::GetPassPrefix( /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); From 4a5e4aae48a7bdc8c24c8f7ae7bd5484034837e4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 05:10:30 +0900 Subject: [PATCH 09/23] move BindParams function to cc file --- src/relay/backend/utils.cc | 50 ++++++++++++++++++++++++++++++++++++ src/relay/backend/utils.h | 52 ++++---------------------------------- 2 files changed, 55 insertions(+), 47 deletions(-) diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 7662018e4f71f..9883fe85c253d 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -308,6 +308,56 @@ std::vector ShapeToJSON(tvm::Array shape) { return ret; } +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params) { + std::unordered_map name_dict; + std::unordered_set repeat_var; + for (auto arg : func->params) { + const auto& name = arg->name_hint(); + if (name_dict.count(name)) { + repeat_var.insert(name_dict[name]); + } else { + name_dict[name] = arg; + } + } + + std::unordered_map bind_dict; + for (auto& kv : params) { + if (name_dict.count(kv.first) == 0) { + continue; + } + auto arg = name_dict.at(kv.first); + if (repeat_var.count(arg)) { + LOG(FATAL) << "Multiple args in the function have name " << kv.first; + } + bind_dict[arg] = Constant(kv.second); + } + Expr bound_expr = relay::Bind(func, bind_dict); + Function ret = Downcast(bound_expr); + ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." + << "\n"; + return ret; +} + +void BindParamsInModule(IRModule mod, + const std::unordered_map& params) { + if (!params.empty()) { + BaseFunc base_func = mod->Lookup("main"); + ICHECK(base_func->IsInstance()); + auto f = relay::backend::BindParamsByName(Downcast(base_func), params); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } +} + +void BindParamsInModule(IRModule mod, Map params) { + std::unordered_map params_tmp; + for (const auto& kv : params) { + params_tmp[kv.first] = kv.second->data; + } + BindParamsInModule(mod, params_tmp); +} + } // namespace backend } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index f15ae4765addd..cfbf0a9007779 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -386,55 +386,13 @@ inline std::string DType2String(const tvm::DataType dtype) { * \param params params dict * \return relay::Function */ -inline relay::Function BindParamsByName( - relay::Function func, const std::unordered_map& params) { - std::unordered_map name_dict; - std::unordered_set repeat_var; - for (auto arg : func->params) { - const auto& name = arg->name_hint(); - if (name_dict.count(name)) { - repeat_var.insert(name_dict[name]); - } else { - name_dict[name] = arg; - } - } - - std::unordered_map bind_dict; - for (auto& kv : params) { - if (name_dict.count(kv.first) == 0) { - continue; - } - auto arg = name_dict.at(kv.first); - if (repeat_var.count(arg)) { - LOG(FATAL) << "Multiple args in the function have name " << kv.first; - } - bind_dict[arg] = Constant(kv.second); - } - Expr bound_expr = relay::Bind(func, bind_dict); - Function ret = Downcast(bound_expr); - ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." - << "\n"; - return ret; -} +relay::Function BindParamsByName(relay::Function func, + const std::unordered_map& params); -inline void BindParamsInModule(IRModule mod, - const std::unordered_map& params) { - if (!params.empty()) { - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()); - auto f = relay::backend::BindParamsByName(Downcast(base_func), params); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } -} +void BindParamsInModule(IRModule mod, + const std::unordered_map& params); -inline void BindParamsInModule(IRModule mod, Map params) { - std::unordered_map params_tmp; - for (const auto& kv : params) { - params_tmp[kv.first] = kv.second->data; - } - BindParamsInModule(mod, params_tmp); -} +void BindParamsInModule(IRModule mod, Map params); /*! * \brief Extract the shape from a Relay tensor type. From f099537d3630d268ad0700c75e93bbdc67831837 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 05:10:44 +0900 Subject: [PATCH 10/23] remove unused stuff from python extract_tasks_from_relay --- python/tvm/meta_schedule/integration.py | 47 +++++-------------------- 1 file changed, 8 insertions(+), 39 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index d3f21a21493ae..e5d98624f7107 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -242,50 +242,19 @@ def extract_task_from_relay( if isinstance(param, np.ndarray): params[name] = nd.array(param) - with transform.PassContext(opt_level=opt_level): - with target: - tasks = extract_task_func(mod, target, params) - return tasks - - @contextmanager - def _autotvm_silencer(): - from tvm import autotvm # pylint: disable=import-outside-toplevel - - silent = autotvm.GLOBAL_SCOPE.silent - autotvm.GLOBAL_SCOPE.silent = True - try: - yield - finally: - autotvm.GLOBAL_SCOPE.silent = silent - - def _thread_run(func: Callable[[], None]) -> None: - import threading # pylint: disable=import-outside-toplevel - - thread = threading.Thread(target=func) - thread.start() - thread.join() - if disabled_pass is None: disabled_pass = [] - if pass_config is None: - pass_config = {"relay.backend.use_meta_schedule": True} - env = TaskExtraction() if isinstance(mod, RelayFunc): mod = IRModule.from_expr(mod) if not isinstance(target, Target): target = Target(target) - def _func(): - with env, _autotvm_silencer(), transform.PassContext( - config=pass_config, - disabled_pass=disabled_pass, - opt_level=opt_level, - ): - compiler = vm.VMCompiler() - if params: - compiler.set_params(params) - compiler.lower(mod, target) - - _thread_run(_func) - return env.tasks + with transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, + ): + with target: + tasks = extract_task_func(mod, target, params) + return tasks From 57f2882a5ed5615ef8eee96cd7284d495f908449 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 05:24:37 +0900 Subject: [PATCH 11/23] fixed constant param bind --- python/tvm/meta_schedule/integration.py | 13 +++++++------ src/relay/backend/metaschedule_task_extraction.cc | 10 +--------- 2 files changed, 8 insertions(+), 15 deletions(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index e5d98624f7107..2477aa11536cd 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -23,7 +23,7 @@ from tvm._ffi import register_object, get_global_func from tvm.ir import IRModule, transform -from tvm.relay import Any +from tvm.relay import Any, const from tvm.relay import Function as RelayFunc from tvm.relay import vm from tvm.runtime import NDArray, Object @@ -238,9 +238,11 @@ def extract_task_from_relay( target = Target(target) if isinstance(target, str) else target + relay_params = {} for name, param in params.items(): if isinstance(param, np.ndarray): - params[name] = nd.array(param) + param = nd.array(param) + relay_params[name] = const(param) if disabled_pass is None: disabled_pass = [] @@ -250,11 +252,10 @@ def extract_task_from_relay( if not isinstance(target, Target): target = Target(target) - with transform.PassContext( + with target, transform.PassContext( opt_level=opt_level, config=pass_config, disabled_pass=disabled_pass, ): - with target: - tasks = extract_task_func(mod, target, params) - return tasks + tasks = extract_task_func(mod, target, relay_params) + return tasks diff --git a/src/relay/backend/metaschedule_task_extraction.cc b/src/relay/backend/metaschedule_task_extraction.cc index c0c7a525f3d3c..1492a824bec55 100644 --- a/src/relay/backend/metaschedule_task_extraction.cc +++ b/src/relay/backend/metaschedule_task_extraction.cc @@ -35,15 +35,7 @@ namespace metaschedule { using meta_schedule::ExtractedTask; Array ExtractTask(IRModule mod, Target target, Map params) { - // backend::BindParamsInModule(mod, params); - if (params.size()) { - std::unordered_map params_; - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()); - auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } + backend::BindParamsInModule(mod, params); Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); From 3c5a3184fb42e69ef10619b05b9b9f128f7ea618 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 05:27:53 +0900 Subject: [PATCH 12/23] rename to task extraction --- .../{metaschedule_task_extraction.cc => task_extraction.cc} | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) rename src/relay/backend/{metaschedule_task_extraction.cc => task_extraction.cc} (97%) diff --git a/src/relay/backend/metaschedule_task_extraction.cc b/src/relay/backend/task_extraction.cc similarity index 97% rename from src/relay/backend/metaschedule_task_extraction.cc rename to src/relay/backend/task_extraction.cc index 1492a824bec55..ea77dca6841c9 100644 --- a/src/relay/backend/metaschedule_task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -30,6 +30,7 @@ namespace tvm { namespace relay { namespace backend { + namespace metaschedule { using meta_schedule::ExtractedTask; @@ -63,12 +64,13 @@ Array ExtractTask(IRModule mod, Target target, Map params) { - return ExtractTask(mod, target, params); + return metaschedule::ExtractTask(mod, target, params); }); -} // namespace metaschedule } // namespace backend } // namespace relay } // namespace tvm From 7b4d35eb00852db6397d43e0aa6b1fedabae3f63 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 05:42:56 +0900 Subject: [PATCH 13/23] add doc to util functions --- src/relay/backend/task_extraction.cc | 4 +++- src/relay/backend/te_compiler_cache.cc | 4 ++-- src/relay/backend/te_compiler_cache.h | 10 +++++++++- src/relay/backend/utils.h | 5 +++++ 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index ea77dca6841c9..d8ef33da3f9d4 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -38,6 +38,7 @@ using meta_schedule::ExtractedTask; Array ExtractTask(IRModule mod, Target target, Map params) { backend::BindParamsInModule(mod, params); + // is_vm=true for backward compatibility Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); @@ -51,7 +52,8 @@ Array ExtractTask(IRModule mod, Target target, MapHasNonzeroAttr(attr::kPrimitive)) { Array outputs; std::string fused_name; - std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func); + std::tie(outputs, fused_name) = + tec::LowerTECompute(relay_func, target, /*return_inputs*/ true); auto prim_func = tir::CreatePrimFunc(outputs); auto prim_fn_var = GlobalVar(fused_name); auto relay_mod = IRModule({{prim_fn_var, relay_func}}); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index e364cf60f98d2..b287eccf5454f 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -754,10 +754,10 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, return MakeShapeFunc().Create(prim_func, target, renamer); } -std::pair, std::string> LowerTECompute(Target target, const Function& relay_func, +std::pair, std::string> LowerTECompute(const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); - auto outputs = lower_te_compute.Lower(relay_func, [&](std::string name) { return name; }); + auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; }); // Following ScheduleBuilder, remove placeholder ops from outputs. tvm::Array tensor_outs; for (const auto& tensor : outputs) { diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index d75e3d2ebbfcb..0f4763d1257cc 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -204,7 +204,15 @@ class CCacheValue : public ObjectRef { Array GetShape(const Array& shape); -std::pair, std::string> LowerTECompute(Target target, const Function& relay_func, bool return_inputs=true); +/*! + * \brief Lowers Relay primitive Function to TE Compute + * \param source_func The primitive function to be lowered. + * \param target The target we want to create schedule for. + * \param return_inputs If true, prepend input tensors to the output array of tensors. + * \return Pair of schedule and fused function name. + */ +std::pair, std::string> LowerTECompute(const Function& source_func, Target target, + bool return_inputs = true); /*! * \brief Create schedule for target. diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index cfbf0a9007779..1a39d4330d454 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -389,6 +389,11 @@ inline std::string DType2String(const tvm::DataType dtype) { relay::Function BindParamsByName(relay::Function func, const std::unordered_map& params); +/*! + * \brief Bind params to the main function in Relay module, using BindParamsByName + * \param mod Relay module + * \param params params dict + */ void BindParamsInModule(IRModule mod, const std::unordered_map& params); From af3e98867f91f99522fee4da2e170dc87311466c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 07:36:35 +0900 Subject: [PATCH 14/23] Removed TaskExtraction node --- include/tvm/meta_schedule/integration.h | 32 ------------------------- python/tvm/meta_schedule/integration.py | 11 --------- src/meta_schedule/integration.cc | 23 ------------------ 3 files changed, 66 deletions(-) diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index 9a8699b2fab94..bb384ba3ae837 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -145,38 +145,6 @@ class MetaScheduleContext : public runtime::ObjectRef { void ExitWithScope(); }; -/**************** TaskExtraction ****************/ - -/*! - * \brief An integration context for task extraction - */ -class TaskExtractionNode : public MetaScheduleContextNode { - public: - /*! \brief The extracted tasks */ - Array tasks{nullptr}; - - void VisitAttrs(AttrVisitor* v) { v->Visit("tasks", &tasks); } - - // Inherited from base class - Optional Query(runtime::String task_name, IRModule mod, Target target, - Optional> dispatched) final; - - static constexpr const char* _type_key = "meta_schedule.TaskExtraction"; - TVM_DECLARE_FINAL_OBJECT_INFO(TaskExtractionNode, MetaScheduleContextNode); -}; - -/*! - * \brief Managed reference to TaskExtractionNode - * \sa TaskExtractionNode - */ -class TaskExtraction : public MetaScheduleContext { - public: - /*! \brief The path to a cache file storing extracted tasks */ - TaskExtraction(); - TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TaskExtraction, MetaScheduleContext, - TaskExtractionNode); -}; - /**************** ApplyHistoryBest ****************/ /*! diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index 2477aa11536cd..b39f318e0f01c 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -179,17 +179,6 @@ def __exit__(self, ptype, value, trace) -> None: _ffi_api.MetaScheduleContextExitScope(self) # type: ignore # pylint: disable=no-member -@register_object("meta_schedule.TaskExtraction") -class TaskExtraction(MetaScheduleContext): - """An integration context for task extraction""" - - tasks: List[ExtractedTask] - """The extracted tasks""" - - def __init__(self) -> None: - self.__init_handle_by_constructor__(_ffi_api.TaskExtraction) # type: ignore # pylint: disable=no-member - - @register_object("meta_schedule.ApplyHistoryBest") class ApplyHistoryBest(MetaScheduleContext): """An integration context that allows application of historically best record from database""" diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 4f9055bf5bbad..1fba65493f07b 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -94,25 +94,6 @@ Optional MetaScheduleContext::QueryInsideWithScope( return NullOpt; } -/**************** TaskExtraction ****************/ - -TaskExtraction::TaskExtraction() { - ObjectPtr n = make_object(); - n->tasks = Array(); - data_ = n; -} - -Optional TaskExtractionNode::Query(runtime::String task_name, IRModule mod, - Target target, Optional> dispatched) { - ICHECK(dispatched.defined()); - ICHECK_EQ(dispatched.value().size(), 1); - IRModule prim_mod = dispatched.value()[0]; - ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; - ICHECK(HasOnlyOneFunction(mod)) << mod; - tasks.push_back(ExtractedTask(task_name, mod, target, {prim_mod})); - return NullOpt; -} - /**************** ApplyHistoryBest ****************/ ApplyHistoryBest::ApplyHistoryBest(Database database) { @@ -158,7 +139,6 @@ class MetaScheduleContextInternal { TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); TVM_REGISTER_OBJECT_TYPE(MetaScheduleContextNode); -TVM_REGISTER_NODE_TYPE(TaskExtractionNode); TVM_REGISTER_NODE_TYPE(ApplyHistoryBestNode); TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") @@ -176,9 +156,6 @@ TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQueryInsideWithScope") .set_body_typed(MetaScheduleContext::QueryInsideWithScope); TVM_REGISTER_GLOBAL("meta_schedule.MetaScheduleContextQuery") .set_body_method(&MetaScheduleContextNode::Query); -TVM_REGISTER_GLOBAL("meta_schedule.TaskExtraction").set_body_typed([]() -> TaskExtraction { - return TaskExtraction(); -}); TVM_REGISTER_GLOBAL("meta_schedule.ApplyHistoryBest") .set_body_typed([](Database database) -> ApplyHistoryBest { return ApplyHistoryBest(database); From 3f93a1e7645118c002aa10e5b7ff14b71b3f837a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 11:54:57 +0900 Subject: [PATCH 15/23] check in minor vnni-related change --- python/tvm/topi/x86/batch_matmul.py | 1 + python/tvm/topi/x86/dense.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index b446c1f0115c4..2d32bfe8f0968 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -47,6 +47,7 @@ def batch_matmul_vnni_compute(cfg, x, y): axis=ak, ), tag="batch_matmul_vnni", + attrs={"schedule_rule": "batch_matmul_vnni"}, ) _, a_y, _ = z.op.axis diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index c8574b9710038..cd6350352d985 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -296,6 +296,7 @@ def dense_vnni_compute(cfg, X, packed_w, bias=None): axis=ak, ), tag="dense_vnni", + attrs={"schedule_rule": "dense_vnni"}, ) if bias is not None: From 99f1701eb71d77a85bb0f8457841739dc586a168 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 12:34:14 +0900 Subject: [PATCH 16/23] clean up integration.cc and Query interface --- include/tvm/meta_schedule/integration.h | 6 ++--- src/meta_schedule/integration.cc | 32 ++++++++++++++++++------- src/relay/backend/te_compiler_cache.cc | 18 +++++++------- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/include/tvm/meta_schedule/integration.h b/include/tvm/meta_schedule/integration.h index bb384ba3ae837..3140b4f981e5b 100644 --- a/include/tvm/meta_schedule/integration.h +++ b/include/tvm/meta_schedule/integration.h @@ -92,7 +92,7 @@ class MetaScheduleContextNode : public runtime::Object { * 3) relay::Function if `mod` should be dispatched to BYOC workflow * 4) IRModule for unified dispatch */ - virtual Optional Query(runtime::String task_name, IRModule mod, Target target, + virtual IRModule Query(runtime::String task_name, IRModule mod, Target target, Optional> dispatched) = 0; static constexpr const char* _type_key = "meta_schedule.MetaScheduleContext"; @@ -129,7 +129,7 @@ class MetaScheduleContext : public runtime::ObjectRef { * 3) relay::Function if `mod` should be dispatched to BYOC workflow * 4) IRModule for unified dispatch */ - static Optional QueryInsideWithScope(runtime::String task_name, IRModule mod, + static IRModule QueryInsideWithScope(runtime::String task_name, IRModule mod, Target target, Optional> dispatched); @@ -161,7 +161,7 @@ class ApplyHistoryBestNode : public MetaScheduleContextNode { } // Inherited from base class - Optional Query(runtime::String task_name, IRModule mod, Target target, + IRModule Query(runtime::String task_name, IRModule mod, Target target, Optional> dispatched) final; static constexpr const char* _type_key = "meta_schedule.ApplyHistoryBest"; diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 1fba65493f07b..ca83118df3369 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -26,6 +26,21 @@ namespace tvm { namespace meta_schedule { /**************** Utility functions ****************/ +template +Optional GetOnlyOneFunctionKey(const IRModule& mod) { + if (mod->functions.size() != 1) { + return NullOpt; + } + for (const auto& kv : mod->functions) { + const BaseFunc& func = kv.second; + if (!func->IsInstance()) { + return NullOpt; + } else { + return kv.first; + } + } + return NullOpt; +} template Optional GetOnlyOneFunction(const IRModule& mod) { @@ -86,12 +101,13 @@ void MetaScheduleContext::ExitWithScope() { ctx = NullOpt; } -Optional MetaScheduleContext::QueryInsideWithScope( - runtime::String task_name, IRModule mod, Target target, Optional> dispatched) { +IRModule MetaScheduleContext::QueryInsideWithScope(runtime::String task_name, IRModule mod, + Target target, + Optional> dispatched) { if (Optional ctx = MetaScheduleContext::Current()) { return ctx.value()->Query(task_name, mod, target, dispatched); } - return NullOpt; + return IRModule{nullptr}; } /**************** ApplyHistoryBest ****************/ @@ -102,14 +118,14 @@ ApplyHistoryBest::ApplyHistoryBest(Database database) { data_ = n; } -Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, - Target target, - Optional> dispatched) { +IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Target target, + Optional> dispatched) { ICHECK(dispatched.defined()); ICHECK_EQ(dispatched.value().size(), 1); ICHECK(HasOnlyOneFunction(mod)) << mod; IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + auto gv = GetOnlyOneFunctionKey(prim_mod).value(); // Unify func name to make sure it can be found in database const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); ICHECK(parse_mod_func) << "Parse mod function not defined!"; @@ -122,11 +138,11 @@ Optional ApplyHistoryBestNode::Query(runtime::String task_name, IRMod /*error_render_level=*/tir::ScheduleErrorRenderLevel::kNone); records[0]->trace->ApplyToSchedule(sch, false); tir::PrimFunc func = GetOnlyOneFunction(sch->mod()).value(); - return func; + return IRModule({{gv, func}}); } } LOG(WARNING) << "Cannot find workload: " << task_name << "\n" << tir::AsTVMScript(prim_mod); - return NullOpt; + return IRModule{nullptr}; } /**************** FFI ****************/ diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index b287eccf5454f..b05f55099f4ee 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -47,6 +47,7 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/pass_utils.h" +#include "tvm/runtime/object.h" #include "utils.h" namespace tvm { @@ -335,15 +336,14 @@ class ScheduleBuilder : public ExprVisitor { } } if (backend::IsMetaScheduleEnabled()) { - prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs)); - Optional opt_mod_or_base_func = - meta_schedule::MetaScheduleContext::QueryInsideWithScope( - prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_, - Array{IRModule({{prim_fn_var, prim_func}})}); - if (const auto* result = opt_mod_or_base_func.as()) { - prim_func = GetRef(result); - } else { - prim_func = tir::PrimFunc(nullptr); + auto relay_mod = IRModule({{prim_fn_var, relay_func}}); + auto tir_mod = + IRModule({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}}); + IRModule scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope( + prim_fn_var->name_hint, relay_mod, target_, Array{tir_mod}); + if (scheduled_mod.defined()) { + ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1); + prim_func = Downcast(scheduled_mod->functions[prim_fn_var]); } } From 74636beae0878cdda7dd03aa2b09ab2821c86477 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 12:39:58 +0900 Subject: [PATCH 17/23] refactor --- src/meta_schedule/integration.cc | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index ca83118df3369..0b78c7711d155 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -26,8 +26,8 @@ namespace tvm { namespace meta_schedule { /**************** Utility functions ****************/ -template -Optional GetOnlyOneFunctionKey(const IRModule& mod) { +template +Optional GetOnlyOneFunctionCommon(const IRModule& mod, Callback on_found) { if (mod->functions.size() != 1) { return NullOpt; } @@ -36,26 +36,21 @@ Optional GetOnlyOneFunctionKey(const IRModule& mod) { if (!func->IsInstance()) { return NullOpt; } else { - return kv.first; + return on_found(kv); } } return NullOpt; } +template +Optional GetOnlyOneFunctionKey(const IRModule& mod) { + return GetOnlyOneFunctionCommon(mod, [](auto kv) { return kv.first; }); +} + template Optional GetOnlyOneFunction(const IRModule& mod) { - if (mod->functions.size() != 1) { - return NullOpt; - } - for (const auto& kv : mod->functions) { - const BaseFunc& func = kv.second; - if (!func->IsInstance()) { - return NullOpt; - } else { - return Downcast(func); - } - } - return NullOpt; + return GetOnlyOneFunctionCommon( + mod, [](auto kv) { return Downcast(kv.second); }); } template From e49d500299c9c884497410046421853266b60cd2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 12:59:45 +0900 Subject: [PATCH 18/23] return reversed list --- python/tvm/meta_schedule/integration.py | 3 ++- src/meta_schedule/integration.cc | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/meta_schedule/integration.py b/python/tvm/meta_schedule/integration.py index b39f318e0f01c..eebc2429acdfa 100644 --- a/python/tvm/meta_schedule/integration.py +++ b/python/tvm/meta_schedule/integration.py @@ -247,4 +247,5 @@ def extract_task_from_relay( disabled_pass=disabled_pass, ): tasks = extract_task_func(mod, target, relay_params) - return tasks + # Tasks are extracted via post order visit, return the reversed list. + return list(reversed(tasks)) diff --git a/src/meta_schedule/integration.cc b/src/meta_schedule/integration.cc index 0b78c7711d155..d2cb4b307bbf0 100644 --- a/src/meta_schedule/integration.cc +++ b/src/meta_schedule/integration.cc @@ -120,6 +120,10 @@ IRModule ApplyHistoryBestNode::Query(runtime::String task_name, IRModule mod, Ta ICHECK(HasOnlyOneFunction(mod)) << mod; IRModule prim_mod = dispatched.value()[0]; ICHECK(HasOnlyOneFunction(prim_mod)) << prim_mod; + // TODO(masahi): parse_mod below replaces the orginal function key with "main". + // This is necessary because some scheduling primitives requires the PrimFunc key be "main". + // If we can remove this restriction, there would no need for GetOnlyOneFunction* calls below + // and we can directly return sch->mod(). auto gv = GetOnlyOneFunctionKey(prim_mod).value(); // Unify func name to make sure it can be found in database const auto* parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); From dfaf4964bf3a0b542ead5f11f356c2ec592be725 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 13:45:30 +0900 Subject: [PATCH 19/23] dedup tasks --- src/relay/backend/task_extraction.cc | 7 +++++-- src/relay/backend/te_compiler_cache.cc | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index d8ef33da3f9d4..9daf7ec239431 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -46,10 +46,12 @@ Array ExtractTask(IRModule mod, Target target, Map tasks; - PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) { + std::unordered_set cache_; + PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); - if (relay_func->HasNonzeroAttr(attr::kPrimitive)) { + tec::CCacheKey cache_key(relay_func, target); + if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache_.find(cache_key) == cache_.end()) { Array outputs; std::string fused_name; std::tie(outputs, fused_name) = @@ -59,6 +61,7 @@ Array ExtractTask(IRModule mod, Target target, Mapname_hint, relay_mod, target, {tir_mod})); + cache_.insert(cache_key); } } }); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index b05f55099f4ee..f353282cdeb24 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -757,7 +757,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, std::pair, std::string> LowerTECompute(const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); - auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; }); + auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name;}); // Following ScheduleBuilder, remove placeholder ops from outputs. tvm::Array tensor_outs; for (const auto& tensor : outputs) { From 40d52a15b4c1ac9b8d4eac16f98ccec5e2a3e966 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 13:50:43 +0900 Subject: [PATCH 20/23] uniquefy task names --- src/relay/backend/task_extraction.cc | 7 +++++-- src/relay/backend/te_compiler_cache.cc | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 9daf7ec239431..6798b449e5a77 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -47,7 +47,9 @@ Array ExtractTask(IRModule mod, Target target, Map tasks; std::unordered_set cache_; - PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_](const Expr& exp) { + std::unordered_map name_map; + + PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_, &name_map](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); tec::CCacheKey cache_key(relay_func, target); @@ -60,7 +62,8 @@ Array ExtractTask(IRModule mod, Target target, Mapname_hint, relay_mod, target, {tir_mod})); + auto task_name = tec::GetUniqueName(prim_fn_var->name_hint, &name_map); + tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod})); cache_.insert(cache_key); } } diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index f353282cdeb24..b05f55099f4ee 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -757,7 +757,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, std::pair, std::string> LowerTECompute(const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); - auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name;}); + auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; }); // Following ScheduleBuilder, remove placeholder ops from outputs. tvm::Array tensor_outs; for (const auto& tensor : outputs) { From a98182eed3b85e477c5f2527d5d21ce545bd5c18 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 13:56:35 +0900 Subject: [PATCH 21/23] fixed test_meta_schedule_integration_apply_history_best --- src/relay/backend/task_extraction.cc | 2 +- .../test_meta_schedule_integration.py | 60 ++----------------- 2 files changed, 7 insertions(+), 55 deletions(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 6798b449e5a77..62f103e22a464 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -62,7 +62,7 @@ Array ExtractTask(IRModule mod, Target target, Mapname_hint, &name_map); + auto task_name = tec::GetUniqueName(fused_name, &name_map); tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod})); cache_.insert(cache_key); } diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 50dc9289780d1..ad0d4832732f7 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -26,7 +26,6 @@ ApplyHistoryBest, ExtractedTask, MetaScheduleContext, - TaskExtraction, ) from tvm.meta_schedule.testing.relay_workload import get_network from tvm.meta_schedule.utils import derived_object @@ -63,61 +62,12 @@ def _has_torch(): requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed") -def _check_mock_task(tasks: List[ExtractedTask], mod: IRModule): - (task,) = tasks - assert isinstance(task, ExtractedTask) - assert task.task_name == "mock-task" - tvm.ir.assert_structural_equal(task.mod, mod) - (tir_mod,) = task.dispatched - tvm.ir.assert_structural_equal(tir_mod, MockModule) - - -@requires_torch -def test_meta_schedule_integration_task_extraction_query(): - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - env = TaskExtraction() - env.query(task_name="mock-task", mod=mod, target=Target("llvm"), dispatched=[MockModule]) - _check_mock_task(env.tasks, mod) - - -def test_meta_schedule_integration_current(): - env = TaskExtraction() - with env: - assert MetaScheduleContext.current() == env - - -def test_meta_schedule_integration_no_current(): - assert MetaScheduleContext.current() is None - - -def test_meta_schedule_integration_multiple_current(): - env = TaskExtraction() - with env: - with pytest.raises(ValueError): - with env: - ... - - -@requires_torch -def test_meta_schedule_integration_query_inside_with_scope(): - mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) - env = TaskExtraction() - with env: - MetaScheduleContext.query_inside_with_scope( - task_name="mock-task", - mod=mod, - target=Target("llvm"), - dispatched=[MockModule], - ) - _check_mock_task(env.tasks, mod) - - @requires_torch def test_meta_schedule_integration_extract_from_resnet(): mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224]) extracted_tasks = ms.integration.extract_task_from_relay(mod, target="llvm", params=params) expected_task_names = [ - "vm_mod_fused_" + s + "fused_" + s for s in [ "nn_max_pool2d", "nn_adaptive_avg_pool2d", @@ -145,7 +95,8 @@ def test_meta_schedule_integration_extract_from_resnet(): assert len(extracted_tasks) == 20 for t in extracted_tasks: - assert t.task_name in expected_task_names, t.task_name + print(t.task_name) + # assert t.task_name in expected_task_names, t.task_name @requires_torch @@ -197,9 +148,10 @@ def print_results(self) -> None: TuningRecord(Schedule(MockModule).trace, [1.0], workload, target, []) ) mod = env.query(task_name="mock-task", mod=mod, target=target, dispatched=[MockModule]) - mod = IRModule({"main": mod}) assert tvm.ir.structural_equal(mod, workload.mod) if __name__ == "__main__": - sys.exit(pytest.main([__file__] + sys.argv[1:])) + # sys.exit(pytest.main([__file__] + sys.argv[1:])) + # test_meta_schedule_integration_extract_from_resnet() + test_meta_schedule_integration_apply_history_best() From dfa4fb0c20c17049e8ac2c135200074b872ce1ec Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 14:09:11 +0900 Subject: [PATCH 22/23] update expected op list in test_meta_schedule_integration_extract_from_resnet to remove dep on Ansor --- .../test_meta_schedule_integration.py | 37 +++++++++---------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index ad0d4832732f7..1c3ef8ae0e192 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -69,24 +69,24 @@ def test_meta_schedule_integration_extract_from_resnet(): expected_task_names = [ "fused_" + s for s in [ - "nn_max_pool2d", - "nn_adaptive_avg_pool2d", - "nn_dense_add", - "nn_conv2d_add", - "nn_conv2d_add_1", "nn_conv2d_add_2", - "nn_conv2d_add_add_nn_relu", + "nn_conv2d_add_1", + "nn_conv2d_add", + "nn_conv2d_add_nn_relu_7", + "nn_max_pool2d", + "nn_conv2d_add_nn_relu_6", + "nn_conv2d_add_add_nn_relu_3", + "nn_conv2d_add_nn_relu_5", + "nn_conv2d_add_nn_relu_4", + "nn_conv2d_add_add_nn_relu_2", + "nn_conv2d_add_nn_relu_3", + "nn_conv2d_add_nn_relu_2", "nn_conv2d_add_add_nn_relu_1", - "nn_conv2d_add_nn_relu", "nn_conv2d_add_nn_relu_1", - "nn_conv2d_add_nn_relu_2", - "nn_conv2d_add_nn_relu_3", - "nn_conv2d_add_nn_relu_4", - "nn_conv2d_add_nn_relu_5", - "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu", - "nn_contrib_conv2d_winograd_without_weight_transform_add_add_nn_relu_1", - "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu", - "nn_contrib_conv2d_winograd_without_weight_transform_add_nn_relu_1", + "nn_conv2d_add_nn_relu", + "nn_conv2d_add_add_nn_relu", + "nn_adaptive_avg_pool2d", + "nn_contrib_dense_pack_add", # The two tasks below are purely spatial and are ruled out by AutoScheduler "layout_transform", "layout_transform_reshape_squeeze", @@ -95,8 +95,7 @@ def test_meta_schedule_integration_extract_from_resnet(): assert len(extracted_tasks) == 20 for t in extracted_tasks: - print(t.task_name) - # assert t.task_name in expected_task_names, t.task_name + assert t.task_name in expected_task_names, t.task_name @requires_torch @@ -152,6 +151,4 @@ def print_results(self) -> None: if __name__ == "__main__": - # sys.exit(pytest.main([__file__] + sys.argv[1:])) - # test_meta_schedule_integration_extract_from_resnet() - test_meta_schedule_integration_apply_history_best() + sys.exit(pytest.main([__file__] + sys.argv[1:])) From ce8c563d09eaba2a6b03189d1d3452f7565f4c69 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 14:12:30 +0900 Subject: [PATCH 23/23] fix cpplint --- src/relay/backend/te_compiler_cache.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 0f4763d1257cc..55f221ac8ba02 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -37,6 +37,7 @@ #include #include #include +#include #include "../transforms/infer_layout_utils.h"