Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] ArgMin/ArgMax support for select_last_index #5661

Merged
merged 14 commits into from
May 25, 2021
7 changes: 0 additions & 7 deletions ngraph/frontend/onnx_import/src/op/argmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ namespace ngraph
{
OutputVector argmax(const Node& node)
{
const auto select_last_index =
node.get_attribute_value<std::int64_t>("select_last_index", 0);
CHECK_VALID_NODE(node,
select_last_index == 0,
"Mode 'select_last_index=1' is not supported by current "
"implementation of ArgMax");

const utils::ArgMinMaxFactory arg_factory(node);
return {arg_factory.make_arg_max()};
}
Expand Down
7 changes: 0 additions & 7 deletions ngraph/frontend/onnx_import/src/op/argmin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ namespace ngraph
{
OutputVector argmin(const Node& node)
{
const auto select_last_index =
node.get_attribute_value<std::int64_t>("select_last_index", 0);
CHECK_VALID_NODE(node,
select_last_index == 0,
"Mode 'select_last_index=1' is not supported by current "
"implementation of ArgMin");

const utils::ArgMinMaxFactory arg_factory(node);
return {arg_factory.make_arg_min()};
}
Expand Down
74 changes: 69 additions & 5 deletions ngraph/frontend/onnx_import/src/utils/arg_min_max_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "utils/arg_min_max_factory.hpp"
#include "default_opset.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/validation_util.hpp"

namespace ngraph
Expand All @@ -14,9 +15,11 @@ namespace ngraph
{
ArgMinMaxFactory::ArgMinMaxFactory(const Node& node)
: m_keep_dims{node.get_attribute_value<std::int64_t>("keepdims", 1)}
, m_input_node{node.get_ng_inputs().at(0)}
, m_axis{node.get_attribute_value<std::int64_t>("axis", 0)}
, m_select_last_index{
node.get_attribute_value<std::int64_t>("select_last_index", 0)}
{
m_input_node = node.get_ng_inputs().at(0);
}

std::shared_ptr<ngraph::Node> ArgMinMaxFactory::make_arg_max() const
Expand All @@ -34,19 +37,80 @@ namespace ngraph
{
const auto k_node =
default_opset::Constant::create(ngraph::element::i64, Shape{}, {1});

if (m_select_last_index == 1)
{
// Example (ArgMin):
// The goal is to get the index of the last occurence of the
// minimum value present in given input tensor.
//
// Input: [1, 2, 1, 3, 4, 4]
// Expected output: [2]
//
// Top-K is always returning the "most-left" result. The trick is to
// reverse input to find the "most-right" occurence which is equal to
// the last occurence in the original input.
// reverse = [4, 4, 3, 1, 2, 1]
//
// Run TopK on reversed tensor, in the example output with index values
// is equal to:
// topk->output(1) = 3
//
// Using ShapeOf and Gather on input obtain length of the input tensor
// along axis, in the example this is equal to:
// dims_on_axis = 6
//
// Now using two Substract ops calculate resulting index:
// res_index = dims_on_axis - topk->output(1) = 6 - 3 = 3
// result = res_index - 1 = 3 - 1 = 2

const auto axis_node =
default_opset::Constant::create(ngraph::element::i64, Shape{1}, {m_axis});
const auto reverse = std::make_shared<opset1::Reverse>(
m_input_node, axis_node, opset1::Reverse::Mode::INDEX);

const auto topk = std::make_shared<default_opset::TopK>(
reverse, k_node, m_axis, mode, default_opset::TopK::SortType::NONE);

const auto data_shape = std::make_shared<default_opset::ShapeOf>(m_input_node);
const auto dims_on_axis = std::make_shared<default_opset::Gather>(
data_shape,
axis_node,
default_opset::Constant::create(ngraph::element::i64, Shape{}, {0}));

const auto res_index = std::make_shared<default_opset::Subtract>(
dims_on_axis,
std::make_shared<default_opset::Convert>(topk->output(1), element::i64));
const auto result = std::make_shared<default_opset::Subtract>(
res_index,
default_opset::Constant::create(ngraph::element::i64, Shape{1}, {1}));

if (m_keep_dims == 0)
{
const auto axis_to_remove = default_opset::Constant::create(
element::u64, Shape{}, {topk->get_axis()});

return std::make_shared<default_opset::Squeeze>(result, axis_to_remove);
}

return result;
}

const auto topk = std::make_shared<default_opset::TopK>(
m_input_node, k_node, m_axis, mode, default_opset::TopK::SortType::NONE);

const auto result =
std::make_shared<default_opset::Convert>(topk->output(1), element::i64);

if (m_keep_dims == 0)
{
const auto axis_to_remove =
default_opset::Constant::create(element::u64, Shape{}, {topk->get_axis()});
const auto reshaped_indices =
std::make_shared<default_opset::Squeeze>(topk->output(1), axis_to_remove);

return std::make_shared<default_opset::Convert>(reshaped_indices, element::i64);
return std::make_shared<default_opset::Squeeze>(result, axis_to_remove);
}
return std::make_shared<default_opset::Convert>(topk->output(1), element::i64);

return result;
}
} // namespace utils
} // namespace onnx_import
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace ngraph
const std::int64_t m_keep_dims;
Output<ngraph::Node> m_input_node;
std::int64_t m_axis;
std::int64_t m_select_last_index;
};

} // namespace utils
Expand Down
2 changes: 1 addition & 1 deletion ngraph/python/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def xfail_test(reason="Mark the test as expected to fail", strict=True):
xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:"
"FakeQuantize_xxx has non const input on 1 port")
xfail_issue_46762 = xfail_test(reason="Incorrect result of Minimum op if uint data type is used")
xfail_issue_46765 = xfail_test(reason="select_last_index attribute is not supported by ArgMin and ArgMax")
xfail_issue_47323 = xfail_test(reason="RuntimeError: The plugin does not support FP64")
xfail_issue_47337 = xfail_test(reason="RuntimeError: Unsupported dynamic ops: v1::OneHot")
xfail_issue_33593 = xfail_test(reason="Current implementation of MaxPool doesn't support indices output")
xfail_issue_51993 = xfail_test(reason="PRelu supports only 1D tensor for 'slope' input broadcasted"
"by channel")
xfail_issue_55760 = xfail_test(reason="RuntimeError: Reversed axis have axes above the source space shape")

