Skip to content

Commit

Permalink
[ONNX] ArgMin/ArgMax support for select_last_index (#5661)
Browse files Browse the repository at this point in the history
* Add Reverse Op to opset

* Worksave with Reverse path

* Add last_index support

* refactor argminmax factory

* Remove old xfail, add new one

* Fix proto file for argmax

* Rewrite test for select_last_index

* Add CPU tests to Manifest

* Update manifest

* Remove Reverse from opset7

* Refactor arg_min_max factory

* Added example comment in arg_min_max

* Codestyle changes
  • Loading branch information
Jan Iwaszkiewicz authored May 25, 2021
1 parent 39c08e4 commit c3ca8d0
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 71 deletions.
7 changes: 0 additions & 7 deletions ngraph/frontend/onnx_import/src/op/argmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ namespace ngraph
{
OutputVector argmax(const Node& node)
{
const auto select_last_index =
node.get_attribute_value<std::int64_t>("select_last_index", 0);
CHECK_VALID_NODE(node,
select_last_index == 0,
"Mode 'select_last_index=1' is not supported by current "
"implementation of ArgMax");

const utils::ArgMinMaxFactory arg_factory(node);
return {arg_factory.make_arg_max()};
}
Expand Down
7 changes: 0 additions & 7 deletions ngraph/frontend/onnx_import/src/op/argmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ namespace ngraph
{
OutputVector argmin(const Node& node)
{
const auto select_last_index =
node.get_attribute_value<std::int64_t>("select_last_index", 0);
CHECK_VALID_NODE(node,
select_last_index == 0,
"Mode 'select_last_index=1' is not supported by current "
"implementation of ArgMin");

const utils::ArgMinMaxFactory arg_factory(node);
return {arg_factory.make_arg_min()};
}
Expand Down
74 changes: 69 additions & 5 deletions ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "utils/arg_min_max_factory.hpp"
#include "default_opset.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/validation_util.hpp"

namespace ngraph
Expand All @@ -14,9 +15,11 @@ namespace ngraph
{
ArgMinMaxFactory::ArgMinMaxFactory(const Node& node)
: m_keep_dims{node.get_attribute_value<std::int64_t>("keepdims", 1)}
, m_input_node{node.get_ng_inputs().at(0)}
, m_axis{node.get_attribute_value<std::int64_t>("axis", 0)}
, m_select_last_index{
node.get_attribute_value<std::int64_t>("select_last_index", 0)}
{
m_input_node = node.get_ng_inputs().at(0);
}

std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_max() const
Expand All @@ -34,19 +37,80 @@ namespace ngraph
{
const auto k_node =
default_opset::Constant::create(ngraph::element::i64, Shape{}, {1});

if (m_select_last_index == 1)
{
// Example (ArgMin):
// The goal is to get the index of the last occurence of the
// minimum value present in given input tensor.
//
// Input: [1, 2, 1, 3, 4, 4]
// Expected output: [2]
//
// Top-K is always returning the "most-left" result. The trick is to
// reverse input to find the "most-right" occurence which is equal to
// the last occurence in the original input.
// reverse = [4, 4, 3, 1, 2, 1]
//
// Run TopK on reversed tensor, in the example output with index values
// is equal to:
// topk->output(1) = 3
//
// Using ShapeOf and Gather on input obtain length of the input tensor
// along axis, in the example this is equal to:
// dims_on_axis = 6
//
// Now using two Substract ops calculate resulting index:
// res_index = dims_on_axis - topk->output(1) = 6 - 3 = 3
// result = res_index - 1 = 3 - 1 = 2

const auto axis_node =
default_opset::Constant::create(ngraph::element::i64, Shape{1}, {m_axis});
const auto reverse = std::make_shared<opset1::Reverse>(
m_input_node, axis_node, opset1::Reverse::Mode::INDEX);

const auto topk = std::make_shared<default_opset::TopK>(
reverse, k_node, m_axis, mode, default_opset::TopK::SortType::NONE);

const auto data_shape = std::make_shared<default_opset::ShapeOf>(m_input_node);
const auto dims_on_axis = std::make_shared<default_opset::Gather>(
data_shape,
axis_node,
default_opset::Constant::create(ngraph::element::i64, Shape{}, {0}));

const auto res_index = std::make_shared<default_opset::Subtract>(
dims_on_axis,
std::make_shared<default_opset::Convert>(topk->output(1), element::i64));
const auto result = std::make_shared<default_opset::Subtract>(
res_index,
default_opset::Constant::create(ngraph::element::i64, Shape{1}, {1}));

if (m_keep_dims == 0)
{
const auto axis_to_remove = default_opset::Constant::create(
element::u64, Shape{}, {topk->get_axis()});

return std::make_shared<default_opset::Squeeze>(result, axis_to_remove);
}

return result;
}

const auto topk = std::make_shared<default_opset::TopK>(
m_input_node, k_node, m_axis, mode, default_opset::TopK::SortType::NONE);

const auto result =
std::make_shared<default_opset::Convert>(topk->output(1), element::i64);

if (m_keep_dims == 0)
{
const auto axis_to_remove =
default_opset::Constant::create(element::u64, Shape{}, {topk->get_axis()});
const auto reshaped_indices =
std::make_shared<default_opset::Squeeze>(topk->output(1), axis_to_remove);

return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64);
return std::make_shared<default_opset::Squeeze>(result, axis_to_remove);
}
return std::make_shared<default_opset::Convert>(topk->output(1), element::i64);

return result;
}
} // namespace utils
} // namespace onnx_import
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace ngraph
const std::int64_t m_keep_dims;
Output<ngraph::Node> m_input_node;
std::int64_t m_axis;
std::int64_t m_select_last_index;
};

} // namespace utils
Expand Down
2 changes: 1 addition & 1 deletion ngraph/python/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True):
xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:"
"FakeQuantize_xxx has non const input on 1 port")
xfail_issue_46762 = xfail_test(reason="Incorrect result of Minimum op if uint data type is used")
xfail_issue_46765 = xfail_test(reason="select_last_index attribute is not supported by ArgMin and ArgMax")
xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")
xfail_issue_51993 = xfail_test(reason="PRelu supports only 1D tensor for 'slope' input broadcasted"
"by channel")
xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape")

