Skip to content

Commit

Permalink
refactor: update functions named unchecked_ (#436)
Browse files Browse the repository at this point in the history
# Rationale for this change

Some of our functions are named `unchecked_`.

# What changes are included in this PR?

See commits.

# Are these changes tested?
By existing tests.
  • Loading branch information
JayWhite2357 authored Dec 16, 2024
2 parents c8ac34f + 8933493 commit a9ba6e0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 43 deletions.
24 changes: 18 additions & 6 deletions crates/proof-of-sql/src/base/math/permutation.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use crate::base::if_rayon;
use alloc::{format, string::String, vec::Vec};
use core::cmp::Ordering;
use itertools::Itertools;
#[cfg(feature = "rayon")]
use rayon::prelude::ParallelSliceMut;
use snafu::Snafu;

/// An error that occurs when working with permutations
Expand All @@ -23,12 +28,19 @@ pub struct Permutation {
}

impl Permutation {
/// Create a new permutation without checks
///
/// Warning: This function does not check if the permutation is valid.
/// Only use this function if you are sure that the permutation is valid.
pub(crate) fn unchecked_new(permutation: Vec<usize>) -> Self {
Self { permutation }
/// Create a new permutation from a comparison function with the given length
pub(crate) fn unchecked_new_from_cmp<F>(length: usize, cmp: F) -> Self
where
F: Fn(&usize, &usize) -> Ordering + Sync,
{
let mut indexes = (0..length).collect_vec();
if_rayon!(
indexes.par_sort_unstable_by(cmp),
indexes.sort_unstable_by(cmp)
);
Self {
permutation: indexes,
}
}

/// Create a new permutation. If the permutation is invalid, return an error.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,11 @@ use crate::base::{
database::{
order_by_util::compare_indexes_by_owned_columns_with_direction, OwnedColumn, OwnedTable,
},
if_rayon,
math::permutation::Permutation,
scalar::Scalar,
};
use alloc::{string::ToString, vec::Vec};
use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection};
#[cfg(feature = "rayon")]
use rayon::prelude::ParallelSliceMut;
use serde::{Deserialize, Serialize};

/// A node representing a list of `OrderBy` expressions.
Expand All @@ -30,7 +27,6 @@ impl OrderByPostprocessing {
impl<S: Scalar> PostprocessingStep<S> for OrderByPostprocessing {
/// Apply the slice transformation to the given `OwnedTable`.
fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
let mut indexes = (0..owned_table.num_rows()).collect::<Vec<_>>();
// Evaluate the columns by which we order
// Once we allow OrderBy for general aggregation-free expressions here we will need to call eval()
let order_by_pairs: Vec<(OwnedColumn<S>, OrderByDirection)> = self
Expand All @@ -52,15 +48,9 @@ impl<S: Scalar> PostprocessingStep<S> for OrderByPostprocessing {
)
.collect::<PostprocessingResult<Vec<(OwnedColumn<S>, OrderByDirection)>>>()?;
// Define the ordering
if_rayon!(
indexes.par_sort_unstable_by(|&a, &b| {
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b)
}),
indexes.sort_unstable_by(|&a, &b| {
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b)
})
);
let permutation = Permutation::unchecked_new(indexes);
let permutation = Permutation::unchecked_new_from_cmp(owned_table.num_rows(), |&a, &b| {
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b)
});
// Apply the ordering
Ok(
OwnedTable::<S>::try_from_iter(owned_table.into_inner().into_iter().map(
Expand Down
30 changes: 6 additions & 24 deletions crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,17 @@
use crate::{
base::{
database::{Column, ColumnarValue, LiteralValue},
if_rayon,
math::decimal::{DecimalError, Precision},
scalar::{Scalar, ScalarExt},
slice_ops,
},
sql::parse::{type_check_binary_operation, ConversionError, ConversionResult},
};
use alloc::string::ToString;
use bumpalo::Bump;
use core::cmp::{max, Ordering};
#[cfg(feature = "rayon")]
use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
use sqlparser::ast::BinaryOperator;

#[allow(clippy::unnecessary_wraps)]
fn unchecked_subtract_impl<'a, S: Scalar>(
alloc: &'a Bump,
lhs: &[S],
rhs: &[S],
table_length: usize,
) -> ConversionResult<&'a [S]> {
let result = alloc.alloc_slice_fill_default(table_length);
if_rayon!(result.par_iter_mut(), result.iter_mut())
.zip(lhs)
.zip(rhs)
.for_each(|((a, l), r)| {
*a = *l - *r;
});
Ok(result)
}

/// Scale LHS and RHS to the same scale if at least one of them is decimal
/// and take the difference. This function is used for comparisons.
///
Expand Down Expand Up @@ -155,12 +136,13 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>(
}
})?;
}
unchecked_subtract_impl(
alloc,
let result = alloc.alloc_slice_fill_default(lhs_len);
slice_ops::sub(
result,
&lhs.to_scalar_with_scaling(lhs_upscale),
&rhs.to_scalar_with_scaling(rhs_upscale),
lhs_len,
)
);
Ok(result)
}

#[allow(clippy::cast_sign_loss)]
Expand Down

0 comments on commit a9ba6e0

Please sign in to comment.