Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use Zac's quicksort algorithm in stdlib sorting #5940

Merged
merged 15 commits into from
Sep 11, 2024
2 changes: 1 addition & 1 deletion docs/docs/noir/concepts/data_types/arrays.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn main() {

### sort_via

Sorts the array with a custom comparison function
Sorts the array with a custom comparison function. The ordering function must return true if the first argument should be sorted to be before the second argument, otherwise it should return false.
jfecher marked this conversation as resolved.
Show resolved Hide resolved

```rust
fn sort_via(self, ordering: fn(T, T) -> bool) -> [T; N]
Expand Down
116 changes: 116 additions & 0 deletions noir_stdlib/src/array/check_shuffle.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use crate::cmp::Eq;

unconstrained fn __get_shuffle_indices<T, let N: u32>(lhs: [T; N], rhs: [T; N]) -> [Field; N] where T: Eq {
let mut shuffle_indices: [Field;N ] = [0; N];

let mut shuffle_mask: [bool; N] = [false; N];
for i in 0..N {
let mut found = false;
for j in 0..N {
if ((shuffle_mask[j] == false) & (!found)) {
if (lhs[i] == rhs[j]) {
found = true;
shuffle_indices[i] = j as Field;
shuffle_mask[j] = true;
}
}
if (found) {
continue;
}
}
assert(found == true, "check_shuffle, lhs and rhs arrays do not contain equivalent values");
}

shuffle_indices
}

unconstrained fn __get_index<let N: u32>(indices: [Field; N], idx: Field) -> Field {
let mut result = 0;
for i in 0..N {
if (indices[i] == idx) {
result = i as Field;
break;
}
}
result
}

pub(crate) fn check_shuffle<T, let N: u32>(lhs: [T; N], rhs: [T; N]) where T: Eq {
unsafe {
let shuffle_indices = __get_shuffle_indices(lhs, rhs);

for i in 0..N {
let idx = __get_index(shuffle_indices, i as Field);
assert_eq(shuffle_indices[idx], i as Field);
}
for i in 0..N {
let idx = shuffle_indices[i];
let expected = rhs[idx];
let result = lhs[i];
assert_eq(expected, result);
}
}
}

mod test {
use super::check_shuffle;
use crate::cmp::Eq;

struct CompoundStruct {
a: bool,
b: Field,
c: u64
}
impl Eq for CompoundStruct {
fn eq(self, other: Self) -> bool {
(self.a == other.a) & (self.b == other.b) & (self.c == other.c)
}
}

#[test]
fn test_shuffle() {
let lhs: [Field; 5] = [0, 1, 2, 3, 4];
let rhs: [Field; 5] = [2, 0, 3, 1, 4];
check_shuffle(lhs, rhs);
}

#[test]
fn test_shuffle_identity() {
let lhs: [Field; 5] = [0, 1, 2, 3, 4];
let rhs: [Field; 5] = [0, 1, 2, 3, 4];
check_shuffle(lhs, rhs);
}

#[test(should_fail_with = "check_shuffle, lhs and rhs arrays do not contain equivalent values")]
fn test_shuffle_fail() {
let lhs: [Field; 5] = [0, 1, 2, 3, 4];
let rhs: [Field; 5] = [0, 1, 2, 3, 5];
check_shuffle(lhs, rhs);
}

#[test(should_fail_with = "check_shuffle, lhs and rhs arrays do not contain equivalent values")]
fn test_shuffle_duplicates() {
let lhs: [Field; 5] = [0, 1, 2, 3, 4];
let rhs: [Field; 5] = [0, 1, 2, 3, 3];
check_shuffle(lhs, rhs);
}

#[test]
fn test_shuffle_compound_struct() {
let lhs: [CompoundStruct; 5] = [
CompoundStruct { a: false, b: 0, c: 12345 },
CompoundStruct { a: false, b: -100, c: 54321 },
CompoundStruct { a: true, b: 5, c: 0xffffffffffffffff },
CompoundStruct { a: true, b: 9814, c: 0xeeffee0011001133 },
CompoundStruct { a: false, b: 0x155, c: 0 }
];
let rhs: [CompoundStruct; 5] = [
CompoundStruct { a: false, b: 0x155, c: 0 },
CompoundStruct { a: false, b: 0, c: 12345 },
CompoundStruct { a: false, b: -100, c: 54321 },
CompoundStruct { a: true, b: 9814, c: 0xeeffee0011001133 },
CompoundStruct { a: true, b: 5, c: 0xffffffffffffffff }
];
check_shuffle(lhs, rhs);
}
}
101 changes: 36 additions & 65 deletions noir_stdlib/src/array.nr → noir_stdlib/src/array/mod.nr
Original file line number Diff line number Diff line change
@@ -1,63 +1,15 @@
use crate::cmp::Ord;
use crate::cmp::{Eq, Ord};
use crate::convert::From;
use crate::runtime::is_unconstrained;

mod check_shuffle;
mod quicksort;

impl<T, let N: u32> [T; N] {
/// Returns the length of the slice.
#[builtin(array_len)]
pub fn len(self) -> u32 {}

pub fn sort(self) -> Self where T: Ord {
self.sort_via(|a: T, b: T| a <= b)
}

pub fn sort_via<Env>(self, ordering: fn[Env](T, T) -> bool) -> Self {
let sorted_index = unsafe {
// Safety: These indices are asserted to be the sorted element indices via `find_index`
let sorted_index: [u32; N] = self.get_sorting_index(ordering);

for i in 0..N {
let pos = find_index(sorted_index, i);
assert(sorted_index[pos] == i);
}

sorted_index
};

// Sort the array using the indexes
let mut result = self;
for i in 0..N {
result[i] = self[sorted_index[i]];
}
// Ensure the array is sorted
for i in 0..N - 1 {
assert(ordering(result[i], result[i + 1]));
}

result
}

/// Returns the index of the elements in the array that would sort it, using the provided custom sorting function.
unconstrained fn get_sorting_index<Env>(self, ordering: fn[Env](T, T) -> bool) -> [u32; N] {
let mut result = [0; N];
let mut a = self;
for i in 0..N {
result[i] = i;
}
for i in 1..N {
for j in 0..i {
if ordering(a[i], a[j]) {
let old_a_j = a[j];
a[j] = a[i];
a[i] = old_a_j;
let old_j = result[j];
result[j] = result[i];
result[i] = old_j;
}
}
}
result
}

#[builtin(as_slice)]
pub fn as_slice(self) -> [T] {}

Expand Down Expand Up @@ -114,25 +66,44 @@ impl<T, let N: u32> [T; N] {
}
}

impl<T, let N: u32> [T; N] where T: Ord + Eq {
pub fn sort(self) -> Self {
self.sort_via(|a: T, b: T| a <= b)
}
}

impl<T, let N: u32> [T; N] where T: Eq {

/// Sorts the array using a custom predicate function `ordering`.
///
/// # Safety
/// The `ordering` function must be designed to return `true` for equal valued inputs
TomAFrench marked this conversation as resolved.
Show resolved Hide resolved
/// If this is not done, `sort_via` will fail to sort inputs with duplicated elements.
pub fn sort_via<Env>(self, ordering: fn[Env](T, T) -> bool) -> Self {
unsafe {
// Safety: `sorted` array is checked to be:
// a. a permutation of `input`'s elements
// b. satisfying the predicate `ordering`
let sorted = quicksort::quicksort(self, ordering);

if !is_unconstrained() {
for i in 0..N - 1 {
assert(ordering(sorted[i], sorted[i + 1]), "Array has not been sorted correctly according to `ordering`.");
}
check_shuffle::check_shuffle(self, sorted);
}
sorted
}
}
}

impl<let N: u32> [u8; N] {
/// Convert a sequence of bytes as-is into a string.
/// This function performs no UTF-8 validation or similar.
#[builtin(array_as_str_unchecked)]
pub fn as_str_unchecked(self) -> str<N> {}
}

// helper function used to look up the position of a value in an array of Field
// Note that function returns 0 if the value is not found
unconstrained fn find_index<let N: u32>(a: [u32; N], find: u32) -> u32 {
let mut result = 0;
for i in 0..a.len() {
if a[i] == find {
result = i;
}
}
result
}

impl<let N: u32> From<str<N>> for [u8; N] {
fn from(s: str<N>) -> Self {
s.as_bytes()
Expand Down
39 changes: 39 additions & 0 deletions noir_stdlib/src/array/quicksort.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
unconstrained fn partition<T, Env, let N: u32>(
arr: &mut [T; N],
low: u32,
high: u32,
sortfn: fn[Env](T, T) -> bool
) -> u32 {
let pivot = high;
let mut i = low;
for j in low..high {
if (sortfn(arr[j], arr[pivot])) {
let temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
i += 1;
}
}
let temp = arr[i];
arr[i] = arr[pivot];
arr[pivot] = temp;
i
}

unconstrained fn quicksort_recursive<T, Env, let N: u32>(arr: &mut [T; N], low: u32, high: u32, sortfn: fn[Env](T, T) -> bool) {
if low < high {
let pivot_index = partition(arr, low, high, sortfn);
if pivot_index > 0 {
quicksort_recursive(arr, low, pivot_index - 1, sortfn);
}
quicksort_recursive(arr, pivot_index + 1, high, sortfn);
}
}

unconstrained pub(crate) fn quicksort<T, Env, let N: u32>(_arr: [T; N], sortfn: fn[Env](T, T) -> bool) -> [T; N] {
let mut arr: [T; N] = _arr;
if arr.len() <= 1 {} else {
quicksort_recursive(&mut arr, 0, arr.len() - 1, sortfn);
}
arr
}
Loading