Skip to content

Commit

Permalink
feat: add sort_merge_join
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 15, 2024
1 parent 5eeefc9 commit 311c8e8
Show file tree
Hide file tree
Showing 3 changed files with 373 additions and 3 deletions.
357 changes: 355 additions & 2 deletions crates/proof-of-sql/src/base/database/join_util.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use super::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp, Table, TableOptions};
use crate::base::scalar::Scalar;
use super::{
apply_column_to_indexes,
order_by_util::{compare_indexes_by_columns, compare_single_row_of_tables},
Column, ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp, Table, TableOperationError,
TableOperationResult, TableOptions,
};
use crate::base::{
map::{IndexMap, IndexSet},
scalar::Scalar,
};
use bumpalo::Bump;
use core::cmp::Ordering;
use itertools::Itertools;
use proof_of_sql_parser::Identifier;

/// Compute the CROSS JOIN / cartesian product of two tables.
///
Expand Down Expand Up @@ -34,6 +45,133 @@ pub fn cross_join<'a, S: Scalar>(
.expect("Table creation should not fail")
}

/// Compute the JOIN of two tables using a sort-merge join.
///
/// Currently we only support INNER JOINs and only support joins on equalities.
/// # Panics
/// The function panics if we feed in incorrect data (e.g. Num of rows in `left` and some column of `left_on` being different).
#[allow(clippy::too_many_lines)]
pub fn sort_merge_join<'a, S: Scalar>(
left: &Table<'a, S>,
right: &Table<'a, S>,
left_on: &[Column<'a, S>],
right_on: &[Column<'a, S>],
left_selected_column_ident_aliases: &[(Identifier, Identifier)],
right_selected_column_ident_aliases: &[(Identifier, Identifier)],
alloc: &'a Bump,
) -> TableOperationResult<Table<'a, S>> {
let left_num_rows = left.num_rows();
let right_num_rows = right.num_rows();
// Check that result aliases are unique
let aliases = left_selected_column_ident_aliases
.iter()
.map(|(_, alias)| alias)
.chain(
right_selected_column_ident_aliases
.iter()
.map(|(_, alias)| alias),
)
.collect::<IndexSet<_>>();
if aliases.len()
!= left_selected_column_ident_aliases.len() + right_selected_column_ident_aliases.len()
{
return Err(TableOperationError::DuplicateColumn);
}
// Check that the number of rows is good
for column in left_on {
assert_eq!(column.len(), left_num_rows);
}
for column in right_on {
assert_eq!(column.len(), right_num_rows);
}
// First of all sort the tables by the columns we are joining on
let left_indexes =
(0..left.num_rows()).sorted_unstable_by(|&a, &b| compare_indexes_by_columns(left_on, a, b));
let right_indexes = (0..right.num_rows())
.sorted_unstable_by(|&a, &b| compare_indexes_by_columns(right_on, a, b));
// Collect the indexes of the rows that match
let mut left_iter = left_indexes.into_iter().peekable();
let mut right_iter = right_indexes.into_iter().peekable();
let mut index_pairs = Vec::<(usize, usize)>::new();
while let (Some(&left_index), Some(&right_index)) = (left_iter.peek(), right_iter.peek()) {
match compare_single_row_of_tables(left_on, right_on, left_index, right_index)? {
Ordering::Less => {
left_iter.next();
}
Ordering::Greater => {
right_iter.next();
}
Ordering::Equal => {
// Collect all matching indexes from the left table
let left_group: Vec<_> = left_iter
.clone()
.take_while(|&item| {
compare_indexes_by_columns(left_on, left_index, item) == Ordering::Equal
})
.collect();
// Collect all matching indexes from the right table
let right_group: Vec<_> = right_iter
.clone()
.take_while(|item| {
compare_indexes_by_columns(right_on, right_index, *item) == Ordering::Equal
})
.collect();
// Collect indexes
let matched_index_pairs = left_group
.iter()
.copied()
.cartesian_product(right_group.iter().copied());
index_pairs.extend(matched_index_pairs);
// Move the iterators to the next group
left_iter.nth(left_group.len() - 1);
right_iter.nth(right_group.len() - 1);
}
}
}
// Now we have the indexes of the rows that match, we can create the new table
let (left_indexes, right_indexes): (Vec<usize>, Vec<usize>) = index_pairs.into_iter().unzip();
let num_rows = left_indexes.len();
let result_columns = left_selected_column_ident_aliases
.iter()
.map(
|(ident, alias)| -> TableOperationResult<(Identifier, Column<'a, S>)> {
Ok((
*alias,
apply_column_to_indexes(
left.inner_table().get(ident).ok_or(
TableOperationError::ColumnDoesNotExist {
column_ident: *ident,
},
)?,
alloc,
&left_indexes,
)?,
))
},
)
.chain(right_selected_column_ident_aliases.iter().map(
|(ident, alias)| -> TableOperationResult<(Identifier, Column<'a, S>)> {
Ok((
*alias,
apply_column_to_indexes(
right.inner_table().get(ident).ok_or(
TableOperationError::ColumnDoesNotExist {
column_ident: *ident,
},
)?,
alloc,
&right_indexes,
)?,
))
},
))
.collect::<TableOperationResult<IndexMap<_, _>>>()?;
Ok(
Table::<'a, S>::try_new_with_options(result_columns, TableOptions::new(Some(num_rows)))
.expect("Table creation should not fail"),
)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -240,4 +378,219 @@ mod tests {
assert_eq!(result.num_rows(), 0);
assert_eq!(result.num_columns(), 0);
}

#[test]
fn we_can_do_sort_merge_join_on_two_tables() {
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])),
(b, Column::Int(&[3_i32, 5, 9, 4, 5, 7])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])),
(b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let left_on = vec![Column::Int(&[3_i32, 5, 9, 4, 5, 7])];
let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])];
let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
let right_selected_column_ident_aliases = vec![(c, c)];
let result = sort_merge_join(
&left,
&right,
&left_on,
&right_on,
&left_selected_column_ident_aliases,
&right_selected_column_ident_aliases,
&bump,
)
.unwrap();
assert_eq!(result.num_rows(), 5);
assert_eq!(result.num_columns(), 3);
assert_eq!(
result.inner_table()[&a].as_smallint().unwrap(),
&[1_i16, 2, 2, 3, 3]
);
assert_eq!(
result.inner_table()[&b].as_int().unwrap(),
&[4_i32, 5, 5, 5, 5]
);
assert_eq!(
result.inner_table()[&c].as_bigint().unwrap(),
&[7_i64, 8, 9, 8, 9]
);
}

