Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess #4543

Merged
merged 10 commits into from
Feb 13, 2020
197 changes: 197 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, model, subgraph, exp_tab):
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_OR': self.convert_logical_or,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -168,6 +169,10 @@ def get_op_code_str(self, op):
op_code_str = self.builtin_op_code[op_code_id]
if op_code_id == BuiltinOperator.CUSTOM:
# Custom operator
custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode()
if custom_op_code_str == b'TFLite_Detection_PostProcess':
return "DETECTION_POSTPROCESS"

raise NotImplementedError("Custom operators are currently not supported")
return op_code_str

Expand Down Expand Up @@ -1806,6 +1811,113 @@ def convert_transpose_conv(self, op):

return out

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
_option_names = [
"w_scale",
"max_detections",
"_output_quantized",
"detections_per_class",
"x_scale",
"nms_score_threshold",
"num_classes",
"max_classes_per_detection",
"use_regular_nms",
"y_scale",
"h_scale",
"_support_output_type_float_in_quantized_op",
"nms_iou_threshold"
]

custom_options = get_custom_options(op, _option_names)
if custom_options["use_regular_nms"]:
raise tvm.error.OpAttributeUnImplemented(
"use_regular_nms=True is not yet supported for operator {}."
.format("TFLite_Detection_PostProcess")
)

inputs = self.get_input_tensors(op)
mbaret marked this conversation as resolved.
Show resolved Hide resolved
assert len(inputs) == 3, "inputs length should be 3"
cls_pred = self.get_expr(inputs[1].tensor_idx)
loc_prob = self.get_expr(inputs[0].tensor_idx)
anchor_values = self.get_tensor_value(inputs[2])
anchor_boxes = len(anchor_values)
anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)

if inputs[0].qnn_params:
loc_prob = _qnn.op.dequantize(data=loc_prob,
input_scale=inputs[0].qnn_params['scale'],
input_zero_point=inputs[0].qnn_params['zero_point'])
if inputs[1].qnn_params:
cls_pred = _qnn.op.dequantize(data=cls_pred,
input_scale=inputs[1].qnn_params['scale'],
input_zero_point=inputs[1].qnn_params['zero_point'])
if inputs[2].qnn_params:
anchor_expr = _qnn.op.dequantize(data=anchor_expr,
input_scale=inputs[2].qnn_params['scale'],
input_zero_point=inputs[2].qnn_params['zero_point'])

# reshape the cls_pred and loc_prob tensors so
# they can be consumed by multibox_transform_loc
cls_pred = _op.transpose(cls_pred, [0, 2, 1])
# loc_prob coords are in yxhw format
# need to convert to xywh
loc_coords = _op.split(loc_prob, 4, axis=2)
loc_prob = _op.concatenate(
[loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
)
loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])

# anchor coords are in yxhw format
# need to convert to ltrb
anchor_coords = _op.split(anchor_expr, 4, axis=1)
anchor_y = anchor_coords[0]
anchor_x = anchor_coords[1]
anchor_h = anchor_coords[2]
anchor_w = anchor_coords[3]
plus_half = _expr.const(0.5, dtype='float32')
minus_half = _expr.const(-0.5, dtype='float32')
anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1)
anchor_expr = _op.expand_dims(anchor_expr, 0)

# attributes for multibox_transform_loc
multibox_transform_loc_attrs = {}
multibox_transform_loc_attrs["clip"] = False
multibox_transform_loc_attrs["threshold"] = custom_options["nms_score_threshold"]
multibox_transform_loc_attrs["variances"] = (
1 / custom_options["x_scale"],
1 / custom_options["y_scale"],
1 / custom_options["w_scale"],
1 / custom_options["h_scale"],
)

# attributes for non_max_suppression
non_max_suppression_attrs = {}
non_max_suppression_attrs["return_indices"] = False
non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"]
non_max_suppression_attrs["force_suppress"] = False
non_max_suppression_attrs["top_k"] = anchor_boxes
non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
non_max_suppression_attrs["invalid_to_bottom"] = False

ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob,
anchor_expr, **multibox_transform_loc_attrs)
ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs)
ret = _op.vision.get_valid_counts(ret, 0)
valid_count = ret[0]
# the output needs some reshaping to match tflite
ret = _op.split(ret[1], 6, axis=2)
cls_ids = ret[0]
scores = ret[1]
boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2)
ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
return ret

def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down Expand Up @@ -1877,6 +1989,91 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def get_custom_options(op, option_names):
"""Get the options of a custom operator.

This implements partial flexbuffer deserialization to be able
to read custom options. It is not intended to be a general
purpose flexbuffer deserializer and as such only supports a
limited number of types and assumes the data is a flat map.

Parameters
----------
op:
A custom TFlite operator.
option_names: list
A complete list of the custom option names.

Returns
-------
options: dict
A dictionary of the custom options.

"""
import struct
from enum import IntEnum

class _FlexBufferType(IntEnum):
"""Flexbuffer type schema from flexbuffers.h"""
FBT_NULL = 0
FBT_INT = 1
FBT_UINT = 2
FBT_FLOAT = 3
# Types above stored inline, types below store an offset.
FBT_KEY = 4
FBT_STRING = 5
FBT_INDIRECT_INT = 6
FBT_INDIRECT_UINT = 7
FBT_INDIRECT_FLOAT = 8
FBT_MAP = 9
FBT_VECTOR = 10 # Untyped.
FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
FBT_VECTOR_UINT = 12
FBT_VECTOR_FLOAT = 13
FBT_VECTOR_KEY = 14
FBT_VECTOR_STRING = 15
FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
FBT_VECTOR_UINT2 = 17
FBT_VECTOR_FLOAT2 = 18
FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
FBT_VECTOR_UINT3 = 20
FBT_VECTOR_FLOAT3 = 21
FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
FBT_VECTOR_UINT4 = 23
FBT_VECTOR_FLOAT4 = 24
FBT_BLOB = 25
FBT_BOOL = 26
FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type

