diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 621115120677..03f5b1bea6cd 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -113,3 +113,4 @@ We do encourage everyone to work anything they are interested in. - [Cody Hao Yu](https://github.com/comaniac) - [Chris Nuernberger](https://github.com/cnuernber) - [Shoubhik Bhattacharya](https://github.com/shoubhik) +- [Neo Chien](https://github.com/cchung100m) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index eed3d8192593..162cc36a89e5 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -81,7 +81,8 @@ def __init__(self, model, subgraph, exp_tab): 'PAD': self.convert_pad, 'PACK': self.convert_pack, 'LOGISTIC': self.convert_logistic, - 'SPLIT': self.convert_split + 'SPLIT': self.convert_split, + 'TRANSPOSE': self.convert_transpose } def check_unsupported_ops(self): @@ -743,6 +744,31 @@ def convert_split(self, op): return out + def convert_transpose(self, op): + """transpose implementation.""" + try: + from tflite.Operator import Operator + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + input_tensor = input_tensors[0] + input_tensor_idx = input_tensor.tensor_idx + + in_expr = self.get_expr(input_tensor_idx) + + # axis + in_axis = tuple(self.get_tensor_value(input_tensors[1])) + + if not in_axis: + out = _op.transpose(in_expr) + else: + out = _op.transpose(in_expr, in_axis) + + return out + def convert_pool2d(self, op, pool_type): """pool2d implementation.""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 2c356d8c9156..a78225cd5646 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -201,6 +201,35 @@ def test_forward_split(): _test_split((1, 3, 6, 5), -2, 3, 'float32') _test_split((1, 3, 5, 6), -1, 3, 'float32') +####################################################################### +# transpose +# --------- + + +def _test_forward_transpose(ishape, axes=()): + data = np.random.uniform(size=ishape).astype(np.float32) + + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + + if not axes: + out = array_ops.transpose(in_data) + else: + out = array_ops.transpose(in_data, axes) + + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + + +def test_forward_transpose(): + _test_forward_transpose((2, 2)) + _test_forward_transpose((2, 3, 4)) + _test_forward_transpose((7, 8, 8, 10)) + _test_forward_transpose((2, 3, 4), (1, 2, 0)) + _test_forward_transpose((2, 3, 4), (0, 1, 2)) + _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2)) + _test_forward_transpose((2, 3, 4, 5), ()) + + ####################################################################### # Pooling # ------- @@ -823,6 +852,10 @@ def test_forward_ssd_mobilenet_v1(): if __name__ == '__main__': # Split test_forward_split() + + # Transpose + test_forward_transpose() + # Transforms test_forward_concatenation() test_forward_pad()