Skip to content

Commit

Permalink
[FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess (#4543)
Browse files Browse the repository at this point in the history
* [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess

This adds support for the custom operator
TFLite_Detection_PostProcess which is commonly used in
object detection networks such as SSD Mobilenet. It
only adds support for when use_regular_nms = False.

Change-Id: I819b253c0eb6f0fa55da65d2634e09359b888828

* Added a test for the tflite custom op

Change-Id: Ie5baa092deae9a8bcffd2ebd9f6d346b90e58afd

* Removed trailing comma

Change-Id: Ib08f02b5f1a59a883048bfb36e4321152cd2e7f2

* Added spaces between divide

Change-Id: If1171fc03d211a809cedeb800804394972af4060

* Formatted comment

Change-Id: I3ce7e69b8d2c73aec57369c1c64ea1eec07f087b

* Reduced line length in test

Change-Id: I49eaafc3369070f8f3e85fbb965ad20972096c68

* Set random seed for test

Change-Id: I542a787d11422ea83c52147b2cb1144fcef0dd77

* Fixes to style

Change-Id: I2971b8ecebe08c882b2481a99f67cfbe515e0b1f

* Assert for incorrect number of inputs

Change-Id: I393f3b3b62be73e427498d98456fb1d5a214e0af

* Change comparison to pass linting

The linter was updated, so I needed to fix
a small style issue as a result.

Change-Id: Ia3c954565a00de92e7fb1912eae9ed9875d60c7c
  • Loading branch information
mbaret authored Feb 13, 2020
1 parent 51a265a commit 70c6382
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 0 deletions.
197 changes: 197 additions & 0 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, model, subgraph, exp_tab):
'SQUARED_DIFFERENCE': self.convert_squared_difference,
'LOGICAL_AND': self.convert_logical_and,
'LOGICAL_OR': self.convert_logical_or,
'DETECTION_POSTPROCESS': self.convert_detection_postprocess
}

def check_unsupported_ops(self):
Expand Down Expand Up @@ -168,6 +169,10 @@ def get_op_code_str(self, op):
op_code_str = self.builtin_op_code[op_code_id]
if op_code_id == BuiltinOperator.CUSTOM:
# Custom operator
custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode()
if custom_op_code_str == b'TFLite_Detection_PostProcess':
return "DETECTION_POSTPROCESS"

raise NotImplementedError("Custom operators are currently not supported")
return op_code_str

Expand Down Expand Up @@ -1814,6 +1819,113 @@ def convert_transpose_conv(self, op):

return out

def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
_option_names = [
"w_scale",
"max_detections",
"_output_quantized",
"detections_per_class",
"x_scale",
"nms_score_threshold",
"num_classes",
"max_classes_per_detection",
"use_regular_nms",
"y_scale",
"h_scale",
"_support_output_type_float_in_quantized_op",
"nms_iou_threshold"
]

custom_options = get_custom_options(op, _option_names)
if custom_options["use_regular_nms"]:
raise tvm.error.OpAttributeUnImplemented(
"use_regular_nms=True is not yet supported for operator {}."
.format("TFLite_Detection_PostProcess")
)

inputs = self.get_input_tensors(op)
assert len(inputs) == 3, "inputs length should be 3"
cls_pred = self.get_expr(inputs[1].tensor_idx)
loc_prob = self.get_expr(inputs[0].tensor_idx)
anchor_values = self.get_tensor_value(inputs[2])
anchor_boxes = len(anchor_values)
anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type())
anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type)

if inputs[0].qnn_params:
loc_prob = _qnn.op.dequantize(data=loc_prob,
input_scale=inputs[0].qnn_params['scale'],
input_zero_point=inputs[0].qnn_params['zero_point'])
if inputs[1].qnn_params:
cls_pred = _qnn.op.dequantize(data=cls_pred,
input_scale=inputs[1].qnn_params['scale'],
input_zero_point=inputs[1].qnn_params['zero_point'])
if inputs[2].qnn_params:
anchor_expr = _qnn.op.dequantize(data=anchor_expr,
input_scale=inputs[2].qnn_params['scale'],
input_zero_point=inputs[2].qnn_params['zero_point'])

