Skip to content

Commit

Permalink
Merge pull request #5 from ceciliapeng2011/cecilia/multiclass_nms/alg…
Browse files Browse the repository at this point in the history
…o_impl

Cecilia/multiclass nms/algo impl
  • Loading branch information
luo-cheng2021 authored Jun 16, 2021
2 parents 8559224 + 7727aa5 commit c740425
Show file tree
Hide file tree
Showing 3 changed files with 1,034 additions and 350 deletions.
2 changes: 1 addition & 1 deletion ngraph/core/include/ngraph/op/multiclass_nms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace ngraph
MulticlassNms(const Output<Node>& boxes,
const Output<Node>& scores,
const SortResultType sort_result_type = SortResultType::NONE,
bool sort_result_across_batch = true,
bool sort_result_across_batch = false,
const ngraph::element::Type& output_type = ngraph::element::i64,
const float iou_threshold = 0.0f,
const float score_threshold = 0.0f,
Expand Down
164 changes: 145 additions & 19 deletions ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,64 @@ namespace ngraph
return score < rhs.score || (score == rhs.score && index > rhs.index);
}

inline bool operator>(const BoxInfo& rhs) const
{
return !(score < rhs.score || (score == rhs.score && index > rhs.index));
}

Rectangle box;
int64_t index = 0;
int64_t suppress_begin_index = 0;
int64_t batch_index = 0;
int64_t class_index = 0;
float score = 0.0f;
};

inline std::ostream& operator<<(std::ostream& s, const Rectangle& b)
{
s << "Rectangle{";
s << b.x1 << ", ";
s << b.y1 << ", ";
s << b.x2 << ", ";
s << b.y2;
s << "}";
return s;
}

inline std::ostream& operator<<(std::ostream& s, const BoxInfo& b)
{
s << "BoxInfo{";
s << b.batch_index << ", ";
s << b.class_index << ", ";
s << b.index << ", ";
s << b.box << ", ";
s << b.score;
s << "}";
return s;
}
} // namespace

template<typename T>
void print_queue(T q) { // NB: pass by value so the print uses a copy
std::cout << "\n{";
while(!q.empty())
{
std::cout << q.top() << ", ";
q.pop();
}
std::cout << "}\n";
}

template<typename T>
void print_list(T &q) {
std::cout << "\n{";
for(auto& v : q)
{
std::cout << v << ", ";
}
std::cout << "}\n";
}

