diff --git a/crates/proof-of-sql/src/base/database/join_util.rs b/crates/proof-of-sql/src/base/database/join_util.rs index 7cb614a9e..659370104 100644 --- a/crates/proof-of-sql/src/base/database/join_util.rs +++ b/crates/proof-of-sql/src/base/database/join_util.rs @@ -1,6 +1,17 @@ -use super::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp, Table, TableOptions}; -use crate::base::scalar::Scalar; +use super::{ + apply_column_to_indexes, + order_by_util::{compare_indexes_by_columns, compare_single_row_of_tables}, + Column, ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp, Table, TableOperationError, + TableOperationResult, TableOptions, +}; +use crate::base::{ + map::{IndexMap, IndexSet}, + scalar::Scalar, +}; use bumpalo::Bump; +use core::cmp::Ordering; +use itertools::Itertools; +use proof_of_sql_parser::Identifier; /// Compute the CROSS JOIN / cartesian product of two tables. /// @@ -34,6 +45,145 @@ pub fn cross_join<'a, S: Scalar>( .expect("Table creation should not fail") } +/// Compute the JOIN of two tables using a sort-merge join. +/// +/// Currently we only support INNER JOINs and only support joins on equalities. +/// # Panics +/// The function panics if we feed in incorrect data (e.g. Num of rows in `left` and some column of `left_on` being different). +#[allow(clippy::too_many_lines)] +pub fn sort_merge_join<'a, S: Scalar>( + left: &Table<'a, S>, + right: &Table<'a, S>, + left_on: &[Column<'a, S>], + right_on: &[Column<'a, S>], + left_selected_column_ident_aliases: &[(Identifier, Identifier)], + right_selected_column_ident_aliases: &[(Identifier, Identifier)], + alloc: &'a Bump, +) -> TableOperationResult> { + let left_num_rows = left.num_rows(); + let right_num_rows = right.num_rows(); + // Check that result aliases are unique + let aliases = left_selected_column_ident_aliases + .iter() + .map(|(_, alias)| alias) + .chain( + right_selected_column_ident_aliases + .iter() + .map(|(_, alias)| alias), + ) + .collect::>(); + if aliases.len() + != left_selected_column_ident_aliases.len() + right_selected_column_ident_aliases.len() + { + return Err(TableOperationError::DuplicateColumn); + } + // Check that the number of rows is good + for column in left_on { + assert_eq!(column.len(), left_num_rows); + } + for column in right_on { + assert_eq!(column.len(), right_num_rows); + } + // First of all sort the tables by the columns we are joining on + let left_indexes = + (0..left.num_rows()).sorted_unstable_by(|&a, &b| compare_indexes_by_columns(left_on, a, b)); + let right_indexes = (0..right.num_rows()) + .sorted_unstable_by(|&a, &b| compare_indexes_by_columns(right_on, a, b)); + // Collect the indexes of the rows that match + let mut left_iter = left_indexes.into_iter().peekable(); + let mut right_iter = right_indexes.into_iter().peekable(); + let index_pairs = core::iter::from_fn(|| { + // If we have reached the end of either table, we are done + let (&left_index, &right_index) = (left_iter.peek()?, right_iter.peek()?); + match compare_single_row_of_tables(left_on, right_on, left_index, right_index).ok()? { + Ordering::Less => { + // Move left forward, return no pairs for this iteration + left_iter.next(); + Some(Vec::new()) + } + Ordering::Greater => { + // Move right forward, return no pairs for this iteration + right_iter.next(); + Some(Vec::new()) + } + Ordering::Equal => { + // Identify groups of equal keys from both sides + let left_group = left_iter + .clone() + .take_while(|&lidx| { + compare_indexes_by_columns(left_on, left_index, lidx) == Ordering::Equal + }) + .collect::>(); + + let right_group = right_iter + .clone() + .take_while(|&ridx| { + compare_indexes_by_columns(right_on, right_index, ridx) == Ordering::Equal + }) + .collect::>(); + + // All combinations of left_group x right_group + let pairs = left_group + .iter() + .cartesian_product(right_group.iter()) + .map(|(&l, &r)| (l, r)) + .collect::>(); + + // Advance the iterators past the groups + left_iter.nth(left_group.len() - 1); + right_iter.nth(right_group.len() - 1); + + Some(pairs) + } + } + }) + // Flatten out the Vec> from above into a single Vec + .flatten() + .collect::>(); + // Now we have the indexes of the rows that match, we can create the new table + let (left_indexes, right_indexes): (Vec, Vec) = index_pairs.into_iter().unzip(); + let num_rows = left_indexes.len(); + let result_columns = left_selected_column_ident_aliases + .iter() + .map( + |(ident, alias)| -> TableOperationResult<(Identifier, Column<'a, S>)> { + Ok(( + *alias, + apply_column_to_indexes( + left.inner_table().get(ident).ok_or( + TableOperationError::ColumnDoesNotExist { + column_ident: *ident, + }, + )?, + alloc, + &left_indexes, + )?, + )) + }, + ) + .chain(right_selected_column_ident_aliases.iter().map( + |(ident, alias)| -> TableOperationResult<(Identifier, Column<'a, S>)> { + Ok(( + *alias, + apply_column_to_indexes( + right.inner_table().get(ident).ok_or( + TableOperationError::ColumnDoesNotExist { + column_ident: *ident, + }, + )?, + alloc, + &right_indexes, + )?, + )) + }, + )) + .collect::>>()?; + Ok( + Table::<'a, S>::try_new_with_options(result_columns, TableOptions::new(Some(num_rows))) + .expect("Table creation should not fail"), + ) +} + #[cfg(test)] mod tests { use super::*; @@ -240,4 +390,299 @@ mod tests { assert_eq!(result.num_rows(), 0); assert_eq!(result.num_columns(), 0); } + + #[test] + fn we_can_do_sort_merge_join_on_two_tables() { + let bump = Bump::new(); + let a = "a".parse().unwrap(); + let b = "b".parse().unwrap(); + let c = "c".parse().unwrap(); + let left = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])), + (b, Column::Int(&[3_i32, 5, 9, 4, 5, 7])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let right = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])), + (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let left_on = vec![Column::Int(&[3_i32, 5, 9, 4, 5, 7])]; + let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])]; + let left_selected_column_ident_aliases = vec![(a, a), (b, b)]; + let right_selected_column_ident_aliases = vec![(c, c)]; + let result = sort_merge_join( + &left, + &right, + &left_on, + &right_on, + &left_selected_column_ident_aliases, + &right_selected_column_ident_aliases, + &bump, + ) + .unwrap(); + assert_eq!(result.num_rows(), 5); + assert_eq!(result.num_columns(), 3); + assert_eq!( + result.inner_table()[&a].as_smallint().unwrap(), + &[1_i16, 2, 2, 3, 3] + ); + assert_eq!( + result.inner_table()[&b].as_int().unwrap(), + &[4_i32, 5, 5, 5, 5] + ); + assert_eq!( + result.inner_table()[&c].as_bigint().unwrap(), + &[7_i64, 8, 9, 8, 9] + ); + } + + #[test] + fn we_can_do_sort_merge_join_on_two_tables_with_empty_results() { + let bump = Bump::new(); + let a = "a".parse().unwrap(); + let b = "b".parse().unwrap(); + let c = "c".parse().unwrap(); + let left = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])), + (b, Column::Int(&[3_i32, 15, 9, 14, 15, 7])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let right = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])), + (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let left_on = vec![Column::Int(&[3_i32, 15, 9, 14, 15, 7])]; + let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])]; + let left_selected_column_ident_aliases = vec![(a, a), (b, b)]; + let right_selected_column_ident_aliases = vec![(c, c)]; + let result = sort_merge_join( + &left, + &right, + &left_on, + &right_on, + &left_selected_column_ident_aliases, + &right_selected_column_ident_aliases, + &bump, + ) + .unwrap(); + assert_eq!(result.num_rows(), 0); + assert_eq!(result.num_columns(), 3); + assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]); + assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]); + assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]); + } + + #[allow(clippy::too_many_lines)] + #[test] + fn we_can_do_sort_merge_join_on_tables_with_no_rows() { + let bump = Bump::new(); + let a = "a".parse().unwrap(); + let b = "b".parse().unwrap(); + let c = "c".parse().unwrap(); + + // Right table has no rows + let left = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])), + (b, Column::Int(&[3_i32, 15, 9, 14, 15, 7])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let right = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (c, Column::BigInt(&[0_i64; 0])), + (b, Column::Int(&[0_i32; 0])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let left_on = vec![Column::Int(&[3_i32, 15, 9, 14, 15, 7])]; + let right_on = vec![Column::Int(&[0_i32; 0])]; + let left_selected_column_ident_aliases = vec![(a, a), (b, b)]; + let right_selected_column_ident_aliases = vec![(c, c)]; + let result = sort_merge_join( + &left, + &right, + &left_on, + &right_on, + &left_selected_column_ident_aliases, + &right_selected_column_ident_aliases, + &bump, + ) + .unwrap(); + assert_eq!(result.num_rows(), 0); + assert_eq!(result.num_columns(), 3); + assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]); + assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]); + assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]); + + // Left table has no rows + let left = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (a, Column::SmallInt(&[0_i16; 0])), + (b, Column::Int(&[0_i32; 0])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let right = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])), + (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let left_on = vec![Column::Int(&[0_i32; 0])]; + let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])]; + let left_selected_column_ident_aliases = vec![(a, a), (b, b)]; + let right_selected_column_ident_aliases = vec![(c, c)]; + let result = sort_merge_join( + &left, + &right, + &left_on, + &right_on, + &left_selected_column_ident_aliases, + &right_selected_column_ident_aliases, + &bump, + ) + .unwrap(); + assert_eq!(result.num_rows(), 0); + assert_eq!(result.num_columns(), 3); + assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]); + assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]); + assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]); + + // Both tables have no rows + let left = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (a, Column::SmallInt(&[0_i16; 0])), + (b, Column::Int(&[0_i32; 0])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let right = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (c, Column::BigInt(&[0_i64; 0])), + (b, Column::Int(&[0_i32; 0])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let left_on = vec![Column::Int(&[0_i32; 0])]; + let right_on = vec![Column::Int(&[0_i32; 0])]; + let left_selected_column_ident_aliases = vec![(a, a), (b, b)]; + let right_selected_column_ident_aliases = vec![(c, c)]; + let result = sort_merge_join( + &left, + &right, + &left_on, + &right_on, + &left_selected_column_ident_aliases, + &right_selected_column_ident_aliases, + &bump, + ) + .unwrap(); + assert_eq!(result.num_rows(), 0); + assert_eq!(result.num_columns(), 3); + assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]); + assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]); + assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]); + } + + #[test] + fn we_can_not_do_sort_merge_join_with_duplicate_aliases() { + let bump = Bump::new(); + let a = "a".parse().unwrap(); + let b = "b".parse().unwrap(); + let c = "c".parse().unwrap(); + let left = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])), + (b, Column::Int(&[3_i32, 5, 9, 4, 5, 7])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let right = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])), + (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let left_on = vec![Column::Int(&[3_i32, 5, 9, 4, 5, 7])]; + let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])]; + let left_selected_column_ident_aliases = vec![(a, a), (b, b)]; + let right_selected_column_ident_aliases = vec![(b, b), (c, c)]; + let result = sort_merge_join( + &left, + &right, + &left_on, + &right_on, + &left_selected_column_ident_aliases, + &right_selected_column_ident_aliases, + &bump, + ); + assert_eq!(result, Err(TableOperationError::DuplicateColumn)); + } + + #[test] + fn we_can_not_do_sort_merge_join_with_wrong_column_identifiers() { + let bump = Bump::new(); + let a = "a".parse().unwrap(); + let b = "b".parse().unwrap(); + let c = "c".parse().unwrap(); + let not_a_column = "not_a_column".parse().unwrap(); + let left = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])), + (b, Column::Int(&[3_i32, 5, 9, 4, 5, 7])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let right = Table::<'_, TestScalar>::try_from_iter_with_options( + vec![ + (c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])), + (b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])), + ], + TableOptions::default(), + ) + .expect("Table creation should not fail"); + let left_on = vec![Column::Int(&[3_i32, 5, 9, 4, 5, 7])]; + let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])]; + let left_selected_column_ident_aliases = vec![(a, a), (b, b)]; + let right_selected_column_ident_aliases = vec![(not_a_column, c)]; + let result = sort_merge_join( + &left, + &right, + &left_on, + &right_on, + &left_selected_column_ident_aliases, + &right_selected_column_ident_aliases, + &bump, + ); + assert!(matches!( + result, + Err(TableOperationError::ColumnDoesNotExist { .. }) + )); + } } diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index 5e259f6a7..03f80f23c 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -26,6 +26,7 @@ pub(super) use column_comparison_operation::{ }; mod column_index_operation; +pub(super) use column_index_operation::apply_column_to_indexes; mod column_repetition_operation; pub(super) use column_repetition_operation::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp}; diff --git a/crates/proof-of-sql/src/base/database/table_operation_error.rs b/crates/proof-of-sql/src/base/database/table_operation_error.rs index b0ec13d91..b631d9003 100644 --- a/crates/proof-of-sql/src/base/database/table_operation_error.rs +++ b/crates/proof-of-sql/src/base/database/table_operation_error.rs @@ -1,6 +1,7 @@ -use crate::base::database::{ColumnField, ColumnType}; +use super::{ColumnField, ColumnOperationError, ColumnType}; use alloc::vec::Vec; use core::result::Result; +use proof_of_sql_parser::Identifier; use snafu::Snafu; /// Errors from operations on tables. @@ -26,6 +27,21 @@ pub enum TableOperationError { /// The right-hand side data type right_type: ColumnType, }, + /// Errors related to a column that does not exist in a table. + #[snafu(display("Column {column_ident:?} does not exist in table"))] + ColumnDoesNotExist { + /// The nonexistent column identifier + column_ident: Identifier, + }, + /// Errors related to duplicate columns in a table. + #[snafu(display("Some column is duplicated in table"))] + DuplicateColumn, + /// Errors due to bad column operations. + #[snafu(transparent)] + ColumnOperationError { + /// The underlying `ColumnOperationError` + source: ColumnOperationError, + }, } /// Result type for table operations