Skip to content

Commit

Permalink
[Relay] Fix Type Arguments not Attached (#6385)
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart authored Sep 3, 2020
1 parent 5262765 commit 17d39fb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,

// Build a subsitituion map up from the function type and type arguments.
// Eventually allow the type vars to be passed in.
CHECK(fn_ty->type_params.size() == ty_args.size())
<< "number of type parameters does not match expected";
for (size_t i = 0; i < ty_args.size(); ++i) {
subst_map.Set(fn_ty->type_params[i], ty_args[i]);
}

for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) {
subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType));
}

Type ret_type = fn_ty->ret_type;

// If the function type is incomplete, place a new IncompleteType
Expand Down Expand Up @@ -445,6 +443,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
<< "Expected " << fn_ty_node->type_params.size() << "but got "
<< type_args.size());
}
for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) {
type_args.push_back(IncompleteType(TypeKind::kType));
}

FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);

Expand Down
14 changes: 14 additions & 0 deletions tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,5 +372,19 @@ def test_if():
ft = run_infer_type(top)
tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32'))

def test_type_arg_infer():
code = """
#[version = "0.0.5"]
def @id[A](%x: A) -> A {
%x
}
def @main(%f: float32) -> float32 {
@id(%f)
}
"""
mod = tvm.parser.fromtext(code)
mod = transform.InferType()(mod)
tvm.ir.assert_structural_equal(mod['main'].body.type_args, [relay.TensorType((), 'float32')])

if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 17d39fb

Please sign in to comment.