Skip to content

Commit

Permalink
Fix old type inference invariant
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed Jan 16, 2019
1 parent 35c3043 commit 6f273d6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}

n->entry_func = GlobalVarNode::make("main");
return Module(n);
}
Expand Down Expand Up @@ -104,7 +105,6 @@ Expr ModuleNode::EntryPoint() {
Module ModuleNode::FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) {
GlobalVar main = GlobalVarNode::make("main");
auto mod = ModuleNode::make({});
auto func_node = expr.as<FunctionNode>();
Function func;
Expand All @@ -113,7 +113,7 @@ Module ModuleNode::FromExpr(
} else {
func = FunctionNode::make({}, expr, Type(), {}, {});
}
mod->Add(main, func);
mod->Add(mod->entry_func, func);
return mod;
}

Expand Down
11 changes: 10 additions & 1 deletion src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,16 @@ Expr InferType(const Expr& expr, const Module& mod_ref) {
// NB(@jroesch): By adding the expression to the module we will
// type check it anyway; afterwards we can just recover type
// from the type-checked function to avoid doing unnecessary work.
return mod->Lookup(mod->entry_func);

Function e = mod->Lookup(mod->entry_func);

// FromExpr wraps a naked expression as a function, we will unbox
// it here.
if (auto func = expr.as<FunctionNode>()) {
return e;
} else {
return e->body;
}
} else {
auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr);
CHECK(WellFormed(e));
Expand Down

0 comments on commit 6f273d6

Please sign in to comment.