buffer = op.CustomOptionsAsNumpy().tobytes()
value_vector_offset = buffer[-3]
buffer = buffer[:-3]
num_bytes = 4 # Assume all values are stored in 32 bit width
value_vector_size = struct.unpack(
"<i", buffer[-value_vector_offset - num_bytes:-value_vector_offset]
)[0]
type_offset = value_vector_size
types = buffer[-type_offset:]
values = []
for i, t in enumerate(types):
flex_type = _FlexBufferType(t >> 2)
value_offset = -value_vector_offset + i*num_bytes
value_bytes = buffer[value_offset:value_offset+num_bytes]
if flex_type == _FlexBufferType.FBT_BOOL:
value = bool(value_bytes[0])
if flex_type == _FlexBufferType.FBT_INT:
mbaret marked this conversation as resolved.
Show resolved Hide resolved
value = struct.unpack("<i", value_bytes)[0]
if flex_type == _FlexBufferType.FBT_UINT:
value = struct.unpack("<I", value_bytes)[0]
if flex_type == _FlexBufferType.FBT_FLOAT:
value = struct.unpack("<f", value_bytes)[0]

values.append(value)

custom_options = dict(zip(sorted(option_names), values))
return custom_options


def from_tflite(model, shape_dict, dtype_dict):
"""Convert from tflite model into compatible relay Function.

Expand Down
48 changes: 48 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,6 +1353,51 @@ def test_forward_fully_connected():
_test_fully_connected([5, 1, 1, 150], [150, 100], [100])


#######################################################################
# Custom Operators
# ----------------

def test_detection_postprocess():
tf_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/object_detection/"
"ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind adding more models like ssd_mobilenetv1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do, but where would you like me to pull it from? I see that ssd mobilenet v1 without the post process op is hosted under "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/", would it be possible to host the version with the post process op here as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, we'd like to pull the model from the related official website, for example https://www.tensorflow.org/lite/models/object_detection/overview for ssd mobilenet v1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK - I did see that model but weirdly it was as a .zip, not a tar as with most other hosted models. I'll see if I can open another PR to extend get_workload_official to zips and then will add the test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test looks non-trivial to add because quite a small difference in the convolutional part of the network can result in significant changes to the ordering of the output tensor (eg. we might see at different detection at the cut off threshold). I'm not sure what the best way is to proceed, do you have any thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, we could remove ssd mobilenet model because of this limitation, but we should still keep the unit testing of detection postprocess. After we resolve the limitation, we could add ssd mobilenet testing back. Morever, we could remove the atol=1 of test_qconv2d and so on. Because we could get the same result completely compared with the tflite. Does it make sense to you?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is a bit misleading because it doesn't actually run ssd mobilenet, it just test the postprocess op. I couldn't find a way to create the op using the tflite python API, so what I did instead was take a model that has it and then run it through the tflite converter but with the converter inputs set to the inputs of the postprocess op rather than the input to the network.

This has the net effect of producing a single postprocess op, so this should already be a unit test (and it passes). I can add the end-to-end tests if/when we resolve the QNN accuracy issue. I'll open an RFC shortly to describe why rounding is a particularly significant in the case of this operator.

Copy link
Member

@FrozenGene FrozenGene Jan 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we could view the TOCO source code, maybe we could find how to construct detection_postprocess. Please refer our _test_prelu comment. I ever write what the pattern tflite could produce prelu. However, current way is acceptable too in my opinion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've written a discuss post here: 5528.

Copy link

@sjoshi30 sjoshi30 Feb 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbaret How did you set converter input as inputs of postprocess op, when I do that it gives me error :
tensorflow/lite/toco/model_cmdline_flags.cc:263] Check failed: mean_values.size() == model_flags->input_arrays_size()

The inputs to postprocess op >1 ('raw_outputs/box_encodings','raw_outputs/class_predictions') also anchors constant

"ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb"
)
converter = tf.lite.TFLiteConverter.from_frozen_graph(
tf_model_file,
input_arrays=["raw_outputs/box_encodings", "raw_outputs/class_predictions"],
output_arrays=[
"TFLite_Detection_PostProcess",
"TFLite_Detection_PostProcess:1",
"TFLite_Detection_PostProcess:2",
"TFLite_Detection_PostProcess:3"
],
input_shapes={
"raw_outputs/box_encodings": (1, 1917, 4),
"raw_outputs/class_predictions": (1, 1917, 91),
},
)
converter.allow_custom_ops = True
converter.inference_type = tf.lite.constants.FLOAT
tflite_model = converter.convert()
np.random.seed(0)
box_encodings = np.random.uniform(size=(1, 1917, 4)).astype('float32')
class_predictions = np.random.uniform(size=(1, 1917, 91)).astype('float32')
tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])
tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions],
["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
# check valid count is the same
assert tvm_output[3] == tflite_output[3]
valid_count = tvm_output[3][0]
tvm_boxes = tvm_output[0][0][:valid_count]
tvm_classes = tvm_output[1][0][:valid_count]
tvm_scores = tvm_output[2][0][:valid_count]
# check the output data is correct
tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5)


#######################################################################
# Mobilenet
# ---------
Expand Down Expand Up @@ -1573,6 +1618,9 @@ def test_forward_mediapipe_hand_landmark():
# Logical
test_all_logical()

# Detection_PostProcess
test_detection_postprocess()

# End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
Expand Down