diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index 1888cd9f1d18..0a7803311f0b 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -90,7 +90,7 @@ class IRBuilder(object): n = tvm.var("n") A = ib.allocate("float32", n, name="A") with ib.for_range(0, n, name="i") as i: - with ib.if_scope((i % 2) == 0): + with ib.if_scope((i % 2).equal(0)): A[i] = A[i] + 1 # The result stmt. stmt = ib.get() diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index 0ef7ed5cf58c..170ed089addf 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -25,7 +25,7 @@ def test_if(): n = tvm.var("n") A = ib.pointer("float32", name="A") with ib.for_range(0, n, name="i") as i: - with ib.if_scope((i % 2) == 0): + with ib.if_scope((i % 2).equal(0)): A[i] = A[i] + 1 with ib.else_scope(): A[0] = A[i] + 2 @@ -34,6 +34,7 @@ def test_if(): assert isinstance(body, tvm.stmt.For) body = body.body assert isinstance(body, tvm.stmt.IfThenElse) + assert isinstance(body.condition, tvm.expr.EQ) assert isinstance(body.then_case.index, tvm.expr.Var) assert body.else_case.index.value == 0