Skip to content

Commit

Permalink
feat: apply normalization for Identifier -> Ident
Browse files Browse the repository at this point in the history
  • Loading branch information
varshith257 committed Dec 8, 2024
1 parent 579a79d commit 1863ff2
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,6 @@ fn we_can_convert_between_owned_table_and_record_batch() {
);
}

#[test]
fn we_cannot_convert_a_record_batch_if_it_has_repeated_column_names() {
let record_batch = record_batch!(
"a" => [0_i64; 0],
"A" => [0_i128; 0],
);
assert!(matches!(
OwnedTable::<TestScalar>::try_from(record_batch),
Err(OwnedArrowConversionError::DuplicateIdentifiers)
));
}

#[test]
#[should_panic(expected = "not implemented: Cannot convert Scalar type to arrow type")]
fn we_panic_when_converting_an_owned_table_with_a_scalar_column() {
Expand Down
73 changes: 37 additions & 36 deletions crates/proof-of-sql/src/base/database/join_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ pub fn cross_join<'a, S: Scalar>(
Table::<'a, S>::try_from_iter_with_options(
left.inner_table()
.iter()
.map(|(&ident, column)| {
.map(|(ident, column)| {
(
ident,
ident.clone(),
ColumnRepeatOp::column_op(column, alloc, right_num_rows),
)
})
.chain(right.inner_table().iter().map(|(&ident, column)| {
.chain(right.inner_table().iter().map(|(ident, column)| {
(
ident,
ident.clone(),
ElementwiseRepeatOp::column_op(column, alloc, left_num_rows),
)
})),
Expand All @@ -38,26 +38,27 @@ pub fn cross_join<'a, S: Scalar>(
mod tests {
use super::*;
use crate::base::{database::Column, scalar::test_scalar::TestScalar};
use sqlparser::ast::Ident;

#[test]
fn we_can_do_cross_joins() {
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let d = "d".parse().unwrap();
let a: Ident = "a".into();
let b: Ident = "b".into();
let c: Ident = "c".into();
let d: Ident = "d".into();
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[1_i16, 2, 3])),
(b, Column::Int(&[4_i32, 5, 6])),
(a.clone(), Column::SmallInt(&[1_i16, 2, 3])),
(b.clone(), Column::Int(&[4_i32, 5, 6])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[7_i64, 8, 9])),
(d, Column::Int128(&[10_i128, 11, 12])),
(c.clone(), Column::BigInt(&[7_i64, 8, 9])),
(d.clone(), Column::Int128(&[10_i128, 11, 12])),
],
TableOptions::default(),
)
Expand Down Expand Up @@ -86,24 +87,24 @@ mod tests {
#[test]
fn we_can_do_cross_joins_if_one_table_has_no_rows() {
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let d = "d".parse().unwrap();
let a: Ident = "a".into();
let b: Ident = "b".into();
let c: Ident = "c".into();
let d: Ident = "d".into();

// Right table has no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[1_i16, 2, 3])),
(b, Column::Int(&[4_i32, 5, 6])),
(a.clone(), Column::SmallInt(&[1_i16, 2, 3])),
(b.clone(), Column::Int(&[4_i32, 5, 6])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[0_i64; 0])),
(d, Column::Int128(&[0_i128; 0])),
(c.clone(), Column::BigInt(&[0_i64; 0])),
(d.clone(), Column::Int128(&[0_i128; 0])),
],
TableOptions::default(),
)
Expand All @@ -119,16 +120,16 @@ mod tests {
// Left table has no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[0_i16; 0])),
(b, Column::Int(&[0_i32; 0])),
(a.clone(), Column::SmallInt(&[0_i16; 0])),
(b.clone(), Column::Int(&[0_i32; 0])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[7_i64, 8, 9])),
(d, Column::Int128(&[10_i128, 11, 12])),
(c.clone(), Column::BigInt(&[7_i64, 8, 9])),
(d.clone(), Column::Int128(&[10_i128, 11, 12])),
],
TableOptions::default(),
)
Expand All @@ -144,16 +145,16 @@ mod tests {
// Both tables have no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[0_i16; 0])),
(b, Column::Int(&[0_i32; 0])),
(a.clone(), Column::SmallInt(&[0_i16; 0])),
(b.clone(), Column::Int(&[0_i32; 0])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[0_i64; 0])),
(d, Column::Int128(&[0_i128; 0])),
(c.clone(), Column::BigInt(&[0_i64; 0])),
(d.clone(), Column::Int128(&[0_i128; 0])),
],
TableOptions::default(),
)
Expand All @@ -171,18 +172,18 @@ mod tests {
fn we_can_do_cross_joins_if_one_table_has_no_columns() {
// Left table has no columns
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let d = "d".parse().unwrap();
let a: Ident = "a".into();
let b: Ident = "b".into();
let c: Ident = "c".into();
let d: Ident = "d".into();
let left =
Table::<'_, TestScalar>::try_from_iter_with_options(vec![], TableOptions::new(Some(2)))
.expect("Table creation should not fail");

let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[7_i64, 8])),
(d, Column::Int128(&[10_i128, 11])),
(c.clone(), Column::BigInt(&[7_i64, 8])),
(d.clone(), Column::Int128(&[10_i128, 11])),
],
TableOptions::default(),
)
Expand All @@ -203,8 +204,8 @@ mod tests {
// Right table has no columns
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[1_i16, 2])),
(b, Column::Int(&[4_i32, 5])),
(a.clone(), Column::SmallInt(&[1_i16, 2])),
(b.clone(), Column::Int(&[4_i32, 5])),
],
TableOptions::default(),
)
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/base/database/owned_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl<'a, S: Scalar> From<&Table<'a, S>> for OwnedTable<S> {
value
.inner_table()
.iter()
.map(|(name, column)| (*name, OwnedColumn::from(column))),
.map(|(name, column)| (name.clone(), OwnedColumn::from(column))),
)
.expect("Tables should not have columns with differing lengths")
}
Expand Down
29 changes: 27 additions & 2 deletions crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ use super::{
use crate::base::{
commitment::{CommitmentEvaluationProof, VecCommitmentExt},
map::IndexMap,
sqlparser::normalize_ident,
};
use alloc::{string::String, vec::Vec};
use ark_std::hash::BuildHasherDefault;
use bumpalo::Bump;
use proof_of_sql_parser::{Identifier, ResourceId};
use sqlparser::ast::Ident;

/// A test accessor that uses [`OwnedTable`] as the underlying table type.
/// Note: this is intended for testing and examples. It is not optimized for performance, so should not be used for benchmarks or production use-cases.
pub struct OwnedTableTestAccessor<'a, CP: CommitmentEvaluationProof> {
Expand Down Expand Up @@ -48,7 +50,30 @@ impl<CP: CommitmentEvaluationProof> TestAccessor<CP::Commitment>
}