# reshape the cls_pred and loc_prob tensors so
# they can be consumed by multibox_transform_loc
cls_pred = _op.transpose(cls_pred, [0, 2, 1])
# loc_prob coords are in yxhw format
# need to convert to xywh
loc_coords = _op.split(loc_prob, 4, axis=2)
loc_prob = _op.concatenate(
[loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2
)
loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4])

# anchor coords are in yxhw format
# need to convert to ltrb
anchor_coords = _op.split(anchor_expr, 4, axis=1)
anchor_y = anchor_coords[0]
anchor_x = anchor_coords[1]
anchor_h = anchor_coords[2]
anchor_w = anchor_coords[3]
plus_half = _expr.const(0.5, dtype='float32')
minus_half = _expr.const(-0.5, dtype='float32')
anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half))
anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half))
anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half))
anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half))
anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1)
anchor_expr = _op.expand_dims(anchor_expr, 0)

# attributes for multibox_transform_loc
multibox_transform_loc_attrs = {}
multibox_transform_loc_attrs["clip"] = False
multibox_transform_loc_attrs["threshold"] = custom_options["nms_score_threshold"]
multibox_transform_loc_attrs["variances"] = (
1 / custom_options["x_scale"],
1 / custom_options["y_scale"],
1 / custom_options["w_scale"],
1 / custom_options["h_scale"],
)

# attributes for non_max_suppression
non_max_suppression_attrs = {}
non_max_suppression_attrs["return_indices"] = False
non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"]
non_max_suppression_attrs["force_suppress"] = False
non_max_suppression_attrs["top_k"] = anchor_boxes
non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"]
non_max_suppression_attrs["invalid_to_bottom"] = False

ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob,
anchor_expr, **multibox_transform_loc_attrs)
ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs)
ret = _op.vision.get_valid_counts(ret, 0)
valid_count = ret[0]
# the output needs some reshaping to match tflite
ret = _op.split(ret[1], 6, axis=2)
cls_ids = ret[0]
scores = ret[1]
boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2)
ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4)
return ret

def get_expr(self, input_tensor_idx):
return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))

Expand Down Expand Up @@ -1885,6 +1997,91 @@ def get_tensor_name(subgraph, tensor_idx):
return subgraph.Tensors(tensor_idx).Name().decode("utf-8")


def get_custom_options(op, option_names):
"""Get the options of a custom operator.
This implements partial flexbuffer deserialization to be able
to read custom options. It is not intended to be a general
purpose flexbuffer deserializer and as such only supports a
limited number of types and assumes the data is a flat map.
Parameters
----------
op:
A custom TFlite operator.
option_names: list
A complete list of the custom option names.
Returns
-------
options: dict
A dictionary of the custom options.
"""
import struct
from enum import IntEnum

class _FlexBufferType(IntEnum):
"""Flexbuffer type schema from flexbuffers.h"""
FBT_NULL = 0
FBT_INT = 1
FBT_UINT = 2
FBT_FLOAT = 3
# Types above stored inline, types below store an offset.
FBT_KEY = 4
FBT_STRING = 5
FBT_INDIRECT_INT = 6
FBT_INDIRECT_UINT = 7
FBT_INDIRECT_FLOAT = 8
FBT_MAP = 9
FBT_VECTOR = 10 # Untyped.
FBT_VECTOR_INT = 11 # Typed any size (stores no type table).
FBT_VECTOR_UINT = 12
FBT_VECTOR_FLOAT = 13
FBT_VECTOR_KEY = 14
FBT_VECTOR_STRING = 15
FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field).
FBT_VECTOR_UINT2 = 17
FBT_VECTOR_FLOAT2 = 18
FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field).
FBT_VECTOR_UINT3 = 20
FBT_VECTOR_FLOAT3 = 21
FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field).
FBT_VECTOR_UINT4 = 23
FBT_VECTOR_FLOAT4 = 24
FBT_BLOB = 25
FBT_BOOL = 26
FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type

