Skip to content

Commit

Permalink
Rollup merge of rust-lang#78681 - m-ou-se:binary-heap-retain, r=Amanieu
Browse files Browse the repository at this point in the history
Improve rebuilding behaviour of BinaryHeap::retain.

This changes `BinaryHeap::retain` such that it doesn't always fully rebuild the heap, but only rebuilds the parts for which that's necessary.

This makes use of the fact that retain gives out `&T`s and not `&mut T`s.

Retaining every element or removing only elements at the end results in no rebuilding at all. Retaining most elements results in only reordering the elements that got moved (those after the first removed element), using the same logic as was already used for `append`.

cc `@KodrAus` `@sfackler` - We briefly discussed this possibility in the meeting last week while we talked about stabilization of this function (rust-lang#71503).
  • Loading branch information
Dylan-DPC authored Apr 22, 2021
2 parents 5f1aeb5 + f5d72ab commit d6e5bba
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 35 deletions.
85 changes: 53 additions & 32 deletions library/alloc/src/collections/binary_heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,43 @@ impl<T: Ord> BinaryHeap<T> {
unsafe { self.sift_up(start, pos) };
}

/// Rebuild assuming data[0..start] is still a proper heap.
fn rebuild_tail(&mut self, start: usize) {
if start == self.len() {
return;
}

let tail_len = self.len() - start;

#[inline(always)]
fn log2_fast(x: usize) -> usize {
(usize::BITS - x.leading_zeros() - 1) as usize
}

// `rebuild` takes O(self.len()) operations
// and about 2 * self.len() comparisons in the worst case
// while repeating `sift_up` takes O(tail_len * log(start)) operations
// and about 1 * tail_len * log_2(start) comparisons in the worst case,
// assuming start >= tail_len. For larger heaps, the crossover point
// no longer follows this reasoning and was determined empirically.
let better_to_rebuild = if start < tail_len {
true
} else if self.len() <= 2048 {
2 * self.len() < tail_len * log2_fast(start)
} else {
2 * self.len() < tail_len * 11
};

if better_to_rebuild {
self.rebuild();
} else {
for i in start..self.len() {
// SAFETY: The index `i` is always less than self.len().
unsafe { self.sift_up(0, i) };
}
}
}

fn rebuild(&mut self) {
let mut n = self.len() / 2;
while n > 0 {
Expand Down Expand Up @@ -689,37 +726,11 @@ impl<T: Ord> BinaryHeap<T> {
swap(self, other);
}

if other.is_empty() {
return;
}

#[inline(always)]
fn log2_fast(x: usize) -> usize {
(usize::BITS - x.leading_zeros() - 1) as usize
}
let start = self.data.len();

// `rebuild` takes O(len1 + len2) operations
// and about 2 * (len1 + len2) comparisons in the worst case
// while `extend` takes O(len2 * log(len1)) operations
// and about 1 * len2 * log_2(len1) comparisons in the worst case,
// assuming len1 >= len2. For larger heaps, the crossover point
// no longer follows this reasoning and was determined empirically.
#[inline]
fn better_to_rebuild(len1: usize, len2: usize) -> bool {
let tot_len = len1 + len2;
if tot_len <= 2048 {
2 * tot_len < len2 * log2_fast(len1)
} else {
2 * tot_len < len2 * 11
}
}
self.data.append(&mut other.data);

if better_to_rebuild(self.len(), other.len()) {
self.data.append(&mut other.data);
self.rebuild();
} else {
self.extend(other.drain());
}
self.rebuild_tail(start);
}

/// Returns an iterator which retrieves elements in heap order.
Expand Down Expand Up @@ -770,12 +781,22 @@ impl<T: Ord> BinaryHeap<T> {
/// assert_eq!(heap.into_sorted_vec(), [-10, 2, 4])
/// ```
#[unstable(feature = "binary_heap_retain", issue = "71503")]
pub fn retain<F>(&mut self, f: F)
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
self.data.retain(f);
self.rebuild();
let mut first_removed = self.len();
let mut i = 0;
self.data.retain(|e| {
let keep = f(e);
if !keep && i < first_removed {
first_removed = i;
}
i += 1;
keep
});
// data[0..first_removed] is untouched, so we only need to rebuild the tail:
self.rebuild_tail(first_removed);
}
}

Expand Down
19 changes: 16 additions & 3 deletions library/alloc/tests/binary_heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,23 @@ fn assert_covariance() {

#[test]
fn test_retain() {
let mut a = BinaryHeap::from(vec![-10, -5, 1, 2, 4, 13]);
a.retain(|x| x % 2 == 0);
let mut a = BinaryHeap::from(vec![100, 10, 50, 1, 2, 20, 30]);
a.retain(|&x| x != 2);

assert_eq!(a.into_sorted_vec(), [-10, 2, 4])
// Check that 20 moved into 10's place.
assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]);

a.retain(|_| true);

assert_eq!(a.clone().into_vec(), [100, 20, 50, 1, 10, 30]);

a.retain(|&x| x < 50);

assert_eq!(a.clone().into_vec(), [30, 20, 10, 1]);

a.retain(|_| false);

assert!(a.is_empty());
}

// old binaryheap failed this test
Expand Down

0 comments on commit d6e5bba

Please sign in to comment.