From e337e654ef58951f06919abf97003a322cf9d2ca Mon Sep 17 00:00:00 2001 From: Vamshi Maskuri <117595548+varshith257@users.noreply.github.com> Date: Wed, 13 Nov 2024 22:10:09 +0530 Subject: [PATCH] refactor!: PoSQLBinaryOP to use sqlparser::ast::BinaryOP --- .../base/database/expression_evaluation.rs | 25 ++++++---- .../src/sql/parse/dyn_proof_expr_builder.rs | 28 +++++++---- .../src/sql/parse/query_context_builder.rs | 47 ++++++++++++------- .../src/sql/proof_exprs/comparison_util.rs | 14 +++--- .../src/sql/proof_exprs/dyn_proof_expr.rs | 17 +++---- 5 files changed, 76 insertions(+), 55 deletions(-) diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation.rs b/crates/proof-of-sql/src/base/database/expression_evaluation.rs index 5400fa805..c5d2315b0 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation.rs @@ -9,10 +9,10 @@ use crate::base::{ }; use alloc::{format, string::ToString, vec}; use proof_of_sql_parser::{ - intermediate_ast::{BinaryOperator, Expression, Literal}, + intermediate_ast::{Expression, Literal}, Identifier, }; -use sqlparser::ast::UnaryOperator; +use sqlparser::ast::{BinaryOperator, UnaryOperator}; impl OwnedTable { /// Evaluate an expression on the table. @@ -20,7 +20,9 @@ impl OwnedTable { match expr { Expression::Column(identifier) => self.evaluate_column(identifier), Expression::Literal(lit) => self.evaluate_literal(lit), - Expression::Binary { op, left, right } => self.evaluate_binary_expr(*op, left, right), + Expression::Binary { op, left, right } => { + self.evaluate_binary_expr(&(*op).into(), left, right) + } Expression::Unary { op, expr } => self.evaluate_unary_expr((*op).into(), expr), _ => Err(ExpressionEvaluationError::Unsupported { expression: format!("Expression {expr:?} is not supported yet"), @@ -84,7 +86,7 @@ impl OwnedTable { fn evaluate_binary_expr( &self, - op: BinaryOperator, + op: &BinaryOperator, left: &Expression, right: &Expression, ) -> ExpressionEvaluationResult> { @@ -93,13 +95,16 @@ impl OwnedTable { match op { BinaryOperator::And => Ok(left.element_wise_and(&right)?), BinaryOperator::Or => Ok(left.element_wise_or(&right)?), - BinaryOperator::Equal => Ok(left.element_wise_eq(&right)?), - BinaryOperator::GreaterThanOrEqual => Ok(left.element_wise_ge(&right)?), - BinaryOperator::LessThanOrEqual => Ok(left.element_wise_le(&right)?), - BinaryOperator::Add => Ok(left.element_wise_add(&right)?), - BinaryOperator::Subtract => Ok(left.element_wise_sub(&right)?), + BinaryOperator::Eq => Ok(left.element_wise_eq(&right)?), + BinaryOperator::GtEq => Ok(left.element_wise_ge(&right)?), + BinaryOperator::LtEq => Ok(left.element_wise_le(&right)?), + BinaryOperator::Plus => Ok(left.element_wise_add(&right)?), + BinaryOperator::Minus => Ok(left.element_wise_sub(&right)?), BinaryOperator::Multiply => Ok(left.element_wise_mul(&right)?), - BinaryOperator::Division => Ok(left.element_wise_div(&right)?), + BinaryOperator::Divide => Ok(left.element_wise_div(&right)?), + _ => Err(ExpressionEvaluationError::Unsupported { + expression: format!("Binary operator '{op}' is not supported."), + }), } } } diff --git a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs index a50a271d0..6bf6e2b16 100644 --- a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs @@ -19,11 +19,11 @@ use crate::{ }; use alloc::{borrow::ToOwned, boxed::Box, format, string::ToString}; use proof_of_sql_parser::{ - intermediate_ast::{AggregationOperator, BinaryOperator, Expression, Literal}, + intermediate_ast::{AggregationOperator, Expression, Literal}, posql_time::{PoSQLTimeUnit, PoSQLTimestampError}, Identifier, }; -use sqlparser::ast::UnaryOperator; +use sqlparser::ast::{BinaryOperator, UnaryOperator}; /// Builder that enables building a `proofs::sql::proof_exprs::DynProofExpr` from /// a `proof_of_sql_parser::intermediate_ast::Expression`. @@ -60,7 +60,9 @@ impl DynProofExprBuilder<'_> { match expr { Expression::Column(identifier) => self.visit_column(*identifier), Expression::Literal(lit) => self.visit_literal(lit), - Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right), + Expression::Binary { op, left, right } => { + self.visit_binary_expr(&(*op).into(), left, right) + } Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr), Expression::Aggregation { op, expr } => self.visit_aggregate_expr(*op, expr), _ => Err(ConversionError::Unprovable { @@ -146,7 +148,7 @@ impl DynProofExprBuilder<'_> { fn visit_binary_expr( &self, - op: BinaryOperator, + op: &BinaryOperator, left: &Expression, right: &Expression, ) -> Result { @@ -161,27 +163,27 @@ impl DynProofExprBuilder<'_> { let right = self.visit_expr(right); DynProofExpr::try_new_or(left?, right?) } - BinaryOperator::Equal => { + BinaryOperator::Eq => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_equals(left?, right?) } - BinaryOperator::GreaterThanOrEqual => { + BinaryOperator::GtEq => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_inequality(left?, right?, false) } - BinaryOperator::LessThanOrEqual => { + BinaryOperator::LtEq => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_inequality(left?, right?, true) } - BinaryOperator::Add => { + BinaryOperator::Plus => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_add(left?, right?) } - BinaryOperator::Subtract => { + BinaryOperator::Minus => { let left = self.visit_expr(left); let right = self.visit_expr(right); DynProofExpr::try_new_subtract(left?, right?) @@ -191,9 +193,15 @@ impl DynProofExprBuilder<'_> { let right = self.visit_expr(right); DynProofExpr::try_new_multiply(left?, right?) } - BinaryOperator::Division => Err(ConversionError::Unprovable { + BinaryOperator::Divide => Err(ConversionError::Unprovable { error: format!("Binary operator {op:?} is not supported at this location"), }), + _ => { + // Handle unsupported binary operations + Err(ConversionError::UnsupportedOperation { + message: format!("{op:?}"), + }) + } } } diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 8e9bd3309..c7f21ac9a 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -12,12 +12,12 @@ use crate::base::{ use alloc::{boxed::Box, format, string::ToString, vec::Vec}; use proof_of_sql_parser::{ intermediate_ast::{ - AggregationOperator, AliasedResultExpr, BinaryOperator, Expression, Literal, OrderBy, - SelectResultExpr, Slice, TableExpression, + AggregationOperator, AliasedResultExpr, Expression, Literal, OrderBy, SelectResultExpr, + Slice, TableExpression, }, Identifier, ResourceId, }; -use sqlparser::ast::UnaryOperator; +use sqlparser::ast::{BinaryOperator, UnaryOperator}; pub struct QueryContextBuilder<'a> { context: QueryContext, schema_accessor: &'a dyn SchemaAccessor, @@ -138,7 +138,9 @@ impl<'a> QueryContextBuilder<'a> { Expression::Literal(literal) => self.visit_literal(literal), Expression::Column(_) => self.visit_column_expr(expr), Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr), - Expression::Binary { op, left, right } => self.visit_binary_expr(*op, left, right), + Expression::Binary { op, left, right } => { + self.visit_binary_expr(&(*op).into(), left, right) + } Expression::Aggregation { op, expr } => self.visit_agg_expr(*op, expr), } } @@ -156,7 +158,7 @@ impl<'a> QueryContextBuilder<'a> { fn visit_binary_expr( &mut self, - op: BinaryOperator, + op: &BinaryOperator, left: &Expression, right: &Expression, ) -> ConversionResult { @@ -166,13 +168,19 @@ impl<'a> QueryContextBuilder<'a> { match op { BinaryOperator::And | BinaryOperator::Or - | BinaryOperator::Equal - | BinaryOperator::GreaterThanOrEqual - | BinaryOperator::LessThanOrEqual => Ok(ColumnType::Boolean), + | BinaryOperator::Eq + | BinaryOperator::GtEq + | BinaryOperator::LtEq => Ok(ColumnType::Boolean), BinaryOperator::Multiply - | BinaryOperator::Division - | BinaryOperator::Subtract - | BinaryOperator::Add => Ok(left_dtype), + | BinaryOperator::Divide + | BinaryOperator::Minus + | BinaryOperator::Plus => Ok(left_dtype), + _ => { + // Handle unsupported binary operations + Err(ConversionError::UnsupportedOperation { + message: format!("{op:?}"), + }) + } } } @@ -268,7 +276,7 @@ impl<'a> QueryContextBuilder<'a> { pub(crate) fn type_check_binary_operation( left_dtype: &ColumnType, right_dtype: &ColumnType, - binary_operator: BinaryOperator, + binary_operator: &BinaryOperator, ) -> bool { match binary_operator { BinaryOperator::And | BinaryOperator::Or => { @@ -277,7 +285,7 @@ pub(crate) fn type_check_binary_operation( (ColumnType::Boolean, ColumnType::Boolean) ) } - BinaryOperator::Equal => { + BinaryOperator::Eq => { matches!( (left_dtype, right_dtype), (ColumnType::VarChar, ColumnType::VarChar) @@ -287,7 +295,7 @@ pub(crate) fn type_check_binary_operation( | (ColumnType::Scalar, _) ) || (left_dtype.is_numeric() && right_dtype.is_numeric()) } - BinaryOperator::GreaterThanOrEqual | BinaryOperator::LessThanOrEqual => { + BinaryOperator::GtEq | BinaryOperator::LtEq => { if left_dtype == &ColumnType::VarChar || right_dtype == &ColumnType::VarChar { return false; } @@ -309,19 +317,22 @@ pub(crate) fn type_check_binary_operation( | (ColumnType::TimestampTZ(_, _), ColumnType::TimestampTZ(_, _)) ) } - BinaryOperator::Add => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(), - BinaryOperator::Subtract => { + BinaryOperator::Plus | BinaryOperator::Minus => { try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok() } BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(), - BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(), + BinaryOperator::Divide => left_dtype.is_numeric() && right_dtype.is_numeric(), + _ => { + // Handle unsupported binary operations + false + } } } fn check_dtypes( left_dtype: ColumnType, right_dtype: ColumnType, - binary_operator: BinaryOperator, + binary_operator: &BinaryOperator, ) -> ConversionResult<()> { if type_check_binary_operation(&left_dtype, &right_dtype, binary_operator) { Ok(()) diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index bbb9ab882..f8b7dcf22 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -10,9 +10,9 @@ use crate::{ use alloc::string::ToString; use bumpalo::Bump; use core::cmp::{max, Ordering}; -use proof_of_sql_parser::intermediate_ast::BinaryOperator; #[cfg(feature = "rayon")] use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +use sqlparser::ast::BinaryOperator; #[allow(clippy::unnecessary_wraps)] fn unchecked_subtract_impl<'a, S: Scalar>( @@ -48,11 +48,11 @@ pub fn scale_and_subtract_literal( let lhs_type = lhs.column_type(); let rhs_type = rhs.column_type(); let operator = if is_equal { - BinaryOperator::Equal + BinaryOperator::Eq } else { - BinaryOperator::LessThanOrEqual + BinaryOperator::LtEq }; - if !type_check_binary_operation(&lhs_type, &rhs_type, operator) { + if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) { return Err(ConversionError::DataTypeMismatch { left_type: lhs_type.to_string(), right_type: rhs_type.to_string(), @@ -121,11 +121,11 @@ pub(crate) fn scale_and_subtract<'a, S: Scalar>( let lhs_type = lhs.column_type(); let rhs_type = rhs.column_type(); let operator = if is_equal { - BinaryOperator::Equal + BinaryOperator::Eq } else { - BinaryOperator::LessThanOrEqual + BinaryOperator::LtEq }; - if !type_check_binary_operation(&lhs_type, &rhs_type, operator) { + if !type_check_binary_operation(&lhs_type, &rhs_type, &operator) { return Err(ConversionError::DataTypeMismatch { left_type: lhs_type.to_string(), right_type: rhs_type.to_string(), diff --git a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs index cbf373e47..ef85d8e29 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs @@ -17,8 +17,9 @@ use crate::{ use alloc::{boxed::Box, string::ToString}; use bumpalo::Bump; use core::fmt::Debug; -use proof_of_sql_parser::intermediate_ast::{AggregationOperator, BinaryOperator}; +use proof_of_sql_parser::intermediate_ast::AggregationOperator; use serde::{Deserialize, Serialize}; +use sqlparser::ast::BinaryOperator; /// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -75,7 +76,7 @@ impl DynProofExpr { pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Equal) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Eq) { Ok(Self::Equals(EqualsExpr::new(Box::new(lhs), Box::new(rhs)))) } else { Err(ConversionError::DataTypeMismatch { @@ -92,11 +93,7 @@ impl DynProofExpr { ) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation( - &lhs_datatype, - &rhs_datatype, - BinaryOperator::LessThanOrEqual, - ) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::LtEq) { Ok(Self::Inequality(InequalityExpr::new( Box::new(lhs), Box::new(rhs), @@ -114,7 +111,7 @@ impl DynProofExpr { pub fn try_new_add(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Add) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Plus) { Ok(Self::AddSubtract(AddSubtractExpr::new( Box::new(lhs), Box::new(rhs), @@ -132,7 +129,7 @@ impl DynProofExpr { pub fn try_new_subtract(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Subtract) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Minus) { Ok(Self::AddSubtract(AddSubtractExpr::new( Box::new(lhs), Box::new(rhs), @@ -150,7 +147,7 @@ impl DynProofExpr { pub fn try_new_multiply(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { let lhs_datatype = lhs.data_type(); let rhs_datatype = rhs.data_type(); - if type_check_binary_operation(&lhs_datatype, &rhs_datatype, BinaryOperator::Multiply) { + if type_check_binary_operation(&lhs_datatype, &rhs_datatype, &BinaryOperator::Multiply) { Ok(Self::Multiply(MultiplyExpr::new( Box::new(lhs), Box::new(rhs),