Skip to content

Commit

Permalink
uniquefy task names
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent dfaf496 commit 40d52a1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta

Array<ExtractedTask> tasks;
std::unordered_set<tec::CCacheKey> cache_;
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_](const Expr& exp) {
std::unordered_map<std::string, int> name_map;

PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks, &cache_, &name_map](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
tec::CCacheKey cache_key(relay_func, target);
Expand All @@ -60,7 +62,8 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
auto prim_fn_var = GlobalVar(fused_name);
auto relay_mod = IRModule({{prim_fn_var, relay_func}});
auto tir_mod = IRModule({{prim_fn_var, prim_func}});
tasks.push_back(ExtractedTask(prim_fn_var->name_hint, relay_mod, target, {tir_mod}));
auto task_name = tec::GetUniqueName(prim_fn_var->name_hint, &name_map);
tasks.push_back(ExtractedTask(task_name, relay_mod, target, {tir_mod}));
cache_.insert(cache_key);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& source_func, Target target,
bool return_inputs) {
LowerToTECompute lower_te_compute(target);
auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name;});
auto outputs = lower_te_compute.Lower(source_func, [&](std::string name) { return name; });
// Following ScheduleBuilder, remove placeholder ops from outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
Expand Down

0 comments on commit 40d52a1

Please sign in to comment.