Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFLite] Add transpose_conv to TFLite parser #4440

Merged
merged 1 commit into from
Dec 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
# pylint: disable=invalid-name, unused-argument, too-many-lines
"""Tensorflow lite frontend."""
from __future__ import absolute_import as _abs
import math
Expand Down Expand Up @@ -96,6 +96,7 @@ def __init__(self, model, subgraph, exp_tab):
'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd,
'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd,
'PRELU': self.convert_prelu,
'TRANSPOSE_CONV': self.convert_transpose_conv,
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -1370,6 +1371,84 @@ def convert_prelu(self, op):

return out

def convert_transpose_conv(self, op):
"""Convert TFLite TRANSPOSE_CONV"""
try:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.TensorType import TensorType
from tflite.Operator import Operator
from tflite.TransposeConvOptions import TransposeConvOptions
from tflite.Padding import Padding
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 3, "input tensors length should be 3"

# Input (data) Tensor. NHWC layout
input_tensor = input_tensors[2]
_, _, _, input_c = input_tensor.tensor.ShapeAsNumpy()
# Weights tensor. TFLite uses OHWI layout
weights_tensor = input_tensors[1]
out_channels, kernel_h, kernel_w, in_channels = weights_tensor.tensor.ShapeAsNumpy()
assert input_c == in_channels, \
"Input channel in the filter should match to channel in the input"
# output_shape Tensor. NHWC layout
output_shape_tensor = input_tensors[0]

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type = output_tensor.tensor.Type()
output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)

assert op.BuiltinOptionsType() == BuiltinOptions.TransposeConvOptions
op_options = op.BuiltinOptions()
deconv_options = TransposeConvOptions()
deconv_options.Init(op_options.Bytes, op_options.Pos)

padding = deconv_options.Padding()
stride_h = deconv_options.StrideH()
stride_w = deconv_options.StrideW()
assert padding in (Padding.VALID, Padding.SAME), \
'Padding format {} is not supported for operator TRANSPOSE_CONV'.format(padding)

# Data
in_expr = self.get_expr(input_tensor.tensor_idx)

# Weights
weights_tensor_type = weights_tensor.tensor.Type()
# weights tensor type should be UINT8 (quantization) or FLOAT32
assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
weight_value_ohwi = self.get_tensor_value(weights_tensor)
# Relay kernel_layout should be OIHW
# Relay weights layout should be different from kernel_layout - it should be IOHW
weight_value_iohw = np.transpose(weight_value_ohwi, (3, 0, 1, 2))
weight_expr_iohw = self.exp_tab.new_const(weight_value_iohw, dtype=weight_tensor_type_str)

# Output shape value
output_shape_value = self.get_tensor_value(output_shape_tensor)
# Relay expects filter output channel to match to output tensor channel.
assert out_channels == output_shape_value[3], \
"Output channel in the filter should match to channel in the output_shape"

# TF frontend supports 'SAME' padding for kernel 1x1 only. Lets do the same here
if padding == Padding.SAME:
assert (kernel_h, kernel_w) == (1, 1), \
"SAME padding is supported for kernel (1,1) only"
Comment on lines +1437 to +1440
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if kh kw is 3x3, what is the current error msg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where? in TF or TFLite?

Copy link
Contributor Author

@apivovarov apivovarov Nov 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in TFLite it is


  File "test_forward.py", line 519, in test_forward_transpose_conv
    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], 'SAME')

  File "test_forward.py", line 504, in _test_transpose_conv
    compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])

  File "test_forward.py", line 162, in compare_tflite_with_tvm
    num_output=len(out_names), out_names=out_names)

  File "test_forward.py", line 75, in run_tvm_graph
    dtype_dict=dtype_dict)

  File "/home/dlc/workplace/apivovarov/incubator-tvm/python/tvm/relay/frontend/tflite.py", line 1572, in from_tflite
    op_converter.convert_op_to_relay()

  File "/home/dlc/workplace/apivovarov/incubator-tvm/python/tvm/relay/frontend/tflite.py", line 125, in convert_op_to_relay
    ret = self.convert_map[op_code_str](op)

  File "/home/dlc/workplace/apivovarov/incubator-tvm/python/tvm/relay/frontend/tflite.py", line 1440, in convert_transpose_conv
    "SAME padding is supported for kernel (1,1) only"

