Skip to content

Commit

Permalink
Implement SpaceToBatchND in Tensorflow frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
alexeyr committed Apr 1, 2019
1 parent 3259e6b commit 285034e
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
45 changes: 45 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,50 @@ def _impl(inputs, attr, params):
return AttrCvt(op_name=name)(inputs, attr)
return _impl

def _space_to_batch_nd():
def _impl(inputs, attr, params):
input_node = inputs[0]
input_shape = attr['_input_shapes'][input_node]
block_shape = params.pop(inputs[1].name_hint).asnumpy().tolist()
paddings = params.pop(inputs[2].name_hint).asnumpy().tolist()
N = len(input_shape)
M = len(block_shape)
batch = input_shape[0]
remaining_shape_length = N - M - 1
paddings = [(0, 0)] + paddings + [(0, 0)] * remaining_shape_length
# From https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/space-to-batch-n-d:

# Zero-pad the start and end of dimensions [1, ..., M] of the input according to paddings
# to produce padded of shape padded_shape.
padded = tvm.relay.nn.pad(input_node, pad_width=paddings)

# Reshape padded to reshaped_padded of shape:
# [batch] + [padded_shape[1] / block_shape[0], block_shape[0], ...,
# padded_shape[M] / block_shape[M-1], block_shape[M-1]] + remaining_shape
shape1 = [batch] + [item for i in range(M) for item in [-4, -1, block_shape[i]]] + [-2]
reshaped_padded = tvm.relay.reshape(padded, newshape=shape1)
# can be simplified, but clearer correspondence to the above
reshaped_padded_dim = 1 + 2 * M + remaining_shape_length

# Permute dimensions of reshaped_padded to produce permuted_reshaped_padded of shape:
# block_shape + [batch] + [padded_shape[1] / block_shape[0], ...,
# padded_shape[M] / block_shape[M-1]] + remaining_shape
axes = [2 * i + 2 for i in range(M)] + [0] + [2 * i + 1 for i in range(M)] + \
list(range(2 * M + 1, reshaped_padded_dim))
permuted_reshaped_padded = tvm.relay.transpose(reshaped_padded, axes=axes)
permuted_reshaped_padded_shape = _infer_out_shapes(permuted_reshaped_padded, params)[0]

# Reshape permuted_reshaped_padded to flatten block_shape into the batch dimension,
# producing an output tensor of shape:
# [batch * prod(block_shape)] + [padded_shape[1] / block_shape[0], ...,
# padded_shape[M] / block_shape[M-1]] + remaining_shape
shape2 = [batch * np.prod(block_shape)] + list(permuted_reshaped_padded_shape)[M + 1:]
reshaped_permuted_reshaped_padded = tvm.relay.reshape(permuted_reshaped_padded, newshape=shape2)

return reshaped_permuted_reshaped_padded
return _impl


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1035,6 +1079,7 @@ def _impl(inputs, attr, params):
'Split' : _split(False),
'SplitV' : _split(True),
'Unpack' : _unpack(),
'SpaceToBatchND' : _space_to_batch_nd(),
}

def _LSTMBlockCell():
Expand Down
22 changes: 16 additions & 6 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,19 @@ def test_forward_pooling():
dilation_rate=[1, 1],
strides=[2, 1])

# Tests involving SpaceToBatchND
_test_pooling(input_shape=[1, 1, 2, 1],
window_shape=[1, 1],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[1, 2])

_test_pooling(input_shape=[1, 2, 1],
window_shape=[1],
padding='VALID',
pooling_type=pool_type,
dilation_rate=[2])

#######################################################################
# Convolution
# -----------
Expand All @@ -214,12 +227,8 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format):
""" One iteration of convolution with given shapes and attributes """

total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
total_size_1 = np.prod(tensor_in_sizes)
total_size_2 = np.prod(filter_in_sizes)
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
Expand All @@ -238,6 +247,7 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
nn_ops.conv2d(in_data,
in_filter,
strides=strides,
dilations=dilations,
padding=padding,
data_format=data_format)

Expand Down

0 comments on commit 285034e

Please sign in to comment.