Skip to content

Commit

Permalink
feat: implement postprocessing::OrderByExpr (#46)
Browse files Browse the repository at this point in the history
# Rationale for this change
We need to have a native Rust implementation of postprocessing for
`ORDER BY` so that we can remove polars & the `transformation` module.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked Jira ticket then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->

# What changes are included in this PR?
- add `math::permutation::Permutation`
- add `OwnedColumn::try_permute`
- add `sql::postprocessing::OrderByExpr`
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->

# Are these changes tested?
Yes
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
  • Loading branch information
iajoiner authored Jul 15, 2024
1 parent d70dfcd commit ba6e53a
Show file tree
Hide file tree
Showing 10 changed files with 514 additions and 5 deletions.
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ mod test_accessor_utility;
pub use test_accessor_utility::{make_random_test_accessor_data, RandomTestAccessorDescriptor};

mod owned_column;
pub(crate) use owned_column::compare_indexes_by_owned_columns_with_direction;
pub use owned_column::OwnedColumn;
mod owned_table;
pub use owned_table::OwnedTable;
Expand Down
132 changes: 130 additions & 2 deletions crates/proof-of-sql/src/base/database/owned_column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,18 @@
/// converting to the final result in either Arrow format or JSON.
/// This is the analog of an arrow Array.
use super::ColumnType;
use crate::base::{math::decimal::Precision, scalar::Scalar};
use proof_of_sql_parser::posql_time::{timezone::PoSQLTimeZone, unit::PoSQLTimeUnit};
use crate::base::{
math::{
decimal::Precision,
permutation::{Permutation, PermutationError},
},
scalar::Scalar,
};
use core::cmp::Ordering;
use proof_of_sql_parser::{
intermediate_ast::OrderByDirection,
posql_time::{timezone::PoSQLTimeZone, unit::PoSQLTimeUnit},
};

#[derive(Debug, PartialEq, Clone, Eq)]
#[non_exhaustive]
Expand Down Expand Up @@ -46,6 +56,25 @@ impl<S: Scalar> OwnedColumn<S> {
}
}

/// Returns the column with its entries permutated
pub fn try_permute(&self, permutation: &Permutation) -> Result<Self, PermutationError> {
Ok(match self {
OwnedColumn::Boolean(col) => OwnedColumn::Boolean(permutation.try_apply(col)?),
OwnedColumn::SmallInt(col) => OwnedColumn::SmallInt(permutation.try_apply(col)?),
OwnedColumn::Int(col) => OwnedColumn::Int(permutation.try_apply(col)?),
OwnedColumn::BigInt(col) => OwnedColumn::BigInt(permutation.try_apply(col)?),
OwnedColumn::VarChar(col) => OwnedColumn::VarChar(permutation.try_apply(col)?),
OwnedColumn::Int128(col) => OwnedColumn::Int128(permutation.try_apply(col)?),
OwnedColumn::Decimal75(precision, scale, col) => {
OwnedColumn::Decimal75(*precision, *scale, permutation.try_apply(col)?)
}
OwnedColumn::Scalar(col) => OwnedColumn::Scalar(permutation.try_apply(col)?),
OwnedColumn::TimestampTZ(tu, tz, col) => {
OwnedColumn::TimestampTZ(*tu, *tz, permutation.try_apply(col)?)
}
})
}

/// Returns the sliced column.
pub fn slice(&self, start: usize, end: usize) -> Self {
match self {
Expand All @@ -64,6 +93,7 @@ impl<S: Scalar> OwnedColumn<S> {
}
}
}

/// Returns true if the column is empty.
pub fn is_empty(&self) -> bool {
match self {
Expand Down Expand Up @@ -161,3 +191,101 @@ impl<S: Scalar> OwnedColumn<S> {
}
}
}

/// Compares the tuples (order_by_pairs[0][i], order_by_pairs[1][i], ...) and
/// (order_by_pairs[0][j], order_by_pairs[1][j], ...) in lexicographic order.
/// Note that direction flips the ordering.
pub(crate) fn compare_indexes_by_owned_columns_with_direction<S: Scalar>(
order_by_pairs: &[(OwnedColumn<S>, OrderByDirection)],
i: usize,
j: usize,
) -> Ordering {
order_by_pairs
.iter()
.map(|(col, direction)| {
let ordering = match col {
OwnedColumn::Boolean(col) => col[i].cmp(&col[j]),
OwnedColumn::SmallInt(col) => col[i].cmp(&col[j]),
OwnedColumn::Int(col) => col[i].cmp(&col[j]),
OwnedColumn::BigInt(col) => col[i].cmp(&col[j]),
OwnedColumn::Int128(col) => col[i].cmp(&col[j]),
OwnedColumn::Decimal75(_, _, col) => col[i].cmp(&col[j]),
OwnedColumn::Scalar(col) => col[i].cmp(&col[j]),
OwnedColumn::VarChar(col) => col[i].cmp(&col[j]),
OwnedColumn::TimestampTZ(_, _, col) => col[i].cmp(&col[j]),
};
match direction {
OrderByDirection::Asc => ordering,
OrderByDirection::Desc => ordering.reverse(),
}
})
.find(|&ord| ord != Ordering::Equal)
.unwrap_or(Ordering::Equal)
}

