diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index 66d0ff326ce07..360bdbd8db210 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -319,6 +319,40 @@ def dequantize(self, expr, tensor): input_zero_point=tensor.qnn_params['zero_point']) return dequantized + + def convert_qnn_fused_activation_function(self, expr, fused_activation_fn, + scale, zero_point, dtype): + """Convert TFLite fused activation function. The expr is an input quantized tensor with + scale and zero point """ + try: + from tflite.ActivationFunctionType import ActivationFunctionType + except ImportError: + raise ImportError("The tflite package must be installed") + + # Quantize a float value to an integer + quantize = lambda value : (value / scale) + zero_point + + # The input expr is a quantized tensor with its scale and zero point. We calculate the + # suitable clip off points based on these scale and zero point. + if fused_activation_fn == ActivationFunctionType.NONE: + return expr + elif fused_activation_fn == ActivationFunctionType.RELU6: + return _op.clip(expr, + a_min=quantize(0), + a_max=quantize(6)) + elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + return _op.clip(expr, + a_min=quantize(-1), + a_max=quantize(1)) + elif fused_activation_fn == ActivationFunctionType.RELU: + return _op.clip(expr, + a_min=quantize(0), + a_max=float(tvm.tir.op.min_value(dtype).value)) + + fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] + raise tvm.error.OpNotImplemented( + 'Quantized activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + def convert_conv2d(self, op): """Convert TFLite conv2d""" return self.convert_conv(op, "conv2d") @@ -452,17 +486,16 @@ def convert_l2_normalization(self, op): if self.is_quantized(op): raise tvm.error.OpNotImplemented( 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + # TFL uses only the default epsilon value out = _op.nn.l2_normalize(in_expr, eps=1e-12, axis=[input_tensor_rank - 1]) # if we have fused activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'TFLite quantized L2_NORMALIZATION operator\ - with fused activation function is not supported yet.') + if output_tensor.qnn_params: + raise tvm.error.OpNotImplemented( + 'TFLite quantized L2_NORMALIZATION operator is not supported yet.') + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -639,14 +672,20 @@ def convert_concatenation(self, op): output_zero_point=output_tensor.qnn_params['zero_point'], axis=concatenation_axis) - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.concatenate')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def _convert_unary_elemwise(self, relay_op, op): @@ -854,13 +893,20 @@ def _convert_elemwise(self, relay_op, op): op_options = op.BuiltinOptions() options.Init(op_options.Bytes, op_options.Pos) fused_activation_fn = options.FusedActivationFunction() - # if we have activation fn - if fused_activation_fn != ActivationFunctionType.NONE: - if output_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Elemwise operators with fused activation are not supported yet.') - out = self.convert_fused_activation_function(out, fused_activation_fn) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type()) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out def convert_add(self, op): @@ -1365,15 +1411,6 @@ def convert_fully_connected(self, op): dtype=bias_tensor_type_str) out = _op.nn.bias_add(out, bias_expr) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.dense')) - # Finally if the dense is quantized. Add a requantize at the end. if output_tensor.qnn_params: data_scale = input_tensor.qnn_params['scale'] @@ -1383,12 +1420,24 @@ def convert_fully_connected(self, op): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Call activation function + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=new_input_scale_val, + zero_point=0, + dtype='int32') + + # Requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -1424,18 +1473,20 @@ def convert_fused_activation_function(self, in_expr, fused_activation_fn): from tflite.ActivationFunctionType import ActivationFunctionType except ImportError: raise ImportError("The tflite package must be installed") - assert fused_activation_fn != ActivationFunctionType.NONE - if fused_activation_fn == ActivationFunctionType.RELU6: + + if fused_activation_fn == ActivationFunctionType.NONE: + return in_expr + elif fused_activation_fn == ActivationFunctionType.RELU6: return _op.clip(in_expr, a_min=0, a_max=6) - if fused_activation_fn == ActivationFunctionType.RELU: + elif fused_activation_fn == ActivationFunctionType.RELU: return _op.nn.relu(in_expr) - if fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: + elif fused_activation_fn == ActivationFunctionType.RELU_N1_TO_1: return _op.clip(in_expr, a_min=-1, a_max=1) - if fused_activation_fn == ActivationFunctionType.TANH: + elif fused_activation_fn == ActivationFunctionType.TANH: return _op.tanh(in_expr) fused_activation_fn_str = self.activation_fn_type[fused_activation_fn] raise tvm.error.OpNotImplemented( - 'Operator {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) + 'Fused activation {} is not supported for frontend TFLite.'.format(fused_activation_fn_str)) def convert_conv(self, op, conv_type): """convolution implementation.""" @@ -1572,17 +1623,9 @@ def convert_conv(self, op, conv_type): channel_axis = 3 out = _op.nn.bias_add(out, bias_expr, axis=channel_axis) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if not output_tensor.qnn_params: - out = self.convert_fused_activation_function(out, fused_activation_fn) - else: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.conv2d')) - - # Finally if the conv is quantized. Add a requantize at the end. + # Handle fused activation. if output_tensor.qnn_params: + # Calculate the intermediate scale and zero point of the int32 output. data_scale = input_tensor.qnn_params['scale'] weight_scale = weight_tensor.qnn_params['scale'] data_scale_val = get_scalar_from_constant(data_scale) @@ -1590,12 +1633,24 @@ def convert_conv(self, op, conv_type): new_input_scale_val = data_scale_val * weight_scale_val new_input_scale = relay.const(new_input_scale_val, 'float32') new_input_zero_point = relay.const(0, 'int32') + + # Call activation function + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=new_input_scale_val, + zero_point=0, + dtype='int32') + + # Finally requantize out = _qnn.op.requantize(out, input_scale=new_input_scale, input_zero_point=new_input_zero_point, output_scale=output_tensor.qnn_params['scale'], output_zero_point=output_tensor.qnn_params['zero_point'], out_dtype=output_tensor_type_str) + else: + out = self.convert_fused_activation_function(out, fused_activation_fn) return out @@ -1835,13 +1890,19 @@ def convert_pool2d(self, op, pool_type): raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend TFLite.'.format(pool_type + ' pool')) - # If we have fused activations - if fused_activation_fn != ActivationFunctionType.NONE: - if input_tensor.qnn_params: - raise tvm.error.OpNotImplemented( - 'Operator {} with fused activation is not supported yet.' - .format('qnn.op.pool2d')) + # Handle fused activations + if output_tensor.qnn_params: + scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale']) + zero_point_val = get_scalar_from_constant(output_tensor.qnn_params['zero_point']) + out = self.convert_qnn_fused_activation_function(\ + expr=out, + fused_activation_fn=fused_activation_fn, + scale=scale_val, + zero_point=zero_point_val, + dtype=output_tensor_type_str) + else: out = self.convert_fused_activation_function(out, fused_activation_fn) + return out def convert_pad(self, op): diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 1a231eb1aaedf..8cbab38f19941 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -184,10 +184,15 @@ def get_workload_official(model_url, model_sub_path): dir_path = os.path.dirname(model_path) import tarfile + import zipfile if model_path.endswith("tgz") or model_path.endswith("gz"): tar = tarfile.open(model_path) tar.extractall(path=dir_path) tar.close() + elif model_path.endswith("zip"): + zip_object = zipfile.ZipFile(model_path) + zip_object.extractall(path=dir_path) + zip_object.close() else: raise RuntimeError('Could not decompress the file: ' + model_path) return os.path.join(dir_path, model_sub_path) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 6b241373fe545..af0d09152aa11 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -1933,6 +1933,49 @@ def test_forward_qnn_mobilenet_v3_net(): tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1] tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) + +####################################################################### +# SSD Mobilenet +# ------------- + +def test_forward_qnn_coco_ssd_mobilenet_v1(): + """Test the quantized Coco SSD Mobilenet V1 TF Lite model.""" + pytest.skip("Unsupported op - use_regular_nms") + tflite_model_file = tf_testing.get_workload_official( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite/coco_ssd_mobilenet_v1_1.0_quant_2018_06_29.zip", + "detect.tflite") + + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + + np.random.seed(0) + data = np.random.uniform(size=(1, 300, 300, 3)).astype('uint8') + tflite_output = run_tflite_graph(tflite_model_buf, data) + tvm_output = run_tvm_graph(tflite_model_buf, data, 'normalized_input_image_tensor', num_output=4) + + # Check all output shapes are equal + assert all([tvm_tensor.shape == tflite_tensor.shape \ + for (tvm_tensor, tflite_tensor) in zip(tvm_output, tflite_output)]) + + # Check valid count is the same + assert tvm_output[3] == tflite_output[3] + valid_count = tvm_output[3][0] + + # For boxes that do not have any detections, TFLite puts random values. Therefore, we compare + # tflite and tvm tensors for only valid boxes. + for i in range(0, valid_count): + # Check bounding box co-ords + tvm.testing.assert_allclose(np.squeeze(tvm_output[0][0][i]), np.squeeze(tflite_output[0][0][i]), + rtol=1e-5, atol=1e-5) + # Check the class + tvm.testing.assert_allclose(np.squeeze(tvm_output[1][0][i]), np.squeeze(tflite_output[1][0][i]), + rtol=1e-5, atol=1e-5) + # Check the score + tvm.testing.assert_allclose(np.squeeze(tvm_output[2][0][i]), np.squeeze(tflite_output[2][0][i]), + rtol=1e-5, atol=1e-5) + + + ####################################################################### # SSD Mobilenet # ------------- @@ -1945,7 +1988,7 @@ def test_forward_coco_ssd_mobilenet_v1(): with open(tflite_model_file, "rb") as f: tflite_model_buf = f.read() - + np.random.seed(0) data = np.random.uniform(size=(1, 300, 300, 3)).astype('float32') tflite_output = run_tflite_graph(tflite_model_buf, data)