diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index f65446691023..47bd0cdaeec8 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -603,7 +603,7 @@ def _impl(inputs, attr, params, mod): out = AttrCvt( op_name=_dimension_picker('conv', surfix="_transpose" if opname == 'conv_transpose' else ""), - ignores=['explicit_paddings'], + ignores=['explicit_paddings', 'Tshape'], transforms={ 'kernel_shape': 'kernel_size', 'data_format': 'data_layout', @@ -2046,6 +2046,7 @@ def _impl(inputs, attr, params, mod): 'Conv2D' : _conv('conv'), 'Conv2DBackpropInput' : _conv('conv_transpose'), 'Conv3D' : _conv3d('conv'), + 'Conv3DBackpropInputV2' : _conv3d('conv_transpose'), 'Cos' : AttrCvt('cos'), 'Cosh' : AttrCvt('cosh'), 'CropAndResize' : _crop_and_resize(), diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 252c4cd8904a..1a4294b52bfc 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -533,6 +533,92 @@ def test_forward_convolution3d(): _test_convolution3d('conv', [4, 17, 17, 17, 12], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC') +####################################################################### +# Convolution3D Transpose +# ----------------------- + +def _test_convolution3d_transpose(data_shape, filter_shape, strides, + padding, output_shape, data_format='NCDHW'): + """ One iteration of 3D convolution transpose with given shapes and attributes """ + + dtype = 'float32' + data_array = np.random.uniform(size=data_shape).astype(dtype) + filter_array = np.random.uniform(size=filter_shape).astype(dtype) + if data_format == 'NDHWC': + strides = [1] + strides + [1] + else: + strides = [1, 1] + strides + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data_shape, dtype=dtype) + in_filter = constant_op.constant( + filter_array, shape=filter_shape, dtype=dtype) + + nn_ops.conv3d_transpose(in_data, + in_filter, + output_shape=output_shape, + strides=strides, + padding=padding, + data_format=data_format) + + compare_tf_with_tvm(data_array, 'Placeholder:0', 'conv3d_transpose:0', cuda_layout="NDHWC") + + +def test_forward_convolution3d_transpose(): + if is_gpu_available(): + _test_convolution3d_transpose(data_shape=[1, 10, 8, 8, 8], + filter_shape=[1, 1, 1, 6, 10], + strides=[1, 1, 1], + padding='VALID', + output_shape=[1, 6, 8, 8, 8]) + + _test_convolution3d_transpose(data_shape=[4, 9, 8, 8, 8], + filter_shape=[1, 1, 1, 6, 9], + strides=[1, 1, 1], + padding='VALID', + output_shape=[4, 6, 8, 8, 8]) + + _test_convolution3d_transpose(data_shape=[1, 3, 8, 8, 8], + filter_shape=[1, 1, 1, 6, 3], + strides=[2, 2, 2], + padding='SAME', + output_shape=[1, 6, 15, 15, 15]) + + _test_convolution3d_transpose(data_shape=[1, 16, 8, 8, 8], + filter_shape=[3, 3, 3, 6, 16], + strides=[3, 3, 3], + padding='VALID', + output_shape=[1, 6, 24, 24, 24]) + + _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 10], + filter_shape=[1, 1, 1, 6, 10], + strides=[1, 1, 1], + padding='VALID', + output_shape=[1, 8, 8, 8, 6], + data_format='NDHWC') + + _test_convolution3d_transpose(data_shape=[4, 8, 8, 8, 9], + filter_shape=[1, 1, 1, 6, 9], + strides=[1, 1, 1], + padding='VALID', + output_shape=[4, 8, 8, 8, 6], + data_format='NDHWC') + + _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 3], + filter_shape=[1, 1, 1, 6, 3], + strides=[2, 2, 2], + padding='SAME', + output_shape=[1, 15, 15, 15, 6], + data_format='NDHWC') + + _test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 16], + filter_shape=[3, 3, 3, 6, 16], + strides=[3, 3, 3], + padding='VALID', + output_shape=[1, 24, 24, 24, 6], + data_format='NDHWC') + + ####################################################################### # BiasAdd # ----------- @@ -3728,6 +3814,7 @@ def test_forward_spop(): # NN test_forward_convolution() test_forward_convolution3d() + test_forward_convolution3d_transpose() test_forward_pooling() test_forward_concat_v2() test_forward_lrn()