diff --git a/crates/proof-of-sql/src/base/database/column_type_operation.rs b/crates/proof-of-sql/src/base/database/column_type_operation.rs index 2c496b95c..a425f73a8 100644 --- a/crates/proof-of-sql/src/base/database/column_type_operation.rs +++ b/crates/proof-of-sql/src/base/database/column_type_operation.rs @@ -4,7 +4,6 @@ use crate::base::{ math::decimal::{DecimalError, Precision}, }; use alloc::{format, string::ToString}; -use proof_of_sql_parser::intermediate_ast::BinaryOperator; // For decimal type manipulation please refer to // https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16 @@ -19,7 +18,6 @@ use proof_of_sql_parser::intermediate_ast::BinaryOperator; pub fn try_add_subtract_column_types( lhs: ColumnType, rhs: ColumnType, - _operator: BinaryOperator, ) -> ColumnOperationResult { if !lhs.is_numeric() || !rhs.is_numeric() { return Err(ColumnOperationError::BinaryOperationInvalidColumnType { @@ -180,79 +178,79 @@ mod test { // lhs and rhs are integers with the same precision let lhs = ColumnType::TinyInt; let rhs = ColumnType::TinyInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::TinyInt; assert_eq!(expected, actual); let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::SmallInt; assert_eq!(expected, actual); // lhs and rhs are integers with different precision let lhs = ColumnType::TinyInt; let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::SmallInt; assert_eq!(expected, actual); let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Int; assert_eq!(expected, actual); // lhs is an integer and rhs is a scalar let lhs = ColumnType::TinyInt; let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Scalar; assert_eq!(expected, actual); let lhs = ColumnType::SmallInt; let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Scalar; assert_eq!(expected, actual); // lhs is a decimal with nonnegative scale and rhs is an integer let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::TinyInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); assert_eq!(expected, actual); let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); assert_eq!(expected, actual); // lhs and rhs are both decimals with nonnegative scale let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3); assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale let lhs = ColumnType::TinyInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); assert_eq!(expected, actual); let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); assert_eq!(expected, actual); // lhs and rhs are both decimals one of which has negative scale let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5); assert_eq!(expected, actual); @@ -260,7 +258,7 @@ mod test { // and with result having maximum precision let lhs = ColumnType::Decimal75(Precision::new(74).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), -14); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13); assert_eq!(expected, actual); } @@ -270,21 +268,21 @@ mod test { let lhs = ColumnType::TinyInt; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::VarChar; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); } @@ -294,7 +292,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 4); let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 4); assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -303,7 +301,7 @@ mod test { let lhs = ColumnType::Int; let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 10); assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -315,79 +313,79 @@ mod test { // lhs and rhs are integers with the same precision let lhs = ColumnType::TinyInt; let rhs = ColumnType::TinyInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::TinyInt; assert_eq!(expected, actual); let lhs = ColumnType::SmallInt; let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::SmallInt; assert_eq!(expected, actual); // lhs and rhs are integers with different precision let lhs = ColumnType::TinyInt; let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::SmallInt; assert_eq!(expected, actual); let lhs = ColumnType::SmallInt; let rhs = ColumnType::Int; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Int; assert_eq!(expected, actual); // lhs is an integer and rhs is a scalar let lhs = ColumnType::TinyInt; let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Scalar; assert_eq!(expected, actual); let lhs = ColumnType::SmallInt; let rhs = ColumnType::Scalar; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Scalar; assert_eq!(expected, actual); // lhs is a decimal and rhs is an integer let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::TinyInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); assert_eq!(expected, actual); let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); let rhs = ColumnType::SmallInt; - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2); assert_eq!(expected, actual); // lhs and rhs are both decimals with nonnegative scale let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3); let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3); assert_eq!(expected, actual); // lhs is an integer and rhs is a decimal with negative scale let lhs = ColumnType::TinyInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); assert_eq!(expected, actual); let lhs = ColumnType::SmallInt; let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0); assert_eq!(expected, actual); // lhs and rhs are both decimals one of which has negative scale let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5); assert_eq!(expected, actual); @@ -395,7 +393,7 @@ mod test { // and with result having maximum precision let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13); let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), -14); - let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap(); + let actual = try_add_subtract_column_types(lhs, rhs).unwrap(); let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13); assert_eq!(expected, actual); } @@ -405,21 +403,21 @@ mod test { let lhs = ColumnType::TinyInt; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::SmallInt; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); let lhs = ColumnType::VarChar; let rhs = ColumnType::VarChar; assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. }) )); } @@ -429,7 +427,7 @@ mod test { let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 0); let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 1); assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) @@ -438,7 +436,7 @@ mod test { let lhs = ColumnType::Int128; let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 12); assert!(matches!( - try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract), + try_add_subtract_column_types(lhs, rhs), Err(ColumnOperationError::DecimalConversionError { source: DecimalError::InvalidPrecision { .. } }) diff --git a/crates/proof-of-sql/src/base/database/owned_column_operation.rs b/crates/proof-of-sql/src/base/database/owned_column_operation.rs index dcd512aef..a8a495115 100644 --- a/crates/proof-of-sql/src/base/database/owned_column_operation.rs +++ b/crates/proof-of-sql/src/base/database/owned_column_operation.rs @@ -16,6 +16,7 @@ use crate::base::{ }, scalar::Scalar, }; +use alloc::string::ToString; use core::ops::{Add, Div, Mul, Sub}; impl OwnedColumn { diff --git a/crates/proof-of-sql/src/base/database/slice_decimal_operation.rs b/crates/proof-of-sql/src/base/database/slice_decimal_operation.rs index d18c92c23..d6b3cd8ef 100644 --- a/crates/proof-of-sql/src/base/database/slice_decimal_operation.rs +++ b/crates/proof-of-sql/src/base/database/slice_decimal_operation.rs @@ -13,7 +13,6 @@ use alloc::vec::Vec; use core::{cmp::Ordering, fmt::Debug}; use num_bigint::BigInt; use num_traits::Zero; -use proof_of_sql_parser::intermediate_ast::BinaryOperator; /// Check whether a numerical slice is equal to a decimal one. /// /// Note that we do not check for length equality here. @@ -273,8 +272,7 @@ where T0: Copy, T1: Copy, { - let new_column_type = - try_add_subtract_column_types(left_column_type, right_column_type, BinaryOperator::Add)?; + let new_column_type = try_add_subtract_column_types(left_column_type, right_column_type)?; let new_precision_value = new_column_type .precision_value() .expect("numeric columns have precision"); @@ -332,11 +330,7 @@ where T0: Copy, T1: Copy, { - let new_column_type = try_add_subtract_column_types( - left_column_type, - right_column_type, - BinaryOperator::Subtract, - )?; + let new_column_type = try_add_subtract_column_types(left_column_type, right_column_type)?; let new_precision_value = new_column_type .precision_value() .expect("numeric columns have precision"); diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 0ce3b138b..dda1f6101 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -305,12 +305,9 @@ pub(crate) fn type_check_binary_operation( | (ColumnType::TimestampTZ(_, _), ColumnType::TimestampTZ(_, _)) ) } - BinaryOperator::Add => { - try_add_subtract_column_types(*left_dtype, *right_dtype, BinaryOperator::Add).is_ok() - } + BinaryOperator::Add => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(), BinaryOperator::Subtract => { - try_add_subtract_column_types(*left_dtype, *right_dtype, BinaryOperator::Subtract) - .is_ok() + try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok() } BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(), BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(), diff --git a/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs index 8f1be6f99..75e6836d1 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs @@ -44,12 +44,7 @@ impl ProofExpr for AddSubtractExpr { } fn data_type(&self) -> ColumnType { - let operator = if self.is_subtract { - BinaryOperator::Subtract - } else { - BinaryOperator::Add - }; - try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type(), operator) + try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type()) .expect("Failed to add/subtract column types") }