Skip to content

Commit

Permalink
Implement clip operation for quantized relu6, relu1
Browse files Browse the repository at this point in the history
* add 'clip' as in the quantized fused operations
* remove redundant assertions and imports
  • Loading branch information
inadob committed Jun 4, 2020
1 parent 1c2435d commit 91c3631
Showing 1 changed file with 46 additions and 26 deletions.
72 changes: 46 additions & 26 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,12 +691,6 @@ def _hard_swish(data):

def convert_relu6(self, op):
"""Convert TFLite ReLU6"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
Expand All @@ -707,23 +701,41 @@ def convert_relu6(self, op):
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = _op.clip(in_expr, a_min=0, a_max=6)
# Quantize a float value to an quantized integer value
quantize = lambda x: float(int(round(x / input_tensor.qnn_params['scale'])) + \
input_tensor.qnn_params['zero_point'])

# Get min/max of the input dtype. This will be used to ensure that
# clip a_min/a_max are not beyond the dtype range.
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value)
qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value)

out = _op.clip(in_expr,
a_min=max(qmin, quantize(0)),
a_max=min(qmax, quantize(6.0)))
else:
out = _op.clip(in_expr, a_min=0, a_max=6)

if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
out = _qnn.op.requantize(out,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

return out

def convert_leaky_relu(self, op):
"""Convert TFLite LEAKY_RELU"""
try:
from tflite.Operator import Operator
from tflite.BuiltinOptions import BuiltinOptions
from tflite.LeakyReluOptions import LeakyReluOptions
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
Expand All @@ -749,12 +761,6 @@ def convert_leaky_relu(self, op):

def convert_relu_n1_to_1(self, op):
"""Convert TFLite RELU_N1_TO_1"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
Expand All @@ -765,21 +771,35 @@ def convert_relu_n1_to_1(self, op):
output_tensor = output_tensors[0]

if input_tensor.qnn_params:
in_expr = self.dequantize(in_expr, input_tensor)
out = _op.clip(in_expr, a_min=-1, a_max=1)
# Quantize a float value to an quantized integer value
quantize = lambda x: float(int(round(x / input_tensor.qnn_params['scale'])) + \
input_tensor.qnn_params['zero_point'])

# Get min/max of the input dtype. This will be used to ensure that
# clip a_min/a_max are not beyond the dtype range.
input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value)
qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value)

out = _op.clip(in_expr,
a_min=max(qmin, quantize(-1.0)),
a_max=min(qmax, quantize(1.0)))
else:
out = _op.clip(in_expr, a_min=-1, a_max=1)

if output_tensor.qnn_params:
out = self.quantize(out, output_tensor)
output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
out = _qnn.op.requantize(out,
input_scale=input_tensor.qnn_params['scale'],
input_zero_point=input_tensor.qnn_params['zero_point'],
output_scale=output_tensor.qnn_params['scale'],
output_zero_point=output_tensor.qnn_params['zero_point'],
out_dtype=output_tensor_type_str)

return out

def convert_log_softmax(self, op):
"""Convert TFLite LOG_SOFTMAX"""
try:
from tflite.Operator import Operator
except ImportError:
raise ImportError("The tflite package must be installed")

assert isinstance(op, Operator)
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
input_tensor = input_tensors[0]
Expand Down

0 comments on commit 91c3631

Please sign in to comment.