Skip to content

Commit

Permalink
refactor!: remove parser structs from ColumnOperationError
Browse files Browse the repository at this point in the history
  • Loading branch information
iajoiner committed Nov 7, 2024
1 parent 5763f53 commit c5b70bb
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 57 deletions.
74 changes: 36 additions & 38 deletions crates/proof-of-sql/src/base/database/column_type_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use crate::base::{
math::decimal::{DecimalError, Precision},
};
use alloc::{format, string::ToString};
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
// For decimal type manipulation please refer to
// https://learn.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql?view=sql-server-ver16

Expand All @@ -19,7 +18,6 @@ use proof_of_sql_parser::intermediate_ast::BinaryOperator;
pub fn try_add_subtract_column_types(
lhs: ColumnType,
rhs: ColumnType,
_operator: BinaryOperator,
) -> ColumnOperationResult<ColumnType> {
if !lhs.is_numeric() || !rhs.is_numeric() {
return Err(ColumnOperationError::BinaryOperationInvalidColumnType {
Expand Down Expand Up @@ -180,87 +178,87 @@ mod test {
// lhs and rhs are integers with the same precision
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::TinyInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::TinyInt;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::SmallInt;
assert_eq!(expected, actual);

// lhs and rhs are integers with different precision
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::SmallInt;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Int;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Int;
assert_eq!(expected, actual);

// lhs is an integer and rhs is a scalar
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::Scalar;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Scalar;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Scalar;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Scalar;
assert_eq!(expected, actual);

// lhs is a decimal with nonnegative scale and rhs is an integer
let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let rhs = ColumnType::TinyInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2);
assert_eq!(expected, actual);

let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2);
assert_eq!(expected, actual);

// lhs and rhs are both decimals with nonnegative scale
let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3);
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3);
assert_eq!(expected, actual);

// lhs is an integer and rhs is a decimal with negative scale
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0);
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0);
assert_eq!(expected, actual);

// lhs and rhs are both decimals one of which has negative scale
let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13);
let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5);
assert_eq!(expected, actual);

// lhs and rhs are both decimals both with negative scale
// and with result having maximum precision
let lhs = ColumnType::Decimal75(Precision::new(74).unwrap(), -13);
let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), -14);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13);
assert_eq!(expected, actual);
}
Expand All @@ -270,21 +268,21 @@ mod test {
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));

let lhs = ColumnType::VarChar;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));
}
Expand All @@ -294,7 +292,7 @@ mod test {
let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 4);
let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 4);
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidPrecision { .. }
})
Expand All @@ -303,7 +301,7 @@ mod test {
let lhs = ColumnType::Int;
let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 10);
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Add),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidPrecision { .. }
})
Expand All @@ -315,87 +313,87 @@ mod test {
// lhs and rhs are integers with the same precision
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::TinyInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::TinyInt;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::SmallInt;
assert_eq!(expected, actual);

// lhs and rhs are integers with different precision
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::SmallInt;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Int;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Int;
assert_eq!(expected, actual);

// lhs is an integer and rhs is a scalar
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::Scalar;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Scalar;
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Scalar;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Scalar;
assert_eq!(expected, actual);

// lhs is a decimal and rhs is an integer
let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let rhs = ColumnType::TinyInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2);
assert_eq!(expected, actual);

let lhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let rhs = ColumnType::SmallInt;
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(11).unwrap(), 2);
assert_eq!(expected, actual);

// lhs and rhs are both decimals with nonnegative scale
let lhs = ColumnType::Decimal75(Precision::new(20).unwrap(), 3);
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), 2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(21).unwrap(), 3);
assert_eq!(expected, actual);

// lhs is an integer and rhs is a decimal with negative scale
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0);
assert_eq!(expected, actual);

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::Decimal75(Precision::new(10).unwrap(), -2);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(13).unwrap(), 0);
assert_eq!(expected, actual);

// lhs and rhs are both decimals one of which has negative scale
let lhs = ColumnType::Decimal75(Precision::new(40).unwrap(), -13);
let rhs = ColumnType::Decimal75(Precision::new(15).unwrap(), 5);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(59).unwrap(), 5);
assert_eq!(expected, actual);

