Skip to content

Commit

Permalink
[Relay/TOPI][Op] Add TopK operator (apache#3256)
Browse files Browse the repository at this point in the history
* init impl for topk

* Fix cpu for topk

* init cuda impl for topk

* Add cuda for topk

* fix

* Add doc

* update doc

* lint

* lint

* lint

* x

* fix warning

* [Relay] Add TopK in tf converter

* Add frontend converter

* fix
  • Loading branch information
icemelon authored and wweic committed Jun 27, 2019
1 parent 383a1fa commit be585d3
Show file tree
Hide file tree
Showing 24 changed files with 904 additions and 148 deletions.
4 changes: 4 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ List of operators
topi.shape
topi.layout_transform
topi.image.resize
topi.argsort
topi.topk


List of schedules
Expand Down Expand Up @@ -163,6 +165,8 @@ topi
.. autofunction:: topi.tile
.. autofunction:: topi.shape
.. autofunction:: topi.layout_transform
.. autofunction:: topi.argsort
.. autofunction:: topi.topk

topi.nn
~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ This level enables additional math and transform operators.
:nosignatures:

tvm.relay.argsort
tvm.relay.topk


**Level 10: Temporary Operators**
Expand Down Expand Up @@ -309,6 +310,7 @@ Level 5 Definitions
Level 6 Definitions
-------------------
.. autofunction:: tvm.relay.argsort
.. autofunction:: tvm.relay.topk


Level 10 Definitions
Expand Down
25 changes: 25 additions & 0 deletions include/tvm/relay/attrs/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
}
};

struct TopKAttrs : public tvm::AttrsNode<TopKAttrs> {
int k;
int axis;
bool is_ascend;
std::string ret_type;
DataType dtype;

TVM_DECLARE_ATTRS(TopKAttrs, "relay.attrs.TopkAttrs") {
TVM_ATTR_FIELD(k).set_default(1)
.describe("Number of top elements to select");
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis along which to sort the input tensor.");
TVM_ATTR_FIELD(ret_type).set_default("both")
.describe("The return type [both, values, indices]."
"both - return both top k data and indices."
"values - return top k data only."
"indices - return top k indices only.");
TVM_ATTR_FIELD(is_ascend).set_default(false)
.describe("Whether to sort in ascending or descending order."
"By default, sort in descending order");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("Data type of the output indices.");
}
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_
16 changes: 16 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs):
return _op.argsort(inputs[0], **new_attrs)


def _mx_topk(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["k"] = attrs.get_int("k", 1)
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
ret_type = attrs.get_str("ret_typ", "indices")
if ret_type == "mask":
raise tvm.error.OpAttributeUnimplemented(
"Attribute ret_type=mask is not supported in topk operator")
new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.topk(inputs[0], **new_attrs)


def _mx_rnn_param_concat(inputs, _):
# We don't need to concatenate RNN params because we will unravel the RNN op
return [inputs]
Expand Down Expand Up @@ -914,6 +929,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"topk" : _mx_topk,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,20 @@ def _impl(inputs, attr, params):
return _get_relay_op('log')(add_out)
return _impl

def _topk():
def _impl(inputs, attr, params):
k = int(params.pop(inputs.pop(1).name_hint).asnumpy())
if k < 1:
raise tvm.error.OpAttributeInvalid(
'Attribute k must be positive in operator TopKV2')
if attr['sorted'] is False:
raise tvm.error.OpAttributeUnimplemented(
'Attribute sorted=False is not supported in operator TopKV2')
return AttrCvt(op_name='topk',
ignores=['sorted'],
extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr)
return _impl

def _logical(name):
def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
Expand Down Expand Up @@ -1271,6 +1285,7 @@ def _impl(inputs, attr, params):
'Sum' : _sum(),
'Tanh' : AttrCvt('tanh'),
'Tile' : _tile(),
'TopKV2' : _topk(),
'Transpose' : _transpose(),
'Unpack' : _unpack(),

Expand Down
30 changes: 25 additions & 5 deletions python/tvm/relay/op/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,31 @@ def compute_argsort(attrs, inputs, _, target):
"""Compute definition of argsort"""
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = str(attrs.dtype)
return [
topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
dtype=dtype, flag=False)
]
dtype = attrs.dtype
return [topi.argsort(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]


register_pattern("argsort", OpPattern.OPAQUE)


@register_schedule("topk")
def schedule_topk(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_topk(outs)


@register_compute("topk")
def compute_topk(attrs, inputs, _, target):
"""Compute definition of argsort"""
k = get_const_int(attrs.k)
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi.topk(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out


register_pattern("topk", OpPattern.OPAQUE)
44 changes: 42 additions & 2 deletions python/tvm/relay/op/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
"""Classic algorithm operation"""
from __future__ import absolute_import as _abs
from . import _make
from ..expr import TupleWrapper

def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
def argsort(data, axis=-1, is_ascend=1, dtype="int32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Expand All @@ -37,11 +38,50 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
The data type of the output indices.
Returns
-------
out : relay.Expr
Tensor with same shape as data.
"""
return _make.argsort(data, axis, is_ascend, dtype)


def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"):
"""Get the top k elements in an input tensor along the given axis.
ret_type specifies the return type, can be one of ("both", "values", "indices").
Parameters
----------
data : relay.Expr
The input data tensor.
k : int, optional
Number of top elements to select. Return all elements if k < 1.
axis : int, optional
Axis long which to sort the input tensor.
ret_type: str, optional
The return type [both, values, indices].
"both": return both top k data and indices.
"values": return top k data only.
"indices": return top k indices only.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
The data type of the indices output.
Returns
-------
out : relay.Expr or List[relay.Expr]
The computed result.
"""
out = _make.topk(data, k, axis, ret_type, is_ascend, dtype)
if ret_type == "both":
return TupleWrapper(out, 2)
return out
2 changes: 1 addition & 1 deletion python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def upsampling(data,
with data of shape (n, c, h, w)
out will have a shape (n, c, h*scale, w*scale)
method indicates the algorithm to be used while calculating ghe out value
method indicates the algorithm to be used while calculating the out value
and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR")
Parameters
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ def take(data, indices, axis=None, mode="clip"):
the flattened input array is used.
mode : str, optional
Specifies how out-of-bound indices will behave.
clip - clip to the range (default)
wrap - wrap around the indices
Specifies how out-of-bound indices will behave [clip, wrap].
clip: clip to the range (default).
wrap: wrap around the indices.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Target CreateTarget(const std::string& target_name,
t->device_type = kDLGPU;
t->keys_array.push_back(ir::StringImm::make("cuda"));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 512;
t->max_num_threads = 1024;
t->thread_warp_size = 32;
} else if (target_name == "rocm" || target_name == "opencl") {
// For now assume rocm schedule for opencl
Expand Down
Loading

0 comments on commit be585d3

Please sign in to comment.