Skip to content

Commit

Permalink
[IE][VPU] Fix NMS DTS (openvinotoolkit#2880)
Browse files Browse the repository at this point in the history
Add a new constructor to fix absent NMS-5 inputs that will be introduced after openvinotoolkit#2450 will be merged.
  • Loading branch information
andreybakalin97 authored and sdurawa committed Nov 6, 2020
1 parent d9d8790 commit 6cd0e35
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <ngraph/node.hpp>
#include <legacy/ngraph_ops/nms_ie.hpp>
#include <ngraph/opsets/opset5.hpp>

#include <memory>
#include <vector>
Expand All @@ -17,6 +18,8 @@ class StaticShapeNonMaxSuppression : public ngraph::op::NonMaxSuppressionIE3 {
static constexpr NodeTypeInfo type_info{"StaticShapeNonMaxSuppression", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }

explicit StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms);

StaticShapeNonMaxSuppression(const Output<Node>& boxes,
const Output<Node>& scores,
const Output<Node>& maxOutputBoxesPerClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@ namespace ngraph { namespace vpu { namespace op {

constexpr NodeTypeInfo StaticShapeNonMaxSuppression::type_info;

StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms)
: StaticShapeNonMaxSuppression(
nms.input_value(0),
nms.input_value(1),
nms.get_input_size() > 2 ? nms.input_value(2) : ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}),
nms.get_input_size() > 3 ? nms.input_value(3) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
nms.get_input_size() > 4 ? nms.input_value(4) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
nms.get_input_size() > 5 ? nms.input_value(5) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}),
nms.get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0,
nms.get_sort_result_descending(),
nms.get_output_type()) {}

StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(
const Output<Node>& boxes,
const Output<Node>& scores,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,7 @@ void dynamicToStaticNonMaxSuppression(std::shared_ptr<ngraph::Node> node) {
VPU_THROW_UNLESS(nms, "dynamicToStaticNonMaxSuppression transformation for {} of type {} expects {} as node for replacement",
node->get_friendly_name(), node->get_type_info(), ngraph::opset5::NonMaxSuppression::type_info);

auto staticShapeNMS = std::make_shared<ngraph::vpu::op::StaticShapeNonMaxSuppression>(
nms->input_value(0),
nms->input_value(1),
nms->input_value(2),
nms->input_value(3),
nms->input_value(4),
nms->input_value(5),
nms->get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0,
nms->get_sort_result_descending(),
nms->get_output_type());
auto staticShapeNMS = std::make_shared<ngraph::vpu::op::StaticShapeNonMaxSuppression>(*nms);

auto dsrIndices = std::make_shared<ngraph::vpu::op::DynamicShapeResolver>(
staticShapeNMS->output(0), staticShapeNMS->output(2));
Expand Down

0 comments on commit 6cd0e35

Please sign in to comment.