diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index a192002825e6..a674265a88de 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -175,6 +175,17 @@ class TypeSolver::Unifier : public TypeFunctor { if (ulhs.same_as(urhs)) { return ulhs; } + + if (ulhs.as() && urhs.as()) { + solver_->shape_uf_.Set(urhs, ulhs); + return urhs; + } + + if (ulhs.as() && urhs.as()) { + solver_->shape_uf_.Set(ulhs, urhs); + return ulhs; + } + if (ulhs.as() || urhs.as()) { return Any(); } diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6810d0b6a753..bf28ee1b2ff5 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -721,9 +721,7 @@ def _body(i, st): mod["main"] = func data = np.array(0.0, dtype='int32') ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32") - # TODO(@jroesch): After LambdaLift pass, TypeInfer pass will fail - # so currently we cannot run this test case on VM - for kind in ["debug"]: + for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(data) np.testing.assert_allclose(result.asnumpy(), ref) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 45916180c1d3..e5082db25cbc 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -21,6 +21,7 @@ from tvm import te from tvm import relay from tvm.relay import op, transform, analysis +from tvm.relay import Any def run_infer_type(expr, mod=None): @@ -362,6 +363,15 @@ def test_let_polymorphism(): tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) +def test_if(): + choice_t = relay.FuncType([], relay.scalar_type('bool')) + f = relay.Var('f', choice_t) + true_branch = relay.Var('True', relay.TensorType([Any(), 1], dtype='float32')) + false_branch = relay.Var('False', relay.TensorType([Any(), Any()], dtype='float32')) + top = relay.Function([true_branch, false_branch], relay.If(f(), true_branch, false_branch)) + ft = run_infer_type(top) + tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype='float32')) + if __name__ == "__main__": test_free_expr() test_dual_op() @@ -380,3 +390,4 @@ def test_let_polymorphism(): test_constructor_call() test_adt_match() test_let_polymorphism() + test_if()