Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cecilia/multiclass nms/algo impl #5

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ngraph/core/include/ngraph/op/multiclass_nms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace ngraph
MulticlassNms(const Output<Node>& boxes,
const Output<Node>& scores,
const SortResultType sort_result_type = SortResultType::NONE,
bool sort_result_across_batch = true,
bool sort_result_across_batch = false,
const ngraph::element::Type& output_type = ngraph::element::i64,
const float iou_threshold = 0.0f,
const float score_threshold = 0.0f,
Expand Down
164 changes: 145 additions & 19 deletions ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,64 @@ namespace ngraph
return score < rhs.score || (score == rhs.score && index > rhs.index);
}

inline bool operator>(const BoxInfo& rhs) const
{
return !(score < rhs.score || (score == rhs.score && index > rhs.index));
}

Rectangle box;
int64_t index = 0;
int64_t suppress_begin_index = 0;
int64_t batch_index = 0;
int64_t class_index = 0;
float score = 0.0f;
};

inline std::ostream& operator<<(std::ostream& s, const Rectangle& b)
{
s << "Rectangle{";
s << b.x1 << ", ";
s << b.y1 << ", ";
s << b.x2 << ", ";
s << b.y2;
s << "}";
return s;
}

inline std::ostream& operator<<(std::ostream& s, const BoxInfo& b)
{
s << "BoxInfo{";
s << b.batch_index << ", ";
s << b.class_index << ", ";
s << b.index << ", ";
s << b.box << ", ";
s << b.score;
s << "}";
return s;
}
} // namespace

template<typename T>
void print_queue(T q) { // NB: pass by value so the print uses a copy
std::cout << "\n{";
while(!q.empty())
{
std::cout << q.top() << ", ";
q.pop();
}
std::cout << "}\n";
}

template<typename T>
void print_list(T &q) {
std::cout << "\n{";
for(auto& v : q)
{
std::cout << v << ", ";
}
std::cout << "}\n";
}

