From 934540532687894452bd9ba0fd437d47897fdab3 Mon Sep 17 00:00:00 2001 From: Ina_Dobreva Date: Fri, 12 Jun 2020 17:51:50 +0100 Subject: [PATCH] Fix floating value quantization for RELU6 and RELU1 --- python/tvm/relay/frontend/tflite.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 4f2b95f60b53..b79d8e1289e5 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -702,8 +702,9 @@ def convert_relu6(self, op): if input_tensor.qnn_params: # Quantize a float value to an quantized integer value - quantize = lambda x: float(int(round(x / input_tensor.qnn_params['scale'])) + \ - input_tensor.qnn_params['zero_point']) + scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point']) + quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val) # Get min/max of the input dtype. This will be used to ensure that # clip a_min/a_max are not beyond the dtype range. @@ -772,8 +773,9 @@ def convert_relu_n1_to_1(self, op): if input_tensor.qnn_params: # Quantize a float value to an quantized integer value - quantize = lambda x: float(int(round(x / input_tensor.qnn_params['scale'])) + \ - input_tensor.qnn_params['zero_point']) + scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point']) + quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val) # Get min/max of the input dtype. This will be used to ensure that # clip a_min/a_max are not beyond the dtype range.