From 9851c8398066e4881ccd3b276ce231506035e180 Mon Sep 17 00:00:00 2001 From: Ramana Radhakrishnan Date: Wed, 5 Jun 2019 18:19:13 +0100 Subject: [PATCH] Add support for overloading comparison operations in relay (#2910) (#3168) --- python/tvm/relay/expr.py | 32 +++++++++++++++++++++++++++ tests/python/relay/test_cmp_op.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 tests/python/relay/test_cmp_op.py diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 98b4a83e09de..8e7f95c4dc26 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -70,6 +70,38 @@ def astype(self, dtype): def __neg__(self): return _op_make.negative(self) + def __lt__(self, other): + if isinstance(other, Expr): + return _op_make.less(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __gt__(self, other): + if isinstance(other, Expr): + return _op_make.greater(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __ge__(self, other): + if isinstance(other, Expr): + return _op_make.greater_equal(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + + def __le__(self, other): + if isinstance(other, Expr): + return _op_make.less_equal(self, other) + elif isinstance(other, _Number): + raise TypeError('convert "%s" with `const` first' % str(other)) + else: + raise TypeError("type %s not supported" % str(type(other))) + def __add__(self, other): if isinstance(other, Expr): return _op_make.add(self, other) diff --git a/tests/python/relay/test_cmp_op.py b/tests/python/relay/test_cmp_op.py new file mode 100644 index 000000000000..d096eec598b7 --- /dev/null +++ b/tests/python/relay/test_cmp_op.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from tvm import relay +a = relay.Var("a") +b = relay.expr.const (1.0, dtype='float32') + +c = a < b +d = relay.less (a, b) +assert (c.astext() == d.astext()) + +c = a > b +d = relay.greater (a, b) +assert (c.astext() == d.astext()) + +c = (a >= b) +d = relay.greater_equal(a, b) +assert (c.astext() == d.astext()) + +c = (a <= b) +d = relay.less_equal(a, b) +assert (c.astext() == d.astext())