Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed May 14, 2019
1 parent 2a18f74 commit a94bb1f
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 41 deletions.
2 changes: 1 addition & 1 deletion docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ Level 4 Definitions
.. autofunction:: tvm.relay.greater_equal
.. autofunction:: tvm.relay.less
.. autofunction:: tvm.relay.less_equal
.. autofunction:: tvm.relay.all
.. autofunction:: tvm.relay.logical_and
.. autofunction:: tvm.relay.logical_or
.. autofunction:: tvm.relay.logical_not
Expand Down Expand Up @@ -310,7 +311,6 @@ Level 6 Definitions

Level 10 Definitions
--------------------
.. autofunction:: tvm.relay.all
.. autofunction:: tvm.relay.broadcast_to_like
.. autofunction:: tvm.relay.collapse_sum_like
.. autofunction:: tvm.relay.slice_like
Expand Down
41 changes: 31 additions & 10 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.
Returns
-------
Expand Down Expand Up @@ -69,7 +69,7 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.
Returns
-------
Expand Down Expand Up @@ -100,7 +100,7 @@ def sum(data, axis=None, keepdims=False, exclude=False):
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.
Returns
-------
Expand All @@ -112,12 +112,12 @@ def sum(data, axis=None, keepdims=False, exclude=False):


def all(data, axis=None, keepdims=False, exclude=False):
"""Computes the logical and of array elements over given axes.
"""Computes the logical AND of boolean array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
The input boolean tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed. The default, axis=None,
Expand All @@ -131,12 +131,33 @@ def all(data, axis=None, keepdims=False, exclude=False):
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
Examples
--------
.. code-block:: python
data = relay.Constant(tvm.nd.array([[[ True, True, True],
[ True, True, True],
[False, True, False]],
[[ True, False, False],
[ True, True, False],
[False, True, True]]]))
relay.all(data, axis=1)
# [[False, True, False],
# [False, False, False]]
relay.all(data, axis=0)
# [[ True, False, False],
# [ True, True, False],
# [False, True, False]]
"""
axis = [axis] if axis and isinstance(axis, int) else axis
return _make.all(data, axis, keepdims, exclude)
Expand All @@ -162,7 +183,7 @@ def max(data, axis=None, keepdims=False, exclude=False):
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.
Returns
-------
Expand Down Expand Up @@ -194,7 +215,7 @@ def min(data, axis=None, keepdims=False, exclude=False):
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.
Returns
-------
Expand Down Expand Up @@ -225,7 +246,7 @@ def mean(data, axis=None, keepdims=False, exclude=False):
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.
Returns
-------
Expand Down Expand Up @@ -256,7 +277,7 @@ def prod(data, axis=None, keepdims=False, exclude=False):
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
NOT in axis instead.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ Array<Tensor> AllCompute(const Attrs& attrs,


RELAY_REGISTER_REDUCE_OP("all")
.describe(R"code(Computes the logical AND of array elements over given axes.
.describe(R"code(Computes the logical AND of boolean array elements over given axes.
Example::
Expand Down
24 changes: 0 additions & 24 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,34 +208,10 @@ def test_shape_of():
tvm.testing.assert_allclose(op_res.asnumpy(),
np.array(shape).astype('int32'))

def test_contrib_all():
def verify_contrib_all(shape, axis=None, keepdims=False, exclude=False):
x = relay.var("x", relay.TensorType(shape, "bool"))
z = relay.all(x, axis, keepdims, exclude)
zz = relay.ir_pass.infer_type(z)
if axis:
assert "axis=" in z.astext()
if keepdims:
assert "keepdims=" in z.astext()
if exclude:
assert "exclude=" in z.astext()

func = relay.Function([x], z)
x_data = np.random.choice([True, False], size=shape)
ref_res = np.all(x_data, axis=axis, keepdims=keepdims)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_contrib_all((2, 3, 4), axis=(0,))
verify_contrib_all((2, 3, 4, 5, 6), axis=(2, 3), keepdims=True)

if __name__ == "__main__":
test_collapse_sum_like()
test_broadcast_to_like()
test_slice_like()
test_reverse_reshape()
test_batch_matmul()
test_shape_of()
test_contrib_all()
7 changes: 6 additions & 1 deletion tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def test_where():
def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
test_func = funcs[0]
ref_func = funcs[1]
dtype = "bool" if ref_func in [np.all] else dtype

x = relay.var("x", relay.TensorType(data, dtype))
z = test_func(x, axis, keepdims, exclude)
Expand All @@ -155,7 +156,9 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32")
return

func = relay.Function([x], z)
x_data = np.random.uniform(size=data).astype(dtype)
x_data = np.random.choice([True, False], size=data) if ref_func in [np.all] \
else np.random.uniform(size=data).astype(dtype)

if ref_func in [np.sum]:
ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
elif ref_func in [np.max, np.min, np.mean, np.prod]:
Expand Down Expand Up @@ -194,6 +197,7 @@ def _wrapper(data, axis=None, keepdims=False):
[relay.min, np.min],
[relay.mean, np.mean],
[relay.prod, np.prod],
[relay.all, np.all],
[relay.argmin, _with_keepdims(np.argmin)],
[relay.argmax, _with_keepdims(np.argmax)]]:
verify_reduce(func, (d1, d2, d3, d4), None, False, False, ())
Expand All @@ -203,6 +207,7 @@ def _wrapper(data, axis=None, keepdims=False):
verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
verify_reduce(func, (2, 3, 4), -1, True, False, (2, 3, 1))
verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
verify_reduce(func, (4, 4, 3), None, False, False, ())
verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
Expand Down
6 changes: 3 additions & 3 deletions topi/include/topi/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,9 @@ inline Tensor collapse_sum(const Tensor& data, Array<Expr> target_shape) {
* \brief Creates an operation that computes the logical AND of elements
* over a given axis
*
* \param data The input tensor
* \param axis The axis to perform logical AND over. If axis is empty, the
* operation will perform logical AND over all elements of the array.
* \param data The input boolean tensor
* \param axis The axes to reduce. If axis is empty, the operation will
* perform logical AND over all elements of the array.
* \param keepdims If this is set to true, the axes which are reduced are
* left in the result as dimensions with size one. This enables the result
* to broadcast correctly against the input array.
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def all(data, axis=None, keepdims=False):
Parameters
----------
data : tvm.Tensor
The input tvm tensor
The input tvm boolean tensor
axis : None or int or tuple of int
Axis or axes along which a logical AND is performed.
Expand Down

0 comments on commit a94bb1f

Please sign in to comment.