diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a519c6fd8b44..79ad9176ed27 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -88,6 +88,7 @@ def __init__(self, model, subgraph, exp_tab): 'RELU':self.convert_relu, 'SPLIT': self.convert_split, 'TRANSPOSE': self.convert_transpose, + 'CAST': self.convert_cast, 'TILE': self.convert_tile, 'BATCH_TO_SPACE_ND': self.convert_batch_to_space_nd, 'SPACE_TO_BATCH_ND': self.convert_space_to_batch_nd @@ -181,6 +182,9 @@ def get_tensor_value(self, tensor_wrapper): if tensor_wrapper.tensor.Type() == TensorType.INT32: return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape( tensor_wrapper.tensor.ShapeAsNumpy()) + if tensor_wrapper.tensor.Type() == TensorType.INT64: + return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape( + tensor_wrapper.tensor.ShapeAsNumpy()) raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_wrapper.tensor.Type()))) @@ -197,6 +201,8 @@ def get_tensor_type_str(self, tensor_type): return "float32" if tensor_type == TensorType.INT32: return "int32" + if tensor_type == TensorType.INT64: + return "int64" raise NotImplementedError("Tensor type {} is currently not supported" .format(str(tensor_type))) @@ -840,6 +846,31 @@ def convert_transpose(self, op): return out + def convert_cast(self, op): + """Convert TFLite CAST""" + try: + from tflite.Operator import Operator + from tflite.BuiltinOptions import BuiltinOptions + from tflite.CastOptions import CastOptions + except ImportError: + raise ImportError("The tflite package must be installed") + + assert isinstance(op, Operator) + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + input_tensor = input_tensors[0] + in_expr = self.get_expr(input_tensor.tensor_idx) + + assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions + op_options = op.BuiltinOptions() + cast_options = CastOptions() + cast_options.Init(op_options.Bytes, op_options.Pos) + cast_dtype = cast_options.OutDataType() + + out = _op.cast(in_expr, self.get_tensor_type_str(cast_dtype)) + + return out + def convert_tile(self, op): """tile implementation.""" try: diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 670e85ba8384..e4013d31935c 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -230,6 +230,24 @@ def test_forward_transpose(): _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2)) _test_forward_transpose((2, 3, 4, 5), ()) +####################################################################### +# Cast +# -------- + +def _test_cast(data, cast_dtype): + """ One iteration of CAST """ + with tf.Graph().as_default(): + in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) + out = math_ops.cast(in_data, cast_dtype) + compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out]) + + +def test_forward_cast(): + """ CAST """ + _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32) + _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8) + _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64) + ####################################################################### # tile # --------- @@ -1013,6 +1031,9 @@ def test_forward_ssd_mobilenet_v1(): # Transpose test_forward_transpose() + # Cast + test_forward_cast() + # Tile test_forward_tile()