From a6a22a2eda8c6e16b72de95004a2a0bf1ad05392 Mon Sep 17 00:00:00 2001 From: Jan Iwaszkiewicz Date: Tue, 25 May 2021 19:37:07 +0200 Subject: [PATCH] [ONNX] ArgMin/ArgMax support for select_last_index (#5661) * Add Reverse Op to opset * Worksave with Reverse path * Add last_index support * refactor argminmax factory * Remove old xfail, add new one * Fix proto file for argmax * Rewrite test for select_last_index * Add CPU tests to Manifest * Update manifest * Remove Reverse from opset7 * Refactor arg_min_max factory * Added example comment in arg_min_max * Codestyle changes --- ngraph/frontend/onnx_import/src/op/argmax.cpp | 7 -- ngraph/frontend/onnx_import/src/op/argmin.cpp | 7 -- .../src/utils/arg_min_max_factory.cpp | 74 +++++++++++++++++-- .../src/utils/arg_min_max_factory.hpp | 1 + ngraph/python/tests/__init__.py | 2 +- ngraph/python/tests/test_onnx/test_backend.py | 20 +---- .../onnx/argmax_select_last_index.prototxt | 3 + ngraph/test/onnx/onnx_import.in.cpp | 49 ++++-------- ngraph/test/runtime/ie/unit_test.manifest | 2 + 9 files changed, 94 insertions(+), 71 deletions(-) diff --git a/ngraph/frontend/onnx_import/src/op/argmax.cpp b/ngraph/frontend/onnx_import/src/op/argmax.cpp index e91983528e20b3..1356626c5c3562 100644 --- a/ngraph/frontend/onnx_import/src/op/argmax.cpp +++ b/ngraph/frontend/onnx_import/src/op/argmax.cpp @@ -26,13 +26,6 @@ namespace ngraph { OutputVector argmax(const Node& node) { - const auto select_last_index = - node.get_attribute_value("select_last_index", 0); - CHECK_VALID_NODE(node, - select_last_index == 0, - "Mode 'select_last_index=1' is not supported by current " - "implementation of ArgMax"); - const utils::ArgMinMaxFactory arg_factory(node); return {arg_factory.make_arg_max()}; } diff --git a/ngraph/frontend/onnx_import/src/op/argmin.cpp b/ngraph/frontend/onnx_import/src/op/argmin.cpp index 983a8dae8ca099..53478f48a59ccf 100644 --- a/ngraph/frontend/onnx_import/src/op/argmin.cpp +++ b/ngraph/frontend/onnx_import/src/op/argmin.cpp @@ -26,13 +26,6 @@ namespace ngraph { OutputVector argmin(const Node& node) { - const auto select_last_index = - node.get_attribute_value("select_last_index", 0); - CHECK_VALID_NODE(node, - select_last_index == 0, - "Mode 'select_last_index=1' is not supported by current " - "implementation of ArgMin"); - const utils::ArgMinMaxFactory arg_factory(node); return {arg_factory.make_arg_min()}; } diff --git a/ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.cpp b/ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.cpp index b0378fae0f497b..65bba632482156 100644 --- a/ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.cpp +++ b/ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.cpp @@ -4,6 +4,7 @@ #include "utils/arg_min_max_factory.hpp" #include "default_opset.hpp" +#include "ngraph/opsets/opset1.hpp" #include "ngraph/validation_util.hpp" namespace ngraph @@ -14,9 +15,11 @@ namespace ngraph { ArgMinMaxFactory::ArgMinMaxFactory(const Node& node) : m_keep_dims{node.get_attribute_value("keepdims", 1)} + , m_input_node{node.get_ng_inputs().at(0)} , m_axis{node.get_attribute_value("axis", 0)} + , m_select_last_index{ + node.get_attribute_value("select_last_index", 0)} { - m_input_node = node.get_ng_inputs().at(0); } std::shared_ptr ArgMinMaxFactory::make_arg_max() const @@ -34,19 +37,80 @@ namespace ngraph { const auto k_node = default_opset::Constant::create(ngraph::element::i64, Shape{}, {1}); + + if (m_select_last_index == 1) + { + // Example (ArgMin): + // The goal is to get the index of the last occurence of the + // minimum value present in given input tensor. + // + // Input: [1, 2, 1, 3, 4, 4] + // Expected output: [2] + // + // Top-K is always returning the "most-left" result. The trick is to + // reverse input to find the "most-right" occurence which is equal to + // the last occurence in the original input. + // reverse = [4, 4, 3, 1, 2, 1] + // + // Run TopK on reversed tensor, in the example output with index values + // is equal to: + // topk->output(1) = 3 + // + // Using ShapeOf and Gather on input obtain length of the input tensor + // along axis, in the example this is equal to: + // dims_on_axis = 6 + // + // Now using two Substract ops calculate resulting index: + // res_index = dims_on_axis - topk->output(1) = 6 - 3 = 3 + // result = res_index - 1 = 3 - 1 = 2 + + const auto axis_node = + default_opset::Constant::create(ngraph::element::i64, Shape{1}, {m_axis}); + const auto reverse = std::make_shared( + m_input_node, axis_node, opset1::Reverse::Mode::INDEX); + + const auto topk = std::make_shared( + reverse, k_node, m_axis, mode, default_opset::TopK::SortType::NONE); + + const auto data_shape = std::make_shared(m_input_node); + const auto dims_on_axis = std::make_shared( + data_shape, + axis_node, + default_opset::Constant::create(ngraph::element::i64, Shape{}, {0})); + + const auto res_index = std::make_shared( + dims_on_axis, + std::make_shared(topk->output(1), element::i64)); + const auto result = std::make_shared( + res_index, + default_opset::Constant::create(ngraph::element::i64, Shape{1}, {1})); + + if (m_keep_dims == 0) + { + const auto axis_to_remove = default_opset::Constant::create( + element::u64, Shape{}, {topk->get_axis()}); + + return std::make_shared(result, axis_to_remove); + } + + return result; + } + const auto topk = std::make_shared( m_input_node, k_node, m_axis, mode, default_opset::TopK::SortType::NONE); + const auto result = + std::make_shared(topk->output(1), element::i64); + if (m_keep_dims == 0) { const auto axis_to_remove = default_opset::Constant::create(element::u64, Shape{}, {topk->get_axis()}); - const auto reshaped_indices = - std::make_shared(topk->output(1), axis_to_remove); - return std::make_shared(reshaped_indices, element::i64); + return std::make_shared(result, axis_to_remove); } - return std::make_shared(topk->output(1), element::i64); + + return result; } } // namespace utils } // namespace onnx_import diff --git a/ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.hpp b/ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.hpp index fadd83a083cf7a..b76cd649761158 100644 --- a/ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.hpp +++ b/ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.hpp @@ -39,6 +39,7 @@ namespace ngraph const std::int64_t m_keep_dims; Output m_input_node; std::int64_t m_axis; + std::int64_t m_select_last_index; }; } // namespace utils diff --git a/ngraph/python/tests/__init__.py b/ngraph/python/tests/__init__.py index c0b0b181e0516c..b84ec773584908 100644 --- a/ngraph/python/tests/__init__.py +++ b/ngraph/python/tests/__init__.py @@ -111,12 +111,12 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True): xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:" "FakeQuantize_xxx has non const input on 1 port") xfail_issue_46762 = xfail_test(reason="Incorrect result of Minimum op if uint data type is used") -xfail_issue_46765 = xfail_test(reason="select_last_index attribute is not supported by ArgMin and ArgMax") xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64") xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot") xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output") xfail_issue_51993 = xfail_test(reason="PRelu supports only 1D tensor for 'slope' input broadcasted" "by channel") +xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape") # Model MSFT issues: xfail_issue_37957 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:" diff --git a/ngraph/python/tests/test_onnx/test_backend.py b/ngraph/python/tests/test_onnx/test_backend.py index ebfd747406ea20..d411919e97785e 100644 --- a/ngraph/python/tests/test_onnx/test_backend.py +++ b/ngraph/python/tests/test_onnx/test_backend.py @@ -50,7 +50,6 @@ xfail_issue_45180, xfail_issue_45344, xfail_issue_46762, - xfail_issue_46765, xfail_issue_47323, xfail_issue_47337, xfail_issue_48052, @@ -60,7 +59,8 @@ xfail_issue_49753, xfail_issue_49754, xfail_issue_52463, - xfail_issue_51993) + xfail_issue_51993, + xfail_issue_55760) def expect_fail(test_case_path, xfail): # type: (str) -> None @@ -165,23 +165,11 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None "OnnxBackendNodeModelTest.test_min_uint16_cpu", "OnnxBackendNodeModelTest.test_min_uint32_cpu", "OnnxBackendNodeModelTest.test_min_uint64_cpu"), - (xfail_issue_46765, + (xfail_issue_55760, "OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmax_keepdims_example_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmax_no_keepdims_example_select_last_index_cpu", "OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmin_keepdims_example_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmin_no_keepdims_example_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmax_default_axis_example_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmax_default_axis_random_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmax_keepdims_random_select_last_index_cpu", "OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_random_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmax_no_keepdims_random_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmin_default_axis_example_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmin_default_axis_random_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu", - "OnnxBackendNodeModelTest.test_argmin_no_keepdims_random_select_last_index_cpu"), + "OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu"), (xfail_issue_38091, "OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"), (xfail_issue_52463, diff --git a/ngraph/test/models/onnx/argmax_select_last_index.prototxt b/ngraph/test/models/onnx/argmax_select_last_index.prototxt index ba73f63c1952ba..d47a89d14a0a0e 100644 --- a/ngraph/test/models/onnx/argmax_select_last_index.prototxt +++ b/ngraph/test/models/onnx/argmax_select_last_index.prototxt @@ -50,6 +50,9 @@ graph { dim { dim_value: 3 } + dim { + dim_value: 1 + } } } } diff --git a/ngraph/test/onnx/onnx_import.in.cpp b/ngraph/test/onnx/onnx_import.in.cpp index bf9551a5cfbf73..6c0438afac145e 100644 --- a/ngraph/test/onnx/onnx_import.in.cpp +++ b/ngraph/test/onnx/onnx_import.in.cpp @@ -2503,45 +2503,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_float) NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmax_select_last_index) { - try - { - auto function = onnx_import::import_onnx_model( - file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_select_last_index.prototxt")); - FAIL() << "Expected exception was not thrown"; - } - catch (const ngraph::ngraph_error& e) - { - EXPECT_HAS_SUBSTRING( - e.what(), - std::string( - "Mode 'select_last_index=1' is not supported by current implementation of ArgMax")); - } - catch (...) - { - FAIL() << "Expected OnnxNodeValidationFailure exception was not thrown"; - } + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_select_last_index.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input(Shape{4, 3}, {1, 1, 1, 0.5, 3, 4, 0.5, 1, 1.1, 0, 3, 0}); + test_case.add_expected_output(Shape{1, 3}, {0, 3, 1}); + test_case.run(); } NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_select_last_index) { - try - { - auto function = onnx_import::import_onnx_model( - file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_select_last_index.prototxt")); - FAIL() << "Expected exception was not thrown"; - } - catch (const ngraph::ngraph_error& e) - { - EXPECT_HAS_SUBSTRING( - e.what(), - std::string( - "Mode 'select_last_index=1' is not supported by current implementation of ArgMin")); - std::string what{e.what()}; - } - catch (...) - { - FAIL() << "Expected OnnxNodeValidationFailure exception was not thrown"; - } + auto function = onnx_import::import_onnx_model( + file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_select_last_index.prototxt")); + + auto test_case = test::TestCase(function); + test_case.add_input(Shape{4, 3}, {1, 1, 1, 2, 3, 4, 2, 1, 1.1, 3, 3, 8}); + test_case.add_expected_output(Shape{4}, {2, 0, 1, 1}); + test_case.run(); } NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k) diff --git a/ngraph/test/runtime/ie/unit_test.manifest b/ngraph/test/runtime/ie/unit_test.manifest index 8c4ae1f893db0f..25537e25b398fb 100644 --- a/ngraph/test/runtime/ie/unit_test.manifest +++ b/ngraph/test/runtime/ie/unit_test.manifest @@ -131,6 +131,8 @@ arg_max_dyn_shape # Result mismatch onnx_model_argmax_float onnx_model_argmin_float +onnx_model_argmax_select_last_index +onnx_model_argmin_select_last_index # Constant has zero dimension that is not allowable onnx_dyn_shapes_transpose