diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 5ac0de4335f7f..c9da9ddb58c37 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -961,6 +961,50 @@ def _impl(inputs, attr, params): return AttrCvt(op_name=name)(inputs, attr) return _impl +def _space_to_batch_nd(): + def _impl(inputs, attr, params): + input_node = inputs[0] + input_shape = attr['_input_shapes'][input_node] + block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist() + paddings = params.pop(inputs[2].name_hint).asnumpy().tolist() + N = len(input_shape) + M = len(block_shape) + batch = input_shape[0] + remaining_shape_length = N - M - 1 + paddings = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length + # From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d: + + # Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings + # to produce padded of shape padded_shape. + padded = tvm.relay.nn.pad(input_node, pad_width=paddings) + + # Reshape padded to reshaped_padded of shape: + # [batch] + [padded_shape[1] / block_shape[0], block_shape[0], ..., + # padded_shape[M] / block_shape[M-1], block_shape[M-1]] + remaining_shape + shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2] + reshaped_padded = tvm.relay.reshape(padded, newshape=shape1) + # can be simplified, but clearer correspondence to the above + reshaped_padded_dim = 1 + 2 * M + remaining_shape_length + + # Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape: + # block_shape + [batch] + [padded_shape[1] / block_shape[0], ..., + # padded_shape[M] / block_shape[M-1]] + remaining_shape + axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \ + list(range(2 * M + 1, reshaped_padded_dim)) + permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes) + permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0] + + # Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension, + # producing an output tensor of shape: + # [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ..., + # padded_shape[M] / block_shape[M-1]] + remaining_shape + shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:] + reshaped_permuted_reshaped_padded = tvm.relay.reshape(permuted_reshaped_padded, newshape=shape2) + + return reshaped_permuted_reshaped_padded + return _impl + + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1035,6 +1079,7 @@ def _impl(inputs, attr, params): 'Split' : _split(False), 'SplitV' : _split(True), 'Unpack' : _unpack(), + 'SpaceToBatchND' : _space_to_batch_nd(), } def _LSTMBlockCell(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 10368ea3d9aba..d3de33f992e74 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -206,6 +206,19 @@ def test_forward_pooling(): dilation_rate=[1, 1], strides=[2, 1]) + # Tests involving SpaceToBatchND + _test_pooling(input_shape=[1, 1, 2, 1], + window_shape=[1, 1], + padding='VALID', + pooling_type=pool_type, + dilation_rate=[1, 2]) + + _test_pooling(input_shape=[1, 2, 1], + window_shape=[1], + padding='VALID', + pooling_type=pool_type, + dilation_rate=[2]) + ####################################################################### # Convolution # ----------- @@ -214,12 +227,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, dilations, strides, padding, data_format): """ One iteration of convolution with given shapes and attributes """ - total_size_1 = 1 - total_size_2 = 1 - for s in tensor_in_sizes: - total_size_1 *= s - for s in filter_in_sizes: - total_size_2 *= s + total_size_1 = np.prod(tensor_in_sizes) + total_size_2 = np.prod(filter_in_sizes) # Initializes the input tensor with array containing incrementing # numbers from 1. data_array = [f * 1.0 for f in range(1, total_size_1 + 1)] @@ -238,6 +247,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes, nn_ops.conv2d(in_data, in_filter, strides=strides, + dilations=dilations, padding=padding, data_format=data_format)