From 231882a26ad72cea5266dc2aceaf2342e2118b80 Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Wed, 4 Jan 2023 09:40:26 -0800 Subject: [PATCH] [Contrib][Sort] Faster Top-K Implementation (#13599) This is a simple rewrite of hand-coded top-k function used for CPU targets. The old implementation sorted each axis and then took the biggest k elements. The new implementation does a single pass of each axis, keeping a min heap to store the top-k elements up to that point. If n is the size of the array, and we want to find top k, the old implementation has runtime in O(nlogn) with additional memory O(n) to store the sorted array. The new implementation is O(n log k), and in practice is probably amortized to O(n / k * log k) in many scenarios and only requires O(k). Note n >> k most of the time. In practice this new kernel led to a 20x speedup over existing one. On a Xeon Platinum 8370C CPU @ 2.80GHz for input shape [1, 3050] with k = 15, the latency went from 200us --> ~10us. There is probably more room for shaving off a little more time on the scale of a single us's, however I have determined it to not be worth it. --- src/runtime/contrib/sort/sort.cc | 91 +++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 24 deletions(-) diff --git a/src/runtime/contrib/sort/sort.cc b/src/runtime/contrib/sort/sort.cc index 8ea2f4b60cdf..bfb174a9206e 100644 --- a/src/runtime/contrib/sort/sort.cc +++ b/src/runtime/contrib/sort/sort.cc @@ -34,13 +34,25 @@ namespace contrib { using namespace runtime; -template +template bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { + if constexpr (stable_comparison) { + if (lhs.second == rhs.second) { + return lhs.first < rhs.first; + } + } + return lhs.second < rhs.second; } -template +template bool CompareDescend(const std::pair& lhs, const std::pair& rhs) { + if constexpr (stable_comparison) { + if (lhs.second == rhs.second) { + return lhs.first < rhs.first; + } + } + return lhs.second > rhs.second; } @@ -49,18 +61,14 @@ struct float16 { float to_float() const { return __extendXfYf2__(bits); } -}; -template <> -bool CompareAscend(const std::pair& lhs, const std::pair& rhs) { - return lhs.second.to_float() < rhs.second.to_float(); -} - -template <> -bool CompareDescend(const std::pair& lhs, - const std::pair& rhs) { - return lhs.second.to_float() > rhs.second.to_float(); -} + inline bool operator==(const float16& rhs) const { return to_float() == rhs.to_float(); } + inline bool operator!=(const float16& rhs) const { return to_float() != rhs.to_float(); } + inline bool operator<(const float16& rhs) const { return to_float() < rhs.to_float(); } + inline bool operator>(const float16& rhs) const { return to_float() > rhs.to_float(); } + inline bool operator<=(const float16& rhs) const { return to_float() <= rhs.to_float(); } + inline bool operator>=(const float16& rhs) const { return to_float() >= rhs.to_float(); } +}; // Argsort implemented C library sort for nms. // Return indices of sorted tensor. @@ -346,7 +354,12 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i (out_values == nullptr) ? nullptr : static_cast(out_values->data); IndicesType* indices_ptr = (out_indices == nullptr) ? nullptr : static_cast(out_indices->data); - std::vector> sorter; + + // Maintain a min/max containing the top-k elements + std::vector> running_heap; + + // Need +1 when inserting new element before maintaining heap invariant + running_heap.reserve(k + 1); int axis_mul_before = 1; int axis_mul_after = 1; @@ -363,26 +376,56 @@ void topk(DLTensor* input, DLTensor* out_values, DLTensor* out_indices, int k, i for (int i = 0; i < axis_mul_before; ++i) { for (int j = 0; j < axis_mul_after; ++j) { - sorter.clear(); + running_heap.clear(); int64_t src_base_idx = i * input->shape[axis] * axis_mul_after + j; int64_t dst_base_idx = i * k * axis_mul_after + j; - for (int64_t kk = 0; kk < input->shape[axis]; ++kk) { - int64_t full_idx = src_base_idx + kk * axis_mul_after; - sorter.emplace_back(std::make_pair(kk, data_ptr[full_idx])); + + // Start by creating min/max heap with fixed-k elements + int cur_axis_index = 0; + for (; cur_axis_index < k && cur_axis_index < input->shape[axis]; cur_axis_index++) { + int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; + running_heap.emplace_back(std::make_pair(cur_axis_index, data_ptr[full_idx])); + } + if (!is_ascend) { + std::make_heap(running_heap.begin(), running_heap.end(), CompareDescend); + } else { + std::make_heap(running_heap.begin(), running_heap.end(), CompareAscend); + } + + // Iterate through all elements, adding to heap along the way + for (; cur_axis_index < input->shape[axis]; cur_axis_index++) { + int64_t full_idx = src_base_idx + cur_axis_index * axis_mul_after; + std::pair cur_val = {cur_axis_index, data_ptr[full_idx]}; + + // Eq. to cur_val.second > running_heap.second + if (!is_ascend && CompareDescend(cur_val, running_heap[0])) { + running_heap.push_back(cur_val); + std::push_heap(running_heap.begin(), running_heap.end(), CompareDescend); + std::pop_heap(running_heap.begin(), running_heap.end(), CompareDescend); + running_heap.pop_back(); + } else if (is_ascend && CompareAscend(cur_val, running_heap[0])) { + running_heap.push_back(cur_val); + std::push_heap(running_heap.begin(), running_heap.end(), CompareAscend); + std::pop_heap(running_heap.begin(), running_heap.end(), CompareAscend); + running_heap.pop_back(); + } } + + // finally sort heap and deliver results if (is_ascend) { - std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + std::stable_sort(running_heap.begin(), running_heap.end(), CompareAscend); } else { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + std::stable_sort(running_heap.begin(), running_heap.end(), CompareDescend); } - int64_t cnt = k > 0 ? k : input->shape[axis]; - for (int64_t kk = 0; kk < cnt; ++kk) { + + for (uint32_t kk = 0; kk < running_heap.size(); ++kk) { if (indices_ptr != nullptr) { indices_ptr[dst_base_idx + kk * axis_mul_after] = - static_cast(sorter[kk].first); + static_cast(running_heap[kk].first); } if (values_ptr != nullptr) { - values_ptr[dst_base_idx + kk * axis_mul_after] = static_cast(sorter[kk].second); + values_ptr[dst_base_idx + kk * axis_mul_after] = + static_cast(running_heap[kk].second); } } }