buffer = op.CustomOptionsAsNumpy().tobytes()
value_vector_offset = buffer[-3]
buffer = buffer[:-3]
num_bytes = 4 # Assume all values are stored in 32 bit width
value_vector_size = struct.unpack(
"<i", buffer[-value_vector_offset - num_bytes:-value_vector_offset]
)[0]
type_offset = value_vector_size
types = buffer[-type_offset:]
values = []
for i, t in enumerate(types):
flex_type = _FlexBufferType(t >> 2)
value_offset = -value_vector_offset + i*num_bytes
value_bytes = buffer[value_offset:value_offset+num_bytes]
if flex_type == _FlexBufferType.FBT_BOOL:
value = bool(value_bytes[0])
if flex_type == _FlexBufferType.FBT_INT:
value = struct.unpack("<i", value_bytes)[0]
if flex_type == _FlexBufferType.FBT_UINT:
value = struct.unpack("<I", value_bytes)[0]
if flex_type == _FlexBufferType.FBT_FLOAT:
value = struct.unpack("<f", value_bytes)[0]

values.append(value)

custom_options = dict(zip(sorted(option_names), values))
return custom_options


def from_tflite(model, shape_dict, dtype_dict):
"""Convert from tflite model into compatible relay Function.
Expand Down
48 changes: 48 additions & 0 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,51 @@ def test_forward_fully_connected():
_test_fully_connected([5, 1, 1, 150], [150, 100], [100])


#######################################################################
# Custom Operators
# ----------------

def test_detection_postprocess():
tf_model_file = tf_testing.get_workload_official(
"http://download.tensorflow.org/models/object_detection/"
"ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03.tar.gz",
"ssd_mobilenet_v2_quantized_300x300_coco_2019_01_03/tflite_graph.pb"
)
converter = tf.lite.TFLiteConverter.from_frozen_graph(
tf_model_file,
input_arrays=["raw_outputs/box_encodings", "raw_outputs/class_predictions"],
output_arrays=[
"TFLite_Detection_PostProcess",
"TFLite_Detection_PostProcess:1",
"TFLite_Detection_PostProcess:2",
"TFLite_Detection_PostProcess:3"
],
input_shapes={
"raw_outputs/box_encodings": (1, 1917, 4),
"raw_outputs/class_predictions": (1, 1917, 91),
},
)
converter.allow_custom_ops = True
converter.inference_type = tf.lite.constants.FLOAT
tflite_model = converter.convert()
np.random.seed(0)
box_encodings = np.random.uniform(size=(1, 1917, 4)).astype('float32')
class_predictions = np.random.uniform(size=(1, 1917, 91)).astype('float32')
tflite_output = run_tflite_graph(tflite_model, [box_encodings, class_predictions])
tvm_output = run_tvm_graph(tflite_model, [box_encodings, class_predictions],
["raw_outputs/box_encodings", "raw_outputs/class_predictions"], num_output=4)
# check valid count is the same
assert tvm_output[3] == tflite_output[3]
valid_count = tvm_output[3][0]
tvm_boxes = tvm_output[0][0][:valid_count]
tvm_classes = tvm_output[1][0][:valid_count]
tvm_scores = tvm_output[2][0][:valid_count]
# check the output data is correct
tvm.testing.assert_allclose(np.squeeze(tvm_boxes), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(np.squeeze(tvm_classes), np.squeeze(tflite_output[1]), rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(np.squeeze(tvm_scores), np.squeeze(tflite_output[2]), rtol=1e-5, atol=1e-5)


#######################################################################
# Mobilenet
# ---------
Expand Down Expand Up @@ -1611,6 +1656,9 @@ def test_forward_mediapipe_hand_landmark():
# Logical
test_all_logical()

# Detection_PostProcess
test_detection_postprocess()

# End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
Expand Down

0 comments on commit 70c6382

Please sign in to comment.