# Model MSFT issues:
xfail_issue_37957 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
Expand Down
20 changes: 4 additions & 16 deletions ngraph/python/tests/test_onnx/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
xfail_issue_45180,
xfail_issue_45344,
xfail_issue_46762,
xfail_issue_46765,
xfail_issue_47323,
xfail_issue_47337,
xfail_issue_48052,
Expand All @@ -60,7 +59,8 @@
xfail_issue_49753,
xfail_issue_49754,
xfail_issue_52463,
xfail_issue_51993)
xfail_issue_51993,
xfail_issue_55760)


def expect_fail(test_case_path, xfail): # type: (str) -> None
Expand Down Expand Up @@ -165,23 +165,11 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None
"OnnxBackendNodeModelTest.test_min_uint16_cpu",
"OnnxBackendNodeModelTest.test_min_uint32_cpu",
"OnnxBackendNodeModelTest.test_min_uint64_cpu"),
(xfail_issue_46765,
(xfail_issue_55760,
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_no_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_no_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_default_axis_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_default_axis_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_no_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_default_axis_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_default_axis_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_no_keepdims_random_select_last_index_cpu"),
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu"),
(xfail_issue_38091,
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"),
(xfail_issue_52463,
Expand Down
3 changes: 3 additions & 0 deletions ngraph/test/models/onnx/argmax_select_last_index.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ graph {
dim {
dim_value: 3
}
dim {
dim_value: 1
}
}
}
}
Expand Down
49 changes: 14 additions & 35 deletions ngraph/test/onnx/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2503,45 +2503,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_float)

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmax_select_last_index)
{
try
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_select_last_index.prototxt"));
FAIL() << "Expected exception was not thrown";
}
catch (const ngraph::ngraph_error& e)
{
EXPECT_HAS_SUBSTRING(
e.what(),
std::string(
"Mode 'select_last_index=1' is not supported by current implementation of ArgMax"));
}
catch (...)
{
FAIL() << "Expected OnnxNodeValidationFailure exception was not thrown";
}
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_select_last_index.prototxt"));

auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{4, 3}, {1, 1, 1, 0.5, 3, 4, 0.5, 1, 1.1, 0, 3, 0});
test_case.add_expected_output<std::int64_t>(Shape{1, 3}, {0, 3, 1});
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_select_last_index)
{
try
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_select_last_index.prototxt"));
FAIL() << "Expected exception was not thrown";
}
catch (const ngraph::ngraph_error& e)
{
EXPECT_HAS_SUBSTRING(
e.what(),
std::string(
"Mode 'select_last_index=1' is not supported by current implementation of ArgMin"));
std::string what{e.what()};
}
catch (...)
{
FAIL() << "Expected OnnxNodeValidationFailure exception was not thrown";
}
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_select_last_index.prototxt"));

auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{4, 3}, {1, 1, 1, 2, 3, 4, 2, 1, 1.1, 3, 3, 8});
test_case.add_expected_output<std::int64_t>(Shape{4}, {2, 0, 1, 1});
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k)
Expand Down
2 changes: 2 additions & 0 deletions ngraph/test/runtime/ie/unit_test.manifest
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ arg_max_dyn_shape
# Result mismatch
onnx_model_argmax_float
onnx_model_argmin_float
onnx_model_argmax_select_last_index
onnx_model_argmin_select_last_index

# Constant has zero dimension that is not allowable
onnx_dyn_shapes_transpose
Expand Down

0 comments on commit c3ca8d0

Please sign in to comment.