diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py index 54ea51ee18de..d9e7397ae71f 100644 --- a/python/tvm/_ffi/node.py +++ b/python/tvm/_ffi/node.py @@ -40,9 +40,7 @@ def __hash__(self): return _api_internal._raw_ptr(self) def __eq__(self, other): - if not isinstance(other, NodeBase): - return False - return self.__hash__() == other.__hash__() + return self.same_as(other) def __ne__(self, other): return not self.__eq__(other) @@ -67,6 +65,12 @@ def __setstate__(self, state): else: self.handle = None + def same_as(self, other): + """check object identity equality""" + if not isinstance(other, NodeBase): + return False + return self.__hash__() == other.__hash__() + def register_node(type_key=None): """register node type diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 39bd80c7467c..b265103360c6 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -138,9 +138,11 @@ def astype(self, dtype): return _make.static_cast(dtype, self) -class Expr(NodeBase, ExprOp): +class Expr(ExprOp, NodeBase): """Base class of all tvm Expressions""" - pass + # In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__ + # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ + __hash__ = NodeBase.__hash__ class ConstExpr(Expr): pass @@ -213,11 +215,19 @@ class Max(BinaryOpExpr): @register_node class EQ(CmpExpr): - pass + def __nonzero__(self): + return self.a.same_as(self.b) + + def __bool__(self): + return self.__nonzero__() @register_node class NE(CmpExpr): - pass + def __nonzero__(self): + return not self.a.same_as(self.b) + + def __bool__(self): + return self.__nonzero__() @register_node class LT(CmpExpr): diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index 0a7803311f0b..1888cd9f1d18 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -90,7 +90,7 @@ class IRBuilder(object): n = tvm.var("n") A = ib.allocate("float32", n, name="A") with ib.for_range(0, n, name="i") as i: - with ib.if_scope((i % 2).equal(0)): + with ib.if_scope((i % 2) == 0): A[i] = A[i] + 1 # The result stmt. stmt = ib.get() diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index 170ed089addf..c5cc192a3f33 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -25,7 +25,7 @@ def test_if(): n = tvm.var("n") A = ib.pointer("float32", name="A") with ib.for_range(0, n, name="i") as i: - with ib.if_scope((i % 2).equal(0)): + with ib.if_scope((i % 2) == 0): A[i] = A[i] + 1 with ib.else_scope(): A[0] = A[i] + 2