diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 35c857a3d77f..93b5cf7d0f36 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -195,10 +195,16 @@ def _impl(inputs, attr, params): attr['data_format'] = attr['data_format'].decode("utf-8") flip_layout = False + if opname == 'conv_transpose' and attr['data_format'] == 'NHWC': + raise NotImplementedError( \ + "conv2d_transpose with NHWC layout is not implemented.") + + inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2] + # NCHW Layout require weights transpose if attr['data_format'] == 'NCHW': tmp_shape = attr['_input_shapes'][inputs[1]] - if opname == 'conv': + if opname in ['conv', 'conv_transpose']: tmp_shape = [tmp_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) else: @@ -206,13 +212,13 @@ def _impl(inputs, attr, params): inputs[1] = _op.transpose(inputs[1], axes=(2, 3, 0, 1)) attr['_input_shapes'][inputs[1]] = tmp_shape - input_shape = attr['_input_shapes'][inputs[0]] + input_shape = attr['_input_shapes'][inputs_data] weights_shape = attr['_input_shapes'][inputs[1]] if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC": input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)] - inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2)) - if opname == 'conv': + inputs_data = _op.transpose(inputs_data, axes=(0, 3, 1, 2)) + if opname in ['conv', 'conv_transpose']: weights_shape = [weights_shape[ii] for ii in (3, 2, 0, 1)] inputs[1] = _op.transpose(inputs[1], axes=(3, 2, 0, 1)) else: @@ -228,6 +234,8 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (weights_shape[0], weights_shape[1]) if opname == 'conv': attr['channels'] = weights_shape[3] + elif opname == 'conv_transpose': + attr['channels'] = weights_shape[2] else: attr['channels'] = input_shape[3] * depth_mult @@ -239,6 +247,8 @@ def _impl(inputs, attr, params): attr['kernel_shape'] = (weights_shape[2], weights_shape[3]) if opname == 'conv': attr['channels'] = weights_shape[0] + elif opname == 'conv_transpose': + attr['channels'] = weights_shape[1] else: attr['channels'] = input_shape[1] * depth_mult if attr['channels'] < 0: @@ -279,17 +289,17 @@ def _impl(inputs, attr, params): if attr['data_format'] == 'NHWC': - inputs[0] = _op.nn.pad(data=inputs[0], - pad_width=((0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]), - (0, 0))) + inputs_data = _op.nn.pad(data=inputs_data, + pad_width=((0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]), + (0, 0))) else: - inputs[0] = _op.nn.pad(data=inputs[0], - pad_width=((0, 0), - (0, 0), - (pad_v[0], pad_v[1]), - (pad_h[0], pad_h[1]))) + inputs_data = _op.nn.pad(data=inputs_data, + pad_width=((0, 0), + (0, 0), + (pad_v[0], pad_v[1]), + (pad_h[0], pad_h[1]))) attr['padding'] = [0, 0] @@ -299,27 +309,30 @@ def _impl(inputs, attr, params): raise tvm.error.OpAttributeInvalid(msg.format(attr['padding'])) if 'kernel_layout' not in attr: - if opname == 'conv': + if opname in ['conv', 'conv_transpose']: attr['kernel_layout'] = 'HWIO' if attr['data_format'] == 'NHWC' else 'OIHW' else: attr['kernel_layout'] = 'HWOI' if attr['data_format'] == 'NHWC' else 'OIHW' - use_bias = len(inputs) == 3 + use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4) channel_axis = 1 if attr['data_format'] == "NCHW" else 3 # Ignore the new attributes from TF2.0, for now. out = AttrCvt( - op_name=_dimension_picker('conv'), + op_name=_dimension_picker('conv', \ + surfix="_transpose" if opname == 'conv_transpose' else ""), ignores=['explicit_paddings'], transforms={ 'kernel_shape': 'kernel_size', 'data_format': 'data_layout', 'dilations': ('dilation', (0, 0)), 'group': ('groups', 1)}, - custom_check=_dimension_constraint())([inputs[0], inputs[1]], attr) + custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr) if use_bias: - out = _op.nn.bias_add(out, inputs[2], axis=channel_axis) + out = _op.nn.bias_add(out, + inputs[2] if opname != 'conv_transpose' else inputs[3], + axis=channel_axis) if flip_layout: out = _op.transpose(out, axes=(0, 2, 3, 1)) @@ -1403,6 +1416,7 @@ def _impl(inputs, attr, params): 'Concat' : _concat(), 'ConcatV2' : _concatV2(), 'Conv2D' : _conv('conv'), + 'Conv2DBackpropInput' : _conv('conv_transpose'), 'CropAndResize' : _crop_and_resize(), 'DecodeJpeg' : _decode_image(), 'DepthwiseConv2dNative' : _conv('depthwise'), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 17db2f5cc9a8..17b168424823 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -295,7 +295,8 @@ def test_forward_pooling(): def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, - dilations, strides, padding, data_format): + dilations, strides, padding, data_format, + deconv_output_shape=[]): """ One iteration of convolution with given shapes and attributes """ total_size_1 = np.prod(tensor_in_sizes) @@ -326,6 +327,17 @@ def _test_convolution(opname, tensor_in_sizes, filter_in_sizes, compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), 'Placeholder:0', 'Conv2D:0') + elif opname == 'conv_transpose': + nn_ops.conv2d_transpose(in_data, + in_filter, + output_shape=deconv_output_shape, + strides=strides, + dilations=dilations, + padding=padding, + data_format=data_format) + + compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'), + 'Placeholder:0', 'conv2d_transpose:0') else: nn_ops.depthwise_conv2d_native(in_data, in_filter, @@ -349,6 +361,14 @@ def test_forward_convolution(): _test_convolution('depthwise', [4, 124, 17, 17], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NCHW') _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NCHW') _test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW') + _test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 176, 8, 8]) + _test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', + 'NCHW', [4, 19, 17, 17]) + _test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', + 'NCHW', [4, 124, 17, 17]) + _test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', + 'NCHW', [4, 12, 17, 17]) _test_convolution('conv', [4, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution('conv', [4, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')