diff --git a/crates/proof-of-sql/Cargo.toml b/crates/proof-of-sql/Cargo.toml index 737050a54..8772aa4ff 100644 --- a/crates/proof-of-sql/Cargo.toml +++ b/crates/proof-of-sql/Cargo.toml @@ -50,7 +50,7 @@ serde = { workspace = true, features = ["serde_derive"] } serde_json = { workspace = true } sha2 = { workspace = true } snafu = { workspace = true } -sqlparser = { workspace = true } +sqlparser = { workspace = true, features = ["serde"] } sysinfo = {workspace = true } tiny-keccak = { workspace = true } tracing = { workspace = true, features = ["attributes"] } diff --git a/crates/proof-of-sql/benches/scaffold/benchmark_accessor.rs b/crates/proof-of-sql/benches/scaffold/benchmark_accessor.rs index 1046fc0e3..568ef9195 100644 --- a/crates/proof-of-sql/benches/scaffold/benchmark_accessor.rs +++ b/crates/proof-of-sql/benches/scaffold/benchmark_accessor.rs @@ -6,15 +6,14 @@ use proof_of_sql::base::{ SchemaAccessor, TableRef, }, }; -use proof_of_sql_parser::Identifier; - +use sqlparser::ast::Ident; #[derive(Default)] pub struct BenchmarkAccessor<'a, C: Commitment> { columns: IndexMap>, lengths: IndexMap, commitments: IndexMap, - column_types: IndexMap<(TableRef, Identifier), ColumnType>, - table_schemas: IndexMap>, + column_types: IndexMap<(TableRef, Ident), ColumnType>, + table_schemas: IndexMap>, } impl<'a, C: Commitment> BenchmarkAccessor<'a, C> { @@ -24,14 +23,14 @@ impl<'a, C: Commitment> BenchmarkAccessor<'a, C> { pub fn insert_table( &mut self, table_ref: TableRef, - columns: &[(Identifier, Column<'a, C::Scalar>)], + columns: &[(Ident, Column<'a, C::Scalar>)], setup: &C::PublicSetup<'_>, ) { self.table_schemas.insert( table_ref, columns .iter() - .map(|(id, col)| (*id, col.column_type())) + .map(|(id, col)| (id.clone(), col.column_type())) .collect(), ); @@ -45,15 +44,15 @@ impl<'a, C: Commitment> BenchmarkAccessor<'a, C> { let mut length = None; for (column, commitment) in columns.iter().zip(commitments) { self.columns.insert( - ColumnRef::new(table_ref, column.0, column.1.column_type()), + ColumnRef::new(table_ref, column.0.clone(), column.1.column_type()), column.1, ); self.commitments.insert( - ColumnRef::new(table_ref, column.0, column.1.column_type()), + ColumnRef::new(table_ref, column.0.clone(), column.1.column_type()), commitment, ); self.column_types - .insert((table_ref, column.0), column.1.column_type()); + .insert((table_ref, column.0.clone()), column.1.column_type()); if let Some(len) = length { assert!(len == column.1.len()); @@ -93,13 +92,13 @@ impl CommitmentAccessor for BenchmarkAccessor<'_, C> { } } impl SchemaAccessor for BenchmarkAccessor<'_, C> { - fn lookup_column(&self, table_ref: TableRef, column_id: Identifier) -> Option { + fn lookup_column(&self, table_ref: TableRef, column_id: Ident) -> Option { self.column_types.get(&(table_ref, column_id)).copied() } /// # Panics /// /// Will panic if the table reference does not exist in the table schemas map. - fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Identifier, ColumnType)> { + fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Ident, ColumnType)> { self.table_schemas.get(&table_ref).unwrap().clone() } } diff --git a/crates/proof-of-sql/benches/scaffold/mod.rs b/crates/proof-of-sql/benches/scaffold/mod.rs index a01619643..0e9aff1be 100644 --- a/crates/proof-of-sql/benches/scaffold/mod.rs +++ b/crates/proof-of-sql/benches/scaffold/mod.rs @@ -33,8 +33,7 @@ fn scaffold<'a, CP: CommitmentEvaluationProof>( &generate_random_columns(alloc, rng, columns, size), prover_setup, ); - let query = - QueryExpr::try_new(query.parse().unwrap(), "bench".parse().unwrap(), accessor).unwrap(); + let query = QueryExpr::try_new(query.parse().unwrap(), "bench".into(), accessor).unwrap(); let result = VerifiableQueryResult::new(query.proof_expr(), accessor, prover_setup); (query, result) } diff --git a/crates/proof-of-sql/benches/scaffold/random_util.rs b/crates/proof-of-sql/benches/scaffold/random_util.rs index a8971eb1b..cdc42775d 100644 --- a/crates/proof-of-sql/benches/scaffold/random_util.rs +++ b/crates/proof-of-sql/benches/scaffold/random_util.rs @@ -3,8 +3,8 @@ use proof_of_sql::base::{ database::{Column, ColumnType}, scalar::Scalar, }; -use proof_of_sql_parser::Identifier; use rand::Rng; +use sqlparser::ast::Ident; pub type OptionalRandBound = Option i64>; /// # Panics @@ -18,12 +18,12 @@ pub fn generate_random_columns<'a, S: Scalar>( rng: &mut impl Rng, columns: &[(&str, ColumnType, OptionalRandBound)], num_rows: usize, -) -> Vec<(Identifier, Column<'a, S>)> { +) -> Vec<(Ident, Column<'a, S>)> { columns .iter() .map(|(id, ty, bound)| { ( - id.parse().unwrap(), + Ident::new(*id), match (ty, bound) { (ColumnType::Boolean, _) => { Column::Boolean(alloc.alloc_slice_fill_with(num_rows, |_| rng.gen())) diff --git a/crates/proof-of-sql/examples/albums/main.rs b/crates/proof-of-sql/examples/albums/main.rs index 5e08ef872..ddcaec30a 100644 --- a/crates/proof-of-sql/examples/albums/main.rs +++ b/crates/proof-of-sql/examples/albums/main.rs @@ -38,8 +38,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "albums".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "albums".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/avocado-prices/avocado-prices.csv b/crates/proof-of-sql/examples/avocado-prices/avocado-prices.csv index 7750f7a46..0663d1284 100644 --- a/crates/proof-of-sql/examples/avocado-prices/avocado-prices.csv +++ b/crates/proof-of-sql/examples/avocado-prices/avocado-prices.csv @@ -1,4 +1,4 @@ -Year,Price +year,price 1990,96 1991,100 1992,269 diff --git a/crates/proof-of-sql/examples/avocado-prices/main.rs b/crates/proof-of-sql/examples/avocado-prices/main.rs index 257bbf933..c1a9c47a1 100644 --- a/crates/proof-of-sql/examples/avocado-prices/main.rs +++ b/crates/proof-of-sql/examples/avocado-prices/main.rs @@ -42,8 +42,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "avocado".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "avocado".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/books/main.rs b/crates/proof-of-sql/examples/books/main.rs index a506fb047..5b20c384c 100644 --- a/crates/proof-of-sql/examples/books/main.rs +++ b/crates/proof-of-sql/examples/books/main.rs @@ -38,8 +38,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "books".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "books".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/brands/brands.csv b/crates/proof-of-sql/examples/brands/brands.csv index f75cdf6a9..26cffcd47 100644 --- a/crates/proof-of-sql/examples/brands/brands.csv +++ b/crates/proof-of-sql/examples/brands/brands.csv @@ -1,4 +1,4 @@ -Name,Country,Founded,Revenue +name,country,founded,revenue Apple,United States,1976,365.82 Samsung,South Korea,1938,200.73 Microsoft,United States,1975,198.27 diff --git a/crates/proof-of-sql/examples/brands/main.rs b/crates/proof-of-sql/examples/brands/main.rs index f648c56e2..4ba531ff9 100644 --- a/crates/proof-of-sql/examples/brands/main.rs +++ b/crates/proof-of-sql/examples/brands/main.rs @@ -38,8 +38,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "brands".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "brands".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/census/census-income.csv b/crates/proof-of-sql/examples/census/census-income.csv index 0accbfc01..f2b709177 100644 --- a/crates/proof-of-sql/examples/census/census-income.csv +++ b/crates/proof-of-sql/examples/census/census-income.csv @@ -1,4 +1,4 @@ -Id,Geography,Id2,Households_Estimate_Total +id,geography,id2,households_estimate_total 0400000US01,Alabama,1,1837292 0400000US02,Alaska,2,250875 0400000US04,Arizona,4,2381501 diff --git a/crates/proof-of-sql/examples/census/main.rs b/crates/proof-of-sql/examples/census/main.rs index c3e82a986..a6af569ac 100644 --- a/crates/proof-of-sql/examples/census/main.rs +++ b/crates/proof-of-sql/examples/census/main.rs @@ -45,8 +45,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "census".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "census".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/countries/countries_gdp.csv b/crates/proof-of-sql/examples/countries/countries_gdp.csv index 397102f8f..521a01900 100644 --- a/crates/proof-of-sql/examples/countries/countries_gdp.csv +++ b/crates/proof-of-sql/examples/countries/countries_gdp.csv @@ -1,4 +1,4 @@ -Country,Continent,GDP,GDPP +country,continent,gdp,gdpp UnitedStates,NorthAmerica,21137,63543 China,Asia,14342,10261 Japan,Asia,5081,40293 diff --git a/crates/proof-of-sql/examples/countries/main.rs b/crates/proof-of-sql/examples/countries/main.rs index 9e1b05205..4fe584271 100644 --- a/crates/proof-of-sql/examples/countries/main.rs +++ b/crates/proof-of-sql/examples/countries/main.rs @@ -39,7 +39,7 @@ fn prove_and_verify_query( println!("Parsing the query: {sql}..."); let now = Instant::now(); let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "countries".parse().unwrap(), accessor).unwrap(); + QueryExpr::try_new(sql.parse().unwrap(), "countries".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/dinosaurs/main.rs b/crates/proof-of-sql/examples/dinosaurs/main.rs index e36c917d9..fc418f5af 100644 --- a/crates/proof-of-sql/examples/dinosaurs/main.rs +++ b/crates/proof-of-sql/examples/dinosaurs/main.rs @@ -39,7 +39,7 @@ fn prove_and_verify_query( println!("Parsing the query: {sql}..."); let now = Instant::now(); let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "dinosaurs".parse().unwrap(), accessor).unwrap(); + QueryExpr::try_new(sql.parse().unwrap(), "dinosaurs".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/dog_breeds/dog_breeds.csv b/crates/proof-of-sql/examples/dog_breeds/dog_breeds.csv index 20f55743e..90d2c2826 100644 --- a/crates/proof-of-sql/examples/dog_breeds/dog_breeds.csv +++ b/crates/proof-of-sql/examples/dog_breeds/dog_breeds.csv @@ -1,4 +1,4 @@ -Name,Origin,Size,Lifespan +name,origin,size,lifespan Labrador Retriever,Canada,Large,12 German Shepherd,Germany,Large,11 Chihuahua,Mexico,Small,14 diff --git a/crates/proof-of-sql/examples/dog_breeds/main.rs b/crates/proof-of-sql/examples/dog_breeds/main.rs index cc5aef3de..e7c58e998 100644 --- a/crates/proof-of-sql/examples/dog_breeds/main.rs +++ b/crates/proof-of-sql/examples/dog_breeds/main.rs @@ -35,12 +35,8 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = QueryExpr::try_new( - sql.parse().unwrap(), - "dog_breeds".parse().unwrap(), - accessor, - ) - .unwrap(); + let query_plan = + QueryExpr::try_new(sql.parse().unwrap(), "dog_breeds".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/hello_world/main.rs b/crates/proof-of-sql/examples/hello_world/main.rs index cfb564ef3..54c97d1b7 100644 --- a/crates/proof-of-sql/examples/hello_world/main.rs +++ b/crates/proof-of-sql/examples/hello_world/main.rs @@ -64,7 +64,7 @@ fn main() { let timer = start_timer("Parsing Query"); let query = QueryExpr::try_new( "SELECT b FROM table WHERE a = 2".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); diff --git a/crates/proof-of-sql/examples/plastics/main.rs b/crates/proof-of-sql/examples/plastics/main.rs index 873fa6ada..681bd1020 100644 --- a/crates/proof-of-sql/examples/plastics/main.rs +++ b/crates/proof-of-sql/examples/plastics/main.rs @@ -38,8 +38,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "plastics".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "plastics".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/plastics/plastics.csv b/crates/proof-of-sql/examples/plastics/plastics.csv index 9b793da0a..d1ba3508d 100644 --- a/crates/proof-of-sql/examples/plastics/plastics.csv +++ b/crates/proof-of-sql/examples/plastics/plastics.csv @@ -1,4 +1,4 @@ -Name,Code,Density,Biodegradable +name,code,density,biodegradable Polyethylene Terephthalate (PET),1,1.38,FALSE High-Density Polyethylene (HDPE),2,0.97,FALSE Polyvinyl Chloride (PVC),3,1.40,FALSE diff --git a/crates/proof-of-sql/examples/posql_db/commit_accessor.rs b/crates/proof-of-sql/examples/posql_db/commit_accessor.rs index b05e96050..5f6edf33f 100644 --- a/crates/proof-of-sql/examples/posql_db/commit_accessor.rs +++ b/crates/proof-of-sql/examples/posql_db/commit_accessor.rs @@ -55,7 +55,7 @@ impl SchemaAccessor for CommitAccessor { fn lookup_column( &self, table_ref: proof_of_sql::base::database::TableRef, - column_id: proof_of_sql_parser::Identifier, + column_id: sqlparser::ast::Ident, ) -> Option { self.inner.lookup_column(table_ref, column_id) } @@ -64,7 +64,7 @@ impl SchemaAccessor for CommitAccessor { &self, table_ref: proof_of_sql::base::database::TableRef, ) -> Vec<( - proof_of_sql_parser::Identifier, + sqlparser::ast::Ident, proof_of_sql::base::database::ColumnType, )> { self.inner.lookup_schema(table_ref) diff --git a/crates/proof-of-sql/examples/posql_db/csv_accessor.rs b/crates/proof-of-sql/examples/posql_db/csv_accessor.rs index f513db428..e5d401471 100644 --- a/crates/proof-of-sql/examples/posql_db/csv_accessor.rs +++ b/crates/proof-of-sql/examples/posql_db/csv_accessor.rs @@ -96,7 +96,7 @@ impl SchemaAccessor for CsvDataAccessor { fn lookup_column( &self, table_ref: TableRef, - column_id: proof_of_sql_parser::Identifier, + column_id: sqlparser::ast::Ident, ) -> Option { self.inner.lookup_column(table_ref, column_id) } @@ -104,7 +104,7 @@ impl SchemaAccessor for CsvDataAccessor { &self, table_ref: TableRef, ) -> Vec<( - proof_of_sql_parser::Identifier, + sqlparser::ast::Ident, proof_of_sql::base::database::ColumnType, )> { self.inner.lookup_schema(table_ref) diff --git a/crates/proof-of-sql/examples/posql_db/main.rs b/crates/proof-of-sql/examples/posql_db/main.rs index 40c9a14ae..d591651e2 100644 --- a/crates/proof-of-sql/examples/posql_db/main.rs +++ b/crates/proof-of-sql/examples/posql_db/main.rs @@ -203,7 +203,7 @@ fn main() { commit_accessor .lookup_schema(table_name) .iter() - .map(|(i, t)| Field::new(i.as_str(), t.into(), false)) + .map(|(i, t)| Field::new(i.value.as_str(), t.into(), false)) .collect::>(), ); let append_batch = @@ -233,15 +233,14 @@ fn main() { commit_accessor .lookup_schema(table) .iter() - .map(|(i, t)| Field::new(i.as_str(), t.into(), false)) + .map(|(i, t)| Field::new(i.value.as_str(), t.into(), false)) .collect::>(), ); csv_accessor .load_table(table, schema) .expect("Failed to load table"); } - let query = - QueryExpr::try_new(query, "example".parse().unwrap(), &commit_accessor).unwrap(); + let query = QueryExpr::try_new(query, "example".into(), &commit_accessor).unwrap(); let timer = start_timer("Generating Proof"); let proof = VerifiableQueryResult::::new( query.proof_expr(), @@ -265,8 +264,7 @@ fn main() { .load_commit(table_name) .expect("Failed to load commit"); } - let query = - QueryExpr::try_new(query, "example".parse().unwrap(), &commit_accessor).unwrap(); + let query = QueryExpr::try_new(query, "example".into(), &commit_accessor).unwrap(); let result: VerifiableQueryResult = postcard::from_bytes(&fs::read(file).expect("Failed to read proof")) .expect("Failed to deserialize proof"); diff --git a/crates/proof-of-sql/examples/posql_db/record_batch_accessor.rs b/crates/proof-of-sql/examples/posql_db/record_batch_accessor.rs index 08e25f4fe..fc84cb14a 100644 --- a/crates/proof-of-sql/examples/posql_db/record_batch_accessor.rs +++ b/crates/proof-of-sql/examples/posql_db/record_batch_accessor.rs @@ -8,8 +8,7 @@ use proof_of_sql::base::{ }, scalar::Scalar, }; -use proof_of_sql_parser::Identifier; - +use sqlparser::ast::Ident; #[derive(Default)] /// An implementation of a data accessor that uses a record batch as the underlying data source. /// @@ -31,7 +30,7 @@ impl DataAccessor for RecordBatchAccessor { .get(&column.table_ref()) .expect("Table not found."); let arrow_column = table - .column_by_name(column.column_id().as_str()) + .column_by_name(column.column_id().value.as_str()) .expect("Column not found."); let result = arrow_column .to_column(&self.alloc, &(0..table.num_rows()), None) @@ -58,12 +57,12 @@ impl MetadataAccessor for RecordBatchAccessor { } } impl SchemaAccessor for RecordBatchAccessor { - fn lookup_column(&self, table_ref: TableRef, column_id: Identifier) -> Option { + fn lookup_column(&self, table_ref: TableRef, column_id: Ident) -> Option { self.tables .get(&table_ref) .expect("Table not found.") .schema() - .column_with_name(column_id.as_str()) + .column_with_name(column_id.value.as_str()) .map(|(_, f)| { f.data_type() .clone() @@ -72,7 +71,7 @@ impl SchemaAccessor for RecordBatchAccessor { }) } - fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Identifier, ColumnType)> { + fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Ident, ColumnType)> { self.tables .get(&table_ref) .expect("Table not found.") @@ -81,7 +80,7 @@ impl SchemaAccessor for RecordBatchAccessor { .iter() .map(|field| { ( - field.name().parse().expect("Failed to parse field name."), + Ident::new(field.name()), field .data_type() .clone() diff --git a/crates/proof-of-sql/examples/programming_books/main.rs b/crates/proof-of-sql/examples/programming_books/main.rs index b8f09da9a..1b125e16d 100644 --- a/crates/proof-of-sql/examples/programming_books/main.rs +++ b/crates/proof-of-sql/examples/programming_books/main.rs @@ -35,12 +35,8 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = QueryExpr::try_new( - sql.parse().unwrap(), - "programming_books".parse().unwrap(), - accessor, - ) - .unwrap(); + let query_plan = + QueryExpr::try_new(sql.parse().unwrap(), "programming_books".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/rockets/main.rs b/crates/proof-of-sql/examples/rockets/main.rs index 331913f83..e4cfc89fa 100644 --- a/crates/proof-of-sql/examples/rockets/main.rs +++ b/crates/proof-of-sql/examples/rockets/main.rs @@ -38,8 +38,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "rockets".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "rockets".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/space/main.rs b/crates/proof-of-sql/examples/space/main.rs index 37f0356fd..88056f1e6 100644 --- a/crates/proof-of-sql/examples/space/main.rs +++ b/crates/proof-of-sql/examples/space/main.rs @@ -47,8 +47,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "space".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "space".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/space/planets.csv b/crates/proof-of-sql/examples/space/planets.csv index cc0d42a63..e249a9986 100644 --- a/crates/proof-of-sql/examples/space/planets.csv +++ b/crates/proof-of-sql/examples/space/planets.csv @@ -1,4 +1,4 @@ -Name,Distance,Dwarf,Density +name,distance,dwarf,density Mercury,36,FALSE,5400 Venus,67,FALSE,5200 Earth,93,FALSE,5500 diff --git a/crates/proof-of-sql/examples/space/space_travellers.csv b/crates/proof-of-sql/examples/space/space_travellers.csv index 47bce0fe5..1356b9a25 100644 --- a/crates/proof-of-sql/examples/space/space_travellers.csv +++ b/crates/proof-of-sql/examples/space/space_travellers.csv @@ -1,4 +1,4 @@ -Number,Name,Nationality,Date,Flight +number,name,nationality,date,flight 1,Yuri Gagarin,Soviet Union,1961-04-12T00:00:00Z,Vostok 1 2,Alan Shepard,United States,1961-05-05,Freedom 7 3,Virgil Grissom,United States,1961-07-21,Liberty Bell 7 diff --git a/crates/proof-of-sql/examples/stocks/main.rs b/crates/proof-of-sql/examples/stocks/main.rs index 7a34ec96e..0d1907f7a 100644 --- a/crates/proof-of-sql/examples/stocks/main.rs +++ b/crates/proof-of-sql/examples/stocks/main.rs @@ -38,8 +38,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "stocks".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "stocks".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/stocks/stocks.csv b/crates/proof-of-sql/examples/stocks/stocks.csv index 1e00ae80d..52a400471 100644 --- a/crates/proof-of-sql/examples/stocks/stocks.csv +++ b/crates/proof-of-sql/examples/stocks/stocks.csv @@ -1,4 +1,4 @@ -Symbol,Company,Sector,Price,Volume,MarketCap,PE_Ratio,DividendYield +symbol,company,sector,price,volume,marketcap,pe_ratio,dividendyield AAPL,Apple Inc.,Technology,175.50,52000000,2850.25,28.5,0.5 MSFT,Microsoft Corporation,Technology,325.75,25000000,2425.80,32.8,0.8 GOOGL,Alphabet Inc.,Technology,135.20,18000000,1720.40,25.2,0.0 diff --git a/crates/proof-of-sql/examples/sushi/fish.csv b/crates/proof-of-sql/examples/sushi/fish.csv index e0a14ebc0..d6a65c3b8 100644 --- a/crates/proof-of-sql/examples/sushi/fish.csv +++ b/crates/proof-of-sql/examples/sushi/fish.csv @@ -1,4 +1,4 @@ -nameEn,nameJa,kindEn,kindJa,pricePerPound +name_en,name_ja,kind_en,kind_ja,price_per_pound Tuna,Maguro,Lean Red Meat,Akami,25 Tuna,Maguro,Medium Fat Red Meat,Toro,65 Tuna,Maguro,Fatty Red Meat,Otoro,115 diff --git a/crates/proof-of-sql/examples/sushi/main.rs b/crates/proof-of-sql/examples/sushi/main.rs index 05e78bdd8..c3fbc0050 100644 --- a/crates/proof-of-sql/examples/sushi/main.rs +++ b/crates/proof-of-sql/examples/sushi/main.rs @@ -29,8 +29,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "sushi".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "sushi".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: print!("Generating proof..."); @@ -88,42 +87,42 @@ fn main() { ); prove_and_verify_query( - "SELECT COUNT(*) FROM fish WHERE nameEn = 'Tuna'", + "SELECT COUNT(*) FROM fish WHERE name_En = 'Tuna'", &accessor, &prover_setup, &verifier_setup, ); prove_and_verify_query( - "SELECT kindEn FROM fish WHERE kindJa = 'Otoro'", + "SELECT kind_En FROM fish WHERE kind_Ja = 'Otoro'", &accessor, &prover_setup, &verifier_setup, ); prove_and_verify_query( - "SELECT kindEn FROM fish WHERE kindJa = 'Otoro'", + "SELECT kind_En FROM fish WHERE kind_Ja = 'Otoro'", &accessor, &prover_setup, &verifier_setup, ); prove_and_verify_query( - "SELECT * FROM fish WHERE pricePerPound > 25 AND pricePerPound < 75", + "SELECT * FROM fish WHERE price_Per_Pound > 25 AND price_Per_Pound < 75", &accessor, &prover_setup, &verifier_setup, ); prove_and_verify_query( - "SELECT kindJa, COUNT(*) FROM fish GROUP BY kindJa", + "SELECT kind_Ja, COUNT(*) FROM fish GROUP BY kind_Ja", &accessor, &prover_setup, &verifier_setup, ); prove_and_verify_query( - "SELECT kindJa, pricePerPound FROM fish WHERE nameEn = 'Tuna' ORDER BY pricePerPound ASC", + "SELECT kind_Ja, price_Per_Pound FROM fish WHERE name_En = 'Tuna' ORDER BY price_Per_Pound ASC", &accessor, &prover_setup, &verifier_setup, diff --git a/crates/proof-of-sql/examples/tech_gadget_prices/main.rs b/crates/proof-of-sql/examples/tech_gadget_prices/main.rs index c886ef777..ad041fdfb 100644 --- a/crates/proof-of-sql/examples/tech_gadget_prices/main.rs +++ b/crates/proof-of-sql/examples/tech_gadget_prices/main.rs @@ -14,6 +14,7 @@ use proof_of_sql::{ sql::{parse::QueryExpr, proof::VerifiableQueryResult}, }; use rand::{rngs::StdRng, SeedableRng}; +use sqlparser::ast::Ident; use std::{error::Error, fs::File, time::Instant}; const DORY_SETUP_MAX_NU: usize = 8; @@ -27,7 +28,7 @@ fn prove_and_verify_query( ) -> Result<(), Box> { println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = QueryExpr::try_new(sql.parse()?, "tech_gadget_prices".parse()?, accessor)?; + let query_plan = QueryExpr::try_new(sql.parse()?, Ident::new("tech_gadget_prices"), accessor)?; println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); print!("Generating proof..."); diff --git a/crates/proof-of-sql/examples/tech_gadget_prices/tech_gadget_prices.csv b/crates/proof-of-sql/examples/tech_gadget_prices/tech_gadget_prices.csv index e03e8e90d..b6d24771b 100644 --- a/crates/proof-of-sql/examples/tech_gadget_prices/tech_gadget_prices.csv +++ b/crates/proof-of-sql/examples/tech_gadget_prices/tech_gadget_prices.csv @@ -1,4 +1,4 @@ -Name,Brand,Category,ReleaseYear,Price +name,brand,category,releaseyear,price iPhone 13,Apple,Smartphone,2021,799 Galaxy S21,Samsung,Smartphone,2021,799 PlayStation 5,Sony,Game Console,2020,499 diff --git a/crates/proof-of-sql/examples/vehicles/main.rs b/crates/proof-of-sql/examples/vehicles/main.rs index e06a9f974..b459d52c7 100644 --- a/crates/proof-of-sql/examples/vehicles/main.rs +++ b/crates/proof-of-sql/examples/vehicles/main.rs @@ -38,8 +38,7 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = - QueryExpr::try_new(sql.parse().unwrap(), "vehicles".parse().unwrap(), accessor).unwrap(); + let query_plan = QueryExpr::try_new(sql.parse().unwrap(), "vehicles".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/examples/wood_types/main.rs b/crates/proof-of-sql/examples/wood_types/main.rs index 30f364da8..1e0ce4259 100644 --- a/crates/proof-of-sql/examples/wood_types/main.rs +++ b/crates/proof-of-sql/examples/wood_types/main.rs @@ -38,12 +38,8 @@ fn prove_and_verify_query( // Parse the query: println!("Parsing the query: {sql}..."); let now = Instant::now(); - let query_plan = QueryExpr::try_new( - sql.parse().unwrap(), - "wood_types".parse().unwrap(), - accessor, - ) - .unwrap(); + let query_plan = + QueryExpr::try_new(sql.parse().unwrap(), "wood_types".into(), accessor).unwrap(); println!("Done in {} ms.", now.elapsed().as_secs_f64() * 1000.); // Generate the proof and result: diff --git a/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs b/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs index 5eade6cf3..dbd218383 100644 --- a/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs +++ b/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs @@ -71,7 +71,7 @@ impl TryFrom for ColumnType { impl From<&ColumnField> for Field { fn from(column_field: &ColumnField) -> Self { Field::new( - column_field.name().name(), + column_field.name().value.as_str(), (&column_field.data_type()).into(), false, ) diff --git a/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs index 74ad96839..3b92bf911 100644 --- a/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs +++ b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs @@ -32,9 +32,10 @@ use arrow::{ }; use proof_of_sql_parser::{ posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestampError}, - Identifier, ParseError, + ParseError, }; use snafu::Snafu; +use sqlparser::ast::Ident; #[derive(Snafu, Debug)] #[non_exhaustive] @@ -123,7 +124,9 @@ impl TryFrom> for RecordBatch { value .into_inner() .into_iter() - .map(|(identifier, owned_column)| (identifier, ArrayRef::from(owned_column))), + .map(|(identifier, owned_column)| { + (identifier.value, ArrayRef::from(owned_column)) + }), ) } } @@ -300,7 +303,7 @@ impl TryFrom for OwnedTable { .zip(value.columns()) .map(|(field, array_ref)| { let owned_column = OwnedColumn::try_from(array_ref)?; - let identifier = Identifier::try_new(field.name())?; //This may always succeed. + let identifier = Ident::new(field.name()); Ok((identifier, owned_column)) }) .collect(); diff --git a/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions_test.rs b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions_test.rs index b6d6a773c..de9db972e 100644 --- a/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions_test.rs +++ b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions_test.rs @@ -129,18 +129,6 @@ fn we_can_convert_between_owned_table_and_record_batch() { ); } -#[test] -fn we_cannot_convert_a_record_batch_if_it_has_repeated_column_names() { - let record_batch = record_batch!( - "a" => [0_i64; 0], - "A" => [0_i128; 0], - ); - assert!(matches!( - OwnedTable::::try_from(record_batch), - Err(OwnedArrowConversionError::DuplicateIdentifiers) - )); -} - #[test] #[should_panic(expected = "not implemented: Cannot convert Scalar type to arrow type")] fn we_panic_when_converting_an_owned_table_with_a_scalar_column() { diff --git a/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs b/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs index 6f24457cc..9dcfe89e1 100644 --- a/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs +++ b/crates/proof-of-sql/src/base/arrow/record_batch_conversion.rs @@ -12,7 +12,7 @@ use crate::base::{ }; use arrow::record_batch::RecordBatch; use bumpalo::Bump; -use proof_of_sql_parser::Identifier; +use sqlparser::ast::Ident; /// This function will return an error if: /// - The field name cannot be parsed into an [`Identifier`]. @@ -20,14 +20,14 @@ use proof_of_sql_parser::Identifier; pub fn batch_to_columns<'a, S: Scalar + 'a>( batch: &'a RecordBatch, alloc: &'a Bump, -) -> Result)>, RecordBatchToColumnsError> { +) -> Result)>, RecordBatchToColumnsError> { batch .schema() .fields() .into_iter() .zip(batch.columns()) .map(|(field, array)| { - let identifier: Identifier = field.name().parse()?; + let identifier: Ident = field.name().as_str().into(); let column: Column = array.to_column(alloc, &(0..array.len()), None)?; Ok((identifier, column)) }) @@ -115,12 +115,9 @@ mod tests { let b_scals = ["1".into(), "2".into(), "3".into()]; let columns = [ + (&"a".into(), &Column::::BigInt(&[1, 2, 3])), ( - &"a".parse().unwrap(), - &Column::::BigInt(&[1, 2, 3]), - ), - ( - &"b".parse().unwrap(), + &"b".into(), &Column::::VarChar((&["1", "2", "3"], &b_scals)), ), ]; @@ -142,12 +139,9 @@ mod tests { let b_scals2 = ["4".into(), "5".into(), "6".into()]; let columns2 = [ + (&"a".into(), &Column::::BigInt(&[4, 5, 6])), ( - &"a".parse().unwrap(), - &Column::::BigInt(&[4, 5, 6]), - ), - ( - &"b".parse().unwrap(), + &"b".into(), &Column::::VarChar((&["4", "5", "6"], &b_scals2)), ), ]; diff --git a/crates/proof-of-sql/src/base/commitment/column_commitment_metadata_map.rs b/crates/proof-of-sql/src/base/commitment/column_commitment_metadata_map.rs index ab0b90e4b..3e2a41ef3 100644 --- a/crates/proof-of-sql/src/base/commitment/column_commitment_metadata_map.rs +++ b/crates/proof-of-sql/src/base/commitment/column_commitment_metadata_map.rs @@ -4,11 +4,11 @@ use super::{ }; use crate::base::{database::ColumnField, map::IndexMap}; use alloc::string::{String, ToString}; -use proof_of_sql_parser::Identifier; use snafu::Snafu; +use sqlparser::ast::Ident; /// Mapping of column identifiers to column metadata used to associate metadata with commitments. -pub type ColumnCommitmentMetadataMap = IndexMap; +pub type ColumnCommitmentMetadataMap = IndexMap; /// During commitment operation, metadata indicates that operand tables cannot be the same. #[derive(Debug, Snafu)] @@ -28,7 +28,7 @@ pub enum ColumnCommitmentsMismatch { #[snafu(display( "column with identifier {id_a} cannot operate with column with identifier {id_b}" ))] - Identifier { + Ident { /// The first column identifier id_a: String, /// The second column identifier @@ -44,7 +44,7 @@ pub trait ColumnCommitmentMetadataMapExt { /// Construct this mapping from an iterator of column identifiers and columns. fn from_columns<'a>( - columns: impl IntoIterator)>, + columns: impl IntoIterator)>, ) -> Self where Self: Sized; @@ -66,7 +66,7 @@ impl ColumnCommitmentMetadataMapExt for ColumnCommitmentMetadataMap { .iter() .map(|f| { ( - f.name(), + f.name().clone(), ColumnCommitmentMetadata::from_column_type_with_max_bounds(f.data_type()), ) }) @@ -74,7 +74,7 @@ impl ColumnCommitmentMetadataMapExt for ColumnCommitmentMetadataMap { } fn from_columns<'a>( - columns: impl IntoIterator)>, + columns: impl IntoIterator)>, ) -> Self where Self: Sized, @@ -82,7 +82,10 @@ impl ColumnCommitmentMetadataMapExt for ColumnCommitmentMetadataMap { columns .into_iter() .map(|(identifier, column)| { - (*identifier, ColumnCommitmentMetadata::from_column(column)) + ( + identifier.clone(), + ColumnCommitmentMetadata::from_column(column), + ) }) .collect() } @@ -99,7 +102,7 @@ impl ColumnCommitmentMetadataMapExt for ColumnCommitmentMetadataMap { .zip(other) .map(|((identifier_a, metadata_a), (identifier_b, metadata_b))| { if identifier_a != identifier_b { - Err(ColumnCommitmentsMismatch::Identifier { + Err(ColumnCommitmentsMismatch::Ident { id_a: identifier_a.to_string(), id_b: identifier_b.to_string(), })?; @@ -122,7 +125,7 @@ impl ColumnCommitmentMetadataMapExt for ColumnCommitmentMetadataMap { .zip(other) .map(|((identifier_a, metadata_a), (identifier_b, metadata_b))| { if identifier_a != identifier_b { - Err(ColumnCommitmentsMismatch::Identifier { + Err(ColumnCommitmentsMismatch::Ident { id_a: identifier_a.to_string(), id_b: identifier_b.to_string(), })?; @@ -148,7 +151,7 @@ mod tests { fn metadata_map_from_owned_table( table: &OwnedTable, ) -> ColumnCommitmentMetadataMap { - let (identifiers, columns): (Vec<&Identifier>, Vec) = table + let (identifiers, columns): (Vec<&Ident>, Vec) = table .inner_table() .into_iter() .map(|(identifier, owned_column)| (identifier, CommittableColumn::from(owned_column))) @@ -176,7 +179,7 @@ mod tests { assert_eq!(metadata_map.len(), 4); let (index_0, metadata_0) = metadata_map.get_index(0).unwrap(); - assert_eq!(index_0, "bigint_column"); + assert_eq!(index_0.value.as_str(), "bigint_column"); assert_eq!(metadata_0.column_type(), &ColumnType::BigInt); if let ColumnBounds::BigInt(Bounds::Sharp(bounds)) = metadata_0.bounds() { assert_eq!(bounds.min(), &-5); @@ -186,7 +189,7 @@ mod tests { } let (index_1, metadata_1) = metadata_map.get_index(1).unwrap(); - assert_eq!(index_1, "int128_column"); + assert_eq!(index_1.value.as_str(), "int128_column"); assert_eq!(metadata_1.column_type(), &ColumnType::Int128); if let ColumnBounds::Int128(Bounds::Sharp(bounds)) = metadata_1.bounds() { assert_eq!(bounds.min(), &100); @@ -196,12 +199,12 @@ mod tests { } let (index_2, metadata_2) = metadata_map.get_index(2).unwrap(); - assert_eq!(index_2, "varchar_column"); + assert_eq!(index_2.value.as_str(), "varchar_column"); assert_eq!(metadata_2.column_type(), &ColumnType::VarChar); assert_eq!(metadata_2.bounds(), &ColumnBounds::NoOrder); let (index_3, metadata_3) = metadata_map.get_index(3).unwrap(); - assert_eq!(index_3, "scalar_column"); + assert_eq!(index_3.value.as_str(), "scalar_column"); assert_eq!(metadata_3.column_type(), &ColumnType::Scalar); assert_eq!(metadata_3.bounds(), &ColumnBounds::NoOrder); } @@ -258,7 +261,7 @@ mod tests { // Check metatadata for ordered columns is mostly the same (now bounded) let (index_0, metadata_0) = b_difference_a.get_index(0).unwrap(); - assert_eq!(index_0, "bigint_column"); + assert_eq!(index_0.value.as_str(), "bigint_column"); assert_eq!(metadata_0.column_type(), &ColumnType::BigInt); if let ColumnBounds::BigInt(Bounds::Bounded(bounds)) = metadata_0.bounds() { assert_eq!(bounds.min(), &-5); @@ -268,7 +271,7 @@ mod tests { } let (index_1, metadata_1) = b_difference_a.get_index(1).unwrap(); - assert_eq!(index_1, "int128_column"); + assert_eq!(index_1.value.as_str(), "int128_column"); assert_eq!(metadata_1.column_type(), &ColumnType::Int128); if let ColumnBounds::Int128(Bounds::Bounded(bounds)) = metadata_1.bounds() { assert_eq!(bounds.min(), &100); diff --git a/crates/proof-of-sql/src/base/commitment/column_commitments.rs b/crates/proof-of-sql/src/base/commitment/column_commitments.rs index 3a0e8b62f..7f4ceb6bb 100644 --- a/crates/proof-of-sql/src/base/commitment/column_commitments.rs +++ b/crates/proof-of-sql/src/base/commitment/column_commitments.rs @@ -12,9 +12,9 @@ use alloc::{ vec::Vec, }; use core::{iter, slice}; -use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; use snafu::Snafu; +use sqlparser::ast::Ident; /// Cannot create commitments with duplicate identifier. #[derive(Debug, Snafu)] @@ -99,7 +99,7 @@ impl ColumnCommitments { /// Returns the commitment with the given identifier. #[must_use] - pub fn get_commitment(&self, identifier: &Identifier) -> Option { + pub fn get_commitment(&self, identifier: &Ident) -> Option { self.column_metadata .get_index_of(identifier) .map(|index| self.commitments[index].clone()) @@ -107,7 +107,7 @@ impl ColumnCommitments { /// Returns the metadata for the commitment with the given identifier. #[must_use] - pub fn get_metadata(&self, identifier: &Identifier) -> Option<&ColumnCommitmentMetadata> { + pub fn get_metadata(&self, identifier: &Ident) -> Option<&ColumnCommitmentMetadata> { self.column_metadata.get(identifier) } @@ -118,7 +118,7 @@ impl ColumnCommitments { /// Returns [`ColumnCommitments`] to the provided columns using the given generator offset pub fn try_from_columns_with_offset<'a, COL>( - columns: impl IntoIterator, + columns: impl IntoIterator, offset: usize, setup: &C::PublicSetup<'_>, ) -> Result, DuplicateIdentifiers> @@ -140,7 +140,7 @@ impl ColumnCommitments { }) .collect::, _>>()?; - let (identifiers, committable_columns): (Vec<&Identifier>, Vec) = + let (identifiers, committable_columns): (Vec<&Ident>, Vec) = unique_columns .into_iter() .map(|(identifier, column)| { @@ -171,7 +171,7 @@ impl ColumnCommitments { #[allow(clippy::missing_panics_doc)] pub fn try_append_rows_with_offset<'a, COL>( &mut self, - columns: impl IntoIterator, + columns: impl IntoIterator, offset: usize, setup: &C::PublicSetup<'_>, ) -> Result<(), AppendColumnCommitmentsError> @@ -193,7 +193,7 @@ impl ColumnCommitments { }) .collect::, _>>()?; - let (identifiers, committable_columns): (Vec<&Identifier>, Vec) = + let (identifiers, committable_columns): (Vec<&Ident>, Vec) = unique_columns .into_iter() .map(|(identifier, column)| { @@ -218,7 +218,7 @@ impl ColumnCommitments { /// Add new columns to this [`ColumnCommitments`] using the given generator offset. pub fn try_extend_columns_with_offset<'a, COL>( &mut self, - columns: impl IntoIterator, + columns: impl IntoIterator, offset: usize, setup: &C::PublicSetup<'_>, ) -> Result<(), DuplicateIdentifiers> @@ -302,11 +302,11 @@ impl ColumnCommitments { /// Owning iterator for [`ColumnCommitments`]. pub type IntoIter = iter::Map< iter::Zip<::IntoIter, vec::IntoIter>, - fn(((Identifier, ColumnCommitmentMetadata), C)) -> (Identifier, ColumnCommitmentMetadata, C), + fn(((Ident, ColumnCommitmentMetadata), C)) -> (Ident, ColumnCommitmentMetadata, C), >; impl IntoIterator for ColumnCommitments { - type Item = (Identifier, ColumnCommitmentMetadata, C); + type Item = (Ident, ColumnCommitmentMetadata, C); type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { self.column_metadata @@ -320,12 +320,12 @@ impl IntoIterator for ColumnCommitments { pub type Iter<'a, C> = iter::Map< iter::Zip<<&'a ColumnCommitmentMetadataMap as IntoIterator>::IntoIter, slice::Iter<'a, C>>, fn( - ((&'a Identifier, &'a ColumnCommitmentMetadata), &'a C), - ) -> (&'a Identifier, &'a ColumnCommitmentMetadata, &'a C), + ((&'a Ident, &'a ColumnCommitmentMetadata), &'a C), + ) -> (&'a Ident, &'a ColumnCommitmentMetadata, &'a C), >; impl<'a, C> IntoIterator for &'a ColumnCommitments { - type Item = (&'a Identifier, &'a ColumnCommitmentMetadata, &'a C); + type Item = (&'a Ident, &'a ColumnCommitmentMetadata, &'a C); type IntoIter = Iter<'a, C>; fn into_iter(self) -> Self::IntoIter { self.column_metadata @@ -335,10 +335,8 @@ impl<'a, C> IntoIterator for &'a ColumnCommitments { } } -impl FromIterator<(Identifier, ColumnCommitmentMetadata, C)> for ColumnCommitments { - fn from_iter>( - iter: T, - ) -> Self { +impl FromIterator<(Ident, ColumnCommitmentMetadata, C)> for ColumnCommitments { + fn from_iter>(iter: T) -> Self { let (column_metadata, commitments) = iter .into_iter() .map(|(identifier, metadata, commitment)| ((identifier, metadata), commitment)) @@ -374,15 +372,18 @@ mod tests { assert!(column_commitments.column_metadata().is_empty()); // nonempty case - let bigint_id: Identifier = "bigint_column".parse().unwrap(); - let varchar_id: Identifier = "varchar_column".parse().unwrap(); - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); + let varchar_id: Ident = "varchar_column".into(); + let scalar_id: Ident = "scalar_column".into(); let owned_table = owned_table::([ - bigint(bigint_id, [1, 5, -5, 0]), + bigint(bigint_id.value.as_str(), [1, 5, -5, 0]), // "int128_column" => [100i128, 200, 300, 400], TODO: enable this column once blitzar // supports it - varchar(varchar_id, ["Lorem", "ipsum", "dolor", "sit"]), - scalar(scalar_id, [1000, 2000, -1000, 0]), + varchar( + varchar_id.value.as_str(), + ["Lorem", "ipsum", "dolor", "sit"], + ), + scalar(scalar_id.value.as_str(), [1000, 2000, -1000, 0]), ]); let column_commitments = @@ -449,15 +450,18 @@ mod tests { #[test] fn we_can_construct_column_commitments_from_iter() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); - let varchar_id: Identifier = "varchar_column".parse().unwrap(); - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); + let varchar_id: Ident = "varchar_column".into(); + let scalar_id: Ident = "scalar_column".into(); let owned_table = owned_table::([ - bigint(bigint_id, [1, 5, -5, 0]), + bigint(bigint_id.value.as_str(), [1, 5, -5, 0]), // "int128_column" => [100i128, 200, 300, 400], TODO: enable this column once blitzar // supports it - varchar(varchar_id, ["Lorem", "ipsum", "dolor", "sit"]), - scalar(scalar_id, [1000, 2000, -1000, 0]), + varchar( + varchar_id.value.as_str(), + ["Lorem", "ipsum", "dolor", "sit"], + ), + scalar(scalar_id.value.as_str(), [1000, 2000, -1000, 0]), ]); let column_commitments_from_columns = @@ -478,9 +482,9 @@ mod tests { } #[test] fn we_cannot_construct_commitments_with_duplicate_identifiers() { - let duplicate_identifier_a = "duplicate_identifier_a".parse().unwrap(); - let duplicate_identifier_b = "duplicate_identifier_b".parse().unwrap(); - let unique_identifier = "unique_identifier".parse().unwrap(); + let duplicate_identifier_a = "duplicate_identifier_a".into(); + let duplicate_identifier_b = "duplicate_identifier_b".into(); + let unique_identifier = "unique_identifier".into(); let empty_column = OwnedColumn::::BigInt(vec![]); @@ -548,13 +552,16 @@ mod tests { #[test] fn we_can_iterate_over_column_commitments() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); - let varchar_id: Identifier = "varchar_column".parse().unwrap(); - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); + let varchar_id: Ident = "varchar_column".into(); + let scalar_id: Ident = "scalar_column".into(); let owned_table = owned_table::([ - bigint(bigint_id, [1, 5, -5, 0]), - varchar(varchar_id, ["Lorem", "ipsum", "dolor", "sit"]), - scalar(scalar_id, [1000, 2000, -1000, 0]), + bigint(bigint_id.value.as_str(), [1, 5, -5, 0]), + varchar( + varchar_id.value.as_str(), + ["Lorem", "ipsum", "dolor", "sit"], + ), + scalar(scalar_id.value.as_str(), [1000, 2000, -1000, 0]), ]); let column_commitments = ColumnCommitments::::try_from_columns_with_offset( @@ -590,19 +597,19 @@ mod tests { #[test] fn we_can_append_rows_to_column_commitments() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let initial_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..2].to_vec()), - varchar(varchar_id, varchar_data[..2].to_vec()), - scalar(scalar_id, scalar_data[..2].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..2].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..2].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..2].to_vec()), ]); let mut column_commitments = @@ -614,9 +621,9 @@ mod tests { .unwrap(); let append_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[2..].to_vec()), - varchar(varchar_id, varchar_data[2..].to_vec()), - scalar(scalar_id, scalar_data[2..].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[2..].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[2..].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[2..].to_vec()), ]); column_commitments @@ -624,9 +631,9 @@ mod tests { .unwrap(); let total_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let expected_column_commitments = @@ -672,7 +679,7 @@ mod tests { assert!(matches!( base_commitments.try_append_rows_with_offset(table_diff_id.inner_table(), 4, &()), Err(AppendColumnCommitmentsError::Mismatch { - source: ColumnCommitmentsMismatch::Identifier { .. } + source: ColumnCommitmentsMismatch::Ident { .. } }) )); @@ -688,18 +695,18 @@ mod tests { #[test] fn we_can_extend_columns_to_column_commitments() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let initial_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), ]); let mut column_commitments = ColumnCommitments::::try_from_columns_with_offset( @@ -709,15 +716,16 @@ mod tests { ) .unwrap(); - let new_columns = owned_table::([scalar(scalar_id, scalar_data)]); + let new_columns = + owned_table::([scalar(scalar_id.value.as_str(), scalar_data)]); column_commitments .try_extend_columns_with_offset(new_columns.inner_table(), 0, &()) .unwrap(); let expected_columns = owned_table::([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let expected_commitments = ColumnCommitments::try_from_columns_with_offset(expected_columns.inner_table(), 0, &()) @@ -728,19 +736,19 @@ mod tests { #[test] fn we_can_add_column_commitments() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let columns_a: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..2].to_vec()), - varchar(varchar_id, varchar_data[..2].to_vec()), - scalar(scalar_id, scalar_data[..2].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..2].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..2].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..2].to_vec()), ]); let column_commitments_a = @@ -752,18 +760,18 @@ mod tests { .unwrap(); let columns_b: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[2..].to_vec()), - varchar(varchar_id, varchar_data[2..].to_vec()), - scalar(scalar_id, scalar_data[2..].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[2..].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[2..].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[2..].to_vec()), ]); let column_commitments_b = ColumnCommitments::try_from_columns_with_offset(columns_b.inner_table(), 2, &()) .unwrap(); let columns_sum: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let column_commitments_sum = ColumnCommitments::try_from_columns_with_offset(columns_sum.inner_table(), 0, &()) @@ -809,7 +817,7 @@ mod tests { .unwrap(); assert!(matches!( base_commitments.clone().try_add(commitments_diff_id), - Err(ColumnCommitmentsMismatch::Identifier { .. }) + Err(ColumnCommitmentsMismatch::Ident { .. }) )); let table_diff_len: OwnedTable = @@ -825,19 +833,19 @@ mod tests { #[test] fn we_can_sub_column_commitments() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let columns_subtrahend: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..2].to_vec()), - varchar(varchar_id, varchar_data[..2].to_vec()), - scalar(scalar_id, scalar_data[..2].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..2].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..2].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..2].to_vec()), ]); let column_commitments_subtrahend = @@ -849,9 +857,9 @@ mod tests { .unwrap(); let columns_minuend: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let column_commitments_minuend = ColumnCommitments::try_from_columns_with_offset(columns_minuend.inner_table(), 0, &()) @@ -862,9 +870,9 @@ mod tests { .unwrap(); let expected_difference_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[2..].to_vec()), - varchar(varchar_id, varchar_data[2..].to_vec()), - scalar(scalar_id, scalar_data[2..].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[2..].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[2..].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[2..].to_vec()), ]); let expected_difference = ColumnCommitments::try_from_columns_with_offset( expected_difference_columns.inner_table(), @@ -935,7 +943,7 @@ mod tests { .unwrap(); assert!(matches!( minuend_commitments.clone().try_sub(commitments_diff_id), - Err(ColumnCommitmentsMismatch::Identifier { .. }) + Err(ColumnCommitmentsMismatch::Ident { .. }) )); let table_diff_len: OwnedTable = owned_table([bigint("column_a", [1, 2])]); diff --git a/crates/proof-of-sql/src/base/commitment/query_commitments.rs b/crates/proof-of-sql/src/base/commitment/query_commitments.rs index 1fcd21d66..81339dcb6 100644 --- a/crates/proof-of-sql/src/base/commitment/query_commitments.rs +++ b/crates/proof-of-sql/src/base/commitment/query_commitments.rs @@ -7,7 +7,7 @@ use crate::base::{ map::IndexMap, }; use alloc::vec::Vec; -use proof_of_sql_parser::Identifier; +use sqlparser::ast::Ident; /// The commitments for all of the tables in a query. /// @@ -52,11 +52,14 @@ impl QueryCommitmentsExt for QueryCommitments { table_ref, TableCommitment::from_accessor_with_max_bounds( table_ref, - &accessor + accessor .lookup_schema(table_ref) .iter() - .filter_map(|c| columns.iter().find(|x| x.name() == c.0).copied()) - .collect::>(), + .filter_map(|c| { + columns.iter().find(|x| x.name() == c.0.clone()).cloned() + }) + .collect::>() + .as_slice(), accessor, ), ) @@ -97,7 +100,7 @@ impl SchemaAccessor for QueryCommitments { fn lookup_column( &self, table_ref: crate::base::database::TableRef, - column_id: Identifier, + column_id: Ident, ) -> Option { let table_commitment = self.get(&table_ref)?; @@ -113,14 +116,16 @@ impl SchemaAccessor for QueryCommitments { fn lookup_schema( &self, table_ref: crate::base::database::TableRef, - ) -> Vec<(Identifier, ColumnType)> { + ) -> Vec<(Ident, ColumnType)> { let table_commitment = self.get(&table_ref).unwrap(); table_commitment .column_commitments() .column_metadata() .iter() - .map(|(identifier, column_metadata)| (*identifier, *column_metadata.column_type())) + .map(|(identifier, column_metadata)| { + (identifier.clone(), *column_metadata.column_type()) + }) .collect() } } @@ -160,7 +165,7 @@ mod tests { let no_offset_id = "no.off".parse().unwrap(); let no_columns_commitment = TableCommitment::try_from_columns_with_offset( - Vec::<(&Identifier, &OwnedColumn)>::new(), + Vec::<(&Ident, &OwnedColumn)>::new(), 0, &(), ) @@ -169,7 +174,7 @@ mod tests { let no_rows_commitment = TableCommitment::try_from_columns_with_offset( [( - &"column_c".parse().unwrap(), + &"column_c".into(), &OwnedColumn::::BigInt(vec![]), )], 3, @@ -201,14 +206,18 @@ mod tests { #[allow(clippy::similar_names)] #[test] fn we_can_get_commitment_of_a_column() { - let column_a_id: Identifier = "column_a".parse().unwrap(); - let column_b_id: Identifier = "column_b".parse().unwrap(); + let column_a_id: Ident = "column_a".into(); + let column_b_id: Ident = "column_b".into(); let table_a: OwnedTable = owned_table([ - bigint(column_a_id, [1, 2, 3, 4]), - varchar(column_b_id, ["Lorem", "ipsum", "dolor", "sit"]), + bigint(column_a_id.value.as_str(), [1, 2, 3, 4]), + varchar( + column_b_id.value.as_str(), + ["Lorem", "ipsum", "dolor", "sit"], + ), ]); - let table_b: OwnedTable = owned_table([scalar(column_a_id, [1, 2])]); + let table_b: OwnedTable = + owned_table([scalar(column_a_id.value.as_str(), [1, 2])]); let table_a_commitment = TableCommitment::::from_owned_table_with_offset(&table_a, 2, &()); @@ -225,7 +234,7 @@ mod tests { assert_eq!( query_commitments.get_commitment(ColumnRef::new( table_a_id, - column_a_id, + column_a_id.clone(), ColumnType::BigInt )), table_a_commitment.column_commitments().commitments()[0] @@ -251,14 +260,18 @@ mod tests { #[allow(clippy::similar_names)] #[test] fn we_can_get_schema_of_tables() { - let column_a_id: Identifier = "column_a".parse().unwrap(); - let column_b_id: Identifier = "column_b".parse().unwrap(); + let column_a_id: Ident = "column_a".into(); + let column_b_id: Ident = "column_b".into(); let table_a: OwnedTable = owned_table([ - bigint(column_a_id, [1, 2, 3, 4]), - varchar(column_b_id, ["Lorem", "ipsum", "dolor", "sit"]), + bigint(column_a_id.value.as_str(), [1, 2, 3, 4]), + varchar( + column_b_id.value.as_str(), + ["Lorem", "ipsum", "dolor", "sit"], + ), ]); - let table_b: OwnedTable = owned_table([scalar(column_a_id, [1, 2])]); + let table_b: OwnedTable = + owned_table([scalar(column_a_id.value.as_str(), [1, 2])]); let table_a_commitment = TableCommitment::::from_owned_table_with_offset(&table_a, 2, &()); @@ -268,7 +281,7 @@ mod tests { let table_b_id = "table.b".parse().unwrap(); let no_columns_commitment = TableCommitment::try_from_columns_with_offset( - Vec::<(&Identifier, &OwnedColumn)>::new(), + Vec::<(&Ident, &OwnedColumn)>::new(), 0, &(), ) @@ -283,27 +296,27 @@ mod tests { assert_eq!( query_commitments - .lookup_column(table_a_id, column_a_id) + .lookup_column(table_a_id, column_a_id.clone()) .unwrap(), ColumnType::BigInt ); assert_eq!( query_commitments - .lookup_column(table_a_id, column_b_id) + .lookup_column(table_a_id, column_b_id.clone()) .unwrap(), ColumnType::VarChar ); assert_eq!( query_commitments.lookup_schema(table_a_id), vec![ - (column_a_id, ColumnType::BigInt), - (column_b_id, ColumnType::VarChar) + (column_a_id.clone(), ColumnType::BigInt), + (column_b_id.clone(), ColumnType::VarChar) ] ); assert_eq!( query_commitments - .lookup_column(table_b_id, column_a_id) + .lookup_column(table_b_id, column_a_id.clone()) .unwrap(), ColumnType::Scalar ); @@ -313,7 +326,7 @@ mod tests { ); assert_eq!( query_commitments.lookup_schema(table_b_id), - vec![(column_a_id, ColumnType::Scalar),] + vec![(column_a_id.clone(), ColumnType::Scalar),] ); assert_eq!( @@ -330,14 +343,20 @@ mod tests { let prover_setup = ProverSetup::from(&public_parameters); let setup = DoryProverPublicSetup::new(&prover_setup, 3); - let column_a_id: Identifier = "column_a".parse().unwrap(); - let column_b_id: Identifier = "column_b".parse().unwrap(); + let column_a_id: Ident = "column_a".into(); + let column_b_id: Ident = "column_b".into(); let table_a = owned_table([ - bigint(column_a_id, [1, 2, 3, 4]), - varchar(column_b_id, ["Lorem", "ipsum", "dolor", "sit"]), + bigint(column_a_id.value.as_str(), [1, 2, 3, 4]), + varchar( + column_b_id.value.as_str(), + ["Lorem", "ipsum", "dolor", "sit"], + ), + ]); + let table_b = owned_table([ + scalar(column_a_id.value.as_str(), [1, 2]), + int128(column_b_id.value.as_str(), [1, 2]), ]); - let table_b = owned_table([scalar(column_a_id, [1, 2]), int128(column_b_id, [1, 2])]); let mut table_a_commitment = TableCommitment::from_owned_table_with_offset(&table_a, 0, &setup); @@ -371,9 +390,9 @@ mod tests { let query_commitments = QueryCommitments::::from_accessor_with_max_bounds( [ - ColumnRef::new(table_a_id, column_a_id, ColumnType::BigInt), + ColumnRef::new(table_a_id, column_a_id.clone(), ColumnType::BigInt), ColumnRef::new(table_b_id, column_a_id, ColumnType::Scalar), - ColumnRef::new(table_a_id, column_b_id, ColumnType::VarChar), + ColumnRef::new(table_a_id, column_b_id.clone(), ColumnType::VarChar), ColumnRef::new(table_b_id, column_b_id, ColumnType::Int128), ], &accessor, diff --git a/crates/proof-of-sql/src/base/commitment/table_commitment.rs b/crates/proof-of-sql/src/base/commitment/table_commitment.rs index 8de64aa1e..dab7beb5d 100644 --- a/crates/proof-of-sql/src/base/commitment/table_commitment.rs +++ b/crates/proof-of-sql/src/base/commitment/table_commitment.rs @@ -8,9 +8,9 @@ use crate::base::{ }; use alloc::vec::Vec; use core::ops::Range; -use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; use snafu::Snafu; +use sqlparser::ast::Ident; /// Cannot create a [`TableCommitment`] with a negative range. #[derive(Debug, Snafu)] @@ -160,18 +160,17 @@ impl TableCommitment { /// /// Provided columns must have the same length and no duplicate identifiers. pub fn try_from_columns_with_offset<'a, COL>( - columns: impl IntoIterator, + columns: impl IntoIterator, offset: usize, setup: &C::PublicSetup<'_>, ) -> Result, TableCommitmentFromColumnsError> where COL: Into>, { - let (identifiers, committable_columns): (Vec<&Identifier>, Vec) = - columns - .into_iter() - .map(|(identifier, column)| (identifier, column.into())) - .unzip(); + let (identifiers, committable_columns): (Vec<&Ident>, Vec) = columns + .into_iter() + .map(|(identifier, column)| (identifier, column.into())) + .unzip(); let num_rows = num_rows_of_columns(&committable_columns)?; @@ -211,17 +210,16 @@ impl TableCommitment { /// Will error on a variety of mismatches, or if the provided columns have mixed length. pub fn try_append_rows<'a, COL>( &mut self, - columns: impl IntoIterator, + columns: impl IntoIterator, setup: &C::PublicSetup<'_>, ) -> Result<(), AppendTableCommitmentError> where COL: Into>, { - let (identifiers, committable_columns): (Vec<&Identifier>, Vec) = - columns - .into_iter() - .map(|(identifier, column)| (identifier, column.into())) - .unzip(); + let (identifiers, committable_columns): (Vec<&Ident>, Vec) = columns + .into_iter() + .map(|(identifier, column)| (identifier, column.into())) + .unzip(); let num_rows = num_rows_of_columns(&committable_columns)?; @@ -269,7 +267,7 @@ impl TableCommitment { /// Columns must have the same length as the current commitment and no duplicate identifiers. pub fn try_extend_columns<'a, COL>( &mut self, - columns: impl IntoIterator, + columns: impl IntoIterator, setup: &C::PublicSetup<'_>, ) -> Result<(), TableCommitmentFromColumnsError> where @@ -277,11 +275,10 @@ impl TableCommitment { { let num_rows = self.range.len(); - let (identifiers, committable_columns): (Vec<&Identifier>, Vec) = - columns - .into_iter() - .map(|(identifier, column)| (identifier, column.into())) - .unzip(); + let (identifiers, committable_columns): (Vec<&Ident>, Vec) = columns + .into_iter() + .map(|(identifier, column)| (identifier, column.into())) + .unzip(); let num_rows_of_new_columns = num_rows_of_columns(&committable_columns)?; if num_rows_of_new_columns != num_rows { @@ -402,8 +399,7 @@ mod tests { #[test] fn we_can_construct_table_commitment_from_columns_and_identifiers() { // no-columns case - let mut empty_columns_iter: IndexMap> = - IndexMap::default(); + let mut empty_columns_iter: IndexMap> = IndexMap::default(); let empty_table_commitment = TableCommitment::::try_from_columns_with_offset( &empty_columns_iter, @@ -420,7 +416,7 @@ mod tests { assert_eq!(empty_table_commitment.num_rows(), 0); // no-rows case - empty_columns_iter.insert("column_a".parse().unwrap(), OwnedColumn::BigInt(vec![])); + empty_columns_iter.insert("column_a".into(), OwnedColumn::BigInt(vec![])); let empty_table_commitment = TableCommitment::::try_from_columns_with_offset( &empty_columns_iter, @@ -467,9 +463,9 @@ mod tests { #[test] fn we_cannot_construct_table_commitment_from_duplicate_identifiers() { - let duplicate_identifier_a = "duplicate_identifier_a".parse().unwrap(); - let duplicate_identifier_b = "duplicate_identifier_b".parse().unwrap(); - let unique_identifier = "unique_identifier".parse().unwrap(); + let duplicate_identifier_a = "duplicate_identifier_a".into(); + let duplicate_identifier_b = "duplicate_identifier_b".into(); + let unique_identifier = "unique_identifier".into(); let empty_column = OwnedColumn::::BigInt(vec![]); @@ -525,9 +521,9 @@ mod tests { #[test] fn we_cannot_construct_table_commitment_from_columns_of_mixed_length() { - let column_id_a = "column_a".parse().unwrap(); - let column_id_b = "column_b".parse().unwrap(); - let column_id_c = "column_c".parse().unwrap(); + let column_id_a = "column_a".into(); + let column_id_b = "column_b".into(); + let column_id_c = "column_c".into(); let one_row_column = OwnedColumn::::BigInt(vec![1]); let two_row_column = OwnedColumn::::BigInt(vec![1, 2]); @@ -580,19 +576,19 @@ mod tests { #[test] fn we_can_append_rows_to_table_commitment() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let initial_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..2].to_vec()), - varchar(varchar_id, varchar_data[..2].to_vec()), - scalar(scalar_id, scalar_data[..2].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..2].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..2].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..2].to_vec()), ]); let mut table_commitment = @@ -605,9 +601,9 @@ mod tests { let mut table_commitment_clone = table_commitment.clone(); let append_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[2..].to_vec()), - varchar(varchar_id, varchar_data[2..].to_vec()), - scalar(scalar_id, scalar_data[2..].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[2..].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[2..].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[2..].to_vec()), ]); table_commitment @@ -615,9 +611,9 @@ mod tests { .unwrap(); let total_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let expected_table_commitment = @@ -668,8 +664,8 @@ mod tests { #[test] fn we_cannot_append_columns_with_duplicate_identifiers_to_table_commitment() { - let column_id_a = "column_a".parse().unwrap(); - let column_id_b = "column_b".parse().unwrap(); + let column_id_a = "column_a".into(); + let column_id_b = "column_b".into(); let column_data = OwnedColumn::::BigInt(vec![1, 2, 3]); @@ -705,11 +701,14 @@ mod tests { #[allow(clippy::similar_names)] #[test] fn we_cannot_append_columns_of_mixed_length_to_table_commitment() { - let column_id_a: Identifier = "column_a".parse().unwrap(); - let column_id_b: Identifier = "column_b".parse().unwrap(); + let column_id_a: Ident = "column_a".into(); + let column_id_b: Ident = "column_b".into(); let base_table: OwnedTable = owned_table([ - bigint(column_id_a, [1, 2, 3, 4]), - varchar(column_id_b, ["Lorem", "ipsum", "dolor", "sit"]), + bigint(column_id_a.value.as_str(), [1, 2, 3, 4]), + varchar( + column_id_b.value.as_str(), + ["Lorem", "ipsum", "dolor", "sit"], + ), ]); let mut table_commitment = @@ -744,18 +743,18 @@ mod tests { #[test] fn we_can_extend_columns_to_table_commitment() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let initial_columns: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), ]); let mut table_commitment = TableCommitment::::try_from_columns_with_offset( @@ -765,15 +764,16 @@ mod tests { ) .unwrap(); - let new_columns = owned_table::([scalar(scalar_id, scalar_data)]); + let new_columns = + owned_table::([scalar(scalar_id.value.as_str(), scalar_data)]); table_commitment .try_extend_columns(new_columns.inner_table(), &()) .unwrap(); let expected_columns = owned_table::([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let expected_table_commitment = TableCommitment::try_from_columns_with_offset(expected_columns.inner_table(), 2, &()) @@ -784,19 +784,19 @@ mod tests { #[test] fn we_can_add_table_commitments() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let columns_a: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..2].to_vec()), - varchar(varchar_id, varchar_data[..2].to_vec()), - scalar(scalar_id, scalar_data[..2].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..2].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..2].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..2].to_vec()), ]); let table_commitment_a = TableCommitment::::try_from_columns_with_offset( @@ -807,17 +807,17 @@ mod tests { .unwrap(); let columns_b: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[2..].to_vec()), - varchar(varchar_id, varchar_data[2..].to_vec()), - scalar(scalar_id, scalar_data[2..].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[2..].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[2..].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[2..].to_vec()), ]); let table_commitment_b = TableCommitment::try_from_columns_with_offset(columns_b.inner_table(), 2, &()).unwrap(); let columns_sum: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let table_commitment_sum = TableCommitment::try_from_columns_with_offset(columns_sum.inner_table(), 0, &()) @@ -929,19 +929,19 @@ mod tests { #[test] fn we_can_sub_table_commitments() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let columns_low: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..2].to_vec()), - varchar(varchar_id, varchar_data[..2].to_vec()), - scalar(scalar_id, scalar_data[..2].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..2].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..2].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..2].to_vec()), ]); let table_commitment_low = TableCommitment::::try_from_columns_with_offset( @@ -952,18 +952,18 @@ mod tests { .unwrap(); let columns_high: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[2..].to_vec()), - varchar(varchar_id, varchar_data[2..].to_vec()), - scalar(scalar_id, scalar_data[2..].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[2..].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[2..].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[2..].to_vec()), ]); let table_commitment_high = TableCommitment::try_from_columns_with_offset(columns_high.inner_table(), 2, &()) .unwrap(); let columns_all: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let table_commitment_all = TableCommitment::try_from_columns_with_offset(columns_all.inner_table(), 0, &()) @@ -1019,25 +1019,25 @@ mod tests { #[test] fn we_cannot_sub_noncontiguous_table_commitments() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let columns_minuend: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..].to_vec()), - varchar(varchar_id, varchar_data[..].to_vec()), - scalar(scalar_id, scalar_data[..].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..].to_vec()), ]); let columns_subtrahend: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..2].to_vec()), - varchar(varchar_id, varchar_data[..2].to_vec()), - scalar(scalar_id, scalar_data[..2].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..2].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..2].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..2].to_vec()), ]); let minuend_table_commitment = @@ -1091,19 +1091,19 @@ mod tests { #[test] fn we_cannot_sub_commitments_with_negative_difference() { - let bigint_id: Identifier = "bigint_column".parse().unwrap(); + let bigint_id: Ident = "bigint_column".into(); let bigint_data = [1i64, 5, -5, 0, 10]; - let varchar_id: Identifier = "varchar_column".parse().unwrap(); + let varchar_id: Ident = "varchar_column".into(); let varchar_data = ["Lorem", "ipsum", "dolor", "sit", "amet"]; - let scalar_id: Identifier = "scalar_column".parse().unwrap(); + let scalar_id: Ident = "scalar_column".into(); let scalar_data = [1000, 2000, 3000, -1000, 0]; let columns_low: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[..2].to_vec()), - varchar(varchar_id, varchar_data[..2].to_vec()), - scalar(scalar_id, scalar_data[..2].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[..2].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[..2].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[..2].to_vec()), ]); let table_commitment_low = TableCommitment::::try_from_columns_with_offset( @@ -1114,18 +1114,18 @@ mod tests { .unwrap(); let columns_high: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data[2..].to_vec()), - varchar(varchar_id, varchar_data[2..].to_vec()), - scalar(scalar_id, scalar_data[2..].to_vec()), + bigint(bigint_id.value.as_str(), bigint_data[2..].to_vec()), + varchar(varchar_id.value.as_str(), varchar_data[2..].to_vec()), + scalar(scalar_id.value.as_str(), scalar_data[2..].to_vec()), ]); let table_commitment_high = TableCommitment::try_from_columns_with_offset(columns_high.inner_table(), 2, &()) .unwrap(); let columns_all: OwnedTable = owned_table([ - bigint(bigint_id, bigint_data), - varchar(varchar_id, varchar_data), - scalar(scalar_id, scalar_data), + bigint(bigint_id.value.as_str(), bigint_data), + varchar(varchar_id.value.as_str(), varchar_data), + scalar(scalar_id.value.as_str(), scalar_data), ]); let table_commitment_all = TableCommitment::try_from_columns_with_offset(columns_all.inner_table(), 0, &()) @@ -1158,12 +1158,9 @@ mod tests { let b_scals = ["1".into(), "2".into(), "3".into()]; let columns = [ + (&"a".into(), &Column::::BigInt(&[1, 2, 3])), ( - &"a".parse().unwrap(), - &Column::::BigInt(&[1, 2, 3]), - ), - ( - &"b".parse().unwrap(), + &"b".into(), &Column::::VarChar((&["1", "2", "3"], &b_scals)), ), ]; @@ -1185,12 +1182,9 @@ mod tests { let b_scals2 = ["4".into(), "5".into(), "6".into()]; let columns2 = [ + (&"a".into(), &Column::::BigInt(&[4, 5, 6])), ( - &"a".parse().unwrap(), - &Column::::BigInt(&[4, 5, 6]), - ), - ( - &"b".parse().unwrap(), + &"b".into(), &Column::::VarChar((&["4", "5", "6"], &b_scals2)), ), ]; diff --git a/crates/proof-of-sql/src/base/database/accessor.rs b/crates/proof-of-sql/src/base/database/accessor.rs index 79f2cd3c9..e2795e428 100644 --- a/crates/proof-of-sql/src/base/database/accessor.rs +++ b/crates/proof-of-sql/src/base/database/accessor.rs @@ -5,7 +5,7 @@ use crate::base::{ scalar::Scalar, }; use alloc::vec::Vec; -use proof_of_sql_parser::Identifier; +use sqlparser::ast::Ident; /// Access metadata of a table span in a database. /// @@ -102,7 +102,7 @@ pub trait DataAccessor: MetadataAccessor { ) } else { Table::::try_from_iter(column_refs.into_iter().map(|column_ref| { - let column = self.get_column(*column_ref); + let column = self.get_column(column_ref.clone()); (column_ref.column_id(), column) })) } @@ -124,7 +124,7 @@ pub trait SchemaAccessor { /// /// Precondition 1: the table must exist and be tamperproof. /// Precondition 2: `table_ref` and `column_id` must always be lowercase. - fn lookup_column(&self, table_ref: TableRef, column_id: Identifier) -> Option; + fn lookup_column(&self, table_ref: TableRef, column_id: Ident) -> Option; /// Lookup all the column names and their data types in the specified table /// @@ -133,5 +133,5 @@ pub trait SchemaAccessor { /// /// Precondition 1: the table must exist and be tamperproof. /// Precondition 2: `table_name` must be lowercase. - fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Identifier, ColumnType)>; + fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Ident, ColumnType)>; } diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index 5cf68d590..25d3b0646 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -11,11 +11,9 @@ use core::{ fmt::{Display, Formatter}, mem::size_of, }; -use proof_of_sql_parser::{ - posql_time::{PoSQLTimeUnit, PoSQLTimeZone}, - Identifier, -}; +use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; use serde::{Deserialize, Serialize}; +use sqlparser::ast::Ident; /// Represents a read-only view of a column in an in-memory, /// column-oriented database. @@ -492,9 +490,9 @@ impl Display for ColumnType { } /// Reference of a SQL column -#[derive(Debug, PartialEq, Eq, Clone, Hash, Copy, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct ColumnRef { - column_id: Identifier, + column_id: Ident, table_ref: TableRef, column_type: ColumnType, } @@ -502,7 +500,7 @@ pub struct ColumnRef { impl ColumnRef { /// Create a new `ColumnRef` from a table, column identifier and column type #[must_use] - pub fn new(table_ref: TableRef, column_id: Identifier, column_type: ColumnType) -> Self { + pub fn new(table_ref: TableRef, column_id: Ident, column_type: ColumnType) -> Self { Self { column_id, table_ref, @@ -518,8 +516,8 @@ impl ColumnRef { /// Returns the column identifier of this column #[must_use] - pub fn column_id(&self) -> Identifier { - self.column_id + pub fn column_id(&self) -> Ident { + self.column_id.clone() } /// Returns the column type of this column @@ -533,23 +531,23 @@ impl ColumnRef { /// of a column in a table. Namely: it's name and type. /// /// This is the analog of a `Field` in Apache Arrow. -#[derive(Debug, PartialEq, Eq, Clone, Hash, Copy, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Eq, Clone, Hash, Serialize, Deserialize)] pub struct ColumnField { - name: Identifier, + name: Ident, data_type: ColumnType, } impl ColumnField { /// Create a new `ColumnField` from a name and a type #[must_use] - pub fn new(name: Identifier, data_type: ColumnType) -> ColumnField { + pub fn new(name: Ident, data_type: ColumnType) -> ColumnField { ColumnField { name, data_type } } /// Returns the name of the column #[must_use] - pub fn name(&self) -> Identifier { - self.name + pub fn name(&self) -> Ident { + self.name.clone() } /// Returns the type of the column diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation.rs b/crates/proof-of-sql/src/base/database/expression_evaluation.rs index c5d2315b0..d9df43097 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation.rs @@ -8,17 +8,14 @@ use crate::base::{ scalar::Scalar, }; use alloc::{format, string::ToString, vec}; -use proof_of_sql_parser::{ - intermediate_ast::{Expression, Literal}, - Identifier, -}; -use sqlparser::ast::{BinaryOperator, UnaryOperator}; +use proof_of_sql_parser::intermediate_ast::{Expression, Literal}; +use sqlparser::ast::{BinaryOperator, Ident, UnaryOperator}; impl OwnedTable { /// Evaluate an expression on the table. pub fn evaluate(&self, expr: &Expression) -> ExpressionEvaluationResult> { match expr { - Expression::Column(identifier) => self.evaluate_column(identifier), + Expression::Column(identifier) => self.evaluate_column(&Ident::from(*identifier)), Expression::Literal(lit) => self.evaluate_literal(lit), Expression::Binary { op, left, right } => { self.evaluate_binary_expr(&(*op).into(), left, right) @@ -30,10 +27,7 @@ impl OwnedTable { } } - fn evaluate_column( - &self, - identifier: &Identifier, - ) -> ExpressionEvaluationResult> { + fn evaluate_column(&self, identifier: &Ident) -> ExpressionEvaluationResult> { Ok(self .inner_table() .get(identifier) diff --git a/crates/proof-of-sql/src/base/database/join_util.rs b/crates/proof-of-sql/src/base/database/join_util.rs index 7cb614a9e..254909a66 100644 --- a/crates/proof-of-sql/src/base/database/join_util.rs +++ b/crates/proof-of-sql/src/base/database/join_util.rs @@ -17,15 +17,15 @@ pub fn cross_join<'a, S: Scalar>( Table::<'a, S>::try_from_iter_with_options( left.inner_table() .iter() - .map(|(&ident, column)| { + .map(|(ident, column)| { ( - ident, + ident.clone(), ColumnRepeatOp::column_op(column, alloc, right_num_rows), ) }) - .chain(right.inner_table().iter().map(|(&ident, column)| { + .chain(right.inner_table().iter().map(|(ident, column)| { ( - ident, + ident.clone(), ElementwiseRepeatOp::column_op(column, alloc, left_num_rows), ) })), @@ -38,26 +38,27 @@ pub fn cross_join<'a, S: Scalar>( mod tests { use super::*; use crate::base::{database::Column, scalar::test_scalar::TestScalar}; + use sqlparser::ast::Ident; #[test] fn we_can_do_cross_joins() { let bump = Bump::new(); - let a = "a".parse().unwrap(); - let b = "b".parse().unwrap(); - let c = "c".parse().unwrap(); - let d = "d".parse().unwrap(); + let a: Ident = "a".into(); + let b: Ident = "b".into(); + let c: Ident = "c".into(); + let d: Ident = "d".into(); let left = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (a, Column::SmallInt(&[1_i16, 2, 3])), - (b, Column::Int(&[4_i32, 5, 6])), + (a.clone(), Column::SmallInt(&[1_i16, 2, 3])), + (b.clone(), Column::Int(&[4_i32, 5, 6])), ], TableOptions::default(), ) .expect("Table creation should not fail"); let right = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (c, Column::BigInt(&[7_i64, 8, 9])), - (d, Column::Int128(&[10_i128, 11, 12])), + (c.clone(), Column::BigInt(&[7_i64, 8, 9])), + (d.clone(), Column::Int128(&[10_i128, 11, 12])), ], TableOptions::default(), ) @@ -86,24 +87,24 @@ mod tests { #[test] fn we_can_do_cross_joins_if_one_table_has_no_rows() { let bump = Bump::new(); - let a = "a".parse().unwrap(); - let b = "b".parse().unwrap(); - let c = "c".parse().unwrap(); - let d = "d".parse().unwrap(); + let a: Ident = "a".into(); + let b: Ident = "b".into(); + let c: Ident = "c".into(); + let d: Ident = "d".into(); // Right table has no rows let left = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (a, Column::SmallInt(&[1_i16, 2, 3])), - (b, Column::Int(&[4_i32, 5, 6])), + (a.clone(), Column::SmallInt(&[1_i16, 2, 3])), + (b.clone(), Column::Int(&[4_i32, 5, 6])), ], TableOptions::default(), ) .expect("Table creation should not fail"); let right = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (c, Column::BigInt(&[0_i64; 0])), - (d, Column::Int128(&[0_i128; 0])), + (c.clone(), Column::BigInt(&[0_i64; 0])), + (d.clone(), Column::Int128(&[0_i128; 0])), ], TableOptions::default(), ) @@ -119,16 +120,16 @@ mod tests { // Left table has no rows let left = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (a, Column::SmallInt(&[0_i16; 0])), - (b, Column::Int(&[0_i32; 0])), + (a.clone(), Column::SmallInt(&[0_i16; 0])), + (b.clone(), Column::Int(&[0_i32; 0])), ], TableOptions::default(), ) .expect("Table creation should not fail"); let right = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (c, Column::BigInt(&[7_i64, 8, 9])), - (d, Column::Int128(&[10_i128, 11, 12])), + (c.clone(), Column::BigInt(&[7_i64, 8, 9])), + (d.clone(), Column::Int128(&[10_i128, 11, 12])), ], TableOptions::default(), ) @@ -144,16 +145,16 @@ mod tests { // Both tables have no rows let left = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (a, Column::SmallInt(&[0_i16; 0])), - (b, Column::Int(&[0_i32; 0])), + (a.clone(), Column::SmallInt(&[0_i16; 0])), + (b.clone(), Column::Int(&[0_i32; 0])), ], TableOptions::default(), ) .expect("Table creation should not fail"); let right = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (c, Column::BigInt(&[0_i64; 0])), - (d, Column::Int128(&[0_i128; 0])), + (c.clone(), Column::BigInt(&[0_i64; 0])), + (d.clone(), Column::Int128(&[0_i128; 0])), ], TableOptions::default(), ) @@ -171,18 +172,18 @@ mod tests { fn we_can_do_cross_joins_if_one_table_has_no_columns() { // Left table has no columns let bump = Bump::new(); - let a = "a".parse().unwrap(); - let b = "b".parse().unwrap(); - let c = "c".parse().unwrap(); - let d = "d".parse().unwrap(); + let a: Ident = "a".into(); + let b: Ident = "b".into(); + let c: Ident = "c".into(); + let d: Ident = "d".into(); let left = Table::<'_, TestScalar>::try_from_iter_with_options(vec![], TableOptions::new(Some(2))) .expect("Table creation should not fail"); let right = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (c, Column::BigInt(&[7_i64, 8])), - (d, Column::Int128(&[10_i128, 11])), + (c.clone(), Column::BigInt(&[7_i64, 8])), + (d.clone(), Column::Int128(&[10_i128, 11])), ], TableOptions::default(), ) @@ -203,8 +204,8 @@ mod tests { // Right table has no columns let left = Table::<'_, TestScalar>::try_from_iter_with_options( vec![ - (a, Column::SmallInt(&[1_i16, 2])), - (b, Column::Int(&[4_i32, 5])), + (a.clone(), Column::SmallInt(&[1_i16, 2])), + (b.clone(), Column::Int(&[4_i32, 5])), ], TableOptions::default(), ) diff --git a/crates/proof-of-sql/src/base/database/owned_table.rs b/crates/proof-of-sql/src/base/database/owned_table.rs index 8e6a7514a..971055219 100644 --- a/crates/proof-of-sql/src/base/database/owned_table.rs +++ b/crates/proof-of-sql/src/base/database/owned_table.rs @@ -5,9 +5,9 @@ use crate::base::{ }; use alloc::{vec, vec::Vec}; use itertools::{EitherOrBoth, Itertools}; -use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; use snafu::Snafu; +use sqlparser::ast::Ident; /// An error that occurs when working with tables. #[derive(Snafu, Debug, PartialEq, Eq)] @@ -37,11 +37,11 @@ pub(crate) enum TableCoercionError { /// This is the analog of an arrow [`RecordBatch`](arrow::record_batch::RecordBatch). #[derive(Debug, Clone, Eq, Serialize, Deserialize)] pub struct OwnedTable { - table: IndexMap>, + table: IndexMap>, } impl OwnedTable { /// Creates a new [`OwnedTable`]. - pub fn try_new(table: IndexMap>) -> Result { + pub fn try_new(table: IndexMap>) -> Result { if table.is_empty() { return Ok(Self { table }); } @@ -53,7 +53,7 @@ impl OwnedTable { } } /// Creates a new [`OwnedTable`]. - pub fn try_from_iter)>>( + pub fn try_from_iter)>>( iter: T, ) -> Result { Self::try_new(IndexMap::from_iter(iter)) @@ -118,16 +118,16 @@ impl OwnedTable { } /// Returns the columns of this table as an `IndexMap` #[must_use] - pub fn into_inner(self) -> IndexMap> { + pub fn into_inner(self) -> IndexMap> { self.table } /// Returns the columns of this table as an `IndexMap` #[must_use] - pub fn inner_table(&self) -> &IndexMap> { + pub fn inner_table(&self) -> &IndexMap> { &self.table } /// Returns the columns of this table as an Iterator - pub fn column_names(&self) -> impl Iterator { + pub fn column_names(&self) -> impl Iterator { self.table.keys() } @@ -158,9 +158,7 @@ impl PartialEq for OwnedTable { impl core::ops::Index<&str> for OwnedTable { type Output = OwnedColumn; fn index(&self, index: &str) -> &Self::Output { - self.table - .get(&index.parse::().unwrap()) - .unwrap() + self.table.get(&Ident::new(index)).unwrap() } } @@ -170,7 +168,7 @@ impl<'a, S: Scalar> From<&Table<'a, S>> for OwnedTable { value .inner_table() .iter() - .map(|(name, column)| (*name, OwnedColumn::from(column))), + .map(|(name, column)| (name.clone(), OwnedColumn::from(column))), ) .expect("Tables should not have columns with differing lengths") } @@ -301,8 +299,8 @@ mod tests { ]); let fields = vec![ - ColumnField::new("bigint".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("scalar".parse().unwrap(), ColumnType::Int), + ColumnField::new("bigint".into(), ColumnType::BigInt), + ColumnField::new("scalar".into(), ColumnType::Int), ]; let coerced_table = table.clone().try_coerce_with_fields(fields).unwrap(); @@ -325,8 +323,8 @@ mod tests { ]); let fields = vec![ - ColumnField::new("bigint".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("mismatch".parse().unwrap(), ColumnType::Int), + ColumnField::new("bigint".into(), ColumnType::BigInt), + ColumnField::new("mismatch".into(), ColumnType::Int), ]; let result = table.clone().try_coerce_with_fields(fields); @@ -343,10 +341,7 @@ mod tests { scalar("scalar", [0, 1, 2, 3, 4, 5, 6, 7, 8]), ]); - let fields = vec![ColumnField::new( - "bigint".parse().unwrap(), - ColumnType::BigInt, - )]; + let fields = vec![ColumnField::new("bigint".into(), ColumnType::BigInt)]; let result = table.clone().try_coerce_with_fields(fields); @@ -366,8 +361,8 @@ mod tests { ]); let fields = vec![ - ColumnField::new("bigint".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("scalar".parse().unwrap(), ColumnType::TinyInt), + ColumnField::new("bigint".into(), ColumnType::BigInt), + ColumnField::new("scalar".into(), ColumnType::TinyInt), ]; let result = table.clone().try_coerce_with_fields(fields); diff --git a/crates/proof-of-sql/src/base/database/owned_table_test.rs b/crates/proof-of-sql/src/base/database/owned_table_test.rs index 2095a8b7c..183e5c870 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test.rs @@ -6,11 +6,8 @@ use crate::{ }, proof_primitive::dory::DoryScalar, }; -use proof_of_sql_parser::{ - posql_time::{PoSQLTimeUnit, PoSQLTimeZone}, - Identifier, -}; - +use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use sqlparser::ast::Ident; #[test] fn we_can_create_an_owned_table_with_no_columns() { let table = OwnedTable::::try_new(IndexMap::default()).unwrap(); @@ -26,26 +23,11 @@ fn we_can_create_an_empty_owned_table() { boolean("boolean", [true; 0]), ]); let mut table = IndexMap::default(); - table.insert( - Identifier::try_new("bigint").unwrap(), - OwnedColumn::BigInt(vec![]), - ); - table.insert( - Identifier::try_new("decimal").unwrap(), - OwnedColumn::Int128(vec![]), - ); - table.insert( - Identifier::try_new("varchar").unwrap(), - OwnedColumn::VarChar(vec![]), - ); - table.insert( - Identifier::try_new("scalar").unwrap(), - OwnedColumn::Scalar(vec![]), - ); - table.insert( - Identifier::try_new("boolean").unwrap(), - OwnedColumn::Boolean(vec![]), - ); + table.insert(Ident::new("bigint"), OwnedColumn::BigInt(vec![])); + table.insert(Ident::new("decimal"), OwnedColumn::Int128(vec![])); + table.insert(Ident::new("varchar"), OwnedColumn::VarChar(vec![])); + table.insert(Ident::new("scalar"), OwnedColumn::Scalar(vec![])); + table.insert(Ident::new("boolean"), OwnedColumn::Boolean(vec![])); assert_eq!(owned_table.into_inner(), table); } #[test] @@ -68,7 +50,7 @@ fn we_can_create_an_owned_table_with_data() { ]); let mut table = IndexMap::default(); table.insert( - Identifier::try_new("time_stamp").unwrap(), + Ident::new("time_stamp"), OwnedColumn::TimestampTZ( PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), @@ -76,15 +58,15 @@ fn we_can_create_an_owned_table_with_data() { ), ); table.insert( - Identifier::try_new("bigint").unwrap(), + Ident::new("bigint"), OwnedColumn::BigInt(vec![0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]), ); table.insert( - Identifier::try_new("decimal").unwrap(), + Ident::new("decimal"), OwnedColumn::Int128(vec![0_i128, 1, 2, 3, 4, 5, 6, i128::MIN, i128::MAX]), ); table.insert( - Identifier::try_new("varchar").unwrap(), + Ident::new("varchar"), OwnedColumn::VarChar(vec![ "0".to_string(), "1".to_string(), @@ -98,7 +80,7 @@ fn we_can_create_an_owned_table_with_data() { ]), ); table.insert( - Identifier::try_new("scalar").unwrap(), + Ident::new("scalar"), OwnedColumn::Scalar(vec![ DoryScalar::from(0), 1.into(), @@ -112,7 +94,7 @@ fn we_can_create_an_owned_table_with_data() { ]), ); table.insert( - Identifier::try_new("boolean").unwrap(), + Ident::new("boolean"), OwnedColumn::Boolean(vec![ true, false, true, false, true, false, true, false, true, ]), @@ -179,8 +161,8 @@ fn we_get_inequality_between_tables_with_differing_data() { fn we_cannot_create_an_owned_table_with_differing_column_lengths() { assert!(matches!( OwnedTable::::try_from_iter([ - ("a".parse().unwrap(), OwnedColumn::BigInt(vec![0])), - ("b".parse().unwrap(), OwnedColumn::BigInt(vec![])), + ("a".into(), OwnedColumn::BigInt(vec![0])), + ("b".into(), OwnedColumn::BigInt(vec![])), ]), Err(OwnedTableError::ColumnLengthMismatch) )); diff --git a/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs b/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs index 3217ee9dd..3ce25b5a0 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test_accessor.rs @@ -8,8 +8,7 @@ use crate::base::{ }; use alloc::{string::String, vec::Vec}; use bumpalo::Bump; -use proof_of_sql_parser::Identifier; - +use sqlparser::ast::Ident; /// A test accessor that uses [`OwnedTable`] as the underlying table type. /// Note: this is intended for testing and examples. It is not optimized for performance, so should not be used for benchmarks or production use-cases. pub struct OwnedTableTestAccessor<'a, CP: CommitmentEvaluationProof> { @@ -61,7 +60,7 @@ impl TestAccessor .unwrap() .0 .column_names() - .map(proof_of_sql_parser::Identifier::as_str) + .map(|ident| ident.value.as_str()) .collect() } @@ -150,7 +149,7 @@ impl MetadataAccessor for OwnedTableTestAccessor< } } impl SchemaAccessor for OwnedTableTestAccessor<'_, CP> { - fn lookup_column(&self, table_ref: TableRef, column_id: Identifier) -> Option { + fn lookup_column(&self, table_ref: TableRef, column_id: Ident) -> Option { Some( self.tables .get(&table_ref)? @@ -164,14 +163,14 @@ impl SchemaAccessor for OwnedTableTestAccessor<'_ /// # Panics /// /// Will panic if the `table_ref` is not found in `self.tables`, indicating that an invalid reference was provided. - fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Identifier, ColumnType)> { + fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Ident, ColumnType)> { self.tables .get(&table_ref) .unwrap() .0 .inner_table() .iter() - .map(|(&id, col)| (id, col.column_type())) + .map(|(id, col)| (id.clone(), col.column_type())) .collect() } } diff --git a/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs b/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs index d0ff8b87a..0eefd4cf3 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs @@ -39,7 +39,7 @@ fn we_can_access_the_columns_of_a_table() { let data1 = owned_table([bigint("a", [1, 2, 3]), bigint("b", [4, 5, 6])]); accessor.add_table(table_ref_1, data1, 0_usize); - let column = ColumnRef::new(table_ref_1, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "b".into(), ColumnType::BigInt); match accessor.get_column(column) { Column::BigInt(col) => assert_eq!(col.to_vec(), vec![4, 5, 6]), _ => panic!("Invalid column type"), @@ -61,19 +61,19 @@ fn we_can_access_the_columns_of_a_table() { ]); accessor.add_table(table_ref_2, data2, 0_usize); - let column = ColumnRef::new(table_ref_1, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "a".into(), ColumnType::BigInt); match accessor.get_column(column) { Column::BigInt(col) => assert_eq!(col.to_vec(), vec![1, 2, 3]), _ => panic!("Invalid column type"), }; - let column = ColumnRef::new(table_ref_2, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_2, "b".into(), ColumnType::BigInt); match accessor.get_column(column) { Column::BigInt(col) => assert_eq!(col.to_vec(), vec![4, 5, 6, 5]), _ => panic!("Invalid column type"), }; - let column = ColumnRef::new(table_ref_2, "c128".parse().unwrap(), ColumnType::Int128); + let column = ColumnRef::new(table_ref_2, "c128".into(), ColumnType::Int128); match accessor.get_column(column) { Column::Int128(col) => assert_eq!(col.to_vec(), vec![1, 2, 3, 4]), _ => panic!("Invalid column type"), @@ -84,7 +84,7 @@ fn we_can_access_the_columns_of_a_table() { .iter() .map(core::convert::Into::into) .collect(); - let column = ColumnRef::new(table_ref_2, "varchar".parse().unwrap(), ColumnType::VarChar); + let column = ColumnRef::new(table_ref_2, "varchar".into(), ColumnType::VarChar); match accessor.get_column(column) { Column::VarChar((col, scals)) => { assert_eq!(col.to_vec(), col_slice); @@ -93,7 +93,7 @@ fn we_can_access_the_columns_of_a_table() { _ => panic!("Invalid column type"), }; - let column = ColumnRef::new(table_ref_2, "scalar".parse().unwrap(), ColumnType::Scalar); + let column = ColumnRef::new(table_ref_2, "scalar".into(), ColumnType::Scalar); match accessor.get_column(column) { Column::Scalar(col) => assert_eq!( col.to_vec(), @@ -107,7 +107,7 @@ fn we_can_access_the_columns_of_a_table() { _ => panic!("Invalid column type"), }; - let column = ColumnRef::new(table_ref_2, "boolean".parse().unwrap(), ColumnType::Boolean); + let column = ColumnRef::new(table_ref_2, "boolean".into(), ColumnType::Boolean); match accessor.get_column(column) { Column::Boolean(col) => assert_eq!(col.to_vec(), vec![true, false, true, false]), _ => panic!("Invalid column type"), @@ -115,7 +115,7 @@ fn we_can_access_the_columns_of_a_table() { let column = ColumnRef::new( table_ref_2, - "time".parse().unwrap(), + "time".into(), ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()), ); match accessor.get_column(column) { @@ -133,7 +133,7 @@ fn we_can_access_the_commitments_of_table_columns() { let data1 = owned_table([bigint("a", [1, 2, 3]), bigint("b", [4, 5, 6])]); accessor.add_table(table_ref_1, data1, 0_usize); - let column = ColumnRef::new(table_ref_1, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "b".into(), ColumnType::BigInt); assert_eq!( accessor.get_commitment(column), NaiveCommitment::compute_commitments( @@ -146,7 +146,7 @@ fn we_can_access_the_commitments_of_table_columns() { let data2 = owned_table([bigint("a", [1, 2, 3, 4]), bigint("b", [4, 5, 6, 5])]); accessor.add_table(table_ref_2, data2, 0_usize); - let column = ColumnRef::new(table_ref_1, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "a".into(), ColumnType::BigInt); assert_eq!( accessor.get_commitment(column), NaiveCommitment::compute_commitments( @@ -156,7 +156,7 @@ fn we_can_access_the_commitments_of_table_columns() { )[0] ); - let column = ColumnRef::new(table_ref_2, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_2, "b".into(), ColumnType::BigInt); assert_eq!( accessor.get_commitment(column), NaiveCommitment::compute_commitments( @@ -176,13 +176,13 @@ fn we_can_access_the_type_of_table_columns() { let data1 = owned_table([bigint("a", [1, 2, 3]), bigint("b", [4, 5, 6])]); accessor.add_table(table_ref_1, data1, 0_usize); - let column = ColumnRef::new(table_ref_1, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "b".into(), ColumnType::BigInt); assert_eq!( accessor.lookup_column(column.table_ref(), column.column_id()), Some(ColumnType::BigInt) ); - let column = ColumnRef::new(table_ref_1, "c".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "c".into(), ColumnType::BigInt); assert!(accessor .lookup_column(column.table_ref(), column.column_id()) .is_none()); @@ -190,19 +190,19 @@ fn we_can_access_the_type_of_table_columns() { let data2 = owned_table([bigint("a", [1, 2, 3, 4]), bigint("b", [4, 5, 6, 5])]); accessor.add_table(table_ref_2, data2, 0_usize); - let column = ColumnRef::new(table_ref_1, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "a".into(), ColumnType::BigInt); assert_eq!( accessor.lookup_column(column.table_ref(), column.column_id()), Some(ColumnType::BigInt) ); - let column = ColumnRef::new(table_ref_2, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_2, "b".into(), ColumnType::BigInt); assert_eq!( accessor.lookup_column(column.table_ref(), column.column_id()), Some(ColumnType::BigInt) ); - let column = ColumnRef::new(table_ref_2, "c".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_2, "c".into(), ColumnType::BigInt); assert!(accessor .lookup_column(column.table_ref(), column.column_id()) .is_none()); @@ -219,8 +219,8 @@ fn we_can_access_schema_and_column_names() { assert_eq!( accessor.lookup_schema(table_ref_1), vec![ - ("a".parse().unwrap(), ColumnType::BigInt), - ("b".parse().unwrap(), ColumnType::VarChar) + ("a".into(), ColumnType::BigInt), + ("b".into(), ColumnType::VarChar) ] ); assert_eq!(accessor.get_column_names(table_ref_1), vec!["a", "b"]); @@ -238,14 +238,14 @@ fn we_can_correctly_update_offsets() { let mut accessor2 = OwnedTableTestAccessor::::new_empty_with_setup(()); accessor2.add_table(table_ref, data, offset); - let column = ColumnRef::new(table_ref, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref, "a".into(), ColumnType::BigInt); assert_ne!( - accessor1.get_commitment(column), + accessor1.get_commitment(column.clone()), accessor2.get_commitment(column) ); - let column = ColumnRef::new(table_ref, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref, "b".into(), ColumnType::BigInt); assert_ne!( - accessor1.get_commitment(column), + accessor1.get_commitment(column.clone()), accessor2.get_commitment(column) ); @@ -254,14 +254,14 @@ fn we_can_correctly_update_offsets() { accessor1.update_offset(table_ref, offset); - let column = ColumnRef::new(table_ref, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref, "a".into(), ColumnType::BigInt); assert_eq!( - accessor1.get_commitment(column), + accessor1.get_commitment(column.clone()), accessor2.get_commitment(column) ); - let column = ColumnRef::new(table_ref, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref, "b".into(), ColumnType::BigInt); assert_eq!( - accessor1.get_commitment(column), + accessor1.get_commitment(column.clone()), accessor2.get_commitment(column) ); diff --git a/crates/proof-of-sql/src/base/database/owned_table_utility.rs b/crates/proof-of-sql/src/base/database/owned_table_utility.rs index 498029734..40f5ebd92 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_utility.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_utility.rs @@ -16,11 +16,8 @@ use super::{OwnedColumn, OwnedTable}; use crate::base::scalar::Scalar; use alloc::string::String; -use core::ops::Deref; -use proof_of_sql_parser::{ - posql_time::{PoSQLTimeUnit, PoSQLTimeZone}, - Identifier, -}; +use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use sqlparser::ast::Ident; /// Creates an [`OwnedTable`] from a list of `(Identifier, OwnedColumn)` pairs. /// This is a convenience wrapper around [`OwnedTable::try_from_iter`] primarily for use in tests and @@ -43,7 +40,7 @@ use proof_of_sql_parser::{ /// # Panics /// - Panics if converting the iterator into an `OwnedTable` fails. pub fn owned_table( - iter: impl IntoIterator)>, + iter: impl IntoIterator)>, ) -> OwnedTable { OwnedTable::try_from_iter(iter).unwrap() } @@ -60,11 +57,11 @@ pub fn owned_table( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn tinyint( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::TinyInt(data.into_iter().map(Into::into).collect()), ) } @@ -81,11 +78,11 @@ pub fn tinyint( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn smallint( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::SmallInt(data.into_iter().map(Into::into).collect()), ) } @@ -102,11 +99,11 @@ pub fn smallint( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn int( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::Int(data.into_iter().map(Into::into).collect()), ) } @@ -122,11 +119,11 @@ pub fn int( /// ``` #[allow(clippy::missing_panics_doc)] pub fn bigint( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::BigInt(data.into_iter().map(Into::into).collect()), ) } @@ -144,11 +141,11 @@ pub fn bigint( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn boolean( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::Boolean(data.into_iter().map(Into::into).collect()), ) } @@ -166,11 +163,11 @@ pub fn boolean( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn int128( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::Int128(data.into_iter().map(Into::into).collect()), ) } @@ -188,11 +185,11 @@ pub fn int128( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn scalar( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::Scalar(data.into_iter().map(Into::into).collect()), ) } @@ -210,11 +207,11 @@ pub fn scalar( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn varchar( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::VarChar(data.into_iter().map(Into::into).collect()), ) } @@ -233,13 +230,13 @@ pub fn varchar( /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. /// - Panics if creating the `Precision` from the specified precision value fails. pub fn decimal75( - name: impl Deref, + name: impl Into, precision: u8, scale: i8, data: impl IntoIterator>, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::Decimal75( crate::base::math::decimal::Precision::new(precision).unwrap(), scale, @@ -273,13 +270,13 @@ pub fn decimal75( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn timestamptz( - name: impl Deref, + name: impl Into, time_unit: PoSQLTimeUnit, timezone: PoSQLTimeZone, data: impl IntoIterator, -) -> (Identifier, OwnedColumn) { +) -> (Ident, OwnedColumn) { ( - name.parse().unwrap(), + name.into(), OwnedColumn::TimestampTZ(time_unit, timezone, data.into_iter().collect()), ) } diff --git a/crates/proof-of-sql/src/base/database/table.rs b/crates/proof-of-sql/src/base/database/table.rs index 5425815e4..c891d2f57 100644 --- a/crates/proof-of-sql/src/base/database/table.rs +++ b/crates/proof-of-sql/src/base/database/table.rs @@ -1,8 +1,8 @@ use super::{Column, ColumnField}; use crate::base::{map::IndexMap, scalar::Scalar}; use alloc::vec::Vec; -use proof_of_sql_parser::Identifier; use snafu::Snafu; +use sqlparser::ast::Ident; /// Options for creating a table. /// Inspired by [`RecordBatchOptions`](https://docs.rs/arrow/latest/arrow/record_batch/struct.RecordBatchOptions.html) @@ -42,18 +42,18 @@ pub enum TableError { /// This is the analog of an arrow [`RecordBatch`](arrow::record_batch::RecordBatch). #[derive(Debug, Clone, Eq)] pub struct Table<'a, S: Scalar> { - table: IndexMap>, + table: IndexMap>, row_count: usize, } impl<'a, S: Scalar> Table<'a, S> { /// Creates a new [`Table`] with the given columns and default [`TableOptions`]. - pub fn try_new(table: IndexMap>) -> Result { + pub fn try_new(table: IndexMap>) -> Result { Self::try_new_with_options(table, TableOptions::default()) } /// Creates a new [`Table`] with the given columns and with [`TableOptions`]. pub fn try_new_with_options( - table: IndexMap>, + table: IndexMap>, options: TableOptions, ) -> Result { match (table.is_empty(), options.row_count) { @@ -78,14 +78,14 @@ impl<'a, S: Scalar> Table<'a, S> { } /// Creates a new [`Table`] from an iterator of `(Identifier, Column)` pairs with default [`TableOptions`]. - pub fn try_from_iter)>>( + pub fn try_from_iter)>>( iter: T, ) -> Result { Self::try_from_iter_with_options(iter, TableOptions::default()) } /// Creates a new [`Table`] from an iterator of `(Identifier, Column)` pairs with [`TableOptions`]. - pub fn try_from_iter_with_options)>>( + pub fn try_from_iter_with_options)>>( iter: T, options: TableOptions, ) -> Result { @@ -109,12 +109,12 @@ impl<'a, S: Scalar> Table<'a, S> { } /// Returns the columns of this table as an `IndexMap` #[must_use] - pub fn into_inner(self) -> IndexMap> { + pub fn into_inner(self) -> IndexMap> { self.table } /// Returns the columns of this table as an `IndexMap` #[must_use] - pub fn inner_table(&self) -> &IndexMap> { + pub fn inner_table(&self) -> &IndexMap> { &self.table } /// Return the schema of this table as a `Vec` of `ColumnField`s @@ -122,11 +122,11 @@ impl<'a, S: Scalar> Table<'a, S> { pub fn schema(&self) -> Vec { self.table .iter() - .map(|(name, column)| ColumnField::new(*name, column.column_type())) + .map(|(name, column)| ColumnField::new(name.clone(), column.column_type())) .collect() } /// Returns the columns of this table as an Iterator - pub fn column_names(&self) -> impl Iterator { + pub fn column_names(&self) -> impl Iterator { self.table.keys() } /// Returns the columns of this table as an Iterator @@ -157,8 +157,6 @@ impl PartialEq for Table<'_, S> { impl<'a, S: Scalar> core::ops::Index<&str> for Table<'a, S> { type Output = Column<'a, S>; fn index(&self, index: &str) -> &Self::Output { - self.table - .get(&index.parse::().unwrap()) - .unwrap() + self.table.get(&Ident::new(index)).unwrap() } } diff --git a/crates/proof-of-sql/src/base/database/table_ref.rs b/crates/proof-of-sql/src/base/database/table_ref.rs index b4f090c48..4ac273a79 100644 --- a/crates/proof-of-sql/src/base/database/table_ref.rs +++ b/crates/proof-of-sql/src/base/database/table_ref.rs @@ -4,7 +4,8 @@ use core::{ fmt::{Display, Formatter}, str::FromStr, }; -use proof_of_sql_parser::{impl_serde_from_str, Identifier, ResourceId}; +use proof_of_sql_parser::{impl_serde_from_str, ResourceId}; +use sqlparser::ast::Ident; /// Expression for an SQL table #[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)] @@ -21,14 +22,14 @@ impl TableRef { /// Returns the identifier of the schema #[must_use] - pub fn schema_id(&self) -> Identifier { - self.resource_id.schema() + pub fn schema_id(&self) -> Ident { + self.resource_id.schema().into() } /// Returns the identifier of the table #[must_use] - pub fn table_id(&self) -> Identifier { - self.resource_id.object_name() + pub fn table_id(&self) -> Ident { + self.resource_id.object_name().into() } /// Returns the underlying resource id of the table diff --git a/crates/proof-of-sql/src/base/database/table_test.rs b/crates/proof-of-sql/src/base/database/table_test.rs index e53f2d954..6f9eaef13 100644 --- a/crates/proof-of-sql/src/base/database/table_test.rs +++ b/crates/proof-of-sql/src/base/database/table_test.rs @@ -4,11 +4,8 @@ use crate::base::{ scalar::test_scalar::TestScalar, }; use bumpalo::Bump; -use proof_of_sql_parser::{ - posql_time::{PoSQLTimeUnit, PoSQLTimeZone}, - Identifier, -}; - +use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use sqlparser::ast::Ident; #[test] fn we_can_create_a_table_with_no_columns_specifying_row_count() { let table = @@ -27,16 +24,16 @@ fn we_can_create_a_table_with_no_columns_specifying_row_count() { #[test] fn we_can_create_a_table_with_default_options() { let table = Table::::try_new(indexmap! { - "a".parse().unwrap() => Column::BigInt(&[0, 1]), - "b".parse().unwrap() => Column::Int128(&[0, 1]), + "a".into() => Column::BigInt(&[0, 1]), + "b".into() => Column::Int128(&[0, 1]), }) .unwrap(); assert_eq!(table.num_columns(), 2); assert_eq!(table.num_rows(), 2); let table = Table::::try_new(indexmap! { - "a".parse().unwrap() => Column::BigInt(&[]), - "b".parse().unwrap() => Column::Int128(&[]), + "a".into() => Column::BigInt(&[]), + "b".into() => Column::Int128(&[]), }) .unwrap(); assert_eq!(table.num_columns(), 2); @@ -47,8 +44,8 @@ fn we_can_create_a_table_with_default_options() { fn we_can_create_a_table_with_specified_row_count() { let table = Table::::try_new_with_options( indexmap! { - "a".parse().unwrap() => Column::BigInt(&[0, 1]), - "b".parse().unwrap() => Column::Int128(&[0, 1]), + "a".into() => Column::BigInt(&[0, 1]), + "b".into() => Column::Int128(&[0, 1]), }, TableOptions::new(Some(2)), ) @@ -58,8 +55,8 @@ fn we_can_create_a_table_with_specified_row_count() { let table = Table::::try_new_with_options( indexmap! { - "a".parse().unwrap() => Column::BigInt(&[]), - "b".parse().unwrap() => Column::Int128(&[]), + "a".into() => Column::BigInt(&[]), + "b".into() => Column::Int128(&[]), }, TableOptions::new(Some(0)), ) @@ -72,8 +69,8 @@ fn we_can_create_a_table_with_specified_row_count() { fn we_cannot_create_a_table_with_differing_column_lengths() { assert!(matches!( Table::::try_from_iter([ - ("a".parse().unwrap(), Column::BigInt(&[0])), - ("b".parse().unwrap(), Column::BigInt(&[])), + ("a".into(), Column::BigInt(&[0])), + ("b".into(), Column::BigInt(&[])), ]), Err(TableError::ColumnLengthMismatch) )); @@ -84,8 +81,8 @@ fn we_cannot_create_a_table_with_column_length_different_from_specified_row_coun assert!(matches!( Table::::try_from_iter_with_options( [ - ("a".parse().unwrap(), Column::BigInt(&[0])), - ("b".parse().unwrap(), Column::BigInt(&[1])), + ("a".into(), Column::BigInt(&[0])), + ("b".into(), Column::BigInt(&[1])), ], TableOptions::new(Some(0)) ), @@ -117,17 +114,11 @@ fn we_can_create_an_empty_table_with_some_columns() { borrowed_boolean("boolean", [true; 0], &alloc), ]); let mut table = IndexMap::default(); - table.insert(Identifier::try_new("bigint").unwrap(), Column::BigInt(&[])); - table.insert(Identifier::try_new("decimal").unwrap(), Column::Int128(&[])); - table.insert( - Identifier::try_new("varchar").unwrap(), - Column::VarChar((&[], &[])), - ); - table.insert(Identifier::try_new("scalar").unwrap(), Column::Scalar(&[])); - table.insert( - Identifier::try_new("boolean").unwrap(), - Column::Boolean(&[]), - ); + table.insert(Ident::new("bigint"), Column::BigInt(&[])); + table.insert(Ident::new("decimal"), Column::Int128(&[])); + table.insert(Ident::new("varchar"), Column::VarChar((&[], &[]))); + table.insert(Ident::new("scalar"), Column::Scalar(&[])); + table.insert(Ident::new("boolean"), Column::Boolean(&[])); assert_eq!(borrowed_table.into_inner(), table); } @@ -170,21 +161,15 @@ fn we_can_create_a_table_with_data() { let time_stamp_data = alloc.alloc_slice_copy(&[0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]); expected_table.insert( - Identifier::try_new("time_stamp").unwrap(), + Ident::new("time_stamp"), Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), time_stamp_data), ); let bigint_data = alloc.alloc_slice_copy(&[0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]); - expected_table.insert( - Identifier::try_new("bigint").unwrap(), - Column::BigInt(bigint_data), - ); + expected_table.insert(Ident::new("bigint"), Column::BigInt(bigint_data)); let decimal_data = alloc.alloc_slice_copy(&[0_i128, 1, 2, 3, 4, 5, 6, i128::MIN, i128::MAX]); - expected_table.insert( - Identifier::try_new("decimal").unwrap(), - Column::Int128(decimal_data), - ); + expected_table.insert(Ident::new("decimal"), Column::Int128(decimal_data)); let varchar_data: Vec<&str> = ["0", "1", "2", "3", "4", "5", "6", "7", "8"] .iter() @@ -194,23 +179,17 @@ fn we_can_create_a_table_with_data() { let varchar_scalars: Vec = varchar_data.iter().map(Into::into).collect(); let varchar_scalars_slice = alloc.alloc_slice_clone(&varchar_scalars); expected_table.insert( - Identifier::try_new("varchar").unwrap(), + Ident::new("varchar"), Column::VarChar((varchar_str_slice, varchar_scalars_slice)), ); let scalar_data: Vec = (0..=8).map(TestScalar::from).collect(); let scalar_slice = alloc.alloc_slice_copy(&scalar_data); - expected_table.insert( - Identifier::try_new("scalar").unwrap(), - Column::Scalar(scalar_slice), - ); + expected_table.insert(Ident::new("scalar"), Column::Scalar(scalar_slice)); let boolean_data = alloc.alloc_slice_copy(&[true, false, true, false, true, false, true, false, true]); - expected_table.insert( - Identifier::try_new("boolean").unwrap(), - Column::Boolean(boolean_data), - ); + expected_table.insert(Ident::new("boolean"), Column::Boolean(boolean_data)); assert_eq!(borrowed_table.into_inner(), expected_table); } diff --git a/crates/proof-of-sql/src/base/database/table_test_accessor.rs b/crates/proof-of-sql/src/base/database/table_test_accessor.rs index 8053fa2d8..8d7e661a4 100644 --- a/crates/proof-of-sql/src/base/database/table_test_accessor.rs +++ b/crates/proof-of-sql/src/base/database/table_test_accessor.rs @@ -7,7 +7,7 @@ use crate::base::{ map::IndexMap, }; use alloc::vec::Vec; -use proof_of_sql_parser::Identifier; +use sqlparser::ast::Ident; /// A test accessor that uses [`Table`] as the underlying table type. /// Note: this is intended for testing and examples. It is not optimized for performance, so should not be used for benchmarks or production use-cases. @@ -55,7 +55,7 @@ impl<'a, CP: CommitmentEvaluationProof> TestAccessor for TableTe .unwrap() .0 .column_names() - .map(proof_of_sql_parser::Identifier::as_str) + .map(|ident| ident.value.as_str()) .collect() } @@ -122,7 +122,7 @@ impl MetadataAccessor for TableTestAccessor<'_, C } } impl SchemaAccessor for TableTestAccessor<'_, CP> { - fn lookup_column(&self, table_ref: TableRef, column_id: Identifier) -> Option { + fn lookup_column(&self, table_ref: TableRef, column_id: Ident) -> Option { Some( self.tables .get(&table_ref)? @@ -136,14 +136,14 @@ impl SchemaAccessor for TableTestAccessor<'_, CP> /// # Panics /// /// Will panic if the `table_ref` is not found in `self.tables`, indicating that an invalid reference was provided. - fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Identifier, ColumnType)> { + fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Ident, ColumnType)> { self.tables .get(&table_ref) .unwrap() .0 .inner_table() .iter() - .map(|(&id, col)| (id, col.column_type())) + .map(|(id, col)| (id.clone(), col.column_type())) .collect() } } diff --git a/crates/proof-of-sql/src/base/database/table_test_accessor_test.rs b/crates/proof-of-sql/src/base/database/table_test_accessor_test.rs index c1341f66c..a594fe151 100644 --- a/crates/proof-of-sql/src/base/database/table_test_accessor_test.rs +++ b/crates/proof-of-sql/src/base/database/table_test_accessor_test.rs @@ -51,7 +51,7 @@ fn we_can_access_the_columns_of_a_table() { ]); accessor.add_table(table_ref_1, data1, 0_usize); - let column = ColumnRef::new(table_ref_1, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "b".into(), ColumnType::BigInt); match accessor.get_column(column) { Column::BigInt(col) => assert_eq!(col.to_vec(), vec![4, 5, 6]), _ => panic!("Invalid column type"), @@ -74,19 +74,19 @@ fn we_can_access_the_columns_of_a_table() { ]); accessor.add_table(table_ref_2, data2, 0_usize); - let column = ColumnRef::new(table_ref_1, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "a".into(), ColumnType::BigInt); match accessor.get_column(column) { Column::BigInt(col) => assert_eq!(col.to_vec(), vec![1, 2, 3]), _ => panic!("Invalid column type"), }; - let column = ColumnRef::new(table_ref_2, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_2, "b".into(), ColumnType::BigInt); match accessor.get_column(column) { Column::BigInt(col) => assert_eq!(col.to_vec(), vec![4, 5, 6, 5]), _ => panic!("Invalid column type"), }; - let column = ColumnRef::new(table_ref_2, "c128".parse().unwrap(), ColumnType::Int128); + let column = ColumnRef::new(table_ref_2, "c128".into(), ColumnType::Int128); match accessor.get_column(column) { Column::Int128(col) => assert_eq!(col.to_vec(), vec![1, 2, 3, 4]), _ => panic!("Invalid column type"), @@ -97,7 +97,7 @@ fn we_can_access_the_columns_of_a_table() { .iter() .map(core::convert::Into::into) .collect(); - let column = ColumnRef::new(table_ref_2, "varchar".parse().unwrap(), ColumnType::VarChar); + let column = ColumnRef::new(table_ref_2, "varchar".into(), ColumnType::VarChar); match accessor.get_column(column) { Column::VarChar((col, scals)) => { assert_eq!(col.to_vec(), col_slice); @@ -106,7 +106,7 @@ fn we_can_access_the_columns_of_a_table() { _ => panic!("Invalid column type"), }; - let column = ColumnRef::new(table_ref_2, "scalar".parse().unwrap(), ColumnType::Scalar); + let column = ColumnRef::new(table_ref_2, "scalar".into(), ColumnType::Scalar); match accessor.get_column(column) { Column::Scalar(col) => assert_eq!( col.to_vec(), @@ -120,7 +120,7 @@ fn we_can_access_the_columns_of_a_table() { _ => panic!("Invalid column type"), }; - let column = ColumnRef::new(table_ref_2, "boolean".parse().unwrap(), ColumnType::Boolean); + let column = ColumnRef::new(table_ref_2, "boolean".into(), ColumnType::Boolean); match accessor.get_column(column) { Column::Boolean(col) => assert_eq!(col.to_vec(), vec![true, false, true, false]), _ => panic!("Invalid column type"), @@ -128,7 +128,7 @@ fn we_can_access_the_columns_of_a_table() { let column = ColumnRef::new( table_ref_2, - "time".parse().unwrap(), + "time".into(), ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()), ); match accessor.get_column(column) { @@ -150,7 +150,7 @@ fn we_can_access_the_commitments_of_table_columns() { ]); accessor.add_table(table_ref_1, data1, 0_usize); - let column = ColumnRef::new(table_ref_1, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "b".into(), ColumnType::BigInt); assert_eq!( accessor.get_commitment(column), NaiveCommitment::compute_commitments( @@ -166,7 +166,7 @@ fn we_can_access_the_commitments_of_table_columns() { ]); accessor.add_table(table_ref_2, data2, 0_usize); - let column = ColumnRef::new(table_ref_1, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "a".into(), ColumnType::BigInt); assert_eq!( accessor.get_commitment(column), NaiveCommitment::compute_commitments( @@ -176,7 +176,7 @@ fn we_can_access_the_commitments_of_table_columns() { )[0] ); - let column = ColumnRef::new(table_ref_2, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_2, "b".into(), ColumnType::BigInt); assert_eq!( accessor.get_commitment(column), NaiveCommitment::compute_commitments( @@ -200,13 +200,13 @@ fn we_can_access_the_type_of_table_columns() { ]); accessor.add_table(table_ref_1, data1, 0_usize); - let column = ColumnRef::new(table_ref_1, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "b".into(), ColumnType::BigInt); assert_eq!( accessor.lookup_column(column.table_ref(), column.column_id()), Some(ColumnType::BigInt) ); - let column = ColumnRef::new(table_ref_1, "c".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "c".into(), ColumnType::BigInt); assert!(accessor .lookup_column(column.table_ref(), column.column_id()) .is_none()); @@ -217,19 +217,19 @@ fn we_can_access_the_type_of_table_columns() { ]); accessor.add_table(table_ref_2, data2, 0_usize); - let column = ColumnRef::new(table_ref_1, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_1, "a".into(), ColumnType::BigInt); assert_eq!( accessor.lookup_column(column.table_ref(), column.column_id()), Some(ColumnType::BigInt) ); - let column = ColumnRef::new(table_ref_2, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_2, "b".into(), ColumnType::BigInt); assert_eq!( accessor.lookup_column(column.table_ref(), column.column_id()), Some(ColumnType::BigInt) ); - let column = ColumnRef::new(table_ref_2, "c".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref_2, "c".into(), ColumnType::BigInt); assert!(accessor .lookup_column(column.table_ref(), column.column_id()) .is_none()); @@ -250,8 +250,8 @@ fn we_can_access_schema_and_column_names() { assert_eq!( accessor.lookup_schema(table_ref_1), vec![ - ("a".parse().unwrap(), ColumnType::BigInt), - ("b".parse().unwrap(), ColumnType::VarChar) + ("a".into(), ColumnType::BigInt), + ("b".into(), ColumnType::VarChar) ] ); assert_eq!(accessor.get_column_names(table_ref_1), vec!["a", "b"]); @@ -273,14 +273,14 @@ fn we_can_correctly_update_offsets() { let mut accessor2 = TableTestAccessor::::new_empty_with_setup(()); accessor2.add_table(table_ref, data, offset); - let column = ColumnRef::new(table_ref, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref, "a".into(), ColumnType::BigInt); assert_ne!( - accessor1.get_commitment(column), + accessor1.get_commitment(column.clone()), accessor2.get_commitment(column) ); - let column = ColumnRef::new(table_ref, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref, "b".into(), ColumnType::BigInt); assert_ne!( - accessor1.get_commitment(column), + accessor1.get_commitment(column.clone()), accessor2.get_commitment(column) ); @@ -289,14 +289,14 @@ fn we_can_correctly_update_offsets() { accessor1.update_offset(table_ref, offset); - let column = ColumnRef::new(table_ref, "a".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref, "a".into(), ColumnType::BigInt); assert_eq!( - accessor1.get_commitment(column), + accessor1.get_commitment(column.clone()), accessor2.get_commitment(column) ); - let column = ColumnRef::new(table_ref, "b".parse().unwrap(), ColumnType::BigInt); + let column = ColumnRef::new(table_ref, "b".into(), ColumnType::BigInt); assert_eq!( - accessor1.get_commitment(column), + accessor1.get_commitment(column.clone()), accessor2.get_commitment(column) ); diff --git a/crates/proof-of-sql/src/base/database/table_utility.rs b/crates/proof-of-sql/src/base/database/table_utility.rs index 8c54028b4..445d7e2bd 100644 --- a/crates/proof-of-sql/src/base/database/table_utility.rs +++ b/crates/proof-of-sql/src/base/database/table_utility.rs @@ -19,11 +19,8 @@ use super::{Column, Table, TableOptions}; use crate::base::scalar::Scalar; use alloc::{string::String, vec::Vec}; use bumpalo::Bump; -use core::ops::Deref; -use proof_of_sql_parser::{ - posql_time::{PoSQLTimeUnit, PoSQLTimeZone}, - Identifier, -}; +use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use sqlparser::ast::Ident; /// Creates an [`Table`] from a list of `(Identifier, Column)` pairs. /// This is a convenience wrapper around [`Table::try_from_iter`] primarily for use in tests and @@ -49,7 +46,7 @@ use proof_of_sql_parser::{ /// # Panics /// - Panics if converting the iterator into an `Table<'a, S>` fails. pub fn table<'a, S: Scalar>( - iter: impl IntoIterator)>, + iter: impl IntoIterator)>, ) -> Table<'a, S> { Table::try_from_iter(iter).unwrap() } @@ -61,7 +58,7 @@ pub fn table<'a, S: Scalar>( /// # Panics /// - Panics if the given row count doesn't match the number of rows in any of the columns. pub fn table_with_row_count<'a, S: Scalar>( - iter: impl IntoIterator)>, + iter: impl IntoIterator)>, row_count: usize, ) -> Table<'a, S> { Table::try_from_iter_with_options(iter, TableOptions::new(Some(row_count))).unwrap() @@ -79,15 +76,15 @@ pub fn table_with_row_count<'a, S: Scalar>( /// ]); ///``` /// # Panics -/// - Panics if `name.parse()` fails to convert the name into an `Identifier`. +/// - Panics if `name.parse()()` fails to convert the name into an `Identifier`. pub fn borrowed_tinyint( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let transformed_data: Vec = data.into_iter().map(Into::into).collect(); let alloc_data = alloc.alloc_slice_copy(&transformed_data); - (name.parse().unwrap(), Column::TinyInt(alloc_data)) + (name.into(), Column::TinyInt(alloc_data)) } /// Creates a `(Identifier, Column)` pair for a smallint column. @@ -104,15 +101,15 @@ pub fn borrowed_tinyint( /// ``` /// /// # Panics -/// - Panics if `name.parse()` fails to convert the name into an `Identifier`. +/// - Panics if `name.parse()()` fails to convert the name into an `Identifier`. pub fn borrowed_smallint( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let transformed_data: Vec = data.into_iter().map(Into::into).collect(); let alloc_data = alloc.alloc_slice_copy(&transformed_data); - (name.parse().unwrap(), Column::SmallInt(alloc_data)) + (name.into(), Column::SmallInt(alloc_data)) } /// Creates a `(Identifier, Column)` pair for an int column. @@ -129,15 +126,15 @@ pub fn borrowed_smallint( /// ``` /// /// # Panics -/// - Panics if `name.parse()` fails to convert the name into an `Identifier`. +/// - Panics if `name.parse()()` fails to convert the name into an `Identifier`. pub fn borrowed_int( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let transformed_data: Vec = data.into_iter().map(Into::into).collect(); let alloc_data = alloc.alloc_slice_copy(&transformed_data); - (name.parse().unwrap(), Column::Int(alloc_data)) + (name.into(), Column::Int(alloc_data)) } /// Creates a `(Identifier, Column)` pair for a bigint column. @@ -156,13 +153,13 @@ pub fn borrowed_int( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn borrowed_bigint( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let transformed_data: Vec = data.into_iter().map(Into::into).collect(); let alloc_data = alloc.alloc_slice_copy(&transformed_data); - (name.parse().unwrap(), Column::BigInt(alloc_data)) + (name.into(), Column::BigInt(alloc_data)) } /// Creates a `(Identifier, Column)` pair for a boolean column. @@ -181,13 +178,13 @@ pub fn borrowed_bigint( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn borrowed_boolean( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let transformed_data: Vec = data.into_iter().map(Into::into).collect(); let alloc_data = alloc.alloc_slice_copy(&transformed_data); - (name.parse().unwrap(), Column::Boolean(alloc_data)) + (name.into(), Column::Boolean(alloc_data)) } /// Creates a `(Identifier, Column)` pair for an int128 column. @@ -206,13 +203,13 @@ pub fn borrowed_boolean( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn borrowed_int128( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let transformed_data: Vec = data.into_iter().map(Into::into).collect(); let alloc_data = alloc.alloc_slice_copy(&transformed_data); - (name.parse().unwrap(), Column::Int128(alloc_data)) + (name.into(), Column::Int128(alloc_data)) } /// Creates a `(Identifier, Column)` pair for a scalar column. @@ -231,13 +228,13 @@ pub fn borrowed_int128( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn borrowed_scalar( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let transformed_data: Vec = data.into_iter().map(Into::into).collect(); let alloc_data = alloc.alloc_slice_copy(&transformed_data); - (name.parse().unwrap(), Column::Scalar(alloc_data)) + (name.into(), Column::Scalar(alloc_data)) } /// Creates a `(Identifier, Column)` pair for a varchar column. @@ -255,10 +252,10 @@ pub fn borrowed_scalar( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn borrowed_varchar<'a, S: Scalar>( - name: impl Deref, + name: impl Into, data: impl IntoIterator>, alloc: &'a Bump, -) -> (Identifier, Column<'a, S>) { +) -> (Ident, Column<'a, S>) { let strings: Vec<&'a str> = data .into_iter() .map(|item| { @@ -269,10 +266,7 @@ pub fn borrowed_varchar<'a, S: Scalar>( let alloc_strings = alloc.alloc_slice_clone(&strings); let scalars: Vec = strings.iter().map(|s| (*s).into()).collect(); let alloc_scalars = alloc.alloc_slice_copy(&scalars); - ( - name.parse().unwrap(), - Column::VarChar((alloc_strings, alloc_scalars)), - ) + (name.into(), Column::VarChar((alloc_strings, alloc_scalars))) } /// Creates a `(Identifier, Column)` pair for a decimal75 column. @@ -291,16 +285,16 @@ pub fn borrowed_varchar<'a, S: Scalar>( /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. /// - Panics if creating the `Precision` from the specified precision value fails. pub fn borrowed_decimal75( - name: impl Deref, + name: impl Into, precision: u8, scale: i8, data: impl IntoIterator>, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let transformed_data: Vec = data.into_iter().map(Into::into).collect(); let alloc_data = alloc.alloc_slice_copy(&transformed_data); ( - name.parse().unwrap(), + name.into(), Column::Decimal75( crate::base::math::decimal::Precision::new(precision).unwrap(), scale, @@ -337,16 +331,16 @@ pub fn borrowed_decimal75( /// # Panics /// - Panics if `name.parse()` fails to convert the name into an `Identifier`. pub fn borrowed_timestamptz( - name: impl Deref, + name: impl Into, time_unit: PoSQLTimeUnit, timezone: PoSQLTimeZone, data: impl IntoIterator, alloc: &Bump, -) -> (Identifier, Column<'_, S>) { +) -> (Ident, Column<'_, S>) { let vec_data: Vec = data.into_iter().collect(); let alloc_data = alloc.alloc_slice_copy(&vec_data); ( - name.parse().unwrap(), + name.into(), Column::TimestampTZ(time_unit, timezone, alloc_data), ) } diff --git a/crates/proof-of-sql/src/base/database/test_schema_accessor.rs b/crates/proof-of-sql/src/base/database/test_schema_accessor.rs index 924e5292c..044f32d43 100644 --- a/crates/proof-of-sql/src/base/database/test_schema_accessor.rs +++ b/crates/proof-of-sql/src/base/database/test_schema_accessor.rs @@ -1,30 +1,29 @@ use super::{ColumnType, SchemaAccessor, TableRef}; use crate::base::map::IndexMap; -use proof_of_sql_parser::Identifier; - +use sqlparser::ast::Ident; /// A simple in-memory `SchemaAccessor` for testing intermediate AST -> Provable AST conversion. pub struct TestSchemaAccessor { - schemas: IndexMap>, + schemas: IndexMap>, } impl TestSchemaAccessor { /// Create a new `TestSchemaAccessor` with the given schema. - pub fn new(schemas: IndexMap>) -> Self { + pub fn new(schemas: IndexMap>) -> Self { Self { schemas } } } impl SchemaAccessor for TestSchemaAccessor { - fn lookup_column(&self, table_ref: TableRef, column_id: Identifier) -> Option { + fn lookup_column(&self, table_ref: TableRef, column_id: Ident) -> Option { self.schemas.get(&table_ref)?.get(&column_id).copied() } - fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Identifier, ColumnType)> { + fn lookup_schema(&self, table_ref: TableRef) -> Vec<(Ident, ColumnType)> { self.schemas .get(&table_ref) .unwrap_or(&IndexMap::default()) .iter() - .map(|(id, col)| (*id, *col)) + .map(|(id, col)| (id.clone(), *col)) .collect() } } @@ -39,11 +38,11 @@ mod tests { let table2: TableRef = TableRef::new("schema.table2".parse().unwrap()); TestSchemaAccessor::new(indexmap! { table1 => indexmap! { - "col1".parse().unwrap() => ColumnType::BigInt, - "col2".parse().unwrap() => ColumnType::VarChar, + "col1".into() => ColumnType::BigInt, + "col2".into() => ColumnType::VarChar, }, table2 => indexmap! { - "col1".parse().unwrap() => ColumnType::BigInt, + "col1".into() => ColumnType::BigInt, }, }) } @@ -55,35 +54,23 @@ mod tests { let table2: TableRef = TableRef::new("schema.table2".parse().unwrap()); let not_a_table: TableRef = TableRef::new("schema.not_a_table".parse().unwrap()); assert_eq!( - accessor.lookup_column(table1, "col1".parse().unwrap()), + accessor.lookup_column(table1, "col1".into()), Some(ColumnType::BigInt) ); assert_eq!( - accessor.lookup_column(table1, "col2".parse().unwrap()), + accessor.lookup_column(table1, "col2".into()), Some(ColumnType::VarChar) ); + assert_eq!(accessor.lookup_column(table1, "not_a_col".into()), None); assert_eq!( - accessor.lookup_column(table1, "not_a_col".parse().unwrap()), - None - ); - assert_eq!( - accessor.lookup_column(table2, "col1".parse().unwrap()), + accessor.lookup_column(table2, "col1".into()), Some(ColumnType::BigInt) ); + assert_eq!(accessor.lookup_column(table2, "col2".into()), None); + assert_eq!(accessor.lookup_column(not_a_table, "col1".into()), None); + assert_eq!(accessor.lookup_column(not_a_table, "col2".into()), None); assert_eq!( - accessor.lookup_column(table2, "col2".parse().unwrap()), - None - ); - assert_eq!( - accessor.lookup_column(not_a_table, "col1".parse().unwrap()), - None - ); - assert_eq!( - accessor.lookup_column(not_a_table, "col2".parse().unwrap()), - None - ); - assert_eq!( - accessor.lookup_column(not_a_table, "not_a_col".parse().unwrap()), + accessor.lookup_column(not_a_table, "not_a_col".into()), None ); } @@ -97,13 +84,13 @@ mod tests { assert_eq!( accessor.lookup_schema(table1), vec![ - ("col1".parse().unwrap(), ColumnType::BigInt), - ("col2".parse().unwrap(), ColumnType::VarChar), + ("col1".into(), ColumnType::BigInt), + ("col2".into(), ColumnType::VarChar), ] ); assert_eq!( accessor.lookup_schema(table2), - vec![("col1".parse().unwrap(), ColumnType::BigInt),] + vec![("col1".into(), ColumnType::BigInt),] ); assert_eq!(accessor.lookup_schema(not_a_table), vec![]); } diff --git a/crates/proof-of-sql/src/base/database/union_util.rs b/crates/proof-of-sql/src/base/database/union_util.rs index 9b37d1ace..14537ec1b 100644 --- a/crates/proof-of-sql/src/base/database/union_util.rs +++ b/crates/proof-of-sql/src/base/database/union_util.rs @@ -307,16 +307,16 @@ mod tests { // Column names don't matter let table0 = Table::<'_, TestScalar>::try_new_with_options( IndexMap::from_iter(vec![ - ("a".parse().unwrap(), Column::BigInt(&[1, 2, 3])), - ("b".parse().unwrap(), Column::BigInt(&[4, 5, 6])), + ("a".into(), Column::BigInt(&[1, 2, 3])), + ("b".into(), Column::BigInt(&[4, 5, 6])), ]), TableOptions::new(Some(3)), ) .unwrap(); let table1 = Table::<'_, TestScalar>::try_new_with_options( IndexMap::from_iter(vec![ - ("c".parse().unwrap(), Column::BigInt(&[7, 8, 9])), - ("d".parse().unwrap(), Column::BigInt(&[10, 11, 12])), + ("c".into(), Column::BigInt(&[7, 8, 9])), + ("d".into(), Column::BigInt(&[10, 11, 12])), ]), TableOptions::new(Some(3)), ) @@ -325,8 +325,8 @@ mod tests { &[table0, table1], &alloc, vec![ - ColumnField::new("e".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("f".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("e".into(), ColumnType::BigInt), + ColumnField::new("f".into(), ColumnType::BigInt), ], ) .unwrap(); @@ -334,8 +334,8 @@ mod tests { result, Table::<'_, TestScalar>::try_new_with_options( IndexMap::from_iter(vec![ - ("e".parse().unwrap(), Column::BigInt(&[1, 2, 3, 7, 8, 9])), - ("f".parse().unwrap(), Column::BigInt(&[4, 5, 6, 10, 11, 12])), + ("e".into(), Column::BigInt(&[1, 2, 3, 7, 8, 9])), + ("f".into(), Column::BigInt(&[4, 5, 6, 10, 11, 12])), ]), TableOptions::new(Some(6)), ) @@ -350,16 +350,16 @@ mod tests { // regardless of whether the tables have the same schema let table0 = Table::<'_, TestScalar>::try_new_with_options( IndexMap::from_iter(vec![ - ("a".parse().unwrap(), Column::BigInt(&[1, 2, 3])), - ("b".parse().unwrap(), Column::BigInt(&[4, 5, 6])), + ("a".into(), Column::BigInt(&[1, 2, 3])), + ("b".into(), Column::BigInt(&[4, 5, 6])), ]), TableOptions::new(Some(3)), ) .unwrap(); let table1 = Table::<'_, TestScalar>::try_new_with_options( IndexMap::from_iter(vec![ - ("c".parse().unwrap(), Column::BigInt(&[7, 8, 9])), - ("d".parse().unwrap(), Column::BigInt(&[10, 11, 12])), + ("c".into(), Column::BigInt(&[7, 8, 9])), + ("d".into(), Column::BigInt(&[10, 11, 12])), ]), TableOptions::new(Some(3)), ) @@ -368,8 +368,8 @@ mod tests { &[table0, table1], &alloc, vec![ - ColumnField::new("e".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("f".parse().unwrap(), ColumnType::Int), + ColumnField::new("e".into(), ColumnType::BigInt), + ColumnField::new("f".into(), ColumnType::Int), ], ); assert!(matches!( diff --git a/crates/proof-of-sql/src/base/mod.rs b/crates/proof-of-sql/src/base/mod.rs index 657b855d1..e421f4c53 100644 --- a/crates/proof-of-sql/src/base/mod.rs +++ b/crates/proof-of-sql/src/base/mod.rs @@ -19,6 +19,7 @@ mod serialize; pub(crate) use serialize::{impl_serde_for_ark_serde_checked, impl_serde_for_ark_serde_unchecked}; pub(crate) mod map; pub(crate) mod slice_ops; +pub(crate) mod sqlparser; mod rayon_cfg; pub(crate) use rayon_cfg::if_rayon; diff --git a/crates/proof-of-sql/src/base/sqlparser.rs b/crates/proof-of-sql/src/base/sqlparser.rs new file mode 100644 index 000000000..df545217c --- /dev/null +++ b/crates/proof-of-sql/src/base/sqlparser.rs @@ -0,0 +1,5 @@ +/// Construct an `Ident` from a string. +#[cfg(test)] +pub(crate) fn ident(name: &str) -> sqlparser::ast::Ident { + name.into() +} diff --git a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_expr.rs b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_expr.rs index 067112a79..b7f04d738 100644 --- a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_expr.rs +++ b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_expr.rs @@ -80,14 +80,14 @@ mod tests { fn we_can_serialize_a_column_expr() { let table_ref: TableRef = "namespace.table".parse().unwrap(); let column_0_ref: ColumnRef = - ColumnRef::new(table_ref, "column_0".parse().unwrap(), ColumnType::BigInt); + ColumnRef::new(table_ref, "column_0".into(), ColumnType::BigInt); let column_1_ref: ColumnRef = - ColumnRef::new(table_ref, "column_1".parse().unwrap(), ColumnType::BigInt); + ColumnRef::new(table_ref, "column_1".into(), ColumnType::BigInt); let column_2_ref: ColumnRef = - ColumnRef::new(table_ref, "column_2".parse().unwrap(), ColumnType::BigInt); + ColumnRef::new(table_ref, "column_2".into(), ColumnType::BigInt); let serializer = DynProofPlanSerializer::::try_new( indexset! {}, - indexset! { column_0_ref, column_1_ref }, + indexset! { column_0_ref.clone(), column_1_ref.clone() }, ) .unwrap(); @@ -178,10 +178,12 @@ mod tests { fn we_can_serialize_an_equals_expr() { let table_ref: TableRef = "namespace.table".parse().unwrap(); let column_0_ref: ColumnRef = - ColumnRef::new(table_ref, "column_0".parse().unwrap(), ColumnType::BigInt); - let serializer = - DynProofPlanSerializer::::try_new(indexset! {}, indexset! { column_0_ref }) - .unwrap(); + ColumnRef::new(table_ref, "column_0".into(), ColumnType::BigInt); + let serializer = DynProofPlanSerializer::::try_new( + indexset! {}, + indexset! { column_0_ref.clone() }, + ) + .unwrap(); let lhs = DynProofExpr::Column(ColumnExpr::new(column_0_ref)); let rhs = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); @@ -229,10 +231,12 @@ mod tests { fn we_cannot_serialize_an_unsupported_expr() { let table_ref: TableRef = "namespace.table".parse().unwrap(); let column_0_ref: ColumnRef = - ColumnRef::new(table_ref, "column_0".parse().unwrap(), ColumnType::BigInt); - let serializer = - DynProofPlanSerializer::::try_new(indexset! {}, indexset! { column_0_ref }) - .unwrap(); + ColumnRef::new(table_ref, "column_0".into(), ColumnType::BigInt); + let serializer = DynProofPlanSerializer::::try_new( + indexset! {}, + indexset! { column_0_ref.clone() }, + ) + .unwrap(); let lhs = DynProofExpr::Column(ColumnExpr::new(column_0_ref)); let rhs = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); diff --git a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_plan.rs b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_plan.rs index cfb418e83..a6b45db24 100644 --- a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_plan.rs +++ b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_plan.rs @@ -89,7 +89,7 @@ mod tests { let aliased_expr = AliasedDynProofExpr { expr: expr.clone(), - alias: "alias".parse().unwrap(), + alias: "alias".into(), }; let bytes = serializer .clone() @@ -155,11 +155,11 @@ mod tests { let expr_c = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); let aliased_expr_0 = AliasedDynProofExpr { expr: expr_a.clone(), - alias: "alias_0".parse().unwrap(), + alias: "alias_0".into(), }; let aliased_expr_1 = AliasedDynProofExpr { expr: expr_b.clone(), - alias: "alias_1".parse().unwrap(), + alias: "alias_1".into(), }; let table_expr = TableExpr { table_ref }; diff --git a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serializer.rs b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serializer.rs index c6110c3ec..4b55c4b98 100644 --- a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serializer.rs +++ b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serializer.rs @@ -87,12 +87,12 @@ mod tests { let table_ref_1: TableRef = "namespace.table1".parse().unwrap(); let table_ref_2: TableRef = "namespace.table2".parse().unwrap(); let column_ref_1: ColumnRef = - ColumnRef::new(table_ref_1, "column1".parse().unwrap(), ColumnType::BigInt); + ColumnRef::new(table_ref_1, "column1".into(), ColumnType::BigInt); let column_ref_2: ColumnRef = - ColumnRef::new(table_ref_2, "column2".parse().unwrap(), ColumnType::BigInt); + ColumnRef::new(table_ref_2, "column2".into(), ColumnType::BigInt); let table_refs = indexset! { table_ref_1, table_ref_2 }; - let column_refs = indexset! { column_ref_1, column_ref_2 }; + let column_refs = indexset! { column_ref_1.clone(), column_ref_2.clone() }; let serializer = DynProofPlanSerializer::::try_new(table_refs, column_refs).unwrap(); assert_eq!( @@ -125,7 +125,7 @@ mod tests { .map(|i| { ColumnRef::new( table_ref, - format!("column{i}").parse().unwrap(), + format!("column{i}").as_str().into(), ColumnType::BigInt, ) }) diff --git a/crates/proof-of-sql/src/evm_compatibility/serialize_query_expr.rs b/crates/proof-of-sql/src/evm_compatibility/serialize_query_expr.rs index cddac175c..5a97ed2d2 100644 --- a/crates/proof-of-sql/src/evm_compatibility/serialize_query_expr.rs +++ b/crates/proof-of-sql/src/evm_compatibility/serialize_query_expr.rs @@ -70,7 +70,7 @@ mod tests { #[test] fn we_can_generate_serialized_proof_plan_for_query_expr() { let table_ref = "namespace.table".parse().unwrap(); - let identifier_alias = "alias".parse().unwrap(); + let identifier_alias = "alias".into(); let plan = DynProofPlan::Filter(FilterExec::new( vec![AliasedDynProofExpr { @@ -121,9 +121,9 @@ mod tests { #[test] fn we_can_generate_serialized_proof_plan_for_simple_filter() { let table_ref = "namespace.table".parse().unwrap(); - let identifier_a = "a".parse().unwrap(); - let identifier_b = "b".parse().unwrap(); - let identifier_alias = "alias".parse().unwrap(); + let identifier_a = "a".into(); + let identifier_b = "b".into(); + let identifier_alias = "alias".into(); let column_ref_a = ColumnRef::new(table_ref, identifier_a, ColumnType::BigInt); let column_ref_b = ColumnRef::new(table_ref, identifier_b, ColumnType::BigInt); diff --git a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs index 6bf6e2b16..a0e76a031 100644 --- a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs @@ -21,27 +21,26 @@ use alloc::{borrow::ToOwned, boxed::Box, format, string::ToString}; use proof_of_sql_parser::{ intermediate_ast::{AggregationOperator, Expression, Literal}, posql_time::{PoSQLTimeUnit, PoSQLTimestampError}, - Identifier, }; -use sqlparser::ast::{BinaryOperator, UnaryOperator}; +use sqlparser::ast::{BinaryOperator, Ident, UnaryOperator}; /// Builder that enables building a `proofs::sql::proof_exprs::DynProofExpr` from /// a `proof_of_sql_parser::intermediate_ast::Expression`. pub struct DynProofExprBuilder<'a> { - column_mapping: &'a IndexMap, + column_mapping: &'a IndexMap, in_agg_scope: bool, } impl<'a> DynProofExprBuilder<'a> { /// Creates a new `DynProofExprBuilder` with the given column mapping. - pub fn new(column_mapping: &'a IndexMap) -> Self { + pub fn new(column_mapping: &'a IndexMap) -> Self { Self { column_mapping, in_agg_scope: false, } } /// Creates a new `DynProofExprBuilder` with the given column mapping and within aggregation scope. - pub(crate) fn new_agg(column_mapping: &'a IndexMap) -> Self { + pub(crate) fn new_agg(column_mapping: &'a IndexMap) -> Self { Self { column_mapping, in_agg_scope: true, @@ -58,7 +57,7 @@ impl<'a> DynProofExprBuilder<'a> { impl DynProofExprBuilder<'_> { fn visit_expr(&self, expr: &Expression) -> Result { match expr { - Expression::Column(identifier) => self.visit_column(*identifier), + Expression::Column(identifier) => self.visit_column((*identifier).into()), Expression::Literal(lit) => self.visit_literal(lit), Expression::Binary { op, left, right } => { self.visit_binary_expr(&(*op).into(), left, right) @@ -71,13 +70,14 @@ impl DynProofExprBuilder<'_> { } } - fn visit_column(&self, identifier: Identifier) -> Result { + fn visit_column(&self, identifier: Ident) -> Result { Ok(DynProofExpr::Column(ColumnExpr::new( - *self.column_mapping.get(&identifier).ok_or( - ConversionError::MissingColumnWithoutTable { + self.column_mapping + .get(&identifier) + .ok_or(ConversionError::MissingColumnWithoutTable { identifier: Box::new(identifier), - }, - )?, + })? + .clone(), ))) } diff --git a/crates/proof-of-sql/src/sql/parse/enriched_expr.rs b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs index 43938ac4e..1772cbf3b 100644 --- a/crates/proof-of-sql/src/sql/parse/enriched_expr.rs +++ b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs @@ -4,10 +4,8 @@ use crate::{ sql::proof_exprs::DynProofExpr, }; use alloc::boxed::Box; -use proof_of_sql_parser::{ - intermediate_ast::{AliasedResultExpr, Expression}, - Identifier, -}; +use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, Expression}; +use sqlparser::ast::Ident; /// Enriched expression /// /// An enriched expression consists of an `proof_of_sql_parser::intermediate_ast::AliasedResultExpr` @@ -26,10 +24,7 @@ impl EnrichedExpr { /// If the expression is not provable, the `dyn_proof_expr` will be `None`. /// Otherwise the `dyn_proof_expr` will contain the provable expression plan /// and the `residue_expression` will contain the remaining expression. - pub fn new( - expression: AliasedResultExpr, - column_mapping: &IndexMap, - ) -> Self { + pub fn new(expression: AliasedResultExpr, column_mapping: &IndexMap) -> Self { // TODO: Using new_agg (ironically) disables aggregations in `QueryExpr` for now. // Re-enable aggregations when we add `GroupByExec` generalizations. let res_dyn_proof_expr = @@ -56,8 +51,10 @@ impl EnrichedExpr { /// /// Since we plan to support unaliased expressions in the future, this method returns an `Option`. #[allow(dead_code)] - pub fn get_alias(&self) -> Option<&Identifier> { - self.residue_expression.try_as_identifier() + pub fn get_alias(&self) -> Option { + self.residue_expression + .try_as_identifier() + .map(|identifier| Ident::new(identifier.as_str())) } /// Is the `EnrichedExpr` provable diff --git a/crates/proof-of-sql/src/sql/parse/error.rs b/crates/proof-of-sql/src/sql/parse/error.rs index 908749c2d..6e738ee5e 100644 --- a/crates/proof-of-sql/src/sql/parse/error.rs +++ b/crates/proof-of-sql/src/sql/parse/error.rs @@ -8,8 +8,9 @@ use alloc::{ string::{String, ToString}, }; use core::result::Result; -use proof_of_sql_parser::{posql_time::PoSQLTimestampError, Identifier, ResourceId}; +use proof_of_sql_parser::{posql_time::PoSQLTimestampError, ResourceId}; use snafu::Snafu; +use sqlparser::ast::Ident; /// Errors from converting an intermediate AST into a provable AST. #[derive(Snafu, Debug, PartialEq, Eq)] @@ -18,7 +19,7 @@ pub enum ConversionError { /// The column is missing in the table MissingColumn { /// The missing column identifier - identifier: Box, + identifier: Box, /// The table resource id resource_id: Box, }, @@ -27,7 +28,7 @@ pub enum ConversionError { /// The column is missing (without table information) MissingColumnWithoutTable { /// The missing column identifier - identifier: Box, + identifier: Box, }, #[snafu(display("Expected '{expected}' but found '{actual}'"))] @@ -146,6 +147,12 @@ pub enum ConversionError { /// The operator that is unsupported message: String, }, + /// Errors in converting `Ident` to `Identifier` + #[snafu(display("Failed to convert `Ident` to `Identifier`: {error}"))] + IdentifierConversionError { + /// The underlying error message + error: String, + }, } impl From for ConversionError { diff --git a/crates/proof-of-sql/src/sql/parse/filter_exec_builder.rs b/crates/proof-of-sql/src/sql/parse/filter_exec_builder.rs index 4d59681db..5512d051c 100644 --- a/crates/proof-of-sql/src/sql/parse/filter_exec_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/filter_exec_builder.rs @@ -11,18 +11,18 @@ use crate::{ }; use alloc::{boxed::Box, vec, vec::Vec}; use itertools::Itertools; -use proof_of_sql_parser::{intermediate_ast::Expression, Identifier}; - +use proof_of_sql_parser::intermediate_ast::Expression; +use sqlparser::ast::Ident; pub struct FilterExecBuilder { table_expr: Option, where_expr: Option, filter_result_expr_list: Vec, - column_mapping: IndexMap, + column_mapping: IndexMap, } // Public interface impl FilterExecBuilder { - pub fn new(column_mapping: IndexMap) -> Self { + pub fn new(column_mapping: IndexMap) -> Self { Self { table_expr: None, where_expr: None, @@ -56,7 +56,7 @@ impl FilterExecBuilder { if let Some(plan) = &enriched_expr.dyn_proof_expr { self.filter_result_expr_list.push(AliasedDynProofExpr { expr: plan.clone(), - alias: enriched_expr.residue_expression.alias, + alias: enriched_expr.residue_expression.alias.into(), }); } else { has_nonprovable_column = true; @@ -68,8 +68,8 @@ impl FilterExecBuilder { for alias in self.column_mapping.keys().sorted() { let column_ref = self.column_mapping.get(alias).unwrap(); self.filter_result_expr_list.push(AliasedDynProofExpr { - expr: DynProofExpr::new_column(*column_ref), - alias: *alias, + expr: DynProofExpr::new_column(column_ref.clone()), + alias: alias.clone(), }); } } diff --git a/crates/proof-of-sql/src/sql/parse/query_context.rs b/crates/proof-of-sql/src/sql/parse/query_context.rs index 9d87661f0..3c1b7c551 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context.rs @@ -10,10 +10,10 @@ use crate::{ }, }; use alloc::{borrow::ToOwned, boxed::Box, string::ToString, vec::Vec}; -use proof_of_sql_parser::{ - intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression, OrderBy, Slice}, - Identifier, +use proof_of_sql_parser::intermediate_ast::{ + AggregationOperator, AliasedResultExpr, Expression, OrderBy, Slice, }; +use sqlparser::ast::Ident; #[derive(Default, Debug)] pub struct QueryContext { @@ -25,12 +25,12 @@ pub struct QueryContext { in_result_scope: bool, has_visited_group_by: bool, order_by_exprs: Vec, - group_by_exprs: Vec, + group_by_exprs: Vec, where_expr: Option>, - result_column_set: IndexSet, + result_column_set: IndexSet, res_aliased_exprs: Vec, - column_mapping: IndexMap, - first_result_col_out_agg_scope: Option, + column_mapping: IndexMap, + first_result_col_out_agg_scope: Option, } impl QueryContext { @@ -100,15 +100,15 @@ impl QueryContext { self.agg_counter > 0 || !self.group_by_exprs.is_empty() } - pub fn push_column_ref(&mut self, column: Identifier, column_ref: ColumnRef) { + pub fn push_column_ref(&mut self, column: Ident, column_ref: ColumnRef) { self.col_ref_counter += 1; - self.push_result_column_ref(column); + self.push_result_column_ref(column.clone()); self.column_mapping.insert(column, column_ref); } - fn push_result_column_ref(&mut self, column: Identifier) { + fn push_result_column_ref(&mut self, column: Ident) { if self.is_in_result_scope() { - self.result_column_set.insert(column); + self.result_column_set.insert(column.clone()); if !self.is_in_agg_scope() && self.first_result_col_out_agg_scope.is_none() { self.first_result_col_out_agg_scope = Some(column); @@ -124,13 +124,13 @@ impl QueryContext { Ok(()) } - pub fn set_group_by_exprs(&mut self, exprs: Vec) { + pub fn set_group_by_exprs(&mut self, exprs: Vec) { self.group_by_exprs = exprs; // Add the group by columns to the result column set // to ensure their integrity in the filter expression. for group_column in &self.group_by_exprs { - self.result_column_set.insert(*group_column); + self.result_column_set.insert(group_column.clone()); } self.has_visited_group_by = true; @@ -140,7 +140,7 @@ impl QueryContext { self.order_by_exprs = order_by_exprs; } - pub fn is_in_group_by_exprs(&self, column: &Identifier) -> ConversionResult { + pub fn is_in_group_by_exprs(&self, column: &Ident) -> ConversionResult { // Non-aggregated result column references must be included in the group by statement. if self.group_by_exprs.is_empty() || self.is_in_agg_scope() || !self.is_in_result_scope() { return Ok(false); @@ -184,7 +184,11 @@ impl QueryContext { && self.first_result_col_out_agg_scope.is_some() { return Err(ConversionError::InvalidGroupByColumnRef { - column: self.first_result_col_out_agg_scope.unwrap().to_string(), + column: self + .first_result_col_out_agg_scope + .as_ref() + .unwrap() + .to_string(), }); } @@ -209,15 +213,15 @@ impl QueryContext { &self.slice_expr } - pub fn get_group_by_exprs(&self) -> &[Identifier] { + pub fn get_group_by_exprs(&self) -> &[Ident] { &self.group_by_exprs } - pub fn get_result_column_set(&self) -> IndexSet { + pub fn get_result_column_set(&self) -> IndexSet { self.result_column_set.clone() } - pub fn get_column_mapping(&self) -> IndexMap { + pub fn get_column_mapping(&self) -> IndexMap { self.column_mapping.clone() } } @@ -247,10 +251,10 @@ impl TryFrom<&QueryContext> for Option { .column_mapping .get(expr) .ok_or(ConversionError::MissingColumn { - identifier: Box::new(*expr), + identifier: Box::new((expr).clone()), resource_id: Box::new(resource_id), }) - .map(|column_ref| ColumnExpr::new(*column_ref)) + .map(|column_ref| ColumnExpr::new(column_ref.clone())) }) .collect::, ConversionError>>()?; // For a query to be provable the result columns must be of one of three kinds below: @@ -272,7 +276,7 @@ impl TryFrom<&QueryContext> for Option { .zip(res_group_by_columns.iter()) .all(|(ident, res)| { if let Expression::Column(res_ident) = *res.expr { - res_ident == *ident + Ident::from(res_ident) == *ident } else { false } @@ -292,7 +296,7 @@ impl TryFrom<&QueryContext> for Option { res_dyn_proof_expr .ok() .map(|dyn_proof_expr| AliasedDynProofExpr { - alias: res.alias, + alias: res.alias.into(), expr: dyn_proof_expr, }) } else { @@ -317,7 +321,7 @@ impl TryFrom<&QueryContext> for Option { Ok(Some(GroupByExec::new( group_by_exprs, sum_expr.expect("the none case was just checked"), - count_column.alias, + count_column.alias.into(), table, where_clause, ))) diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 45c33caa1..708cb8236 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -22,6 +22,7 @@ pub struct QueryContextBuilder<'a> { context: QueryContext, schema_accessor: &'a dyn SchemaAccessor, } +use sqlparser::ast::Ident; // Public interface impl<'a> QueryContextBuilder<'a> { @@ -41,10 +42,9 @@ impl<'a> QueryContextBuilder<'a> { assert_eq!(table_expr.len(), 1); match *table_expr[0] { TableExpression::Named { table, schema } => { - self.context.set_table_ref(TableRef::new(ResourceId::new( - schema.unwrap_or(default_schema), - table, - ))); + let schema_identifier = schema.unwrap_or(default_schema); + self.context + .set_table_ref(TableRef::new(ResourceId::new(schema_identifier, table))); } } self @@ -87,12 +87,9 @@ impl<'a> QueryContextBuilder<'a> { self } - pub fn visit_group_by_exprs( - mut self, - group_by_exprs: Vec, - ) -> ConversionResult { + pub fn visit_group_by_exprs(mut self, group_by_exprs: Vec) -> ConversionResult { for id in &group_by_exprs { - self.visit_column_identifier(*id)?; + self.visit_column_identifier(id)?; } self.context.set_group_by_exprs(group_by_exprs); Ok(self) @@ -110,7 +107,7 @@ impl<'a> QueryContextBuilder<'a> { clippy::missing_panics_doc, reason = "The assertion ensures there is at least one column, and this is a fundamental requirement for schema retrieval." )] - fn lookup_schema(&self) -> Vec<(Identifier, ColumnType)> { + fn lookup_schema(&self) -> Vec<(Ident, ColumnType)> { let table_ref = self.context.get_table_ref(); let columns = self.schema_accessor.lookup_schema(*table_ref); assert!(!columns.is_empty(), "At least one column must exist"); @@ -119,8 +116,13 @@ impl<'a> QueryContextBuilder<'a> { fn visit_select_all_expr(&mut self) -> ConversionResult<()> { for (column_name, _) in self.lookup_schema() { - let col_expr = Expression::Column(column_name); - self.visit_aliased_expr(AliasedResultExpr::new(col_expr, column_name))?; + let column_identifier = Identifier::try_from(column_name).map_err(|e| { + ConversionError::IdentifierConversionError { + error: format!("Failed to convert Ident to Identifier: {e}"), + } + })?; + let col_expr = Expression::Column(column_identifier); + self.visit_aliased_expr(AliasedResultExpr::new(col_expr, column_identifier))?; } Ok(()) } @@ -153,7 +155,7 @@ impl<'a> QueryContextBuilder<'a> { _ => panic!("Must be a column expression"), }; - self.visit_column_identifier(identifier) + self.visit_column_identifier(&identifier.into()) } fn visit_binary_expr( @@ -255,18 +257,20 @@ impl<'a> QueryContextBuilder<'a> { } } - fn visit_column_identifier(&mut self, column_name: Identifier) -> ConversionResult { + fn visit_column_identifier(&mut self, column_name: &Ident) -> ConversionResult { let table_ref = self.context.get_table_ref(); - let column_type = self.schema_accessor.lookup_column(*table_ref, column_name); + let column_type = self + .schema_accessor + .lookup_column(*table_ref, column_name.clone()); let column_type = column_type.ok_or_else(|| ConversionError::MissingColumn { - identifier: Box::new(column_name), + identifier: Box::new(column_name.clone()), resource_id: Box::new(table_ref.resource_id()), })?; - let column = ColumnRef::new(*table_ref, column_name, column_type); + let column = ColumnRef::new(*table_ref, column_name.clone(), column_type); - self.context.push_column_ref(column_name, column); + self.context.push_column_ref(column_name.clone(), column); Ok(column_type) } diff --git a/crates/proof-of-sql/src/sql/parse/query_expr.rs b/crates/proof-of-sql/src/sql/parse/query_expr.rs index 1b5b733c8..798406c45 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr.rs @@ -5,14 +5,15 @@ use crate::{ parse::ConversionResult, postprocessing::{ GroupByPostprocessing, OrderByPostprocessing, OwnedTablePostprocessing, - SelectPostprocessing, SlicePostprocessing, + PostprocessingError, SelectPostprocessing, SlicePostprocessing, }, proof_plans::{DynProofPlan, GroupByExec}, }, }; -use alloc::{fmt, vec, vec::Vec}; +use alloc::{fmt, format, vec, vec::Vec}; use proof_of_sql_parser::{intermediate_ast::SetExpression, Identifier, SelectStatement}; use serde::{Deserialize, Serialize}; +use sqlparser::ast::Ident; #[derive(PartialEq, Serialize, Deserialize)] /// A `QueryExpr` represents a Proof of SQL query that can be executed against a database. @@ -34,6 +35,12 @@ impl fmt::Debug for QueryExpr { } } +pub fn convert_ident_to_identifier(ident: Ident) -> Result { + Identifier::try_from(ident).map_err(|e| PostprocessingError::IdentifierConversionError { + error: format!("Failed to convert Ident to Identifier: {e}"), + }) +} + impl QueryExpr { /// Creates a new `QueryExpr` with the given `DynProofPlan` and `OwnedTablePostprocessing`. #[must_use] @@ -47,7 +54,7 @@ impl QueryExpr { /// Parse an intermediate AST `SelectStatement` into a `QueryExpr`. pub fn try_new( ast: SelectStatement, - default_schema: Identifier, + default_schema: Ident, schema_accessor: &dyn SchemaAccessor, ) -> ConversionResult { let context = match *ast.expr { @@ -57,8 +64,8 @@ impl QueryExpr { where_expr, group_by, } => QueryContextBuilder::new(schema_accessor) - .visit_table_expr(&from, default_schema) - .visit_group_by_exprs(group_by)? + .visit_table_expr(&from, convert_ident_to_identifier(default_schema)?) + .visit_group_by_exprs(group_by.into_iter().map(Ident::from).collect())? .visit_result_exprs(result_exprs)? .visit_where_expr(where_expr)? .visit_order_by_exprs(ast.order_by) @@ -67,7 +74,6 @@ impl QueryExpr { }; let result_aliased_exprs = context.get_aliased_result_exprs()?.to_vec(); let group_by = context.get_group_by_exprs(); - // Figure out the basic postprocessing steps. let mut postprocessing = vec![]; let order_bys = context.get_order_by_exprs()?; diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index 0695e9d3e..177e925d6 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -19,8 +19,8 @@ use proof_of_sql_parser::{ add as padd, aliased_expr, col, count, count_all, lit, max, min, mul as pmul, sub as psub, sum, }, - Identifier, }; +use sqlparser::ast::Ident; /// # Panics /// @@ -40,7 +40,7 @@ fn invalid_query_to_provable_ast(table: TableRef, query: &str, accessor: &TestSc #[cfg(test)] pub fn schema_accessor_from_table_ref_with_schema( table: TableRef, - schema: IndexMap, + schema: IndexMap, ) -> TestSchemaAccessor { TestSchemaAccessor::new(indexmap! {table => schema}) } @@ -51,7 +51,7 @@ fn we_can_convert_an_ast_with_one_column() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 3", &accessor); @@ -72,7 +72,7 @@ fn we_can_convert_an_ast_with_one_column_and_i128_data() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::Int128, + "a".into() => ColumnType::Int128, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 3", &accessor); @@ -93,7 +93,7 @@ fn we_can_convert_an_ast_with_one_column_and_a_filter_by_a_string_literal() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::VarChar, + "a".into() => ColumnType::VarChar, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab where a = 'abc'", &accessor); @@ -114,8 +114,8 @@ fn we_cannot_convert_an_ast_with_duplicate_aliases() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, }, ); invalid_query_to_provable_ast( @@ -138,7 +138,7 @@ fn we_dont_have_duplicate_filter_result_expressions() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -163,9 +163,9 @@ fn we_can_convert_an_ast_with_two_columns() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, - "c".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "c".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a, b from sxt_tab where c = 123", &accessor); @@ -186,9 +186,9 @@ fn we_can_convert_an_ast_with_two_columns_and_arithmetic() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, - "c".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "c".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -219,8 +219,8 @@ fn we_can_parse_all_result_columns_with_select_star() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "b".parse().unwrap() => ColumnType::BigInt, - "a".parse().unwrap() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select * from sxt_tab where a = 3", &accessor); @@ -241,8 +241,8 @@ fn we_can_convert_an_ast_with_one_positive_cond() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab where b = +4", &accessor); @@ -263,8 +263,8 @@ fn we_can_convert_an_ast_with_one_not_equals_cond() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab where b <> +4", &accessor); @@ -285,8 +285,8 @@ fn we_can_convert_an_ast_with_one_negative_cond() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab where b <= -4", &accessor); @@ -307,9 +307,9 @@ fn we_can_convert_an_ast_with_cond_and() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, - "c".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "c".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -337,9 +337,9 @@ fn we_can_convert_an_ast_with_cond_or() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, - "c".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "c".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -370,9 +370,9 @@ fn we_can_convert_an_ast_with_conds_or_not() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, - "c".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "c".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -400,10 +400,10 @@ fn we_can_convert_an_ast_with_conds_not_and_or() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, - "c".parse().unwrap() => ColumnType::BigInt, - "f".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "c".into() => ColumnType::BigInt, + "f".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -443,7 +443,7 @@ fn we_can_convert_an_ast_with_the_min_i128_filter_value_and_const() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -471,7 +471,7 @@ fn we_can_convert_an_ast_with_the_max_i128_filter_value_and_const() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -499,8 +499,8 @@ fn we_can_convert_an_ast_using_an_aliased_column() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "b".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -531,7 +531,7 @@ fn we_cannot_convert_an_ast_with_a_nonexistent_column() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "b".parse().unwrap() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, }, ); invalid_query_to_provable_ast(t, "select * from sxt_tab where a = 3", &accessor); @@ -543,7 +543,7 @@ fn we_cannot_convert_an_ast_with_a_column_type_different_than_equal_literal() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "b".parse().unwrap() => ColumnType::VarChar, + "b".into() => ColumnType::VarChar, }, ); invalid_query_to_provable_ast(t, "select * from sxt_tab where b = 123", &accessor); @@ -555,7 +555,7 @@ fn we_can_convert_an_ast_with_a_schema() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from eth.sxt_tab where a = 3", &accessor); @@ -576,7 +576,7 @@ fn we_can_convert_an_ast_without_any_filter() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let expected_ast = QueryExpr::new( @@ -609,8 +609,8 @@ fn we_can_parse_order_by_with_a_single_column() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "b".parse().unwrap() => ColumnType::BigInt, - "a".parse().unwrap() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select * from sxt_tab where a = 3 order by b", &accessor); @@ -631,8 +631,8 @@ fn we_can_parse_order_by_with_multiple_columns() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "b".parse().unwrap() => ColumnType::BigInt, - "a".parse().unwrap() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -661,8 +661,8 @@ fn we_can_parse_order_by_referencing_an_alias_associated_with_column_b_but_with_ let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "name".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "name".into() => ColumnType::VarChar, }, ); let ast = query_to_provable_ast( @@ -690,7 +690,7 @@ fn we_cannot_parse_order_by_referencing_a_column_name_instead_of_an_alias() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, }, ); invalid_query_to_provable_ast( @@ -706,8 +706,8 @@ fn we_cannot_parse_order_by_referencing_invalid_aliased_expressions() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "b".parse().unwrap() => ColumnType::BigInt, - "a".parse().unwrap() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); // Note: While this operation is acceptable with PostgreSQL, we do not currently support it. @@ -723,8 +723,8 @@ fn we_cannot_parse_order_by_referencing_an_alias_name_associated_with_two_differ let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "name".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "name".into() => ColumnType::VarChar, }, ); invalid_query_to_provable_ast( @@ -762,8 +762,8 @@ fn we_can_parse_order_by_queries_with_the_same_column_name_appearing_more_than_o let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "name".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "name".into() => ColumnType::VarChar, }, ); for order_by in ["s", "d"] { @@ -798,7 +798,7 @@ fn we_can_parse_a_query_having_a_simple_limit_clause() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab limit 3", &accessor); @@ -819,7 +819,7 @@ fn slice_is_still_applied_when_limit_is_u64_max_and_offset_is_zero() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab offset 0", &accessor); @@ -840,7 +840,7 @@ fn we_can_parse_a_query_having_a_simple_positive_offset_clause() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab offset 7", &accessor); @@ -861,7 +861,7 @@ fn we_can_parse_a_query_having_a_negative_offset_clause() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab offset -7", &accessor); @@ -882,7 +882,7 @@ fn we_can_parse_a_query_having_a_simple_limit_and_offset_clause() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, + "a".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast(t, "select a from sxt_tab limit 55 offset 3", &accessor); @@ -907,8 +907,8 @@ fn we_can_parse_a_query_having_a_simple_limit_and_offset_clause_preceded_by_wher let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "boolean".parse().unwrap() => ColumnType::Boolean, + "a".into() => ColumnType::BigInt, + "boolean".into() => ColumnType::Boolean, }, ); let ast = query_to_provable_ast( @@ -945,8 +945,8 @@ fn we_can_do_provable_group_by() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -973,8 +973,8 @@ fn we_can_do_provable_group_by_without_sum() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -1001,9 +1001,9 @@ fn we_can_do_provable_group_by_with_two_group_by_columns() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "state".parse().unwrap() => ColumnType::VarChar, - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "state".into() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -1030,9 +1030,9 @@ fn we_can_do_provable_group_by_with_two_sums_and_filter() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "tax".parse().unwrap() => ColumnType::BigInt, - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "tax".into() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -1065,8 +1065,8 @@ fn we_can_group_by_without_using_aggregate_functions() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -1106,9 +1106,9 @@ fn group_by_expressions_are_parsed_before_an_order_by_referencing_an_aggregate_a let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "department_budget".parse().unwrap() => ColumnType::BigInt, - "salary".parse().unwrap() => ColumnType::BigInt, - "tax".parse().unwrap() => ColumnType::BigInt, + "department_budget".into() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "tax".into() => ColumnType::BigInt, }, ); @@ -1146,8 +1146,8 @@ fn we_cannot_parse_non_aggregated_or_non_group_by_columns_in_the_select_clause() let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); invalid_query_to_provable_ast( @@ -1163,8 +1163,8 @@ fn alias_references_are_not_allowed_in_the_group_by() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); invalid_query_to_provable_ast( @@ -1185,8 +1185,8 @@ fn order_by_cannot_reference_an_invalid_group_by_column() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); invalid_query_to_provable_ast( @@ -1207,8 +1207,8 @@ fn group_by_column_cannot_be_a_column_result_alias() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, }, ); invalid_query_to_provable_ast( @@ -1225,7 +1225,7 @@ fn we_can_have_aggregate_functions_without_a_group_by_clause() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "name".parse().unwrap() => ColumnType::VarChar, + "name".into() => ColumnType::VarChar, }, ); @@ -1245,9 +1245,9 @@ fn we_can_parse_a_query_having_group_by_with_the_same_name_as_the_aggregation_ex let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, - "bonus".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, + "bonus".into() => ColumnType::VarChar, }, ); let ast = query_to_provable_ast( @@ -1275,9 +1275,9 @@ fn count_aggregate_functions_can_be_used_with_non_numeric_columns() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, - "bonus".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, + "bonus".into() => ColumnType::VarChar, }, ); let ast = query_to_provable_ast( @@ -1309,9 +1309,9 @@ fn count_all_uses_the_first_group_by_identifier_as_default_result_column() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, - "bonus".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, + "bonus".into() => ColumnType::VarChar, }, ); let ast = query_to_provable_ast( @@ -1339,9 +1339,9 @@ fn aggregate_result_columns_cannot_reference_invalid_columns() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, - "bonus".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, + "bonus".into() => ColumnType::VarChar, }, ); invalid_query_to_provable_ast( @@ -1357,9 +1357,9 @@ fn we_can_use_the_same_result_columns_with_different_aliases_and_associate_it_wi let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, - "bonus".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, + "bonus".into() => ColumnType::VarChar, }, ); let ast = query_to_provable_ast( @@ -1390,10 +1390,10 @@ fn we_can_use_multiple_group_by_clauses_with_multiple_agg_and_non_agg_exprs() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "bonus".parse().unwrap() => ColumnType::BigInt, - "name".parse().unwrap() => ColumnType::VarChar, - "salary".parse().unwrap() => ColumnType::BigInt, - "tax".parse().unwrap() => ColumnType::BigInt, + "bonus".into() => ColumnType::BigInt, + "name".into() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "tax".into() => ColumnType::BigInt, }, ); let query_text = "select salary d1, max(tax), salary d2, sum(bonus) sum_bonus, count(name) count_s from sxt.employees group by salary, bonus, salary"; @@ -1427,10 +1427,10 @@ fn we_can_parse_a_simple_add_mul_sub_div_arithmetic_expressions_in_the_result_ex let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "a".parse().unwrap() => ColumnType::BigInt, - "f".parse().unwrap() => ColumnType::Int128, - "b".parse().unwrap() => ColumnType::BigInt, - "h".parse().unwrap() => ColumnType::Int128, + "a".into() => ColumnType::BigInt, + "f".into() => ColumnType::Int128, + "b".into() => ColumnType::BigInt, + "h".into() => ColumnType::Int128, }, ); // TODO: add `a / b as a_div_b` result expr once polars properly @@ -1472,10 +1472,10 @@ fn we_can_parse_multiple_arithmetic_expression_where_multiplication_has_preceden let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "c".parse().unwrap() => ColumnType::BigInt, - "f".parse().unwrap() => ColumnType::BigInt, - "g".parse().unwrap() => ColumnType::BigInt, - "h".parse().unwrap() => ColumnType::BigInt, + "c".into() => ColumnType::BigInt, + "f".into() => ColumnType::BigInt, + "g".into() => ColumnType::BigInt, + "h".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -1527,10 +1527,10 @@ fn we_can_parse_arithmetic_expression_within_aggregations_in_the_result_expr() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "c".parse().unwrap() => ColumnType::BigInt, - "f".parse().unwrap() => ColumnType::BigInt, - "g".parse().unwrap() => ColumnType::BigInt, - "k".parse().unwrap() => ColumnType::BigInt, + "c".into() => ColumnType::BigInt, + "f".into() => ColumnType::BigInt, + "g".into() => ColumnType::BigInt, + "k".into() => ColumnType::BigInt, }, ); let ast = query_to_provable_ast( @@ -1564,8 +1564,8 @@ fn we_cannot_use_non_grouped_columns_outside_agg() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "name".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "name".into() => ColumnType::VarChar, }, ); let identifier_not_in_agg_queries = vec![ @@ -1610,8 +1610,8 @@ fn varchar_column_is_not_compatible_with_integer_column() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "name".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "name".into() => ColumnType::VarChar, }, ); @@ -1658,8 +1658,8 @@ fn arithmetic_operations_are_not_allowed_with_varchar_column() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "name".parse().unwrap() => ColumnType::VarChar, - "position".parse().unwrap() => ColumnType::VarChar, + "name".into() => ColumnType::VarChar, + "position".into() => ColumnType::VarChar, }, ); @@ -1682,7 +1682,7 @@ fn varchar_column_is_not_allowed_within_numeric_aggregations() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "name".parse().unwrap() => ColumnType::VarChar, + "name".into() => ColumnType::VarChar, }, ); let sum_query = "select sum(name) from sxt.employees"; @@ -1722,7 +1722,7 @@ fn group_by_with_bigint_column_is_valid() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, }, ); let query_text = "select salary from sxt.employees group by salary"; @@ -1750,7 +1750,7 @@ fn group_by_with_decimal_column_is_valid() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::Int128, + "salary".into() => ColumnType::Int128, }, ); let query_text = "select salary from sxt.employees group by salary"; @@ -1778,7 +1778,7 @@ fn group_by_with_varchar_column_is_valid() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "name".parse().unwrap() => ColumnType::VarChar, + "name".into() => ColumnType::VarChar, }, ); let query_text = "select name from sxt.employees group by name"; @@ -1806,8 +1806,8 @@ fn we_can_use_arithmetic_outside_agg_expressions_while_using_group_by() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "tax".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, + "tax".into() => ColumnType::BigInt, }, ); let query_text = @@ -1851,8 +1851,8 @@ fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "bonus".parse().unwrap() => ColumnType::Int128, + "salary".into() => ColumnType::BigInt, + "bonus".into() => ColumnType::Int128, }, ); let query_text = "select 7 + max(salary) as max_i, min(salary + 777 * bonus) * -5 as min_d from sxt.employees"; @@ -1895,9 +1895,9 @@ fn count_aggregation_always_have_integer_type() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "name".parse().unwrap() => ColumnType::VarChar, - "salary".parse().unwrap() => ColumnType::BigInt, - "tax".parse().unwrap() => ColumnType::Int128, + "name".into() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "tax".into() => ColumnType::Int128, }, ); let query_text = @@ -1953,15 +1953,15 @@ fn select_wildcard_is_valid_with_group_by_exprs() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "employee_name".parse().unwrap() => ColumnType::VarChar, - "base_salary".parse().unwrap() => ColumnType::BigInt, - "annual_bonus".parse().unwrap() => ColumnType::Int128, - "manager_name".parse().unwrap() => ColumnType::VarChar, - "manager_salary".parse().unwrap() => ColumnType::BigInt, - "manager_bonus".parse().unwrap() => ColumnType::Int128, - "department_name".parse().unwrap() => ColumnType::VarChar, - "department_budget".parse().unwrap() => ColumnType::BigInt, - "department_headcount".parse().unwrap() => ColumnType::Int128, + "employee_name".into() => ColumnType::VarChar, + "base_salary".into() => ColumnType::BigInt, + "annual_bonus".into() => ColumnType::Int128, + "manager_name".into() => ColumnType::VarChar, + "manager_salary".into() => ColumnType::BigInt, + "manager_bonus".into() => ColumnType::Int128, + "department_name".into() => ColumnType::VarChar, + "department_budget".into() => ColumnType::BigInt, + "department_headcount".into() => ColumnType::Int128, }, ); @@ -1995,7 +1995,7 @@ fn nested_aggregations_are_not_supported() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, + "salary".into() => ColumnType::BigInt, }, ); @@ -2024,10 +2024,10 @@ fn select_group_and_order_by_preserve_the_column_order_reference() { let accessor = schema_accessor_from_table_ref_with_schema( t, indexmap! { - "salary".parse().unwrap() => ColumnType::BigInt, - "department".parse().unwrap() => ColumnType::BigInt, - "tax".parse().unwrap() => ColumnType::BigInt, - "name".parse().unwrap() => ColumnType::VarChar, + "salary".into() => ColumnType::BigInt, + "department".into() => ColumnType::BigInt, + "tax".into() => ColumnType::BigInt, + "name".into() => ColumnType::VarChar, }, ); let base_cols: [&str; N] = ["salary", "department", "tax", "name"]; // sorted because of `select: [cols = ... ]` @@ -2086,12 +2086,12 @@ fn query_expr_for_test_table(sql_text: &str) -> QueryExpr { let schema_accessor = schema_accessor_from_table_ref_with_schema( "test.table".parse().unwrap(), indexmap! { - "bigint_column".parse().unwrap() => ColumnType::BigInt, - "varchar_column".parse().unwrap() => ColumnType::VarChar, - "int128_column".parse().unwrap() => ColumnType::Int128, + "bigint_column".into() => ColumnType::BigInt, + "varchar_column".into() => ColumnType::VarChar, + "int128_column".into() => ColumnType::Int128, }, ); - let default_schema = "test".parse().unwrap(); + let default_schema = "test".into(); let select_statement = SelectStatementParser::new().parse(sql_text).unwrap(); QueryExpr::try_new(select_statement, default_schema, &schema_accessor).unwrap() } diff --git a/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs index 65356db8a..4201e68eb 100644 --- a/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs @@ -7,7 +7,8 @@ use crate::{ sql::proof_exprs::{DynProofExpr, ProofExpr}, }; use alloc::boxed::Box; -use proof_of_sql_parser::{intermediate_ast::Expression, Identifier}; +use proof_of_sql_parser::intermediate_ast::Expression; +use sqlparser::ast::Ident; /// Builder that enables building a `proof_of_sql::sql::proof_exprs::DynProofExpr` from a `proof_of_sql_parser::intermediate_ast::Expression` that is /// intended to be used as the where clause in a filter expression or group by expression. @@ -16,7 +17,7 @@ pub struct WhereExprBuilder<'a> { } impl<'a> WhereExprBuilder<'a> { /// Creates a new `WhereExprBuilder` with the given column mapping. - pub fn new(column_mapping: &'a IndexMap) -> Self { + pub fn new(column_mapping: &'a IndexMap) -> Self { Self { builder: DynProofExprBuilder::new(column_mapping), } diff --git a/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs b/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs index 662555028..734a21fbe 100644 --- a/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs @@ -3,6 +3,7 @@ use crate::{ database::{ColumnRef, ColumnType, LiteralValue, TestSchemaAccessor}, map::{indexmap, IndexMap}, math::decimal::Precision, + sqlparser::ident, }, sql::{ parse::{ConversionError, QueryExpr, WhereExprBuilder}, @@ -14,8 +15,9 @@ use core::str::FromStr; use proof_of_sql_parser::{ posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestamp}, utility::*, - Identifier, SelectStatement, + SelectStatement, }; +use sqlparser::ast::Ident; /// # Panics /// @@ -26,7 +28,7 @@ use proof_of_sql_parser::{ /// - The precision used for creating the `Decimal75` column type fails. The `Precision::new(7)` /// call is expected to succeed; however, if it encounters an invalid precision value, it will /// cause a panic when `unwrap()` is called. -fn get_column_mappings_for_testing() -> IndexMap { +fn get_column_mappings_for_testing() -> IndexMap { let tab_ref = "sxt.sxt_tab".parse().unwrap(); let mut column_mapping = IndexMap::default(); // Setup column mapping @@ -302,7 +304,7 @@ fn we_expect_an_error_while_trying_to_check_varchar_column_eq_decimal() { let t = "sxt.sxt_tab".parse().unwrap(); let accessor = TestSchemaAccessor::new(indexmap! { t => indexmap! { - "b".parse().unwrap() => ColumnType::VarChar, + "b".into() => ColumnType::VarChar, }, }); @@ -321,7 +323,7 @@ fn we_expect_an_error_while_trying_to_check_varchar_column_ge_decimal() { let t = "sxt.sxt_tab".parse().unwrap(); let accessor = TestSchemaAccessor::new(indexmap! { t => indexmap! { - "b".parse().unwrap() => ColumnType::VarChar, + "b".into() => ColumnType::VarChar, }, }); @@ -340,7 +342,7 @@ fn we_do_not_expect_an_error_while_trying_to_check_int128_column_eq_decimal_with let t = "sxt.sxt_tab".parse().unwrap(); let accessor = TestSchemaAccessor::new(indexmap! { t => indexmap! { - "b".parse().unwrap() => ColumnType::Int128, + "b".into() => ColumnType::Int128, }, }); @@ -357,7 +359,7 @@ fn we_do_not_expect_an_error_while_trying_to_check_bigint_column_eq_decimal_with let t = "sxt.sxt_tab".parse().unwrap(); let accessor = TestSchemaAccessor::new(indexmap! { t => indexmap! { - "b".parse().unwrap() => ColumnType::BigInt, + "b".into() => ColumnType::BigInt, }, }); diff --git a/crates/proof-of-sql/src/sql/postprocessing/error.rs b/crates/proof-of-sql/src/sql/postprocessing/error.rs index 6cc242c64..adce01894 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/error.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/error.rs @@ -1,6 +1,6 @@ use alloc::string::String; -use proof_of_sql_parser::Identifier; use snafu::Snafu; +use sqlparser::ast::Ident; /// Errors in postprocessing #[derive(Snafu, Debug, PartialEq, Eq)] @@ -33,7 +33,13 @@ pub enum PostprocessingError { #[snafu(display("Invalid group by: column '{column}' must not appear outside aggregate functions or `GROUP BY` clause."))] IdentifierNotInAggregationOperatorOrGroupByClause { /// The column identifier - column: Identifier, + column: Ident, + }, + /// Errors in converting `Ident` to `Identifier` + #[snafu(display("Failed to convert `Ident` to `Identifier`: {error}"))] + IdentifierConversionError { + /// The underlying error message + error: String, }, /// Errors in aggregate columns #[snafu(transparent)] diff --git a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs index e75d3b24f..5e79d0c9b 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs @@ -12,6 +12,7 @@ use proof_of_sql_parser::{ Identifier, }; use serde::{Deserialize, Serialize}; +use sqlparser::ast::Ident; /// A group by expression #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -20,10 +21,10 @@ pub struct GroupByPostprocessing { remainder_exprs: Vec, /// A list of identifiers in the group by clause - group_by_identifiers: Vec, + group_by_identifiers: Vec, /// A list of aggregation expressions - aggregation_exprs: Vec<(AggregationOperator, Expression, Identifier)>, + aggregation_exprs: Vec<(AggregationOperator, Expression, Ident)>, } /// Check whether multiple layers of aggregation exist within the same GROUP BY clause @@ -43,9 +44,9 @@ fn contains_nested_aggregation(expr: &Expression, is_agg: bool) -> bool { } /// Get identifiers NOT in aggregate functions -fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet { +fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet { match expr { - Expression::Column(identifier) => IndexSet::from_iter([*identifier]), + Expression::Column(identifier) => IndexSet::from_iter([(*identifier).into()]), Expression::Literal(_) | Expression::Aggregation { .. } | Expression::Wildcard => { IndexSet::default() } @@ -70,20 +71,33 @@ fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet { /// or if there are issues retrieving an identifier from the map. fn get_aggregate_and_remainder_expressions( expr: Expression, - aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Identifier>, -) -> Expression { + aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>, +) -> Result { match expr { - Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => expr, + Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => Ok(expr), Expression::Aggregation { op, expr } => { let key = (op, (*expr)); - if aggregation_expr_map.contains_key(&key) { - Expression::Column(*aggregation_expr_map.get(&key).unwrap()) + if let Some(ident) = aggregation_expr_map.get(&key) { + let identifier = Identifier::try_from(ident.clone()).map_err(|e| { + PostprocessingError::IdentifierConversionError { + error: format!("Failed to convert Ident to Identifier: {e}"), + } + })?; + Ok(Expression::Column(identifier)) } else { - let new_col_id = format!("__col_agg_{}", aggregation_expr_map.len()) - .parse() - .unwrap(); - aggregation_expr_map.insert(key, new_col_id); - Expression::Column(new_col_id) + let new_ident = Ident { + value: format!("__col_agg_{}", aggregation_expr_map.len()), + quote_style: None, + }; + + let new_identifier = Identifier::try_from(new_ident.clone()).map_err(|e| { + PostprocessingError::IdentifierConversionError { + error: format!("Failed to convert Ident to Identifier: {e}"), + } + })?; + + aggregation_expr_map.insert(key, new_ident); + Ok(Expression::Column(new_identifier)) } } Expression::Binary { op, left, right } => { @@ -91,18 +105,18 @@ fn get_aggregate_and_remainder_expressions( get_aggregate_and_remainder_expressions(*left, aggregation_expr_map); let right_remainder = get_aggregate_and_remainder_expressions(*right, aggregation_expr_map); - Expression::Binary { + Ok(Expression::Binary { op, - left: Box::new(left_remainder), - right: Box::new(right_remainder), - } + left: Box::new(left_remainder?), + right: Box::new(right_remainder?), + }) } Expression::Unary { op, expr } => { let remainder = get_aggregate_and_remainder_expressions(*expr, aggregation_expr_map); - Expression::Unary { + Ok(Expression::Unary { op, - expr: Box::new(remainder), - } + expr: Box::new(remainder?), + }) } } } @@ -113,13 +127,13 @@ fn get_aggregate_and_remainder_expressions( /// Will panic if there is an issue retrieving the first element from the difference of free identifiers and group-by identifiers, indicating a logical inconsistency in the identifiers. fn check_and_get_aggregation_and_remainder( expr: AliasedResultExpr, - group_by_identifiers: &[Identifier], - aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Identifier>, + group_by_identifiers: &[Ident], + aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>, ) -> PostprocessingResult { let free_identifiers = get_free_identifiers_from_expr(&expr.expr); let group_by_identifier_set = group_by_identifiers .iter() - .copied() + .cloned() .collect::>(); if contains_nested_aggregation(&expr.expr, false) { return Err(PostprocessingError::NestedAggregationInGroupByClause { @@ -130,7 +144,7 @@ fn check_and_get_aggregation_and_remainder( let remainder = get_aggregate_and_remainder_expressions(*expr.expr, aggregation_expr_map); Ok(AliasedResultExpr { alias: expr.alias, - expr: Box::new(remainder), + expr: Box::new(remainder?), }) } else { let diff = free_identifiers @@ -139,7 +153,7 @@ fn check_and_get_aggregation_and_remainder( .unwrap(); Err( PostprocessingError::IdentifierNotInAggregationOperatorOrGroupByClause { - column: *diff, + column: diff.clone(), }, ) } @@ -148,10 +162,10 @@ fn check_and_get_aggregation_and_remainder( impl GroupByPostprocessing { /// Create a new group by expression containing the group by and aggregation expressions pub fn try_new( - by_ids: Vec, + by_ids: Vec, aliased_exprs: Vec, ) -> PostprocessingResult { - let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Identifier> = + let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> = IndexMap::default(); // Look for aggregation expressions and check for non-aggregation expressions that contain identifiers not in the group by clause let remainder_exprs: Vec = aliased_exprs @@ -177,7 +191,7 @@ impl GroupByPostprocessing { /// Get group by identifiers #[must_use] - pub fn group_by(&self) -> &[Identifier] { + pub fn group_by(&self) -> &[Ident] { &self.group_by_identifiers } @@ -189,7 +203,7 @@ impl GroupByPostprocessing { /// Get aggregation expressions #[must_use] - pub fn aggregation_exprs(&self) -> &[(AggregationOperator, Expression, Identifier)] { + pub fn aggregation_exprs(&self) -> &[(AggregationOperator, Expression, Ident)] { &self.aggregation_exprs } } @@ -205,7 +219,7 @@ impl PostprocessingStep for GroupByPostprocessing { .iter() .map(|(agg_op, expr, id)| -> PostprocessingResult<_> { let evaluated_owned_column = owned_table.evaluate(expr)?; - Ok((*agg_op, (*id, evaluated_owned_column))) + Ok((*agg_op, (id.clone(), evaluated_owned_column))) }) .process_results(|iter| { iter.fold( @@ -236,7 +250,7 @@ impl PostprocessingStep for GroupByPostprocessing { .map_or((vec![], vec![]), |tuple| { tuple .iter() - .map(|(id, c)| (*id, Column::::from_owned_column(c, &alloc))) + .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) .unzip() }); let (max_identifiers, max_columns): (Vec<_>, Vec<_>) = evaluated_columns @@ -244,7 +258,7 @@ impl PostprocessingStep for GroupByPostprocessing { .map_or((vec![], vec![]), |tuple| { tuple .iter() - .map(|(id, c)| (*id, Column::::from_owned_column(c, &alloc))) + .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) .unzip() }); let (min_identifiers, min_columns): (Vec<_>, Vec<_>) = evaluated_columns @@ -252,7 +266,7 @@ impl PostprocessingStep for GroupByPostprocessing { .map_or((vec![], vec![]), |tuple| { tuple .iter() - .map(|(id, c)| (*id, Column::::from_owned_column(c, &alloc))) + .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) .unzip() }); let aggregation_results = aggregate_columns( @@ -269,7 +283,7 @@ impl PostprocessingStep for GroupByPostprocessing { .group_by_columns .iter() .zip(self.group_by_identifiers.iter()) - .map(|(column, id)| Ok((*id, OwnedColumn::from(column)))); + .map(|(column, id)| Ok((id.clone(), OwnedColumn::from(column)))); let sum_outs = izip!( aggregation_results.sum_columns, sum_identifiers, @@ -309,7 +323,7 @@ impl PostprocessingStep for GroupByPostprocessing { .get(&AggregationOperator::Count) .into_iter() .flatten() - .map(|(id, _)| -> PostprocessingResult<_> { Ok((*id, count_column.clone())) }); + .map(|(id, _)| -> PostprocessingResult<_> { Ok((id.clone(), count_column.clone())) }); let new_owned_table: OwnedTable = group_by_outs .into_iter() .chain(sum_outs) @@ -320,7 +334,7 @@ impl PostprocessingStep for GroupByPostprocessing { // If there are no columns at all we need to have the count column so that we can handle // queries such as `SELECT 1 FROM table` let target_table = if new_owned_table.is_empty() { - OwnedTable::try_new(indexmap! {"__count__".parse().unwrap() => count_column})? + OwnedTable::try_new(indexmap! {"__count__".into() => count_column})? } else { new_owned_table }; @@ -329,7 +343,8 @@ impl PostprocessingStep for GroupByPostprocessing { .iter() .map(|aliased_expr| -> PostprocessingResult<_> { let column = target_table.evaluate(&aliased_expr.expr)?; - Ok((aliased_expr.alias, column)) + let alias: Ident = aliased_expr.alias.into(); + Ok((alias, column)) }) .process_results(|iter| OwnedTable::try_from_iter(iter))??; Ok(result) @@ -339,6 +354,7 @@ impl PostprocessingStep for GroupByPostprocessing { #[cfg(test)] mod tests { use super::*; + use crate::base::sqlparser::ident; use proof_of_sql_parser::utility::*; #[test] @@ -378,41 +394,38 @@ mod tests { fn we_can_get_free_identifiers_from_expr() { // Literal let expr = lit("Not an identifier"); - let expected: IndexSet = IndexSet::default(); + let expected: IndexSet = IndexSet::default(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); // a + b + 1 let expr = add(add(col("a"), col("b")), lit(1)); - let expected: IndexSet = [ident("a"), ident("b")].iter().copied().collect(); + let expected: IndexSet = [ident("a"), ident("b")].into_iter().collect(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); // ! (a == b || c >= a) let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a")))); - let expected: IndexSet = [ident("a"), ident("b"), ident("c")] - .iter() - .copied() - .collect(); + let expected: IndexSet = [ident("a"), ident("b"), ident("c")].into_iter().collect(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); // SUM(a + b) * 2 let expr = mul(sum(add(col("a"), col("b"))), lit(2)); - let expected: IndexSet = IndexSet::default(); + let expected: IndexSet = IndexSet::default(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); // (COUNT(a + b) + c) * d let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d")); - let expected: IndexSet = [ident("c"), ident("d")].iter().copied().collect(); + let expected: IndexSet = [ident("c"), ident("d")].into_iter().collect(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); } #[test] fn we_can_get_aggregate_and_remainder_expressions() { - let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Identifier> = + let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> = IndexMap::default(); // SUM(a) + b let expr = add(sum(col("a")), col("b")); @@ -422,7 +435,7 @@ mod tests { aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))], ident("__col_agg_0") ); - assert_eq!(remainder_expr, *add(col("__col_agg_0"), col("b"))); + assert_eq!(remainder_expr, Ok(*add(col("__col_agg_0"), col("b")))); assert_eq!(aggregation_expr_map.len(), 1); // SUM(a) + SUM(b) @@ -437,7 +450,10 @@ mod tests { aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))], ident("__col_agg_1") ); - assert_eq!(remainder_expr, *add(col("__col_agg_0"), col("__col_agg_1"))); + assert_eq!( + remainder_expr, + Ok(*add(col("__col_agg_0"), col("__col_agg_1"))) + ); assert_eq!(aggregation_expr_map.len(), 2); // MAX(a + 1) + MIN(2 * b - 4) + c @@ -463,7 +479,7 @@ mod tests { ); assert_eq!( remainder_expr, - *add(add(col("__col_agg_2"), col("__col_agg_3")), col("c")) + Ok(*add(add(col("__col_agg_2"), col("__col_agg_3")), col("c"))) ); assert_eq!(aggregation_expr_map.len(), 4); @@ -480,10 +496,10 @@ mod tests { ); assert_eq!( remainder_expr, - *add( + Ok(*add( add(mul(col("__col_agg_4"), lit(2)), col("__col_agg_1")), lit(1) - ) + )) ); assert_eq!(aggregation_expr_map.len(), 5); } diff --git a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs index f38cfa201..94614409e 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs @@ -2,6 +2,7 @@ use crate::{ base::{ database::{owned_table_utility::*, OwnedTable}, scalar::Curve25519Scalar, + sqlparser::ident, }, sql::postprocessing::{ apply_postprocessing_steps, group_by_postprocessing::*, test_utility::*, @@ -10,7 +11,6 @@ use crate::{ }; use bigdecimal::BigDecimal; use proof_of_sql_parser::{intermediate_ast::AggregationOperator, utility::*}; - #[test] fn we_cannot_have_invalid_group_bys() { // Column in result but not in group by or aggregation diff --git a/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs index 4f8b3f5f5..8303e1faf 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/order_by_postprocessing.rs @@ -34,10 +34,11 @@ impl PostprocessingStep for OrderByPostprocessing { .iter() .map( |order_by| -> PostprocessingResult<(OwnedColumn, OrderByDirection)> { + let identifier: sqlparser::ast::Ident = order_by.expr.into(); Ok(( owned_table .inner_table() - .get(&order_by.expr) + .get(&identifier) .ok_or(PostprocessingError::ColumnNotFound { column: order_by.expr.to_string(), })? diff --git a/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing.rs index fad453d25..9437c5daf 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing.rs @@ -5,8 +5,9 @@ use crate::base::{ scalar::Scalar, }; use alloc::vec::Vec; -use proof_of_sql_parser::{intermediate_ast::AliasedResultExpr, Identifier}; +use proof_of_sql_parser::intermediate_ast::AliasedResultExpr; use serde::{Deserialize, Serialize}; +use sqlparser::ast::Ident; /// The select expression used to select, reorder, and apply alias transformations #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -28,13 +29,13 @@ impl SelectPostprocessing { impl PostprocessingStep for SelectPostprocessing { /// Apply the select transformation to the given `OwnedTable`. fn apply(&self, owned_table: OwnedTable) -> PostprocessingResult> { - let cols: IndexMap> = self + let cols: IndexMap> = self .aliased_result_exprs .iter() .map( - |aliased_result_expr| -> PostprocessingResult<(Identifier, OwnedColumn)> { + |aliased_result_expr| -> PostprocessingResult<(Ident, OwnedColumn)> { let result_column = owned_table.evaluate(&aliased_result_expr.expr)?; - Ok((aliased_result_expr.alias, result_column)) + Ok((aliased_result_expr.alias.into(), result_column)) }, ) .collect::>()?; diff --git a/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs b/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs index dd421d0db..9ee210326 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs @@ -1,16 +1,14 @@ use super::*; -use proof_of_sql_parser::{ - intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection}, - utility::ident, - Identifier, -}; +use crate::base::sqlparser::ident; +use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection}; +use sqlparser::ast::Ident; #[must_use] pub fn group_by_postprocessing( cols: &[&str], result_exprs: &[AliasedResultExpr], ) -> OwnedTablePostprocessing { - let ids: Vec = cols.iter().map(|col| ident(col)).collect(); + let ids: Vec = cols.iter().map(|col| ident(col)).collect(); OwnedTablePostprocessing::new_group_by( GroupByPostprocessing::try_new(ids, result_exprs.to_vec()).unwrap(), ) diff --git a/crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs b/crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs index b1e8b9d2e..2053224d9 100644 --- a/crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs +++ b/crates/proof-of-sql/src/sql/proof/final_round_builder_test.rs @@ -78,8 +78,8 @@ fn we_can_form_the_provable_query_result() { let res = ProvableQueryResult::new(2, &[col1, col2]); let column_fields = vec![ - ColumnField::new("a".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("a".into(), ColumnType::BigInt), + ColumnField::new("b".into(), ColumnType::BigInt), ]; let res = RecordBatch::try_from( res.to_owned_table::(&column_fields) diff --git a/crates/proof-of-sql/src/sql/proof/provable_query_result_test.rs b/crates/proof-of-sql/src/sql/proof/provable_query_result_test.rs index 3540de040..ee9870545 100644 --- a/crates/proof-of-sql/src/sql/proof/provable_query_result_test.rs +++ b/crates/proof-of-sql/src/sql/proof/provable_query_result_test.rs @@ -17,7 +17,7 @@ use num_traits::Zero; fn we_can_convert_an_empty_provable_result_to_a_final_result() { let cols: [Column; 1] = [Column::BigInt(&[0_i64; 0])]; let res = ProvableQueryResult::new(0, &cols); - let column_fields = vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)]; + let column_fields = vec![ColumnField::new("a1".into(), ColumnType::BigInt)]; let res = RecordBatch::try_from( res.to_owned_table::(&column_fields) .unwrap(), @@ -44,8 +44,7 @@ fn we_can_evaluate_result_columns_as_mles() { let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); cols.len()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::BigInt); cols.len()]; let evals = res .evaluate(&evaluation_point, 2, &column_fields[..]) .unwrap(); @@ -62,8 +61,7 @@ fn we_can_evaluate_result_columns_with_no_rows() { let evaluation_point = []; let mut evaluation_vec = [Curve25519Scalar::ZERO; 0]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); cols.len()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::BigInt); cols.len()]; let evals = res .evaluate(&evaluation_point, 0, &column_fields[..]) .unwrap(); @@ -81,8 +79,7 @@ fn we_can_evaluate_multiple_result_columns_as_mles() { ]; let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); cols.len()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::BigInt); cols.len()]; let evals = res .evaluate(&evaluation_point, 2, &column_fields[..]) .unwrap(); @@ -105,8 +102,7 @@ fn we_can_evaluate_multiple_result_columns_as_mles_with_128_bits() { ]; let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::Int128); cols.len()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::Int128); cols.len()]; let evals = res .evaluate(&evaluation_point, 2, &column_fields[..]) .unwrap(); @@ -138,8 +134,7 @@ fn we_can_evaluate_multiple_result_columns_as_mles_with_scalar_columns() { ]; let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::Scalar); cols.len()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::Scalar); cols.len()]; let evals = res .evaluate(&evaluation_point, 2, &column_fields[..]) .unwrap(); @@ -163,8 +158,8 @@ fn we_can_evaluate_multiple_result_columns_as_mles_with_mixed_data_types() { let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); let column_fields = [ - ColumnField::new("a".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("a".parse().unwrap(), ColumnType::Int128), + ColumnField::new("a".into(), ColumnType::BigInt), + ColumnField::new("a".into(), ColumnType::Int128), ]; let evals = res .evaluate(&evaluation_point, 2, &column_fields[..]) @@ -189,8 +184,7 @@ fn evaluation_fails_if_extra_data_is_included() { ]; let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); cols.len()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::BigInt); cols.len()]; assert!(matches!( res.evaluate(&evaluation_point, 2, &column_fields[..]), Err(QueryError::MiscellaneousEvaluationError) @@ -207,8 +201,7 @@ fn evaluation_fails_if_the_result_cant_be_decoded() { ]; let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); res.num_columns()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::BigInt); res.num_columns()]; assert!(matches!( res.evaluate(&evaluation_point, 2, &column_fields[..]), Err(QueryError::Overflow) @@ -226,8 +219,7 @@ fn evaluation_fails_if_integer_overflow_happens() { ]; let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::Int); res.num_columns()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::Int); res.num_columns()]; assert!(matches!( res.evaluate(&evaluation_point, 2, &column_fields[..]), Err(QueryError::Overflow) @@ -245,8 +237,7 @@ fn evaluation_fails_if_data_is_missing() { ]; let mut evaluation_vec = [Curve25519Scalar::ZERO; 2]; compute_evaluation_vector(&mut evaluation_vec, &evaluation_point); - let column_fields = - vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt); res.num_columns()]; + let column_fields = vec![ColumnField::new("a".into(), ColumnType::BigInt); res.num_columns()]; assert!(matches!( res.evaluate(&evaluation_point, 2, &column_fields[..]), Err(QueryError::Overflow) @@ -257,7 +248,7 @@ fn evaluation_fails_if_data_is_missing() { fn we_can_convert_a_provable_result_to_a_final_result() { let cols: [Column; 1] = [Column::BigInt(&[10, 12])]; let res = ProvableQueryResult::new(2, &cols); - let column_fields = vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)]; + let column_fields = vec![ColumnField::new("a1".into(), ColumnType::BigInt)]; let res = RecordBatch::try_from( res.to_owned_table::(&column_fields) .unwrap(), @@ -277,7 +268,7 @@ fn we_can_convert_a_provable_result_to_a_final_result() { fn we_can_convert_a_provable_result_to_a_final_result_with_128_bits() { let cols: [Column; 1] = [Column::Int128(&[10, i128::MAX])]; let res = ProvableQueryResult::new(2, &cols); - let column_fields = vec![ColumnField::new("a1".parse().unwrap(), ColumnType::Int128)]; + let column_fields = vec![ColumnField::new("a1".into(), ColumnType::Int128)]; let res = RecordBatch::try_from( res.to_owned_table::(&column_fields) .unwrap(), @@ -307,7 +298,7 @@ fn we_can_convert_a_provable_result_to_a_final_result_with_252_bits() { let cols: [Column; 1] = [Column::Scalar(&values)]; let res = ProvableQueryResult::new(2, &cols); let column_fields = vec![ColumnField::new( - "a1".parse().unwrap(), + "a1".into(), ColumnType::Decimal75(Precision::new(75).unwrap(), 0), )]; let res = RecordBatch::try_from( @@ -352,11 +343,11 @@ fn we_can_convert_a_provable_result_to_a_final_result_with_mixed_data_types() { ]; let res = ProvableQueryResult::new(2, &cols); let column_fields = vec![ - ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("a2".parse().unwrap(), ColumnType::Int128), - ColumnField::new("a3".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("a1".into(), ColumnType::BigInt), + ColumnField::new("a2".into(), ColumnType::Int128), + ColumnField::new("a3".into(), ColumnType::VarChar), ColumnField::new( - "a4".parse().unwrap(), + "a4".into(), ColumnType::Decimal75(Precision::new(75).unwrap(), 0), ), ]; diff --git a/crates/proof-of-sql/src/sql/proof/query_proof.rs b/crates/proof-of-sql/src/sql/proof/query_proof.rs index 694c930f1..708b330d1 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof.rs @@ -93,7 +93,7 @@ impl QueryProof { let col_refs: IndexSet = total_col_refs .iter() .filter(|col_ref| col_ref.table_ref() == table_ref) - .copied() + .cloned() .collect(); (table_ref, accessor.get_table(table_ref, &col_refs)) }) @@ -342,7 +342,7 @@ impl QueryProof { let pcs_proof_commitments: Vec<_> = column_references .iter() - .map(|col| accessor.get_commitment(*col)) + .map(|col| accessor.get_commitment(col.clone())) .chain(self.commitments.iter().cloned()) .collect(); let evaluation_accessor: IndexMap<_, _> = column_references @@ -447,7 +447,7 @@ fn extend_transcript_with_owned_table( result: &OwnedTable, ) { for (name, column) in result.inner_table() { - transcript.extend_as_le_from_refs([name.as_str()]); + transcript.extend_as_le_from_refs([name.value.as_str()]); match column { OwnedColumn::Boolean(col) => transcript.extend_as_be(col.iter().map(|&b| u8::from(b))), OwnedColumn::TinyInt(col) => transcript.extend_as_be_from_refs(col), diff --git a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs index 0d043fb68..e88c7ba7b 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof_test.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof_test.rs @@ -16,8 +16,8 @@ use crate::{ sql::proof::{FirstRoundBuilder, QueryData, SumcheckSubpolynomialType}, }; use bumpalo::Bump; -use proof_of_sql_parser::Identifier; use serde::Serialize; +use sqlparser::ast::Ident; /// Type to allow us to prove and verify an artificial polynomial where we prove /// that every entry in the result is zero @@ -106,7 +106,7 @@ impl ProofPlan for TrivialTestProofPlan { /// /// This method will panic if the `ColumnField` cannot be created from the provided column name (e.g., if the name parsing fails). fn get_column_result_fields(&self) -> Vec { - vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)] + vec![ColumnField::new("a1".into(), ColumnType::BigInt)] } fn get_column_references(&self) -> IndexSet { indexset! {} @@ -281,7 +281,7 @@ impl ProverEvaluate for SquareTestProofPlan { .get(&TableRef::new("sxt.test".parse().unwrap())) .unwrap() .inner_table() - .get(&"x".parse::().unwrap()) + .get(&Ident::new("x")) .unwrap(); let res: &[_] = alloc.alloc_slice_copy(&self.res); builder.produce_intermediate_mle(res); @@ -307,7 +307,7 @@ impl ProofPlan for SquareTestProofPlan { * *accessor .get(&ColumnRef::new( "sxt.test".parse().unwrap(), - "x".parse().unwrap(), + "x".into(), ColumnType::BigInt, )) .unwrap(); @@ -323,12 +323,12 @@ impl ProofPlan for SquareTestProofPlan { )) } fn get_column_result_fields(&self) -> Vec { - vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)] + vec![ColumnField::new("a1".into(), ColumnType::BigInt)] } fn get_column_references(&self) -> IndexSet { indexset! {ColumnRef::new( "sxt.test".parse().unwrap(), - "x".parse().unwrap(), + "x".into(), ColumnType::BigInt, )} } @@ -460,7 +460,7 @@ impl ProverEvaluate for DoubleSquareTestProofPlan { .get(&TableRef::new("sxt.test".parse().unwrap())) .unwrap() .inner_table() - .get(&"x".parse::().unwrap()) + .get(&Ident::new("x")) .unwrap(); let res: &[_] = alloc.alloc_slice_copy(&self.res); let z: &[_] = alloc.alloc_slice_copy(&self.z); @@ -498,7 +498,7 @@ impl ProofPlan for DoubleSquareTestProofPlan { let x_eval = *accessor .get(&ColumnRef::new( "sxt.test".parse().unwrap(), - "x".parse().unwrap(), + "x".into(), ColumnType::BigInt, )) .unwrap(); @@ -524,12 +524,12 @@ impl ProofPlan for DoubleSquareTestProofPlan { )) } fn get_column_result_fields(&self) -> Vec { - vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)] + vec![ColumnField::new("a1".into(), ColumnType::BigInt)] } fn get_column_references(&self) -> IndexSet { indexset! {ColumnRef::new( "sxt.test".parse().unwrap(), - "x".parse().unwrap(), + "x".into(), ColumnType::BigInt, )} } @@ -671,7 +671,7 @@ impl ProverEvaluate for ChallengeTestProofPlan { .get(&TableRef::new("sxt.test".parse().unwrap())) .unwrap() .inner_table() - .get(&"x".parse::().unwrap()) + .get(&Ident::new("x")) .unwrap(); let res: &[_] = alloc.alloc_slice_copy(&[9, 25]); let alpha = builder.consume_post_result_challenge(); @@ -700,7 +700,7 @@ impl ProofPlan for ChallengeTestProofPlan { let x_eval = *accessor .get(&ColumnRef::new( "sxt.test".parse().unwrap(), - "x".parse().unwrap(), + "x".into(), ColumnType::BigInt, )) .unwrap(); @@ -716,12 +716,12 @@ impl ProofPlan for ChallengeTestProofPlan { )) } fn get_column_result_fields(&self) -> Vec { - vec![ColumnField::new("a1".parse().unwrap(), ColumnType::BigInt)] + vec![ColumnField::new("a1".into(), ColumnType::BigInt)] } fn get_column_references(&self) -> IndexSet { indexset! {ColumnRef::new( "sxt.test".parse().unwrap(), - "x".parse().unwrap(), + "x".into(), ColumnType::BigInt, )} } diff --git a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs index b95334b95..1ff118f5e 100644 --- a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs +++ b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test.rs @@ -13,6 +13,7 @@ use crate::{ map::{indexset, IndexMap, IndexSet}, proof::ProofError, scalar::Scalar, + sqlparser::ident, }, sql::proof::{FirstRoundBuilder, QueryData}, }; @@ -34,7 +35,8 @@ impl ProverEvaluate for EmptyTestQueryExpr { let zeros = vec![0_i64; self.length]; builder.produce_one_evaluation_length(self.length); table_with_row_count( - (1..=self.columns).map(|i| borrowed_bigint(format!("a{i}"), zeros.clone(), alloc)), + (1..=self.columns) + .map(|i| borrowed_bigint(ident(format!("a{i}").as_str()), zeros.clone(), alloc)), self.length, ) } @@ -51,7 +53,8 @@ impl ProverEvaluate for EmptyTestQueryExpr { .take(self.columns) .collect::>(); table_with_row_count( - (1..=self.columns).map(|i| borrowed_bigint(format!("a{i}"), zeros.clone(), alloc)), + (1..=self.columns) + .map(|i| borrowed_bigint(ident(format!("a{i}").as_str()), zeros.clone(), alloc)), self.length, ) } @@ -76,7 +79,7 @@ impl ProofPlan for EmptyTestQueryExpr { fn get_column_result_fields(&self) -> Vec { (1..=self.columns) - .map(|i| ColumnField::new(format!("a{i}").parse().unwrap(), ColumnType::BigInt)) + .map(|i| ColumnField::new(format!("a{i}").as_str().into(), ColumnType::BigInt)) .collect() } diff --git a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test_utility.rs b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test_utility.rs index b2c5c0758..461ae0c1a 100644 --- a/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof/verifiable_query_result_test_utility.rs @@ -96,7 +96,7 @@ fn append_single_row_to_table(table: &OwnedTable) -> OwnedTable table .inner_table() .iter() - .map(|(name, col)| (*name, append_single_row_to_column(col))), + .map(|(name, col)| (name.clone(), append_single_row_to_column(col))), ) .expect("Failed to create table") } @@ -122,7 +122,7 @@ fn tamper_first_element_of_table(table: &OwnedTable) -> OwnedTable .enumerate() .map(|(i, (name, col))| { ( - *name, + name.clone(), if i == 0 { tamper_first_row_of_column(col) } else { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/aliased_dyn_proof_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/aliased_dyn_proof_expr.rs index 838b6f32b..72db25d50 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/aliased_dyn_proof_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/aliased_dyn_proof_expr.rs @@ -1,10 +1,10 @@ use super::DynProofExpr; -use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; +use sqlparser::ast::Ident; /// A `DynProofExpr` with an alias. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] pub struct AliasedDynProofExpr { pub expr: DynProofExpr, - pub alias: Identifier, + pub alias: Ident, } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs index 3f25d6dcb..5df96f167 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/column_expr.rs @@ -9,8 +9,8 @@ use crate::{ sql::proof::{FinalRoundBuilder, VerificationBuilder}, }; use bumpalo::Bump; -use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; +use sqlparser::ast::Ident; /// Provable expression for a column /// /// Note: this is currently limited to named column expressions. @@ -27,7 +27,7 @@ impl ColumnExpr { /// Return the column referenced by this [`ColumnExpr`] pub fn get_column_reference(&self) -> ColumnRef { - self.column_ref + self.column_ref.clone() } /// Wrap the column output name and its type within the [`ColumnField`] @@ -36,7 +36,7 @@ impl ColumnExpr { } /// Get the column identifier - pub fn column_id(&self) -> Identifier { + pub fn column_id(&self) -> Ident { self.column_ref.column_id() } @@ -99,6 +99,6 @@ impl ProofExpr for ColumnExpr { /// references in the `BoolExpr` or forwards the call to some /// subsequent `bool_expr` fn get_column_references(&self, columns: &mut IndexSet) { - columns.insert(self.column_ref); + columns.insert(self.column_ref.clone()); } } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs b/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs index 86f2f27f5..e084916ab 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs @@ -5,13 +5,14 @@ use crate::base::{ scalar::Scalar, }; use proof_of_sql_parser::intermediate_ast::AggregationOperator; +use sqlparser::ast::Ident; /// # Panics /// Panics if: /// - `name.parse()` fails, which means the provided string could not be parsed into the expected type (usually an `Identifier`). pub fn col_ref(tab: TableRef, name: &str, accessor: &impl SchemaAccessor) -> ColumnRef { - let name = name.parse().unwrap(); - let type_col = accessor.lookup_column(tab, name).unwrap(); + let name: Ident = name.into(); + let type_col = accessor.lookup_column(tab, name.clone()).unwrap(); ColumnRef::new(tab, name, type_col) } @@ -20,8 +21,8 @@ pub fn col_ref(tab: TableRef, name: &str, accessor: &impl SchemaAccessor) -> Col /// - `name.parse()` fails to parse the column name. /// - `accessor.lookup_column()` returns `None`, indicating the column is not found. pub fn column(tab: TableRef, name: &str, accessor: &impl SchemaAccessor) -> DynProofExpr { - let name = name.parse().unwrap(); - let type_col = accessor.lookup_column(tab, name).unwrap(); + let name: Ident = name.into(); + let type_col = accessor.lookup_column(tab, name.clone()).unwrap(); DynProofExpr::Column(ColumnExpr::new(ColumnRef::new(tab, name, type_col))) } @@ -138,7 +139,7 @@ pub fn tab(tab: TableRef) -> TableExpr { pub fn aliased_plan(expr: DynProofExpr, alias: &str) -> AliasedDynProofExpr { AliasedDynProofExpr { expr, - alias: alias.parse().unwrap(), + alias: alias.into(), } } @@ -154,7 +155,7 @@ pub fn aliased_col_expr_plan( ) -> AliasedDynProofExpr { AliasedDynProofExpr { expr: DynProofExpr::Column(ColumnExpr::new(col_ref(tab, old_name, accessor))), - alias: new_name.parse().unwrap(), + alias: new_name.into(), } } @@ -169,7 +170,7 @@ pub fn col_expr_plan( ) -> AliasedDynProofExpr { AliasedDynProofExpr { expr: DynProofExpr::Column(ColumnExpr::new(col_ref(tab, name, accessor))), - alias: name.parse().unwrap(), + alias: name.into(), } } @@ -212,6 +213,6 @@ pub fn cols_expr(tab: TableRef, names: &[&str], accessor: &impl SchemaAccessor) pub fn sum_expr(expr: DynProofExpr, alias: &str) -> AliasedDynProofExpr { AliasedDynProofExpr { expr: DynProofExpr::new_aggregate(AggregationOperator::Sum, expr), - alias: alias.parse().unwrap(), + alias: alias.into(), } } diff --git a/crates/proof-of-sql/src/sql/proof_plans/demo_mock_plan.rs b/crates/proof-of-sql/src/sql/proof_plans/demo_mock_plan.rs index 8b865b885..7a3352caa 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/demo_mock_plan.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/demo_mock_plan.rs @@ -42,7 +42,7 @@ impl ProofPlan for DemoMockPlan { } fn get_column_references(&self) -> IndexSet { - indexset! {self.column} + indexset! {self.column.clone()} } fn get_table_references(&self) -> IndexSet { @@ -91,11 +91,7 @@ mod tests { fn we_can_create_and_prove_a_demo_mock_plan() { let table_ref = "namespace.table_name".parse::().unwrap(); let table = owned_table([bigint("column_name", [0, 1, 2, 3])]); - let column_ref = ColumnRef::new( - table_ref, - "column_name".parse().unwrap(), - ColumnType::BigInt, - ); + let column_ref = ColumnRef::new(table_ref, "column_name".into(), ColumnType::BigInt); let plan = DemoMockPlan { column: column_ref }; let accessor = OwnedTableTestAccessor::::new_from_table( table_ref, diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs index 0bfece895..ed8d1a14b 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs @@ -115,7 +115,9 @@ where fn get_column_result_fields(&self) -> Vec { self.aliased_results .iter() - .map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type())) + .map(|aliased_expr| { + ColumnField::new(aliased_expr.alias.clone(), aliased_expr.expr.data_type()) + }) .collect() } @@ -171,7 +173,7 @@ impl ProverEvaluate for FilterExec { let res = Table::<'a, S>::try_from_iter_with_options( self.aliased_results .iter() - .map(|expr| expr.alias) + .map(|expr| expr.alias.clone()) .zip(filtered_columns), TableOptions::new(Some(output_length)), ) @@ -235,7 +237,7 @@ impl ProverEvaluate for FilterExec { let res = Table::<'a, S>::try_from_iter_with_options( self.aliased_results .iter() - .map(|expr| expr.alias) + .map(|expr| expr.alias.clone()) .zip(filtered_columns), TableOptions::new(Some(output_length)), ) diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs index 86b317051..4bb70961b 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs @@ -20,13 +20,14 @@ use crate::{ }; use blitzar::proof::InnerProductProof; use bumpalo::Bump; -use proof_of_sql_parser::{Identifier, ResourceId}; +use proof_of_sql_parser::ResourceId; +use sqlparser::ast::Ident; #[test] fn we_can_correctly_fetch_the_query_result_schema() { let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap()); - let a = Identifier::try_new("a").unwrap(); - let b = Identifier::try_new("b").unwrap(); + let a = Ident::new("a"); + let b = Ident::new("b"); let provable_ast = FilterExec::new( vec![ aliased_plan( @@ -50,7 +51,7 @@ fn we_can_correctly_fetch_the_query_result_schema() { DynProofExpr::try_new_equals( DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( table_ref, - Identifier::try_new("c").unwrap(), + Ident::new("c"), ColumnType::BigInt, ))), DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(123))), @@ -62,8 +63,8 @@ fn we_can_correctly_fetch_the_query_result_schema() { assert_eq!( column_fields, vec![ - ColumnField::new("a".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt) + ColumnField::new("a".into(), ColumnType::BigInt), + ColumnField::new("b".into(), ColumnType::BigInt) ] ); } @@ -71,8 +72,8 @@ fn we_can_correctly_fetch_the_query_result_schema() { #[test] fn we_can_correctly_fetch_all_the_referenced_columns() { let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap()); - let a = Identifier::try_new("a").unwrap(); - let f = Identifier::try_new("f").unwrap(); + let a = Ident::new("a"); + let f = Ident::new("f"); let provable_ast = FilterExec::new( vec![ aliased_plan( @@ -98,7 +99,7 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { DynProofExpr::try_new_equals( DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( table_ref, - Identifier::try_new("f").unwrap(), + Ident::new("f"), ColumnType::BigInt, ))), DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(45))), @@ -107,7 +108,7 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { DynProofExpr::try_new_equals( DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( table_ref, - Identifier::try_new("c").unwrap(), + Ident::new("c"), ColumnType::BigInt, ))), DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(-2))), @@ -117,7 +118,7 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { DynProofExpr::try_new_equals( DynProofExpr::Column(ColumnExpr::new(ColumnRef::new( table_ref, - Identifier::try_new("b").unwrap(), + Ident::new("b"), ColumnType::BigInt, ))), DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(3))), @@ -131,26 +132,10 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { assert_eq!( ref_columns, IndexSet::from_iter([ - ColumnRef::new( - table_ref, - Identifier::try_new("a").unwrap(), - ColumnType::BigInt - ), - ColumnRef::new( - table_ref, - Identifier::try_new("f").unwrap(), - ColumnType::BigInt - ), - ColumnRef::new( - table_ref, - Identifier::try_new("c").unwrap(), - ColumnType::BigInt - ), - ColumnRef::new( - table_ref, - Identifier::try_new("b").unwrap(), - ColumnType::BigInt - ) + ColumnRef::new(table_ref, Ident::new("a"), ColumnType::BigInt), + ColumnRef::new(table_ref, Ident::new("f"), ColumnType::BigInt), + ColumnRef::new(table_ref, Ident::new("c"), ColumnType::BigInt), + ColumnRef::new(table_ref, Ident::new("b"), ColumnType::BigInt) ]) ); @@ -199,11 +184,11 @@ fn we_can_get_an_empty_result_from_a_basic_filter_on_an_empty_table_using_first_ where_clause, ); let fields = &[ - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("c".parse().unwrap(), ColumnType::Int128), - ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("b".into(), ColumnType::BigInt), + ColumnField::new("c".into(), ColumnType::Int128), + ColumnField::new("d".into(), ColumnType::VarChar), ColumnField::new( - "e".parse().unwrap(), + "e".into(), ColumnType::Decimal75(Precision::new(75).unwrap(), 0), ), ]; @@ -248,11 +233,11 @@ fn we_can_get_an_empty_result_from_a_basic_filter_using_first_round_evaluate() { where_clause, ); let fields = &[ - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("c".parse().unwrap(), ColumnType::Int128), - ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("b".into(), ColumnType::BigInt), + ColumnField::new("c".into(), ColumnType::Int128), + ColumnField::new("d".into(), ColumnType::VarChar), ColumnField::new( - "e".parse().unwrap(), + "e".into(), ColumnType::Decimal75(Precision::new(1).unwrap(), 0), ), ]; @@ -328,11 +313,11 @@ fn we_can_get_the_correct_result_from_a_basic_filter_using_first_round_evaluate( where_clause, ); let fields = &[ - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("c".parse().unwrap(), ColumnType::Int128), - ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("b".into(), ColumnType::BigInt), + ColumnField::new("c".into(), ColumnType::Int128), + ColumnField::new("d".into(), ColumnType::VarChar), ColumnField::new( - "e".parse().unwrap(), + "e".into(), ColumnType::Decimal75(Precision::new(1).unwrap(), 0), ), ]; diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs index 09d6b9494..9ec4f5701 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs @@ -64,7 +64,7 @@ impl ProverEvaluate for DishonestFilterExec { let res = Table::<'a, S>::try_from_iter_with_options( self.aliased_results .iter() - .map(|expr| expr.alias) + .map(|expr| expr.alias.clone()) .zip(filtered_columns), TableOptions::new(Some(output_length)), ) @@ -132,7 +132,7 @@ impl ProverEvaluate for DishonestFilterExec { let res = Table::<'a, S>::try_from_iter_with_options( self.aliased_results .iter() - .map(|expr| expr.alias) + .map(|expr| expr.alias.clone()) .zip(filtered_columns), TableOptions::new(Some(output_length)), ) diff --git a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs index 4a9270985..3ca31d42b 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs @@ -25,8 +25,8 @@ use alloc::{boxed::Box, vec, vec::Vec}; use bumpalo::Bump; use core::iter; use num_traits::{One, Zero}; -use proof_of_sql_parser::Identifier; use serde::{Deserialize, Serialize}; +use sqlparser::ast::Ident; /// Provable expressions for queries of the form /// ```ignore @@ -43,7 +43,7 @@ use serde::{Deserialize, Serialize}; pub struct GroupByExec { pub(super) group_by_exprs: Vec, pub(super) sum_expr: Vec, - pub(super) count_alias: Identifier, + pub(super) count_alias: Ident, pub(super) table: TableExpr, pub(super) where_clause: DynProofExpr, } @@ -53,7 +53,7 @@ impl GroupByExec { pub fn new( group_by_exprs: Vec, sum_expr: Vec, - count_alias: Identifier, + count_alias: Ident, table: TableExpr, where_clause: DynProofExpr, ) -> Self { @@ -160,10 +160,10 @@ impl ProofPlan for GroupByExec { .iter() .map(|col| col.get_column_field()) .chain(self.sum_expr.iter().map(|aliased_expr| { - ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type()) + ColumnField::new(aliased_expr.alias.clone(), aliased_expr.expr.data_type()) })) .chain(iter::once(ColumnField::new( - self.count_alias, + self.count_alias.clone(), ColumnType::BigInt, ))) .collect() diff --git a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs index df10790d7..80e8e85b4 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs @@ -67,7 +67,9 @@ impl ProofPlan for ProjectionExec { fn get_column_result_fields(&self) -> Vec { self.aliased_results .iter() - .map(|aliased_expr| ColumnField::new(aliased_expr.alias, aliased_expr.expr.data_type())) + .map(|aliased_expr| { + ColumnField::new(aliased_expr.alias.clone(), aliased_expr.expr.data_type()) + }) .collect() } @@ -104,7 +106,7 @@ impl ProverEvaluate for ProjectionExec { let res = Table::<'a, S>::try_from_iter_with_options( self.aliased_results.iter().map(|aliased_expr| { ( - aliased_expr.alias, + aliased_expr.alias.clone(), aliased_expr.expr.result_evaluate(alloc, table), ) }), @@ -138,7 +140,7 @@ impl ProverEvaluate for ProjectionExec { let res = Table::<'a, S>::try_from_iter_with_options( self.aliased_results.iter().map(|aliased_expr| { ( - aliased_expr.alias, + aliased_expr.alias.clone(), aliased_expr.expr.prover_evaluate(builder, alloc, table), ) }), diff --git a/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs index c3396a491..519bcb0b1 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/projection_exec_test.rs @@ -19,13 +19,14 @@ use crate::{ }; use blitzar::proof::InnerProductProof; use bumpalo::Bump; -use proof_of_sql_parser::{Identifier, ResourceId}; +use proof_of_sql_parser::ResourceId; +use sqlparser::ast::Ident; #[test] fn we_can_correctly_fetch_the_query_result_schema() { let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap()); - let a = Identifier::try_new("a").unwrap(); - let b = Identifier::try_new("b").unwrap(); + let a = Ident::new("a"); + let b = Ident::new("b"); let provable_ast = ProjectionExec::new( vec![ aliased_plan( @@ -51,8 +52,8 @@ fn we_can_correctly_fetch_the_query_result_schema() { assert_eq!( column_fields, vec![ - ColumnField::new("a".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), + ColumnField::new("a".into(), ColumnType::BigInt), + ColumnField::new("b".into(), ColumnType::BigInt), ] ); } @@ -60,8 +61,8 @@ fn we_can_correctly_fetch_the_query_result_schema() { #[test] fn we_can_correctly_fetch_all_the_referenced_columns() { let table_ref = TableRef::new(ResourceId::try_new("sxt", "sxt_tab").unwrap()); - let a = Identifier::try_new("a").unwrap(); - let f = Identifier::try_new("f").unwrap(); + let a = Ident::new("a"); + let f = Ident::new("f"); let provable_ast = ProjectionExec::new( vec![ aliased_plan( @@ -89,16 +90,8 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { assert_eq!( ref_columns, IndexSet::from_iter([ - ColumnRef::new( - table_ref, - Identifier::try_new("a").unwrap(), - ColumnType::BigInt - ), - ColumnRef::new( - table_ref, - Identifier::try_new("f").unwrap(), - ColumnType::BigInt - ), + ColumnRef::new(table_ref, Ident::new("a"), ColumnType::BigInt), + ColumnRef::new(table_ref, Ident::new("f"), ColumnType::BigInt), ]) ); @@ -173,11 +166,11 @@ fn we_can_get_an_empty_result_from_a_basic_projection_on_an_empty_table_using_fi let expr: DynProofPlan = projection(cols_expr_plan(t, &["b", "c", "d", "e"], &accessor), tab(t)); let fields = &[ - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("c".parse().unwrap(), ColumnType::Int128), - ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("b".into(), ColumnType::BigInt), + ColumnField::new("c".into(), ColumnType::Int128), + ColumnField::new("d".into(), ColumnType::VarChar), ColumnField::new( - "e".parse().unwrap(), + "e".into(), ColumnType::Decimal75(Precision::new(75).unwrap(), 0), ), ]; @@ -259,11 +252,11 @@ fn we_can_get_the_correct_result_from_a_basic_projection_using_first_round_evalu tab(t), ); let fields = &[ - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("prod".parse().unwrap(), ColumnType::Int128), - ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("b".into(), ColumnType::BigInt), + ColumnField::new("prod".into(), ColumnType::Int128), + ColumnField::new("d".into(), ColumnType::VarChar), ColumnField::new( - "e".parse().unwrap(), + "e".into(), ColumnType::Decimal75(Precision::new(1).unwrap(), 0), ), ]; diff --git a/crates/proof-of-sql/src/sql/proof_plans/range_check_test_plan.rs b/crates/proof-of-sql/src/sql/proof_plans/range_check_test_plan.rs index f26be1fcc..446bd077e 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/range_check_test_plan.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/range_check_test_plan.rs @@ -62,7 +62,7 @@ impl ProofPlan for RangeCheckTestPlan { } fn get_column_references(&self) -> IndexSet { - indexset! {self.column} + indexset! {self.column.clone()} } #[doc = " Return all the tables referenced in the Query"] @@ -113,7 +113,7 @@ mod tests { let t = "sxt.t".parse().unwrap(); let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let ast = RangeCheckTestPlan { - column: ColumnRef::new(t, "a".parse().unwrap(), ColumnType::Scalar), + column: ColumnRef::new(t, "a".into(), ColumnType::Scalar), }; let verifiable_res = VerifiableQueryResult::::new(&ast, &accessor, &()); let _ = verifiable_res.verify(&ast, &accessor, &()); @@ -125,7 +125,7 @@ mod tests { let t = "sxt.t".parse().unwrap(); let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let ast = RangeCheckTestPlan { - column: ColumnRef::new(t, "a".parse().unwrap(), ColumnType::Scalar), + column: ColumnRef::new(t, "a".into(), ColumnType::Scalar), }; let verifiable_res = VerifiableQueryResult::::new(&ast, &accessor, &()); let res: Result< @@ -145,7 +145,7 @@ mod tests { let t = "sxt.t".parse().unwrap(); let accessor = OwnedTableTestAccessor::::new_from_table(t, data, 0, ()); let ast = RangeCheckTestPlan { - column: ColumnRef::new(t, "a".parse().unwrap(), ColumnType::Scalar), + column: ColumnRef::new(t, "a".into(), ColumnType::Scalar), }; let verifiable_res = VerifiableQueryResult::::new(&ast, &accessor, &()); let res: Result< diff --git a/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs index bf869b987..27248127d 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/slice_exec_test.rs @@ -94,11 +94,11 @@ fn we_can_get_an_empty_result_from_a_slice_on_an_empty_table_using_first_round_e ); let fields = &[ - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("c".parse().unwrap(), ColumnType::Int128), - ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("b".into(), ColumnType::BigInt), + ColumnField::new("c".into(), ColumnType::Int128), + ColumnField::new("d".into(), ColumnType::VarChar), ColumnField::new( - "e".parse().unwrap(), + "e".into(), ColumnType::Decimal75(Precision::new(75).unwrap(), 0), ), ]; @@ -148,11 +148,11 @@ fn we_can_get_an_empty_result_from_a_slice_using_first_round_evaluate() { ); let fields = &[ - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("c".parse().unwrap(), ColumnType::Int128), - ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("b".into(), ColumnType::BigInt), + ColumnField::new("c".into(), ColumnType::Int128), + ColumnField::new("d".into(), ColumnType::VarChar), ColumnField::new( - "e".parse().unwrap(), + "e".into(), ColumnType::Decimal75(Precision::new(1).unwrap(), 0), ), ]; @@ -236,11 +236,11 @@ fn we_can_get_the_correct_result_from_a_slice_using_first_round_evaluate() { None, ); let fields = &[ - ColumnField::new("b".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("c".parse().unwrap(), ColumnType::Int128), - ColumnField::new("d".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("b".into(), ColumnType::BigInt), + ColumnField::new("c".into(), ColumnType::Int128), + ColumnField::new("d".into(), ColumnType::VarChar), ColumnField::new( - "e".parse().unwrap(), + "e".into(), ColumnType::Decimal75(Precision::new(1).unwrap(), 0), ), ]; @@ -461,9 +461,9 @@ fn we_can_create_and_prove_a_slice_exec_on_top_of_a_table_exec() { table_exec( table_ref, vec![ - ColumnField::new("language_rank".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("language_name".parse().unwrap(), ColumnType::VarChar), - ColumnField::new("space_and_time".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("language_rank".into(), ColumnType::BigInt), + ColumnField::new("language_name".into(), ColumnType::VarChar), + ColumnField::new("space_and_time".into(), ColumnType::VarChar), ], ), 1, diff --git a/crates/proof-of-sql/src/sql/proof_plans/table_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/table_exec_test.rs index ad3eddab8..28d728c74 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/table_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/table_exec_test.rs @@ -15,7 +15,7 @@ fn we_can_create_and_prove_an_empty_table_exec() { let table_ref = TableRef::new("namespace.table_name".parse().unwrap()); let plan = table_exec( table_ref, - vec![ColumnField::new("a".parse().unwrap(), ColumnType::BigInt)], + vec![ColumnField::new("a".into(), ColumnType::BigInt)], ); let accessor = TableTestAccessor::::new_from_table( table_ref, @@ -37,9 +37,9 @@ fn we_can_create_and_prove_a_table_exec() { let plan = table_exec( table_ref, vec![ - ColumnField::new("language_rank".parse().unwrap(), ColumnType::BigInt), - ColumnField::new("language_name".parse().unwrap(), ColumnType::VarChar), - ColumnField::new("space_and_time".parse().unwrap(), ColumnType::VarChar), + ColumnField::new("language_rank".into(), ColumnType::BigInt), + ColumnField::new("language_name".into(), ColumnType::VarChar), + ColumnField::new("space_and_time".into(), ColumnType::VarChar), ], ); let accessor = TableTestAccessor::::new_from_table( diff --git a/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs b/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs index 28116fef4..407627d5a 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/test_utility.rs @@ -8,7 +8,7 @@ use crate::{ }; pub fn column_field(name: &str, column_type: ColumnType) -> ColumnField { - ColumnField::new(name.parse().unwrap(), column_type) + ColumnField::new(name.into(), column_type) } pub fn empty_exec() -> DynProofPlan { @@ -44,7 +44,7 @@ pub fn group_by( DynProofPlan::GroupBy(GroupByExec::new( group_by_exprs, sum_expr, - count_alias.parse().unwrap(), + count_alias.into(), table, where_clause, )) diff --git a/crates/proof-of-sql/tests/decimal_integration_tests.rs b/crates/proof-of-sql/tests/decimal_integration_tests.rs index 37c2847b3..3e0885221 100644 --- a/crates/proof-of-sql/tests/decimal_integration_tests.rs +++ b/crates/proof-of-sql/tests/decimal_integration_tests.rs @@ -29,12 +29,7 @@ fn run_query( accessor.add_table("sxt.table".parse().unwrap(), data, 0); - let query = QueryExpr::try_new( - query_str.parse().unwrap(), - "sxt".parse().unwrap(), - &accessor, - ) - .unwrap(); + let query = QueryExpr::try_new(query_str.parse().unwrap(), "sxt".into(), &accessor).unwrap(); let proof = VerifiableQueryResult::::new(query.proof_expr(), &accessor, &()); let owned_table_result = proof .verify(query.proof_expr(), &accessor, &()) diff --git a/crates/proof-of-sql/tests/integration_tests.rs b/crates/proof-of-sql/tests/integration_tests.rs index 16496e62e..d479ab118 100644 --- a/crates/proof-of-sql/tests/integration_tests.rs +++ b/crates/proof-of-sql/tests/integration_tests.rs @@ -30,7 +30,7 @@ fn we_can_prove_a_minimal_filter_query_with_curve25519() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE a;".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -61,7 +61,7 @@ fn we_can_prove_a_minimal_filter_query_with_dory() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE not a".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -93,7 +93,7 @@ fn we_can_prove_a_minimal_filter_query_with_dynamic_dory() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE not a".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -121,7 +121,7 @@ fn we_can_prove_a_basic_equality_query_with_curve25519() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE b = 1;".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -152,7 +152,7 @@ fn we_can_prove_a_basic_equality_query_with_dory() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE b = 1".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -180,7 +180,7 @@ fn we_can_prove_a_basic_inequality_query_with_curve25519() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE b >= 1;".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -211,7 +211,7 @@ fn we_can_prove_a_basic_query_containing_extrema_with_curve25519() { ); let query = QueryExpr::try_new( "SELECT * FROM table".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -254,7 +254,7 @@ fn we_can_prove_a_basic_query_containing_extrema_with_dory() { ); let query = QueryExpr::try_new( "SELECT * FROM table;".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -288,7 +288,7 @@ fn we_can_prove_a_query_with_arithmetic_in_where_clause_with_curve25519() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE b >= a + 1".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -320,7 +320,7 @@ fn we_can_prove_a_query_with_arithmetic_in_where_clause_with_dory() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE b > 1 - a;".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -353,7 +353,7 @@ fn we_can_prove_a_basic_equality_with_out_of_order_results_with_curve25519() { "select primes, amount from public.test_table where primes = 'abcd';" .parse() .unwrap(), - "public".parse().unwrap(), + "public".into(), &accessor, ) .unwrap(); @@ -386,7 +386,7 @@ fn we_can_prove_a_basic_inequality_query_with_dory() { ); let query = QueryExpr::try_new( "SELECT * FROM table WHERE b <= 0".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -415,11 +415,7 @@ fn decimal_type_issues_should_cause_provable_ast_to_fail() { let large_decimal = format!("0.{}", "1".repeat(75)); let query_string = format!("SELECT d0 + {large_decimal} as res FROM table;"); assert!(matches!( - QueryExpr::try_new( - query_string.parse().unwrap(), - "sxt".parse().unwrap(), - &accessor, - ), + QueryExpr::try_new(query_string.parse().unwrap(), "sxt".into(), &accessor,), Err(ConversionError::DataTypeMismatch { .. }) )); } @@ -446,7 +442,7 @@ fn we_can_prove_a_complex_query_with_curve25519() { "SELECT a + (b * c) + 1 as t, 45.7 as g, (a = b) or f as h, d0 * d1 + 1.4 as dr FROM table WHERE (a >= b) = (c < d) and (e = 'e') = f;" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -493,7 +489,7 @@ fn we_can_prove_a_complex_query_with_dory() { "SELECT 0.5 + a * b * c - d as res, 32 as g, (c >= d) and f as h, (a + 1) * (b + 1 + c + d + d0 - d1 + 0.5) as res2 FROM table WHERE (a < b) = (c <= d) and e <> 'f' and f and 100000 * d1 * d0 + a = 1.3" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -529,7 +525,7 @@ fn we_can_prove_a_minimal_group_by_query_with_curve25519() { "SELECT a, count(*) as c FROM table group by a" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -563,7 +559,7 @@ fn we_can_prove_a_basic_group_by_query_with_curve25519() { "SELECT a, sum(2 * b + 1) as d, count(*) as e FROM table WHERE c >= 0 group by a" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -626,7 +622,7 @@ fn we_can_prove_a_cat_group_by_query_with_curve25519() { "select human, sum(age + 0.1) as total_adjusted_cat_age, count(*) as num_cats from sxt.cats where is_female group by human order by human" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -698,7 +694,7 @@ fn we_can_prove_a_cat_group_by_query_with_dynamic_dory() { "select diff_from_ideal_weight, count(*) as num_cats from sxt.cats where is_female group by diff_from_ideal_weight order by diff_from_ideal_weight" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -741,7 +737,7 @@ fn we_can_prove_a_basic_group_by_query_with_dory() { "SELECT a, sum(2 * b + 1) as d, count(*) as e FROM table WHERE c >= 0 group by a" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -774,7 +770,7 @@ fn we_can_prove_a_query_with_overflow_with_curve25519() { ); let query = QueryExpr::try_new( "SELECT a + b as c from table".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -803,7 +799,7 @@ fn we_can_prove_a_query_with_overflow_with_dory() { ); let query = QueryExpr::try_new( "SELECT a - b as c from table".parse().unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -835,7 +831,7 @@ fn we_can_perform_arithmetic_and_conditional_operations_on_tinyint() { "SELECT a*b+b+c as result FROM table WHERE a>b OR c=4" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); diff --git a/crates/proof-of-sql/tests/timestamp_integration_tests.rs b/crates/proof-of-sql/tests/timestamp_integration_tests.rs index 6b0a71ebe..d45f14da7 100644 --- a/crates/proof-of-sql/tests/timestamp_integration_tests.rs +++ b/crates/proof-of-sql/tests/timestamp_integration_tests.rs @@ -44,7 +44,7 @@ fn we_can_prove_a_basic_query_containing_rfc3339_timestamp_with_dory() { "SELECT times FROM table WHERE times = timestamp '1970-01-01T00:00:00Z';" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap(); @@ -88,12 +88,7 @@ fn run_timestamp_query_test( ); // Parse and execute the query - let query = QueryExpr::try_new( - query_str.parse().unwrap(), - "sxt".parse().unwrap(), - &accessor, - ) - .unwrap(); + let query = QueryExpr::try_new(query_str.parse().unwrap(), "sxt".into(), &accessor).unwrap(); let proof = VerifiableQueryResult::::new(query.proof_expr(), &accessor, &()); @@ -435,7 +430,7 @@ fn we_can_prove_timestamp_inequality_queries_with_multiple_columns() { "select *, a <= b as res from TABLE where a <= b" .parse() .unwrap(), - "sxt".parse().unwrap(), + "sxt".into(), &accessor, ) .unwrap();