void multiclass_nms(const float* boxes_data,
const Shape& boxes_data_shape,
const float* scores_data,
Expand All @@ -143,8 +193,8 @@ namespace ngraph
const Shape& selected_indices_shape,
int64_t* valid_outputs)
{
auto func = [iou_threshold](float iou) {
return iou <= iou_threshold ? 1.0f : 0.0f;
auto func = [](float iou, float adaptive_threshold) {
return iou <= adaptive_threshold ? 1.0f : 0.0f;
};

// boxes shape: {num_batches, num_boxes, 4}
Expand All @@ -158,44 +208,73 @@ namespace ngraph
SelectedOutput* selected_scores_ptr =
reinterpret_cast<SelectedOutput*>(selected_outputs);

size_t boxes_per_class = static_cast<size_t>(nms_top_k);

std::vector<BoxInfo> filteredBoxes;
std::vector<BoxInfo> filteredBoxes; // container for the whole batch

for (int64_t batch = 0; batch < num_batches; batch++)
{
const float* boxesPtr = boxes_data + batch * num_boxes * 4;
Rectangle* r = reinterpret_cast<Rectangle*>(const_cast<float*>(boxesPtr));

int64_t num_dets = 0;
std::vector<BoxInfo> selected_boxes; // container for a batch element

for (int64_t class_idx = 0; class_idx < num_classes; class_idx++)
{
{
if(class_idx == background_class) continue;

auto adaptive_threshold = iou_threshold;

const float* scoresPtr =
scores_data + batch * (num_classes * num_boxes) + class_idx * num_boxes;

std::vector<BoxInfo> candidate_boxes;
candidate_boxes.reserve(num_boxes);

for (int64_t box_idx = 0; box_idx < num_boxes; box_idx++)
{
if (scoresPtr[box_idx] > score_threshold)
if (scoresPtr[box_idx] >= score_threshold) /* NOTE: ">=" instead of ">" used in PDPD */
{
candidate_boxes.emplace_back(
r[box_idx], box_idx, scoresPtr[box_idx], 0, batch, class_idx);
}
}

std::priority_queue<BoxInfo> sorted_boxes(std::less<BoxInfo>(),
std::move(candidate_boxes));
int candiate_size = candidate_boxes.size();

// threshold nms_top_k for each class
// NOTE: "nms_top_k" in PDPD not exactly equal to
// "max_output_boxes_per_class" in ONNX.
if (nms_top_k > -1 && nms_top_k < candiate_size)
{
candiate_size = nms_top_k;
}

if (candiate_size <= 0) // early drop
{
continue;
}

// sort by score
std::partial_sort(candidate_boxes.begin(),
candidate_boxes.begin() + candiate_size,
candidate_boxes.end(),
std::greater<BoxInfo>());

std::vector<BoxInfo> selected;
print_list(candidate_boxes);

std::priority_queue<BoxInfo> sorted_boxes(candidate_boxes.begin(),
candidate_boxes.begin() + candiate_size, std::less<BoxInfo>());

print_list(candidate_boxes);

print_queue(sorted_boxes);

std::vector<BoxInfo> selected; // container for a class
// Get the next box with top score, filter by iou_threshold

BoxInfo next_candidate;
float original_score;

while (!sorted_boxes.empty() && selected.size() < boxes_per_class)
while (!sorted_boxes.empty())
{
next_candidate = sorted_boxes.top();
original_score = next_candidate.score;
Expand All @@ -208,9 +287,9 @@ namespace ngraph
{
float iou =
intersectionOverUnion(next_candidate.box, selected[j].box);
next_candidate.score *= func(iou);
next_candidate.score *= func(iou, adaptive_threshold);

if (iou >= iou_threshold)
if (iou >= adaptive_threshold)
{
should_hard_suppress = true;
break;
Expand All @@ -226,6 +305,10 @@ namespace ngraph

if (!should_hard_suppress)
{
if(nms_eta < 1 && adaptive_threshold > 0.5)
{
adaptive_threshold *= nms_eta;
}
if (next_candidate.score == original_score)
{
selected.push_back(next_candidate);
Expand All @@ -240,17 +323,44 @@ namespace ngraph

for (const auto& box_info : selected)
{
filteredBoxes.push_back(box_info);
selected_boxes.push_back(box_info);
}
num_dets += filteredBoxes.size();
num_dets += selected.size();
} // for each class

/* sort inside batch element */
if (sort_result_type == op::v8::MulticlassNms::SortResultType::SCORE)
{
std::sort(selected_boxes.begin(),
selected_boxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return ((l.batch_index == r.batch_index) &&
((l.score > r.score) ||
((std::fabs(l.score - r.score) < 1e-6) && l.class_index < r.class_index) ||
((std::fabs(l.score - r.score) < 1e-6) && l.class_index == r.class_index && l.index < r.index)));
});
}
// in case of "NONE" and "CLASSID", pass through

// threshold keep_top_k for each batch element
if (keep_top_k > -1 && keep_top_k < num_dets)
{
num_dets = nms_top_k;
selected_boxes.resize(num_dets);
}

*valid_outputs++ = num_dets;
}
for(auto& v:selected_boxes)
{
filteredBoxes.push_back(v);
}
} // for each batch element

if (sort_result_across_batch)
{
std::sort(filteredBoxes.begin(),
{ /* sort across batch */
if (sort_result_type == op::v8::MulticlassNms::SortResultType::SCORE)
{
std::sort(filteredBoxes.begin(),
filteredBoxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (l.score > r.score) ||
Expand All @@ -260,8 +370,24 @@ namespace ngraph
(l.score == r.score && l.batch_index == r.batch_index &&
l.class_index == r.class_index && l.index < r.index);
});
}
else if(sort_result_type == op::v8::MulticlassNms::SortResultType::CLASSID)
{
std::sort(filteredBoxes.begin(),
filteredBoxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (l.class_index < r.class_index) ||
(l.class_index == r.class_index && l.batch_index < r.batch_index) ||
(l.class_index == r.class_index && l.batch_index == r.batch_index &&
l.score > r.score) ||
(l.class_index == r.class_index && l.batch_index == r.batch_index &&
l.score == r.score && l.index < r.index);
});
}
}

/* output */

size_t max_num_of_selected_indices = selected_indices_shape[0];
size_t output_size = std::min(filteredBoxes.size(), max_num_of_selected_indices);

Expand Down
Loading