diff --git a/src/relay/transforms/type_infer.cc b/src/relay/transforms/type_infer.cc index 7182f0e96f0f..e110737d6226 100644 --- a/src/relay/transforms/type_infer.cc +++ b/src/relay/transforms/type_infer.cc @@ -369,14 +369,12 @@ class TypeInferencer : private ExprFunctor, // Build a subsitituion map up from the function type and type arguments. // Eventually allow the type vars to be passed in. + CHECK(fn_ty->type_params.size() == ty_args.size()) + << "number of type parameters does not match expected"; for (size_t i = 0; i < ty_args.size(); ++i) { subst_map.Set(fn_ty->type_params[i], ty_args[i]); } - for (size_t i = ty_args.size(); i < fn_ty->type_params.size(); ++i) { - subst_map.Set(fn_ty->type_params[i], IncompleteType(Kind::kType)); - } - Type ret_type = fn_ty->ret_type; // If the function type is incomplete, place a new IncompleteType @@ -445,6 +443,9 @@ class TypeInferencer : private ExprFunctor, << "Expected " << fn_ty_node->type_params.size() << "but got " << type_args.size()); } + for (size_t i = type_args.size(); i < fn_ty_node->type_params.size(); i++) { + type_args.push_back(IncompleteType(TypeKind::kType)); + } FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index cc4748c92b00..70e0c3fd5127 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -372,5 +372,19 @@ def test_if(): ft = run_infer_type(top) tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32')) +def test_type_arg_infer(): + code = """ +#[version = "0.0.5"] +def @id[A](%x: A) -> A { + %x +} +def @main(%f: float32) -> float32 { + @id(%f) +} +""" + mod = tvm.parser.fromtext(code) + mod = transform.InferType()(mod) + tvm.ir.assert_structural_equal(mod['main'].body.type_args, [relay.TensorType((), 'float32')]) + if __name__ == "__main__": pytest.main([__file__])