From 8c50596e4cf401bfc2facb19da0bd957e552eab3 Mon Sep 17 00:00:00 2001 From: mbarrett97 <55580676+mbarrett97@users.noreply.github.com> Date: Thu, 13 Feb 2020 02:07:58 +0000 Subject: [PATCH] [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess (#4543) * [FRONTEND][TFLITE] Add support for TFLite_Detection_PostProcess This adds support for the custom operator TFLite_Detection_PostProcess which is commonly used in object detection networks such as SSD Mobilenet. It only adds support for when use_regular_nms = False. Change-Id: I819b253c0eb6f0fa55da65d2634e09359b888828 * Added a test for the tflite custom op Change-Id: Ie5baa092deae9a8bcffd2ebd9f6d346b90e58afd * Removed trailing comma Change-Id: Ib08f02b5f1a59a883048bfb36e4321152cd2e7f2 * Added spaces between divide Change-Id: If1171fc03d211a809cedeb800804394972af4060 * Formatted comment Change-Id: I3ce7e69b8d2c73aec57369c1c64ea1eec07f087b * Reduced line length in test Change-Id: I49eaafc3369070f8f3e85fbb965ad20972096c68 * Set random seed for test Change-Id: I542a787d11422ea83c52147b2cb1144fcef0dd77 * Fixes to style Change-Id: I2971b8ecebe08c882b2481a99f67cfbe515e0b1f * Assert for incorrect number of inputs Change-Id: I393f3b3b62be73e427498d98456fb1d5a214e0af * Change comparison to pass linting The linter was updated, so I needed to fix a small style issue as a result. Change-Id: Ia3c954565a00de92e7fb1912eae9ed9875d60c7c --- python/tvm/relay/frontend/tflite.py | 197 +++++++++++++++++++ tests/python/frontend/tflite/test_forward.py | 48 +++++ 2 files changed, 245 insertions(+) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index a0b0c0fce5268..d889631a4cd8c 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -121,6 +121,7 @@ def __init__(self, model, subgraph, exp_tab): 'SQUARED_DIFFERENCE': self.convert_squared_difference, 'LOGICAL_AND': self.convert_logical_and, 'LOGICAL_OR': self.convert_logical_or, + 'DETECTION_POSTPROCESS': self.convert_detection_postprocess } def check_unsupported_ops(self): @@ -168,6 +169,10 @@ def get_op_code_str(self, op): op_code_str = self.builtin_op_code[op_code_id] if op_code_id == BuiltinOperator.CUSTOM: # Custom operator + custom_op_code_str = self.model.OperatorCodes(op_code_list_idx).CustomCode() + if custom_op_code_str == b'TFLite_Detection_PostProcess': + return "DETECTION_POSTPROCESS" + raise NotImplementedError("Custom operators are currently not supported") return op_code_str @@ -1814,6 +1819,113 @@ def convert_transpose_conv(self, op): return out + def convert_detection_postprocess(self, op): + """Convert TFLite_Detection_PostProcess""" + _option_names = [ + "w_scale", + "max_detections", + "_output_quantized", + "detections_per_class", + "x_scale", + "nms_score_threshold", + "num_classes", + "max_classes_per_detection", + "use_regular_nms", + "y_scale", + "h_scale", + "_support_output_type_float_in_quantized_op", + "nms_iou_threshold" + ] + + custom_options = get_custom_options(op, _option_names) + if custom_options["use_regular_nms"]: + raise tvm.error.OpAttributeUnImplemented( + "use_regular_nms=True is not yet supported for operator {}." + .format("TFLite_Detection_PostProcess") + ) + + inputs = self.get_input_tensors(op) + assert len(inputs) == 3, "inputs length should be 3" + cls_pred = self.get_expr(inputs[1].tensor_idx) + loc_prob = self.get_expr(inputs[0].tensor_idx) + anchor_values = self.get_tensor_value(inputs[2]) + anchor_boxes = len(anchor_values) + anchor_type = self.get_tensor_type_str(inputs[2].tensor.Type()) + anchor_expr = self.exp_tab.new_const(anchor_values, dtype=anchor_type) + + if inputs[0].qnn_params: + loc_prob = _qnn.op.dequantize(data=loc_prob, + input_scale=inputs[0].qnn_params['scale'], + input_zero_point=inputs[0].qnn_params['zero_point']) + if inputs[1].qnn_params: + cls_pred = _qnn.op.dequantize(data=cls_pred, + input_scale=inputs[1].qnn_params['scale'], + input_zero_point=inputs[1].qnn_params['zero_point']) + if inputs[2].qnn_params: + anchor_expr = _qnn.op.dequantize(data=anchor_expr, + input_scale=inputs[2].qnn_params['scale'], + input_zero_point=inputs[2].qnn_params['zero_point']) + + # reshape the cls_pred and loc_prob tensors so + # they can be consumed by multibox_transform_loc + cls_pred = _op.transpose(cls_pred, [0, 2, 1]) + # loc_prob coords are in yxhw format + # need to convert to xywh + loc_coords = _op.split(loc_prob, 4, axis=2) + loc_prob = _op.concatenate( + [loc_coords[1], loc_coords[0], loc_coords[3], loc_coords[2]], axis=2 + ) + loc_prob = _op.reshape(loc_prob, [1, anchor_boxes*4]) + + # anchor coords are in yxhw format + # need to convert to ltrb + anchor_coords = _op.split(anchor_expr, 4, axis=1) + anchor_y = anchor_coords[0] + anchor_x = anchor_coords[1] + anchor_h = anchor_coords[2] + anchor_w = anchor_coords[3] + plus_half = _expr.const(0.5, dtype='float32') + minus_half = _expr.const(-0.5, dtype='float32') + anchor_l = _op.add(anchor_x, _op.multiply(anchor_w, minus_half)) + anchor_r = _op.add(anchor_x, _op.multiply(anchor_w, plus_half)) + anchor_t = _op.add(anchor_y, _op.multiply(anchor_h, minus_half)) + anchor_b = _op.add(anchor_y, _op.multiply(anchor_h, plus_half)) + anchor_expr = _op.concatenate([anchor_l, anchor_t, anchor_r, anchor_b], axis=1) + anchor_expr = _op.expand_dims(anchor_expr, 0) + + # attributes for multibox_transform_loc + multibox_transform_loc_attrs = {} + multibox_transform_loc_attrs["clip"] = False + multibox_transform_loc_attrs["threshold"] = custom_options["nms_score_threshold"] + multibox_transform_loc_attrs["variances"] = ( + 1 / custom_options["x_scale"], + 1 / custom_options["y_scale"], + 1 / custom_options["w_scale"], + 1 / custom_options["h_scale"], + ) + + # attributes for non_max_suppression + non_max_suppression_attrs = {} + non_max_suppression_attrs["return_indices"] = False + non_max_suppression_attrs["iou_threshold"] = custom_options["nms_iou_threshold"] + non_max_suppression_attrs["force_suppress"] = False + non_max_suppression_attrs["top_k"] = anchor_boxes + non_max_suppression_attrs["max_output_size"] = custom_options["max_detections"] + non_max_suppression_attrs["invalid_to_bottom"] = False + + ret = _op.vision.multibox_transform_loc(cls_pred, loc_prob, + anchor_expr, **multibox_transform_loc_attrs) + ret = _op.vision.non_max_suppression(ret[0], ret[1], **non_max_suppression_attrs) + ret = _op.vision.get_valid_counts(ret, 0) + valid_count = ret[0] + # the output needs some reshaping to match tflite + ret = _op.split(ret[1], 6, axis=2) + cls_ids = ret[0] + scores = ret[1] + boxes = _op.concatenate([ret[3], ret[2], ret[5], ret[4]], axis=2) + ret = _expr.TupleWrapper(_expr.Tuple([boxes, cls_ids, scores, valid_count]), size=4) + return ret + def get_expr(self, input_tensor_idx): return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx)) @@ -1885,6 +1997,91 @@ def get_tensor_name(subgraph, tensor_idx): return subgraph.Tensors(tensor_idx).Name().decode("utf-8") +def get_custom_options(op, option_names): + """Get the options of a custom operator. + + This implements partial flexbuffer deserialization to be able + to read custom options. It is not intended to be a general + purpose flexbuffer deserializer and as such only supports a + limited number of types and assumes the data is a flat map. + + Parameters + ---------- + op: + A custom TFlite operator. + option_names: list + A complete list of the custom option names. + + Returns + ------- + options: dict + A dictionary of the custom options. + + """ + import struct + from enum import IntEnum + + class _FlexBufferType(IntEnum): + """Flexbuffer type schema from flexbuffers.h""" + FBT_NULL = 0 + FBT_INT = 1 + FBT_UINT = 2 + FBT_FLOAT = 3 + # Types above stored inline, types below store an offset. + FBT_KEY = 4 + FBT_STRING = 5 + FBT_INDIRECT_INT = 6 + FBT_INDIRECT_UINT = 7 + FBT_INDIRECT_FLOAT = 8 + FBT_MAP = 9 + FBT_VECTOR = 10 # Untyped. + FBT_VECTOR_INT = 11 # Typed any size (stores no type table). + FBT_VECTOR_UINT = 12 + FBT_VECTOR_FLOAT = 13 + FBT_VECTOR_KEY = 14 + FBT_VECTOR_STRING = 15 + FBT_VECTOR_INT2 = 16 # Typed tuple (no type table, no size field). + FBT_VECTOR_UINT2 = 17 + FBT_VECTOR_FLOAT2 = 18 + FBT_VECTOR_INT3 = 19 # Typed triple (no type table, no size field). + FBT_VECTOR_UINT3 = 20 + FBT_VECTOR_FLOAT3 = 21 + FBT_VECTOR_INT4 = 22 # Typed quad (no type table, no size field). + FBT_VECTOR_UINT4 = 23 + FBT_VECTOR_FLOAT4 = 24 + FBT_BLOB = 25 + FBT_BOOL = 26 + FBT_VECTOR_BOOL = 36 # To Allow the same type of conversion of type to vector type + + buffer = op.CustomOptionsAsNumpy().tobytes() + value_vector_offset = buffer[-3] + buffer = buffer[:-3] + num_bytes = 4 # Assume all values are stored in 32 bit width + value_vector_size = struct.unpack( + "> 2) + value_offset = -value_vector_offset + i*num_bytes + value_bytes = buffer[value_offset:value_offset+num_bytes] + if flex_type == _FlexBufferType.FBT_BOOL: + value = bool(value_bytes[0]) + if flex_type == _FlexBufferType.FBT_INT: + value = struct.unpack("