diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index bf1938b1481e..8c602694ad54 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -79,6 +79,7 @@ def __init__(self, model, subgraph, exp_tab): 'REDUCE_PROD': self._convert_reduce_prod, 'FULLY_CONNECTED': self.convert_fully_connected, 'PAD': self.convert_pad, + 'PACK': self.convert_pack, 'LOGISTIC': self.convert_logistic, } @@ -789,6 +790,33 @@ def convert_pad(self, op): out = _op.nn.pad(in_expr, paddings) return out + def convert_pack(self, op): + """Convert TFLite pack""" + try: + from tflite.BuiltinOptions import BuiltinOptions + from tflite.Operator import Operator + from tflite.PackOptions import PackOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) >= 1, "input tensors should greater than 1" + in_exprs = [self.get_expr(input_tensor.tensor_idx) for input_tensor in input_tensors] + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors should be 1" + + assert op.BuiltinOptionsType() == BuiltinOptions.PackOptions + op_options = op.BuiltinOptions() + pack_options = PackOptions() + pack_options.Init(op_options.Bytes, op_options.Pos) + pack_axis = pack_options.Axis() + + in_exprs_reshaped = [_op.expand_dims(i, axis=pack_axis, num_newaxis=1) for i in in_exprs] + out = _op.concatenate(in_exprs_reshaped, pack_axis) + return out + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 577e2dc56ab0..d88e7a320fed 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -581,6 +581,41 @@ def test_forward_pad(): np.array([[1, 1], [2, 2]], dtype=np.int32)]) +####################################################################### +# Pack +# ------------- + +def _test_pack(data, axis): + """ One iteration of pack """ + + assert len(data) >= 1 + + with tf.Graph().as_default(): + in_data = [ + array_ops.placeholder(shape=tensor.shape, dtype=tensor.dtype, name="in_{}".format(idx)) + for idx, tensor in enumerate(data)] + out = array_ops.pack(in_data, axis=axis) + name = ["in_{}:0".format(idx) for idx in range(len(data))] + + compare_tflite_with_tvm(data, name, in_data, [out]) + + +def test_forward_pack(): + """ Pack """ + _test_pack( + [np.arange(6).reshape((1, 2, 1, 3)), + np.arange(6).reshape((1, 2, 1, 3))], 1) + + _test_pack( + [np.arange(6).reshape((3, 2)), + np.arange(6).reshape((3, 2))], 1) + + _test_pack( + [np.arange(6).reshape((2, 1, 1, 3)), + np.arange(6).reshape((2, 1, 1, 3)), + np.arange(6).reshape((2, 1, 1, 3))], 1) + + ####################################################################### # Logistic # -------- @@ -750,6 +785,7 @@ def test_forward_ssd_mobilenet_v1(): # Transforms test_forward_concatenation() test_forward_pad() + test_forward_pack() test_forward_reshape() test_all_resize() test_forward_squeeze()