From 6cd0e35a04265ec0a78506f5ae61bdf5e73a4252 Mon Sep 17 00:00:00 2001 From: Andrew Bakalin Date: Thu, 5 Nov 2020 13:33:16 +0300 Subject: [PATCH] [IE][VPU] Fix NMS DTS (#2880) Add a new constructor to fix absent NMS-5 inputs that will be introduced after #2450 will be merged. --- .../static_shape_non_maximum_suppression.hpp | 3 +++ .../static_shape_non_maximum_suppression.cpp | 12 ++++++++++++ .../dynamic_to_static_shape_non_max_suppression.cpp | 11 +---------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/inference-engine/src/vpu/common/include/vpu/ngraph/operations/static_shape_non_maximum_suppression.hpp b/inference-engine/src/vpu/common/include/vpu/ngraph/operations/static_shape_non_maximum_suppression.hpp index acc0b898cf26dd..a31115d18e3021 100644 --- a/inference-engine/src/vpu/common/include/vpu/ngraph/operations/static_shape_non_maximum_suppression.hpp +++ b/inference-engine/src/vpu/common/include/vpu/ngraph/operations/static_shape_non_maximum_suppression.hpp @@ -6,6 +6,7 @@ #include #include +#include #include #include @@ -17,6 +18,8 @@ class StaticShapeNonMaxSuppression : public ngraph::op::NonMaxSuppressionIE3 { static constexpr NodeTypeInfo type_info{"StaticShapeNonMaxSuppression", 0}; const NodeTypeInfo& get_type_info() const override { return type_info; } + explicit StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms); + StaticShapeNonMaxSuppression(const Output& boxes, const Output& scores, const Output& maxOutputBoxesPerClass, diff --git a/inference-engine/src/vpu/common/src/ngraph/operations/static_shape_non_maximum_suppression.cpp b/inference-engine/src/vpu/common/src/ngraph/operations/static_shape_non_maximum_suppression.cpp index 94e1a964ca6719..18c8b3bca7bd8d 100644 --- a/inference-engine/src/vpu/common/src/ngraph/operations/static_shape_non_maximum_suppression.cpp +++ b/inference-engine/src/vpu/common/src/ngraph/operations/static_shape_non_maximum_suppression.cpp @@ -12,6 +12,18 @@ namespace ngraph { namespace vpu { namespace op { constexpr NodeTypeInfo StaticShapeNonMaxSuppression::type_info; +StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression(const ngraph::opset5::NonMaxSuppression& nms) + : StaticShapeNonMaxSuppression( + nms.input_value(0), + nms.input_value(1), + nms.get_input_size() > 2 ? nms.input_value(2) : ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}), + nms.get_input_size() > 3 ? nms.input_value(3) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}), + nms.get_input_size() > 4 ? nms.input_value(4) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}), + nms.get_input_size() > 5 ? nms.input_value(5) : ngraph::opset5::Constant::create(ngraph::element::f32, ngraph::Shape{}, {.0f}), + nms.get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0, + nms.get_sort_result_descending(), + nms.get_output_type()) {} + StaticShapeNonMaxSuppression::StaticShapeNonMaxSuppression( const Output& boxes, const Output& scores, diff --git a/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_non_max_suppression.cpp b/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_non_max_suppression.cpp index 02145ca889ab4a..1a369393fcb7cb 100644 --- a/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_non_max_suppression.cpp +++ b/inference-engine/src/vpu/common/src/ngraph/transformations/dynamic_to_static_shape_non_max_suppression.cpp @@ -21,16 +21,7 @@ void dynamicToStaticNonMaxSuppression(std::shared_ptr node) { VPU_THROW_UNLESS(nms, "dynamicToStaticNonMaxSuppression transformation for {} of type {} expects {} as node for replacement", node->get_friendly_name(), node->get_type_info(), ngraph::opset5::NonMaxSuppression::type_info); - auto staticShapeNMS = std::make_shared( - nms->input_value(0), - nms->input_value(1), - nms->input_value(2), - nms->input_value(3), - nms->input_value(4), - nms->input_value(5), - nms->get_box_encoding() == ngraph::opset5::NonMaxSuppression::BoxEncodingType::CENTER ? 1 : 0, - nms->get_sort_result_descending(), - nms->get_output_type()); + auto staticShapeNMS = std::make_shared(*nms); auto dsrIndices = std::make_shared( staticShapeNMS->output(0), staticShapeNMS->output(2));