Skip to content

Commit

Permalink
[ONNX]ReduceL1, ReduceL2, ReduceSumSquare, ReduceLogSum ops added (ap…
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and trevor-m committed Jun 18, 2020
1 parent e98b06e commit 5d62a36
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 120 deletions.
70 changes: 70 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,72 @@ class ReduceLogSumExp(Reduce):
"""
name = 'logsumexp'


class ReduceSumSquare(OnnxOpConverter):
""" Operator converter for ReduceSumSquare.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'axes' in attr:
axis = attr.get('axes', 0)
else:
axis_len = len(infer_shape(inputs[0]))
axis = list(range(axis_len))
attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
inputs[0] = inputs[0] * inputs[0]

return AttrCvt("sum")(inputs, attr)


class ReduceL1(OnnxOpConverter):
""" Operator converter for ReduceL1.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'axes' in attr:
axis = attr.get('axes', 0)
else:
axis_len = len(infer_shape(inputs[0]))
axis = list(range(axis_len))
attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
inputs[0] = _op.abs(inputs[0])

return AttrCvt("sum")(inputs, attr)


class ReduceL2(OnnxOpConverter):
""" Operator converter for ReduceL2.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'axes' in attr:
axis = attr.get('axes', 0)
else:
axis_len = len(infer_shape(inputs[0]))
axis = list(range(axis_len))
attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
inputs[0] = inputs[0] * inputs[0]
out = AttrCvt("sum")(inputs, attr)

return _op.sqrt(out)


class ReduceLogSum(OnnxOpConverter):
""" Operator converter for ReduceLogSum.
"""
@classmethod
def _impl_v1(cls, inputs, attr, params):
if 'axes' in attr:
axis = attr.get('axes', 0)
else:
axis_len = len(infer_shape(inputs[0]))
axis = list(range(axis_len))
attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
out = AttrCvt("sum")(inputs, attr)

return _op.log(out)


class ArgMax(OnnxOpConverter):
""" Operator converter for ArgMax.
"""
Expand Down Expand Up @@ -1740,6 +1806,10 @@ def _get_convert_map(opset):
'ReduceMean': ReduceMean.get_converter(opset),
'ReduceProd': ReduceProd.get_converter(opset),
'ReduceLogSumExp': ReduceLogSumExp.get_converter(opset),
'ReduceLogSum': ReduceLogSum.get_converter(opset),
'ReduceSumSquare': ReduceSumSquare.get_converter(opset),
'ReduceL1': ReduceL1.get_converter(opset),
'ReduceL2': ReduceL2.get_converter(opset),

#defs/sorting
'ArgMax': ArgMax.get_converter(opset),
Expand Down
182 changes: 62 additions & 120 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,125 +1364,71 @@ def test_pad():
np.float32), [0, 0, 1, 1, 0, 0, 1, 1], 'reflect')


