From 5acdaeef16458866e9a723e1c3950d05376ea4e3 Mon Sep 17 00:00:00 2001 From: optima2005 Date: Wed, 13 Nov 2019 05:22:36 +0000 Subject: [PATCH] add transformation from NHWC to NCHW to compatible with TVM conv2d_transpose implementation --- python/tvm/relay/frontend/tensorflow.py | 12 ++++++++++-- tests/python/frontend/tensorflow/test_forward.py | 9 +++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 93b5cf7d0f36..d86e2b7a2fe6 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -196,8 +196,16 @@ def _impl(inputs, attr, params): flip_layout = False if opname == 'conv_transpose' and attr['data_format'] == 'NHWC': - raise NotImplementedError( \ - "conv2d_transpose with NHWC layout is not implemented.") + # transform to NCHW for TVM backend compatible and set 'flip_layout' + # to have output flip back to NHWC + tmp_shape = attr['_input_shapes'][inputs[2]] + tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)] + inputs[2] = _op.transpose(inputs[2], axes=(0, 3, 1, 2)) + attr['_input_shapes'][inputs[2]] = tmp_shape + attr['strides'][1], attr['strides'][2], attr['strides'][3] = \ + attr['strides'][3], attr['strides'][1], attr['strides'][2] + attr['data_format'] = 'NCHW' + flip_layout = True inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 17b168424823..dd520d3c601e 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -379,6 +379,15 @@ def test_forward_convolution(): _test_convolution('depthwise', [4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC') _test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC') + _test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', + 'NHWC', [4, 8, 8, 176]) + _test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', + 'NHWC', [4, 17, 17, 19]) + _test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', + 'NHWC', [4, 17, 17, 124]) + _test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', + 'NHWC', [4, 17, 17, 12]) + ####################################################################### # BiasAdd