AssertionError: SAME padding is supported for kernel (1,1) only

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both should be the same. i.e. when our model is 3x3 conv_transpose, what is the current error msg of TVM?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean if we remove this assert, what will be happened of TVM? If we add this assert, we can not support 3x3 conv_transpose

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in TF they only test kernel 1,1 for SAME padding
test_forward.py L381

test_forward.py L385

Other kernels will fail for SAME padding. More info on it. https://discuss.tvm.ai/t/why-we-only-support-kernel-1-1-for-tf-conv2d-transpose-same/4957

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we support non-1x1 conv_transpose too? I think maybe it is a good time to do it completely no matter tf or tflite.

Copy link
Contributor Author

@apivovarov apivovarov Nov 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two types of padding in TF and TFLite - VALID and SAME.
conv_transpose op does opposite of what conv2d is doing.
conv2d output size is the same or smaller than its input. Bigger the kernel - smaller the output.
conv_transpose is opposite - Bigger the kernel - bigger the output.
We can use any kernel for padding 'VALID' - 2x2, 3x3, etc.
But if we use padding 'SAME' then the output size should be the same as the input.
If kernel is 1x1 - then the output is the same as the input. So, kernel 1x1 is implicitly SAME.
If we increase conv_transpose kernel size then the output will have extra paddings.
In order to make the output size to be the SAME as the input we need to remove padding from the output.
The model which needs conv_transpose op is palm_detection.tflite. It uses VALID padding.
Looks like in most of the cases people use conv2d and conv_transpose with VALID padding.
I think it is fine if we merge the PR with known limitation for padding SAME (it is not that frequently used anyway). We probably should wait till TF frontend supports SAME paddings for non-1x1 kernel and do the same in TFLite frontend.
Maybe they will decide to add VALID/SAME padding field to Relay op directly. In that case we will just pass padding type to Relay op as-is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.


out = _op.nn.conv2d_transpose(in_expr, weight_expr_iohw,
strides=(stride_h, stride_w),
channels=int(out_channels),
kernel_size=(int(kernel_h), int(kernel_w)),
data_layout="NHWC",
kernel_layout="OIHW",
out_dtype=output_tensor_type_str)

return out

def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down
55 changes: 55 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,60 @@ def test_forward_convolution():
_test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)


#######################################################################
# Transpose Convolution
# ---------------------

def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides, padding):
""" One iteration of transpose convolution with given shapes and attributes """

total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype='float32')
strides = [1] + strides + [1]
# in_filter layout is HWOI
out = nn_ops.conv2d_transpose(in_data,
in_filter,
output_shape=output_shape,
strides=strides,
padding=padding)
data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
compare_tflite_with_tvm(data_array, 'Placeholder:0', [in_data], [out])


def test_forward_transpose_conv():
# kernel 3x3, padding VALID
_test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], 'VALID')

# kernel 2x2, padding VALID
_test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], 'VALID')

# kernel 1x1, padding VALID
_test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'VALID')
_test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'VALID')

# kernel 1x1, padding SAME
_test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], 'SAME')
_test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], 'SAME')
_test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], 'SAME')

apivovarov marked this conversation as resolved.
Show resolved Hide resolved

#######################################################################
# Reshape
# -------
Expand Down Expand Up @@ -1212,6 +1266,7 @@ def test_forward_mediapipe_hand_landmark():

# NN
test_forward_convolution()
test_forward_transpose_conv()
test_forward_logistic()
test_forward_pooling()
test_forward_softmax()
Expand Down