diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 2b0e30b37863..8c1cc92fe009 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -186,17 +186,6 @@ class TypeSolver::Unifier : public TypeFunctor { if (ulhs.same_as(urhs)) { return ulhs; } - - if (ulhs.as() && urhs.as()) { - solver_->shape_uf_.Set(urhs, ulhs); - return urhs; - } - - if (ulhs.as() && urhs.as()) { - solver_->shape_uf_.Set(ulhs, urhs); - return ulhs; - } - if (ulhs.as() || urhs.as()) { return Any(); } diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 6758d96773a2..b518c31d3e62 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -22,7 +22,6 @@ from tvm import IRModule, te, relay, parser from tvm.relay import op, transform, analysis -from tvm.relay import Any def infer_mod(mod, annotate_spans=True): @@ -388,16 +387,6 @@ def test_let_polymorphism(): tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) -def test_if(): - choice_t = relay.FuncType([], relay.scalar_type("bool")) - f = relay.Var("f", choice_t) - true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32")) - false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32")) - top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch)) - ft = infer_expr(top) - tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32")) - - def test_type_arg_infer(): code = """ #[version = "0.0.5"] diff --git a/tests/python/topi/python/test_topi_sparse.py b/tests/python/topi/python/test_topi_sparse.py index 07af478a5087..9426eb7499df 100644 --- a/tests/python/topi/python/test_topi_sparse.py +++ b/tests/python/topi/python/test_topi_sparse.py @@ -466,7 +466,8 @@ def test_sparse_dense_padded_alter_op(): ), ) f = relay.Function([], mult) - f_ = relay.transform.AlterOpLayout()(tvm.IRModule.from_expr(f)) + f = relay.transform.InferType()(tvm.IRModule.from_expr(f)) + f_ = relay.transform.AlterOpLayout()(f) assert f_["main"].body.op.name == "nn.internal.sparse_dense_padded"