diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 920c15b2e7563..da273265ae339 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -146,7 +146,7 @@ void ModuleNode::Update(const Module& mod) { Module ModuleNode::FromExpr( const Expr& expr, const tvm::Map& global_funcs) { - auto mod = ModuleNode::make(global_funcs); + auto mod = ModuleNode::make(global_funcs, {}); auto func_node = expr.as(); Function func; if (func_node) { diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index dc84ffb5a3304..207600a5f760c 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -118,7 +118,7 @@ class TypeInferencer : private ExprFunctor, // Perform unification on two types and report the error at the expression // or the span of the expression. - Type Unify(const Type& t1, const Type& t2, const Expr& expr) { + Type Unify(const Type& t1, const Type& t2, const NodeRef& expr) { // TODO(tqchen, jroesch): propagate span to solver try { return solver_.Unify(t1, t2, expr); @@ -148,7 +148,7 @@ class TypeInferencer : private ExprFunctor, return ret; } - void ReportFatalError(const Expr& expr, const Error& err) { + void ReportFatalError(const NodeRef& expr, const Error& err) { CHECK(this->current_func_.defined()); this->err_reporter.ReportAt(this->current_func_, expr, err); this->err_reporter.RenderErrors(this->mod_); @@ -214,7 +214,7 @@ class TypeInferencer : private ExprFunctor, unknown_args.push_back(IncompleteTypeNode::make(Kind::kType)); } Type expected = TypeCallNode::make(con->con->belong_to, unknown_args); - Type unified = Unify(t, expected, con->span); + Type unified = Unify(t, expected, GetRef(con)); auto* tc = unified.as(); CHECK(tc) << "must be type call"; @@ -259,7 +259,7 @@ class TypeInferencer : private ExprFunctor, // if the definition is a function literal, permit recursion bool is_functional_literal = let->value.as() != nullptr; if (is_functional_literal) { - type_map_[op->var].checked_type = IncompleteTypeNode::make(Kind::kType); + type_map_[let->var].checked_type = IncompleteTypeNode::make(Kind::kType); } Type vtype = GetType(let->value); @@ -685,7 +685,7 @@ Expr InferType(const Expr& expr, const Module& mod_ref) { } else { auto e = TypeInferencer(mod_ref, mod_ref->entry_func).Infer(expr); CHECK(WellFormed(e)); - auto free_tvars = FreeTypeVars(e); + auto free_tvars = FreeTypeVars(e, mod_ref); CHECK(free_tvars.size() == 0) << "Found unbound type variables in " << e << ": " << free_tvars; EnsureCheckedType(e); diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 10e1d5390eec1..d6c412c5d950b 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -510,7 +510,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver") } else if (name == "AddConstraint") { return TypedPackedFunc([solver](TypeConstraint c) { Expr e = VarNode::make("dummy_var", - IncompleteTypeNode::make(TypeVarNode::Kind::kType)); + IncompleteTypeNode::make(Kind::kType)); return solver->AddConstraint(c, e); }); } else {