#[test]
fn we_can_do_sort_merge_join_on_two_tables_with_empty_results() {
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])),
(b, Column::Int(&[3_i32, 15, 9, 14, 15, 7])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])),
(b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let left_on = vec![Column::Int(&[3_i32, 15, 9, 14, 15, 7])];
let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])];
let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
let right_selected_column_ident_aliases = vec![(c, c)];
let result = sort_merge_join(
&left,
&right,
&left_on,
&right_on,
&left_selected_column_ident_aliases,
&right_selected_column_ident_aliases,
&bump,
)
.unwrap();
assert_eq!(result.num_rows(), 0);
assert_eq!(result.num_columns(), 3);
assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]);
assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]);
assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]);
}

#[allow(clippy::too_many_lines)]
#[test]
fn we_can_do_sort_merge_join_on_tables_with_no_rows() {
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();

// Right table has no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[8_i16, 2, 5, 1, 3, 7])),
(b, Column::Int(&[3_i32, 15, 9, 14, 15, 7])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[0_i64; 0])),
(b, Column::Int(&[0_i32; 0])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let left_on = vec![Column::Int(&[3_i32, 15, 9, 14, 15, 7])];
let right_on = vec![Column::Int(&[0_i32; 0])];
let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
let right_selected_column_ident_aliases = vec![(c, c)];
let result = sort_merge_join(
&left,
&right,
&left_on,
&right_on,
&left_selected_column_ident_aliases,
&right_selected_column_ident_aliases,
&bump,
)
.unwrap();
assert_eq!(result.num_rows(), 0);
assert_eq!(result.num_columns(), 3);
assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]);
assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]);
assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]);

// Left table has no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[0_i16; 0])),
(b, Column::Int(&[0_i32; 0])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[1_i64, 2, 7, 8, 9, 7, 2])),
(b, Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let left_on = vec![Column::Int(&[0_i32; 0])];
let right_on = vec![Column::Int(&[10_i32, 11, 6, 5, 5, 4, 8])];
let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
let right_selected_column_ident_aliases = vec![(c, c)];
let result = sort_merge_join(
&left,
&right,
&left_on,
&right_on,
&left_selected_column_ident_aliases,
&right_selected_column_ident_aliases,
&bump,
)
.unwrap();
assert_eq!(result.num_rows(), 0);
assert_eq!(result.num_columns(), 3);
assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]);
assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]);
assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]);

// Both tables have no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[0_i16; 0])),
(b, Column::Int(&[0_i32; 0])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[0_i64; 0])),
(b, Column::Int(&[0_i32; 0])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let left_on = vec![Column::Int(&[0_i32; 0])];
let right_on = vec![Column::Int(&[0_i32; 0])];
let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
let right_selected_column_ident_aliases = vec![(c, c)];
let result = sort_merge_join(
&left,
&right,
&left_on,
&right_on,
&left_selected_column_ident_aliases,
&right_selected_column_ident_aliases,
&bump,
)
.unwrap();
assert_eq!(result.num_rows(), 0);
assert_eq!(result.num_columns(), 3);
assert_eq!(result.inner_table()[&a].as_smallint().unwrap(), &[0_i16; 0]);
assert_eq!(result.inner_table()[&b].as_int().unwrap(), &[0_i32; 0]);
assert_eq!(result.inner_table()[&c].as_bigint().unwrap(), &[0_i64; 0]);
}
}
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub(super) use column_comparison_operation::{
};

mod column_index_operation;
pub(super) use column_index_operation::apply_column_to_indexes;

mod column_repetition_operation;
pub(super) use column_repetition_operation::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp};
Expand Down
Loading

0 comments on commit 311c8e8

Please sign in to comment.