Skip to content

Commit

Permalink
feat: add sort_merge_join
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 13, 2024
1 parent 5eeefc9 commit f40b136
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 3 deletions.
191 changes: 189 additions & 2 deletions crates/proof-of-sql/src/base/database/join_util.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
use super::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp, Table, TableOptions};
use crate::base::scalar::Scalar;
use super::{
apply_column_to_indexes,
order_by_util::{compare_indexes_by_columns, compare_single_row_of_tables},
Column, ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp, Table, TableOperationError,
TableOperationResult, TableOptions,
};
use crate::base::{
map::{IndexMap, IndexSet},
scalar::Scalar,
};
use bumpalo::Bump;
use core::cmp::Ordering;
use itertools::Itertools;
use proof_of_sql_parser::Identifier;

/// Compute the CROSS JOIN / cartesian product of two tables.
///
Expand Down Expand Up @@ -34,6 +45,130 @@ pub fn cross_join<'a, S: Scalar>(
.expect("Table creation should not fail")
}

/// Compute the JOIN of two tables using a sort-merge join.
///
/// Currently we only support INNER JOINs and only support joins on equalities.
/// # Panics
/// The function panics if we feed in incorrect data (e.g. Num of rows in `left` and some column of `left_on` being different).
pub fn sort_merge_join<'a, S: Scalar>(
left: &Table<'a, S>,
right: &Table<'a, S>,
left_on: &[Column<'a, S>],
right_on: &[Column<'a, S>],
left_selected_column_ident_aliases: &[(Identifier, Identifier)],
right_selected_column_ident_aliases: &[(Identifier, Identifier)],
alloc: &'a Bump,
) -> TableOperationResult<Table<'a, S>> {
let left_num_rows = left.num_rows();
let right_num_rows = right.num_rows();
// Check that result aliases are unique
let aliases = left_selected_column_ident_aliases
.iter()
.map(|(_, alias)| alias)
.chain(
right_selected_column_ident_aliases
.iter()
.map(|(_, alias)| alias),
)
.collect::<IndexSet<_>>();
if aliases.len()
!= left_selected_column_ident_aliases.len() + right_selected_column_ident_aliases.len()
{
return Err(TableOperationError::DuplicateColumn);
}
// Check that the number of rows is good
for column in left_on {
assert_eq!(column.len(), left_num_rows);
}
for column in right_on {
assert_eq!(column.len(), right_num_rows);
}
// First of all sort the tables by the columns we are joining on
let left_indexes =
(0..left.num_rows()).sorted_unstable_by(|&a, &b| compare_indexes_by_columns(left_on, a, b));
let right_indexes = (0..right.num_rows())
.sorted_unstable_by(|&a, &b| compare_indexes_by_columns(right_on, a, b));
// Collect the indexes of the rows that match
let mut left_iter = left_indexes.into_iter().peekable();
let mut right_iter = right_indexes.into_iter().peekable();
let mut index_pairs = Vec::<(usize, usize)>::new();
while let (Some(&left_index), Some(&right_index)) = (left_iter.peek(), right_iter.peek()) {
match compare_single_row_of_tables(left_on, right_on, left_index, right_index)? {
Ordering::Less => {
left_iter.next();
}
Ordering::Greater => {
right_iter.next();
}
Ordering::Equal => {
// Collect all matching indexes from the left table
let left_group: Vec<_> = left_iter
.by_ref()
.take_while(|item| {
compare_indexes_by_columns(left_on, left_index, *item) == Ordering::Equal
})
.collect();
// Collect all matching indexes from the right table
let right_group: Vec<_> = right_iter
.by_ref()
.take_while(|item| {
compare_indexes_by_columns(right_on, right_index, *item) == Ordering::Equal
})
.collect();
// Collect indexes
let matched_index_pairs = left_group
.iter()
.copied()
.cartesian_product(right_group.iter().copied());
dbg!(&matched_index_pairs);
index_pairs.extend(matched_index_pairs);
}
}
}
// Now we have the indexes of the rows that match, we can create the new table
let (left_indexes, right_indexes): (Vec<usize>, Vec<usize>) = index_pairs.into_iter().unzip();
let num_rows = left_indexes.len();
let result_columns = left_selected_column_ident_aliases
.iter()
.map(
|(ident, alias)| -> TableOperationResult<(Identifier, Column<'a, S>)> {
Ok((
*alias,
apply_column_to_indexes(
left.inner_table().get(ident).ok_or(
TableOperationError::ColumnDoesNotExist {
column_ident: *ident,
},
)?,
alloc,
&left_indexes,
)?,
))
},
)
.chain(right_selected_column_ident_aliases.iter().map(
|(ident, alias)| -> TableOperationResult<(Identifier, Column<'a, S>)> {
Ok((
*alias,
apply_column_to_indexes(
right.inner_table().get(ident).ok_or(
TableOperationError::ColumnDoesNotExist {
column_ident: *ident,
},
)?,
alloc,
&right_indexes,
)?,
))
},
))
.collect::<TableOperationResult<IndexMap<_, _>>>()?;
Ok(
Table::<'a, S>::try_new_with_options(result_columns, TableOptions::new(Some(num_rows)))
.expect("Table creation should not fail"),
)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -240,4 +375,56 @@ mod tests {
assert_eq!(result.num_rows(), 0);
assert_eq!(result.num_columns(), 0);
}

#[test]
fn we_can_do_sort_merge_join_on_two_tables() {
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[1_i16, 2, 3, 7])),
(b, Column::Int(&[4_i32, 5, 5, 7])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[7_i64, 8, 9, 2])),
(b, Column::Int(&[4_i32, 5, 5, 8])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let left_on = vec![Column::Int(&[4_i32, 5, 5, 7])];
let right_on = vec![Column::Int(&[4_i32, 5, 5, 8])];
let left_selected_column_ident_aliases = vec![(a, a), (b, b)];
let right_selected_column_ident_aliases = vec![(c, c)];
let result = sort_merge_join(
&left,
&right,
&left_on,
&right_on,
&left_selected_column_ident_aliases,
&right_selected_column_ident_aliases,
&bump,
).unwrap();
dbg!(&result);
assert_eq!(result.num_rows(), 5);
assert_eq!(result.num_columns(), 3);
assert_eq!(
result.inner_table()[&a].as_smallint().unwrap(),
&[1_i16, 2, 2, 3, 3]
);
assert_eq!(
result.inner_table()[&b].as_int().unwrap(),
&[4_i32, 5, 5, 5, 5]
);
assert_eq!(
result.inner_table()[&c].as_bigint().unwrap(),
&[7_i64, 8, 9, 8, 9]
);
}
}
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub(super) use column_comparison_operation::{
};

