From 08fa79134bda45fc6595ef648125705679f7812f Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Mon, 20 May 2019 11:56:22 -0700 Subject: [PATCH] [Relay][TOPI] operator All (#3124) * [Relay][TOPI] operator All * Update tests/python/frontend/tensorflow/test_forward.py Co-Authored-By: yongwww <55wuyong@163.com> * fix comments * change to level 4 --- docs/api/python/topi.rst | 2 + docs/langref/relay_op.rst | 2 + include/tvm/expr_operator.h | 7 ++ python/tvm/relay/frontend/tensorflow.py | 12 ++++ python/tvm/relay/op/_reduce.py | 1 + python/tvm/relay/op/reduce.py | 66 +++++++++++++++++-- src/lang/expr_operator.cc | 10 +++ src/relay/op/tensor/reduce.cc | 37 +++++++++++ .../frontend/tensorflow/test_forward.py | 12 ++++ tests/python/relay/test_op_level4.py | 7 +- topi/include/topi/reduction.h | 21 ++++++ topi/python/topi/reduction.py | 25 +++++++ topi/src/topi.cc | 5 ++ topi/tests/python/test_topi_reduce.py | 47 +++++++++---- 14 files changed, 232 insertions(+), 22 deletions(-) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index eaa5dacd678e..0b217d4fe3af 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -88,6 +88,7 @@ List of operators topi.not_equal topi.greater_equal topi.less_equal + topi.all topi.logical_and topi.logical_or topi.logical_not @@ -140,6 +141,7 @@ topi .. autofunction:: topi.gather_nd .. autofunction:: topi.full .. autofunction:: topi.full_like +.. autofunction:: topi.all .. autofunction:: topi.max .. autofunction:: topi.sum .. autofunction:: topi.min diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index cd5677293571..836f8f30bfa8 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -135,6 +135,7 @@ This level enables additional math and transform operators. tvm.relay.greater_equal tvm.relay.less tvm.relay.less_equal + tvm.relay.all tvm.relay.logical_and tvm.relay.logical_or tvm.relay.logical_not @@ -277,6 +278,7 @@ Level 4 Definitions .. autofunction:: tvm.relay.greater_equal .. autofunction:: tvm.relay.less .. autofunction:: tvm.relay.less_equal +.. autofunction:: tvm.relay.all .. autofunction:: tvm.relay.logical_and .. autofunction:: tvm.relay.logical_or .. autofunction:: tvm.relay.logical_not diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 2e1348e00470..f289bdd810d5 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -428,6 +428,13 @@ TVM_DLL Expr abs(Expr x); */ TVM_DLL Expr sum(Expr source, Array axis); +/*! + * \brief logical And of of source expression over axis + * \param source The source expression. + * \param axis List of iteration variables that will be used for reduction. + */ +TVM_DLL Expr all(Expr source, Array axis); + /*! * \brief max of of source expression over axis * \param source The source expression. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 11026b9e5ad8..7fe82ea7eac1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -767,6 +767,17 @@ def _impl(inputs, attr, params): ignores=['name', 'Tidx'])([inputs[0]], attr) return _impl +def _reduce_all(): + def _impl(inputs, attr, params): + axis = params.pop(inputs[1].name_hint).asnumpy() + axis = tuple(axis) + return AttrCvt( + op_name='all', + extras={'axis': axis}, + transforms={'keep_dims':'keepdims'}, + ignores=['name', 'Tidx'])([inputs[0]], attr) + return _impl + def _square(): def _impl(inputs, attr, params): return _op.multiply(inputs[0], inputs[0]) @@ -1180,6 +1191,7 @@ def _impl(inputs, attr, params): # for N to 1 mapping, currently not supported(?) _convert_map = { 'Add' : _elemwise('add'), + 'All' : _reduce_all(), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), 'AvgPool' : _pooling('avg_pool'), diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index b97e3a8ce993..b7c9a79a8ad9 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -30,6 +30,7 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("argmax", _schedule_reduce) _reg.register_schedule("argmin", _schedule_reduce) _reg.register_schedule("sum", _schedule_reduce) +_reg.register_schedule("all", _schedule_reduce) _reg.register_schedule("max", _schedule_reduce) _reg.register_schedule("min", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce) diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 9d58a92041f3..0f2594600b0a 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -39,7 +39,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -69,7 +69,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -100,7 +100,7 @@ def sum(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -111,6 +111,58 @@ def sum(data, axis=None, keepdims=False, exclude=False): return _make.sum(data, axis, keepdims, exclude) +def all(data, axis=None, keepdims=False, exclude=False): + """Computes the logical AND of boolean array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input boolean tensor + + axis : None or int or tuple of int + Axis or axes along which a sum is performed. The default, axis=None, + will sum all of the elements of the input array. If axis is + negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. With this option, the result will broadcast + correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + data = relay.Constant(tvm.nd.array([[[ True, True, True], + [ True, True, True], + [False, True, False]], + [[ True, False, False], + [ True, True, False], + [False, True, True]]])) + + relay.all(data, axis=1) + # [[False, True, False], + # [False, False, False]] + + relay.all(data, axis=0) + # [[ True, False, False], + # [ True, True, False], + # [False, True, False]] + + """ + axis = [axis] if axis and isinstance(axis, int) else axis + return _make.all(data, axis, keepdims, exclude) + + def max(data, axis=None, keepdims=False, exclude=False): """ Computes the max of array elements over given axes. @@ -131,7 +183,7 @@ def max(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -163,7 +215,7 @@ def min(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -194,7 +246,7 @@ def mean(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- @@ -225,7 +277,7 @@ def prod(data, axis=None, keepdims=False, exclude=False): exclude : bool If `exclude` is true, reduction will be performed on the axes that are - NOT in axis instead. + NOT in axis instead. Returns ------- diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 4504ee23f812..8537f17b763c 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -393,6 +393,16 @@ Expr sum(Expr source, Array rdom) { return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } +Expr all(Expr source, Array rdom) { + CHECK(source.type().is_bool()); + Var x("x", source.type()), y("y", source.type()); + Expr result = ir::And::make(x, y); + Expr identity_element = make_const(source.type(), true); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); +} + Expr max(Expr source, Array rdom) { Var x("x", source.type()), y("y", source.type()); Expr result = ir::Max::make(x, y); diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index a4ebd1e8d050..647e4d0f4f90 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -355,6 +355,43 @@ Example:: .set_attr("TOpPattern", kCommReduce); +Array AllCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::all); +} + + +RELAY_REGISTER_REDUCE_OP("all") +.describe(R"code(Computes the logical AND of boolean array elements over given axes. + +Example:: + + data = [[[ True, True, True], + [ True, True, True], + [False, True, False]], + [[ True, False, False], + [ True, True, False], + [False, True, True]]] + + all(data, axis=1) + [[False, True, False], + [False, False, False]] + + all(data, axis=0) + [[ True, False, False], + [ True, True, False], + [False, True, False]] + +)code" TVM_ADD_FILELINE) +.set_attrs_type_key("relay.attrs.ReduceAttrs") +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", AllCompute) +.set_attr("TOpPattern", kCommReduce); + + Array MaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index e4626e0d60ff..023cdf5eb261 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1597,6 +1597,17 @@ def check_mean(ishape, **kwargs): check_mean((10, 8, 16, 32), axis=(2,3)) check_mean((10, 8, 16, 32), axis=(1,2), keepdims=True) +####################################################################### +# All +# --- +def test_forward_all(): + """Test the All operator.""" + np_data = np.random.choice([True, False], size=(5, 7, 11)) + tf.reset_default_graph() + in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data") + tf.reduce_all(in_data, name="all") + compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0') + ####################################################################### # Relational operators # -------------------- @@ -1718,6 +1729,7 @@ def test_placeholder(): test_forward_reduce() test_forward_mean() test_forward_reduce_prod() + test_forward_all() # General test_forward_multi_input() diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 0e44bf851dc4..aac4a6d4af16 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -138,6 +138,7 @@ def test_where(): def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): test_func = funcs[0] ref_func = funcs[1] + dtype = "bool" if ref_func in [np.all] else dtype x = relay.var("x", relay.TensorType(data, dtype)) z = test_func(x, axis, keepdims, exclude) @@ -155,7 +156,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") return func = relay.Function([x], z) - x_data = np.random.uniform(size=data).astype(dtype) + x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \ + else np.random.uniform(size=data).astype(dtype) + if ref_func in [np.sum]: ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims) elif ref_func in [np.max, np.min, np.mean, np.prod]: @@ -194,6 +197,7 @@ def _wrapper(data, axis=None, keepdims=False): [relay.min, np.min], [relay.mean, np.mean], [relay.prod, np.prod], + [relay.all, np.all], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) @@ -203,6 +207,7 @@ def _wrapper(data, axis=None, keepdims=False): verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) + verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1)) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) verify_reduce(func, (4, 4, 3), None, False, False, ()) verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index b24c4577c4e5..09d1b4b1b33e 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -368,6 +368,27 @@ inline Tensor collapse_sum(const Tensor& data, Array target_shape) { return DoCommReduce(data, tvm::sum, target_shape, reduce_axes, squeeze_axes); } +/*! +* \brief Creates an operation that computes the logical AND of elements +* over a given axis +* +* \param data The input boolean tensor +* \param axis The axes to reduce. If axis is empty, the operation will +* perform logical AND over all elements of the array. +* \param keepdims If this is set to true, the axes which are reduced are +* left in the result as dimensions with size one. This enables the result +* to broadcast correctly against the input array. +* \param atleast1d Whether the output need to be atleast1d. +* +* \return A Tensor whose op member is the all operation +*/ +inline Tensor all(const Tensor& data, + const Array& axis, + bool keepdims = false, + bool atleast1d = false) { + return CommReduce(data, axis, tvm::all, keepdims, atleast1d); +} + /*! * \brief Creates an operation that finds the minimum of elements over * a given axis. diff --git a/topi/python/topi/reduction.py b/topi/python/topi/reduction.py index ce1326b78162..5079bf474deb 100644 --- a/topi/python/topi/reduction.py +++ b/topi/python/topi/reduction.py @@ -65,6 +65,31 @@ def sum(data, axis=None, keepdims=False): return cpp.sum(data, axis, keepdims) +def all(data, axis=None, keepdims=False): + """Logical AND of array elements over a given axis or a list of axes + + Parameters + ---------- + data : tvm.Tensor + The input tvm boolean tensor + + axis : None or int or tuple of int + Axis or axes along which a logical AND is performed. + The default, axis=None, will perform logical AND over all elements of the input array. + If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + Returns + ------- + ret : tvm.Tensor + """ + return cpp.all(data, axis, keepdims) + + def max(data, axis=None, keepdims=False): """Maximum of array elements over a given axis or a list of axes diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 1585d877b625..d3e0bc938f7c 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -265,6 +265,11 @@ TVM_REGISTER_GLOBAL("topi.prod") *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); }); +TVM_REGISTER_GLOBAL("topi.all") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]); + }); + /* Ops from transform.h */ TVM_REGISTER_GLOBAL("topi.expand_dims") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py index 1882cbd7f896..6e6470dad588 100644 --- a/topi/tests/python/test_topi_reduce.py +++ b/topi/tests/python/test_topi_reduce.py @@ -50,6 +50,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") out_dtype = dtype if type == "sum": B = topi.sum(A1, axis=axis, keepdims=keepdims) + elif type == "all": + B = topi.all(A, axis=axis, keepdims=keepdims) elif type == "max": B = topi.max(A1, axis=axis, keepdims=keepdims) elif type == "min": @@ -74,10 +76,16 @@ def check_device(device): foo = tvm.build(s, [A, B], device, name=type) # Test - in_npy = np.random.uniform(size=in_shape).astype(dtype) - in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) + if dtype == 'bool': + in_npy_map = in_npy = np.random.choice([True, False], size=in_shape) + else: + in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype) + in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) + if type == "sum": out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) + elif type == "all" and dtype == 'bool': + out_npy = in_npy_map.all(axis=axis, keepdims=keepdims) elif type == "max": out_npy = in_npy_map.max(axis=axis, keepdims=keepdims) elif type == "min": @@ -113,26 +121,37 @@ def check_device(device): def test_reduce_map(): + verify_reduce_map_ele(in_shape=(32,), axis=0, keepdims=False, type="argmax") verify_reduce_map_ele(in_shape=(128, 24, 128, 24), - axis=(1, 2, 3), - keepdims=True, - type="sum") + axis=(1, 2, 3), + keepdims=True, + type="sum") + verify_reduce_map_ele(in_shape=(2, 3), + axis=None, + keepdims=True, + type="all", + dtype='bool') verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24), - axis=(1,), - keepdims=False, - type="max") + axis=(1,), + keepdims=False, + type="max") + verify_reduce_map_ele(in_shape=(32, 128, 24), + axis=None, + keepdims=True, + type="sum") verify_reduce_map_ele(in_shape=(32, 128, 24), - axis=None, - keepdims=True, - type="sum") + axis=None, + keepdims=True, + dtype='bool', + type="all") verify_reduce_map_ele(in_shape=(128, 24, 128, 24), - axis=(0, 2), - keepdims=False, - type="min") + axis=(0, 2), + keepdims=False, + type="min") verify_reduce_map_ele(in_shape=(32, 128), axis=1, keepdims=True,