From 144a86ff9be2acf02c45ad25d92288d0a22eb4b0 Mon Sep 17 00:00:00 2001 From: Vamshi Maskuri <117595548+varshith257@users.noreply.github.com> Date: Tue, 12 Nov 2024 12:48:57 +0530 Subject: [PATCH] refactor!: PoSQLBinaryOP to use sqlparser::ast::BinaryOP --- .../base/database/expression_evaluation.rs | 26 ++++++--- .../database/expression_evaluation_error.rs | 6 +++ .../src/sql/parse/dyn_proof_expr_builder.rs | 26 ++++++--- crates/proof-of-sql/src/sql/parse/error.rs | 7 +++ .../src/sql/parse/query_context_builder.rs | 54 +++++++++++-------- .../src/sql/proof_exprs/comparison_util.rs | 14 ++--- .../src/sql/proof_exprs/dyn_proof_expr.rs | 17 +++--- 7 files changed, 96 insertions(+), 54 deletions(-) diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation.rs b/crates/proof-of-sql/src/base/database/expression_evaluation.rs index f4e81424c..8d0e4ef1f 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation.rs @@ -9,9 +9,10 @@ use crate::base::{ }; use alloc::{format, string::ToString, vec}; use proof_of_sql_parser::{ - intermediate_ast::{BinaryOperator, Expression, Literal, UnaryOperator}, + intermediate_ast::{Expression, Literal, UnaryOperator}, Identifier, }; +use sqlparser::ast::BinaryOperator; impl OwnedTable { /// Evaluate an expression on the table. @@ -19,7 +20,9 @@ impl OwnedTable { match expr { Expression::Column(identifier) => self.evaluate_column(identifier), Expression::Literal(lit) => self.evaluate_literal(lit), - Expression::Binary { op, left, right } => self.evaluate_binary_expr(*op, left, right), + Expression::Binary { op, left, right } => { + self.evaluate_binary_expr((*op).into(), left, right) + } Expression::Unary { op, expr } => self.evaluate_unary_expr(*op, expr), _ => Err(ExpressionEvaluationError::Unsupported { expression: format!("Expression {expr:?} is not supported yet"), @@ -77,6 +80,7 @@ impl OwnedTable { } } + #[allow(clippy::needless_pass_by_value)] fn evaluate_binary_expr( &self, op: BinaryOperator, @@ -88,13 +92,19 @@ impl OwnedTable { match op { BinaryOperator::And => Ok(left.element_wise_and(&right)?), BinaryOperator::Or => Ok(left.element_wise_or(&right)?), - BinaryOperator::Equal => Ok(left.element_wise_eq(&right)?), - BinaryOperator::GreaterThanOrEqual => Ok(left.element_wise_ge(&right)?), - BinaryOperator::LessThanOrEqual => Ok(left.element_wise_le(&right)?), - BinaryOperator::Add => Ok((left + right)?), - BinaryOperator::Subtract => Ok((left - right)?), + BinaryOperator::Eq => Ok(left.element_wise_eq(&right)?), + BinaryOperator::GtEq => Ok(left.element_wise_ge(&right)?), + BinaryOperator::LtEq => Ok(left.element_wise_le(&right)?), + BinaryOperator::Plus => Ok((left + right)?), + BinaryOperator::Minus => Ok((left - right)?), BinaryOperator::Multiply => Ok((left * right)?), - BinaryOperator::Division => Ok((left / right)?), + BinaryOperator::Divide => Ok((left / right)?), + _ => { + // Handle unsupported binary operations + Err(ExpressionEvaluationError::UnsupportedBinaryOperator { + operator: format!("{op:?}"), + }) + } } } } diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation_error.rs b/crates/proof-of-sql/src/base/database/expression_evaluation_error.rs index c49d953be..02f1313b6 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation_error.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation_error.rs @@ -30,6 +30,12 @@ pub enum ExpressionEvaluationError { /// The underlying source error source: DecimalError, }, + #[snafu(display("Unsupported binary operator: {operator}"))] + /// Unsupported binary operation + UnsupportedBinaryOperator { + /// The unsupported binary operator + operator: String, + }, } /// Result type for expression evaluation diff --git a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs index 50550f8c2..160b65069 100644 --- a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs @@ -19,10 +19,11 @@ use crate::{ }; use alloc::{borrow::ToOwned, boxed::Box, format, string::ToString}; use proof_of_sql_parser::{ - intermediate_ast::{AggregationOperator, BinaryOperator, Expression, Literal, UnaryOperator}, + intermediate_ast::{AggregationOperator, Expression, Literal, UnaryOperator}, posql_time::{PoSQLTimeUnit, PoSQLTimestampError}, Identifier, }; +use sqlparser::ast::BinaryOperator; /// Builder that enables building a `proofs::sql::proof_exprs::DynProofExpr` from /// a `proof_of_sql_parser::intermediate_ast::Expression`. @@ -59,7 +60,9 @@ impl DynProofExprBuilder<'_> { match expr { Expression::Column(identifier) => self.visit_column(*identifier), Expression::Literal(lit) => self.visit_literal(lit), - Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right), + Expression::Binary { op, left, right } => { + self.visit_binary_expr((*op).into(), left, right) + } Expression::Unary { op, expr } => self.visit_unary_expr(*op, expr), Expression::Aggregation { op, expr } => self.visit_aggregate_expr(*op, expr), _ => Err(ConversionError::Unprovable { @@ -139,6 +142,7 @@ impl DynProofExprBuilder<'_> { } } + #[allow(clippy::needless_pass_by_value)] fn visit_binary_expr( &self, op: BinaryOperator, @@ -156,27 +160,27 @@ impl DynProofExprBuilder<'_> { let right = self.visit_expr(right); DynProofExpr::try_new_or(left?, right?) } - BinaryOperator::Equal => { + BinaryOperator::Eq => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_equals(left?, right?) } - BinaryOperator::GreaterThanOrEqual => { + BinaryOperator::GtEq => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_inequality(left?, right?, false) } - BinaryOperator::LessThanOrEqual => { + BinaryOperator::LtEq => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_inequality(left?, right?, true) } - BinaryOperator::Add => { + BinaryOperator::Plus => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_add(left?, right?) } - BinaryOperator::Subtract => { + BinaryOperator::Minus => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_subtract(left?, right?) @@ -186,9 +190,15 @@ impl DynProofExprBuilder<'_> { let right = self.visit_expr(right); DynProofExpr::try_new_multiply(left?, right?) } - BinaryOperator::Division => Err(ConversionError::Unprovable { + BinaryOperator::Divide => Err(ConversionError::Unprovable { error: format!("Binary operator {op:?} is not supported at this location"), }), + _ => { + // Handle unsupported binary operations + Err(ConversionError::UnsupportedBinaryOperator { + operator: format!("{op:?}"), + }) + } } } diff --git a/crates/proof-of-sql/src/sql/parse/error.rs b/crates/proof-of-sql/src/sql/parse/error.rs index c608093dc..33ef668a8 100644 --- a/crates/proof-of-sql/src/sql/parse/error.rs +++ b/crates/proof-of-sql/src/sql/parse/error.rs @@ -139,6 +139,13 @@ pub enum ConversionError { /// The underlying error error: String, }, + + #[snafu(display("Unsupported binary operator: {operator}"))] + /// Unsupported binary operation + UnsupportedBinaryOperator { + /// The binary operator that is unsupported + operator: String, + }, } impl From for ConversionError { diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index dda1f6101..bbe7cb4f6 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -9,15 +9,15 @@ use crate::base::{ BigDecimalExt, }, }; -use alloc::{boxed::Box, string::ToString, vec::Vec}; +use alloc::{boxed::Box, format, string::ToString, vec::Vec}; use proof_of_sql_parser::{ intermediate_ast::{ - AggregationOperator, AliasedResultExpr, BinaryOperator, Expression, Literal, OrderBy, - SelectResultExpr, Slice, TableExpression, UnaryOperator, + AggregationOperator, AliasedResultExpr, Expression, Literal, OrderBy, SelectResultExpr, + Slice, TableExpression, UnaryOperator, }, Identifier, ResourceId, }; - +use sqlparser::ast::BinaryOperator; pub struct QueryContextBuilder<'a> { context: QueryContext, schema_accessor: &'a dyn SchemaAccessor, @@ -138,7 +138,9 @@ impl<'a> QueryContextBuilder<'a> { Expression::Literal(literal) => self.visit_literal(literal), Expression::Column(_) => self.visit_column_expr(expr), Expression::Unary { op, expr } => self.visit_unary_expr(*op, expr), - Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right), + Expression::Binary { op, left, right } => { + self.visit_binary_expr((*op).into(), left, right) + } Expression::Aggregation { op, expr } => self.visit_agg_expr(*op, expr), } } @@ -154,6 +156,7 @@ impl<'a> QueryContextBuilder<'a> { self.visit_column_identifier(identifier) } + #[allow(clippy::needless_pass_by_value)] fn visit_binary_expr( &mut self, op: BinaryOperator, @@ -162,17 +165,23 @@ impl<'a> QueryContextBuilder<'a> { ) -> ConversionResult { let left_dtype = self.visit_expr(left)?; let right_dtype = self.visit_expr(right)?; - check_dtypes(left_dtype, right_dtype, op)?; + check_dtypes(left_dtype, right_dtype, op.clone())?; match op { BinaryOperator::And | BinaryOperator::Or - | BinaryOperator::Equal - | BinaryOperator::GreaterThanOrEqual - | BinaryOperator::LessThanOrEqual => Ok(ColumnType::Boolean), + | BinaryOperator::Eq + | BinaryOperator::GtEq + | BinaryOperator::LtEq => Ok(ColumnType::Boolean), BinaryOperator::Multiply - | BinaryOperator::Division - | BinaryOperator::Subtract - | BinaryOperator::Add => Ok(left_dtype), + | BinaryOperator::Divide + | BinaryOperator::Minus + | BinaryOperator::Plus => Ok(left_dtype), + _ => { + // Handle unsupported binary operations + Err(ConversionError::UnsupportedBinaryOperator { + operator: format!("{op:?}"), + }) + } } } @@ -264,7 +273,7 @@ impl<'a> QueryContextBuilder<'a> { pub(crate) fn type_check_binary_operation( left_dtype: &ColumnType, right_dtype: &ColumnType, - binary_operator: BinaryOperator, + binary_operator: &BinaryOperator, ) -> bool { match binary_operator { BinaryOperator::And | BinaryOperator::Or => { @@ -273,7 +282,7 @@ pub(crate) fn type_check_binary_operation( (ColumnType::Boolean, ColumnType::Boolean) ) } - BinaryOperator::Equal => { + BinaryOperator::Eq => { matches!( (left_dtype, right_dtype), (ColumnType::VarChar, ColumnType::VarChar) @@ -283,7 +292,7 @@ pub(crate) fn type_check_binary_operation( | (ColumnType::Scalar, _) ) || (left_dtype.is_numeric() && right_dtype.is_numeric()) } - BinaryOperator::GreaterThanOrEqual | BinaryOperator::LessThanOrEqual => { + BinaryOperator::GtEq | BinaryOperator::LtEq => { if left_dtype == &ColumnType::VarChar || right_dtype == &ColumnType::VarChar { return false; } @@ -305,21 +314,24 @@ pub(crate) fn type_check_binary_operation( | (ColumnType::TimestampTZ(_, _), ColumnType::TimestampTZ(_, _)) ) } - BinaryOperator::Add => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(), - BinaryOperator::Subtract => { - try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok() - } + BinaryOperator::Plus => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(), + BinaryOperator::Minus => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(), BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(), - BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(), + BinaryOperator::Divide => left_dtype.is_numeric() && right_dtype.is_numeric(), + _ => { + // Handle unsupported binary operations + false + } } } +#[allow(clippy::needless_pass_by_value)] fn check_dtypes( left_dtype: ColumnType, right_dtype: ColumnType, binary_operator: BinaryOperator, ) -> ConversionResult<()> { - if type_check_binary_operation(&left_dtype, &right_dtype, binary_operator) { + if type_check_binary_operation(&left_dtype, &right_dtype, &binary_operator) { Ok(()) } else { Err(ConversionError::DataTypeMismatch { diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index bbb9ab882..f8b7dcf22 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -10,9 +10,9 @@ use crate::{ use alloc::string::ToString; use bumpalo::Bump; use core::cmp::{max, Ordering}; -use proof_of_sql_parser::intermediate_ast::BinaryOperator; #[cfg(feature = "rayon")] use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +use sqlparser::ast::BinaryOperator; #[allow(clippy::unnecessary_wraps)] fn unchecked_subtract_impl<'a, S: Scalar>( @@ -48,11 +48,11 @@ pub fn scale_and_subtract_literal( let lhs_type = lhs.column_type(); let rhs_type = rhs.column_type(); let operator = if is_equal { - BinaryOperator::Equal + BinaryOperator::Eq } else { - BinaryOperator::LessThanOrEqual + BinaryOperator::LtEq }; - if !type_check_binary_operation(&lhs_type, &rhs_type, operator) { + if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) { return Err(ConversionError::DataTypeMismatch { left_type: lhs_type.to_string(), right_type: rhs_type.to_string(), @@ -121,11 +121,11 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( let lhs_type = lhs.column_type(); let rhs_type = rhs.column_type(); let operator = if is_equal { - BinaryOperator::Equal + BinaryOperator::Eq } else { - BinaryOperator::LessThanOrEqual + BinaryOperator::LtEq }; - if !type_check_binary_operation(&lhs_type, &rhs_type, operator) { + if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) { return Err(ConversionError::DataTypeMismatch { left_type: lhs_type.to_string(), right_type: rhs_type.to_string(), diff --git a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs index 06c46578c..b36d5cf1f 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs @@ -18,8 +18,9 @@ use crate::{ use alloc::{boxed::Box, string::ToString}; use bumpalo::Bump; use core::fmt::Debug; -use proof_of_sql_parser::intermediate_ast::{AggregationOperator, BinaryOperator}; +use proof_of_sql_parser::intermediate_ast::AggregationOperator; use serde::{Deserialize, Serialize}; +use sqlparser::ast::BinaryOperator; /// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -76,7 +77,7 @@ impl DynProofExpr { pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Equal) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Eq) { Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs)))) } else { Err(ConversionError::DataTypeMismatch { @@ -93,11 +94,7 @@ impl DynProofExpr { ) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation( - &lhs_datatype, - &rhs_datatype, - BinaryOperator::LessThanOrEqual, - ) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::LtEq) { Ok(Self::Inequality(InequalityExpr::new( Box::new(lhs), Box::new(rhs), @@ -115,7 +112,7 @@ impl DynProofExpr { pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Add) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Plus) { Ok(Self::AddSubtract(AddSubtractExpr::new( Box::new(lhs), Box::new(rhs), @@ -133,7 +130,7 @@ impl DynProofExpr { pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Subtract) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Minus) { Ok(Self::AddSubtract(AddSubtractExpr::new( Box::new(lhs), Box::new(rhs), @@ -151,7 +148,7 @@ impl DynProofExpr { pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Multiply) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Multiply) { Ok(Self::Multiply(MultiplyExpr::new( Box::new(lhs), Box::new(rhs),