fn add_table(&mut self, table_ref: TableRef, data: Self::Table, table_offset: usize) {
self.tables.insert(table_ref, (data, table_offset));
// Normalize the table reference (schema and object name)
let normalized_table_ref = TableRef::new(ResourceId::new(
Identifier::try_from(Ident::new(normalize_ident(
table_ref.resource_id().schema().into(),
)))
.expect("Failed to convert Ident to Identifier"),
Identifier::try_from(Ident::new(normalize_ident(
table_ref.resource_id().object_name().into(),
)))
.expect("Failed to convert Ident to Identifier"),
));
// Normalize column names within the table
let normalized_data = {
let mut normalized_table =
IndexMap::with_capacity_and_hasher(0, BuildHasherDefault::default());
for (ident, column) in data.into_inner() {
let normalized_ident = Ident::new(normalize_ident(ident));
normalized_table.insert(normalized_ident, column);
}
OwnedTable::try_new(normalized_table).expect("Column lengths must match")
};

self.tables
.insert(normalized_table_ref, (normalized_data, table_offset));
}
///
/// # Panics
Expand Down
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod serialize;
pub(crate) use serialize::{impl_serde_for_ark_serde_checked, impl_serde_for_ark_serde_unchecked};
pub(crate) mod map;
pub(crate) mod slice_ops;
pub(crate) mod sqlparser;

mod rayon_cfg;
pub(crate) use rayon_cfg::if_rayon;
9 changes: 9 additions & 0 deletions crates/proof-of-sql/src/base/sqlparser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use sqlparser::ast::Ident;

