diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b4d36306c85d..eba02e70c865 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -335,6 +335,23 @@ class Reciprocal(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return _expr.const(1.0) / inputs[0] + +class Flatten(OnnxOpConverter): + """ Operator converter for Flatten. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get('axis', 1) + if axis == 1: + out = _op.nn.batch_flatten(inputs[0]) + else: + newshape = [0] * (axis + 1) + newshape[axis] = -1 + out = _op.reshape(inputs[0], list(newshape)) + return out + + class Reshape(OnnxOpConverter): """ Operator converter for Reshape. """ @@ -850,7 +867,7 @@ def _get_convert_map(opset): # 'InstanceNormalization' # 'LpNormalization' 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), - 'Flatten': Renamer('batch_flatten'), + 'Flatten': Flatten.get_converter(opset), 'LRN': LRN.get_converter(opset), # defs/reduction diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7be6bb611e9a..f867e73e8c08 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -211,6 +211,29 @@ def test_squeeze(): tvm.testing.assert_allclose(out_shape, tvm_out.shape) +def test_flatten(): + + in_shape = (1, 3, 4, 4) + axis = 1 + ref_shape = (1, 48) + + flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis = axis) + + graph = helper.make_graph([flatten_node], + "flatten_test", + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(ref_shape))]) + + model = helper.make_model(graph, producer_name='flatten_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('int32') + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + def test_unsqueeze(): in_shape = (3, 3) axis = (0, 3, 4) @@ -1046,6 +1069,7 @@ def test_LogSoftmax(): {'axis': 1}) if __name__ == '__main__': + test_flatten() test_reshape() test_shape() test_power()