#[cfg(test)]
mod test {
use super::*;
use crate::base::{math::decimal::Precision, scalar::Curve25519Scalar};
use proof_of_sql_parser::intermediate_ast::OrderByDirection;

#[test]
fn we_can_slice_a_column() {
let col: OwnedColumn<Curve25519Scalar> = OwnedColumn::Int128(vec![1, 2, 3, 4, 5]);
assert_eq!(col.slice(1, 4), OwnedColumn::Int128(vec![2, 3, 4]));
}

#[test]
fn we_can_permute_a_column() {
let col: OwnedColumn<Curve25519Scalar> = OwnedColumn::Int128(vec![1, 2, 3, 4, 5]);
let permutation = Permutation::try_new(vec![1, 3, 4, 0, 2]).unwrap();
assert_eq!(
col.try_permute(&permutation).unwrap(),
OwnedColumn::Int128(vec![2, 4, 5, 1, 3])
);
}

#[test]
fn we_can_compare_columns() {
let col1: OwnedColumn<Curve25519Scalar> = OwnedColumn::SmallInt(vec![1, 1, 2, 1, 1]);
let col2: OwnedColumn<Curve25519Scalar> = OwnedColumn::VarChar(
["b", "b", "a", "b", "a"]
.iter()
.map(|s| s.to_string())
.collect(),
);
let col3: OwnedColumn<Curve25519Scalar> = OwnedColumn::Decimal75(
Precision::new(70).unwrap(),
20,
[1, 2, 2, 1, 2]
.iter()
.map(|&i| Curve25519Scalar::from(i))
.collect(),
);
let order_by_pairs = vec![
(col1, OrderByDirection::Asc),
(col2, OrderByDirection::Desc),
(col3, OrderByDirection::Asc),
];
// Equal on col1 and col2, less on col3
assert_eq!(
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, 0, 1),
Ordering::Less
);
// Less on col1
assert_eq!(
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, 0, 2),
Ordering::Less
);
// Equal on all 3 columns
assert_eq!(
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, 0, 3),
Ordering::Equal
);
// Equal on col1, greater on col2 reversed
assert_eq!(
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, 1, 4),
Ordering::Less
)
}
}
4 changes: 3 additions & 1 deletion crates/proof-of-sql/src/base/math/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Handles parsing between decimal tokens received from the lexer into native `Decimal75` Proof of SQL type.
//! This module defines math utilities used in Proof of SQL.
/// Handles parsing between decimal tokens received from the lexer into native `Decimal75` Proof of SQL type.
pub mod decimal;
#[cfg(test)]
mod decimal_tests;
mod log;
pub(crate) use log::log2_up;
pub(crate) mod permutation;
115 changes: 115 additions & 0 deletions crates/proof-of-sql/src/base/math/permutation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use thiserror::Error;

/// An error that occurs when working with permutations
#[derive(Error, Debug, PartialEq, Eq)]
pub enum PermutationError {
/// The permutation is invalid
#[error("Permutation is invalid {0}")]
InvalidPermutation(String),
/// Application of a permutation to a slice with an incorrect length
#[error("Application of a permutation to a slice with a different length {permutation_size} != {slice_length}")]
PermutationSizeMismatch {
permutation_size: usize,
slice_length: usize,
},
}

/// Permutation of [0, 1, 2, ..., n-1]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Permutation {
/// The permutation
permutation: Vec<usize>,
}

