diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 8ea2f4b60cdf..bfb174a9206e 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -34,13 +34,25 @@ namespace contrib { using namespace runtime; -template +template bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { + if constexpr (stable_comparison) { + if (lhs.second == rhs.second) { + return lhs.first < rhs.first; + } + } + return lhs.second < rhs.second; } -template +template bool CompareDescend(const std::pair& lhs, const std::pair& rhs) { + if constexpr (stable_comparison) { + if (lhs.second == rhs.second) { + return lhs.first < rhs.first; + } + } + return lhs.second > rhs.second; } @@ -49,18 +61,14 @@ struct float16 { float to_float() const { return __extendXfYf2__(bits); } -}; -template <> -bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { - return lhs.second.to_float() < rhs.second.to_float(); -} - -template <> -bool CompareDescend(const std::pair& lhs, - const std::pair& rhs) { - return lhs.second.to_float() > rhs.second.to_float(); -} + inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); } + inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); } + inline bool operator<(const float16& rhs) const { return to_float() < rhs.to_float(); } + inline bool operator>(const float16& rhs) const { return to_float() > rhs.to_float(); } + inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); } + inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); } +}; // Argsort implemented C library sort for nms. // Return indices of sorted tensor. @@ -346,7 +354,12 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i (out_values == nullptr) ? nullptr : static_cast(out_values->data); IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : static_cast(out_indices->data); - std::vector> sorter; + + // Maintain a min/max containing the top-k elements + std::vector> running_heap; + + // Need +1 when inserting new element before maintaining heap invariant + running_heap.reserve(k + 1); int axis_mul_before = 1; int axis_mul_after = 1; @@ -363,26 +376,56 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i for (int i = 0; i < axis_mul_before; ++i) { for (int j = 0; j < axis_mul_after; ++j) { - sorter.clear(); + running_heap.clear(); int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t dst_base_idx = i * k * axis_mul_after + j; - for (int64_t kk = 0; kk < input->shape[axis]; ++kk) { - int64_t full_idx = src_base_idx + kk * axis_mul_after; - sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx])); + + // Start by creating min/max heap with fixed-k elements + int cur_axis_index = 0; + for (; cur_axis_index < k && cur_axis_index < input->shape[axis]; cur_axis_index++) { + int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; + running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx])); + } + if (!is_ascend) { + std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend); + } else { + std::make_heap(running_heap.begin(), running_heap.end(), CompareAscend); + } + + // Iterate through all elements, adding to heap along the way + for (; cur_axis_index < input->shape[axis]; cur_axis_index++) { + int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; + std::pair cur_val = {cur_axis_index, data_ptr[full_idx]}; + + // Eq. to cur_val.second > running_heap.second + if (!is_ascend && CompareDescend(cur_val, running_heap[0])) { + running_heap.push_back(cur_val); + std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend); + std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend); + running_heap.pop_back(); + } else if (is_ascend && CompareAscend(cur_val, running_heap[0])) { + running_heap.push_back(cur_val); + std::push_heap(running_heap.begin(), running_heap.end(), CompareAscend); + std::pop_heap(running_heap.begin(), running_heap.end(), CompareAscend); + running_heap.pop_back(); + } } + + // finally sort heap and deliver results if (is_ascend) { - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend); } else { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend); } - int64_t cnt = k > 0 ? k : input->shape[axis]; - for (int64_t kk = 0; kk < cnt; ++kk) { + + for (uint32_t kk = 0; kk < running_heap.size(); ++kk) { if (indices_ptr != nullptr) { indices_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].first); + static_cast(running_heap[kk].first); } if (values_ptr != nullptr) { - values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast(sorter[kk].second); + values_ptr[dst_base_idx + kk * axis_mul_after] = + static_cast(running_heap[kk].second); } } }