Skip to content

Commit

Permalink
refactor!: replac
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 12, 2024
1 parent ea47dae commit 044b954
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,6 @@ fn we_can_convert_between_owned_table_and_record_batch() {
);
}

#[test]
fn we_cannot_convert_a_record_batch_if_it_has_repeated_column_names() {
let record_batch = record_batch!(
"a" => [0_i64; 0],
"A" => [0_i128; 0],
);
assert!(matches!(
OwnedTable::<TestScalar>::try_from(record_batch),
Err(OwnedArrowConversionError::DuplicateIdentifiers)
));
}

#[test]
#[should_panic(expected = "not implemented: Cannot convert Scalar type to arrow type")]
fn we_panic_when_converting_an_owned_table_with_a_scalar_column() {
Expand Down
73 changes: 37 additions & 36 deletions crates/proof-of-sql/src/base/database/join_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ pub fn cross_join<'a, S: Scalar>(
Table::<'a, S>::try_from_iter_with_options(
left.inner_table()
.iter()
.map(|(&ident, column)| {
.map(|(ident, column)| {
(
ident,
ident.clone(),
ColumnRepeatOp::column_op(column, alloc, right_num_rows),
)
})
.chain(right.inner_table().iter().map(|(&ident, column)| {
.chain(right.inner_table().iter().map(|(ident, column)| {
(
ident,
ident.clone(),
ElementwiseRepeatOp::column_op(column, alloc, left_num_rows),
)
})),
Expand All @@ -38,26 +38,27 @@ pub fn cross_join<'a, S: Scalar>(
mod tests {
use super::*;
use crate::base::{database::Column, scalar::test_scalar::TestScalar};
use sqlparser::ast::Ident;

#[test]
fn we_can_do_cross_joins() {
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let d = "d".parse().unwrap();
let a: Ident = "a".into();
let b: Ident = "b".into();
let c: Ident = "c".into();
let d: Ident = "d".into();
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[1_i16, 2, 3])),
(b, Column::Int(&[4_i32, 5, 6])),
(a.clone(), Column::SmallInt(&[1_i16, 2, 3])),
(b.clone(), Column::Int(&[4_i32, 5, 6])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[7_i64, 8, 9])),
(d, Column::Int128(&[10_i128, 11, 12])),
(c.clone(), Column::BigInt(&[7_i64, 8, 9])),
(d.clone(), Column::Int128(&[10_i128, 11, 12])),
],
TableOptions::default(),
)
Expand Down Expand Up @@ -86,24 +87,24 @@ mod tests {
#[test]
fn we_can_do_cross_joins_if_one_table_has_no_rows() {
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let d = "d".parse().unwrap();
let a: Ident = "a".into();
let b: Ident = "b".into();
let c: Ident = "c".into();
let d: Ident = "d".into();

// Right table has no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[1_i16, 2, 3])),
(b, Column::Int(&[4_i32, 5, 6])),
(a.clone(), Column::SmallInt(&[1_i16, 2, 3])),
(b.clone(), Column::Int(&[4_i32, 5, 6])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[0_i64; 0])),
(d, Column::Int128(&[0_i128; 0])),
(c.clone(), Column::BigInt(&[0_i64; 0])),
(d.clone(), Column::Int128(&[0_i128; 0])),
],
TableOptions::default(),
)
Expand All @@ -119,16 +120,16 @@ mod tests {
// Left table has no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[0_i16; 0])),
(b, Column::Int(&[0_i32; 0])),
(a.clone(), Column::SmallInt(&[0_i16; 0])),
(b.clone(), Column::Int(&[0_i32; 0])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[7_i64, 8, 9])),
(d, Column::Int128(&[10_i128, 11, 12])),
(c.clone(), Column::BigInt(&[7_i64, 8, 9])),
(d.clone(), Column::Int128(&[10_i128, 11, 12])),
],
TableOptions::default(),
)
Expand All @@ -144,16 +145,16 @@ mod tests {
// Both tables have no rows
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[0_i16; 0])),
(b, Column::Int(&[0_i32; 0])),
(a.clone(), Column::SmallInt(&[0_i16; 0])),
(b.clone(), Column::Int(&[0_i32; 0])),
],
TableOptions::default(),
)
.expect("Table creation should not fail");
let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[0_i64; 0])),
(d, Column::Int128(&[0_i128; 0])),
(c.clone(), Column::BigInt(&[0_i64; 0])),
(d.clone(), Column::Int128(&[0_i128; 0])),
],
TableOptions::default(),
)
Expand All @@ -171,18 +172,18 @@ mod tests {
fn we_can_do_cross_joins_if_one_table_has_no_columns() {
// Left table has no columns
let bump = Bump::new();
let a = "a".parse().unwrap();
let b = "b".parse().unwrap();
let c = "c".parse().unwrap();
let d = "d".parse().unwrap();
let a: Ident = "a".into();
let b: Ident = "b".into();
let c: Ident = "c".into();
let d: Ident = "d".into();
let left =
Table::<'_, TestScalar>::try_from_iter_with_options(vec![], TableOptions::new(Some(2)))
.expect("Table creation should not fail");

let right = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(c, Column::BigInt(&[7_i64, 8])),
(d, Column::Int128(&[10_i128, 11])),
(c.clone(), Column::BigInt(&[7_i64, 8])),
(d.clone(), Column::Int128(&[10_i128, 11])),
],
TableOptions::default(),
)
Expand All @@ -203,8 +204,8 @@ mod tests {
// Right table has no columns
let left = Table::<'_, TestScalar>::try_from_iter_with_options(
vec![
(a, Column::SmallInt(&[1_i16, 2])),
(b, Column::Int(&[4_i32, 5])),
(a.clone(), Column::SmallInt(&[1_i16, 2])),
(b.clone(), Column::Int(&[4_i32, 5])),
],
TableOptions::default(),
)
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/base/database/owned_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ impl<'a, S: Scalar> From<&Table<'a, S>> for OwnedTable<S> {
value
.inner_table()
.iter()
.map(|(name, column)| (*name, OwnedColumn::from(column))),
.map(|(name, column)| (name.clone(), OwnedColumn::from(column))),
)
.expect("Tables should not have columns with differing lengths")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ fn we_can_access_the_columns_of_a_table() {

let column = ColumnRef::new(
table_ref_2,
"time".parse().unwrap(),
"time".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()),
);
match accessor.get_column(column) {
Expand Down
1 change: 1 addition & 0 deletions crates/proof-of-sql/src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ mod serialize;
pub(crate) use serialize::{impl_serde_for_ark_serde_checked, impl_serde_for_ark_serde_unchecked};
pub(crate) mod map;
pub(crate) mod slice_ops;
pub(crate) mod sqlparser;

mod rayon_cfg;
pub(crate) use rayon_cfg::if_rayon;
9 changes: 9 additions & 0 deletions crates/proof-of-sql/src/base/sqlparser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use sqlparser::ast::Ident;

// Normalize an owned identifier to a lowercase string unless the identifier is quoted.
pub(crate) fn normalize_ident(id: Ident) -> alloc::string::String {
match id.quote_style {
Some(_) => id.value,
None => id.value.to_ascii_lowercase(),
}
}
6 changes: 4 additions & 2 deletions crates/proof-of-sql/src/sql/parse/query_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
base::{
database::{ColumnRef, LiteralValue, TableRef},
map::{IndexMap, IndexSet},
sqlparser::normalize_ident,
},
sql::{
parse::{ConversionError, ConversionResult, DynProofExprBuilder, WhereExprBuilder},
Expand Down Expand Up @@ -102,8 +103,9 @@ impl QueryContext {

pub fn push_column_ref(&mut self, column: Identifier, column_ref: ColumnRef) {
self.col_ref_counter += 1;
self.push_result_column_ref(column);
self.column_mapping.insert(column, column_ref);
let normalized_column = Ident::new(normalize_ident(column));
self.push_result_column_ref(normalized_column.clone());
self.column_mapping.insert(normalized_column, column_ref);
}

fn push_result_column_ref(&mut self, column: Identifier) {
Expand Down
26 changes: 19 additions & 7 deletions crates/proof-of-sql/src/sql/parse/query_context_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub struct QueryContextBuilder<'a> {
context: QueryContext,
schema_accessor: &'a dyn SchemaAccessor,
}
use crate::base::sqlparser::normalize_ident;
use sqlparser::ast::Ident;

// Public interface
impl<'a> QueryContextBuilder<'a> {
Expand Down Expand Up @@ -92,7 +94,7 @@ impl<'a> QueryContextBuilder<'a> {
group_by_exprs: Vec<Identifier>,
) -> ConversionResult<Self> {
for id in &group_by_exprs {
self.visit_column_identifier(*id)?;
self.visit_column_identifier(id)?;
}
self.context.set_group_by_exprs(group_by_exprs);
Ok(self)
Expand All @@ -112,8 +114,14 @@ impl<'a> QueryContextBuilder<'a> {
)]
fn lookup_schema(&self) -> Vec<(Identifier, ColumnType)> {
let table_ref = self.context.get_table_ref();
let columns = self.schema_accessor.lookup_schema(*table_ref);
let mut columns = self.schema_accessor.lookup_schema(*table_ref);
assert!(!columns.is_empty(), "At least one column must exist");
// Normalize all column names
for (ident, _) in &mut columns {
let normalized_name = normalize_ident(ident.clone());
*ident = Ident::new(normalized_name);
}

columns
}

Expand Down Expand Up @@ -153,7 +161,7 @@ impl<'a> QueryContextBuilder<'a> {
_ => panic!("Must be a column expression"),
};

self.visit_column_identifier(identifier)
self.visit_column_identifier(&identifier.into())
}

fn visit_binary_expr(
Expand Down Expand Up @@ -255,18 +263,22 @@ impl<'a> QueryContextBuilder<'a> {
}
}

fn visit_column_identifier(&mut self, column_name: Identifier) -> ConversionResult<ColumnType> {
fn visit_column_identifier(&mut self, column_name: &Ident) -> ConversionResult<ColumnType> {
let table_ref = self.context.get_table_ref();
let column_type = self.schema_accessor.lookup_column(*table_ref, column_name);
// Normalize the column name before looking it up
let normalized_column_name = Ident::new(normalize_ident(column_name.clone()));
let column_type = self
.schema_accessor
.lookup_column(*table_ref, normalized_column_name.clone());

let column_type = column_type.ok_or_else(|| ConversionError::MissingColumn {
identifier: Box::new(column_name),
resource_id: Box::new(table_ref.resource_id()),
})?;

let column = ColumnRef::new(*table_ref, column_name, column_type);
let column = ColumnRef::new(*table_ref, normalized_column_name.clone(), column_type);

self.context.push_column_ref(column_name, column);
self.context.push_column_ref(normalized_column_name, column);

Ok(column_type)
}
Expand Down
2 changes: 1 addition & 1 deletion crates/proof-of-sql/src/sql/proof/query_proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ fn extend_transcript_with_owned_table<S: Scalar, T: Transcript>(
result: &OwnedTable<S>,
) {
for (name, column) in result.inner_table() {
transcript.extend_as_le_from_refs([name.as_str()]);
transcript.extend_as_le_from_refs([name.value.as_str()]);
match column {
OwnedColumn::Boolean(col) => transcript.extend_as_be(col.iter().map(|&b| u8::from(b))),
OwnedColumn::TinyInt(col) => transcript.extend_as_be_from_refs(col),
Expand Down

0 comments on commit 044b954

Please sign in to comment.