Skip to content

Commit

Permalink
[Frontend][TFLite] Add MIRROR_PAD operator (apache#4822)
Browse files Browse the repository at this point in the history
  • Loading branch information
wyc-ruiker authored and alexwong committed Feb 28, 2020
1 parent 97248b4 commit cbb12fc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
43 changes: 41 additions & 2 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, model, subgraph, exp_tab):
'SUM': self._convert_reduce_sum,
'FULLY_CONNECTED': self.convert_fully_connected,
'PAD': self.convert_pad,
'MIRROR_PAD': self.convert_mirror_pad,
'PACK': self.convert_pack,
'UNPACK': self.convert_unpack,
'LOGISTIC': self.convert_logistic,
Expand Down Expand Up @@ -1472,7 +1473,7 @@ def convert_pad(self, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

# TFLite only support CONSTANT mode and does not support constant_values parameter.
# TFLite PAD only support CONSTANT mode and does not support constant_values parameter.
# tensor
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)
Expand All @@ -1482,10 +1483,48 @@ def convert_pad(self, op):
# convert list of lists to tuple of tuples
paddings = tuple(tuple(l) for l in pad_list)

# Use default pad_value 0 because TFLite does not support constant_values parameter
# Use default pad_value 0 because TFLite PAD does not support constant_values parameter
out = _op.nn.pad(in_expr, paddings)
return out

def convert_mirror_pad(self, op):
"""Convert TFLite MIRROR_PAD"""
try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.MirrorPadOptions import MirrorPadOptions
except ImportError:
raise ImportError("The tflite package must be installed")

# the quantized form MirrorPad is not yet implemented in TFLite.
if self.is_quantized(op):
raise tvm.error.OpNotImplemented(
'TFlite quantized MIRROR_PAD operator is not supported yet.')

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 2, "input tensors length should be 2"

# tensor
input_tensor = input_tensors[0]
in_expr = self.get_expr(input_tensor.tensor_idx)

# paddings
pad_list = self.get_tensor_value(input_tensors[1])
# convert list of lists to tuple of tuples
paddings = tuple(tuple(l) for l in pad_list)

assert op.BuiltinOptionsType() == BuiltinOptions.MirrorPadOptions
op_options = op.BuiltinOptions()
mirror_pad_options = MirrorPadOptions()
mirror_pad_options.Init(op_options.Bytes, op_options.Pos)
mode_byte = mirror_pad_options.Mode()

mode = "REFLECT" if mode_byte == 0 else "SYMMETRIC"
out = _op.nn.mirror_pad(in_expr, paddings, mode)

return out

def convert_pack(self, op):
"""Convert TFLite pack"""
try:
Expand Down
8 changes: 6 additions & 2 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,15 +1139,15 @@ def test_forward_squeeze():
# Pad
# ---

def _test_pad(data):
def _test_pad(data, mode="CONSTANT"):
""" One iteration of PAD """

assert len(data) == 2

# Test with tensor and constant
with tf.Graph().as_default():
in_data = [array_ops.placeholder(shape=data[0].shape, dtype=data[0].dtype, name='in')]
out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype))
out = array_ops.pad(in_data[0], ops.convert_to_tensor(data[1], dtype=data[1].dtype), mode=mode)
compare_tflite_with_tvm([data[0]], ['in:0'], in_data, [out])


Expand All @@ -1161,6 +1161,10 @@ def test_forward_pad():
np.array([[1, 1], [2, 2]], dtype=np.int32)])
_test_pad([np.arange(1.0, 4.0, dtype=np.float32).reshape((1, 3)),
np.array([[1, 1], [2, 2]], dtype=np.int32)])
_test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="REFLECT")
_test_pad([np.arange(1.0, 7.0, dtype=np.float32).reshape((2, 3)),
np.array([[1, 1], [2, 2]], dtype=np.int32)], mode="SYMMETRIC")


#######################################################################
Expand Down

0 comments on commit cbb12fc

Please sign in to comment.