Skip to content

Commit

Permalink
Fix onnx import bugs (apache#4750)
Browse files Browse the repository at this point in the history
* Fix onnx import bugs

Fix onnx attributes of string type incorrect handling
Merge symmetric padding of Conv to symmetric form

* Only merge symmetric padding for conv2d
  • Loading branch information
kice authored and alexwong committed Feb 28, 2020
1 parent 0771b02 commit e4b8b35
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,14 @@ def _impl_v1(cls, inputs, attr, params):
msg = 'Value {} in attribute "auto_pad" of operator Conv is invalid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['auto_pad']))
attr.pop('auto_pad')
elif len(attr['kernel_shape']) == 2:
sym_pad = True
padding = attr['pads']
for i in range(0, len(padding), 2):
sym_pad = sym_pad and padding[i] == padding[i + 1]

if sym_pad:
attr['pads'] = padding[0::2]

out = AttrCvt(
op_name=dimension_picker('conv'),
Expand Down Expand Up @@ -505,7 +513,7 @@ def _impl_v1(cls, inputs, attr, params):
for i in range(dims):
pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width
pad_mode = attr.get('mode', 'constant').decode('utf-8')
pad_mode = attr.get('mode', b'constant').decode('utf-8')
if pad_mode in ['constant', 'edge', 'reflect']:
attr['pad_mode'] = pad_mode
attr.pop('mode', None)
Expand All @@ -528,7 +536,7 @@ def _impl_v2(cls, inputs, attr, params):
for i in range(dims):
pad_width.append((pads[i], pads[i+dims]))
attr['pad_width'] = pad_width
pad_mode = attr.get('mode', 'constant').decode('utf-8')
pad_mode = attr.get('mode', b'constant').decode('utf-8')
if pad_mode in ['constant', 'edge', 'reflect']:
attr['pad_mode'] = pad_mode
attr.pop('mode', None)
Expand Down Expand Up @@ -620,7 +628,7 @@ class DepthToSpace(OnnxOpConverter):
def _impl_v11(cls, inputs, attr, params):

block_size = int(attr['blocksize'])
mode = attr.get("mode", "DCR")
mode = attr.get('mode', b'DCR').decode('utf-8')
return _op.nn.depth_to_space(inputs[0], block_size, mode=mode)


Expand Down

0 comments on commit e4b8b35

Please sign in to comment.