Skip to content

Commit

Permalink
Addressing reviews.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Jul 1, 2020
1 parent cef3a85 commit 0518b31
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 30 deletions.
54 changes: 27 additions & 27 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,40 +296,40 @@ def get_tensors(self, tensors_idx_list):
return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
return return_list

def get_tensor_value(self, tensor_wrapper):
"""Get tensor buffer value from given tensor wrapper"""

def get_tensor_type_as_numpy(self, tensor_wrapper):
"""Returns np.dtype out of TensorType"""
assert isinstance(tensor_wrapper, TensorWrapper)

try:
from tflite.TensorType import TensorType
return {TensorType.UINT8: np.uint8,
TensorType.INT8: np.int8,
TensorType.FLOAT32: np.float32,
TensorType.INT32: np.int32,
TensorType.INT64: np.int64,
TensorType.BOOL: np.bool_}[tensor_wrapper.tensor.Type()]
except ImportError:
raise ImportError("The tflite package must be installed")
except KeyError:
raise NotImplementedError("Tensor type '{}' currently not supported"
.format(tensor_wrapper.tensor.Type()))

# Read the data from the buffer. Also extract the shape.
# The shape is used later to reshape the data.

def get_tensor_value(self, tensor_wrapper):
"""Get tensor buffer value from given tensor wrapper"""
assert isinstance(tensor_wrapper, TensorWrapper)

dtype = self.get_tensor_type_as_numpy(tensor_wrapper)
data = tensor_wrapper.buffer.DataAsNumpy()
shape = tensor_wrapper.tensor.ShapeAsNumpy()

# When TFLite buffer is of size 1 (scalar), then TFLite tensor shape is set to 0.
# Therefore, we set the shape to 1 for numpy reshape to work. Set shape to 1 if the data is
# a scalar type
if data.size == 1 and isinstance(shape, int) and shape == 0:
shape = (1,)

if tensor_wrapper.tensor.Type() == TensorType.INT8:
return np.frombuffer(data, dtype=np.int8).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.UINT8:
return np.frombuffer(data, dtype=np.uint8).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.INT32:
return np.frombuffer(data, dtype=np.int32).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.INT64:
return np.frombuffer(data, dtype=np.int64).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
return np.frombuffer(data, dtype=np.float32).reshape(shape)
if tensor_wrapper.tensor.Type() == TensorType.BOOL:
return np.frombuffer(data, dtype=np.bool).reshape(shape)
raise NotImplementedError("Tensor type {} is currently not supported"
.format(str(tensor_wrapper.tensor.Type())))

if tensor_wrapper.tensor.ShapeLength() != 0:
shape = tensor_wrapper.tensor.ShapeAsNumpy()
else:
shape = []

return np.frombuffer(data, dtype=dtype).reshape(shape)


def get_tensor_type_str(self, tensor_type):
"""Get tensor type string representation when given TFLite tensor type"""
Expand Down Expand Up @@ -728,7 +728,7 @@ def convert_relu(self, op):
zero_point=zero_point_val,
dtype=output_tensor_type_str)
else:
out = _op.clip(in_expr, a_min=0, a_max=6)
out = _op.nn.relu(in_expr)

if output_tensor.qnn_params:
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
Expand Down
5 changes: 2 additions & 3 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def _quantize_keras_model(keras_model, representative_data_gen):
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
return converter.convert()


Expand Down Expand Up @@ -903,7 +902,7 @@ def test_forward_convolution():
_test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
_test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
# dephtwise convolution with single input channel
# depthwise convolution with single input channel
_test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)

# TFLite2 quantized convolution testing
Expand Down Expand Up @@ -1814,7 +1813,7 @@ def _test_quantize_dequantize(data):

# Keras model to force TFLite converter to insert 2 TFLite quantize ops.
# First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize.
# Second TLite quantize op converts int8 tensor to int8 tensor - Qnn requantize.
# Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize.
data_in = tf.keras.layers.Input(shape=data.shape[1:])
relu = tf.keras.layers.ReLU()(data_in)
add = tf.keras.layers.Add()([data_in, relu])
Expand Down

0 comments on commit 0518b31

Please sign in to comment.