Skip to content

Commit

Permalink
refactor: remove base::sqlparser::ident since into() is simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Dec 16, 2024
1 parent be3853a commit 4d1b9cf
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 51 deletions.
1 change: 0 additions & 1 deletion crates/proof-of-sql/src/base/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ mod serialize;
pub(crate) use serialize::{impl_serde_for_ark_serde_checked, impl_serde_for_ark_serde_unchecked};
pub(crate) mod map;
pub(crate) mod slice_ops;
pub(crate) mod sqlparser;

mod rayon_cfg;
pub(crate) use rayon_cfg::if_rayon;
5 changes: 0 additions & 5 deletions crates/proof-of-sql/src/base/sqlparser.rs

This file was deleted.

41 changes: 20 additions & 21 deletions crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use crate::{
database::{ColumnRef, ColumnType, LiteralValue, TestSchemaAccessor},
map::{indexmap, IndexMap},
math::decimal::Precision,
sqlparser::ident,
},
sql::{
parse::{ConversionError, QueryExpr, WhereExprBuilder},
Expand Down Expand Up @@ -33,59 +32,59 @@ fn get_column_mappings_for_testing() -> IndexMap<Ident, ColumnRef> {
let mut column_mapping = IndexMap::default();
// Setup column mapping
column_mapping.insert(
ident("boolean_column"),
ColumnRef::new(tab_ref, ident("boolean_column"), ColumnType::Boolean),
"boolean_column".into(),
ColumnRef::new(tab_ref, "boolean_column".into(), ColumnType::Boolean),
);
column_mapping.insert(
ident("decimal_column"),
"decimal_column".into(),
ColumnRef::new(
tab_ref,
ident("decimal_column"),
"decimal_column".into(),
ColumnType::Decimal75(Precision::new(7).unwrap(), 2),
),
);
column_mapping.insert(
ident("int128_column"),
ColumnRef::new(tab_ref, ident("int128_column"), ColumnType::Int128),
"int128_column".into(),
ColumnRef::new(tab_ref, "int128_column".into(), ColumnType::Int128),
);
column_mapping.insert(
ident("bigint_column"),
ColumnRef::new(tab_ref, ident("bigint_column"), ColumnType::BigInt),
"bigint_column".into(),
ColumnRef::new(tab_ref, "bigint_column".into(), ColumnType::BigInt),
);

column_mapping.insert(
ident("varchar_column"),
ColumnRef::new(tab_ref, ident("varchar_column"), ColumnType::VarChar),
"varchar_column".into(),
ColumnRef::new(tab_ref, "varchar_column".into(), ColumnType::VarChar),
);
column_mapping.insert(
ident("timestamp_second_column"),
"timestamp_second_column".into(),
ColumnRef::new(
tab_ref,
ident("timestamp_second_column"),
"timestamp_second_column".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()),
),
);
column_mapping.insert(
ident("timestamp_millisecond_column"),
"timestamp_millisecond_column".into(),
ColumnRef::new(
tab_ref,
ident("timestamp_millisecond_column"),
"timestamp_millisecond_column".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Millisecond, PoSQLTimeZone::utc()),
),
);
column_mapping.insert(
ident("timestamp_microsecond_column"),
"timestamp_microsecond_column".into(),
ColumnRef::new(
tab_ref,
ident("timestamp_microsecond_column"),
"timestamp_microsecond_column".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Microsecond, PoSQLTimeZone::utc()),
),
);
column_mapping.insert(
ident("timestamp_nanosecond_column"),
"timestamp_nanosecond_column".into(),
ColumnRef::new(
tab_ref,
ident("timestamp_nanosecond_column"),
"timestamp_nanosecond_column".into(),
ColumnType::TimestampTZ(PoSQLTimeUnit::Nanosecond, PoSQLTimeZone::utc()),
),
);
Expand Down Expand Up @@ -147,7 +146,7 @@ fn we_can_directly_check_whether_bigint_columns_ge_int128() {
let expected = DynProofExpr::try_new_inequality(
DynProofExpr::Column(ColumnExpr::new(ColumnRef::new(
"sxt.sxt_tab".parse().unwrap(),
ident("bigint_column"),
"bigint_column".into(),
ColumnType::BigInt,
))),
DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))),
Expand All @@ -169,7 +168,7 @@ fn we_can_directly_check_whether_bigint_columns_le_int128() {
let expected = DynProofExpr::try_new_inequality(
DynProofExpr::Column(ColumnExpr::new(ColumnRef::new(
"sxt.sxt_tab".parse().unwrap(),
ident("bigint_column"),
"bigint_column".into(),
ColumnType::BigInt,
))),
DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,6 @@ impl<S: Scalar> PostprocessingStep<S> for GroupByPostprocessing {
#[cfg(test)]
mod tests {
use super::*;
use crate::base::sqlparser::ident;
use proof_of_sql_parser::utility::*;

#[test]
Expand Down Expand Up @@ -400,13 +399,13 @@ mod tests {

// a + b + 1
let expr = add(add(col("a"), col("b")), lit(1));
let expected: IndexSet<Ident> = [ident("a"), ident("b")].into_iter().collect();
let expected: IndexSet<Ident> = ["a".into(), "b".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);

// ! (a == b || c >= a)
let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a"))));
let expected: IndexSet<Ident> = [ident("a"), ident("b"), ident("c")].into_iter().collect();
let expected: IndexSet<Ident> = ["a".into(), "b".into(), "c".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);

Expand All @@ -418,7 +417,7 @@ mod tests {

// (COUNT(a + b) + c) * d
let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d"));
let expected: IndexSet<Ident> = [ident("c"), ident("d")].into_iter().collect();
let expected: IndexSet<Ident> = ["c".into(), "d".into()].into_iter().collect();
let actual = get_free_identifiers_from_expr(&expr);
assert_eq!(actual, expected);
}
Expand All @@ -433,7 +432,7 @@ mod tests {
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
ident("__col_agg_0")
"__col_agg_0".into()
);
assert_eq!(remainder_expr, Ok(*add(col("__col_agg_0"), col("b"))));
assert_eq!(aggregation_expr_map.len(), 1);
Expand All @@ -444,11 +443,11 @@ mod tests {
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))],
ident("__col_agg_0")
"__col_agg_0".into()
);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))],
ident("__col_agg_1")
"__col_agg_1".into()
);
assert_eq!(
remainder_expr,
Expand All @@ -468,14 +467,14 @@ mod tests {
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Max, *add(col("a"), lit(1)))],
ident("__col_agg_2")
"__col_agg_2".into()
);
assert_eq!(
aggregation_expr_map[&(
AggregationOperator::Min,
*sub(mul(lit(2), col("b")), lit(4))
)],
ident("__col_agg_3")
"__col_agg_3".into()
);
assert_eq!(
remainder_expr,
Expand All @@ -492,7 +491,7 @@ mod tests {
get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map);
assert_eq!(
aggregation_expr_map[&(AggregationOperator::Count, *mul(lit(2), col("a")))],
ident("__col_agg_4")
"__col_agg_4".into()
);
assert_eq!(
remainder_expr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::{
base::{
database::{owned_table_utility::*, OwnedTable},
scalar::Curve25519Scalar,
sqlparser::ident,
},
sql::postprocessing::{
apply_postprocessing_steps, group_by_postprocessing::*, test_utility::*,
Expand All @@ -15,15 +14,15 @@ use proof_of_sql_parser::{intermediate_ast::AggregationOperator, utility::*};
fn we_cannot_have_invalid_group_bys() {
// Column in result but not in group by or aggregation
let expr = add(sum(col("a")), col("b")); // b is not in group by or aggregation
let res = GroupByPostprocessing::try_new(vec![ident("a")], vec![aliased_expr(expr, "res")]);
let res = GroupByPostprocessing::try_new(vec!["a".into()], vec![aliased_expr(expr, "res")]);
assert!(matches!(
res,
Err(PostprocessingError::IdentifierNotInAggregationOperatorOrGroupByClause { .. })
));

// Nested aggregation
let expr = sum(max(col("a"))); // Nested aggregation
let res = GroupByPostprocessing::try_new(vec![ident("a")], vec![aliased_expr(expr, "res")]);
let res = GroupByPostprocessing::try_new(vec!["a".into()], vec![aliased_expr(expr, "res")]);
assert!(matches!(
res,
Err(PostprocessingError::NestedAggregationInGroupByClause { .. })
Expand All @@ -34,14 +33,14 @@ fn we_cannot_have_invalid_group_bys() {
fn we_can_make_group_by_postprocessing() {
// SELECT SUM(a) + 2 as c0, SUM(b + a) as c1 FROM tab GROUP BY a, b
let res = GroupByPostprocessing::try_new(
vec![ident("a"), ident("b")],
vec!["a".into(), "b".into()],
vec![
aliased_expr(add(sum(col("a")), lit(2)), "c0"),
aliased_expr(sum(add(col("b"), col("a"))), "c1"),
],
)
.unwrap();
assert_eq!(res.group_by(), &[ident("a"), ident("b")]);
assert_eq!(res.group_by(), &["a".into(), "b".into()]);
assert_eq!(
res.remainder_exprs(),
&[
Expand All @@ -52,11 +51,11 @@ fn we_can_make_group_by_postprocessing() {
assert_eq!(
res.aggregation_exprs(),
&[
(AggregationOperator::Sum, *col("a"), ident("__col_agg_0")),
(AggregationOperator::Sum, *col("a"), "__col_agg_0".into()),
(
AggregationOperator::Sum,
*add(col("b"), col("a")),
ident("__col_agg_1")
"__col_agg_1".into()
),
]
);
Expand Down
3 changes: 1 addition & 2 deletions crates/proof-of-sql/src/sql/postprocessing/test_utility.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::*;
use crate::base::sqlparser::ident;
use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection};
use sqlparser::ast::Ident;

Expand All @@ -8,7 +7,7 @@ pub fn group_by_postprocessing(
cols: &[&str],
result_exprs: &[AliasedResultExpr],
) -> OwnedTablePostprocessing {
let ids: Vec<Ident> = cols.iter().map(|col| ident(col)).collect();
let ids: Vec<Ident> = cols.iter().map(|col| (*col).into()).collect();
OwnedTablePostprocessing::new_group_by(
GroupByPostprocessing::try_new(ids, result_exprs.to_vec()).unwrap(),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ use crate::{
map::{indexset, IndexMap, IndexSet},
proof::ProofError,
scalar::Scalar,
sqlparser::ident,
},
sql::proof::{FirstRoundBuilder, QueryData},
};
use bumpalo::Bump;
use serde::Serialize;
use sqlparser::ast::Ident;

#[derive(Debug, Serialize, Default)]
pub(super) struct EmptyTestQueryExpr {
Expand All @@ -35,8 +35,9 @@ impl ProverEvaluate for EmptyTestQueryExpr {
let zeros = vec![0_i64; self.length];
builder.produce_one_evaluation_length(self.length);
table_with_row_count(
(1..=self.columns)
.map(|i| borrowed_bigint(ident(format!("a{i}").as_str()), zeros.clone(), alloc)),
(1..=self.columns).map(|i| {
borrowed_bigint(Ident::from(format!("a{i}").as_str()), zeros.clone(), alloc)
}),
self.length,
)
}
Expand All @@ -53,8 +54,9 @@ impl ProverEvaluate for EmptyTestQueryExpr {
.take(self.columns)
.collect::<Vec<_>>();
table_with_row_count(
(1..=self.columns)
.map(|i| borrowed_bigint(ident(format!("a{i}").as_str()), zeros.clone(), alloc)),
(1..=self.columns).map(|i| {
borrowed_bigint(Ident::from(format!("a{i}").as_str()), zeros.clone(), alloc)
}),
self.length,
)
}
Expand Down

0 comments on commit 4d1b9cf

Please sign in to comment.