diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 245b3853ae90..f1fa5b42b31e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1470,6 +1470,22 @@ def _impl_v9(cls, inputs, attr, params): output = AttrCvt(op_name='argwhere')(inputs, attr, params) return _op.transpose(output, axes=(1, 0)) +class TopK(OnnxOpConverter): + """Operator converter for TopK + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if len(inputs) != 2: + raise ValueError("Expect 2 input only") + axis = attr.get("axis", -1) + largest = attr.get("largest", 1) + + if largest == 0: + raise ValueError("TVM only supports finding TopK largest elements") + + K = int(infer_value(inputs[1], params).asnumpy()[0]) + + return _op.topk(inputs[0], k=K, axis=axis) # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1573,8 +1589,11 @@ def _get_convert_map(opset): 'ReduceProd': ReduceProd.get_converter(opset), # 'ReduceProd' # 'ReduceLogSumExp' + + #defs/sorting 'ArgMax': ArgMax.get_converter(opset), 'ArgMin': ArgMin.get_converter(opset), + 'TopK': TopK.get_converter(opset), # defs/tensor 'Cast': Cast.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c06aa50538f4..6f4f2ae93048 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -2330,6 +2330,43 @@ def verify_nonzero(indata, outdata, dtype): result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 2, 2], [0, 1, 0, 1]] verify_nonzero(input_data, result, dtype=np.int64) +def test_topk(): + def verify_topk(input_dims, K, axis=-1): + output_dims = list(input_dims) + output_dims[axis] = K + + node = helper.make_node('TopK', + inputs=['X', 'K'], + outputs=['Values', 'Indicies'], + axis=axis) + + graph = helper.make_graph([node], + "topk_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), + helper.make_tensor_value_info("K", TensorProto.INT64, [1,])], + initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])], + outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), + helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)]) + + model = helper.make_model(graph, producer_name='topk_test') + + indata = np.random.uniform(-10, 10, input_dims).astype(np.float32) + onnx_out = get_onnxruntime_output(model, [indata, k]) + + for target, ctx in [('llvm', tvm.cpu())]: + tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims], + output_dtype=['float32', 'int64']) + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) + + for n in [12, 32]: + for shape in [[n], [n, n], [n, n, n]]: + for k in [1, 5, 10]: + verify_topk(shape, k) + + verify_topk([n, n, n], 5, 0) + verify_topk([n, n, n], 5, 1) + verify_topk([n, n, n], 5, 2) + if __name__ == '__main__': test_flatten() @@ -2392,3 +2429,4 @@ def verify_nonzero(indata, outdata, dtype): test_lstm() test_resize() test_nonzero() + test_topk()