mod column_index_operation;
pub(super) use column_index_operation::apply_column_to_indexes;

mod column_repetition_operation;
pub(super) use column_repetition_operation::{ColumnRepeatOp, ElementwiseRepeatOp, RepetitionOp};
Expand Down
18 changes: 17 additions & 1 deletion crates/proof-of-sql/src/base/database/table_operation_error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::base::database::{ColumnField, ColumnType};
use super::{ColumnField, ColumnOperationError, ColumnType};
use alloc::vec::Vec;
use core::result::Result;
use proof_of_sql_parser::Identifier;
use snafu::Snafu;

/// Errors from operations on tables.
Expand All @@ -26,6 +27,21 @@ pub enum TableOperationError {
/// The right-hand side data type
right_type: ColumnType,
},
/// Errors related to a column that does not exist in a table.
#[snafu(display("Column {column_ident:?} does not exist in table"))]
ColumnDoesNotExist {
/// The nonexistent column identifier
column_ident: Identifier,
},
/// Errors related to duplicate columns in a table.
#[snafu(display("Some column is duplicated in table"))]
DuplicateColumn,
/// Errors due to bad column operations.
#[snafu(transparent)]
ColumnOperationError {
/// The underlying `ColumnOperationError`
source: ColumnOperationError,
},
}

/// Result type for table operations
Expand Down

0 comments on commit f40b136

Please sign in to comment.