// Normalize an owned identifier to a lowercase string unless the identifier is quoted.
pub(crate) fn normalize_ident(id: Ident) -> alloc::string::String {
match id.quote_style {
Some(_) => id.value,
None => id.value.to_ascii_lowercase(),
}
}
6 changes: 4 additions & 2 deletions crates/proof-of-sql/src/sql/parse/query_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
base::{
database::{ColumnRef, LiteralValue, TableRef},
map::{IndexMap, IndexSet},
sqlparser::normalize_ident,
},
sql::{
parse::{ConversionError, ConversionResult, DynProofExprBuilder, WhereExprBuilder},
Expand Down Expand Up @@ -102,8 +103,9 @@ impl QueryContext {

pub fn push_column_ref(&mut self, column: Ident, column_ref: ColumnRef) {
self.col_ref_counter += 1;
self.push_result_column_ref(column.clone());
self.column_mapping.insert(column, column_ref);
let normalized_column = Ident::new(normalize_ident(column));
self.push_result_column_ref(normalized_column.clone());
self.column_mapping.insert(normalized_column, column_ref);
}

fn push_result_column_ref(&mut self, column: Ident) {
Expand Down
23 changes: 16 additions & 7 deletions crates/proof-of-sql/src/sql/parse/query_context_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub struct QueryContextBuilder<'a> {
context: QueryContext,
schema_accessor: &'a dyn SchemaAccessor,
}
use crate::base::sqlparser::normalize_ident;
use sqlparser::ast::Ident;

// Public interface
Expand Down Expand Up @@ -91,7 +92,7 @@ impl<'a> QueryContextBuilder<'a> {

pub fn visit_group_by_exprs(mut self, group_by_exprs: Vec<Ident>) -> ConversionResult<Self> {
for id in &group_by_exprs {
self.visit_column_identifier(id.clone())?;
self.visit_column_identifier(id)?;
}
self.context.set_group_by_exprs(group_by_exprs);
Ok(self)
Expand All @@ -111,8 +112,14 @@ impl<'a> QueryContextBuilder<'a> {
)]
fn lookup_schema(&self) -> Vec<(Ident, ColumnType)> {
let table_ref = self.context.get_table_ref();
let columns = self.schema_accessor.lookup_schema(*table_ref);
let mut columns = self.schema_accessor.lookup_schema(*table_ref);
assert!(!columns.is_empty(), "At least one column must exist");
// Normalize all column names
for (ident, _) in &mut columns {
let normalized_name = normalize_ident(ident.clone());
*ident = Ident::new(normalized_name);
}

columns
}

Expand Down Expand Up @@ -156,7 +163,7 @@ impl<'a> QueryContextBuilder<'a> {
_ => panic!("Must be a column expression"),
};

self.visit_column_identifier(identifier.into())
self.visit_column_identifier(&identifier.into())
}

fn visit_binary_expr(
Expand Down Expand Up @@ -258,20 +265,22 @@ impl<'a> QueryContextBuilder<'a> {
}
}

fn visit_column_identifier(&mut self, column_name: Ident) -> ConversionResult<ColumnType> {
fn visit_column_identifier(&mut self, column_name: &Ident) -> ConversionResult<ColumnType> {
let table_ref = self.context.get_table_ref();
// Normalize the column name before looking it up
let normalized_column_name = Ident::new(normalize_ident(column_name.clone()));
let column_type = self
.schema_accessor
.lookup_column(*table_ref, column_name.clone());
.lookup_column(*table_ref, normalized_column_name.clone());

let column_type = column_type.ok_or_else(|| ConversionError::MissingColumn {
identifier: Box::new(column_name.clone()),
resource_id: Box::new(table_ref.resource_id()),
})?;

let column = ColumnRef::new(*table_ref, column_name.clone(), column_type);
let column = ColumnRef::new(*table_ref, normalized_column_name.clone(), column_type);

self.context.push_column_ref(column_name, column);
self.context.push_column_ref(normalized_column_name, column);

Ok(column_type)
}
Expand Down
8 changes: 2 additions & 6 deletions crates/proof-of-sql/src/sql/proof_plans/demo_mock_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl ProofPlan for DemoMockPlan {
}

fn get_column_references(&self) -> IndexSet<ColumnRef> {
indexset! {self.column}
indexset! {self.column.clone()}
}

fn get_table_references(&self) -> IndexSet<TableRef> {
Expand Down Expand Up @@ -98,11 +98,7 @@ mod tests {
fn we_can_create_and_prove_a_demo_mock_plan() {
let table_ref = "namespace.table_name".parse::<TableRef>().unwrap();
let table = owned_table([bigint("column_name", [0, 1, 2, 3])]);
let column_ref = ColumnRef::new(
table_ref,
"column_name".parse().unwrap(),
ColumnType::BigInt,
);
let column_ref = ColumnRef::new(table_ref, "column_name".into(), ColumnType::BigInt);
let plan = DemoMockPlan { column: column_ref };
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(
table_ref,
Expand Down
Loading

0 comments on commit 1863ff2

Please sign in to comment.