// lhs and rhs are both decimals both with negative scale
// and with result having maximum precision
let lhs = ColumnType::Decimal75(Precision::new(61).unwrap(), -13);
let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), -14);
let actual = try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract).unwrap();
let actual = try_add_subtract_column_types(lhs, rhs).unwrap();
let expected = ColumnType::Decimal75(Precision::new(75).unwrap(), -13);
assert_eq!(expected, actual);
}
Expand All @@ -405,21 +403,21 @@ mod test {
let lhs = ColumnType::TinyInt;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));

let lhs = ColumnType::SmallInt;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));

let lhs = ColumnType::VarChar;
let rhs = ColumnType::VarChar;
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::BinaryOperationInvalidColumnType { .. })
));
}
Expand All @@ -429,7 +427,7 @@ mod test {
let lhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 0);
let rhs = ColumnType::Decimal75(Precision::new(73).unwrap(), 1);
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidPrecision { .. }
})
Expand All @@ -438,7 +436,7 @@ mod test {
let lhs = ColumnType::Int128;
let rhs = ColumnType::Decimal75(Precision::new(75).unwrap(), 12);
assert!(matches!(
try_add_subtract_column_types(lhs, rhs, BinaryOperator::Subtract),
try_add_subtract_column_types(lhs, rhs),
Err(ColumnOperationError::DecimalConversionError {
source: DecimalError::InvalidPrecision { .. }
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use crate::base::{
},
scalar::Scalar,
};
use alloc::string::ToString;
use core::ops::{Add, Div, Mul, Sub};

impl<S: Scalar> OwnedColumn<S> {
Expand Down
10 changes: 2 additions & 8 deletions crates/proof-of-sql/src/base/database/slice_decimal_operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use alloc::vec::Vec;
use core::{cmp::Ordering, fmt::Debug};
use num_bigint::BigInt;
use num_traits::Zero;
use proof_of_sql_parser::intermediate_ast::BinaryOperator;
/// Check whether a numerical slice is equal to a decimal one.
///
/// Note that we do not check for length equality here.
Expand Down Expand Up @@ -273,8 +272,7 @@ where
T0: Copy,
T1: Copy,
{
let new_column_type =
try_add_subtract_column_types(left_column_type, right_column_type, BinaryOperator::Add)?;
let new_column_type = try_add_subtract_column_types(left_column_type, right_column_type)?;
let new_precision_value = new_column_type
.precision_value()
.expect("numeric columns have precision");
Expand Down Expand Up @@ -332,11 +330,7 @@ where
T0: Copy,
T1: Copy,
{
let new_column_type = try_add_subtract_column_types(
left_column_type,
right_column_type,
BinaryOperator::Subtract,
)?;
let new_column_type = try_add_subtract_column_types(left_column_type, right_column_type)?;
let new_precision_value = new_column_type
.precision_value()
.expect("numeric columns have precision");
Expand Down
7 changes: 2 additions & 5 deletions crates/proof-of-sql/src/sql/parse/query_context_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,9 @@ pub(crate) fn type_check_binary_operation(
| (ColumnType::TimestampTZ(_, _), ColumnType::TimestampTZ(_, _))
)
}
BinaryOperator::Add => {
try_add_subtract_column_types(*left_dtype, *right_dtype, BinaryOperator::Add).is_ok()
}
BinaryOperator::Add => try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok(),
BinaryOperator::Subtract => {
try_add_subtract_column_types(*left_dtype, *right_dtype, BinaryOperator::Subtract)
.is_ok()
try_add_subtract_column_types(*left_dtype, *right_dtype).is_ok()
}
BinaryOperator::Multiply => try_multiply_column_types(*left_dtype, *right_dtype).is_ok(),
BinaryOperator::Division => left_dtype.is_numeric() && right_dtype.is_numeric(),
Expand Down
7 changes: 1 addition & 6 deletions crates/proof-of-sql/src/sql/proof_exprs/add_subtract_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ impl ProofExpr for AddSubtractExpr {
}

fn data_type(&self) -> ColumnType {
let operator = if self.is_subtract {
BinaryOperator::Subtract
} else {
BinaryOperator::Add
};
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type(), operator)
try_add_subtract_column_types(self.lhs.data_type(), self.rhs.data_type())
.expect("Failed to add/subtract column types")
}

Expand Down

0 comments on commit c5b70bb

Please sign in to comment.