def verify_reduce_x(name, indata, axis, keepdims):
indata = np.array(indata).astype(np.float32)
# numpy expect result
if name == 'ReduceMax':
outdata = np.maximum.reduce(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceMin':
outdata = np.minimum.reduce(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceSum':
outdata = np.sum(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceMean':
outdata = np.mean(indata, axis=axis, keepdims=keepdims == 1)
elif name == 'ReduceLogSumExp':
def _np_log_sum_exp(x, axis, keepdims=False):
max_x = np.max(x, axis=axis, keepdims=True)
x = np.log(np.sum(np.exp(x - max_x), axis=axis, keepdims=True))
x = x + max_x
if not keepdims:
x = np.squeeze(x, axis=axis)
return x
outdata = _np_log_sum_exp(indata, axis=axis, keepdims=keepdims == 1)
else:
raise Exception('unsupport op: {}'.format(name))
if len(np.asarray(outdata).shape) == 0:
outdata = np.asarray([outdata])
# onnx graph
if axis is None:
node = helper.make_node(name, inputs=['input'], outputs=['output'],
keepdims=keepdims)
def verify_reduce_func(func, data, axis, keepdims):
inshape = data.shape
outshape = np.sum(data, axis=axis, keepdims=keepdims == 1).shape

if axis:
node = onnx.helper.make_node(func,
inputs=['x'],
outputs=['y'],
axes=axis,
keepdims=keepdims)
else:
node = helper.make_node(name, inputs=['input'], outputs=['output'],
axes=axis, keepdims=keepdims)
node = onnx.helper.make_node(func,
inputs=['x'],
outputs=['y'],
keepdims=keepdims)

graph = helper.make_graph([node],
'{}_test'.format(name),
inputs=[helper.make_tensor_value_info("input",
TensorProto.FLOAT, list(indata.shape))],
outputs=[helper.make_tensor_value_info("output",
TensorProto.FLOAT, list(outdata.shape))])
model = helper.make_model(graph, producer_name='{}_test'.format(name))
# tvm result
for target, ctx in ctx_list():
tvm_out = get_tvm_output(
model, indata, target, ctx, outdata.shape, 'float32')
tvm.testing.assert_allclose(outdata, tvm_out, rtol=1e-5, atol=1e-5)
"reduce_test",
inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, list(inshape))],
outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, list(outshape))])

model = helper.make_model(graph, producer_name='reduce_test')

onnx_out = get_onnxruntime_output(model, data, 'float32')
for target, ctx in ctx_list():
tvm_out = get_tvm_output(model, data, target, ctx, outshape, 'float32')
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)

def test_reduce_max():
verify_reduce_x("ReduceMax",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMax",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMax",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)


def test_reduce_min():
verify_reduce_x("ReduceMin",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMin",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMin",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)


def test_reduce_sum():
verify_reduce_x("ReduceSum",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceSum",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceSum",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)


def test_reduce_mean():
verify_reduce_x("ReduceMean",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=1)
verify_reduce_x("ReduceMean",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=0)
verify_reduce_x("ReduceMean",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=1)


def test_reduce_logsumexp():

for keepdims in [True, False]:
verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 3, 3, 1).astype(np.float32),
axis=(1, 2), keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(3, 3, 3, 1).astype(np.float32),
axis=(1), keepdims=keepdims)

verify_reduce_x("ReduceLogSumExp",
np.random.randn(1, 3, 4, 1).astype(np.float32),
axis=(1), keepdims=keepdims)
def test_all_reduce_funcs():
funcs = ["ReduceMax",
"ReduceMean",
"ReduceMin",
"ReduceProd",
"ReduceSum",
'ReduceSumSquare',
"ReduceLogSum",
"ReduceLogSumExp",
"ReduceL1",
"ReduceL2"]

for func in funcs:
for keepdims in [True, False]:
verify_reduce_func(func,
np.random.randn(3, 2, 2).astype(np.float32),
axis=None, keepdims=keepdims)

verify_reduce_func(func,
np.random.randn(3, 2, 3).astype(np.float32),
axis=None, keepdims=keepdims)

verify_reduce_func(func,
np.random.randn(3, 3, 3).astype(np.float32),
axis=(1,), keepdims=keepdims)

verify_reduce_func(func,
np.random.randn(3, 3, 3, 1).astype(np.float32),
axis=(1, 2), keepdims=keepdims)

verify_reduce_func(func,
np.random.randn(3, 3, 3, 1).astype(np.float32),
axis=(1,), keepdims=keepdims)

verify_reduce_func(func,
np.random.randn(1, 3, 4, 1).astype(np.float32),
axis=(1,), keepdims=keepdims)


def verify_split(indata, outdatas, split, axis=0):
Expand Down Expand Up @@ -2758,11 +2704,7 @@ def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_
test_forward_arg_min_max()
test_softmax()
test_constantofshape()
test_reduce_max()
test_reduce_min()
test_reduce_sum()
test_reduce_mean()
test_reduce_logsumexp()
test_all_reduce_funcs()
test_pad()
test_split()
test_binary_ops()
Expand Down

0 comments on commit 5d62a36

Please sign in to comment.