diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 6310e3bfcf29..c3da195d9c8e 100644 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -457,7 +457,7 @@ def get_name(node): def infer_type(node, mod=None): """A method to infer the type of an intermediate node in the relay graph.""" if isinstance(mod, IRModule): - mod["main"] = _function.Function([], node) + mod["main"] = _function.Function(tvm.relay.analysis.free_vars(node), node) mod = _transform.InferType()(mod) entry = mod["main"] ret = entry.body