impl Permutation {
/// Create a new permutation without checks
///
/// Warning: This function does not check if the permutation is valid.
/// Only use this function if you are sure that the permutation is valid.
pub(crate) fn unchecked_new(permutation: Vec<usize>) -> Self {
Self { permutation }
}

/// Create a new permutation. If the permutation is invalid, return an error.
pub fn try_new(permutation: Vec<usize>) -> Result<Self, PermutationError> {
let length = permutation.len();
// Check for uniqueness
let mut elements = permutation.clone();
elements.sort_unstable();
elements.dedup();
if elements.len() < length {
Err(PermutationError::InvalidPermutation(format!(
"Permutation can not have duplicate elements: {:?}",
permutation
)))
}
// Check that no element is out of bounds
else if permutation.iter().any(|&i| i >= length) {
Err(PermutationError::InvalidPermutation(format!(
"Permutation can not have elements out of bounds: {:?}",
permutation
)))
} else {
Ok(Self { permutation })
}
}

/// Get the size of the permutation
pub fn size(&self) -> usize {
self.permutation.len()
}

/// Apply the permutation to the given slice
pub fn try_apply<T>(&self, slice: &[T]) -> Result<Vec<T>, PermutationError>
where
T: Clone,
{
if slice.len() != self.size() {
Err(PermutationError::PermutationSizeMismatch {
permutation_size: self.size(),
slice_length: slice.len(),
})
} else {
Ok(self.permutation.iter().map(|&i| slice[i].clone()).collect())
}
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_apply_permutation() {
let permutation = Permutation::try_new(vec![1, 0, 2]).unwrap();
assert_eq!(permutation.size(), 3);
assert_eq!(
permutation.try_apply(&["and", "Space", "Time"]).unwrap(),
vec!["Space", "and", "Time"]
);
}

#[test]
fn test_invalid_permutation() {
assert!(matches!(
Permutation::try_new(vec![1, 0, 0]),
Err(PermutationError::InvalidPermutation(_))
));
assert!(matches!(
Permutation::try_new(vec![1, 0, 3]),
Err(PermutationError::InvalidPermutation(_))
));
}

#[test]
fn test_permutation_size_mismatch() {
let permutation = Permutation::try_new(vec![1, 0, 2]).unwrap();
assert_eq!(
permutation.try_apply(&["Space", "Time"]),
Err(PermutationError::PermutationSizeMismatch {
permutation_size: 3,
slice_length: 2
})
);
}
}
3 changes: 3 additions & 0 deletions crates/proof-of-sql/src/sql/postprocessing/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ pub enum PostprocessingError {
/// Error in slicing due to slice index beyond usize
#[error("Error in slicing due to slice index beyond usize {0}")]
InvalidSliceIndex(i128),
/// Column not found
#[error("Column not found: {0}")]
ColumnNotFound(String),
}

/// Result type for postprocessing
Expand Down
6 changes: 5 additions & 1 deletion crates/proof-of-sql/src/sql/postprocessing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ pub use postprocessing_step::PostprocessingStep;
#[cfg(test)]
pub mod test_utility;

mod order_by_expr;
pub use order_by_expr::OrderByExpr;
#[cfg(test)]
mod order_by_expr_test;

mod slice_expr;
pub use slice_expr::SliceExpr;

#[cfg(test)]
mod slice_expr_test;
72 changes: 72 additions & 0 deletions crates/proof-of-sql/src/sql/postprocessing/order_by_expr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
use super::{PostprocessingError, PostprocessingResult, PostprocessingStep};
use crate::base::{
database::{compare_indexes_by_owned_columns_with_direction, OwnedColumn, OwnedTable},
math::permutation::Permutation,
scalar::Scalar,
};
use proof_of_sql_parser::intermediate_ast::{OrderBy, OrderByDirection};
use rayon::prelude::ParallelSliceMut;
use serde::{Deserialize, Serialize};

/// A node representing a list of `OrderBy` expressions.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct OrderByExpr<S: Scalar> {
by_exprs: Vec<OrderBy>,
_phantom: core::marker::PhantomData<S>,
}

impl<S: Scalar> OrderByExpr<S> {
/// Create a new `OrderByExpr` node.
pub fn new(by_exprs: Vec<OrderBy>) -> Self {
Self {
by_exprs,
_phantom: core::marker::PhantomData,
}
}
}

impl<S: Scalar> PostprocessingStep<S> for OrderByExpr<S> {
/// Apply the slice transformation to the given `OwnedTable`.
fn apply(&self, owned_table: OwnedTable<S>) -> PostprocessingResult<OwnedTable<S>> {
let mut indexes = (0..owned_table.num_rows()).collect::<Vec<_>>();
// Evaluate the columns by which we order
// Once we allow OrderBy for general aggregation-free expressions here we will need to call eval()
let order_by_pairs: Vec<(OwnedColumn<S>, OrderByDirection)> = self
.by_exprs
.iter()
.map(
|order_by| -> PostprocessingResult<(OwnedColumn<S>, OrderByDirection)> {
Ok((
owned_table
.inner_table()
.get(&order_by.expr)
.ok_or(PostprocessingError::ColumnNotFound(
order_by.expr.to_string(),
))?
.clone(),
order_by.direction,
))
},
)
.collect::<PostprocessingResult<Vec<(OwnedColumn<S>, OrderByDirection)>>>()?;
// Define the ordering
indexes.par_sort_unstable_by(|&a, &b| {
compare_indexes_by_owned_columns_with_direction(&order_by_pairs, a, b)
});
let permutation = Permutation::unchecked_new(indexes);
// Apply the ordering
Ok(
OwnedTable::<S>::try_from_iter(owned_table.into_inner().into_iter().map(
|(identifier, column)| {
(
identifier,
column
.try_permute(&permutation)
.expect("There should be no column length mismatch here"),
)
},
))
.expect("There should be no column length mismatch here"),
)
}
}
Loading

0 comments on commit ba6e53a

Please sign in to comment.