Skip to content

Commit

Permalink
Use simpler branchy swap logic in tiny merge sort
Browse files Browse the repository at this point in the history
Avoids the code duplication issue and results in
smaller binary size, which after all is the
purpose of the feature.
  • Loading branch information
Voultapher committed Aug 29, 2024
1 parent ae57bdf commit 717e3aa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 42 deletions.
5 changes: 5 additions & 0 deletions core/src/slice/sort/shared/smallsort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,11 @@ where

/// Swap two values in the slice pointed to by `v_base` at the position `a_pos` and `b_pos` if the
/// value at position `b_pos` is less than the one at position `a_pos`.
///
/// Purposefully not marked `#[inline]`, despite us wanting it to be inlined for integers like
/// types. `is_less` could be a huge function and we want to give the compiler an option to
/// not inline this function. For the same reasons that this function is very perf critical
/// it should be in the same module as the functions that use it.
unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
Expand Down
50 changes: 8 additions & 42 deletions core/src/slice/sort/stable/tiny.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Binary-size optimized mergesort inspired by https://github.com/voultapher/tiny-sort-rs.
use crate::mem::{ManuallyDrop, MaybeUninit};
use crate::mem::MaybeUninit;
use crate::ptr;
use crate::slice::sort::stable::merge;

Expand All @@ -27,49 +27,15 @@ pub fn mergesort<T, F: FnMut(&T, &T) -> bool>(

merge::merge(v, scratch, mid, is_less);
} else if len == 2 {
// Branchless swap the two elements. This reduces the recursion depth and improves
// perf significantly at a small binary-size cost. Trades ~10% perf boost for integers
// for ~50 bytes in the binary.

// SAFETY: We checked the len, the pointers we create are valid and don't overlap.
unsafe {
swap_if_less(v.as_mut_ptr(), 0, 1, is_less);
}
}
}

/// Swap two values in the slice pointed to by `v_base` at the position `a_pos` and `b_pos` if the
/// value at position `b_pos` is less than the one at position `a_pos`.
unsafe fn swap_if_less<T, F>(v_base: *mut T, a_pos: usize, b_pos: usize, is_less: &mut F)
where
F: FnMut(&T, &T) -> bool,
{
// SAFETY: the caller must guarantee that `a` and `b` each added to `v_base` yield valid
// pointers into `v_base`, and are properly aligned, and part of the same allocation.
unsafe {
let v_a = v_base.add(a_pos);
let v_b = v_base.add(b_pos);
let v_base = v.as_mut_ptr();
let v_a = v_base;
let v_b = v_base.add(1);

// PANIC SAFETY: if is_less panics, no scratch memory was created and the slice should still be
// in a well defined state, without duplicates.

// Important to only swap if it is more and not if it is equal. is_less should return false for
// equal, so we don't swap.
let should_swap = is_less(&*v_b, &*v_a);

// This is a branchless version of swap if.
// The equivalent code with a branch would be:
//
// if should_swap {
// ptr::swap(left, right, 1);
// }

// The goal is to generate cmov instructions here.
let left_swap = if should_swap { v_b } else { v_a };
let right_swap = if should_swap { v_a } else { v_b };

let right_swap_tmp = ManuallyDrop::new(ptr::read(right_swap));
ptr::copy(left_swap, v_a, 1);
ptr::copy_nonoverlapping(&*right_swap_tmp, v_b, 1);
if is_less(&*v_b, &*v_a) {
ptr::swap_nonoverlapping(v_a, v_b, 1);
}
}
}
}

0 comments on commit 717e3aa

Please sign in to comment.