Skip to content

Commit

Permalink
Merge pull request #8 from ceciliapeng2011/cecilia/ops_matrix_nms
Browse files Browse the repository at this point in the history
fixes: 1. normalized support. 2. sort by score before keep_top_k insi…
  • Loading branch information
luo-cheng2021 authored Jun 24, 2021
2 parents 950e830 + fb096de commit 9179f5a
Showing 1 changed file with 42 additions and 23 deletions.
65 changes: 42 additions & 23 deletions ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ namespace ngraph
float y2 = 0.0f;
};

static float intersectionOverUnion(const Rectangle& boxI, const Rectangle& boxJ)
static float intersectionOverUnion(const Rectangle& boxI, const Rectangle& boxJ, const bool normalized)
{
float areaI = (boxI.y2 - boxI.y1) * (boxI.x2 - boxI.x1);
float areaJ = (boxJ.y2 - boxJ.y1) * (boxJ.x2 - boxJ.x1);
const float norm = static_cast<float>(normalized == false);

float areaI = (boxI.y2 - boxI.y1 + norm) * (boxI.x2 - boxI.x1 + norm);
float areaJ = (boxJ.y2 - boxJ.y1 + norm) * (boxJ.x2 - boxJ.x1 + norm);

if (areaI <= 0.0f || areaJ <= 0.0f)
{
Expand All @@ -56,8 +58,8 @@ namespace ngraph
float intersection_xmax = std::min(boxI.x2, boxJ.x2);

float intersection_area =
std::max(intersection_ymax - intersection_ymin, 0.0f) *
std::max(intersection_xmax - intersection_xmin, 0.0f);
std::max(intersection_ymax - intersection_ymin + norm, 0.0f) *
std::max(intersection_xmax - intersection_xmin + norm, 0.0f);

return intersection_area / (areaI + areaJ - intersection_area);
}
Expand Down Expand Up @@ -241,7 +243,7 @@ namespace ngraph
continue;
}

// sort by score
// sort by score in current class
std::partial_sort(candidate_boxes.begin(),
candidate_boxes.begin() + candiate_size,
candidate_boxes.end(),
Expand Down Expand Up @@ -270,7 +272,7 @@ namespace ngraph
--j)
{
float iou = multiclass_nms_v8::intersectionOverUnion(
next_candidate.box, selected[j].box);
next_candidate.box, selected[j].box, normalized);
next_candidate.score *= func(iou, adaptive_threshold);

if (iou >= adaptive_threshold)
Expand Down Expand Up @@ -312,22 +314,18 @@ namespace ngraph
num_dets += selected.size();
} // for each class

/* sort inside batch element */
if (sort_result_type == op::v8::MulticlassNms::SortResultType::SCORE)
{
std::sort(selected_boxes.begin(),
selected_boxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (
(l.batch_index == r.batch_index) &&
((l.score > r.score) ||
((std::fabs(l.score - r.score) < 1e-6) &&
l.class_index < r.class_index) ||
((std::fabs(l.score - r.score) < 1e-6) &&
l.class_index == r.class_index && l.index < r.index)));
});
}
// in case of "NONE" and "CLASSID", pass through
// sort inside batch element before go through keep_top_k
std::sort(selected_boxes.begin(),
selected_boxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (
(l.batch_index == r.batch_index) &&
((l.score > r.score) ||
((std::fabs(l.score - r.score) < 1e-6) &&
l.class_index < r.class_index) ||
((std::fabs(l.score - r.score) < 1e-6) &&
l.class_index == r.class_index && l.index < r.index)));
});

// threshold keep_top_k for each batch element
if (keep_top_k > -1 && keep_top_k < num_dets)
Expand All @@ -336,6 +334,27 @@ namespace ngraph
selected_boxes.resize(num_dets);
}

// sort
if (!sort_result_across_batch)
{
if (sort_result_type == op::v8::MulticlassNms::SortResultType::CLASSID)
{
std::sort(selected_boxes.begin(),
selected_boxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (
(l.batch_index == r.batch_index) &&
((l.class_index < r.class_index) ||
((l.class_index == r.class_index) &&
l.score > r.score) ||
((std::fabs(l.score - r.score) <= 1e-6) &&
l.class_index == r.class_index && l.index < r.index)));
});
}
// in case of "SCORE", pass through, as,
// it has already gurranteed.
}

*valid_outputs++ = num_dets;
for (auto& v : selected_boxes)
{
Expand Down

0 comments on commit 9179f5a

Please sign in to comment.