Skip to content

Commit

Permalink
Allow stable sort in TopK when sorting by indices (openvinotoolkit#16811
Browse files Browse the repository at this point in the history
)

* Allow stable sort in TopK when sorting by indices

* Clarification of stable sorting by index and unblocked test

* XFAIL the test again

* Clarification of sorting by indices

* Revert of changes in previous versions op TopK (spec)
  • Loading branch information
Tomasz Dołbniak authored Apr 13, 2023
1 parent 9c6d287 commit dcf6fb1
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/ops/sort/TopK_1.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ So for each slice `input[i1, ...., i(axis-1), :, i(axis+1), ..., iN]` which repr

Sorting and minimum/maximum are controlled by `sort` and `mode` attributes:
* *mode*=`max`, *sort*=`value` - descending by value
* *mode*=`max`, *sort*=`index` - descending by index
* *mode*=`max`, *sort*=`index` - ascending by index
* *mode*=`max`, *sort*=`none` - undefined
* *mode*=`min`, *sort*=`value` - ascending by value
* *mode*=`min`, *sort*=`index` - ascending by index
Expand Down
35 changes: 32 additions & 3 deletions docs/ops/sort/TopK_11.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

* *stable*

* **Description**: Specifies whether the equivalent elements should maintain their relative order from the input tensor. Takes effect only if the `sort` attribute is set to `value`.
* **Description**: Specifies whether the equivalent elements should maintain their relative order from the input tensor. Takes effect only if the `sort` attribute is set to `value` or `index`.
* **Range of values**: `true` of `false`
* **Type**: `boolean`
* **Default value**: `false`
Expand Down Expand Up @@ -77,12 +77,41 @@ Sorting and minimum/maximum are controlled by `sort` and `mode` attributes with
* *sort*=`value`, *mode*=`max`, *stable*=`true` - descending by value, relative order of equal elements guaranteed to be maintained
* *sort*=`value`, *mode*=`min`, *stable*=`false` - ascending by value, relative order of equal elements not guaranteed to be maintained
* *sort*=`value`, *mode*=`min`, *stable*=`true` - ascending by value, relative order of equal elements guaranteed to be maintained
* *sort*=`index`, *mode*=`max` - descending by index
* *sort*=`index`, *mode*=`min` - ascending by index
* *sort*=`index`, *mode*=`max`, *stable*=`false` - ascending by index, relative order of equal elements not guaranteed to be maintained
* *sort*=`index`, *mode*=`max`, *stable*=`true` - ascending by index, relative order of equal elements guaranteed to be maintained
* *sort*=`index`, *mode*=`min`, *stable*=`false` - ascending by index, relative order of equal elements not guaranteed to be maintained
* *sort*=`index`, *mode*=`min`, *stable*=`true` - ascending by index, relative order of equal elements guaranteed to be maintained
* *sort*=`none` , *mode*=`max` - undefined
* *sort*=`none` , *mode*=`min` - undefined

The relative order of equivalent elements is only preserved if the *stable* attribute is set to `true`. This makes the implementation use stable sorting algorithm during the computation of TopK elements. Otherwise the output order is undefined.
The "by index" order means that the input tensor's elements are still sorted by value but their order in the output tensor is additionally determined by the indices of those elements in the input tensor. This might yield multiple correct results though. For example if the input tensor contains the following elements:

input = [5, 3, 1, 2, 5, 5]

and when TopK is configured the following way:

mode = min
sort = index
k = 4

then the 3 following results are correct:

output_values = [5, 3, 1, 2]
output_indices = [0, 1, 2, 3]

output_values = [3, 1, 2, 5]
output_indices = [1, 2, 3, 4]

output_values = [3, 1, 2, 5]
output_indices = [1, 2, 3, 5]

When the `stable` attribute is additionally set to `true`, the example above will only have a single correct solution:

output_values = [5, 3, 1, 2]
output_indices = [0, 1, 2, 3]

The indices are always sorted ascendingly when `sort == index` for any given TopK node. Setting `sort == index` and `mode == max` means gthat the values are first sorted in the descending order but the indices which affect the order of output elements are sorted ascendingly.

**Example**

Expand Down
2 changes: 1 addition & 1 deletion docs/ops/sort/TopK_3.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ So for each slice `input[i1, ...., i(axis-1), :, i(axis+1), ..., iN]` which repr

Sorting and minimum/maximum are controlled by `sort` and `mode` attributes:
* *mode*=`max`, *sort*=`value` - descending by value
* *mode*=`max`, *sort*=`index` - descending by index
* *mode*=`max`, *sort*=`index` - ascending by index
* *mode*=`max`, *sort*=`none` - undefined
* *mode*=`min`, *sort*=`value` - ascending by value
* *mode*=`min`, *sort*=`index` - ascending by index
Expand Down
9 changes: 4 additions & 5 deletions src/core/src/op/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,10 @@ void ov::op::v11::TopK::validate_and_infer_types() {
OV_OP_SCOPE(v11_TopK_validate_and_infer_types);

if (m_stable) {
NODE_VALIDATION_CHECK(
this,
m_sort == TopKSortType::SORT_VALUES,
"Stable sort can only be used when TopK's sorting mode is set to 'VALUE'. Current sorting mode = ",
AttributeAdapter<TopKSortType>(m_sort).get());
NODE_VALIDATION_CHECK(this,
m_sort == TopKSortType::SORT_VALUES || m_sort == TopKSortType::SORT_INDICES,
"Stable sort can only be used when TopK's sorting mode is set to 'VALUE' or 'INDEX'.",
AttributeAdapter<TopKSortType>(m_sort).get());
}

util::TopKBase::validate_and_infer_types();
Expand Down
16 changes: 1 addition & 15 deletions src/core/tests/type_prop/top_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,20 +420,6 @@ TEST(type_prop, top_k_partial_value) {
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({{0, 200}}));
}

TEST(type_prop, topk_v11_stable_sort_by_indices) {
const auto data = std::make_shared<Parameter>(element::f32, Shape{2, 3, 4});
const auto k = Constant::create(element::u32, Shape{}, {1});
OV_EXPECT_THROW(const auto op = std::make_shared<ov::op::v11::TopK>(data,
k,
1,
op::TopKMode::MAX,
op::TopKSortType::SORT_INDICES,
element::i32,
true),
NodeValidationFailure,
HasSubstr("Stable sort can only be used when TopK's sorting mode is set to 'VALUE'"));
}

TEST(type_prop, topk_v11_stable_sort_by_none) {
const auto data = std::make_shared<Parameter>(element::f32, Shape{2, 3, 4});
const auto k = Constant::create(element::u32, Shape{}, {1});
Expand All @@ -445,5 +431,5 @@ TEST(type_prop, topk_v11_stable_sort_by_none) {
element::i64,
true),
NodeValidationFailure,
HasSubstr("Stable sort can only be used when TopK's sorting mode is set to 'VALUE'"));
HasSubstr("Stable sort can only be used when TopK's sorting mode is set to 'VALUE' or 'INDEX'"));
}

0 comments on commit dcf6fb1

Please sign in to comment.