diff --git a/library/alloc/src/collections/binary_heap.rs b/library/alloc/src/collections/binary_heap.rs index bf9f7432fb536..a201af0103070 100644 --- a/library/alloc/src/collections/binary_heap.rs +++ b/library/alloc/src/collections/binary_heap.rs @@ -652,6 +652,43 @@ impl BinaryHeap { unsafe { self.sift_up(start, pos) }; } + /// Rebuild assuming data[0..start] is still a proper heap. + fn rebuild_tail(&mut self, start: usize) { + if start == self.len() { + return; + } + + let tail_len = self.len() - start; + + #[inline(always)] + fn log2_fast(x: usize) -> usize { + (usize::BITS - x.leading_zeros() - 1) as usize + } + + // `rebuild` takes O(self.len()) operations + // and about 2 * self.len() comparisons in the worst case + // while repeating `sift_up` takes O(tail_len * log(start)) operations + // and about 1 * tail_len * log_2(start) comparisons in the worst case, + // assuming start >= tail_len. For larger heaps, the crossover point + // no longer follows this reasoning and was determined empirically. + let better_to_rebuild = if start < tail_len { + true + } else if self.len() <= 2048 { + 2 * self.len() < tail_len * log2_fast(start) + } else { + 2 * self.len() < tail_len * 11 + }; + + if better_to_rebuild { + self.rebuild(); + } else { + for i in start..self.len() { + // SAFETY: The index `i` is always less than self.len(). + unsafe { self.sift_up(0, i) }; + } + } + } + fn rebuild(&mut self) { let mut n = self.len() / 2; while n > 0 { @@ -689,37 +726,11 @@ impl BinaryHeap { swap(self, other); } - if other.is_empty() { - return; - } - - #[inline(always)] - fn log2_fast(x: usize) -> usize { - (usize::BITS - x.leading_zeros() - 1) as usize - } + let start = self.data.len(); - // `rebuild` takes O(len1 + len2) operations - // and about 2 * (len1 + len2) comparisons in the worst case - // while `extend` takes O(len2 * log(len1)) operations - // and about 1 * len2 * log_2(len1) comparisons in the worst case, - // assuming len1 >= len2. For larger heaps, the crossover point - // no longer follows this reasoning and was determined empirically. - #[inline] - fn better_to_rebuild(len1: usize, len2: usize) -> bool { - let tot_len = len1 + len2; - if tot_len <= 2048 { - 2 * tot_len < len2 * log2_fast(len1) - } else { - 2 * tot_len < len2 * 11 - } - } + self.data.append(&mut other.data); - if better_to_rebuild(self.len(), other.len()) { - self.data.append(&mut other.data); - self.rebuild(); - } else { - self.extend(other.drain()); - } + self.rebuild_tail(start); } /// Returns an iterator which retrieves elements in heap order. @@ -770,12 +781,22 @@ impl BinaryHeap { /// assert_eq!(heap.into_sorted_vec(), [-10, 2, 4]) /// ``` #[unstable(feature = "binary_heap_retain", issue = "71503")] - pub fn retain(&mut self, f: F) + pub fn retain(&mut self, mut f: F) where F: FnMut(&T) -> bool, { - self.data.retain(f); - self.rebuild(); + let mut first_removed = self.len(); + let mut i = 0; + self.data.retain(|e| { + let keep = f(e); + if !keep && i < first_removed { + first_removed = i; + } + i += 1; + keep + }); + // data[0..first_removed] is untouched, so we only need to rebuild the tail: + self.rebuild_tail(first_removed); } } diff --git a/library/alloc/tests/binary_heap.rs b/library/alloc/tests/binary_heap.rs index ce794a9a4afa2..a7913dcd28740 100644 --- a/library/alloc/tests/binary_heap.rs +++ b/library/alloc/tests/binary_heap.rs @@ -386,10 +386,23 @@ fn assert_covariance() { #[test] fn test_retain() { - let mut a = BinaryHeap::from(vec![-10, -5, 1, 2, 4, 13]); - a.retain(|x| x % 2 == 0); + let mut a = BinaryHeap::from(vec![100, 10, 50, 1, 2, 20, 30]); + a.retain(|&x| x != 2); - assert_eq!(a.into_sorted_vec(), [-10, 2, 4]) + // Check that 20 moved into 10's place. + assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]); + + a.retain(|_| true); + + assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]); + + a.retain(|&x| x < 50); + + assert_eq!(a.clone().into_vec(), [30, 20, 10, 1]); + + a.retain(|_| false); + + assert!(a.is_empty()); } // old binaryheap failed this test