diff --git a/docs/ops/sort/TopK_1.md b/docs/ops/sort/TopK_1.md index b1ad91b4b791f7..e3777356cb6951 100644 --- a/docs/ops/sort/TopK_1.md +++ b/docs/ops/sort/TopK_1.md @@ -59,7 +59,7 @@ So for each slice `input[i1, ...., i(axis-1), :, i(axis+1), ..., iN]` which repr Sorting and minimum/maximum are controlled by `sort` and `mode` attributes: * *mode*=`max`, *sort*=`value` - descending by value - * *mode*=`max`, *sort*=`index` - descending by index + * *mode*=`max`, *sort*=`index` - ascending by index * *mode*=`max`, *sort*=`none` - undefined * *mode*=`min`, *sort*=`value` - ascending by value * *mode*=`min`, *sort*=`index` - ascending by index diff --git a/docs/ops/sort/TopK_11.md b/docs/ops/sort/TopK_11.md index f96007704da53e..d770188f72cbba 100644 --- a/docs/ops/sort/TopK_11.md +++ b/docs/ops/sort/TopK_11.md @@ -31,7 +31,7 @@ * *stable* - * **Description**: Specifies whether the equivalent elements should maintain their relative order from the input tensor. Takes effect only if the `sort` attribute is set to `value`. + * **Description**: Specifies whether the equivalent elements should maintain their relative order from the input tensor. Takes effect only if the `sort` attribute is set to `value` or `index`. * **Range of values**: `true` of `false` * **Type**: `boolean` * **Default value**: `false` @@ -77,12 +77,41 @@ Sorting and minimum/maximum are controlled by `sort` and `mode` attributes with * *sort*=`value`, *mode*=`max`, *stable*=`true` - descending by value, relative order of equal elements guaranteed to be maintained * *sort*=`value`, *mode*=`min`, *stable*=`false` - ascending by value, relative order of equal elements not guaranteed to be maintained * *sort*=`value`, *mode*=`min`, *stable*=`true` - ascending by value, relative order of equal elements guaranteed to be maintained - * *sort*=`index`, *mode*=`max` - descending by index - * *sort*=`index`, *mode*=`min` - ascending by index + * *sort*=`index`, *mode*=`max`, *stable*=`false` - ascending by index, relative order of equal elements not guaranteed to be maintained + * *sort*=`index`, *mode*=`max`, *stable*=`true` - ascending by index, relative order of equal elements guaranteed to be maintained + * *sort*=`index`, *mode*=`min`, *stable*=`false` - ascending by index, relative order of equal elements not guaranteed to be maintained + * *sort*=`index`, *mode*=`min`, *stable*=`true` - ascending by index, relative order of equal elements guaranteed to be maintained * *sort*=`none` , *mode*=`max` - undefined * *sort*=`none` , *mode*=`min` - undefined The relative order of equivalent elements is only preserved if the *stable* attribute is set to `true`. This makes the implementation use stable sorting algorithm during the computation of TopK elements. Otherwise the output order is undefined. +The "by index" order means that the input tensor's elements are still sorted by value but their order in the output tensor is additionally determined by the indices of those elements in the input tensor. This might yield multiple correct results though. For example if the input tensor contains the following elements: + + input = [5, 3, 1, 2, 5, 5] + +and when TopK is configured the following way: + + mode = min + sort = index + k = 4 + +then the 3 following results are correct: + + output_values = [5, 3, 1, 2] + output_indices = [0, 1, 2, 3] + + output_values = [3, 1, 2, 5] + output_indices = [1, 2, 3, 4] + + output_values = [3, 1, 2, 5] + output_indices = [1, 2, 3, 5] + +When the `stable` attribute is additionally set to `true`, the example above will only have a single correct solution: + + output_values = [5, 3, 1, 2] + output_indices = [0, 1, 2, 3] + +The indices are always sorted ascendingly when `sort == index` for any given TopK node. Setting `sort == index` and `mode == max` means gthat the values are first sorted in the descending order but the indices which affect the order of output elements are sorted ascendingly. **Example** diff --git a/docs/ops/sort/TopK_3.md b/docs/ops/sort/TopK_3.md index 2ad37b24cfbb7d..e7f5c9ce41fc59 100644 --- a/docs/ops/sort/TopK_3.md +++ b/docs/ops/sort/TopK_3.md @@ -66,7 +66,7 @@ So for each slice `input[i1, ...., i(axis-1), :, i(axis+1), ..., iN]` which repr Sorting and minimum/maximum are controlled by `sort` and `mode` attributes: * *mode*=`max`, *sort*=`value` - descending by value - * *mode*=`max`, *sort*=`index` - descending by index + * *mode*=`max`, *sort*=`index` - ascending by index * *mode*=`max`, *sort*=`none` - undefined * *mode*=`min`, *sort*=`value` - ascending by value * *mode*=`min`, *sort*=`index` - ascending by index diff --git a/src/core/src/op/topk.cpp b/src/core/src/op/topk.cpp index 9bc85add54b3ac..a1555e24d126d5 100644 --- a/src/core/src/op/topk.cpp +++ b/src/core/src/op/topk.cpp @@ -323,11 +323,10 @@ void ov::op::v11::TopK::validate_and_infer_types() { OV_OP_SCOPE(v11_TopK_validate_and_infer_types); if (m_stable) { - NODE_VALIDATION_CHECK( - this, - m_sort == TopKSortType::SORT_VALUES, - "Stable sort can only be used when TopK's sorting mode is set to 'VALUE'. Current sorting mode = ", - AttributeAdapter(m_sort).get()); + NODE_VALIDATION_CHECK(this, + m_sort == TopKSortType::SORT_VALUES || m_sort == TopKSortType::SORT_INDICES, + "Stable sort can only be used when TopK's sorting mode is set to 'VALUE' or 'INDEX'.", + AttributeAdapter(m_sort).get()); } util::TopKBase::validate_and_infer_types(); diff --git a/src/core/tests/type_prop/top_k.cpp b/src/core/tests/type_prop/top_k.cpp index 9c718f4d7262e8..ed415b9653ee4f 100644 --- a/src/core/tests/type_prop/top_k.cpp +++ b/src/core/tests/type_prop/top_k.cpp @@ -420,20 +420,6 @@ TEST(type_prop, top_k_partial_value) { EXPECT_EQ(op->get_output_partial_shape(0), PartialShape({{0, 200}})); } -TEST(type_prop, topk_v11_stable_sort_by_indices) { - const auto data = std::make_shared(element::f32, Shape{2, 3, 4}); - const auto k = Constant::create(element::u32, Shape{}, {1}); - OV_EXPECT_THROW(const auto op = std::make_shared(data, - k, - 1, - op::TopKMode::MAX, - op::TopKSortType::SORT_INDICES, - element::i32, - true), - NodeValidationFailure, - HasSubstr("Stable sort can only be used when TopK's sorting mode is set to 'VALUE'")); -} - TEST(type_prop, topk_v11_stable_sort_by_none) { const auto data = std::make_shared(element::f32, Shape{2, 3, 4}); const auto k = Constant::create(element::u32, Shape{}, {1}); @@ -445,5 +431,5 @@ TEST(type_prop, topk_v11_stable_sort_by_none) { element::i64, true), NodeValidationFailure, - HasSubstr("Stable sort can only be used when TopK's sorting mode is set to 'VALUE'")); + HasSubstr("Stable sort can only be used when TopK's sorting mode is set to 'VALUE' or 'INDEX'")); }