From 792aef4f352f9ab9ed61cb60e5ae16a15c880503 Mon Sep 17 00:00:00 2001 From: kice Date: Tue, 11 Feb 2020 16:44:48 -0500 Subject: [PATCH] Fix onnx import bugs (#4750) * Fix onnx import bugs Fix onnx attributes of string type incorrect handling Merge symmetric padding of Conv to symmetric form * Only merge symmetric padding for conv2d --- python/tvm/relay/frontend/onnx.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index dba5c8cb3574..9ecd950e3a3c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -342,6 +342,14 @@ def _impl_v1(cls, inputs, attr, params): msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.' raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad'])) attr.pop('auto_pad') + elif len(attr['kernel_shape']) == 2: + sym_pad = True + padding = attr['pads'] + for i in range(0, len(padding), 2): + sym_pad = sym_pad and padding[i] == padding[i + 1] + + if sym_pad: + attr['pads'] = padding[0::2] out = AttrCvt( op_name=dimension_picker('conv'), @@ -505,7 +513,7 @@ def _impl_v1(cls, inputs, attr, params): for i in range(dims): pad_width.append((pads[i], pads[i+dims])) attr['pad_width'] = pad_width - pad_mode = attr.get('mode', 'constant').decode('utf-8') + pad_mode = attr.get('mode', b'constant').decode('utf-8') if pad_mode in ['constant', 'edge', 'reflect']: attr['pad_mode'] = pad_mode attr.pop('mode', None) @@ -528,7 +536,7 @@ def _impl_v2(cls, inputs, attr, params): for i in range(dims): pad_width.append((pads[i], pads[i+dims])) attr['pad_width'] = pad_width - pad_mode = attr.get('mode', 'constant').decode('utf-8') + pad_mode = attr.get('mode', b'constant').decode('utf-8') if pad_mode in ['constant', 'edge', 'reflect']: attr['pad_mode'] = pad_mode attr.pop('mode', None) @@ -620,7 +628,7 @@ class DepthToSpace(OnnxOpConverter): def _impl_v11(cls, inputs, attr, params): block_size = int(attr['blocksize']) - mode = attr.get("mode", "DCR") + mode = attr.get('mode', b'DCR').decode('utf-8') return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)