-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #236 from mlcommons/patch-latest-od-nms-1.1
Patch the latest NMS implementation from TF master
- Loading branch information
Showing
3 changed files
with
388 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright 2020 The MLPerf Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
########################################################################## | ||
|
||
licenses(["notice"]) # Apache 2.0 | ||
|
||
package(default_visibility = ["//visibility:public"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,365 @@ | ||
diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc | ||
index be28b99f33a..5f1af9427e9 100644 | ||
--- a/tensorflow/lite/kernels/detection_postprocess.cc | ||
+++ b/tensorflow/lite/kernels/detection_postprocess.cc | ||
@@ -98,7 +98,6 @@ struct OpData { | ||
// Indices of Temporary tensors | ||
int decoded_boxes_index; | ||
int scores_index; | ||
- int active_candidate_index; | ||
}; | ||
|
||
void* Init(TfLiteContext* context, const char* buffer, size_t length) { | ||
@@ -126,7 +125,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { | ||
op_data->scale_values.w = m["w_scale"].AsFloat(); | ||
context->AddTensors(context, 1, &op_data->decoded_boxes_index); | ||
context->AddTensors(context, 1, &op_data->scores_index); | ||
- context->AddTensors(context, 1, &op_data->active_candidate_index); | ||
return op_data; | ||
} | ||
|
||
@@ -205,10 +203,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { | ||
|
||
// Temporary tensors | ||
TfLiteIntArrayFree(node->temporaries); | ||
- node->temporaries = TfLiteIntArrayCreate(3); | ||
+ node->temporaries = TfLiteIntArrayCreate(2); | ||
node->temporaries->data[0] = op_data->decoded_boxes_index; | ||
node->temporaries->data[1] = op_data->scores_index; | ||
- node->temporaries->data[2] = op_data->active_candidate_index; | ||
|
||
// decoded_boxes | ||
TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index]; | ||
@@ -225,14 +222,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { | ||
{input_class_predictions->dims->data[1], | ||
input_class_predictions->dims->data[2]}); | ||
|
||
- // active_candidate | ||
- TfLiteTensor* active_candidate = | ||
- &context->tensors[op_data->active_candidate_index]; | ||
- active_candidate->type = kTfLiteUInt8; | ||
- active_candidate->allocation_type = kTfLiteArenaRw; | ||
- SetTensorSizes(context, active_candidate, | ||
- {input_box_encodings->dims->data[1]}); | ||
- | ||
return kTfLiteOk; | ||
} | ||
|
||
@@ -434,8 +423,8 @@ float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes, | ||
// Complexity is O(N^2) pairwise comparison between boxes | ||
TfLiteStatus NonMaxSuppressionSingleClassHelper( | ||
TfLiteContext* context, TfLiteNode* node, OpData* op_data, | ||
- const std::vector<float>& scores, std::vector<int>* selected, | ||
- int max_detections) { | ||
+ const std::vector<float>& scores, int max_detections, | ||
+ std::vector<int>* selected) { | ||
const TfLiteTensor* input_box_encodings; | ||
TF_LITE_ENSURE_OK(context, | ||
GetInputSafe(context, node, kInputTensorBoxEncodings, | ||
@@ -473,14 +462,8 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( | ||
const int num_boxes_kept = num_scores_kept; | ||
const int output_size = std::min(num_boxes_kept, max_detections); | ||
selected->clear(); | ||
- TfLiteTensor* active_candidate = | ||
- &context->tensors[op_data->active_candidate_index]; | ||
- TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes); | ||
int num_active_candidate = num_boxes_kept; | ||
- uint8_t* active_box_candidate = (active_candidate->data.uint8); | ||
- for (int row = 0; row < num_boxes_kept; row++) { | ||
- active_box_candidate[row] = 1; | ||
- } | ||
+ std::vector<uint8_t> active_box_candidate(num_boxes_kept, 1); | ||
|
||
for (int i = 0; i < num_boxes_kept; ++i) { | ||
if (num_active_candidate == 0 || selected->size() >= output_size) break; | ||
@@ -508,6 +491,109 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper( | ||
return kTfLiteOk; | ||
} | ||
|
||
+struct BoxInfo { | ||
+ int index; | ||
+ float score; | ||
+}; | ||
+ | ||
+struct NMSTaskParam { | ||
+ // Caller retains the ownership of `context`, `node`, `op_data` and `scores`. | ||
+ // Caller should ensure their lifetime is longer than NMSTaskParam instance. | ||
+ TfLiteContext* context; | ||
+ TfLiteNode* node; | ||
+ OpData* op_data; | ||
+ const float* scores; | ||
+ | ||
+ int num_classes; | ||
+ int num_boxes; | ||
+ int label_offset; | ||
+ int num_classes_with_background; | ||
+ int num_detections_per_class; | ||
+ int max_detections; | ||
+ std::vector<int>& num_selected; | ||
+}; | ||
+ | ||
+void InplaceMergeBoxInfo(std::vector<BoxInfo>& boxes, int mid_index, | ||
+ int end_index) { | ||
+ std::inplace_merge( | ||
+ boxes.begin(), boxes.begin() + mid_index, boxes.begin() + end_index, | ||
+ [](const BoxInfo& a, const BoxInfo& b) { return a.score >= b.score; }); | ||
+} | ||
+ | ||
+TfLiteStatus ComputeNMSResult(const NMSTaskParam& nms_task_param, int col_begin, | ||
+ int col_end, int& sorted_indices_size, | ||
+ std::vector<BoxInfo>& resulted_sorted_box_info) { | ||
+ std::vector<float> class_scores(nms_task_param.num_boxes); | ||
+ std::vector<int> selected; | ||
+ selected.reserve(nms_task_param.num_detections_per_class); | ||
+ | ||
+ for (int col = col_begin; col <= col_end; ++col) { | ||
+ const float* scores_base = | ||
+ nms_task_param.scores + col + nms_task_param.label_offset; | ||
+ for (int row = 0; row < nms_task_param.num_boxes; row++) { | ||
+ // Get scores of boxes corresponding to all anchors for single class | ||
+ class_scores[row] = *scores_base; | ||
+ scores_base += nms_task_param.num_classes_with_background; | ||
+ } | ||
+ | ||
+ // Perform non-maximal suppression on single class | ||
+ selected.clear(); | ||
+ TF_LITE_ENSURE_OK( | ||
+ nms_task_param.context, | ||
+ NonMaxSuppressionSingleClassHelper( | ||
+ nms_task_param.context, nms_task_param.node, nms_task_param.op_data, | ||
+ class_scores, nms_task_param.num_detections_per_class, &selected)); | ||
+ if (selected.empty()) { | ||
+ continue; | ||
+ } | ||
+ | ||
+ for (int i = 0; i < selected.size(); ++i) { | ||
+ resulted_sorted_box_info[sorted_indices_size + i].score = | ||
+ class_scores[selected[i]]; | ||
+ resulted_sorted_box_info[sorted_indices_size + i].index = | ||
+ (selected[i] * nms_task_param.num_classes_with_background + col + | ||
+ nms_task_param.label_offset); | ||
+ } | ||
+ | ||
+ // In-place merge the original boxes and new selected boxes which are both | ||
+ // sorted by scores. | ||
+ InplaceMergeBoxInfo(resulted_sorted_box_info, sorted_indices_size, | ||
+ sorted_indices_size + selected.size()); | ||
+ | ||
+ sorted_indices_size = | ||
+ std::min(sorted_indices_size + static_cast<int>(selected.size()), | ||
+ nms_task_param.max_detections); | ||
+ } | ||
+ return kTfLiteOk; | ||
+} | ||
+ | ||
+struct NonMaxSuppressionWorkerTask : cpu_backend_threadpool::Task { | ||
+ NonMaxSuppressionWorkerTask(NMSTaskParam& nms_task_param, | ||
+ std::atomic<int>& next_col, int col_begin) | ||
+ : nms_task_param(nms_task_param), | ||
+ next_col(next_col), | ||
+ col_begin(col_begin), | ||
+ sorted_indices_size(0) {} | ||
+ void Run() override { | ||
+ sorted_box_info.resize(nms_task_param.num_detections_per_class + | ||
+ nms_task_param.max_detections); | ||
+ for (int col = col_begin; col < nms_task_param.num_classes; | ||
+ col = (++next_col)) { | ||
+ if (ComputeNMSResult(nms_task_param, col, col, sorted_indices_size, | ||
+ sorted_box_info) != kTfLiteOk) { | ||
+ break; | ||
+ } | ||
+ } | ||
+ } | ||
+ NMSTaskParam& nms_task_param; | ||
+ // A shared atomic variable across threads, representing the next col this | ||
+ // task will work on after completing the work for 'col_begin' | ||
+ std::atomic<int>& next_col; | ||
+ const int col_begin; | ||
+ int sorted_indices_size; | ||
+ std::vector<BoxInfo> sorted_box_info; | ||
+}; | ||
+ | ||
// This function implements a regular version of Non Maximal Suppression (NMS) | ||
// for multiple classes where | ||
// 1) we do NMS separately for each class across all anchors and | ||
@@ -549,7 +635,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, | ||
|
||
const int num_boxes = input_box_encodings->dims->data[1]; | ||
const int num_classes = op_data->num_classes; | ||
- const int num_detections_per_class = op_data->detections_per_class; | ||
+ const int num_detections_per_class = | ||
+ std::min(op_data->detections_per_class, op_data->max_detections); | ||
const int max_detections = op_data->max_detections; | ||
const int num_classes_with_background = | ||
input_class_predictions->dims->data[2]; | ||
@@ -557,73 +644,70 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, | ||
int label_offset = num_classes_with_background - num_classes; | ||
TF_LITE_ENSURE(context, num_detections_per_class > 0); | ||
|
||
- // For each class, perform non-max suppression. | ||
- std::vector<float> class_scores(num_boxes); | ||
- | ||
- std::vector<int> box_indices_after_regular_non_max_suppression( | ||
- num_boxes + max_detections); | ||
- std::vector<float> scores_after_regular_non_max_suppression(num_boxes + | ||
- max_detections); | ||
- | ||
- int size_of_sorted_indices = 0; | ||
- std::vector<int> sorted_indices; | ||
- sorted_indices.resize(num_boxes + max_detections); | ||
- std::vector<float> sorted_values; | ||
- sorted_values.resize(max_detections); | ||
- | ||
- for (int col = 0; col < num_classes; col++) { | ||
- for (int row = 0; row < num_boxes; row++) { | ||
- // Get scores of boxes corresponding to all anchors for single class | ||
- class_scores[row] = | ||
- *(scores + row * num_classes_with_background + col + label_offset); | ||
- } | ||
- // Perform non-maximal suppression on single class | ||
- std::vector<int> selected; | ||
- TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( | ||
- context, node, op_data, class_scores, &selected, | ||
- num_detections_per_class)); | ||
- // Add selected indices from non-max suppression of boxes in this class | ||
- int output_index = size_of_sorted_indices; | ||
- for (const auto& selected_index : selected) { | ||
- box_indices_after_regular_non_max_suppression[output_index] = | ||
- (selected_index * num_classes_with_background + col + label_offset); | ||
- scores_after_regular_non_max_suppression[output_index] = | ||
- class_scores[selected_index]; | ||
- output_index++; | ||
- } | ||
- // Sort the max scores among the selected indices | ||
- // Get the indices for top scores | ||
- int num_indices_to_sort = std::min(output_index, max_detections); | ||
- DecreasingPartialArgSort(scores_after_regular_non_max_suppression.data(), | ||
- output_index, num_indices_to_sort, | ||
- sorted_indices.data()); | ||
- | ||
- // Copy values to temporary vectors | ||
- for (int row = 0; row < num_indices_to_sort; row++) { | ||
- int temp = sorted_indices[row]; | ||
- sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp]; | ||
- sorted_values[row] = scores_after_regular_non_max_suppression[temp]; | ||
+ int sorted_indices_size = 0; | ||
+ std::vector<BoxInfo> box_info_after_regular_non_max_suppression( | ||
+ max_detections + num_detections_per_class); | ||
+ std::vector<int> num_selected(num_classes); | ||
+ | ||
+ NMSTaskParam nms_task_param{context, | ||
+ node, | ||
+ op_data, | ||
+ scores, | ||
+ num_classes, | ||
+ num_boxes, | ||
+ label_offset, | ||
+ num_classes_with_background, | ||
+ num_detections_per_class, | ||
+ max_detections, | ||
+ num_selected}; | ||
+ | ||
+ int num_threads = | ||
+ CpuBackendContext::GetFromContext(context)->max_num_threads(); | ||
+ if (num_threads == 1) { | ||
+ // For each class, perform non-max suppression. | ||
+ TF_LITE_ENSURE_OK( | ||
+ context, ComputeNMSResult(nms_task_param, /* col_begin= */ 0, | ||
+ num_classes - 1, sorted_indices_size, | ||
+ box_info_after_regular_non_max_suppression)); | ||
+ } else { | ||
+ std::atomic<int> next_col(num_threads); | ||
+ std::vector<NonMaxSuppressionWorkerTask> tasks; | ||
+ tasks.reserve(num_threads); | ||
+ for (int i = 0; i < num_threads; ++i) { | ||
+ tasks.emplace_back( | ||
+ NonMaxSuppressionWorkerTask(nms_task_param, next_col, i)); | ||
} | ||
- // Copy scores and indices from temporary vectors | ||
- for (int row = 0; row < num_indices_to_sort; row++) { | ||
- box_indices_after_regular_non_max_suppression[row] = sorted_indices[row]; | ||
- scores_after_regular_non_max_suppression[row] = sorted_values[row]; | ||
+ cpu_backend_threadpool::Execute(tasks.size(), tasks.data(), | ||
+ CpuBackendContext::GetFromContext(context)); | ||
+ | ||
+ // Merge results from tasks. | ||
+ for (int j = 0; j < tasks.size(); ++j) { | ||
+ if (tasks[j].sorted_indices_size == 0) { | ||
+ continue; | ||
+ } | ||
+ memcpy(&box_info_after_regular_non_max_suppression[sorted_indices_size], | ||
+ &tasks[j].sorted_box_info[0], | ||
+ sizeof(BoxInfo) * tasks[j].sorted_indices_size); | ||
+ InplaceMergeBoxInfo(box_info_after_regular_non_max_suppression, | ||
+ sorted_indices_size, | ||
+ sorted_indices_size + tasks[j].sorted_indices_size); | ||
+ sorted_indices_size = std::min( | ||
+ sorted_indices_size + tasks[j].sorted_indices_size, max_detections); | ||
} | ||
- size_of_sorted_indices = num_indices_to_sort; | ||
} | ||
|
||
// Allocate output tensors | ||
for (int output_box_index = 0; output_box_index < max_detections; | ||
output_box_index++) { | ||
- if (output_box_index < size_of_sorted_indices) { | ||
+ if (output_box_index < sorted_indices_size) { | ||
const int anchor_index = floor( | ||
- box_indices_after_regular_non_max_suppression[output_box_index] / | ||
+ box_info_after_regular_non_max_suppression[output_box_index].index / | ||
num_classes_with_background); | ||
const int class_index = | ||
- box_indices_after_regular_non_max_suppression[output_box_index] - | ||
+ box_info_after_regular_non_max_suppression[output_box_index].index - | ||
anchor_index * num_classes_with_background - label_offset; | ||
const float selected_score = | ||
- scores_after_regular_non_max_suppression[output_box_index]; | ||
+ box_info_after_regular_non_max_suppression[output_box_index].score; | ||
// detection_boxes | ||
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32); | ||
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32); | ||
@@ -644,9 +728,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context, | ||
GetTensorData<float>(detection_scores)[output_box_index] = 0.0f; | ||
} | ||
} | ||
- GetTensorData<float>(num_detections)[0] = size_of_sorted_indices; | ||
- box_indices_after_regular_non_max_suppression.clear(); | ||
- scores_after_regular_non_max_suppression.clear(); | ||
+ GetTensorData<float>(num_detections)[0] = sorted_indices_size; | ||
+ box_info_after_regular_non_max_suppression.clear(); | ||
return kTfLiteOk; | ||
} | ||
|
||
@@ -702,11 +785,12 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, | ||
std::vector<float> max_scores; | ||
max_scores.resize(num_boxes); | ||
std::vector<int> sorted_class_indices; | ||
- sorted_class_indices.resize(num_boxes * num_classes); | ||
+ sorted_class_indices.resize(num_boxes * num_categories_per_anchor); | ||
for (int row = 0; row < num_boxes; row++) { | ||
const float* box_scores = | ||
scores + row * num_classes_with_background + label_offset; | ||
- int* class_indices = sorted_class_indices.data() + row * num_classes; | ||
+ int* class_indices = | ||
+ sorted_class_indices.data() + row * num_categories_per_anchor; | ||
DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor, | ||
class_indices); | ||
max_scores[row] = box_scores[class_indices[0]]; | ||
@@ -714,14 +798,14 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context, | ||
// Perform non-maximal suppression on max scores | ||
std::vector<int> selected; | ||
TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper( | ||
- context, node, op_data, max_scores, &selected, op_data->max_detections)); | ||
+ context, node, op_data, max_scores, op_data->max_detections, &selected)); | ||
// Allocate output tensors | ||
int output_box_index = 0; | ||
for (const auto& selected_index : selected) { | ||
const float* box_scores = | ||
scores + selected_index * num_classes_with_background + label_offset; | ||
- const int* class_indices = | ||
- sorted_class_indices.data() + selected_index * num_classes; | ||
+ const int* class_indices = sorted_class_indices.data() + | ||
+ selected_index * num_categories_per_anchor; | ||
|
||
for (int col = 0; col < num_categories_per_anchor; ++col) { | ||
int box_offset = max_categories_per_anchor * output_box_index + col; |