From 5d62a36487b27a94e591242d0dfc1625385d41e1 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 4 Jun 2020 07:36:21 +0530 Subject: [PATCH] [ONNX]ReduceL1, ReduceL2, ReduceSumSquare, ReduceLogSum ops added (#5721) --- python/tvm/relay/frontend/onnx.py | 70 ++++++++ tests/python/frontend/onnx/test_forward.py | 182 +++++++-------------- 2 files changed, 132 insertions(+), 120 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index be886838d740..399597891605 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1163,6 +1163,72 @@ class ReduceLogSumExp(Reduce): """ name = 'logsumexp' + +class ReduceSumSquare(OnnxOpConverter): + """ Operator converter for ReduceSumSquare. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + inputs[0] = inputs[0] * inputs[0] + + return AttrCvt("sum")(inputs, attr) + + +class ReduceL1(OnnxOpConverter): + """ Operator converter for ReduceL1. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + inputs[0] = _op.abs(inputs[0]) + + return AttrCvt("sum")(inputs, attr) + + +class ReduceL2(OnnxOpConverter): + """ Operator converter for ReduceL2. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + inputs[0] = inputs[0] * inputs[0] + out = AttrCvt("sum")(inputs, attr) + + return _op.sqrt(out) + + +class ReduceLogSum(OnnxOpConverter): + """ Operator converter for ReduceLogSum. + """ + @classmethod + def _impl_v1(cls, inputs, attr, params): + if 'axes' in attr: + axis = attr.get('axes', 0) + else: + axis_len = len(infer_shape(inputs[0])) + axis = list(range(axis_len)) + attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)} + out = AttrCvt("sum")(inputs, attr) + + return _op.log(out) + + class ArgMax(OnnxOpConverter): """ Operator converter for ArgMax. """ @@ -1740,6 +1806,10 @@ def _get_convert_map(opset): 'ReduceMean': ReduceMean.get_converter(opset), 'ReduceProd': ReduceProd.get_converter(opset), 'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset), + 'ReduceLogSum': ReduceLogSum.get_converter(opset), + 'ReduceSumSquare': ReduceSumSquare.get_converter(opset), + 'ReduceL1': ReduceL1.get_converter(opset), + 'ReduceL2': ReduceL2.get_converter(opset), #defs/sorting 'ArgMax': ArgMax.get_converter(opset), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index edc33b7079c5..01a54ae945e7 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1364,125 +1364,71 @@ def test_pad(): np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect') -def verify_reduce_x(name, indata, axis, keepdims): - indata = np.array(indata).astype(np.float32) - # numpy expect result - if name == 'ReduceMax': - outdata = np.maximum.reduce(indata, axis=axis, keepdims=keepdims == 1) - elif name == 'ReduceMin': - outdata = np.minimum.reduce(indata, axis=axis, keepdims=keepdims == 1) - elif name == 'ReduceSum': - outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1) - elif name == 'ReduceMean': - outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1) - elif name == 'ReduceLogSumExp': - def _np_log_sum_exp(x, axis, keepdims=False): - max_x = np.max(x, axis=axis, keepdims=True) - x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True)) - x = x + max_x - if not keepdims: - x = np.squeeze(x, axis=axis) - return x - outdata = _np_log_sum_exp(indata, axis=axis, keepdims=keepdims == 1) - else: - raise Exception('unsupport op: {}'.format(name)) - if len(np.asarray(outdata).shape) == 0: - outdata = np.asarray([outdata]) - # onnx graph - if axis is None: - node = helper.make_node(name, inputs=['input'], outputs=['output'], - keepdims=keepdims) +def verify_reduce_func(func, data, axis, keepdims): + inshape = data.shape + outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape + + if axis: + node = onnx.helper.make_node(func, + inputs=['x'], + outputs=['y'], + axes=axis, + keepdims=keepdims) else: - node = helper.make_node(name, inputs=['input'], outputs=['output'], - axes=axis, keepdims=keepdims) + node = onnx.helper.make_node(func, + inputs=['x'], + outputs=['y'], + keepdims=keepdims) + graph = helper.make_graph([node], - '{}_test'.format(name), - inputs=[helper.make_tensor_value_info("input", - TensorProto.FLOAT, list(indata.shape))], - outputs=[helper.make_tensor_value_info("output", - TensorProto.FLOAT, list(outdata.shape))]) - model = helper.make_model(graph, producer_name='{}_test'.format(name)) - # tvm result - for target, ctx in ctx_list(): - tvm_out = get_tvm_output( - model, indata, target, ctx, outdata.shape, 'float32') - tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5) + "reduce_test", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))], + outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))]) + + model = helper.make_model(graph, producer_name='reduce_test') + onnx_out = get_onnxruntime_output(model, data, 'float32') + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, data, target, ctx, outshape, 'float32') + tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5) -def test_reduce_max(): - verify_reduce_x("ReduceMax", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=1) - verify_reduce_x("ReduceMax", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=0) - verify_reduce_x("ReduceMax", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=1) - - -def test_reduce_min(): - verify_reduce_x("ReduceMin", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=1) - verify_reduce_x("ReduceMin", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=0) - verify_reduce_x("ReduceMin", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=1) - - -def test_reduce_sum(): - verify_reduce_x("ReduceSum", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=1) - verify_reduce_x("ReduceSum", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=0) - verify_reduce_x("ReduceSum", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=1) - - -def test_reduce_mean(): - verify_reduce_x("ReduceMean", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=1) - verify_reduce_x("ReduceMean", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=0) - verify_reduce_x("ReduceMean", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=1) - - -def test_reduce_logsumexp(): - - for keepdims in [True, False]: - verify_reduce_x("ReduceLogSumExp", - np.random.randn(3, 2, 2).astype(np.float32), - axis=None, keepdims=keepdims) - - verify_reduce_x("ReduceLogSumExp", - np.random.randn(3, 2, 3).astype(np.float32), - axis=None, keepdims=keepdims) - - verify_reduce_x("ReduceLogSumExp", - np.random.randn(3, 3, 3).astype(np.float32), - axis=(1,), keepdims=keepdims) - - verify_reduce_x("ReduceLogSumExp", - np.random.randn(3, 3, 3, 1).astype(np.float32), - axis=(1, 2), keepdims=keepdims) - - verify_reduce_x("ReduceLogSumExp", - np.random.randn(3, 3, 3, 1).astype(np.float32), - axis=(1), keepdims=keepdims) - - verify_reduce_x("ReduceLogSumExp", - np.random.randn(1, 3, 4, 1).astype(np.float32), - axis=(1), keepdims=keepdims) +def test_all_reduce_funcs(): + funcs = ["ReduceMax", + "ReduceMean", + "ReduceMin", + "ReduceProd", + "ReduceSum", + 'ReduceSumSquare', + "ReduceLogSum", + "ReduceLogSumExp", + "ReduceL1", + "ReduceL2"] + + for func in funcs: + for keepdims in [True, False]: + verify_reduce_func(func, + np.random.randn(3, 2, 2).astype(np.float32), + axis=None, keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(3, 2, 3).astype(np.float32), + axis=None, keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(3, 3, 3).astype(np.float32), + axis=(1,), keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1, 2), keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(3, 3, 3, 1).astype(np.float32), + axis=(1,), keepdims=keepdims) + + verify_reduce_func(func, + np.random.randn(1, 3, 4, 1).astype(np.float32), + axis=(1,), keepdims=keepdims) def verify_split(indata, outdatas, split, axis=0): @@ -2758,11 +2704,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ test_forward_arg_min_max() test_softmax() test_constantofshape() - test_reduce_max() - test_reduce_min() - test_reduce_sum() - test_reduce_mean() - test_reduce_logsumexp() + test_all_reduce_funcs() test_pad() test_split() test_binary_ops()