diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 1b5318a83412..b357a2fbff30 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -984,6 +984,91 @@ def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) return _impl +def _space_to_batch_nd(): + def _impl(inputs, attr, params): + input_node = inputs[0] + input_shape = attr['_input_shapes'][input_node] + block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() + paddings = params.pop(inputs[2].name_hint).asnumpy().tolist() + N = len(input_shape) + M = len(block_shape) + batch = input_shape[0] + remaining_shape_length = N - M - 1 + paddings = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length + # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d: + # Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings + # to produce padded of shape padded_shape. + padded = tvm.relay.nn.pad(input_node, pad_width=paddings) + # Reshape padded to reshaped_padded of shape: + # [batch] + [padded_shape[1] / block_shape[0], block_shape[0], ..., + # padded_shape[M] / block_shape[M-1], block_shape[M-1]] + remaining_shape + shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2] + reshaped_padded = tvm.relay.reshape(padded, newshape=shape1) + # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape: + # block_shape + [batch] + [padded_shape[1] / block_shape[0], ..., + # padded_shape[M] / block_shape[M-1]] + remaining_shape + axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ + list(range(1 + 2 * M, 1 + 2 * M + remaining_shape_length)) + permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) + permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0] + # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, + # producing an output tensor of shape: + # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., + # padded_shape[M] / block_shape[M-1]] + remaining_shape + shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:] + reshaped_permuted_reshaped_padded = tvm.relay.reshape(permuted_reshaped_padded, + newshape=shape2) + return reshaped_permuted_reshaped_padded + + return _impl + + +def _batch_to_space_nd(): + def _impl(inputs, attr, params): + input_node = inputs[0] + input_shape = attr['_input_shapes'][input_node] + block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() + crops = params.pop(inputs[2].name_hint).asnumpy().tolist() + M = len(block_shape) + batch = input_shape[0] + # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d: + # Reshape input to reshaped of shape: + # [block_shape[0], ..., block_shape[M-1], batch / prod(block_shape), + # input_shape[1], ..., input_shape[N-1]] + shape1 = block_shape + [batch // np.prod(block_shape)] + input_shape[1:] + reshaped = tvm.relay.reshape(input_node, newshape=shape1) + # Permute dimensions of reshaped to produce permuted of shape + # [batch / prod(block_shape), input_shape[1], block_shape[0], ..., + # input_shape[M], block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]] + axes = [M] + [axis for i in range(M) for axis in [M + i + 1, i]] + \ + list(range(2 * M + 1, len(shape1))) + permuted = tvm.relay.transpose(reshaped, axes=axes) + # Reshape permuted to produce reshaped_permuted of shape + # [batch / prod(block_shape), input_shape[1] * block_shape[0], ..., + # input_shape[M] * block_shape[M-1], input_shape[M+1], ..., input_shape[N-1]] + shape2 = [0] + [-3] * M + [-2] + reshaped_permuted = tvm.relay.reshape(permuted, newshape=shape2) + # Crop the start and end of dimensions [1, ..., M] of reshaped_permuted according to crops + # to produce the output of shape: + # [batch / prod(block_shape), input_shape[1] * block_shape[0] - crops[0,0] - crops[0,1], + # ..., input_shape[M] * block_shape[M-1] - crops[M-1,0] - crops[M-1,1], + # input_shape[M+1], ..., input_shape[N-1]] + reshaped_permuted_shape = _infer_out_shapes(reshaped_permuted, params)[0] + cropped = reshaped_permuted + for axis in range(1, M+1): + crop = crops[axis - 1] + if crop != [0, 0]: + indices = tvm.relay.arange( + crop[0], + reshaped_permuted_shape[axis] - crop[1], + dtype='int32' + ) + cropped = tvm.relay.take(cropped, indices=indices, axis=axis) + + return cropped + + return _impl + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1060,6 +1145,8 @@ def _impl(inputs, attr, params): 'Split' : _split(False), 'SplitV' : _split(True), 'Unpack' : _unpack(), + 'SpaceToBatchND' : _space_to_batch_nd(), + 'BatchToSpaceND' : _batch_to_space_nd(), } def _LSTMBlockCell(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 35dca8008dfc..7e7c1510c60b 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -161,6 +161,7 @@ def is_gpu_available(): else: return False + ####################################################################### # Pooling # ------- @@ -221,6 +222,19 @@ def test_forward_pooling(): dilation_rate=[1, 1], strides=[2, 1]) + # Tests involving SpaceToBatchND + _test_pooling(input_shape=[1, 1, 2, 1], + window_shape=[1, 1], + padding='VALID', + pooling_type=pool_type, + dilation_rate=[1, 2]) + + _test_pooling(input_shape=[1, 2, 1], + window_shape=[1], + padding='VALID', + pooling_type=pool_type, + dilation_rate=[2]) + ####################################################################### # Convolution # ----------- @@ -229,12 +243,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, dilations, strides, padding, data_format): """ One iteration of convolution with given shapes and attributes """ - total_size_1 = 1 - total_size_2 = 1 - for s in tensor_in_sizes: - total_size_1 *= s - for s in filter_in_sizes: - total_size_2 *= s + total_size_1 = np.prod(tensor_in_sizes) + total_size_2 = np.prod(filter_in_sizes) # Initializes the input tensor with array containing incrementing # numbers from 1. data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] @@ -253,6 +263,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, nn_ops.conv2d(in_data, in_filter, strides=strides, + dilations=dilations, padding=padding, data_format=data_format) @@ -271,6 +282,116 @@ def test_forward_convolution(): _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC') _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC') +####################################################################### +# SpaceToBatchND +# -------------- +def _test_space_to_batch_nd(input_shape, block_shape, paddings, dtype='int32'): + data = np.random.uniform(0, 5, size=input_shape).astype(dtype) + + with tf.Graph().as_default(): + in_data = tf.placeholder(shape=input_shape, dtype=dtype) + out = tf.space_to_batch_nd(in_data, block_shape, paddings) + + compare_tf_with_tvm(data, in_data.name, out.name) + +def test_forward_space_to_batch_nd(): + # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d + _test_space_to_batch_nd( + input_shape=[1, 2, 2, 1], + block_shape=[2, 2], + paddings=[[0, 0], [0, 0]] + ) + + _test_space_to_batch_nd( + input_shape=[1, 2, 2, 3], + block_shape=[2, 2], + paddings=[[0, 0], [0, 0]] + ) + + _test_space_to_batch_nd( + input_shape=[1, 4, 4, 1], + block_shape=[2, 2], + paddings=[[0, 0], [0, 0]] + ) + + _test_space_to_batch_nd( + input_shape=[2, 2, 4, 1], + block_shape=[2, 2], + paddings=[[0, 0], [2, 0]], + dtype='int64' + ) + + # pylint: disable=line-too-long + # https://github.com/tensorflow/tensorflow/blob/24f578/tensorflow/python/kernel_tests/spacetobatch_op_test.py + _test_space_to_batch_nd( + input_shape=[2, 3], + block_shape=[2], + paddings=[[1, 0]], + dtype='float32' + ) + + _test_space_to_batch_nd( + input_shape=[2, 3, 2], + block_shape=[2], + paddings=[[1, 0]], + dtype='float64' + ) + +####################################################################### +# BatchToSpaceND +# -------------- +def _test_batch_to_space_nd(input_shape, block_shape, crops, dtype='int32'): + data = np.random.uniform(0, 5, size=input_shape).astype(dtype) + + with tf.Graph().as_default(): + in_data = tf.placeholder(shape=input_shape, dtype=dtype) + out = tf.batch_to_space_nd(in_data, block_shape, crops) + + compare_tf_with_tvm(data, in_data.name, out.name) + +def test_forward_batch_to_space_nd(): + # test cases: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-to-space-n-d + _test_batch_to_space_nd( + input_shape=[4, 1, 1, 1], + block_shape=[2, 2], + crops=[[0, 0], [0, 0]] + ) + + _test_batch_to_space_nd( + input_shape=[4, 1, 1, 3], + block_shape=[2, 2], + crops=[[0, 0], [0, 0]] + ) + + _test_batch_to_space_nd( + input_shape=[4, 2, 2, 1], + block_shape=[2, 2], + crops=[[0, 0], [0, 0]] + ) + + _test_batch_to_space_nd( + input_shape=[8, 1, 3, 1], + block_shape=[2, 2], + crops=[[0, 0], [2, 0]], + dtype='int64' + ) + + # pylint: disable=line-too-long + # https://github.com/tensorflow/tensorflow/blob/24f578/tensorflow/python/kernel_tests/batchtospace_op_test.py + _test_batch_to_space_nd( + input_shape=[18, 2, 1, 2], + block_shape=[2, 3], + crops=[[1, 1], [0, 0]], + dtype='float32' + ) + + _test_batch_to_space_nd( + input_shape=[20, 5, 8, 7], + block_shape=[2, 2], + crops=[[1, 1], [1, 1]], + dtype='float64' + ) + ####################################################################### # Reshape # ------- @@ -1312,6 +1433,8 @@ def test_forward_rel_ops(): _test_forward_concat_v2() test_forward_lrn() test_forward_l2_normalize() + test_forward_space_to_batch_nd() + test_forward_batch_to_space_nd() # End to End test_forward_inception_v3()