Skip to content

Commit

Permalink
Fix a few bugs: (#13)
Browse files Browse the repository at this point in the history
1. Don't add relay main function to list of lowered TIR functions
2. Don't skip visiting call to relay function in graph runtime codegen
  • Loading branch information
csullivan authored Feb 3, 2021
1 parent 554d2c5 commit 7b0287e
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 14 deletions.
1 change: 0 additions & 1 deletion python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def build(mod, target, target_host=None):
"""
if target_host == "":
target_host = None
import pdb; pdb.set_trace()
return tvm.driver.build(mod, target=target, target_host=target_host)


Expand Down
3 changes: 0 additions & 3 deletions src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G

ret.lowered_funcs = lowered_module.per_target_module;
std::cout << "Modules: " << ret.lowered_funcs << std::endl;
auto it = ret.lowered_funcs.begin();
(*it).second->Update(main_module);
ret.external_mods = lowered_module.external_mods;
return ret;
}
Expand Down Expand Up @@ -410,7 +408,6 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator<std::vector<G
// LOG(FATAL) << "TVM only support calls to primitive functions "
// << "(i.e functions composed of fusable operator invocations)";
// }
std::cout << PrettyPrint(call) << std::endl;

if (auto op_node = call->op.as<OpNode>()) {
if (op_node->name != "prim_fn_call") {
Expand Down
12 changes: 2 additions & 10 deletions src/relay/backend/tir_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -887,17 +887,9 @@ class LowerTensorExpr : public ExprMutator {
}

// Process inputs.
bool skip_first;
Array<Expr> args;
for (auto arg : expr->args) {
// The first input is a function, not a tensor.
if (skip_first) {
skip_first = false;
args.push_back(arg);
continue;
}

args.push_back(VisitExpr(arg));
for (size_t i = 0; i < expr->args.size(); i++) {
args.push_back(VisitExpr(expr->args[i]));
}

Target target;
Expand Down

0 comments on commit 7b0287e

Please sign in to comment.