From dfaf4964bf3a0b542ead5f11f356c2ec592be725 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 13:45:30 +0900 Subject: [PATCH] dedup tasks --- src/relay/backend/task_extraction.cc | 7 +++++-- src/relay/backend/te_compiler_cache.cc | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index d8ef33da3f9d4..9daf7ec239431 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -46,10 +46,12 @@ Array ExtractTask(IRModule mod, Target target, Map tasks; - PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) { + std::unordered_set cache_; + PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); - if (relay_func->HasNonzeroAttr(attr::kPrimitive)) { + tec::CCacheKey cache_key(relay_func, target); + if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache_.find(cache_key) == cache_.end()) { Array outputs; std::string fused_name; std::tie(outputs, fused_name) = @@ -59,6 +61,7 @@ Array ExtractTask(IRModule mod, Target target, Mapname_hint, relay_mod, target, {tir_mod})); + cache_.insert(cache_key); } } }); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index b05f55099f4ee..f353282cdeb24 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -757,7 +757,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, std::pair, std::string> LowerTECompute(const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); - auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; }); + auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name;}); // Following ScheduleBuilder, remove placeholder ops from outputs. tvm::Array tensor_outs; for (const auto& tensor : outputs) {