Skip to content

Commit

Permalink
dedup tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent e49d500 commit dfaf496
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
auto opt_mod = seq(std::move(mod));

Array<ExtractedTask> tasks;
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) {
std::unordered_set<tec::CCacheKey> cache_;
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
if (relay_func->HasNonzeroAttr(attr::kPrimitive)) {
tec::CCacheKey cache_key(relay_func, target);
if (relay_func->HasNonzeroAttr(attr::kPrimitive) && cache_.find(cache_key) == cache_.end()) {
Array<te::Tensor> outputs;
std::string fused_name;
std::tie(outputs, fused_name) =
Expand All @@ -59,6 +61,7 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
auto relay_mod = IRModule({{prim_fn_var, relay_func}});
auto tir_mod = IRModule({{prim_fn_var, prim_func}});
tasks.push_back(ExtractedTask(prim_fn_var->name_hint, relay_mod, target, {tir_mod}));
cache_.insert(cache_key);
}
}
});
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& source_func, Target target,
bool return_inputs) {
LowerToTECompute lower_te_compute(target);
auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; });
auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name;});
// Following ScheduleBuilder, remove placeholder ops from outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
Expand Down

0 comments on commit dfaf496

Please sign in to comment.