diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0bc7923648ff..0975a33450c8 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -683,6 +683,21 @@ def _mx_argsort(inputs, attrs): return _op.argsort(inputs[0], **new_attrs) +def _mx_topk(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["k"] = attrs.get_int("k", 1) + new_attrs["axis"] = attrs.get_int("axis", -1) + new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True) + ret_type = attrs.get_str("ret_typ", "indices") + if ret_type == "mask": + raise tvm.error.OpAttributeUnimplemented( + "Attribute ret_type=mask is not supported in topk operator") + new_attrs["ret_type"] = "values" if ret_type == "value" else ret_type + new_attrs["dtype"] = attrs.get_str("dtype", "float32") + return _op.topk(inputs[0], **new_attrs) + + def _mx_rnn_param_concat(inputs, _): # We don't need to concatenate RNN params because we will unravel the RNN op return [inputs] @@ -914,6 +929,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "shape_array" : _mx_shape_array, "Embedding" : _mx_embedding, "argsort" : _mx_argsort, + "topk" : _mx_topk, "SoftmaxOutput" : _mx_softmax_output, "SoftmaxActivation" : _mx_softmax_activation, "LinearRegressionOutput" : _mx_linear_regression_output, diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 87cb91523a3f..307fb20693f4 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1084,11 +1084,16 @@ def _impl(inputs, attr, params): def _topk(): def _impl(inputs, attr, params): - k = params.pop(inputs.pop(1).name_hint).asnumpy() - is_ascend = False if attr['sorted'] else None # TODO: arg value for original ordering + k = int(params.pop(inputs.pop(1).name_hint).asnumpy()) + if k < 1: + raise tvm.error.OpAttributeInvalid( + 'Attribute k must be positive in operator TopKV2') + if attr['sorted'] is False: + raise tvm.error.OpAttributeUnimplemented( + 'Attribute sorted=False is not supported in operator TopKV2') return AttrCvt(op_name='topk', ignores=['sorted'], - extras={'k': int(k), 'is_ascend': is_ascend})(inputs, attr) + extras={'k': k, 'is_ascend': False, 'dtype': 'int32'})(inputs, attr) return _impl def _logical(name): diff --git a/python/tvm/relay/op/algorithm.py b/python/tvm/relay/op/algorithm.py index d0336095d264..6f875919df4c 100644 --- a/python/tvm/relay/op/algorithm.py +++ b/python/tvm/relay/op/algorithm.py @@ -17,8 +17,9 @@ """Classic algorithm operation""" from __future__ import absolute_import as _abs from . import _make +from ..expr import TupleWrapper -def argsort(data, axis=-1, is_ascend=1, dtype="float32"): +def argsort(data, axis=-1, is_ascend=1, dtype="int32"): """Performs sorting along the given axis and returns an array of indicies having same shape as an input array that index data in sorted order. @@ -47,7 +48,7 @@ def argsort(data, axis=-1, is_ascend=1, dtype="float32"): return _make.argsort(data, axis, is_ascend, dtype) -def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): +def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int32"): """Get the top k elements in an input tensor along the given axis. ret_type specifies the return type, can be one of ("both", "values", "indices"). @@ -80,4 +81,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): out : relay.Expr or List[relay.Expr] The computed result. """ - return _make.topk(data, k, axis, ret_type, is_ascend, dtype) + out = _make.topk(data, k, axis, ret_type, is_ascend, dtype) + if ret_type == "both": + return TupleWrapper(out, 2) + return out diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 50a25a9aff61..7569257830af 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -608,6 +608,45 @@ def verify(xshape, yshape, offset=None): verify((5, 32, 40, 40), (5, 32, 25, 25)) verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5)) +def test_forward_argsort(): + def verify(shape, axis, is_ascend, dtype="float32"): + x_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.argsort(mx.nd.array(x_np), axis=axis, is_ascend=is_ascend, dtype=dtype) + mx_sym = mx.sym.argsort(mx.sym.var("x"), axis=axis, is_ascend=is_ascend, dtype=dtype) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((2, 3, 4), axis=0, is_ascend=False) + verify((1, 4, 6), axis=1, is_ascend=True) + verify((3, 5, 6), axis=-3, is_ascend=False, dtype="int32") + +def test_forward_topk(): + def verify(shape, k, axis, ret_type, is_ascend=False, dtype="float32"): + x_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.topk(mx.nd.array(x_np), k=k, axis=axis, ret_typ=ret_type, + is_ascend=is_ascend, dtype=dtype) + mx_sym = mx.sym.topk(mx.sym.var("x"), k=k, axis=axis, ret_typ=ret_type, + is_ascend=is_ascend, dtype=dtype) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(x_np) + if isinstance(ref_res, list): + assert len(op_res) == len(ref_res) + for i, t in enumerate(op_res): + tvm.testing.assert_allclose(t.asnumpy(), ref_res[i].asnumpy()) + else: + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((3, 4), k=1, axis=0, ret_type="both") + verify((3, 4), k=1, axis=-1, ret_type="indices") + verify((3, 5, 6), k=2, axis=2, ret_type="value") + verify((3, 5, 6), k=2, axis=1, ret_type="value", is_ascend=True) + verify((3, 5, 6), k=0, axis=2, ret_type="both", dtype="int32") + if __name__ == '__main__': test_forward_mlp() @@ -650,3 +689,5 @@ def verify(xshape, yshape, offset=None): test_forward_bilinear_resize() test_forward_rnn_layer() test_forward_Crop() + test_forward_argsort() + test_forward_topk() diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 59428d34f5cd..eebb73c95b1b 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -758,20 +758,18 @@ def test_forward_split(): # TopKV2 # ------ -def _test_forward_top_k_v2(in_shape, k, sorted, dtype): - np_data = np.random.uniform(-100, 100, size=in_shape).astype(dtype) - +def _test_forward_top_k_v2(in_shape, k): + np_data = np.random.uniform(-100, 100, size=in_shape).astype("float32") tf.reset_default_graph() - in_data = tf.placeholder(dtype, in_shape, name="in_data") - tf.math.top_k(in_data, k, sorted, name='TopK') + in_data = tf.placeholder("float32", in_shape, name="in_data") + tf.math.top_k(in_data, k, name='TopK') compare_tf_with_tvm([np_data], ['in_data:0'], 'TopK:0') def test_forward_top_k_v2(): - #_test_forward_top_k_v2((3,), 0, True, 'int32') - _test_forward_top_k_v2((3,), 1, True, 'float32') - _test_forward_top_k_v2((3,), 3, True, 'float32') - _test_forward_top_k_v2((3, 5, 7), 3, True, 'float32') - #_test_forward_top_k_v2((3, 5, 13), 11, False, 'float32') + _test_forward_top_k_v2((3,), 1) + _test_forward_top_k_v2((3,), 3) + _test_forward_top_k_v2((3, 5, 7), 3) + _test_forward_top_k_v2((3, 5, 7), 3) #######################################################################