From 08df02243d1c0697248aecfe4168840a6b05d834 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Fri, 13 Oct 2017 08:56:47 -0700 Subject: [PATCH 1/2] Add same_as to NodeBase 1. Most class inherited from NodeBase(Schedule, Stage, etc) still have the convenience of using '==' for object identity. And this is the right behavior for non-Expr classes. 2. subclasses of ExprOp now create EQ expression when '==' is used. `__nonzero__` and `__bool__` in EQ and NE is a comprise that in some cases object identity semantics is still useful, like in unit test. For instance: ```` assert a == b ```` "a == b" will create EQ expression, assert then calls `__nonzero__` of the result expression. `Expr.__nonzero__` throws exception since it prohibits evaluating IR expression. More complex case like: ```` assert a in b # b is dict ```` it will call `__eq__` on a and all keys of b, then `__bool__` on the result expression. This could not easily be done by same_as. --- python/tvm/_ffi/node.py | 10 +++++++--- python/tvm/expr.py | 14 +++++++++++--- python/tvm/ir_builder.py | 2 +- tests/python/unittest/test_ir_builder.py | 2 +- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/python/tvm/_ffi/node.py b/python/tvm/_ffi/node.py index 54ea51ee18de..d9e7397ae71f 100644 --- a/python/tvm/_ffi/node.py +++ b/python/tvm/_ffi/node.py @@ -40,9 +40,7 @@ def __hash__(self): return _api_internal._raw_ptr(self) def __eq__(self, other): - if not isinstance(other, NodeBase): - return False - return self.__hash__() == other.__hash__() + return self.same_as(other) def __ne__(self, other): return not self.__eq__(other) @@ -67,6 +65,12 @@ def __setstate__(self, state): else: self.handle = None + def same_as(self, other): + """check object identity equality""" + if not isinstance(other, NodeBase): + return False + return self.__hash__() == other.__hash__() + def register_node(type_key=None): """register node type diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 39bd80c7467c..0c43bc585d2a 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -138,7 +138,7 @@ def astype(self, dtype): return _make.static_cast(dtype, self) -class Expr(NodeBase, ExprOp): +class Expr(ExprOp, NodeBase): """Base class of all tvm Expressions""" pass @@ -213,11 +213,19 @@ class Max(BinaryOpExpr): @register_node class EQ(CmpExpr): - pass + def __nonzero__(self): + return self.a.same_as(self.b) + + def __bool__(self): + return self.__nonzero__() @register_node class NE(CmpExpr): - pass + def __nonzero__(self): + return not self.a.same_as(self.b) + + def __bool__(self): + return self.__nonzero__() @register_node class LT(CmpExpr): diff --git a/python/tvm/ir_builder.py b/python/tvm/ir_builder.py index 0a7803311f0b..1888cd9f1d18 100644 --- a/python/tvm/ir_builder.py +++ b/python/tvm/ir_builder.py @@ -90,7 +90,7 @@ class IRBuilder(object): n = tvm.var("n") A = ib.allocate("float32", n, name="A") with ib.for_range(0, n, name="i") as i: - with ib.if_scope((i % 2).equal(0)): + with ib.if_scope((i % 2) == 0): A[i] = A[i] + 1 # The result stmt. stmt = ib.get() diff --git a/tests/python/unittest/test_ir_builder.py b/tests/python/unittest/test_ir_builder.py index 170ed089addf..c5cc192a3f33 100644 --- a/tests/python/unittest/test_ir_builder.py +++ b/tests/python/unittest/test_ir_builder.py @@ -25,7 +25,7 @@ def test_if(): n = tvm.var("n") A = ib.pointer("float32", name="A") with ib.for_range(0, n, name="i") as i: - with ib.if_scope((i % 2).equal(0)): + with ib.if_scope((i % 2) == 0): A[i] = A[i] + 1 with ib.else_scope(): A[0] = A[i] + 2 From 3422a9e5ecf1ef5ed650aea4161657db424e4e98 Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Fri, 13 Oct 2017 14:20:46 -0700 Subject: [PATCH 2/2] Retain __hash__ from NodeBase in Python3 --- python/tvm/expr.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 0c43bc585d2a..b265103360c6 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -140,7 +140,9 @@ def astype(self, dtype): class Expr(ExprOp, NodeBase): """Base class of all tvm Expressions""" - pass + # In Python3, We have to explicity tell interpreter to retain __hash__ if we overide __eq__ + # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ + __hash__ = NodeBase.__hash__ class ConstExpr(Expr): pass