# Model MSFT issues:
xfail_issue_37957 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:"
Expand Down
20 changes: 4 additions & 16 deletions ngraph/python/tests/test_onnx/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
xfail_issue_45180,
xfail_issue_45344,
xfail_issue_46762,
xfail_issue_46765,
xfail_issue_47323,
xfail_issue_47337,
xfail_issue_48052,
Expand All @@ -60,7 +59,8 @@
xfail_issue_49753,
xfail_issue_49754,
xfail_issue_52463,
xfail_issue_51993)
xfail_issue_51993,
xfail_issue_55760)


def expect_fail(test_case_path, xfail): # type: (str) -> None
Expand Down Expand Up @@ -165,23 +165,11 @@ def expect_fail(test_case_path, xfail): # type: (str) -> None
"OnnxBackendNodeModelTest.test_min_uint16_cpu",
"OnnxBackendNodeModelTest.test_min_uint32_cpu",
"OnnxBackendNodeModelTest.test_min_uint64_cpu"),
(xfail_issue_46765,
(xfail_issue_55760,
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_no_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_no_keepdims_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_default_axis_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_default_axis_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_negative_axis_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmax_no_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_default_axis_example_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_default_axis_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu",
"OnnxBackendNodeModelTest.test_argmin_no_keepdims_random_select_last_index_cpu"),
"OnnxBackendNodeModelTest.test_argmin_negative_axis_keepdims_random_select_last_index_cpu"),
(xfail_issue_38091,
"OnnxBackendNodeModelTest.test_gather_negative_indices_cpu"),
(xfail_issue_52463,
Expand Down
3 changes: 3 additions & 0 deletions ngraph/test/models/onnx/argmax_select_last_index.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ graph {
dim {
dim_value: 3
}
dim {
dim_value: 1
}
}
}
}
Expand Down
49 changes: 14 additions & 35 deletions ngraph/test/onnx/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2503,45 +2503,24 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_float)

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmax_select_last_index)
{
try
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_select_last_index.prototxt"));
FAIL() << "Expected exception was not thrown";
}
catch (const ngraph::ngraph_error& e)
{
EXPECT_HAS_SUBSTRING(
e.what(),
std::string(
"Mode 'select_last_index=1' is not supported by current implementation of ArgMax"));
}
catch (...)
{
FAIL() << "Expected OnnxNodeValidationFailure exception was not thrown";
}
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmax_select_last_index.prototxt"));

auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{4, 3}, {1, 1, 1, 0.5, 3, 4, 0.5, 1, 1.1, 0, 3, 0});
test_case.add_expected_output<std::int64_t>(Shape{1, 3}, {0, 3, 1});
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_argmin_select_last_index)
{
try
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_select_last_index.prototxt"));
FAIL() << "Expected exception was not thrown";
}
catch (const ngraph::ngraph_error& e)
{
EXPECT_HAS_SUBSTRING(
e.what(),
std::string(
"Mode 'select_last_index=1' is not supported by current implementation of ArgMin"));
std::string what{e.what()};
}
catch (...)
{
FAIL() << "Expected OnnxNodeValidationFailure exception was not thrown";
}
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/argmin_select_last_index.prototxt"));

auto test_case = test::TestCase<TestEngine>(function);
test_case.add_input<float>(Shape{4, 3}, {1, 1, 1, 2, 3, 4, 2, 1, 1.1, 3, 3, 8});
test_case.add_expected_output<std::int64_t>(Shape{4}, {2, 0, 1, 1});
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, onnx_model_top_k)
Expand Down
2 changes: 2 additions & 0 deletions ngraph/test/runtime/ie/unit_test.manifest
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ arg_max_dyn_shape
# Result mismatch
onnx_model_argmax_float
onnx_model_argmin_float
onnx_model_argmax_select_last_index
onnx_model_argmin_select_last_index

# Constant has zero dimension that is not allowable
onnx_dyn_shapes_transpose
Expand Down