void multiclass_nms(const float* boxes_data,
const Shape& boxes_data_shape,
const float* scores_data,
Expand All @@ -143,8 +193,8 @@ namespace ngraph
const Shape& selected_indices_shape,
int64_t* valid_outputs)
{
auto func = [iou_threshold](float iou) {
return iou <= iou_threshold ? 1.0f : 0.0f;
auto func = [](float iou, float adaptive_threshold) {
return iou <= adaptive_threshold ? 1.0f : 0.0f;
};

// boxes shape: {num_batches, num_boxes, 4}
Expand All @@ -158,44 +208,73 @@ namespace ngraph
SelectedOutput* selected_scores_ptr =
reinterpret_cast<SelectedOutput*>(selected_outputs);

size_t boxes_per_class = static_cast<size_t>(nms_top_k);

std::vector<BoxInfo> filteredBoxes;
std::vector<BoxInfo> filteredBoxes; // container for the whole batch

for (int64_t batch = 0; batch < num_batches; batch++)
{
const float* boxesPtr = boxes_data + batch * num_boxes * 4;
Rectangle* r = reinterpret_cast<Rectangle*>(const_cast<float*>(boxesPtr));

int64_t num_dets = 0;
std::vector<BoxInfo> selected_boxes; // container for a batch element

for (int64_t class_idx = 0; class_idx < num_classes; class_idx++)
{
{
if(class_idx == background_class) continue;

auto adaptive_threshold = iou_threshold;

const float* scoresPtr =
scores_data + batch * (num_classes * num_boxes) + class_idx * num_boxes;

std::vector<BoxInfo> candidate_boxes;
candidate_boxes.reserve(num_boxes);

for (int64_t box_idx = 0; box_idx < num_boxes; box_idx++)
{
if (scoresPtr[box_idx] > score_threshold)
if (scoresPtr[box_idx] >= score_threshold) /* NOTE: ">=" instead of ">" used in PDPD */
{
candidate_boxes.emplace_back(
r[box_idx], box_idx, scoresPtr[box_idx], 0, batch, class_idx);
}
}

std::priority_queue<BoxInfo> sorted_boxes(std::less<BoxInfo>(),
std::move(candidate_boxes));
int candiate_size = candidate_boxes.size();

// threshold nms_top_k for each class
// NOTE: "nms_top_k" in PDPD not exactly equal to
// "max_output_boxes_per_class" in ONNX.
if (nms_top_k > -1 && nms_top_k < candiate_size)
{
candiate_size = nms_top_k;
}

if (candiate_size <= 0) // early drop
{
continue;
}

// sort by score
std::partial_sort(candidate_boxes.begin(),
candidate_boxes.begin() + candiate_size,
candidate_boxes.end(),
std::greater<BoxInfo>());

std::vector<BoxInfo> selected;
print_list(candidate_boxes);

std::priority_queue<BoxInfo> sorted_boxes(candidate_boxes.begin(),
candidate_boxes.begin() + candiate_size, std::less<BoxInfo>());

print_list(candidate_boxes);

print_queue(sorted_boxes);

std::vector<BoxInfo> selected; // container for a class
// Get the next box with top score, filter by iou_threshold

BoxInfo next_candidate;
float original_score;

while (!sorted_boxes.empty() && selected.size() < boxes_per_class)
while (!sorted_boxes.empty())
{
next_candidate = sorted_boxes.top();
original_score = next_candidate.score;
Expand All @@ -208,9 +287,9 @@ namespace ngraph
{
float iou =
intersectionOverUnion(next_candidate.box, selected[j].box);
next_candidate.score *= func(iou);
next_candidate.score *= func(iou, adaptive_threshold);

if (iou >= iou_threshold)
if (iou >= adaptive_threshold)
{
should_hard_suppress = true;
break;
Expand All @@ -226,6 +305,10 @@ namespace ngraph

if (!should_hard_suppress)
{
if(nms_eta < 1 && adaptive_threshold > 0.5)
{
adaptive_threshold *= nms_eta;
}
if (next_candidate.score == original_score)
{
selected.push_back(next_candidate);
Expand All @@ -240,17 +323,44 @@ namespace ngraph

for (const auto& box_info : selected)
{
filteredBoxes.push_back(box_info);
selected_boxes.push_back(box_info);
}
num_dets += filteredBoxes.size();
num_dets += selected.size();
} // for each class

/* sort inside batch element */
if (sort_result_type == op::v8::MulticlassNms::SortResultType::SCORE)
{
std::sort(selected_boxes.begin(),
selected_boxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return ((l.batch_index == r.batch_index) &&
((l.score > r.score) ||
((std::fabs(l.score - r.score) < 1e-6) && l.class_index < r.class_index) ||
((std::fabs(l.score - r.score) < 1e-6) && l.class_index == r.class_index && l.index < r.index)));
});
}
// in case of "NONE" and "CLASSID", pass through

// threshold keep_top_k for each batch element
if (keep_top_k > -1 && keep_top_k < num_dets)
{
num_dets = nms_top_k;
selected_boxes.resize(num_dets);
}

*valid_outputs++ = num_dets;
}
for(auto& v:selected_boxes)
{
filteredBoxes.push_back(v);
}
} // for each batch element

if (sort_result_across_batch)
{
std::sort(filteredBoxes.begin(),
{ /* sort across batch */
if (sort_result_type == op::v8::MulticlassNms::SortResultType::SCORE)
{
std::sort(filteredBoxes.begin(),
filteredBoxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (l.score > r.score) ||
Expand All @@ -260,8 +370,24 @@ namespace ngraph
(l.score == r.score && l.batch_index == r.batch_index &&
l.class_index == r.class_index && l.index < r.index);
});
}
else if(sort_result_type == op::v8::MulticlassNms::SortResultType::CLASSID)
{
std::sort(filteredBoxes.begin(),
filteredBoxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (l.class_index < r.class_index) ||
(l.class_index == r.class_index && l.batch_index < r.batch_index) ||
(l.class_index == r.class_index && l.batch_index == r.batch_index &&
l.score > r.score) ||
(l.class_index == r.class_index && l.batch_index == r.batch_index &&
l.score == r.score && l.index < r.index);
});
}
}

/* output */

size_t max_num_of_selected_indices = selected_indices_shape[0];
size_t output_size = std::min(filteredBoxes.size(), max_num_of_selected_indices);

Expand Down
Loading

0 comments on commit c740425

Please sign in to comment.