Skip to content

Commit

Permalink
[TFLite] TFLite 2.x parser quantization support.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Jun 23, 2020
1 parent 8931cfa commit 0bbef96
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 38 deletions.
157 changes: 120 additions & 37 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,30 @@ def get_tensors(self, tensors_idx_list):
qnn_params = None
tflite_qnn_params = tensor.Quantization()
if tflite_qnn_params is not None:
scale = float(tflite_qnn_params.ScaleAsNumpy())
zero_point = int(tflite_qnn_params.ZeroPointAsNumpy())
tflite_scale = tflite_qnn_params.ScaleAsNumpy()
if isinstance(tflite_scale, np.ndarray) and tflite_scale.size == 1:
scale = float(tflite_scale[0])
elif isinstance(tflite_scale, np.ndarray):
scale = tflite_scale
elif isinstance(tflite_scale, int):
scale = float(tflite_scale)
else:
raise NotImplementedError("Quantized type {} not supported"
.format(type(tflite_scale)))

tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy()
if isinstance(tflite_zero_point, np.ndarray):
zero_point = tflite_zero_point
assert all(x == zero_point[0] for x in zero_point)
zero_point = int(zero_point[0])
elif isinstance(tflite_zero_point, int):
zero_point = int(tflite_zero_point)
else:
raise NotImplementedError("Quantized type {} not supported"
.format(type(tflite_zero_point)))

# Check that the scale and zero points are valid.
if scale != 0 or zero_point != 0:
if isinstance(scale, float) and scale != 0 or isinstance(scale, np.ndarray) or zero_point != 0:
qnn_params = dict()
qnn_params['scale'] = relay.const(scale, 'float32')
qnn_params['zero_point'] = relay.const(zero_point, 'int32')
Expand All @@ -262,6 +282,12 @@ def get_tensor_value(self, tensor_wrapper):
except ImportError:
raise ImportError("The tflite package must be installed")

if tensor_wrapper.tensor.Type() == TensorType.INT8:
val = np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int8)
if len(val) == 1:
return val
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int8).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
if tensor_wrapper.tensor.Type() == TensorType.UINT8:
return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
tensor_wrapper.tensor.ShapeAsNumpy())
Expand Down Expand Up @@ -307,12 +333,22 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
rhs_scale = rhs_tensor.qnn_params['scale']
lhs_zero_point = lhs_tensor.qnn_params['zero_point']
rhs_zero_point = rhs_tensor.qnn_params['zero_point']
lhs_scale_value = get_scalar_from_constant(lhs_scale)
rhs_scale_value = get_scalar_from_constant(rhs_scale)
lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
return lhs_scale_value == rhs_scale_value and \
lhs_zero_point_value == rhs_zero_point_value
try:
lhs_scale_value = get_scalar_from_constant(lhs_scale)
rhs_scale_value = get_scalar_from_constant(rhs_scale)
lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
return lhs_scale_value == rhs_scale_value and \
lhs_zero_point_value == rhs_zero_point_value
except:
lhs_scale_value = get_vector_from_constant(lhs_scale)
rhs_scale_value = get_vector_from_constant(rhs_scale)
lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
return lhs_scale_value == rhs_scale_value and \
lhs_zero_point_value == rhs_zero_point_value



def is_quantized(self, op):
"""Check if an input tensor is quantized."""
Expand Down Expand Up @@ -1605,7 +1641,7 @@ def convert_fully_connected(self, op):

# weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

if self.has_expr(weight_tensor.tensor_idx):
Expand Down Expand Up @@ -1796,7 +1832,7 @@ def convert_conv(self, op, conv_type):

# weight tensor type should be UINT8 (quantization) or FLOAT32
weight_tensor_type = weight_tensor.tensor.Type()
assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)

in_expr = self.get_expr(input_tensor_idx)
Expand Down Expand Up @@ -1854,31 +1890,60 @@ def convert_conv(self, op, conv_type):
# Handle fused activation.
if output_tensor.qnn_params:
# Calculate the intermediate scale and zero point of the int32 output.
data_scale = input_tensor.qnn_params['scale']
weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)
new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')
try:
data_scale = input_tensor.qnn_params['scale']
weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_scalar_from_constant(weight_scale)
new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')

# Finally requantize
out = _qnn.op.requantize(out,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
out = self.convert_qnn_fused_activation_function(\
expr=out,
fused_activation_fn=fused_activation_fn,
scale=output_scale_val,
zero_point=output_zero_point_val,
dtype=output_tensor_type_str)
except:
data_scale = input_tensor.qnn_params['scale']
weight_scale = weight_tensor.qnn_params['scale']
data_scale_val = get_scalar_from_constant(data_scale)
weight_scale_val = get_vector_from_constant(weight_scale)
new_input_scale_val = data_scale_val * weight_scale_val
new_input_scale = relay.const(new_input_scale_val, 'float32')
new_input_zero_point = relay.const(0, 'int32')

# Finally requantize
out = _qnn.op.requantize(out,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
out = self.convert_qnn_fused_activation_function(\
expr=out,
fused_activation_fn=fused_activation_fn,
scale=output_scale_val,
zero_point=output_zero_point_val,
dtype=output_tensor_type_str)

# Finally requantize
out = _qnn.op.requantize(out,
input_scale=new_input_scale,
input_zero_point=new_input_zero_point,
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

# Call activation function
output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
output_zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point'])
out = self.convert_qnn_fused_activation_function(\
expr=out,
fused_activation_fn=fused_activation_fn,
scale=output_scale_val,
zero_point=output_zero_point_val,
dtype=output_tensor_type_str)
else:
out = self.convert_fused_activation_function(out, fused_activation_fn)

Expand Down Expand Up @@ -2566,17 +2631,27 @@ def convert_quantize(self, op):
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
in_expr = self.get_expr(input_tensor.tensor_idx)

output_tensors = self.get_output_tensors(op)
assert len(output_tensors) == 1, "output tensors length should be 1"
output_tensor = output_tensors[0]
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())

# The output must be quantized
assert output_tensor.qnn_params
# Quantize the input
out = self.quantize(in_expr, output_tensor)

# TFLite Quantize op can also act as Requantize op
if input_tensor_type_str == "float32":
out = self.quantize(in_expr, output_tensor)
else:
out = _qnn.op.requantize(in_expr,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)
return out

def convert_dequantize(self, op):
Expand Down Expand Up @@ -2710,7 +2785,6 @@ def get_tensor_expr(self, tensor):
# we can receive as constant.
type_str = self.get_tensor_type_str(tensor.tensor.Type())
expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)

return expr


Expand All @@ -2723,6 +2797,15 @@ def get_scalar_from_constant(expr):
"value must be float32/int32"
return np.asscalar(value)

def get_vector_from_constant(expr):
""" Returns scalar value from Relay constant scalar. """
assert isinstance(expr, _expr.Constant)
value = expr.data.asnumpy()
assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
"value must be float32/int32"
return value



def build_str_map(obj):
"""Build string map of TFLite enum int value
Expand Down
Loading

0 comments on commit 0bbef96

Please sign in to comment.