Skip to content

Commit

Permalink
[TFLite] Added check for dynamic range quantization (apache#7114)
Browse files Browse the repository at this point in the history
* [TFLite] Added check for dynamic range quantization

Added check to prevent optimized with "dynamic range quantization"
tflite files to be loaded as the optimization is not fully supported.

https://www.tensorflow.org/lite/performance/post_training_quantization#dynamic_range_quantization

* linter

* linter

* unit test fix
  • Loading branch information
d-smirnov authored and alexwong committed Feb 11, 2021
1 parent 6100c49 commit 5ad9136
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
34 changes: 31 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,45 @@ def __init__(self, model, subgraph, exp_tab):
def check_unsupported_ops(self):
"""Check unsupported TFLite ops in our converter."""
unsupported_ops_set = set()

dynamic_range_ops_set = set()
for op_idx in range(self.subgraph.OperatorsLength()):
op = self.subgraph.Operators(op_idx)
op_code_str = self.get_op_code_str(op)
if op_code_str not in self.convert_map:
unsupported_ops_set.add(op_code_str)
continue

# Trying to exclude "dynamic range quantization" optimized ops as not supported in TVM
qnn_in_cnt = len(
[_.qnn_params for _ in self.get_input_tensors(op)[0:1] if _.qnn_params is not None]
)
qnn_weight_cnt = len(
[_.qnn_params for _ in self.get_input_tensors(op)[1:] if _.qnn_params is not None]
)
qnn_out_cnt = len(
[_.qnn_params for _ in self.get_output_tensors(op) if _.qnn_params is not None]
)

if qnn_in_cnt == 0 and qnn_out_cnt == 0 and qnn_weight_cnt > 0:
dynamic_range_ops_set.add(op_code_str)

raise_msg = ""

if unsupported_ops_set:
msg = "The following operators are not supported in frontend " "TFLite: {}"
msg = "The following operators are not supported in frontend " "TFLite: {}\n"
ops = str(list(unsupported_ops_set)).strip("[,]")
raise tvm.error.OpNotImplemented(msg.format(ops))
raise_msg += msg.format(ops)

if dynamic_range_ops_set:
msg = (
"The following operators are likely to have dynamic range quantization: {}. "
"If you are running an optimized graph, please turn off dynamic range quantization "
"or use full integer quantization"
)
raise_msg += msg.format(str(list(dynamic_range_ops_set)).strip("[,]"))

if len(raise_msg) > 0:
raise tvm.error.OpNotImplemented(raise_msg)

def convert_op_to_relay(self):
"""Convert TFLite ops to relay ops"""
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4156,6 +4156,27 @@ def test_forward_mediapipe_hand_landmark():
)


#######################################################################
# Test check for Tensorflow "dynamic range quantization" optimization
# --------------
def test_prevent_tensorflow_dynamic_range():
"""
Should prevent runnung "dynamic range quantization" optimized TFLite graph
"""
data_array = np.random.randint(0, 2, (1, 1024, 1024)).astype(dtype=np.float32)
filter_array = np.random.randint(0, 2, (1024, 1024)).astype(dtype=np.float32)
data_in = tf.keras.layers.Input(shape=data_array.shape[1:])
dense = tf.keras.layers.Dense(units=filter_array.shape[-1], use_bias=False)(data_in)
keras_model = tf.keras.models.Model(data_in, dense)
keras_model.layers[1].set_weights([filter_array])

converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
with pytest.raises(tvm.error.OpNotImplemented):
tvm_output = run_tvm_graph(tflite_model, data_array, data_in.name.replace(":0", ""))


#######################################################################
# Main
# ----
Expand Down

0 comments on commit 5ad9136

Please sign in to comment.