Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implemented multiclass_nms reference. #2

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ namespace ngraph
const ngraph::element::Type output_type,
const std::vector<float>& selected_outputs,
const std::vector<int64_t>& selected_indices,
int64_t valid_outputs);
const std::vector<int64_t>& valid_outputs,
const ngraph::element::Type selected_scores_type);
} // namespace reference
} // namespace runtime
} // namespace ngraph
271 changes: 239 additions & 32 deletions ngraph/core/reference/src/runtime/reference/multiclass_nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cmath>
#include <queue>
#include <vector>
#include <numeric>
#include "ngraph/runtime/reference/multiclass_nms.hpp"
#include "ngraph/shape.hpp"

Expand All @@ -23,20 +24,20 @@ namespace ngraph
{
struct Rectangle
{
Rectangle(float y_left, float x_left, float y_right, float x_right)
: y1{y_left}
, x1{x_left}
, y2{y_right}
Rectangle(float x_left, float y_left, float x_right, float y_right)
: x1{x_left}
, y1{y_left}
, x2{x_right}
, y2{y_right}
{
}

Rectangle() = default;

float y1 = 0.0f;
float x1 = 0.0f;
float y2 = 0.f;
float y1 = 0.0f;
float x2 = 0.0f;
float y2 = 0.0f;
};

static float intersectionOverUnion(const Rectangle& boxI, const Rectangle& boxJ)
Expand All @@ -63,34 +64,33 @@ namespace ngraph

struct SelectedIndex
{
SelectedIndex(int64_t batch_idx, int64_t class_idx, int64_t box_idx)
: batch_index(batch_idx)
, class_index(class_idx)
, box_index(box_idx)
SelectedIndex(int64_t batch_idx, int64_t box_idx, int64_t num_box)
: flattened_index(batch_idx * num_box + box_idx)
{
}

SelectedIndex() = default;

int64_t batch_index = 0;
int64_t class_index = 0;
int64_t box_index = 0;
int64_t flattened_index = 0;
};

struct SelectedScore
struct SelectedOutput
{
SelectedScore(float batch_idx, float class_idx, float score)
: batch_index{batch_idx}
, class_index{class_idx}
SelectedOutput(float class_idx, float score, float x1, float y1, float x2, float y2)
: class_index{class_idx}
, box_score{score}
, xmin{x1}
, ymin{y1}
, xmax{x2}
, ymax{y2}
{
}

SelectedScore() = default;
SelectedOutput() = default;

float batch_index = 0.0f;
float class_index = 0.0f;
float box_score = 0.0f;
float xmin, ymin, xmax, ymax;
};

struct BoxInfo
Expand Down Expand Up @@ -142,45 +142,252 @@ namespace ngraph
const Shape& selected_indices_shape,
int64_t* valid_outputs)
{
BoxInfo info;
intersectionOverUnion(Rectangle{}, Rectangle{});
*valid_outputs = 0;
auto func = [iou_threshold](float iou) {
return iou <= iou_threshold ? 1.0f : 0.0f;
};

// boxes shape: {num_batches, num_boxes, 4}
// scores shape: {num_batches, num_classes, num_boxes}
int64_t num_batches = static_cast<int64_t>(scores_data_shape[0]);
int64_t num_classes = static_cast<int64_t>(scores_data_shape[1]);
int64_t num_boxes = static_cast<int64_t>(boxes_data_shape[1]);

SelectedIndex* selected_indices_ptr =
reinterpret_cast<SelectedIndex*>(selected_indices);
SelectedOutput* selected_scores_ptr =
reinterpret_cast<SelectedOutput*>(selected_outputs);

size_t boxes_per_class = static_cast<size_t>(nms_top_k);

std::vector<BoxInfo> filteredBoxes;

for (int64_t batch = 0; batch < num_batches; batch++)
{
const float* boxesPtr = boxes_data + batch * num_boxes * 4;
Rectangle* r = reinterpret_cast<Rectangle*>(const_cast<float*>(boxesPtr));

int64_t num_dets = 0;

for (int64_t class_idx = 0; class_idx < num_classes; class_idx++)
{
const float* scoresPtr =
scores_data + batch * (num_classes * num_boxes) + class_idx * num_boxes;

std::vector<BoxInfo> candidate_boxes;
candidate_boxes.reserve(num_boxes);

for (int64_t box_idx = 0; box_idx < num_boxes; box_idx++)
{
if (scoresPtr[box_idx] > score_threshold)
{
candidate_boxes.emplace_back(
r[box_idx], box_idx, scoresPtr[box_idx], 0, batch, class_idx);
}
}

std::priority_queue<BoxInfo> sorted_boxes(std::less<BoxInfo>(),
std::move(candidate_boxes));

std::vector<BoxInfo> selected;
// Get the next box with top score, filter by iou_threshold

BoxInfo next_candidate;
float original_score;

while (!sorted_boxes.empty() && selected.size() < boxes_per_class)
{
next_candidate = sorted_boxes.top();
original_score = next_candidate.score;
sorted_boxes.pop();

bool should_hard_suppress = false;
for (int64_t j = static_cast<int64_t>(selected.size()) - 1;
j >= next_candidate.suppress_begin_index;
--j)
{
float iou =
intersectionOverUnion(next_candidate.box, selected[j].box);
next_candidate.score *= func(iou);

if (iou >= iou_threshold)
{
should_hard_suppress = true;
break;
}

if (next_candidate.score <= score_threshold)
{
break;
}
}

next_candidate.suppress_begin_index = selected.size();

if (!should_hard_suppress)
{
if (next_candidate.score == original_score)
{
selected.push_back(next_candidate);
continue;
}
if (next_candidate.score > score_threshold)
{
sorted_boxes.push(next_candidate);
}
}
}

for (const auto& box_info : selected)
{
filteredBoxes.push_back(box_info);
}
num_dets += filteredBoxes.size();
}

*valid_outputs++ = num_dets;
}

bool sort_result_across_batch = false; // TODO

if (sort_result_across_batch)
{
std::sort(filteredBoxes.begin(),
filteredBoxes.end(),
[](const BoxInfo& l, const BoxInfo& r) {
return (l.score > r.score) ||
(l.score == r.score && l.batch_index < r.batch_index) ||
(l.score == r.score && l.batch_index == r.batch_index &&
l.class_index < r.class_index) ||
(l.score == r.score && l.batch_index == r.batch_index &&
l.class_index == r.class_index && l.index < r.index);
});
}

size_t max_num_of_selected_indices = selected_indices_shape[0];
size_t output_size = std::min(filteredBoxes.size(), max_num_of_selected_indices);

size_t idx;
for (idx = 0; idx < output_size; idx++)
{
const auto& box_info = filteredBoxes[idx];
SelectedIndex selected_index{
box_info.batch_index, box_info.index, num_boxes};
SelectedOutput selected_score{static_cast<float>(box_info.class_index),
box_info.score,
box_info.box.x1, box_info.box.y1,
box_info.box.x2, box_info.box.y2};

selected_indices_ptr[idx] = selected_index;
selected_scores_ptr[idx] = selected_score;
}

SelectedIndex selected_index_filler{0, 0, 0};
SelectedOutput selected_score_filler{0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
for (; idx < max_num_of_selected_indices; idx++)
{
selected_indices_ptr[idx] = selected_index_filler;
selected_scores_ptr[idx] = selected_score_filler;
}
}

void multiclass_nms_postprocessing(const HostTensorVector& outputs,
const ngraph::element::Type output_type,
const std::vector<float>& selected_outputs,
const std::vector<int64_t>& selected_indices,
int64_t valid_outputs)
{
outputs[0]->set_shape(Shape{static_cast<size_t>(valid_outputs), 6});
float* ptr = outputs[0]->get_data_ptr<float>();
memcpy(ptr, selected_outputs.data(), valid_outputs * sizeof(float) * 6);
const std::vector<int64_t>& valid_outputs,
const ngraph::element::Type selected_scores_type)
{
auto num_selected = std::accumulate(valid_outputs.begin(), valid_outputs.end(), 0);

/* shape & type */

outputs[0]->set_element_type(selected_scores_type); // "selected_outputs"
outputs[0]->set_shape(Shape{static_cast<size_t>(num_selected), 6});

size_t num_of_outputs = outputs.size();

if (num_of_outputs >= 2)
{
outputs[1]->set_element_type(output_type); // "selected_indices"
outputs[1]->set_shape(Shape{static_cast<size_t>(num_selected)});
}

if (num_of_outputs >= 3)
{
outputs[2]->set_element_type(output_type); // "selected_num"
outputs[2]->set_shape(Shape{valid_outputs.size()});
}

/* data */
size_t selected_outputs_size = num_selected * 6;

switch (selected_scores_type)
{
case element::Type_t::bf16:
{
bfloat16* scores_ptr = outputs[0]->get_data_ptr<bfloat16>();
for (size_t i = 0; i < selected_outputs_size; ++i)
{
scores_ptr[i] = bfloat16(selected_outputs[i]);
}
}
break;
case element::Type_t::f16:
{
float16* scores_ptr = outputs[0]->get_data_ptr<float16>();
for (size_t i = 0; i < selected_outputs_size; ++i)
{
scores_ptr[i] = float16(selected_outputs[i]);
}
}
break;
case element::Type_t::f32:
{
float* scores_ptr = outputs[0]->get_data_ptr<float>();
memcpy(scores_ptr, selected_outputs.data(), selected_outputs_size * sizeof(float));
}
break;
default:;
}

if (num_of_outputs < 2)
{
return;
}

size_t selected_indices_size = num_selected * 1;

outputs[1]->set_shape(Shape{static_cast<size_t>(valid_outputs), 1});
if (output_type == ngraph::element::i64)
{
int64_t* indices_ptr = outputs[1]->get_data_ptr<int64_t>();
memcpy(indices_ptr, selected_indices.data(), valid_outputs * sizeof(int64_t));
memcpy(indices_ptr, selected_indices.data(), selected_indices_size * sizeof(int64_t));
}
else
{
int32_t* indices_ptr = outputs[1]->get_data_ptr<int32_t>();
for (size_t i = 0; i < (size_t)valid_outputs; ++i)
for (size_t i = 0; i < selected_indices_size; ++i)
{
indices_ptr[i] = static_cast<int32_t>(selected_indices[i]);
}
}

if (num_of_outputs < 3)
{
return;
}

if (output_type == ngraph::element::i64)
{
int64_t* valid_outputs_ptr = outputs[2]->get_data_ptr<int64_t>();
*valid_outputs_ptr = valid_outputs;
memcpy(valid_outputs_ptr, valid_outputs.data(), valid_outputs.size() * sizeof(int64_t));
}
else
{
int32_t* valid_outputs_ptr = outputs[2]->get_data_ptr<int32_t>();
*valid_outputs_ptr = static_cast<int32_t>(valid_outputs);
for (size_t i = 0; i < valid_outputs.size(); ++i)
{
valid_outputs_ptr[i] = static_cast<int32_t>(valid_outputs[i]);
}
}
}
} // namespace reference
Expand Down
Loading