diff --git a/tests/python/unittest/test_arith_solve_linear_inequality.py b/tests/python/unittest/test_arith_solve_linear_inequality.py index acdabecc89f4..80618dd7929e 100644 --- a/tests/python/unittest/test_arith_solve_linear_inequality.py +++ b/tests/python/unittest/test_arith_solve_linear_inequality.py @@ -156,8 +156,16 @@ def test_multi_equal(): assert solution.ranges[x].extent == 1 assert len(solution.relations) == 3 assert ir.structural_equal(solution.relations[0], x == z * y) - assert ir.structural_equal(solution.relations[1], z*y - 6 <= 0) - assert ir.structural_equal(solution.relations[2], 6 - z*y <= 0) + + assert isinstance(solution.relations[1], tvm.tir.LE) + assert solution.relations[1].b == 0 + assert isinstance(solution.relations[2], tvm.tir.LE) + assert solution.relations[2].b == 0 + # (z*y - 6) <= 0 && (6 - z*y) <= 0 + ana = tvm.arith.Analyzer() + assert ana.simplify(solution.relations[1].a + solution.relations[2].a) == 0 + assert ir.structural_equal(solution.relations[1].a, (z*y - 6)) or \ + ir.structural_equal(solution.relations[2].a, (z*y - 6)) solution = arith.solve_linear_inequalities(problem, [x, y, z], deskew_range=True) assert solution.src_to_dst[y] == y