Skip to content

Commit

Permalink
[RELAY]Reduce ops sum/max/min/mean/prod (apache#1927)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and Wei Chen committed Feb 19, 2019
1 parent 3ec1191 commit 76aa7a8
Show file tree
Hide file tree
Showing 4 changed files with 308 additions and 29 deletions.
10 changes: 10 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ This level enables additional math and transform operators.
tvm.relay.where
tvm.relay.argmax
tvm.relay.argmin
tvm.relay.sum
tvm.relay.max
tvm.relay.min
tvm.relay.mean
tvm.relay.prod


**Level 5: Vision/Image Operators**
Expand Down Expand Up @@ -187,6 +192,11 @@ Level 4 Definitions
.. autofunction:: tvm.relay.where
.. autofunction:: tvm.relay.argmax
.. autofunction:: tvm.relay.argmin
.. autofunction:: tvm.relay.sum
.. autofunction:: tvm.relay.max
.. autofunction:: tvm.relay.min
.. autofunction:: tvm.relay.mean
.. autofunction:: tvm.relay.prod


Level 5 Definitions
Expand Down
152 changes: 150 additions & 2 deletions python/tvm/relay/op/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def argmax(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""

return _make.argmax(data, axis, keepdims, exclude)

def argmin(data, axis=None, keepdims=False, exclude=False):
Expand Down Expand Up @@ -60,5 +59,154 @@ def argmin(data, axis=None, keepdims=False, exclude=False):
result : relay.Expr
The computed result.
"""

return _make.argmin(data, axis, keepdims, exclude)


def sum(data, axis=None, keepdims=False, exclude=False):
"""Computes the sum of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.sum(data, axis, keepdims, exclude)


def max(data, axis=None, keepdims=False, exclude=False):
""" Computes the max of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.max(data, axis, keepdims, exclude)


def min(data, axis=None, keepdims=False, exclude=False):
"""Computes the min of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.min(data, axis, keepdims, exclude)


def mean(data, axis=None, keepdims=False, exclude=False):
"""Computes the mean of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.mean(data, axis, keepdims, exclude)


def prod(data, axis=None, keepdims=False, exclude=False):
"""Computes the products of array elements over given axes.
Parameters
----------
data : relay.Expr
The input data
axis : None or int or tuple of int
Axis or axes along which a argmin operation is performed.
The default, axis=None, will find the indices of minimum element all of the elements of
the input array. If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
exclude : bool
If `exclude` is true, reduction will be performed on the axes that are
NOT in axis instead.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.prod(data, axis, keepdims, exclude)
116 changes: 111 additions & 5 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <tvm/relay/op.h>
#include <numeric>
#include <limits>
#include "../op_common.h"
#include "../type_relations.h"

namespace tvm {
Expand All @@ -19,7 +20,7 @@ struct ReduceAttrs : public tvm::AttrsNode<ReduceAttrs> {
bool exclude;

TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
TVM_ATTR_FIELD(axis).set_default(Array<IndexExpr>({}))
TVM_ATTR_FIELD(axis).set_default(NullValue<Array<IndexExpr>>())
.describe(R"code(The axis or axes along which to perform the reduction.
The default, `axis=()`, will compute over all elements into a
Expand Down Expand Up @@ -158,10 +159,7 @@ bool ArgReduceRel(const Array<Type>& types,
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr> in_shape;
for (auto i : data->shape) {
in_shape.push_back(i);
}
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);

const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);
Expand All @@ -172,6 +170,31 @@ bool ArgReduceRel(const Array<Type>& types,
return true;
}

/*!
* \brief ReduceRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool ReduceRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
CHECK(static_cast<int>(data->shape.size()) != 0);
std::vector<IndexExpr>&& in_shape = AsVector(data->shape);

const ReduceAttrs* param = attrs.as<ReduceAttrs>();
CHECK(param != nullptr);

// assign output type and shape
auto oshape = ReduceShapeImpl(in_shape, param, reporter);
reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
return true;
}

#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
Expand Down Expand Up @@ -213,5 +236,88 @@ values over a given axis.
.set_support_level(4)
.add_type_rel("ArgReduce", ArgReduceRel);


RELAY_REGISTER_REDUCE_OP("sum")
.describe(R"code(Computes the sum of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
sum(data, axis=1)
[[ 4. 8.]
[ 10. 9.]
[ 21. 6.]]
sum(data, axis=[1,2])
[ 12. 19. 27.]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("max")
.describe(R"code(Computes the max of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("min")
.describe(R"code(Computes the min of array elements over given axes.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("mean")
.describe(R"code(Computes the mean of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data)
[3.22]
mean(data, axis=[1,2])
[ 2. 3.16666667 4.5]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);


RELAY_REGISTER_REDUCE_OP("prod")
.describe(R"code(Computes the products of array elements over given axes.
Example::
data = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
mean(data, axis=1)
[35562240]
mean(data, axis=[1,2])
[ 36 480 2058]
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.ReduceAttrs")
.set_support_level(4)
.add_type_rel("Reduce", ReduceRel);

} // namespace relay
} // namespace tvm
Loading

0 comments on commit 76aa7a8

Please sign in to comment.