Skip to content

Commit

Permalink
Merge pull request #236 from mlcommons/patch-latest-od-nms-1.1
Browse files Browse the repository at this point in the history
Patch the latest NMS implementation from TF master
  • Loading branch information
ptt2panda authored Jan 27, 2022
2 parents ed23419 + 1e984ed commit 7fd5b8d
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 0 deletions.
5 changes: 5 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ http_archive(

http_archive(
name = "org_tensorflow",
patch_args = ["-p1"],
patches = [
# An improve from tensorflow 2.7 that Pixel 6 needs
"//patches:fast_nms.diff",
],
sha256 = "40d3203ab5f246d83bae328288a24209a2b85794f1b3e2cd0329458d8e7c1985",
strip_prefix = "tensorflow-2.6.0",
urls = [
Expand Down
18 changes: 18 additions & 0 deletions patches/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright 2020 The MLPerf Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##########################################################################

licenses(["notice"]) # Apache 2.0

package(default_visibility = ["//visibility:public"])
365 changes: 365 additions & 0 deletions patches/fast_nms.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,365 @@
diff --git a/tensorflow/lite/kernels/detection_postprocess.cc b/tensorflow/lite/kernels/detection_postprocess.cc
index be28b99f33a..5f1af9427e9 100644
--- a/tensorflow/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/lite/kernels/detection_postprocess.cc
@@ -98,7 +98,6 @@ struct OpData {
// Indices of Temporary tensors
int decoded_boxes_index;
int scores_index;
- int active_candidate_index;
};

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -126,7 +125,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
op_data->scale_values.w = m["w_scale"].AsFloat();
context->AddTensors(context, 1, &op_data->decoded_boxes_index);
context->AddTensors(context, 1, &op_data->scores_index);
- context->AddTensors(context, 1, &op_data->active_candidate_index);
return op_data;
}

@@ -205,10 +203,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {

// Temporary tensors
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(3);
+ node->temporaries = TfLiteIntArrayCreate(2);
node->temporaries->data[0] = op_data->decoded_boxes_index;
node->temporaries->data[1] = op_data->scores_index;
- node->temporaries->data[2] = op_data->active_candidate_index;

// decoded_boxes
TfLiteTensor* decoded_boxes = &context->tensors[op_data->decoded_boxes_index];
@@ -225,14 +222,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
{input_class_predictions->dims->data[1],
input_class_predictions->dims->data[2]});

- // active_candidate
- TfLiteTensor* active_candidate =
- &context->tensors[op_data->active_candidate_index];
- active_candidate->type = kTfLiteUInt8;
- active_candidate->allocation_type = kTfLiteArenaRw;
- SetTensorSizes(context, active_candidate,
- {input_box_encodings->dims->data[1]});
-
return kTfLiteOk;
}

@@ -434,8 +423,8 @@ float ComputeIntersectionOverUnion(const TfLiteTensor* decoded_boxes,
// Complexity is O(N^2) pairwise comparison between boxes
TfLiteStatus NonMaxSuppressionSingleClassHelper(
TfLiteContext* context, TfLiteNode* node, OpData* op_data,
- const std::vector<float>& scores, std::vector<int>* selected,
- int max_detections) {
+ const std::vector<float>& scores, int max_detections,
+ std::vector<int>* selected) {
const TfLiteTensor* input_box_encodings;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensorBoxEncodings,
@@ -473,14 +462,8 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
const int num_boxes_kept = num_scores_kept;
const int output_size = std::min(num_boxes_kept, max_detections);
selected->clear();
- TfLiteTensor* active_candidate =
- &context->tensors[op_data->active_candidate_index];
- TF_LITE_ENSURE(context, (active_candidate->dims->data[0]) == num_boxes);
int num_active_candidate = num_boxes_kept;
- uint8_t* active_box_candidate = (active_candidate->data.uint8);
- for (int row = 0; row < num_boxes_kept; row++) {
- active_box_candidate[row] = 1;
- }
+ std::vector<uint8_t> active_box_candidate(num_boxes_kept, 1);

for (int i = 0; i < num_boxes_kept; ++i) {
if (num_active_candidate == 0 || selected->size() >= output_size) break;
@@ -508,6 +491,109 @@ TfLiteStatus NonMaxSuppressionSingleClassHelper(
return kTfLiteOk;
}

+struct BoxInfo {
+ int index;
+ float score;
+};
+
+struct NMSTaskParam {
+ // Caller retains the ownership of `context`, `node`, `op_data` and `scores`.
+ // Caller should ensure their lifetime is longer than NMSTaskParam instance.
+ TfLiteContext* context;
+ TfLiteNode* node;
+ OpData* op_data;
+ const float* scores;
+
+ int num_classes;
+ int num_boxes;
+ int label_offset;
+ int num_classes_with_background;
+ int num_detections_per_class;
+ int max_detections;
+ std::vector<int>& num_selected;
+};
+
+void InplaceMergeBoxInfo(std::vector<BoxInfo>& boxes, int mid_index,
+ int end_index) {
+ std::inplace_merge(
+ boxes.begin(), boxes.begin() + mid_index, boxes.begin() + end_index,
+ [](const BoxInfo& a, const BoxInfo& b) { return a.score >= b.score; });
+}
+
+TfLiteStatus ComputeNMSResult(const NMSTaskParam& nms_task_param, int col_begin,
+ int col_end, int& sorted_indices_size,
+ std::vector<BoxInfo>& resulted_sorted_box_info) {
+ std::vector<float> class_scores(nms_task_param.num_boxes);
+ std::vector<int> selected;
+ selected.reserve(nms_task_param.num_detections_per_class);
+
+ for (int col = col_begin; col <= col_end; ++col) {
+ const float* scores_base =
+ nms_task_param.scores + col + nms_task_param.label_offset;
+ for (int row = 0; row < nms_task_param.num_boxes; row++) {
+ // Get scores of boxes corresponding to all anchors for single class
+ class_scores[row] = *scores_base;
+ scores_base += nms_task_param.num_classes_with_background;
+ }
+
+ // Perform non-maximal suppression on single class
+ selected.clear();
+ TF_LITE_ENSURE_OK(
+ nms_task_param.context,
+ NonMaxSuppressionSingleClassHelper(
+ nms_task_param.context, nms_task_param.node, nms_task_param.op_data,
+ class_scores, nms_task_param.num_detections_per_class, &selected));
+ if (selected.empty()) {
+ continue;
+ }
+
+ for (int i = 0; i < selected.size(); ++i) {
+ resulted_sorted_box_info[sorted_indices_size + i].score =
+ class_scores[selected[i]];
+ resulted_sorted_box_info[sorted_indices_size + i].index =
+ (selected[i] * nms_task_param.num_classes_with_background + col +
+ nms_task_param.label_offset);
+ }
+
+ // In-place merge the original boxes and new selected boxes which are both
+ // sorted by scores.
+ InplaceMergeBoxInfo(resulted_sorted_box_info, sorted_indices_size,
+ sorted_indices_size + selected.size());
+
+ sorted_indices_size =
+ std::min(sorted_indices_size + static_cast<int>(selected.size()),
+ nms_task_param.max_detections);
+ }
+ return kTfLiteOk;
+}
+
+struct NonMaxSuppressionWorkerTask : cpu_backend_threadpool::Task {
+ NonMaxSuppressionWorkerTask(NMSTaskParam& nms_task_param,
+ std::atomic<int>& next_col, int col_begin)
+ : nms_task_param(nms_task_param),
+ next_col(next_col),
+ col_begin(col_begin),
+ sorted_indices_size(0) {}
+ void Run() override {
+ sorted_box_info.resize(nms_task_param.num_detections_per_class +
+ nms_task_param.max_detections);
+ for (int col = col_begin; col < nms_task_param.num_classes;
+ col = (++next_col)) {
+ if (ComputeNMSResult(nms_task_param, col, col, sorted_indices_size,
+ sorted_box_info) != kTfLiteOk) {
+ break;
+ }
+ }
+ }
+ NMSTaskParam& nms_task_param;
+ // A shared atomic variable across threads, representing the next col this
+ // task will work on after completing the work for 'col_begin'
+ std::atomic<int>& next_col;
+ const int col_begin;
+ int sorted_indices_size;
+ std::vector<BoxInfo> sorted_box_info;
+};
+
// This function implements a regular version of Non Maximal Suppression (NMS)
// for multiple classes where
// 1) we do NMS separately for each class across all anchors and
@@ -549,7 +635,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,

const int num_boxes = input_box_encodings->dims->data[1];
const int num_classes = op_data->num_classes;
- const int num_detections_per_class = op_data->detections_per_class;
+ const int num_detections_per_class =
+ std::min(op_data->detections_per_class, op_data->max_detections);
const int max_detections = op_data->max_detections;
const int num_classes_with_background =
input_class_predictions->dims->data[2];
@@ -557,73 +644,70 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
int label_offset = num_classes_with_background - num_classes;
TF_LITE_ENSURE(context, num_detections_per_class > 0);

- // For each class, perform non-max suppression.
- std::vector<float> class_scores(num_boxes);
-
- std::vector<int> box_indices_after_regular_non_max_suppression(
- num_boxes + max_detections);
- std::vector<float> scores_after_regular_non_max_suppression(num_boxes +
- max_detections);
-
- int size_of_sorted_indices = 0;
- std::vector<int> sorted_indices;
- sorted_indices.resize(num_boxes + max_detections);
- std::vector<float> sorted_values;
- sorted_values.resize(max_detections);
-
- for (int col = 0; col < num_classes; col++) {
- for (int row = 0; row < num_boxes; row++) {
- // Get scores of boxes corresponding to all anchors for single class
- class_scores[row] =
- *(scores + row * num_classes_with_background + col + label_offset);
- }
- // Perform non-maximal suppression on single class
- std::vector<int> selected;
- TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
- context, node, op_data, class_scores, &selected,
- num_detections_per_class));
- // Add selected indices from non-max suppression of boxes in this class
- int output_index = size_of_sorted_indices;
- for (const auto& selected_index : selected) {
- box_indices_after_regular_non_max_suppression[output_index] =
- (selected_index * num_classes_with_background + col + label_offset);
- scores_after_regular_non_max_suppression[output_index] =
- class_scores[selected_index];
- output_index++;
- }
- // Sort the max scores among the selected indices
- // Get the indices for top scores
- int num_indices_to_sort = std::min(output_index, max_detections);
- DecreasingPartialArgSort(scores_after_regular_non_max_suppression.data(),
- output_index, num_indices_to_sort,
- sorted_indices.data());
-
- // Copy values to temporary vectors
- for (int row = 0; row < num_indices_to_sort; row++) {
- int temp = sorted_indices[row];
- sorted_indices[row] = box_indices_after_regular_non_max_suppression[temp];
- sorted_values[row] = scores_after_regular_non_max_suppression[temp];
+ int sorted_indices_size = 0;
+ std::vector<BoxInfo> box_info_after_regular_non_max_suppression(
+ max_detections + num_detections_per_class);
+ std::vector<int> num_selected(num_classes);
+
+ NMSTaskParam nms_task_param{context,
+ node,
+ op_data,
+ scores,
+ num_classes,
+ num_boxes,
+ label_offset,
+ num_classes_with_background,
+ num_detections_per_class,
+ max_detections,
+ num_selected};
+
+ int num_threads =
+ CpuBackendContext::GetFromContext(context)->max_num_threads();
+ if (num_threads == 1) {
+ // For each class, perform non-max suppression.
+ TF_LITE_ENSURE_OK(
+ context, ComputeNMSResult(nms_task_param, /* col_begin= */ 0,
+ num_classes - 1, sorted_indices_size,
+ box_info_after_regular_non_max_suppression));
+ } else {
+ std::atomic<int> next_col(num_threads);
+ std::vector<NonMaxSuppressionWorkerTask> tasks;
+ tasks.reserve(num_threads);
+ for (int i = 0; i < num_threads; ++i) {
+ tasks.emplace_back(
+ NonMaxSuppressionWorkerTask(nms_task_param, next_col, i));
}
- // Copy scores and indices from temporary vectors
- for (int row = 0; row < num_indices_to_sort; row++) {
- box_indices_after_regular_non_max_suppression[row] = sorted_indices[row];
- scores_after_regular_non_max_suppression[row] = sorted_values[row];
+ cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
+ CpuBackendContext::GetFromContext(context));
+
+ // Merge results from tasks.
+ for (int j = 0; j < tasks.size(); ++j) {
+ if (tasks[j].sorted_indices_size == 0) {
+ continue;
+ }
+ memcpy(&box_info_after_regular_non_max_suppression[sorted_indices_size],
+ &tasks[j].sorted_box_info[0],
+ sizeof(BoxInfo) * tasks[j].sorted_indices_size);
+ InplaceMergeBoxInfo(box_info_after_regular_non_max_suppression,
+ sorted_indices_size,
+ sorted_indices_size + tasks[j].sorted_indices_size);
+ sorted_indices_size = std::min(
+ sorted_indices_size + tasks[j].sorted_indices_size, max_detections);
}
- size_of_sorted_indices = num_indices_to_sort;
}

// Allocate output tensors
for (int output_box_index = 0; output_box_index < max_detections;
output_box_index++) {
- if (output_box_index < size_of_sorted_indices) {
+ if (output_box_index < sorted_indices_size) {
const int anchor_index = floor(
- box_indices_after_regular_non_max_suppression[output_box_index] /
+ box_info_after_regular_non_max_suppression[output_box_index].index /
num_classes_with_background);
const int class_index =
- box_indices_after_regular_non_max_suppression[output_box_index] -
+ box_info_after_regular_non_max_suppression[output_box_index].index -
anchor_index * num_classes_with_background - label_offset;
const float selected_score =
- scores_after_regular_non_max_suppression[output_box_index];
+ box_info_after_regular_non_max_suppression[output_box_index].score;
// detection_boxes
TF_LITE_ENSURE_EQ(context, detection_boxes->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, decoded_boxes->type, kTfLiteFloat32);
@@ -644,9 +728,8 @@ TfLiteStatus NonMaxSuppressionMultiClassRegularHelper(TfLiteContext* context,
GetTensorData<float>(detection_scores)[output_box_index] = 0.0f;
}
}
- GetTensorData<float>(num_detections)[0] = size_of_sorted_indices;
- box_indices_after_regular_non_max_suppression.clear();
- scores_after_regular_non_max_suppression.clear();
+ GetTensorData<float>(num_detections)[0] = sorted_indices_size;
+ box_info_after_regular_non_max_suppression.clear();
return kTfLiteOk;
}

@@ -702,11 +785,12 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
std::vector<float> max_scores;
max_scores.resize(num_boxes);
std::vector<int> sorted_class_indices;
- sorted_class_indices.resize(num_boxes * num_classes);
+ sorted_class_indices.resize(num_boxes * num_categories_per_anchor);
for (int row = 0; row < num_boxes; row++) {
const float* box_scores =
scores + row * num_classes_with_background + label_offset;
- int* class_indices = sorted_class_indices.data() + row * num_classes;
+ int* class_indices =
+ sorted_class_indices.data() + row * num_categories_per_anchor;
DecreasingPartialArgSort(box_scores, num_classes, num_categories_per_anchor,
class_indices);
max_scores[row] = box_scores[class_indices[0]];
@@ -714,14 +798,14 @@ TfLiteStatus NonMaxSuppressionMultiClassFastHelper(TfLiteContext* context,
// Perform non-maximal suppression on max scores
std::vector<int> selected;
TF_LITE_ENSURE_STATUS(NonMaxSuppressionSingleClassHelper(
- context, node, op_data, max_scores, &selected, op_data->max_detections));
+ context, node, op_data, max_scores, op_data->max_detections, &selected));
// Allocate output tensors
int output_box_index = 0;
for (const auto& selected_index : selected) {
const float* box_scores =
scores + selected_index * num_classes_with_background + label_offset;
- const int* class_indices =
- sorted_class_indices.data() + selected_index * num_classes;
+ const int* class_indices = sorted_class_indices.data() +
+ selected_index * num_categories_per_anchor;

for (int col = 0; col < num_categories_per_anchor; ++col) {
int box_offset = max_categories_per_anchor * output_box_index + col;

0 comments on commit 7fd5b8d

Please sign in to comment.