diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..212566614 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +* text=auto \ No newline at end of file diff --git a/crates/proof-of-sql-parser/src/posql_time/unit.rs b/crates/proof-of-sql-parser/src/posql_time/unit.rs index 16be4cfb0..0f6f08c85 100644 --- a/crates/proof-of-sql-parser/src/posql_time/unit.rs +++ b/crates/proof-of-sql-parser/src/posql_time/unit.rs @@ -1,4 +1,5 @@ use super::PoSQLTimestampError; +use crate::alloc::string::ToString; use core::fmt; use serde::{Deserialize, Serialize}; @@ -16,6 +17,21 @@ pub enum PoSQLTimeUnit { Nanosecond, } +impl PoSQLTimeUnit { + /// Converts a precision value into a corresponding `PoSQLTimeUnit`. + pub fn from_precision(precision: u64) -> Result { + match precision { + 0 => Ok(PoSQLTimeUnit::Second), + 3 => Ok(PoSQLTimeUnit::Millisecond), + 6 => Ok(PoSQLTimeUnit::Microsecond), + 9 => Ok(PoSQLTimeUnit::Nanosecond), + _ => Err(PoSQLTimestampError::UnsupportedPrecision { + error: precision.to_string(), + }), + } + } +} + impl From for u64 { fn from(value: PoSQLTimeUnit) -> u64 { match value { diff --git a/crates/proof-of-sql-parser/src/sqlparser.rs b/crates/proof-of-sql-parser/src/sqlparser.rs index 72643b568..736c45f47 100644 --- a/crates/proof-of-sql-parser/src/sqlparser.rs +++ b/crates/proof-of-sql-parser/src/sqlparser.rs @@ -5,10 +5,16 @@ use crate::{ OrderBy as PoSqlOrderBy, OrderByDirection, SelectResultExpr, SetExpression, TableExpression, UnaryOperator as PoSqlUnaryOperator, }, + posql_time::{PoSQLTimeUnit, PoSQLTimeZone}, Identifier, ResourceId, SelectStatement, }; -use alloc::{boxed::Box, string::ToString, vec}; +use alloc::{ + boxed::Box, + string::{String, ToString}, + vec, +}; use core::fmt::Display; +use serde::{Deserialize, Serialize}; use sqlparser::ast::{ BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, Ident, ObjectName, Offset, OffsetRows, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor, @@ -28,6 +34,79 @@ fn id(id: Identifier) -> Expr { Expr::Identifier(id.into()) } +#[must_use] +/// New `AliasedResultExpr` using sqlparser types +/// Represents an aliased SQL expression, e.g., `a + 1 AS alias`. +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)] +pub struct SqlAliasedResultExpr { + /// The SQL expression being aliased (e.g., `a + 1`). + pub expr: Box, + /// The alias for the expression (e.g., `alias` in `a + 1 AS alias`). + pub alias: Ident, +} + +impl SqlAliasedResultExpr { + /// Creates a new `SqlAliasedResultExpr`. + pub fn new(expr: Box, alias: Ident) -> Self { + Self { expr, alias } + } + + /// Try to get the identifier of the expression if it is a column + /// Otherwise, return None + #[must_use] + pub fn try_as_identifier(&self) -> Option<&Ident> { + match self.expr.as_ref() { + Expr::Identifier(identifier) => Some(identifier), + _ => None, + } + } +} + +/// Provides an extension for the `TimezoneInfo` type for offsets. +pub trait TimezoneInfoExt { + /// Retrieve the offset in seconds for `TimezoneInfo`. + fn offset(&self, timezone_str: Option<&str>) -> i32; +} + +impl TimezoneInfoExt for TimezoneInfo { + fn offset(&self, timezone_str: Option<&str>) -> i32 { + match self { + TimezoneInfo::None => PoSQLTimeZone::utc().offset(), + TimezoneInfo::WithTimeZone => match timezone_str { + Some(tz_str) => PoSQLTimeZone::try_from(&Some(tz_str.into())) + .unwrap_or_else(|_| PoSQLTimeZone::utc()) + .offset(), + None => PoSQLTimeZone::utc().offset(), + }, + _ => panic!("Offsets are not applicable for WithoutTimeZone or Tz variants."), + } + } +} + +/// Utility function to create a `Timestamp` expression. +pub fn timestamp_to_expr( + value: &str, + time_unit: PoSQLTimeUnit, + timezone: TimezoneInfo, +) -> Result { + let time_unit_as_u64 = u64::from(time_unit); + + Ok(Expr::TypedString { + data_type: DataType::Timestamp(Some(time_unit_as_u64), timezone), + value: value.to_string(), + }) +} + +/// Parses [`PoSQLTimeZone`] into a `TimezoneInfo`. +impl From for TimezoneInfo { + fn from(posql_timezone: PoSQLTimeZone) -> Self { + match posql_timezone.offset() { + 0 => TimezoneInfo::None, + _ => TimezoneInfo::WithTimeZone, + } + } +} + impl From for Ident { fn from(id: Identifier) -> Self { Ident::new(id.as_str()) @@ -125,7 +204,7 @@ impl From for OrderByExpr { impl From for Expr { fn from(expr: Expression) -> Self { match expr { - Expression::Literal(literal) => literal.into(), + Expression::Literal(literal) => Expr::from(literal), Expression::Column(identifier) => id(identifier), Expression::Unary { op, expr } => Expr::UnaryOp { op: op.into(), @@ -268,6 +347,11 @@ mod test { "select timestamp '2024-11-07T04:55:12.345+03:00' as time from t;", "select timestamp(3) '2024-11-07 01:55:12.345 UTC' as time from t;", ); + + check_posql_intermediate_ast_to_sqlparser_equivalence( + "select timestamp '2024-11-07T04:55:12+00:00' as time from t;", + "select timestamp(0) '2024-11-07 04:55:12 UTC' as time from t;", + ); } // Check that PoSQL intermediate AST can be converted to SQL parser AST and that the two are equal. diff --git a/crates/proof-of-sql/benches/bench_append_rows.rs b/crates/proof-of-sql/benches/bench_append_rows.rs index 0f353a884..150570385 100644 --- a/crates/proof-of-sql/benches/bench_append_rows.rs +++ b/crates/proof-of-sql/benches/bench_append_rows.rs @@ -24,8 +24,9 @@ use proof_of_sql::{ DoryCommitment, DoryProverPublicSetup, DoryScalar, ProverSetup, PublicParameters, }, }; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; use rand::Rng; +use sqlparser::ast::TimezoneInfo; /// Bench dory performance when appending rows to a table. This includes the computation of /// commitments. Chose the number of columns to randomly generate across supported `PoSQL` @@ -121,7 +122,7 @@ pub fn generate_random_owned_table( "timestamptz" => columns.push(timestamptz( &*identifier, PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, vec![rng.gen::(); num_rows], )), _ => unreachable!(), diff --git a/crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs b/crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs index 31f4f1d8e..e4143a201 100644 --- a/crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs +++ b/crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs @@ -202,60 +202,69 @@ impl ArrayRefExt for ArrayRef { } } // Handle all possible TimeStamp TimeUnit instances - DataType::Timestamp(time_unit, tz) => match time_unit { - ArrowTimeUnit::Second => { - if let Some(array) = self.as_any().downcast_ref::() { - Ok(Column::TimestampTZ( - PoSQLTimeUnit::Second, - PoSQLTimeZone::try_from(tz)?, - &array.values()[range.start..range.end], - )) - } else { - Err(ArrowArrayToColumnConversionError::UnsupportedType { - datatype: self.data_type().clone(), - }) + DataType::Timestamp(time_unit, tz) => { + let timezone = PoSQLTimeZone::try_from(tz)?; + match time_unit { + ArrowTimeUnit::Second => { + if let Some(array) = self.as_any().downcast_ref::() { + Ok(Column::TimestampTZ( + PoSQLTimeUnit::Second, + timezone.into(), + &array.values()[range.start..range.end], + )) + } else { + Err(ArrowArrayToColumnConversionError::UnsupportedType { + datatype: self.data_type().clone(), + }) + } } - } - ArrowTimeUnit::Millisecond => { - if let Some(array) = self.as_any().downcast_ref::() { - Ok(Column::TimestampTZ( - PoSQLTimeUnit::Millisecond, - PoSQLTimeZone::try_from(tz)?, - &array.values()[range.start..range.end], - )) - } else { - Err(ArrowArrayToColumnConversionError::UnsupportedType { - datatype: self.data_type().clone(), - }) + ArrowTimeUnit::Millisecond => { + if let Some(array) = + self.as_any().downcast_ref::() + { + Ok(Column::TimestampTZ( + PoSQLTimeUnit::Millisecond, + timezone.into(), + &array.values()[range.start..range.end], + )) + } else { + Err(ArrowArrayToColumnConversionError::UnsupportedType { + datatype: self.data_type().clone(), + }) + } } - } - ArrowTimeUnit::Microsecond => { - if let Some(array) = self.as_any().downcast_ref::() { - Ok(Column::TimestampTZ( - PoSQLTimeUnit::Microsecond, - PoSQLTimeZone::try_from(tz)?, - &array.values()[range.start..range.end], - )) - } else { - Err(ArrowArrayToColumnConversionError::UnsupportedType { - datatype: self.data_type().clone(), - }) + ArrowTimeUnit::Microsecond => { + if let Some(array) = + self.as_any().downcast_ref::() + { + Ok(Column::TimestampTZ( + PoSQLTimeUnit::Microsecond, + timezone.into(), + &array.values()[range.start..range.end], + )) + } else { + Err(ArrowArrayToColumnConversionError::UnsupportedType { + datatype: self.data_type().clone(), + }) + } } - } - ArrowTimeUnit::Nanosecond => { - if let Some(array) = self.as_any().downcast_ref::() { - Ok(Column::TimestampTZ( - PoSQLTimeUnit::Nanosecond, - PoSQLTimeZone::try_from(tz)?, - &array.values()[range.start..range.end], - )) - } else { - Err(ArrowArrayToColumnConversionError::UnsupportedType { - datatype: self.data_type().clone(), - }) + ArrowTimeUnit::Nanosecond => { + if let Some(array) = + self.as_any().downcast_ref::() + { + Ok(Column::TimestampTZ( + PoSQLTimeUnit::Nanosecond, + timezone.into(), + &array.values()[range.start..range.end], + )) + } else { + Err(ArrowArrayToColumnConversionError::UnsupportedType { + datatype: self.data_type().clone(), + }) + } } } - }, + } DataType::Utf8 => { if let Some(array) = self.as_any().downcast_ref::() { let vals = alloc @@ -292,6 +301,7 @@ mod tests { use alloc::sync::Arc; use arrow::array::Decimal256Builder; use core::str::FromStr; + use sqlparser::ast::TimezoneInfo; #[test] fn we_can_convert_timestamp_array_normal_range() { @@ -305,7 +315,7 @@ mod tests { let result = array.to_column::(&alloc, &(1..3), None); assert_eq!( result.unwrap(), - Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3]) + Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3]) ); } @@ -323,7 +333,7 @@ mod tests { .unwrap(); assert_eq!( result, - Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[]) + Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[]) ); } @@ -339,7 +349,7 @@ mod tests { let result = array.to_column::(&alloc, &(1..1), None); assert_eq!( result.unwrap(), - Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[]) + Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[]) ); } @@ -1006,7 +1016,7 @@ mod tests { .unwrap(); assert_eq!( result, - Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[..]) + Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[..]) ); } @@ -1076,7 +1086,7 @@ mod tests { array .to_column::(&alloc, &(1..3), None) .unwrap(), - Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3]) + Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3]) ); } @@ -1134,7 +1144,7 @@ mod tests { .unwrap(); assert_eq!( result, - Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[]) + Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[]) ); } } diff --git a/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs b/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs index dbd218383..578d4d9c6 100644 --- a/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs +++ b/crates/proof-of-sql/src/base/arrow/column_arrow_conversions.rs @@ -59,7 +59,7 @@ impl TryFrom for ColumnType { }; Ok(ColumnType::TimestampTZ( posql_time_unit, - PoSQLTimeZone::try_from(&timezone_option)?, + PoSQLTimeZone::try_from(&timezone_option)?.into(), )) } DataType::Utf8 => Ok(ColumnType::VarChar), diff --git a/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs index 9e8df184a..89893621d 100644 --- a/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs +++ b/crates/proof-of-sql/src/base/arrow/owned_and_arrow_conversions.rs @@ -238,7 +238,7 @@ impl TryFrom<&ArrayRef> for OwnedColumn { let timestamps = array.values().iter().copied().collect::>(); Ok(OwnedColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::try_from(timezone)?, + PoSQLTimeZone::try_from(timezone)?.into(), timestamps, )) } @@ -252,7 +252,7 @@ impl TryFrom<&ArrayRef> for OwnedColumn { let timestamps = array.values().iter().copied().collect::>(); Ok(OwnedColumn::TimestampTZ( PoSQLTimeUnit::Millisecond, - PoSQLTimeZone::try_from(timezone)?, + PoSQLTimeZone::try_from(timezone)?.into(), timestamps, )) } @@ -266,7 +266,7 @@ impl TryFrom<&ArrayRef> for OwnedColumn { let timestamps = array.values().iter().copied().collect::>(); Ok(OwnedColumn::TimestampTZ( PoSQLTimeUnit::Microsecond, - PoSQLTimeZone::try_from(timezone)?, + PoSQLTimeZone::try_from(timezone)?.into(), timestamps, )) } @@ -280,7 +280,7 @@ impl TryFrom<&ArrayRef> for OwnedColumn { let timestamps = array.values().iter().copied().collect::>(); Ok(OwnedColumn::TimestampTZ( PoSQLTimeUnit::Nanosecond, - PoSQLTimeZone::try_from(timezone)?, + PoSQLTimeZone::try_from(timezone)?.into(), timestamps, )) } diff --git a/crates/proof-of-sql/src/base/commitment/column_bounds.rs b/crates/proof-of-sql/src/base/commitment/column_bounds.rs index 29051afb6..32d5abca4 100644 --- a/crates/proof-of-sql/src/base/commitment/column_bounds.rs +++ b/crates/proof-of-sql/src/base/commitment/column_bounds.rs @@ -312,7 +312,8 @@ mod tests { }; use alloc::{string::String, vec}; use itertools::Itertools; - use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; + use proof_of_sql_parser::posql_time::PoSQLTimeUnit; + use sqlparser::ast::TimezoneInfo; #[test] fn we_can_construct_bounds_by_method() { @@ -563,7 +564,7 @@ mod tests { let timestamp_column = OwnedColumn::::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, vec![1_i64, 2, 3, 4], ); let committable_timestamp_column = CommittableColumn::from(×tamp_column); diff --git a/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs b/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs index a93a861ad..e2fe41bc0 100644 --- a/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs +++ b/crates/proof-of-sql/src/base/commitment/column_commitment_metadata.rs @@ -186,7 +186,8 @@ mod tests { scalar::test_scalar::TestScalar, }; use alloc::string::String; - use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; + use proof_of_sql_parser::posql_time::PoSQLTimeUnit; + use sqlparser::ast::TimezoneInfo; #[test] fn we_can_construct_metadata() { @@ -257,12 +258,12 @@ mod tests { assert_eq!( ColumnCommitmentMetadata::try_new( - ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()), + ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None), ColumnBounds::TimestampTZ(Bounds::Empty), ) .unwrap(), ColumnCommitmentMetadata { - column_type: ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()), + column_type: ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None), bounds: ColumnBounds::TimestampTZ(Bounds::Empty), } ); @@ -399,7 +400,7 @@ mod tests { let timestamp_column: OwnedColumn = OwnedColumn::::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [1i64, 2, 3, 4, 5].to_vec(), ); let committable_timestamp_column = CommittableColumn::from(×tamp_column); @@ -407,7 +408,7 @@ mod tests { ColumnCommitmentMetadata::from_column(&committable_timestamp_column); assert_eq!( timestamp_metadata.column_type(), - &ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()) + &ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None) ); if let ColumnBounds::TimestampTZ(Bounds::Sharp(bounds)) = timestamp_metadata.bounds() { assert_eq!(bounds.min(), &1); @@ -583,7 +584,7 @@ mod tests { 1_625_072_400, 1_625_065_000, ]; - let timezone = PoSQLTimeZone::utc(); + let timezone = TimezoneInfo::None; let timeunit = PoSQLTimeUnit::Second; let timestamp_column_a = CommittableColumn::TimestampTZ(timeunit, timezone, ×[..2]); let timestamp_metadata_a = ColumnCommitmentMetadata::from_column(×tamp_column_a); @@ -609,7 +610,7 @@ mod tests { 1_625_072_400, 1_625_065_000, ]; - let timezone = PoSQLTimeZone::utc(); + let timezone = TimezoneInfo::None; let timeunit = PoSQLTimeUnit::Second; let timestamp_column_a = CommittableColumn::TimestampTZ(timeunit, timezone, ×[..2]); @@ -960,12 +961,12 @@ mod tests { .is_err()); let timestamp_tz_metadata_a = ColumnCommitmentMetadata { - column_type: ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()), + column_type: ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None), bounds: ColumnBounds::TimestampTZ(Bounds::Empty), }; let timestamp_tz_metadata_b = ColumnCommitmentMetadata { - column_type: ColumnType::TimestampTZ(PoSQLTimeUnit::Millisecond, PoSQLTimeZone::utc()), + column_type: ColumnType::TimestampTZ(PoSQLTimeUnit::Millisecond, TimezoneInfo::None), bounds: ColumnBounds::TimestampTZ(Bounds::Empty), }; diff --git a/crates/proof-of-sql/src/base/commitment/committable_column.rs b/crates/proof-of-sql/src/base/commitment/committable_column.rs index 9526e5bbf..4a1f37d6b 100644 --- a/crates/proof-of-sql/src/base/commitment/committable_column.rs +++ b/crates/proof-of-sql/src/base/commitment/committable_column.rs @@ -7,7 +7,8 @@ use crate::base::{ use alloc::vec::Vec; #[cfg(feature = "blitzar")] use blitzar::sequence::Sequence; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::TimezoneInfo; /// Column data in "committable form". /// @@ -42,7 +43,7 @@ pub enum CommittableColumn<'a> { /// Column of limbs for committing to scalars, hashed from a `VarChar` column. VarChar(Vec<[u64; 4]>), /// Borrowed Timestamp column with Timezone, mapped to `i64`. - TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone, &'a [i64]), + TimestampTZ(PoSQLTimeUnit, TimezoneInfo, &'a [i64]), /// Borrowed byte column, mapped to `u8`. This is not a `PoSQL` /// type, we need this to commit to words in the range check. RangeCheckWord(&'a [u8]), @@ -258,24 +259,24 @@ mod tests { fn we_can_get_type_and_length_of_timestamp_column() { // empty case let committable_column = - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[]); + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[]); assert_eq!(committable_column.len(), 0); assert!(committable_column.is_empty()); assert_eq!( committable_column.column_type(), - ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()) + ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None) ); let committable_column = CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[12, 34, 56], ); assert_eq!(committable_column.len(), 3); assert!(!committable_column.is_empty()); assert_eq!( committable_column.column_type(), - ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()) + ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None) ); } @@ -455,28 +456,24 @@ mod tests { // empty case let from_borrowed_column = CommittableColumn::from(&Column::::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[], )); assert_eq!( from_borrowed_column, - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[]) + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[]) ); // non-empty case let timestamps = [1_625_072_400, 1_625_076_000, 1_625_083_200]; let from_borrowed_column = CommittableColumn::from(&Column::::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, ×tamps, )); assert_eq!( from_borrowed_column, - CommittableColumn::TimestampTZ( - PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), - ×tamps - ) + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, ×tamps) ); } @@ -660,30 +657,26 @@ mod tests { // empty case let owned_column = OwnedColumn::::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, Vec::new(), ); let from_owned_column = CommittableColumn::from(&owned_column); assert_eq!( from_owned_column, - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[]) + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[]) ); // non-empty case let timestamps = vec![1_625_072_400, 1_625_076_000, 1_625_083_200]; let owned_column = OwnedColumn::::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, timestamps.clone(), ); let from_owned_column = CommittableColumn::from(&owned_column); assert_eq!( from_owned_column, - CommittableColumn::TimestampTZ( - PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), - ×tamps - ) + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, ×tamps) ); } @@ -1013,7 +1006,7 @@ mod tests { fn we_can_commit_to_timestamp_column_through_committable_column() { // Empty case let committable_column = - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[]); + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[]); let sequence = Sequence::from(&committable_column); let mut commitment_buffer = [CompressedRistretto::default()]; compute_curve25519_commitments(&mut commitment_buffer, &[sequence], 0); @@ -1021,11 +1014,8 @@ mod tests { // Non-empty case let timestamps = [1_625_072_400, 1_625_076_000, 1_625_083_200]; - let committable_column = CommittableColumn::TimestampTZ( - PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), - ×tamps, - ); + let committable_column = + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, ×tamps); let sequence_actual = Sequence::from(&committable_column); let sequence_expected = Sequence::from(timestamps.as_slice()); diff --git a/crates/proof-of-sql/src/base/database/column.rs b/crates/proof-of-sql/src/base/database/column.rs index 25d3b0646..3c13b7dec 100644 --- a/crates/proof-of-sql/src/base/database/column.rs +++ b/crates/proof-of-sql/src/base/database/column.rs @@ -1,19 +1,24 @@ -use super::{LiteralValue, OwnedColumn, TableRef}; -use crate::base::{ - math::decimal::Precision, - scalar::{Scalar, ScalarExt}, - slice_ops::slice_cast_with, +use super::{OwnedColumn, TableRef}; +use crate::{ + alloc::string::ToString, + base::{ + math::{decimal::Precision, i256::I256}, + scalar::{Scalar, ScalarExt}, + slice_ops::slice_cast_with, + }, + sql::parse::ConversionError, }; -use alloc::vec::Vec; +use alloc::{format, string::String, vec::Vec}; use bumpalo::Bump; use core::{ fmt, fmt::{Display, Formatter}, mem::size_of, }; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; use serde::{Deserialize, Serialize}; -use sqlparser::ast::Ident; +use snafu::Snafu; +use sqlparser::ast::{DataType, ExactNumberInfo, Expr as SqlExpr, Ident, TimezoneInfo, Value}; /// Represents a read-only view of a column in an in-memory, /// column-oriented database. @@ -49,7 +54,7 @@ pub enum Column<'a, S: Scalar> { /// - the first element maps to the stored `TimeUnit` /// - the second element maps to a timezone /// - the third element maps to columns of timeunits since unix epoch - TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone, &'a [i64]), + TimestampTZ(PoSQLTimeUnit, TimezoneInfo, &'a [i64]), } impl<'a, S: Scalar> Column<'a, S> { @@ -97,43 +102,114 @@ impl<'a, S: Scalar> Column<'a, S> { } /// Generate a constant column from a literal value with a given length + /// + /// # Panics + /// - Panics if the precision or scale for a decimal value cannot fit into `u8` or `i8`, respectively. + /// - Panics if creating a `Precision` object fails. pub fn from_literal_with_length( - literal: &LiteralValue, + expr: &SqlExpr, length: usize, alloc: &'a Bump, - ) -> Self { - match literal { - LiteralValue::Boolean(value) => { - Column::Boolean(alloc.alloc_slice_fill_copy(length, *value)) - } - LiteralValue::TinyInt(value) => { - Column::TinyInt(alloc.alloc_slice_fill_copy(length, *value)) - } - LiteralValue::SmallInt(value) => { - Column::SmallInt(alloc.alloc_slice_fill_copy(length, *value)) + ) -> Result { + match expr { + // Boolean value + SqlExpr::Value(Value::Boolean(value)) => { + Ok(Column::Boolean(alloc.alloc_slice_fill_copy(length, *value))) } - LiteralValue::Int(value) => Column::Int(alloc.alloc_slice_fill_copy(length, *value)), - LiteralValue::BigInt(value) => { - Column::BigInt(alloc.alloc_slice_fill_copy(length, *value)) - } - LiteralValue::Int128(value) => { - Column::Int128(alloc.alloc_slice_fill_copy(length, *value)) - } - LiteralValue::Scalar(value) => { - Column::Scalar(alloc.alloc_slice_fill_copy(length, (*value).into())) - } - LiteralValue::Decimal75(precision, scale, value) => Column::Decimal75( - *precision, - *scale, - alloc.alloc_slice_fill_copy(length, value.into_scalar()), - ), - LiteralValue::TimeStampTZ(tu, tz, value) => { - Column::TimestampTZ(*tu, *tz, alloc.alloc_slice_fill_copy(length, *value)) + + // Numeric values + SqlExpr::Value(Value::Number(value, _)) => { + let n = value + .parse::() + .map_err(|_| ColumnError::InvalidNumberFormat { + value: value.clone(), + })?; + + if let Ok(n_i8) = i8::try_from(n) { + Ok(Column::TinyInt(alloc.alloc_slice_fill_copy(length, n_i8))) + } else if let Ok(n_i16) = i16::try_from(n) { + Ok(Column::SmallInt(alloc.alloc_slice_fill_copy(length, n_i16))) + } else if let Ok(n_i32) = i32::try_from(n) { + Ok(Column::Int(alloc.alloc_slice_fill_copy(length, n_i32))) + } else { + Ok(Column::BigInt(alloc.alloc_slice_fill_copy(length, n))) + } } - LiteralValue::VarChar(string) => Column::VarChar(( + + // String values + SqlExpr::Value(Value::SingleQuotedString(string)) => Ok(Column::VarChar(( alloc.alloc_slice_fill_with(length, |_| alloc.alloc_str(string) as &str), alloc.alloc_slice_fill_copy(length, S::from(string)), - )), + ))), + + // Typed string literals + SqlExpr::TypedString { data_type, value } => match data_type { + // Decimal values + DataType::Decimal(ExactNumberInfo::PrecisionAndScale(precision, scale)) => { + let i256_value = + I256::from_string(value).map_err(|_| ColumnError::InvalidDecimal { + value: value.clone(), + })?; + let precision_u8 = + u8::try_from(*precision).expect("Precision must fit into u8"); + let scale_i8 = i8::try_from(*scale).expect("Scale must fit into i8"); + let precision_obj = + Precision::new(precision_u8).expect("Failed to create Precision"); + + Ok(Column::Decimal75( + precision_obj, + scale_i8, + alloc.alloc_slice_fill_copy(length, i256_value.into_scalar()), + )) + } + + // Timestamp values + DataType::Timestamp(Some(precision), tz) => { + let time_unit = + PoSQLTimeUnit::from_precision(*precision).unwrap_or(PoSQLTimeUnit::Second); + let timestamp_value = + value + .parse::() + .map_err(|_| ColumnError::InvalidNumberFormat { + value: value.clone(), + })?; + Ok(Column::TimestampTZ( + time_unit, + *tz, + alloc.alloc_slice_fill_copy(length, timestamp_value), + )) + } + DataType::Custom(_, _) if data_type.to_string() == "scalar" => { + let scalar_str = value.strip_prefix("scalar:").ok_or_else(|| { + ColumnError::InvalidScalarFormat { + value: value.clone(), + } + })?; + let limbs: Vec = scalar_str + .split(',') + .map(|x| { + x.parse::() + .map_err(|_| ColumnError::InvalidScalarFormat { + value: value.clone(), + }) + }) + .collect::, ColumnError>>()?; + if limbs.len() != 4 { + return Err(ColumnError::InvalidScalarFormat { + value: value.clone(), + }); + } + Ok(Column::Scalar( + alloc.alloc_slice_fill_copy(length, value.clone().into()), + )) + } + _ => Err(ColumnError::UnsupportedDataType { + data_type: format!("{expr:?}"), + }), + }, + _ => Err(ColumnError::UnsupportedDataType { + data_type: format!("{expr:?}"), + }), } } @@ -280,6 +356,51 @@ impl<'a, S: Scalar> Column<'a, S> { } } +/// Represents errors that can occur while working with columns. +#[derive(Snafu, Debug, PartialEq, Eq)] +pub enum ColumnError { + #[snafu(display("Invalid number format: {value}"))] + /// Error for invalid number format. + InvalidNumberFormat { + /// The invalid number value that caused the error. + value: String, + }, + + #[snafu(display("Unsupported data type: {data_type}"))] + /// Error for unsupported data types. + UnsupportedDataType { + /// The unsupported data type as a string. + data_type: String, + }, + + #[snafu(display("Scalar parsing error: {value}"))] + /// Error for scalar parsing failures. + InvalidScalarFormat { + /// The scalar value that caused the parsing error. + value: String, + }, + + #[snafu(display("Invalid decimal format: {value}"))] + /// Error for invalid decimal format. + InvalidDecimal { + /// The invalid decimal value that caused the error. + value: String, + }, + + #[snafu(display("Conversion error: {source}"))] + /// Error for column operation conversion issues. + ConversionError { + /// The underlying conversion error that occurred. + source: ConversionError, + }, +} + +impl From for ColumnError { + fn from(error: ConversionError) -> Self { + ColumnError::ConversionError { source: error } + } +} + /// Represents the supported data types of a column in an in-memory, /// column-oriented database. /// @@ -313,7 +434,7 @@ pub enum ColumnType { Decimal75(Precision, i8), /// Mapped to i64 #[serde(alias = "TIMESTAMP", alias = "timestamp")] - TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone), + TimestampTZ(PoSQLTimeUnit, TimezoneInfo), /// Mapped to `S` #[serde(alias = "SCALAR", alias = "scalar")] Scalar, @@ -565,9 +686,9 @@ mod tests { #[test] fn column_type_serializes_to_string() { - let column_type = ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()); + let column_type = ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None); let serialized = serde_json::to_string(&column_type).unwrap(); - assert_eq!(serialized, r#"{"TimestampTZ":["Second",{"offset":0}]}"#); + assert_eq!(serialized, r#"{"TimestampTZ":["Second","None"]}"#); let column_type = ColumnType::Boolean; let serialized = serde_json::to_string(&column_type).unwrap(); @@ -609,9 +730,9 @@ mod tests { #[test] fn we_can_deserialize_columns_from_valid_strings() { let expected_column_type = - ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()); + ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None); let deserialized: ColumnType = - serde_json::from_str(r#"{"TimestampTZ":["Second",{"offset":0}]}"#).unwrap(); + serde_json::from_str(r#"{"TimestampTZ":["Second","None"]}"#).unwrap(); assert_eq!(deserialized, expected_column_type); let expected_column_type = ColumnType::Boolean; @@ -1064,7 +1185,7 @@ mod tests { assert_eq!(column.column_type().bit_size(), 256); let column: Column<'_, DoryScalar> = - Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[1, 2, 3]); + Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[1, 2, 3]); assert_eq!(column.column_type().byte_size(), 8); assert_eq!(column.column_type().bit_size(), 64); } diff --git a/crates/proof-of-sql/src/base/database/columnar_value.rs b/crates/proof-of-sql/src/base/database/columnar_value.rs index 301361149..5f130914f 100644 --- a/crates/proof-of-sql/src/base/database/columnar_value.rs +++ b/crates/proof-of-sql/src/base/database/columnar_value.rs @@ -1,9 +1,10 @@ use crate::base::{ - database::{Column, ColumnType, LiteralValue}, + database::{Column, ColumnError, ColumnType, ExprExt}, scalar::Scalar, }; use bumpalo::Bump; use snafu::Snafu; +use sqlparser::ast::Expr as SqlExpr; /// The result of evaluating an expression. /// @@ -13,7 +14,7 @@ pub enum ColumnarValue<'a, S: Scalar> { /// A [ `ColumnarValue::Column` ] is a list of values. Column(Column<'a, S>), /// A [ `ColumnarValue::Literal` ] is a single value with indeterminate size. - Literal(LiteralValue), + Literal(SqlExpr), } /// Errors from operations on [`ColumnarValue`]s. @@ -26,6 +27,19 @@ pub enum ColumnarValueError { /// The length we attempted to convert the `[ColumnarValue::Column]` to attempt_to_convert_length: usize, }, + + /// Error during column conversion. + #[snafu(display("Error during column conversion: {source}"))] + ColumnConversionError { + /// The underlying column error. + source: ColumnError, + }, +} + +impl From for ColumnarValueError { + fn from(error: ColumnError) -> Self { + ColumnarValueError::ColumnConversionError { source: error } + } } impl<'a, S: Scalar> ColumnarValue<'a, S> { @@ -34,7 +48,7 @@ impl<'a, S: Scalar> ColumnarValue<'a, S> { pub fn column_type(&self) -> ColumnType { match self { Self::Column(column) => column.column_type(), - Self::Literal(literal) => literal.column_type(), + Self::Literal(expr) => expr.column_type(), } } @@ -55,8 +69,8 @@ impl<'a, S: Scalar> ColumnarValue<'a, S> { }) } } - Self::Literal(literal) => { - Ok(Column::from_literal_with_length(literal, num_rows, alloc)) + Self::Literal(expr) => { + Column::from_literal_with_length(expr, num_rows, alloc).map_err(Into::into) } } } @@ -67,13 +81,14 @@ mod tests { use super::*; use crate::base::scalar::test_scalar::TestScalar; use core::convert::Into; + use sqlparser::ast::{Expr as SqlExpr, Value}; #[test] fn we_can_get_column_type_of_columnar_values() { let column = ColumnarValue::Column(Column::::Int(&[1, 2, 3])); assert_eq!(column.column_type(), ColumnType::Int); - let column = ColumnarValue::::Literal(LiteralValue::Boolean(true)); + let column = ColumnarValue::::Literal(SqlExpr::Value(Value::Boolean(true))); assert_eq!(column.column_type(), ColumnType::Boolean); } @@ -85,12 +100,16 @@ mod tests { let column = columnar_value.into_column(3, &bump).unwrap(); assert_eq!(column, Column::Int(&[1, 2, 3])); - let columnar_value = ColumnarValue::::Literal(LiteralValue::Boolean(false)); + let columnar_value = + ColumnarValue::::Literal(SqlExpr::Value(Value::Boolean(false))); let column = columnar_value.into_column(5, &bump).unwrap(); assert_eq!(column, Column::Boolean(&[false; 5])); // Check whether it works if `num_rows` is 0 - let columnar_value = ColumnarValue::::Literal(LiteralValue::TinyInt(2)); + let columnar_value = ColumnarValue::::Literal(SqlExpr::Value(Value::Number( + "2".to_string(), + false, + ))); let column = columnar_value.into_column(0, &bump).unwrap(); assert_eq!(column, Column::TinyInt(&[])); diff --git a/crates/proof-of-sql/src/base/database/expr_utility.rs b/crates/proof-of-sql/src/base/database/expr_utility.rs new file mode 100644 index 000000000..374d1bbcd --- /dev/null +++ b/crates/proof-of-sql/src/base/database/expr_utility.rs @@ -0,0 +1,170 @@ +use alloc::{boxed::Box, vec}; +use proof_of_sql_parser::{intermediate_ast::Literal, sqlparser::SqlAliasedResultExpr}; +use sqlparser::ast::{ + BinaryOperator, Expr, Function, FunctionArg, FunctionArgExpr, Ident, ObjectName, UnaryOperator, +}; + +/// Compute the sum of an expression +#[must_use] +pub fn sum(expr: Expr) -> Expr { + Expr::Function(Function { + name: ObjectName(vec![Ident::new("SUM")]), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(*Box::new(expr)))], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: vec![], + }) +} + +/// Get column from name +/// +/// # Panics +/// +/// This function will panic if the name cannot be parsed into a valid column expression as valid [Identifier]s. +#[must_use] +pub fn col(name: &str) -> Expr { + Expr::Identifier(name.into()) +} + +/// Compute the maximum of an expression +#[must_use] +pub fn max(expr: Expr) -> Expr { + Expr::Function(Function { + name: ObjectName(vec![Ident::new("MAX")]), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(*Box::new(expr)))], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: vec![], + }) +} + +/// Construct a new `Expr` A + B +#[must_use] +pub fn add(left: Expr, right: Expr) -> Expr { + Expr::BinaryOp { + op: BinaryOperator::Plus, + left: Box::new(left), + right: Box::new(right), + } +} + +/// Construct a new `Expr` A - B +#[must_use] +pub fn sub(left: Expr, right: Expr) -> Expr { + Expr::BinaryOp { + op: BinaryOperator::Minus, + left: Box::new(left), + right: Box::new(right), + } +} + +/// Get literal from value +pub fn lit(literal: L) -> Expr +where + L: Into, +{ + Expr::from(literal.into()) +} + +/// Count the amount of non-null entries of an expression +#[must_use] +pub fn count(expr: Expr) -> Expr { + Expr::Function(Function { + name: ObjectName(vec![Ident::new("COUNT")]), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(*Box::new(expr)))], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: vec![], + }) +} + +/// Count the rows +#[must_use] +pub fn count_all() -> Expr { + count(Expr::Wildcard) +} + +/// Construct a new `Expr` representing A * B +#[must_use] +pub fn mul(left: Expr, right: Expr) -> Expr { + Expr::BinaryOp { + left: Box::new(left), + op: BinaryOperator::Multiply, + right: Box::new(right), + } +} + +/// Compute the minimum of an expression +#[must_use] +pub fn min(expr: Expr) -> Expr { + Expr::Function(Function { + name: ObjectName(vec![Ident::new("MIN")]), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(*Box::new(expr)))], + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: vec![], + }) +} + +/// Construct a new `Expr` for NOT P +#[must_use] +pub fn not(expr: Expr) -> Expr { + Expr::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(expr), + } +} + +/// Construct a new `Expr` for A >= B +#[must_use] +pub fn ge(left: Expr, right: Expr) -> Expr { + Expr::BinaryOp { + left: Box::new(left), + op: BinaryOperator::GtEq, + right: Box::new(right), + } +} + +/// Construct a new `Expr` for A == B +#[must_use] +pub fn equal(left: Expr, right: Expr) -> Expr { + Expr::BinaryOp { + left: Box::new(left), + op: BinaryOperator::Eq, + right: Box::new(right), + } +} + +/// Construct a new `Expr` for P OR Q +#[must_use] +pub fn or(left: Expr, right: Expr) -> Expr { + Expr::BinaryOp { + left: Box::new(left), + op: BinaryOperator::Or, + right: Box::new(right), + } +} + +/// An expression with an alias, i.e., EXPR AS ALIAS +/// +/// # Panics +/// +/// This function will panic if the `alias` cannot be parsed as a valid [Identifier]. +pub fn aliased_expr(expr: Expr, alias: &str) -> SqlAliasedResultExpr { + SqlAliasedResultExpr { + expr: Box::new(expr), + alias: Ident::new(alias), + } +} diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation.rs b/crates/proof-of-sql/src/base/database/expression_evaluation.rs index d9df43097..01c33bebe 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation.rs @@ -1,26 +1,24 @@ use super::{ExpressionEvaluationError, ExpressionEvaluationResult}; use crate::base::{ database::{OwnedColumn, OwnedTable}, - math::{ - decimal::{try_convert_intermediate_decimal_to_scalar, DecimalError, Precision}, - BigDecimalExt, - }, + math::decimal::{try_convert_intermediate_decimal_to_scalar, DecimalError, Precision}, scalar::Scalar, }; use alloc::{format, string::ToString, vec}; -use proof_of_sql_parser::intermediate_ast::{Expression, Literal}; -use sqlparser::ast::{BinaryOperator, Ident, UnaryOperator}; +use bigdecimal::BigDecimal; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::{ + BinaryOperator, DataType, ExactNumberInfo, Expr, Ident, UnaryOperator, Value, +}; impl OwnedTable { /// Evaluate an expression on the table. - pub fn evaluate(&self, expr: &Expression) -> ExpressionEvaluationResult> { + pub fn evaluate(&self, expr: &Expr) -> ExpressionEvaluationResult> { match expr { - Expression::Column(identifier) => self.evaluate_column(&Ident::from(*identifier)), - Expression::Literal(lit) => self.evaluate_literal(lit), - Expression::Binary { op, left, right } => { - self.evaluate_binary_expr(&(*op).into(), left, right) - } - Expression::Unary { op, expr } => self.evaluate_unary_expr((*op).into(), expr), + Expr::Identifier(ident) => self.evaluate_column(ident), + Expr::Value(_) | Expr::TypedString { .. } => self.evaluate_literal(expr), + Expr::BinaryOp { op, left, right } => self.evaluate_binary_expr(op, left, right), + Expr::UnaryOp { op, expr } => self.evaluate_unary_expr(*op, expr), _ => Err(ExpressionEvaluationError::Unsupported { expression: format!("Expression {expr:?} is not supported yet"), }), @@ -36,37 +34,81 @@ impl OwnedTable { })? .clone()) } - - fn evaluate_literal(&self, lit: &Literal) -> ExpressionEvaluationResult> { + /// Evaluates a literal expression and returns its corresponding column representation. + /// + /// # Panics + /// + /// This function will panic if: + /// - `BigDecimal::parse_bytes` fails to parse a valid decimal string. + /// - `Precision::try_from` fails due to invalid precision or scale values. + fn evaluate_literal(&self, value: &Expr) -> ExpressionEvaluationResult> { let len = self.num_rows(); - match lit { - Literal::Boolean(b) => Ok(OwnedColumn::Boolean(vec![*b; len])), - Literal::BigInt(i) => Ok(OwnedColumn::BigInt(vec![*i; len])), - Literal::Int128(i) => Ok(OwnedColumn::Int128(vec![*i; len])), - Literal::Decimal(d) => { - let raw_scale = d.scale(); - let scale = raw_scale - .try_into() - .map_err(|_| DecimalError::InvalidScale { - scale: raw_scale.to_string(), + match value { + Expr::Value(Value::Boolean(b)) => Ok(OwnedColumn::Boolean(vec![*b; len])), + Expr::Value(Value::Number(n, _)) => { + let num = n + .parse::() + .map_err(|_| DecimalError::InvalidDecimal { + error: format!("Invalid number: {n}"), })?; - let precision = Precision::try_from(d.precision())?; - let scalar = try_convert_intermediate_decimal_to_scalar(d, precision, scale)?; - Ok(OwnedColumn::Decimal75(precision, scale, vec![scalar; len])) + if num >= i128::from(i64::MIN) && num <= i128::from(i64::MAX) { + Ok(OwnedColumn::BigInt(vec![num.try_into().unwrap(); len])) + } else { + Ok(OwnedColumn::Int128(vec![num; len])) + } } - Literal::VarChar(s) => Ok(OwnedColumn::VarChar(vec![s.clone(); len])), - Literal::Timestamp(its) => Ok(OwnedColumn::TimestampTZ( - its.timeunit(), - its.timezone(), - vec![its.timestamp().timestamp(); len], - )), + Expr::Value(Value::SingleQuotedString(s)) => { + Ok(OwnedColumn::VarChar(vec![s.clone(); len])) + } + Expr::TypedString { data_type, value } => match data_type { + DataType::Decimal(ExactNumberInfo::PrecisionAndScale(raw_precision, raw_scale)) => { + let decimal = BigDecimal::parse_bytes(value.as_bytes(), 10).unwrap(); + let precision = Precision::try_from(*raw_precision).map_err(|_| { + DecimalError::InvalidPrecision { + error: raw_precision.to_string(), + } + })?; + let scale = + i8::try_from(*raw_scale).map_err(|_| DecimalError::InvalidScale { + scale: raw_scale.to_string(), + })?; + let scalar = + try_convert_intermediate_decimal_to_scalar(&decimal, precision, scale)?; + Ok(OwnedColumn::Decimal75(precision, scale, vec![scalar; len])) + } + DataType::Timestamp(Some(time_unit), time_zone) => { + let time_unit = PoSQLTimeUnit::from_precision(*time_unit).map_err(|err| { + DecimalError::InvalidDecimal { + error: format!("Invalid time unit precision: {err}"), + } + })?; + + let timestamp_value = + value + .parse::() + .map_err(|_| DecimalError::InvalidDecimal { + error: format!("Invalid timestamp value: {value}"), + })?; + Ok(OwnedColumn::TimestampTZ( + time_unit, + *time_zone, + vec![timestamp_value; len], + )) + } + _ => Err(ExpressionEvaluationError::Unsupported { + expression: "Unsupported TypedString data type".to_string(), + }), + }, + _ => Err(ExpressionEvaluationError::Unsupported { + expression: "Unsupported expression type".to_string(), + }), } } fn evaluate_unary_expr( &self, op: UnaryOperator, - expr: &Expression, + expr: &Expr, ) -> ExpressionEvaluationResult> { let column = self.evaluate(expr)?; match op { @@ -81,8 +123,8 @@ impl OwnedTable { fn evaluate_binary_expr( &self, op: &BinaryOperator, - left: &Expression, - right: &Expression, + left: &Expr, + right: &Expr, ) -> ExpressionEvaluationResult> { let left = self.evaluate(left)?; let right = self.evaluate(right)?; diff --git a/crates/proof-of-sql/src/base/database/expression_evaluation_test.rs b/crates/proof-of-sql/src/base/database/expression_evaluation_test.rs index 1123dbea4..125743682 100644 --- a/crates/proof-of-sql/src/base/database/expression_evaluation_test.rs +++ b/crates/proof-of-sql/src/base/database/expression_evaluation_test.rs @@ -9,9 +9,10 @@ use crate::base::{ use bigdecimal::BigDecimal; use proof_of_sql_parser::{ intermediate_ast::Literal, - posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestamp}, + posql_time::{PoSQLTimeUnit, PoSQLTimestamp}, utility::*, }; +use sqlparser::ast::{Expr, TimezoneInfo}; #[test] fn we_can_evaluate_a_simple_literal() { @@ -20,13 +21,15 @@ fn we_can_evaluate_a_simple_literal() { // "Space and Time" in Hebrew let expr = lit("מרחב וזמן".to_string()); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column = OwnedColumn::VarChar(vec!["מרחב וזמן".to_string(); 5]); assert_eq!(actual_column, expected_column); // Is Proof of SQL in production? let expr = lit(true); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column = OwnedColumn::Boolean(vec![true; 5]); assert_eq!(actual_column, expected_column); @@ -35,19 +38,21 @@ fn we_can_evaluate_a_simple_literal() { let expr = lit(Literal::Timestamp( PoSQLTimestamp::try_from(timestamp).unwrap(), )); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); // UNIX timestamp for 2022-03-01T00:00:00Z let actual_timestamp = 1_646_092_800; let expected_column = OwnedColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, vec![actual_timestamp; 5], ); assert_eq!(actual_column, expected_column); // A group of people has about 0.67 cats per person let expr = lit("0.67".parse::().unwrap()); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column = OwnedColumn::Decimal75(Precision::new(2).unwrap(), 2, vec![67.into(); 5]); assert_eq!(actual_column, expected_column); } @@ -60,12 +65,14 @@ fn we_can_evaluate_a_simple_column() { varchar("john", ["John", "Juan", "João", "Jean", "Jean"]), ]); let expr = col("bigints"); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column = OwnedColumn::BigInt(vec![i64::MIN, -1, 0, 1, i64::MAX]); assert_eq!(actual_column, expected_column); let expr = col("john"); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column = OwnedColumn::VarChar( ["John", "Juan", "João", "Jean", "Jean"] .iter() @@ -81,8 +88,9 @@ fn we_can_not_evaluate_a_nonexisting_column() { owned_table([varchar("cats", ["Chloe", "Margaret", "Prudence", "Lucy"])]); // "not_a_column" is not a column in the table let expr = col("not_a_column"); + let sql_expr: Expr = (*expr).into(); assert!(matches!( - table.evaluate(&expr), + table.evaluate(&sql_expr), Err(ExpressionEvaluationError::ColumnNotFound { .. }) )); } @@ -101,20 +109,23 @@ fn we_can_evaluate_a_logical_expression() { // Find words that are not proper nouns let expr = not(col("is_proper_noun")); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column = OwnedColumn::Boolean(vec![false, false, true, true, false]); assert_eq!(actual_column, expected_column); // Which Czech and Slovak words agree? let expr = equal(col("cz"), col("sk")); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column: OwnedColumn = OwnedColumn::Boolean(vec![false, false, false, true, false]); assert_eq!(actual_column, expected_column); // Find words shared among Slovak, Croatian and Slovenian let expr = and(equal(col("sk"), col("hr")), equal(col("hr"), col("sl"))); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column: OwnedColumn = OwnedColumn::Boolean(vec![false, false, true, false, false]); assert_eq!(actual_column, expected_column); @@ -124,7 +135,8 @@ fn we_can_evaluate_a_logical_expression() { equal(col("pl"), col("cz")), not(equal(col("pl"), col("sl"))), ); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column: OwnedColumn = OwnedColumn::Boolean(vec![false, true, false, false, false]); assert_eq!(actual_column, expected_column); @@ -134,7 +146,8 @@ fn we_can_evaluate_a_logical_expression() { col("is_proper_noun"), and(equal(col("hr"), col("sl")), equal(col("hr"), col("sk"))), ); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column: OwnedColumn = OwnedColumn::Boolean(vec![true, true, true, false, true]); assert_eq!(actual_column, expected_column); @@ -152,13 +165,15 @@ fn we_can_evaluate_an_arithmetic_expression() { // Subtract 1 from the bigints let expr = sub(col("bigints"), lit(1)); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column = OwnedColumn::BigInt(vec![-9, -5, -1, 3, 7]); assert_eq!(actual_column, expected_column); // Add bigints to the smallints and multiply the sum by the ints let expr = mul(add(col("bigints"), col("smallints")), col("ints")); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_column = OwnedColumn::BigInt(vec![40, 10, 0, 10, 40]); assert_eq!(actual_column, expected_column); @@ -167,7 +182,8 @@ fn we_can_evaluate_an_arithmetic_expression() { col("smallints"), mul(col("decimals"), lit("0.75".parse::().unwrap())), ); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_scalars = [-2000, -925, 150, 1225, 2300] .iter() .map(|&x| x.into()) @@ -180,7 +196,8 @@ fn we_can_evaluate_an_arithmetic_expression() { div(col("decimals"), lit("2.5".parse::().unwrap())), col("int128s"), ); - let actual_column = table.evaluate(&expr).unwrap(); + let sql_expr: Expr = (*expr).into(); + let actual_column = table.evaluate(&sql_expr).unwrap(); let expected_scalars = [-16_000_000, -7_960_000, 80000, 8_120_000, 16_160_000] .iter() .map(|&x| x.into()) @@ -199,8 +216,9 @@ fn we_cannot_evaluate_expressions_if_column_operation_errors_out() { // NOT doesn't work on varchar let expr = not(col("language")); + let sql_expr: Expr = (*expr).into(); assert!(matches!( - table.evaluate(&expr), + table.evaluate(&sql_expr), Err(ExpressionEvaluationError::ColumnOperationError { source: ColumnOperationError::UnaryOperationInvalidColumnType { .. } }) @@ -208,8 +226,9 @@ fn we_cannot_evaluate_expressions_if_column_operation_errors_out() { // NOT doesn't work on bigint let expr = not(col("bigints")); + let sql_expr: Expr = (*expr).into(); assert!(matches!( - table.evaluate(&expr), + table.evaluate(&sql_expr), Err(ExpressionEvaluationError::ColumnOperationError { source: ColumnOperationError::UnaryOperationInvalidColumnType { .. } }) @@ -217,8 +236,9 @@ fn we_cannot_evaluate_expressions_if_column_operation_errors_out() { // + doesn't work on varchar let expr = add(col("sarah"), col("bigints")); + let sql_expr: Expr = (*expr).into(); assert!(matches!( - table.evaluate(&expr), + table.evaluate(&sql_expr), Err(ExpressionEvaluationError::ColumnOperationError { source: ColumnOperationError::BinaryOperationInvalidColumnType { .. } }) @@ -226,8 +246,9 @@ fn we_cannot_evaluate_expressions_if_column_operation_errors_out() { // i64::MIN - 1 overflows let expr = sub(col("bigints"), lit(1)); + let sql_expr: Expr = (*expr).into(); assert!(matches!( - table.evaluate(&expr), + table.evaluate(&sql_expr), Err(ExpressionEvaluationError::ColumnOperationError { source: ColumnOperationError::IntegerOverflow { .. } }) @@ -235,8 +256,9 @@ fn we_cannot_evaluate_expressions_if_column_operation_errors_out() { // We can't divide by zero let expr = div(col("bigints"), lit(0)); + let sql_expr: Expr = (*expr).clone().into(); assert!(matches!( - table.evaluate(&expr), + table.evaluate(&sql_expr), Err(ExpressionEvaluationError::ColumnOperationError { source: ColumnOperationError::DivisionByZero }) diff --git a/crates/proof-of-sql/src/base/database/literal_value.rs b/crates/proof-of-sql/src/base/database/literal_value.rs index d4cde1eb3..54faac1c6 100644 --- a/crates/proof-of-sql/src/base/database/literal_value.rs +++ b/crates/proof-of-sql/src/base/database/literal_value.rs @@ -1,78 +1,100 @@ -use crate::base::{ - database::ColumnType, - math::{decimal::Precision, i256::I256}, - scalar::Scalar, +use crate::{ + alloc::string::ToString, + base::{ + database::ColumnType, + math::{decimal::Precision, i256::I256}, + scalar::Scalar, + }, }; -use alloc::string::String; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; -use serde::{Deserialize, Serialize}; +use alloc::vec::Vec; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::{DataType, ExactNumberInfo, Expr, Value}; -/// Represents a literal value. -/// -/// Note: The types here should correspond to native SQL database types. -/// See `` for -/// a description of the native types used by Apache Ignite. -#[derive(Debug, Eq, PartialEq, Clone, Serialize, Deserialize)] -#[non_exhaustive] -pub enum LiteralValue { - /// Boolean literals - Boolean(bool), - /// i8 literals - TinyInt(i8), - /// i16 literals - SmallInt(i16), - /// i32 literals - Int(i32), - /// i64 literals - BigInt(i64), +/// A trait for SQL expressions that provides functionality to retrieve their associated column type. +/// This trait is primarily used to map SQL expressions to their corresponding [`ColumnType`]. +pub trait ExprExt { + /// Determines the [`ColumnType`] associated with the expression. + fn column_type(&self) -> ColumnType; +} - /// String literals - /// - the first element maps to the str value. - /// - the second element maps to the str hash (see [`crate::base::scalar::Scalar`]). - VarChar(String), - /// i128 literals - Int128(i128), - /// Decimal literals with a max width of 252 bits - /// - the backing store maps to the type [`crate::base::scalar::Curve25519Scalar`] - Decimal75(Precision, i8, I256), - /// Scalar literals. The underlying `[u64; 4]` is the limbs of the canonical form of the literal - Scalar([u64; 4]), - /// `TimeStamp` defined over a unit (s, ms, ns, etc) and timezone with backing store - /// mapped to i64, which is time units since unix epoch - TimeStampTZ(PoSQLTimeUnit, PoSQLTimeZone, i64), +/// A trait for SQL expressions that allows converting them into scalar values. +/// This trait provides functionality to interpret SQL expressions as scalars +pub trait ToScalar { + /// Converts the SQL expression into a scalar value of the specified type. + fn to_scalar(&self) -> S; } -impl LiteralValue { +impl ExprExt for Expr { /// Provides the column type associated with the column #[must_use] - pub fn column_type(&self) -> ColumnType { + fn column_type(&self) -> ColumnType { match self { - Self::Boolean(_) => ColumnType::Boolean, - Self::TinyInt(_) => ColumnType::TinyInt, - Self::SmallInt(_) => ColumnType::SmallInt, - Self::Int(_) => ColumnType::Int, - Self::BigInt(_) => ColumnType::BigInt, - Self::VarChar(_) => ColumnType::VarChar, - Self::Int128(_) => ColumnType::Int128, - Self::Scalar(_) => ColumnType::Scalar, - Self::Decimal75(precision, scale, _) => ColumnType::Decimal75(*precision, *scale), - Self::TimeStampTZ(tu, tz, _) => ColumnType::TimestampTZ(*tu, *tz), + Expr::Value(Value::Boolean(_)) => ColumnType::Boolean, + Expr::Value(Value::Number(value, _)) => { + let n = value.parse::().unwrap_or_else(|err| { + panic!("Failed to parse '{value}' as a number. Error: {err}"); + }); + if i8::try_from(n).is_ok() { + ColumnType::TinyInt + } else if i16::try_from(n).is_ok() { + ColumnType::SmallInt + } else if i32::try_from(n).is_ok() { + ColumnType::Int + } else { + ColumnType::BigInt + } + } + Expr::Value(Value::SingleQuotedString(_)) => ColumnType::VarChar, + Expr::TypedString { data_type, .. } => match data_type { + DataType::Decimal(ExactNumberInfo::PrecisionAndScale(p, s)) => { + let precision = u8::try_from(*p).expect("Precision must fit into u8"); + let scale = i8::try_from(*s).expect("Scale must fit into i8"); + let precision_obj = + Precision::new(precision).expect("Failed to create Precision"); + ColumnType::Decimal75(precision_obj, scale) + } + DataType::Timestamp(Some(precision), tz) => { + let tu = + PoSQLTimeUnit::from_precision(*precision).unwrap_or(PoSQLTimeUnit::Second); + ColumnType::TimestampTZ(tu, *tz) + } + DataType::Custom(_, _) if data_type.to_string() == "scalar" => ColumnType::Scalar, + _ => unimplemented!("Mapping for {:?} is not implemented", data_type), + }, + _ => unimplemented!("Mapping for {:?} is not implemented", self), } } +} +impl ToScalar for Expr { /// Converts the literal to a scalar - pub(crate) fn to_scalar(&self) -> S { + fn to_scalar(&self) -> S { match self { - Self::Boolean(b) => b.into(), - Self::TinyInt(i) => i.into(), - Self::SmallInt(i) => i.into(), - Self::Int(i) => i.into(), - Self::BigInt(i) => i.into(), - Self::VarChar(str) => str.into(), - Self::Decimal75(_, _, i) => i.into_scalar(), - Self::Int128(i) => i.into(), - Self::Scalar(limbs) => (*limbs).into(), - Self::TimeStampTZ(_, _, time) => time.into(), + Expr::Value(Value::Boolean(b)) => b.into(), + Expr::Value(Value::Number(n, _)) => n + .parse::() + .unwrap_or_else(|_| panic!("Invalid number: {n}")) + .into(), + Expr::Value(Value::SingleQuotedString(s)) => s.into(), + Expr::TypedString { data_type, value } if data_type.to_string() == "scalar" => { + let scalar_str = value.strip_prefix("scalar:").unwrap(); + let limbs: Vec = scalar_str + .split(',') + .map(|x| x.parse::().unwrap()) + .collect(); + assert!(limbs.len() == 4, "Scalar must have exactly 4 limbs"); + S::from([limbs[0], limbs[1], limbs[2], limbs[3]]) + } + Expr::TypedString { data_type, value } => match data_type { + DataType::Timestamp(_, _) => value.parse::().unwrap().into(), + DataType::Decimal(_) => { + let i256_value = I256::from_string(value) + .unwrap_or_else(|_| panic!("Failed to parse '{value}' as a decimal")); + i256_value.into_scalar() + } + _ => unimplemented!("Conversion for {:?} is not implemented.", data_type), + }, + _ => unimplemented!("Conversion for {:?} is not implemented", self), } } } diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index 03f80f23c..30584c1c7 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -12,6 +12,9 @@ mod slice_operation; mod slice_decimal_operation; +/// util functions for `Expr` tests +pub mod expr_utility; + mod column_type_operation; pub use column_type_operation::{ try_add_subtract_column_types, try_divide_column_types, try_multiply_column_types, @@ -40,8 +43,9 @@ pub use table_operation_error::{TableOperationError, TableOperationResult}; mod columnar_value; pub use columnar_value::ColumnarValue; -mod literal_value; -pub use literal_value::LiteralValue; +/// TODO: add docs +pub mod literal_value; +pub use literal_value::{ExprExt, ToScalar}; mod table_ref; #[cfg(feature = "arrow")] @@ -126,3 +130,5 @@ mod order_by_util_test; #[allow(dead_code)] pub(crate) mod join_util; + +pub use column::ColumnError; diff --git a/crates/proof-of-sql/src/base/database/owned_column.rs b/crates/proof-of-sql/src/base/database/owned_column.rs index ad5b96ffc..779f5aaca 100644 --- a/crates/proof-of-sql/src/base/database/owned_column.rs +++ b/crates/proof-of-sql/src/base/database/owned_column.rs @@ -16,8 +16,9 @@ use alloc::{ vec::Vec, }; use itertools::Itertools; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; use serde::{Deserialize, Serialize}; +use sqlparser::ast::TimezoneInfo; #[derive(Debug, PartialEq, Clone, Eq, Serialize, Deserialize)] #[non_exhaustive] @@ -42,7 +43,7 @@ pub enum OwnedColumn { /// Scalar columns Scalar(Vec), /// Timestamp columns - TimestampTZ(PoSQLTimeUnit, PoSQLTimeZone, Vec), + TimestampTZ(PoSQLTimeUnit, TimezoneInfo, Vec), } impl OwnedColumn { diff --git a/crates/proof-of-sql/src/base/database/owned_table.rs b/crates/proof-of-sql/src/base/database/owned_table.rs index 2218c5e42..167b39bfc 100644 --- a/crates/proof-of-sql/src/base/database/owned_table.rs +++ b/crates/proof-of-sql/src/base/database/owned_table.rs @@ -8,7 +8,6 @@ use itertools::{EitherOrBoth, Itertools}; use serde::{Deserialize, Serialize}; use snafu::Snafu; use sqlparser::ast::Ident; - /// An error that occurs when working with tables. #[derive(Snafu, Debug, PartialEq, Eq)] pub enum OwnedTableError { @@ -198,7 +197,8 @@ mod tests { scalar::test_scalar::TestScalar, }; use bumpalo::Bump; - use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; + use proof_of_sql_parser::posql_time::PoSQLTimeUnit; + use sqlparser::ast::TimezoneInfo; #[test] fn test_conversion_from_table_to_owned_table() { @@ -229,7 +229,7 @@ mod tests { borrowed_timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX], &alloc, ), @@ -247,7 +247,7 @@ mod tests { timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX], ), ]); diff --git a/crates/proof-of-sql/src/base/database/owned_table_test.rs b/crates/proof-of-sql/src/base/database/owned_table_test.rs index 183e5c870..a3bbc1436 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test.rs @@ -6,8 +6,9 @@ use crate::{ }, proof_primitive::dory::DoryScalar, }; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; -use sqlparser::ast::Ident; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::{Ident, TimezoneInfo}; + #[test] fn we_can_create_an_owned_table_with_no_columns() { let table = OwnedTable::::try_new(IndexMap::default()).unwrap(); @@ -44,7 +45,7 @@ fn we_can_create_an_owned_table_with_data() { timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX], ), ]); @@ -53,7 +54,7 @@ fn we_can_create_an_owned_table_with_data() { Ident::new("time_stamp"), OwnedColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX].into(), ), ); @@ -111,7 +112,7 @@ fn we_get_inequality_between_tables_with_differing_column_order() { timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0; 0], ), ]); @@ -123,7 +124,7 @@ fn we_get_inequality_between_tables_with_differing_column_order() { timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0; 0], ), ]); @@ -139,7 +140,7 @@ fn we_get_inequality_between_tables_with_differing_data() { timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [1_625_072_400], ), ]); @@ -151,7 +152,7 @@ fn we_get_inequality_between_tables_with_differing_data() { timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [1_625_076_000], ), ]); diff --git a/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs b/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs index 0eefd4cf3..f566e4eb8 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_test_accessor_test.rs @@ -10,7 +10,8 @@ use crate::base::{ database::owned_table_utility::*, scalar::test_scalar::TestScalar, }; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::TimezoneInfo; #[test] fn we_can_query_the_length_of_a_table() { @@ -55,7 +56,7 @@ fn we_can_access_the_columns_of_a_table() { timestamptz( "time", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::WithTimeZone, [4, 5, 6, 5], ), ]); @@ -116,7 +117,7 @@ fn we_can_access_the_columns_of_a_table() { let column = ColumnRef::new( table_ref_2, "time".into(), - ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()), + ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::WithTimeZone), ); match accessor.get_column(column) { Column::TimestampTZ(_, _, col) => assert_eq!(col.to_vec(), vec![4, 5, 6, 5]), diff --git a/crates/proof-of-sql/src/base/database/owned_table_utility.rs b/crates/proof-of-sql/src/base/database/owned_table_utility.rs index 97a695ca4..d87ab703e 100644 --- a/crates/proof-of-sql/src/base/database/owned_table_utility.rs +++ b/crates/proof-of-sql/src/base/database/owned_table_utility.rs @@ -16,8 +16,8 @@ use super::{OwnedColumn, OwnedTable}; use crate::base::scalar::Scalar; use alloc::string::String; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; -use sqlparser::ast::Ident; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::{Ident, TimezoneInfo}; /// Creates an [`OwnedTable`] from a list of `(Ident, OwnedColumn)` pairs. /// This is a convenience wrapper around [`OwnedTable::try_from_iter`] primarily for use in tests and @@ -242,18 +242,17 @@ pub fn decimal75( /// use proof_of_sql::base::{database::owned_table_utility::*, /// scalar::Curve25519Scalar, /// }; -/// use proof_of_sql_parser::{ -/// posql_time::{PoSQLTimeZone, PoSQLTimeUnit}}; +/// use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +/// use sqlparser::ast::TimezoneInfo; /// /// let result = owned_table::([ -/// timestamptz("event_time", PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), vec![1625072400, 1625076000, 1625079600]), +/// timestamptz("event_time", PoSQLTimeUnit::Second, TimezoneInfo::WithTimeZone, vec![1625072400, 1625076000, 1625079600]), /// ]); /// ``` - pub fn timestamptz( name: impl Into, time_unit: PoSQLTimeUnit, - timezone: PoSQLTimeZone, + timezone: TimezoneInfo, data: impl IntoIterator, ) -> (Ident, OwnedColumn) { ( diff --git a/crates/proof-of-sql/src/base/database/table_test.rs b/crates/proof-of-sql/src/base/database/table_test.rs index 6f9eaef13..b145c985b 100644 --- a/crates/proof-of-sql/src/base/database/table_test.rs +++ b/crates/proof-of-sql/src/base/database/table_test.rs @@ -4,8 +4,8 @@ use crate::base::{ scalar::test_scalar::TestScalar, }; use bumpalo::Bump; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; -use sqlparser::ast::Ident; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::{Ident, TimezoneInfo}; #[test] fn we_can_create_a_table_with_no_columns_specifying_row_count() { let table = @@ -151,7 +151,7 @@ fn we_can_create_a_table_with_data() { borrowed_timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX], &alloc, ), @@ -162,7 +162,7 @@ fn we_can_create_a_table_with_data() { let time_stamp_data = alloc.alloc_slice_copy(&[0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]); expected_table.insert( Ident::new("time_stamp"), - Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), time_stamp_data), + Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, time_stamp_data), ); let bigint_data = alloc.alloc_slice_copy(&[0_i64, 1, 2, 3, 4, 5, 6, i64::MIN, i64::MAX]); @@ -206,7 +206,7 @@ fn we_get_inequality_between_tables_with_differing_column_order() { borrowed_timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0_i64; 0], &alloc, ), @@ -220,7 +220,7 @@ fn we_get_inequality_between_tables_with_differing_column_order() { borrowed_timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0_i64; 0], &alloc, ), @@ -241,7 +241,7 @@ fn we_get_inequality_between_tables_with_differing_data() { borrowed_timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [1_625_072_400], &alloc, ), @@ -255,7 +255,7 @@ fn we_get_inequality_between_tables_with_differing_data() { borrowed_timestamptz( "time_stamp", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [1_625_076_000], &alloc, ), diff --git a/crates/proof-of-sql/src/base/database/table_test_accessor_test.rs b/crates/proof-of-sql/src/base/database/table_test_accessor_test.rs index a594fe151..c310b0147 100644 --- a/crates/proof-of-sql/src/base/database/table_test_accessor_test.rs +++ b/crates/proof-of-sql/src/base/database/table_test_accessor_test.rs @@ -11,7 +11,8 @@ use crate::base::{ scalar::test_scalar::TestScalar, }; use bumpalo::Bump; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::TimezoneInfo; #[test] fn we_can_query_the_length_of_a_table() { @@ -67,7 +68,7 @@ fn we_can_access_the_columns_of_a_table() { borrowed_timestamptz( "time", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [4, 5, 6, 5], &alloc, ), @@ -129,7 +130,7 @@ fn we_can_access_the_columns_of_a_table() { let column = ColumnRef::new( table_ref_2, "time".into(), - ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()), + ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None), ); match accessor.get_column(column) { Column::TimestampTZ(_, _, col) => assert_eq!(col.to_vec(), vec![4, 5, 6, 5]), diff --git a/crates/proof-of-sql/src/base/database/table_utility.rs b/crates/proof-of-sql/src/base/database/table_utility.rs index 25458bac3..dcaed3911 100644 --- a/crates/proof-of-sql/src/base/database/table_utility.rs +++ b/crates/proof-of-sql/src/base/database/table_utility.rs @@ -19,8 +19,8 @@ use super::{Column, Table, TableOptions}; use crate::base::scalar::Scalar; use alloc::{string::String, vec::Vec}; use bumpalo::Bump; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; -use sqlparser::ast::Ident; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::{Ident, TimezoneInfo}; /// Creates an [`Table`] from a list of `(Ident, Column)` pairs. /// This is a convenience wrapper around [`Table::try_from_iter`] primarily for use in tests and @@ -301,19 +301,18 @@ pub fn borrowed_decimal75( /// use proof_of_sql::base::{database::table_utility::*, /// scalar::Curve25519Scalar, /// }; -/// use proof_of_sql_parser::{ -/// posql_time::{PoSQLTimeZone, PoSQLTimeUnit}}; -/// +/// use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +/// use sqlparser::ast::TimezoneInfo; /// let alloc = Bump::new(); /// let result = table::([ -/// borrowed_timestamptz("event_time", PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), vec![1625072400, 1625076000, 1625079600], &alloc), +/// borrowed_timestamptz("event_time", PoSQLTimeUnit::Second, TimezoneInfo::None,vec![1625072400, 1625076000, 1625079600], &alloc), /// ]); /// ``` pub fn borrowed_timestamptz( name: impl Into, time_unit: PoSQLTimeUnit, - timezone: PoSQLTimeZone, + timezone: TimezoneInfo, data: impl IntoIterator, alloc: &Bump, ) -> (Ident, Column<'_, S>) { diff --git a/crates/proof-of-sql/src/base/math/big_decimal_ext.rs b/crates/proof-of-sql/src/base/math/big_decimal_ext.rs index 34ded4040..19e86c70b 100644 --- a/crates/proof-of-sql/src/base/math/big_decimal_ext.rs +++ b/crates/proof-of-sql/src/base/math/big_decimal_ext.rs @@ -3,6 +3,7 @@ use bigdecimal::BigDecimal; use num_bigint::BigInt; pub trait BigDecimalExt { + #[allow(dead_code)] fn precision(&self) -> u64; fn scale(&self) -> i64; fn try_into_bigint_with_precision_and_scale( @@ -14,6 +15,7 @@ pub trait BigDecimalExt { impl BigDecimalExt for BigDecimal { /// Get the precision of the fixed-point representation of this intermediate decimal. #[must_use] + #[allow(dead_code)] fn precision(&self) -> u64 { self.normalized().digits() } diff --git a/crates/proof-of-sql/src/base/math/i256.rs b/crates/proof-of-sql/src/base/math/i256.rs index b088ce082..3df8de65b 100644 --- a/crates/proof-of-sql/src/base/math/i256.rs +++ b/crates/proof-of-sql/src/base/math/i256.rs @@ -1,4 +1,5 @@ use crate::base::scalar::Scalar; +use alloc::{fmt, format, string::String, vec::Vec}; use ark_ff::BigInteger; use serde::{Deserialize, Serialize}; @@ -44,7 +45,31 @@ impl I256 { num_bigint::Sign::Plus | num_bigint::Sign::NoSign => Self(limbs), } } + + /// Creates an `I256` instance from a string representation of a number. + pub fn from_string(value: &str) -> Result { + let bigint = num_bigint::BigInt::parse_bytes(value.as_bytes(), 10) + .ok_or_else(|| format!("Failed to parse '{value}' as a valid number"))?; + Ok(Self::from_num_bigint(&bigint)) + } + + /// Converts `I256` into a little-endian byte array for compatibility with `BigInt`. + fn as_bytes_le(&self) -> Vec { + let mut bytes = Vec::with_capacity(32); + for limb in &self.0 { + bytes.extend_from_slice(&limb.to_le_bytes()); + } + bytes + } } + +impl fmt::Display for I256 { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let bigint = num_bigint::BigInt::from_signed_bytes_le(&self.as_bytes_le()); + write!(f, "{bigint}") + } +} + impl From for I256 { fn from(value: i32) -> Self { let abs = Self([value.unsigned_abs().into(), 0, 0, 0]); diff --git a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/error.rs b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/error.rs index 61e0c1208..0b5c1d564 100644 --- a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/error.rs +++ b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/error.rs @@ -1,3 +1,4 @@ +use alloc::string::String; use snafu::Snafu; /// Errors that can occur during proof plan serialization. @@ -22,4 +23,8 @@ pub enum ProofPlanSerializationError { /// Error indicating that the column was not found. #[snafu(display("Column not found"))] ColumnNotFound, + + /// Error indicating as an invalid number format. + #[snafu(display("Invalid number format: {value:?}"))] + InvalidNumberFormat { value: String }, } diff --git a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_expr.rs b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_expr.rs index b7f04d738..45cef39ca 100644 --- a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_expr.rs +++ b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_expr.rs @@ -4,11 +4,12 @@ use super::{ DynProofPlanSerializer, ProofPlanSerializationError, }; use crate::{ - base::{database::LiteralValue, scalar::Scalar}, + base::scalar::Scalar, evm_compatibility::primitive_serialize_ext::PrimitiveSerializeExt, sql::proof_exprs::{ColumnExpr, DynProofExpr, EqualsExpr, LiteralExpr}, }; use snafu::OptionExt; +use sqlparser::ast::{Expr, Value}; impl DynProofPlanSerializer { pub(super) fn serialize_dyn_proof_expr( @@ -45,10 +46,17 @@ impl DynProofPlanSerializer { self, literal_expr: &LiteralExpr, ) -> Result { - match literal_expr.value { - LiteralValue::BigInt(value) => Ok(self - .serialize_u8(BIGINT_TYPE_NUM) - .serialize_scalar(value.into())), + match &literal_expr.value { + Expr::Value(Value::Number(value, _)) => { + let parsed_value = value.parse::().map_err(|_| { + ProofPlanSerializationError::InvalidNumberFormat { + value: value.clone(), + } + })?; + Ok(self + .serialize_u8(BIGINT_TYPE_NUM) + .serialize_scalar(parsed_value.into())) + } _ => NotSupportedSnafu.fail(), } } @@ -75,6 +83,7 @@ mod tests { }; use core::iter; use itertools::Itertools; + use sqlparser::ast::{Expr as SqlExpr, Value}; #[test] fn we_can_serialize_a_column_expr() { @@ -135,7 +144,8 @@ mod tests { // Serialization of a big int literal should result in a byte with the big int type number, // followed by the big int value in big-endian form, padded with leading zeros to 32 bytes. - let literal_bigint_expr = LiteralExpr::new(LiteralValue::BigInt(4200)); + let literal_bigint_expr = + LiteralExpr::new(SqlExpr::Value(Value::Number("4200".to_string(), false))); let bigint_bytes = serializer .clone() .serialize_literal_expr(&literal_bigint_expr) @@ -164,7 +174,8 @@ mod tests { // Serialization of a small int literal should result in an error // because only big int literals are supported so far - let literal_smallint_expr = LiteralExpr::new(LiteralValue::SmallInt(4200)); + let literal_smallint_expr = + LiteralExpr::new(SqlExpr::Value(Value::Number("4200".to_string(), false))); let result = serializer .clone() .serialize_literal_expr(&literal_smallint_expr); @@ -186,7 +197,10 @@ mod tests { .unwrap(); let lhs = DynProofExpr::Column(ColumnExpr::new(column_0_ref)); - let rhs = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); + let rhs = DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value(Value::Number( + "4200".to_string(), + false, + )))); let lhs_bytes = serializer .clone() .serialize_dyn_proof_expr(&lhs) @@ -239,7 +253,10 @@ mod tests { .unwrap(); let lhs = DynProofExpr::Column(ColumnExpr::new(column_0_ref)); - let rhs = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); + let rhs = DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value(Value::Number( + "4200".to_string(), + false, + )))); let expr = DynProofExpr::And(AndExpr::new(Box::new(lhs.clone()), Box::new(rhs.clone()))); let result = serializer.clone().serialize_dyn_proof_expr(&expr); assert!(matches!( diff --git a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_plan.rs b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_plan.rs index a6b45db24..2937ca896 100644 --- a/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_plan.rs +++ b/crates/proof-of-sql/src/evm_compatibility/dyn_proof_plan_serializer/serialize_proof_plan.rs @@ -69,18 +69,22 @@ impl DynProofPlanSerializer { mod tests { use super::*; use crate::{ - base::{database::LiteralValue, map::indexset, scalar::test_scalar::TestScalar}, + base::{map::indexset, scalar::test_scalar::TestScalar}, sql::proof_exprs::{DynProofExpr, LiteralExpr}, }; use core::iter; use itertools::Itertools; + use sqlparser::ast::{Expr as SqlExpr, Value}; #[test] fn we_can_serialize_an_aliased_dyn_proof_expr() { let serializer = DynProofPlanSerializer::::try_new(indexset! {}, indexset! {}).unwrap(); - let expr = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); + let expr = DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value(Value::Number( + "4200".to_string(), + false, + )))); let expr_bytes = serializer .clone() .serialize_dyn_proof_expr(&expr) @@ -150,9 +154,18 @@ mod tests { DynProofPlanSerializer::::try_new(indexset! { table_ref }, indexset! {}) .unwrap(); - let expr_a = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); - let expr_b = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); - let expr_c = DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(4200))); + let expr_a = DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value(Value::Number( + "4200".to_string(), + false, + )))); + let expr_b = DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value(Value::Number( + "4200".to_string(), + false, + )))); + let expr_c = DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value(Value::Number( + "4200".to_string(), + false, + )))); let aliased_expr_0 = AliasedDynProofExpr { expr: expr_a.clone(), alias: "alias_0".into(), diff --git a/crates/proof-of-sql/src/evm_compatibility/serialize_query_expr.rs b/crates/proof-of-sql/src/evm_compatibility/serialize_query_expr.rs index 5a97ed2d2..769c3a1fc 100644 --- a/crates/proof-of-sql/src/evm_compatibility/serialize_query_expr.rs +++ b/crates/proof-of-sql/src/evm_compatibility/serialize_query_expr.rs @@ -4,7 +4,6 @@ use crate::{ sql::{parse::QueryExpr, proof::ProofPlan}, }; use alloc::vec::Vec; - /// Serializes a `QueryExpr` into a vector of bytes. /// /// This function takes a `QueryExpr` and attempts to serialize it into a vector of bytes. @@ -45,7 +44,7 @@ pub fn serialize_query_expr( mod tests { use crate::{ base::{ - database::{ColumnRef, ColumnType, LiteralValue}, + database::{ColumnRef, ColumnType}, map::indexset, scalar::test_scalar::TestScalar, }, @@ -66,6 +65,7 @@ mod tests { }; use core::iter; use itertools::Itertools; + use sqlparser::ast::{Expr as SqlExpr, Value}; #[test] fn we_can_generate_serialized_proof_plan_for_query_expr() { @@ -74,11 +74,17 @@ mod tests { let plan = DynProofPlan::Filter(FilterExec::new( vec![AliasedDynProofExpr { - expr: DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(1001))), + expr: DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value(Value::Number( + "1001".to_string(), + false, + )))), alias: identifier_alias, }], TableExpr { table_ref }, - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(1002))), + DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value(Value::Number( + "1001".to_string(), + false, + )))), )); // Serializing a query expression without postprocessing steps should succeed and @@ -136,9 +142,9 @@ mod tests { TableExpr { table_ref }, DynProofExpr::Equals(EqualsExpr::new( Box::new(DynProofExpr::Column(ColumnExpr::new(column_ref_a))), - Box::new(DynProofExpr::Literal(LiteralExpr::new( - LiteralValue::BigInt(5), - ))), + Box::new(DynProofExpr::Literal(LiteralExpr::new(SqlExpr::Value( + Value::Number("5".to_string(), false), + )))), )), )); diff --git a/crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs b/crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs index a79973857..d914fc99c 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/blitzar_metadata_table.rs @@ -291,7 +291,8 @@ pub fn create_blitzar_metadata_tables( mod tests { use super::*; use crate::base::math::decimal::Precision; - use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; + use proof_of_sql_parser::posql_time::PoSQLTimeUnit; + use sqlparser::ast::TimezoneInfo; fn assert_blitzar_metadata( committable_columns: &[CommittableColumn], @@ -633,7 +634,7 @@ mod tests { CommittableColumn::Decimal75(Precision::new(1).unwrap(), 0, vec![[6, 0, 0, 0]]), CommittableColumn::Scalar(vec![[7, 0, 0, 0]]), CommittableColumn::VarChar(vec![[8, 0, 0, 0]]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[9]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[9]), CommittableColumn::Boolean(&[true]), ]; @@ -667,7 +668,7 @@ mod tests { CommittableColumn::Decimal75(Precision::new(1).unwrap(), 0, vec![[6, 0, 0, 0]]), CommittableColumn::Scalar(vec![[7, 0, 0, 0]]), CommittableColumn::VarChar(vec![[8, 0, 0, 0]]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[9]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[9]), CommittableColumn::Boolean(&[true]), ]; @@ -709,7 +710,7 @@ mod tests { CommittableColumn::Decimal75(Precision::new(1).unwrap(), 0, vec![[6, 0, 0, 0]]), CommittableColumn::Scalar(vec![[7, 0, 0, 0]]), CommittableColumn::VarChar(vec![[8, 0, 0, 0]]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[9]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[9]), CommittableColumn::Boolean(&[true]), ]; diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dory_compute_commitments_test.rs b/crates/proof-of-sql/src/proof_primitive/dory/dory_compute_commitments_test.rs index b73c956bc..847525052 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dory_compute_commitments_test.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dory_compute_commitments_test.rs @@ -7,7 +7,8 @@ use crate::{ use ark_ec::pairing::Pairing; use ark_std::test_rng; use num_traits::Zero; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::TimezoneInfo; #[test] fn we_can_compute_a_dory_commitment_with_int128_values() { @@ -414,7 +415,7 @@ fn we_can_compute_a_dory_commitment_with_mixed_committable_columns_with_fewer_ro CommittableColumn::VarChar(vec![[16, 0, 0, 0]]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[17, 18, 19, 20], ), ], @@ -491,7 +492,7 @@ fn we_can_compute_a_dory_commitment_with_mixed_committable_columns_with_an_offse CommittableColumn::VarChar(vec![[16, 0, 0, 0]]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[17, 18, 19, 20], ), ], @@ -567,7 +568,7 @@ fn we_can_compute_a_dory_commitment_with_mixed_committable_columns_with_signed_v CommittableColumn::VarChar(vec![[16, 0, 0, 0]]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[-18, -17, 17, 18], ), ], @@ -656,7 +657,7 @@ fn we_can_compute_a_dory_commitment_with_mixed_committable_columns_with_an_offse CommittableColumn::VarChar(vec![[16, 0, 0, 0]]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[-18, -17, 17, 18], ), ], diff --git a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_compute_commitments_test.rs b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_compute_commitments_test.rs index 25f81f7c3..952cb482e 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_compute_commitments_test.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/dynamic_dory_compute_commitments_test.rs @@ -6,7 +6,8 @@ use crate::{ }; use ark_ec::pairing::Pairing; use num_traits::Zero; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::TimezoneInfo; #[test] fn we_can_handle_calling_with_an_empty_committable_column() { @@ -247,7 +248,7 @@ fn we_can_compute_a_dynamic_dory_commitment_with_mixed_committable_columns() { CommittableColumn::VarChar(vec![[16, 0, 0, 0]]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[17, 18, 19, 20], ), ], @@ -322,7 +323,7 @@ fn we_can_compute_a_dynamic_dory_commitment_with_mixed_committable_columns_with_ CommittableColumn::VarChar(vec![[16, 0, 0, 0]]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[17, 18, 19, 20], ), ], @@ -397,7 +398,7 @@ fn we_can_compute_a_dynamic_dory_commitment_with_mixed_committable_columns_with_ CommittableColumn::VarChar(vec![[16, 0, 0, 0]]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[-18, -17, 17, 18], ), ], @@ -485,7 +486,7 @@ fn we_can_compute_a_dynamic_dory_commitment_with_mixed_committable_columns_with_ CommittableColumn::VarChar(vec![[16, 0, 0, 0]]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, &[-18, -17, 17, 18], ), ], diff --git a/crates/proof-of-sql/src/proof_primitive/dory/pack_scalars.rs b/crates/proof-of-sql/src/proof_primitive/dory/pack_scalars.rs index c39bfcbcc..9c5c0f5b4 100644 --- a/crates/proof-of-sql/src/proof_primitive/dory/pack_scalars.rs +++ b/crates/proof-of-sql/src/proof_primitive/dory/pack_scalars.rs @@ -468,7 +468,8 @@ pub fn bit_table_and_scalars_for_packed_msm( mod tests { use super::*; use crate::base::math::decimal::Precision; - use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; + use proof_of_sql_parser::posql_time::PoSQLTimeUnit; + use sqlparser::ast::TimezoneInfo; #[test] fn we_can_get_a_bit_table() { @@ -491,7 +492,7 @@ mod tests { CommittableColumn::Scalar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0]]), CommittableColumn::VarChar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]), CommittableColumn::Boolean(&[true, false]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[1]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::WithTimeZone, &[1]), ]; let offset = 0; @@ -526,7 +527,7 @@ mod tests { CommittableColumn::Scalar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0]]), CommittableColumn::VarChar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]), CommittableColumn::Boolean(&[true, false]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[1]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::WithTimeZone, &[1]), ]; let offset = 1; @@ -561,7 +562,7 @@ mod tests { CommittableColumn::Scalar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0]]), CommittableColumn::VarChar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]), CommittableColumn::Boolean(&[true, false]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[1]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::WithTimeZone, &[1]), ]; let offset = 0; @@ -601,7 +602,7 @@ mod tests { CommittableColumn::Scalar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0]]), CommittableColumn::VarChar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]), CommittableColumn::Boolean(&[true, false]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[1]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::WithTimeZone, &[1]), ]; let offset = 2; @@ -641,7 +642,7 @@ mod tests { CommittableColumn::Scalar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0]]), CommittableColumn::VarChar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]), CommittableColumn::Boolean(&[true, false]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[1]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::WithTimeZone, &[1]), ]; let offset = 0; @@ -681,7 +682,7 @@ mod tests { CommittableColumn::Scalar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0], [4, 0, 0, 0]]), CommittableColumn::VarChar(vec![[1, 0, 0, 0], [2, 0, 0, 0], [3, 0, 0, 0]]), CommittableColumn::Boolean(&[true, false]), - CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[1]), + CommittableColumn::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::WithTimeZone, &[1]), ]; let offset = 1; @@ -1023,7 +1024,7 @@ mod tests { CommittableColumn::Boolean(&[true, false, true, false, true]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::WithTimeZone, &[1, 2, 3, 4, 5], ), ]; @@ -1062,7 +1063,7 @@ mod tests { CommittableColumn::Boolean(&[true, false, true, false, true]), CommittableColumn::TimestampTZ( PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::WithTimeZone, &[1, 2, 3, 4, 5], ), ]; diff --git a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs index a0e76a031..3a299e920 100644 --- a/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/dyn_proof_expr_builder.rs @@ -1,28 +1,14 @@ use super::ConversionError; use crate::{ - base::{ - database::{ColumnRef, LiteralValue}, - map::IndexMap, - math::{ - decimal::{DecimalError, Precision}, - i256::I256, - BigDecimalExt, - }, - }, - sql::{ - parse::{ - dyn_proof_expr_builder::DecimalError::{InvalidPrecision, InvalidScale}, - ConversionError::DecimalConversionError, - }, - proof_exprs::{ColumnExpr, DynProofExpr, ProofExpr}, - }, + base::{database::ColumnRef, map::IndexMap, math::i256::I256}, + sql::proof_exprs::{ColumnExpr, DynProofExpr, ProofExpr}, }; -use alloc::{borrow::ToOwned, boxed::Box, format, string::ToString}; -use proof_of_sql_parser::{ - intermediate_ast::{AggregationOperator, Expression, Literal}, - posql_time::{PoSQLTimeUnit, PoSQLTimestampError}, +use alloc::{boxed::Box, format, string::ToString, vec, vec::Vec}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::{ + BinaryOperator, DataType, ExactNumberInfo, Expr, FunctionArg, FunctionArgExpr, Ident, + ObjectName, UnaryOperator, Value, }; -use sqlparser::ast::{BinaryOperator, Ident, UnaryOperator}; /// Builder that enables building a `proofs::sql::proof_exprs::DynProofExpr` from /// a `proof_of_sql_parser::intermediate_ast::Expression`. @@ -46,8 +32,8 @@ impl<'a> DynProofExprBuilder<'a> { in_agg_scope: true, } } - /// Builds a `proofs::sql::proof_exprs::DynProofExpr` from a `proof_of_sql_parser::intermediate_ast::Expression` - pub fn build(&self, expr: &Expression) -> Result { + /// Builds a `proofs::sql::proof_exprs::DynProofExpr` from a `sqlparser::ast::Expr` + pub fn build(&self, expr: &Expr) -> Result { self.visit_expr(expr) } } @@ -55,15 +41,24 @@ impl<'a> DynProofExprBuilder<'a> { #[allow(clippy::match_wildcard_for_single_variants)] // Private interface impl DynProofExprBuilder<'_> { - fn visit_expr(&self, expr: &Expression) -> Result { + fn visit_expr(&self, expr: &Expr) -> Result { match expr { - Expression::Column(identifier) => self.visit_column((*identifier).into()), - Expression::Literal(lit) => self.visit_literal(lit), - Expression::Binary { op, left, right } => { - self.visit_binary_expr(&(*op).into(), left, right) + Expr::Identifier(identifier) => self.visit_column(identifier.clone()), + Expr::Value(value) => self.visit_literal(&Expr::Value(value.clone())), + Expr::BinaryOp { op, left, right } => { + self.visit_binary_expr(op, left.as_ref(), right.as_ref()) + } + Expr::UnaryOp { op, expr } => self.visit_unary_expr(*op, expr.as_ref()), + Expr::Function(function) => { + if let Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(inner_expr))) = + function.args.first() + { + return self.visit_aggregate_expr(&function.name.to_string(), inner_expr); + } + Err(ConversionError::Unprovable { + error: format!("Function {function:?} has unsupported arguments"), + }) } - Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr), - Expression::Aggregation { op, expr } => self.visit_aggregate_expr(*op, expr), _ => Err(ConversionError::Unprovable { error: format!("Expression {expr:?} is not supported yet"), }), @@ -81,60 +76,96 @@ impl DynProofExprBuilder<'_> { ))) } + /// Converts a `Expr` into a `DynProofExpr` + /// + /// # Panics + /// - Panics if: + /// - `u8::try_from` for precision fails (precision out of range). + /// - `i8::try_from` for scale fails (scale out of range). + /// - A scalar string does not contain exactly 4 limbs. + /// - Parsing scalar limbs fails. + /// + /// # Examples + /// ``` + /// let expr = Expr::Value(Value::Boolean(true)); + /// let dyn_expr = visit_literal(&expr).unwrap(); + /// ``` #[allow(clippy::unused_self)] - fn visit_literal(&self, lit: &Literal) -> Result { - match lit { - Literal::Boolean(b) => Ok(DynProofExpr::new_literal(LiteralValue::Boolean(*b))), - Literal::BigInt(i) => Ok(DynProofExpr::new_literal(LiteralValue::BigInt(*i))), - Literal::Int128(i) => Ok(DynProofExpr::new_literal(LiteralValue::Int128(*i))), - Literal::Decimal(d) => { - let raw_scale = d.scale(); - let scale = raw_scale.try_into().map_err(|_| InvalidScale { - scale: raw_scale.to_string(), - })?; - let precision = - Precision::try_from(d.precision()).map_err(|_| DecimalConversionError { - source: InvalidPrecision { - error: d.precision().to_string(), - }, - })?; - Ok(DynProofExpr::new_literal(LiteralValue::Decimal75( - precision, - scale, - I256::from_num_bigint( - &d.try_into_bigint_with_precision_and_scale(precision.value(), scale)?, - ), - ))) + fn visit_literal(&self, expr: &Expr) -> Result { + match expr { + Expr::Value(Value::Boolean(b)) => { + Ok(DynProofExpr::new_literal(Expr::Value(Value::Boolean(*b)))) } - Literal::VarChar(s) => Ok(DynProofExpr::new_literal(LiteralValue::VarChar(s.clone()))), - Literal::Timestamp(its) => { - let timestamp = match its.timeunit() { - PoSQLTimeUnit::Nanosecond => { - its.timestamp().timestamp_nanos_opt().ok_or_else(|| { - PoSQLTimestampError::UnsupportedPrecision{ error: "Timestamp out of range: - Valid nanosecond timestamps must be between 1677-09-21T00:12:43.145224192 - and 2262-04-11T23:47:16.854775807.".to_owned() + Expr::Value(Value::Number(value, _)) => value.parse::().map_or_else( + |_| { + Err(ConversionError::InvalidNumberFormat { + value: value.clone(), + }) + }, + |n| { + let number_expr = Expr::Value(Value::Number(n.to_string(), false)); + Ok(DynProofExpr::new_literal(number_expr)) + }, + ), + Expr::Value(Value::SingleQuotedString(s)) => Ok(DynProofExpr::new_literal( + Expr::Value(Value::SingleQuotedString(s.clone())), + )), + Expr::TypedString { data_type, value } => match data_type { + DataType::Decimal(ExactNumberInfo::PrecisionAndScale(precision, scale)) => { + let parsed_value = I256::from_string(value).map_err(|_| { + ConversionError::InvalidDecimalFormat { + value: value.clone(), + precision: u8::try_from(*precision) + .expect("Precision must fit into u8"), + scale: i8::try_from(*scale).expect("Scale must fit into i8"), } - })? - } - PoSQLTimeUnit::Microsecond => its.timestamp().timestamp_micros(), - PoSQLTimeUnit::Millisecond => its.timestamp().timestamp_millis(), - PoSQLTimeUnit::Second => its.timestamp().timestamp(), - }; - - Ok(DynProofExpr::new_literal(LiteralValue::TimeStampTZ( - its.timeunit(), - its.timezone(), - timestamp, - ))) - } + })?; + Ok(DynProofExpr::new_literal(Expr::TypedString { + data_type: DataType::Decimal(ExactNumberInfo::PrecisionAndScale( + *precision, *scale, + )), + value: parsed_value.to_string(), + })) + } + DataType::Timestamp(Some(precision), tz) => { + let time_unit = + PoSQLTimeUnit::from_precision(*precision).unwrap_or(PoSQLTimeUnit::Second); + let parsed_value = value.parse::().map_err(|_| { + ConversionError::InvalidTimestampFormat { + value: value.clone(), + } + })?; + Ok(DynProofExpr::new_literal(Expr::TypedString { + data_type: DataType::Timestamp(Some(time_unit.into()), *tz), + value: parsed_value.to_string(), + })) + } + DataType::Custom(_, _) if data_type.to_string() == "scalar" => { + let scalar_str = value.strip_prefix("scalar:").unwrap_or_default(); + let limbs: Vec = scalar_str + .split(',') + .map(|x| x.parse::().unwrap_or_default()) + .collect(); + assert!(limbs.len() == 4, "Scalar must have exactly 4 limbs"); + Ok(DynProofExpr::new_literal(Expr::TypedString { + data_type: DataType::Custom(ObjectName(vec![]), vec![]), + value: format!("{},{},{},{}", limbs[0], limbs[1], limbs[2], limbs[3]), + })) + } + _ => Err(ConversionError::UnsupportedDataType { + data_type: data_type.to_string(), + }), + }, + _ => Err(ConversionError::UnsupportedLiteral { + literal: format!("{expr:?}"), + }), } } fn visit_unary_expr( &self, op: UnaryOperator, - expr: &Expression, + expr: &Expr, ) -> Result { let expr = self.visit_expr(expr); match op { @@ -149,8 +180,8 @@ impl DynProofExprBuilder<'_> { fn visit_binary_expr( &self, op: &BinaryOperator, - left: &Expression, - right: &Expression, + left: &Expr, + right: &Expr, ) -> Result { match op { BinaryOperator::And => { @@ -205,28 +236,23 @@ impl DynProofExprBuilder<'_> { } } - fn visit_aggregate_expr( - &self, - op: AggregationOperator, - expr: &Expression, - ) -> Result { + fn visit_aggregate_expr(&self, op: &str, expr: &Expr) -> Result { if self.in_agg_scope { return Err(ConversionError::InvalidExpression { expression: "nested aggregations are invalid".to_string(), }); } let expr = DynProofExprBuilder::new_agg(self.column_mapping).visit_expr(expr)?; + match (op, expr.data_type().is_numeric()) { - (AggregationOperator::Count, _) | (AggregationOperator::Sum, true) => { - Ok(DynProofExpr::new_aggregate(op, expr)) - } - (AggregationOperator::Sum, false) => Err(ConversionError::InvalidExpression { + ("COUNT", _) | ("SUM", true) => Ok(DynProofExpr::new_aggregate(op, expr)?), + ("SUM", false) => Err(ConversionError::InvalidExpression { expression: format!( - "Aggregation operator {op:?} doesn't work with non-numeric types" + "Aggregation operator {op} doesn't work with non-numeric types" ), }), _ => Err(ConversionError::Unprovable { - error: format!("Aggregation operator {op:?} is not supported at this location"), + error: format!("Aggregation operator {op} is not supported at this location"), }), } } diff --git a/crates/proof-of-sql/src/sql/parse/enriched_expr.rs b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs index 1772cbf3b..251254ca5 100644 --- a/crates/proof-of-sql/src/sql/parse/enriched_expr.rs +++ b/crates/proof-of-sql/src/sql/parse/enriched_expr.rs @@ -4,8 +4,9 @@ use crate::{ sql::proof_exprs::DynProofExpr, }; use alloc::boxed::Box; -use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, Expression}; -use sqlparser::ast::Ident; +use proof_of_sql_parser::sqlparser::SqlAliasedResultExpr; +use sqlparser::ast::{Expr, Ident}; + /// Enriched expression /// /// An enriched expression consists of an `proof_of_sql_parser::intermediate_ast::AliasedResultExpr` @@ -13,7 +14,7 @@ use sqlparser::ast::Ident; /// If the `DynProofExpr` is `None`, the `EnrichedExpr` is not provable. pub struct EnrichedExpr { /// The remaining expression after the provable expression plan has been extracted. - pub residue_expression: AliasedResultExpr, + pub residue_expression: SqlAliasedResultExpr, /// The extracted provable expression plan if it exists. pub dyn_proof_expr: Option, } @@ -24,7 +25,10 @@ impl EnrichedExpr { /// If the expression is not provable, the `dyn_proof_expr` will be `None`. /// Otherwise the `dyn_proof_expr` will contain the provable expression plan /// and the `residue_expression` will contain the remaining expression. - pub fn new(expression: AliasedResultExpr, column_mapping: &IndexMap) -> Self { + pub fn new( + expression: SqlAliasedResultExpr, + column_mapping: &IndexMap, + ) -> Self { // TODO: Using new_agg (ironically) disables aggregations in `QueryExpr` for now. // Re-enable aggregations when we add `GroupByExec` generalizations. let res_dyn_proof_expr = @@ -33,8 +37,8 @@ impl EnrichedExpr { Ok(dyn_proof_expr) => { let alias = expression.alias; Self { - residue_expression: AliasedResultExpr { - expr: Box::new(Expression::Column(alias)), + residue_expression: SqlAliasedResultExpr { + expr: Box::new(Expr::Identifier(alias.clone())), alias, }, dyn_proof_expr: Some(dyn_proof_expr), @@ -54,7 +58,7 @@ impl EnrichedExpr { pub fn get_alias(&self) -> Option { self.residue_expression .try_as_identifier() - .map(|identifier| Ident::new(identifier.as_str())) + .map(|identifier| Ident::new(identifier.value.as_str())) } /// Is the `EnrichedExpr` provable diff --git a/crates/proof-of-sql/src/sql/parse/error.rs b/crates/proof-of-sql/src/sql/parse/error.rs index 6e738ee5e..2c0237522 100644 --- a/crates/proof-of-sql/src/sql/parse/error.rs +++ b/crates/proof-of-sql/src/sql/parse/error.rs @@ -40,6 +40,27 @@ pub enum ConversionError { actual: ColumnType, }, + #[snafu(display("Unsupported expression: {error}"))] + /// The expression is unsupported + UnsupportedExpr { + /// The error for unsupported expression + error: String, + }, + + #[snafu(display("Invalid precision value: {precision}"))] + /// Precision value is invalid + InvalidPrecision { + /// The invalid precision value + precision: String, + }, + + #[snafu(display("Invalid scale value: {scale}"))] + /// Scale value is invalid + InvalidScale { + /// The invalid scale value + scale: String, + }, + #[snafu(display("Left side has '{left_type}' type but right side has '{right_type}' type"))] /// Data types do not match DataTypeMismatch { @@ -153,6 +174,54 @@ pub enum ConversionError { /// The underlying error message error: String, }, + + #[snafu(display("Invalid number format: {value:?}"))] + /// Represents an error due to an invalid number format. + InvalidNumberFormat { + /// The invalid number value as a string. + value: String, + }, + + #[snafu(display( + "Invalid decimal format: {value:?} with precision {precision} and scale {scale}" + ))] + /// Represents an error due to an invalid decimal format. + InvalidDecimalFormat { + /// The invalid decimal value as a string. + value: String, + /// The precision of the decimal value. + precision: u8, + /// The scale of the decimal value. + scale: i8, + }, + + #[snafu(display("Unsupported literal type: {literal:?}"))] + /// The literal type is not supported. + UnsupportedLiteral { + /// The unsupported literal type as a string. + literal: String, + }, + + #[snafu(display("Unsupported data type: {data_type:?}"))] + /// The data type is not supported. + UnsupportedDataType { + /// The unsupported data type as a string. + data_type: String, + }, + + #[snafu(display("Invalid timestamp format: {value:?}"))] + /// The timestamp format is invalid. + InvalidTimestampFormat { + /// The invalid timestamp value as a string. + value: String, + }, + + #[snafu(display("Timestamp out of range: {value:?}"))] + /// The timestamp value is out of the allowed range. + TimestampOutOfRange { + /// The out-of-range timestamp value as a string. + value: String, + }, } impl From for ConversionError { diff --git a/crates/proof-of-sql/src/sql/parse/filter_exec_builder.rs b/crates/proof-of-sql/src/sql/parse/filter_exec_builder.rs index 5512d051c..836531d27 100644 --- a/crates/proof-of-sql/src/sql/parse/filter_exec_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/filter_exec_builder.rs @@ -1,7 +1,7 @@ use super::{where_expr_builder::WhereExprBuilder, ConversionError, EnrichedExpr}; use crate::{ base::{ - database::{ColumnRef, LiteralValue, TableRef}, + database::{ColumnRef, TableRef}, map::IndexMap, }, sql::{ @@ -12,7 +12,8 @@ use crate::{ use alloc::{boxed::Box, vec, vec::Vec}; use itertools::Itertools; use proof_of_sql_parser::intermediate_ast::Expression; -use sqlparser::ast::Ident; +use sqlparser::ast::{Expr, Ident, Value}; + pub struct FilterExecBuilder { table_expr: Option, where_expr: Option, @@ -56,7 +57,7 @@ impl FilterExecBuilder { if let Some(plan) = &enriched_expr.dyn_proof_expr { self.filter_result_expr_list.push(AliasedDynProofExpr { expr: plan.clone(), - alias: enriched_expr.residue_expression.alias.into(), + alias: enriched_expr.residue_expression.alias.clone(), }); } else { has_nonprovable_column = true; @@ -82,7 +83,7 @@ impl FilterExecBuilder { self.filter_result_expr_list, self.table_expr.expect("Table expr is required"), self.where_expr - .unwrap_or_else(|| DynProofExpr::new_literal(LiteralValue::Boolean(true))), + .unwrap_or_else(|| DynProofExpr::new_literal(Expr::Value(Value::Boolean(true)))), ) } } diff --git a/crates/proof-of-sql/src/sql/parse/query_context.rs b/crates/proof-of-sql/src/sql/parse/query_context.rs index 3c1b7c551..bda044424 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context.rs @@ -1,6 +1,6 @@ use crate::{ base::{ - database::{ColumnRef, LiteralValue, TableRef}, + database::{ColumnRef, TableRef}, map::{IndexMap, IndexSet}, }, sql::{ @@ -10,10 +10,11 @@ use crate::{ }, }; use alloc::{borrow::ToOwned, boxed::Box, string::ToString, vec::Vec}; -use proof_of_sql_parser::intermediate_ast::{ - AggregationOperator, AliasedResultExpr, Expression, OrderBy, Slice, +use proof_of_sql_parser::{ + intermediate_ast::{Expression, OrderBy, Slice}, + sqlparser::SqlAliasedResultExpr, }; -use sqlparser::ast::Ident; +use sqlparser::ast::{Expr, Ident, Value}; #[derive(Default, Debug)] pub struct QueryContext { @@ -28,7 +29,7 @@ pub struct QueryContext { group_by_exprs: Vec, where_expr: Option>, result_column_set: IndexSet, - res_aliased_exprs: Vec, + res_aliased_exprs: Vec, column_mapping: IndexMap, first_result_col_out_agg_scope: Option, } @@ -117,7 +118,7 @@ impl QueryContext { } #[allow(clippy::missing_panics_doc, clippy::unnecessary_wraps)] - pub fn push_aliased_result_expr(&mut self, expr: AliasedResultExpr) -> ConversionResult<()> { + pub fn push_aliased_result_expr(&mut self, expr: SqlAliasedResultExpr) -> ConversionResult<()> { assert!(&self.has_visited_group_by, "Group by must be visited first"); self.res_aliased_exprs.push(expr); @@ -160,7 +161,7 @@ impl QueryContext { /// /// Will panic if: /// - `self.res_aliased_exprs` is empty, triggering the assertion `assert!(!self.res_aliased_exprs.is_empty(), "empty aliased exprs")`. - pub fn get_aliased_result_exprs(&self) -> ConversionResult<&[AliasedResultExpr]> { + pub fn get_aliased_result_exprs(&self) -> ConversionResult<&[SqlAliasedResultExpr]> { assert!(!self.res_aliased_exprs.is_empty(), "empty aliased exprs"); // We need to check that each column alias is unique @@ -200,7 +201,7 @@ impl QueryContext { for by_expr in &self.order_by_exprs { self.res_aliased_exprs .iter() - .find(|col| col.alias == by_expr.expr) + .find(|col| col.alias == by_expr.expr.into()) .ok_or(ConversionError::InvalidOrderBy { alias: by_expr.expr.as_str().to_string(), })?; @@ -236,7 +237,7 @@ impl TryFrom<&QueryContext> for Option { fn try_from(value: &QueryContext) -> Result, Self::Error> { let where_clause = WhereExprBuilder::new(&value.column_mapping) .build(value.where_expr.clone())? - .unwrap_or_else(|| DynProofExpr::new_literal(LiteralValue::Boolean(true))); + .unwrap_or_else(|| DynProofExpr::new_literal(Expr::Value(Value::Boolean(true)))); let table = value.table.map(|table_ref| TableExpr { table_ref }).ok_or( ConversionError::InvalidExpression { expression: "QueryContext has no table_ref".to_owned(), @@ -275,8 +276,8 @@ impl TryFrom<&QueryContext> for Option { .iter() .zip(res_group_by_columns.iter()) .all(|(ident, res)| { - if let Expression::Column(res_ident) = *res.expr { - Ident::from(res_ident) == *ident + if let Expr::Identifier(res_ident) = &*res.expr { + *res_ident == *ident } else { false } @@ -286,19 +287,20 @@ impl TryFrom<&QueryContext> for Option { let sum_expr = sum_expr_columns .iter() .map(|res| { - if let Expression::Aggregation { - op: AggregationOperator::Sum, - .. - } = (*res.expr).clone() - { - let res_dyn_proof_expr = - DynProofExprBuilder::new(&value.column_mapping).build(&res.expr); - res_dyn_proof_expr - .ok() - .map(|dyn_proof_expr| AliasedDynProofExpr { - alias: res.alias.into(), - expr: dyn_proof_expr, - }) + if let Expr::Function(function) = &*res.expr { + let function_name = function.name.to_string().to_uppercase(); + if function_name == "SUM" { + let res_dyn_proof_expr = + DynProofExprBuilder::new(&value.column_mapping).build(&res.expr); + res_dyn_proof_expr + .ok() + .map(|dyn_proof_expr| AliasedDynProofExpr { + alias: res.alias.clone(), + expr: dyn_proof_expr, + }) + } else { + None + } } else { None } @@ -307,13 +309,12 @@ impl TryFrom<&QueryContext> for Option { // Check count(*) let count_column = &value.res_aliased_exprs[num_result_columns - 1]; - let count_column_compliant = matches!( - *count_column.expr, - Expression::Aggregation { - op: AggregationOperator::Count, - .. - } - ); + let count_column_compliant = if let Expr::Function(function) = &*count_column.expr { + let function_name = function.name.to_string().to_uppercase(); + function_name == "COUNT" + } else { + false + }; if !group_by_compliance || sum_expr.is_none() || !count_column_compliant { return Ok(None); @@ -321,7 +322,7 @@ impl TryFrom<&QueryContext> for Option { Ok(Some(GroupByExec::new( group_by_exprs, sum_expr.expect("the none case was just checked"), - count_column.alias.into(), + count_column.alias.clone(), table, where_clause, ))) diff --git a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs index 708cb8236..bb9472245 100644 --- a/crates/proof-of-sql/src/sql/parse/query_context_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/query_context_builder.rs @@ -4,26 +4,26 @@ use crate::base::{ try_add_subtract_column_types, try_multiply_column_types, ColumnRef, ColumnType, SchemaAccessor, TableRef, }, - math::{ - decimal::{DecimalError, Precision}, - BigDecimalExt, - }, + math::decimal::Precision, }; use alloc::{boxed::Box, format, string::ToString, vec::Vec}; use proof_of_sql_parser::{ intermediate_ast::{ - AggregationOperator, AliasedResultExpr, Expression, Literal, OrderBy, SelectResultExpr, - Slice, TableExpression, + AggregationOperator, Expression, OrderBy, SelectResultExpr, Slice, TableExpression, }, + posql_time::PoSQLTimeUnit, Identifier, ResourceId, }; -use sqlparser::ast::{BinaryOperator, UnaryOperator}; +use sqlparser::ast::{ + BinaryOperator, DataType, ExactNumberInfo, Expr, FunctionArg, FunctionArgExpr, UnaryOperator, + Value, +}; pub struct QueryContextBuilder<'a> { context: QueryContext, schema_accessor: &'a dyn SchemaAccessor, } +use proof_of_sql_parser::sqlparser::SqlAliasedResultExpr; use sqlparser::ast::Ident; - // Public interface impl<'a> QueryContextBuilder<'a> { pub fn new(schema_accessor: &'a dyn SchemaAccessor) -> Self { @@ -55,7 +55,8 @@ impl<'a> QueryContextBuilder<'a> { mut where_expr: Option>, ) -> ConversionResult { if let Some(expr) = where_expr.as_deref_mut() { - self.visit_expr(expr)?; + let sql_expr: sqlparser::ast::Expr = (*expr).clone().into(); + self.visit_expr(&sql_expr)?; } self.context.set_where_expr(where_expr); Ok(self) @@ -69,7 +70,11 @@ impl<'a> QueryContextBuilder<'a> { for column in result_exprs { match column { SelectResultExpr::ALL => self.visit_select_all_expr()?, - SelectResultExpr::AliasedResultExpr(expr) => self.visit_aliased_expr(expr)?, + SelectResultExpr::AliasedResultExpr(expr) => { + let converted_expr: Box = Box::new((*expr.expr).clone().into()); + let sql_expr = SqlAliasedResultExpr::new(converted_expr, expr.alias.into()); + self.visit_aliased_expr(sql_expr)?; + } } } self.context.toggle_result_scope(); @@ -116,53 +121,88 @@ impl<'a> QueryContextBuilder<'a> { fn visit_select_all_expr(&mut self) -> ConversionResult<()> { for (column_name, _) in self.lookup_schema() { - let column_identifier = Identifier::try_from(column_name).map_err(|e| { - ConversionError::IdentifierConversionError { - error: format!("Failed to convert Ident to Identifier: {e}"), - } - })?; - let col_expr = Expression::Column(column_identifier); - self.visit_aliased_expr(AliasedResultExpr::new(col_expr, column_identifier))?; + let column_identifier = Ident { + value: column_name.to_string(), + quote_style: None, + }; + let col_expr = Expr::Identifier(column_identifier.clone()); + self.visit_aliased_expr(SqlAliasedResultExpr::new( + Box::new(col_expr), + column_identifier, + ))?; } Ok(()) } - fn visit_aliased_expr(&mut self, aliased_expr: AliasedResultExpr) -> ConversionResult<()> { + fn visit_aliased_expr(&mut self, aliased_expr: SqlAliasedResultExpr) -> ConversionResult<()> { self.visit_expr(&aliased_expr.expr)?; self.context.push_aliased_result_expr(aliased_expr)?; Ok(()) } /// Visits the expression and returns its data type. - fn visit_expr(&mut self, expr: &Expression) -> ConversionResult { + fn visit_expr(&mut self, expr: &Expr) -> ConversionResult { match expr { - Expression::Wildcard => Ok(ColumnType::BigInt), // Since COUNT(*) = COUNT(1) - Expression::Literal(literal) => self.visit_literal(literal), - Expression::Column(_) => self.visit_column_expr(expr), - Expression::Unary { op, expr } => self.visit_unary_expr((*op).into(), expr), - Expression::Binary { op, left, right } => { - self.visit_binary_expr(&(*op).into(), left, right) + Expr::Wildcard => Ok(ColumnType::BigInt), // Since COUNT(*) = COUNT(1) + Expr::Value(_) => self.visit_literal(expr), + Expr::Identifier(_) | Expr::CompoundIdentifier(_) | Expr::QualifiedWildcard(_) => { + self.visit_column_expr(expr) } - Expression::Aggregation { op, expr } => self.visit_agg_expr(*op, expr), + Expr::UnaryOp { op, expr } => self.visit_unary_expr(*op, expr), + Expr::BinaryOp { op, left, right } => self.visit_binary_expr(&op.clone(), left, right), + Expr::Function(function) => { + let function_name = function.name.to_string().to_uppercase(); + match function_name.as_str() { + "SUM" | "COUNT" | "MAX" | "MIN" | "FIRST" => { + let agg_op = match function_name.as_str() { + "SUM" => AggregationOperator::Sum, + "COUNT" => AggregationOperator::Count, + "MAX" => AggregationOperator::Max, + "MIN" => AggregationOperator::Min, + "FIRST" => AggregationOperator::First, + _ => unreachable!(), + }; + if let Some(arg) = function.args.first() { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) = arg { + self.visit_agg_expr(agg_op, expr) + } else { + Err(ConversionError::Unprovable { + error: "Aggregation with named arguments is not supported." + .to_string(), + }) + } + } else { + Err(ConversionError::Unprovable { + error: "Aggregation function requires at least one argument." + .to_string(), + }) + } + } + _ => Err(ConversionError::Unprovable { + error: format!("Unsupported function: {function_name}"), + }), + } + } + _ => Err(ConversionError::UnsupportedExpr { + error: format!("Unsupported expression: {expr:?}"), + }), } } /// # Panics /// Panics if the expression is not a column expression. - fn visit_column_expr(&mut self, expr: &Expression) -> ConversionResult { - let identifier = match expr { - Expression::Column(identifier) => *identifier, + fn visit_column_expr(&mut self, expr: &Expr) -> ConversionResult { + match expr { + Expr::Identifier(identifier) => self.visit_column_identifier(identifier), _ => panic!("Must be a column expression"), - }; - - self.visit_column_identifier(&identifier.into()) + } } fn visit_binary_expr( &mut self, op: &BinaryOperator, - left: &Expression, - right: &Expression, + left: &Expr, + right: &Expr, ) -> ConversionResult { let left_dtype = self.visit_expr(left)?; let right_dtype = self.visit_expr(right)?; @@ -186,11 +226,7 @@ impl<'a> QueryContextBuilder<'a> { } } - fn visit_unary_expr( - &mut self, - op: UnaryOperator, - expr: &Expression, - ) -> ConversionResult { + fn visit_unary_expr(&mut self, op: UnaryOperator, expr: &Expr) -> ConversionResult { match op { UnaryOperator::Not => { let dtype = self.visit_expr(expr)?; @@ -212,7 +248,7 @@ impl<'a> QueryContextBuilder<'a> { fn visit_agg_expr( &mut self, op: AggregationOperator, - expr: &Expression, + expr: &Expr, ) -> ConversionResult { self.context.set_in_agg_scope(true)?; @@ -236,24 +272,54 @@ impl<'a> QueryContextBuilder<'a> { } } + /// # Panics + /// This function will panic if the precision value cannot be wrapped #[allow(clippy::unused_self)] - fn visit_literal(&self, literal: &Literal) -> Result { - match literal { - Literal::Boolean(_) => Ok(ColumnType::Boolean), - Literal::BigInt(_) => Ok(ColumnType::BigInt), - Literal::Int128(_) => Ok(ColumnType::Int128), - Literal::VarChar(_) => Ok(ColumnType::VarChar), - Literal::Decimal(d) => { - let precision = Precision::try_from(d.precision())?; - let scale = d.scale(); - Ok(ColumnType::Decimal75( - precision, - scale.try_into().map_err(|_| DecimalError::InvalidScale { - scale: scale.to_string(), - })?, - )) + fn visit_literal(&self, expr: &Expr) -> Result { + match expr { + Expr::Value(Value::Boolean(_)) => Ok(ColumnType::Boolean), + Expr::Value(Value::Number(value, _)) => { + let n = + value + .parse::() + .map_err(|_| ConversionError::InvalidNumberFormat { + value: value.clone(), + })?; + if n >= i128::from(i64::MIN) && n <= i128::from(i64::MAX) { + Ok(ColumnType::BigInt) + } else { + Ok(ColumnType::Int128) + } } - Literal::Timestamp(its) => Ok(ColumnType::TimestampTZ(its.timeunit(), its.timezone())), + Expr::Value(Value::SingleQuotedString(_)) => Ok(ColumnType::VarChar), + Expr::TypedString { data_type, .. } => match data_type { + DataType::Decimal(ExactNumberInfo::PrecisionAndScale(precision, scale)) => { + let precision = u8::try_from(*precision).map_err(|_| { + ConversionError::InvalidPrecision { + precision: precision.to_string(), + } + })?; + let scale = + i8::try_from(*scale).map_err(|_| ConversionError::InvalidScale { + scale: scale.to_string(), + })?; + Ok(ColumnType::Decimal75( + Precision::new(precision).unwrap(), + scale, + )) + } + DataType::Timestamp(Some(precision), tz) => { + let time_unit = + PoSQLTimeUnit::from_precision(*precision).unwrap_or(PoSQLTimeUnit::Second); + Ok(ColumnType::TimestampTZ(time_unit, *tz)) + } + _ => Err(ConversionError::UnsupportedDataType { + data_type: data_type.to_string(), + }), + }, + _ => Err(ConversionError::UnsupportedLiteral { + literal: format!("{expr:?}"), + }), } } diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index 29a84d0f9..5f4a4afb7 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -1,7 +1,13 @@ use super::ConversionError; use crate::{ base::{ - database::{ColumnType, TableRef, TestSchemaAccessor}, + database::{ + expr_utility::{ + add as padd, aliased_expr, col, count, count_all, lit, max, min, mul as pmul, + sub as psub, sum, + }, + ColumnType, TableRef, TestSchemaAccessor, + }, map::{indexmap, IndexMap, IndexSet}, }, sql::{ @@ -15,10 +21,10 @@ use itertools::Itertools; use proof_of_sql_parser::{ intermediate_ast::OrderByDirection::*, sql::SelectStatementParser, - utility::{ - add as padd, aliased_expr, col, count, count_all, lit, max, min, mul as pmul, sub as psub, - sum, - }, + // utility::{ + // add as padd, aliased_expr, col, count, count_all, lit, max, min, mul as pmul, sub as psub, + // sum, + // }, }; use sqlparser::ast::Ident; diff --git a/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs b/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs index 4201e68eb..3510ebba8 100644 --- a/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs +++ b/crates/proof-of-sql/src/sql/parse/where_expr_builder.rs @@ -30,7 +30,8 @@ impl<'a> WhereExprBuilder<'a> { ) -> Result, ConversionError> { where_expr .map(|where_expr| { - let expr_plan = self.builder.build(&where_expr)?; + let converted_expr = (*where_expr).into(); + let expr_plan = self.builder.build(&converted_expr)?; // Ensure that the expression is a boolean expression match expr_plan.data_type() { ColumnType::Boolean => Ok(expr_plan), diff --git a/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs b/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs index 560362df4..ecff0dfec 100644 --- a/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/where_expr_builder_tests.rs @@ -1,6 +1,6 @@ use crate::{ base::{ - database::{ColumnRef, ColumnType, LiteralValue, TestSchemaAccessor}, + database::{ColumnRef, ColumnType, TestSchemaAccessor}, map::{indexmap, IndexMap}, math::decimal::Precision, }, @@ -12,11 +12,11 @@ use crate::{ use bigdecimal::BigDecimal; use core::str::FromStr; use proof_of_sql_parser::{ - posql_time::{PoSQLTimeUnit, PoSQLTimeZone, PoSQLTimestamp}, + posql_time::{PoSQLTimeUnit, PoSQLTimestamp}, utility::*, SelectStatement, }; -use sqlparser::ast::Ident; +use sqlparser::ast::{Expr, Ident, TimezoneInfo, Value}; /// # Panics /// @@ -61,7 +61,7 @@ fn get_column_mappings_for_testing() -> IndexMap { ColumnRef::new( tab_ref, "timestamp_second_column".into(), - ColumnType::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc()), + ColumnType::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None), ), ); column_mapping.insert( @@ -69,7 +69,7 @@ fn get_column_mappings_for_testing() -> IndexMap { ColumnRef::new( tab_ref, "timestamp_millisecond_column".into(), - ColumnType::TimestampTZ(PoSQLTimeUnit::Millisecond, PoSQLTimeZone::utc()), + ColumnType::TimestampTZ(PoSQLTimeUnit::Millisecond, TimezoneInfo::None), ), ); column_mapping.insert( @@ -77,7 +77,7 @@ fn get_column_mappings_for_testing() -> IndexMap { ColumnRef::new( tab_ref, "timestamp_microsecond_column".into(), - ColumnType::TimestampTZ(PoSQLTimeUnit::Microsecond, PoSQLTimeZone::utc()), + ColumnType::TimestampTZ(PoSQLTimeUnit::Microsecond, TimezoneInfo::None), ), ); column_mapping.insert( @@ -85,7 +85,7 @@ fn get_column_mappings_for_testing() -> IndexMap { ColumnRef::new( tab_ref, "timestamp_nanosecond_column".into(), - ColumnType::TimestampTZ(PoSQLTimeUnit::Nanosecond, PoSQLTimeZone::utc()), + ColumnType::TimestampTZ(PoSQLTimeUnit::Nanosecond, TimezoneInfo::None), ), ); column_mapping @@ -149,7 +149,10 @@ fn we_can_directly_check_whether_bigint_columns_ge_int128() { "bigint_column".into(), ColumnType::BigInt, ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), + DynProofExpr::Literal(LiteralExpr::new(Expr::Value(Value::Number( + "-12345".to_string(), + false, + )))), false, ) .unwrap(); @@ -171,7 +174,10 @@ fn we_can_directly_check_whether_bigint_columns_le_int128() { "bigint_column".into(), ColumnType::BigInt, ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::Int128(-12345))), + DynProofExpr::Literal(LiteralExpr::new(Expr::Value(Value::Number( + "-12345".to_string(), + false, + )))), true, ) .unwrap(); diff --git a/crates/proof-of-sql/src/sql/postprocessing/error.rs b/crates/proof-of-sql/src/sql/postprocessing/error.rs index 054b07358..3497b6aa9 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/error.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/error.rs @@ -41,6 +41,12 @@ pub enum PostprocessingError { /// The underlying error message error: String, }, + /// Unsupported expression encountered during postprocessing + #[snafu(display("Unsupported expression: {error}"))] + UnsupportedExpr { + /// The underlying error message + error: String, + }, /// Errors in aggregate columns #[snafu(transparent)] AggregateColumnsError { diff --git a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs index 7bc701ba8..889e8d9fd 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing.rs @@ -4,27 +4,30 @@ use crate::base::{ map::{indexmap, IndexMap, IndexSet}, scalar::Scalar, }; -use alloc::{boxed::Box, format, string::ToString, vec, vec::Vec}; +use alloc::{ + boxed::Box, + format, + string::{String, ToString}, + vec, + vec::Vec, +}; use bumpalo::Bump; use itertools::{izip, Itertools}; -use proof_of_sql_parser::{ - intermediate_ast::{AggregationOperator, AliasedResultExpr, Expression}, - Identifier, -}; +use proof_of_sql_parser::sqlparser::SqlAliasedResultExpr; use serde::{Deserialize, Serialize}; -use sqlparser::ast::Ident; +use sqlparser::ast::{Expr, Ident}; /// A group by expression #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct GroupByPostprocessing { - /// A list of `AliasedResultExpr` that exclusively use identifiers in the group by clause or results of aggregation expressions - remainder_exprs: Vec, + /// A list of `SqlAliasedResultExpr` that exclusively use identifiers in the group by clause or results of aggregation expressions + remainder_exprs: Vec, /// A list of identifiers in the group by clause group_by_identifiers: Vec, /// A list of aggregation expressions - aggregation_exprs: Vec<(AggregationOperator, Expression, Ident)>, + aggregation_exprs: Vec<(String, Expr, Ident)>, } /// Check whether multiple layers of aggregation exist within the same GROUP BY clause @@ -32,31 +35,43 @@ pub struct GroupByPostprocessing { /// /// If the context is within an aggregation function, then any aggregation function is considered nested. /// Otherwise we need two layers of aggregation functions to be nested. -fn contains_nested_aggregation(expr: &Expression, is_agg: bool) -> bool { +fn contains_nested_aggregation(expr: &Expr, is_agg: bool) -> bool { match expr { - Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => false, - Expression::Aggregation { expr, .. } => is_agg || contains_nested_aggregation(expr, true), - Expression::Binary { left, right, .. } => { + Expr::Function(function) => { + is_agg + || function.args.iter().any(|arg| match arg { + sqlparser::ast::FunctionArg::Unnamed( + sqlparser::ast::FunctionArgExpr::Expr(arg_expr), + ) + | sqlparser::ast::FunctionArg::Named { + arg: sqlparser::ast::FunctionArgExpr::Expr(arg_expr), + .. + } => contains_nested_aggregation(arg_expr, true), + _ => false, + }) + } + Expr::BinaryOp { left, right, .. } => { contains_nested_aggregation(left, is_agg) || contains_nested_aggregation(right, is_agg) } - Expression::Unary { expr, .. } => contains_nested_aggregation(expr, is_agg), + Expr::UnaryOp { expr, .. } => contains_nested_aggregation(expr, is_agg), + _ => false, } } /// Get identifiers NOT in aggregate functions -fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet { +fn get_free_identifiers_from_expr(expr: &Expr) -> IndexSet { match expr { - Expression::Column(identifier) => IndexSet::from_iter([(*identifier).into()]), - Expression::Literal(_) | Expression::Aggregation { .. } | Expression::Wildcard => { - IndexSet::default() - } - Expression::Binary { left, right, .. } => { + Expr::Identifier(identifier) => IndexSet::from_iter([identifier.clone()]), + // Expr::Value(_) | Expr::Function(_) | Expr::Wildcard => IndexSet::default(), + Expr::BinaryOp { left, right, .. } => { let mut left_identifiers = get_free_identifiers_from_expr(left); let right_identifiers = get_free_identifiers_from_expr(right); left_identifiers.extend(right_identifiers); left_identifiers } - Expression::Unary { expr, .. } => get_free_identifiers_from_expr(expr), + Expr::UnaryOp { expr, .. } => get_free_identifiers_from_expr(expr), + + _ => IndexSet::default(), } } @@ -70,66 +85,60 @@ fn get_free_identifiers_from_expr(expr: &Expression) -> IndexSet { /// Will panic if the key for an aggregation expression cannot be parsed as a valid identifier /// or if there are issues retrieving an identifier from the map. fn get_aggregate_and_remainder_expressions( - expr: Expression, - aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>, -) -> Result { + expr: Expr, + aggregation_expr_map: &mut IndexMap<(String, Expr), Ident>, +) -> Result { match expr { - Expression::Column(_) | Expression::Literal(_) | Expression::Wildcard => Ok(expr), - Expression::Aggregation { op, expr } => { - let key = (op, (*expr)); + Expr::Identifier(_) | Expr::Value(_) | Expr::Wildcard => Ok(expr), + Expr::Function(function) => { + let key = (function.name.to_string(), Expr::Function(function.clone())); + if let Some(ident) = aggregation_expr_map.get(&key) { - let identifier = Identifier::try_from(ident.clone()).map_err(|e| { - PostprocessingError::IdentifierConversionError { - error: format!("Failed to convert Ident to Identifier: {e}"), - } - })?; - Ok(Expression::Column(identifier)) + Ok(Expr::Identifier(ident.clone())) } else { let new_ident = Ident { value: format!("__col_agg_{}", aggregation_expr_map.len()), quote_style: None, }; - let new_identifier = Identifier::try_from(new_ident.clone()).map_err(|e| { - PostprocessingError::IdentifierConversionError { - error: format!("Failed to convert Ident to Identifier: {e}"), - } - })?; - - aggregation_expr_map.insert(key, new_ident); - Ok(Expression::Column(new_identifier)) + aggregation_expr_map.insert(key, new_ident.clone()); + Ok(Expr::Identifier(new_ident)) } } - Expression::Binary { op, left, right } => { + Expr::BinaryOp { op, left, right } => { let left_remainder = - get_aggregate_and_remainder_expressions(*left, aggregation_expr_map); + get_aggregate_and_remainder_expressions(*left, aggregation_expr_map)?; let right_remainder = - get_aggregate_and_remainder_expressions(*right, aggregation_expr_map); - Ok(Expression::Binary { + get_aggregate_and_remainder_expressions(*right, aggregation_expr_map)?; + Ok(Expr::BinaryOp { op, - left: Box::new(left_remainder?), - right: Box::new(right_remainder?), + left: Box::new(left_remainder), + right: Box::new(right_remainder), }) } - Expression::Unary { op, expr } => { - let remainder = get_aggregate_and_remainder_expressions(*expr, aggregation_expr_map); - Ok(Expression::Unary { + + Expr::UnaryOp { op, expr } => { + let remainder = get_aggregate_and_remainder_expressions(*expr, aggregation_expr_map)?; + Ok(Expr::UnaryOp { op, - expr: Box::new(remainder?), + expr: Box::new(remainder), }) } + _ => Err(PostprocessingError::UnsupportedExpr { + error: format!("Expression {expr:?} is not supported yet"), + }), } } -/// Given an `AliasedResultExpr`, check if it is legitimate and if so grab the relevant aggregation expression +/// Given an `SqlAliasedResultExpr`, check if it is legitimate and if so grab the relevant aggregation expression /// # Panics /// /// Will panic if there is an issue retrieving the first element from the difference of free identifiers and group-by identifiers, indicating a logical inconsistency in the identifiers. fn check_and_get_aggregation_and_remainder( - expr: AliasedResultExpr, + expr: SqlAliasedResultExpr, group_by_identifiers: &[Ident], - aggregation_expr_map: &mut IndexMap<(AggregationOperator, Expression), Ident>, -) -> PostprocessingResult { + aggregation_expr_map: &mut IndexMap<(String, Expr), Ident>, +) -> PostprocessingResult { let free_identifiers = get_free_identifiers_from_expr(&expr.expr); let group_by_identifier_set = group_by_identifiers .iter() @@ -141,10 +150,11 @@ fn check_and_get_aggregation_and_remainder( }); } if free_identifiers.is_subset(&group_by_identifier_set) { - let remainder = get_aggregate_and_remainder_expressions(*expr.expr, aggregation_expr_map); - Ok(AliasedResultExpr { + let remainder_expr = + get_aggregate_and_remainder_expressions(*expr.expr, aggregation_expr_map)?; + Ok(SqlAliasedResultExpr { alias: expr.alias, - expr: Box::new(remainder?), + expr: Box::new(remainder_expr), }) } else { let diff = free_identifiers @@ -163,12 +173,11 @@ impl GroupByPostprocessing { /// Create a new group by expression containing the group by and aggregation expressions pub fn try_new( by_ids: Vec, - aliased_exprs: Vec, + aliased_exprs: Vec, ) -> PostprocessingResult { - let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> = - IndexMap::default(); + let mut aggregation_expr_map: IndexMap<(String, Expr), Ident> = IndexMap::default(); // Look for aggregation expressions and check for non-aggregation expressions that contain identifiers not in the group by clause - let remainder_exprs: Vec = aliased_exprs + let remainder_exprs: Vec = aliased_exprs .into_iter() .map(|aliased_expr| -> PostprocessingResult<_> { check_and_get_aggregation_and_remainder( @@ -177,7 +186,7 @@ impl GroupByPostprocessing { &mut aggregation_expr_map, ) }) - .collect::>>()?; + .collect::>>()?; let group_by_identifiers = Vec::from_iter(IndexSet::from_iter(by_ids)); Ok(Self { remainder_exprs, @@ -196,14 +205,13 @@ impl GroupByPostprocessing { } /// Get remainder expressions for SELECT - #[must_use] - pub fn remainder_exprs(&self) -> &[AliasedResultExpr] { + pub fn remainder_exprs(&self) -> &[SqlAliasedResultExpr] { &self.remainder_exprs } /// Get aggregation expressions #[must_use] - pub fn aggregation_exprs(&self) -> &[(AggregationOperator, Expression, Ident)] { + pub fn aggregation_exprs(&self) -> &[(String, Expr, Ident)] { &self.aggregation_exprs } } @@ -219,11 +227,11 @@ impl PostprocessingStep for GroupByPostprocessing { .iter() .map(|(agg_op, expr, id)| -> PostprocessingResult<_> { let evaluated_owned_column = owned_table.evaluate(expr)?; - Ok((*agg_op, (id.clone(), evaluated_owned_column))) + Ok((agg_op.to_string(), (id.clone(), evaluated_owned_column))) }) .process_results(|iter| { iter.fold( - IndexMap::<_, Vec<_>>::default(), + IndexMap::>::default(), |mut lookup, (key, val)| { lookup.entry(key).or_default().push(val); lookup @@ -243,32 +251,40 @@ impl PostprocessingStep for GroupByPostprocessing { Ok(Column::::from_owned_column(column, &alloc)) }) .collect::>>()?; + // TODO: Allow a filter let selection_in = vec![true; owned_table.num_rows()]; - let (sum_identifiers, sum_columns): (Vec<_>, Vec<_>) = evaluated_columns - .get(&AggregationOperator::Sum) - .map_or((vec![], vec![]), |tuple| { - tuple - .iter() - .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) - .unzip() - }); - let (max_identifiers, max_columns): (Vec<_>, Vec<_>) = evaluated_columns - .get(&AggregationOperator::Max) - .map_or((vec![], vec![]), |tuple| { - tuple - .iter() - .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) - .unzip() - }); - let (min_identifiers, min_columns): (Vec<_>, Vec<_>) = evaluated_columns - .get(&AggregationOperator::Min) - .map_or((vec![], vec![]), |tuple| { - tuple - .iter() - .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) - .unzip() - }); + + let (sum_identifiers, sum_columns): (Vec<_>, Vec<_>) = + evaluated_columns + .get("Sum") + .map_or((vec![], vec![]), |tuple| { + tuple + .iter() + .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) + .unzip() + }); + + let (max_identifiers, max_columns): (Vec<_>, Vec<_>) = + evaluated_columns + .get("Max") + .map_or((vec![], vec![]), |tuple| { + tuple + .iter() + .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) + .unzip() + }); + + let (min_identifiers, min_columns): (Vec<_>, Vec<_>) = + evaluated_columns + .get("Min") + .map_or((vec![], vec![]), |tuple| { + tuple + .iter() + .map(|(id, c)| (id.clone(), Column::::from_owned_column(c, &alloc))) + .unzip() + }); + let aggregation_results = aggregate_columns( &alloc, &group_by_ins, @@ -277,6 +293,7 @@ impl PostprocessingStep for GroupByPostprocessing { &min_columns, &selection_in, )?; + // Finally do another round of evaluation to get the final result // Gather the results into a new OwnedTable let group_by_outs = aggregation_results @@ -284,6 +301,7 @@ impl PostprocessingStep for GroupByPostprocessing { .iter() .zip(self.group_by_identifiers.iter()) .map(|(column, id)| Ok((id.clone(), OwnedColumn::from(column)))); + let sum_outs = izip!( aggregation_results.sum_columns, sum_identifiers, @@ -295,6 +313,7 @@ impl PostprocessingStep for GroupByPostprocessing { OwnedColumn::try_from_scalars(c_out, c_in.column_type())?, )) }); + let max_outs = izip!( aggregation_results.max_columns, max_identifiers, @@ -306,6 +325,7 @@ impl PostprocessingStep for GroupByPostprocessing { OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?, )) }); + let min_outs = izip!( aggregation_results.min_columns, min_identifiers, @@ -317,13 +337,14 @@ impl PostprocessingStep for GroupByPostprocessing { OwnedColumn::try_from_option_scalars(c_out, c_in.column_type())?, )) }); + //TODO: When we have NULLs we need to differentiate between count(1) and count(expression) let count_column = OwnedColumn::BigInt(aggregation_results.count_column.to_vec()); - let count_outs = evaluated_columns - .get(&AggregationOperator::Count) - .into_iter() - .flatten() - .map(|(id, _)| -> PostprocessingResult<_> { Ok((id.clone(), count_column.clone())) }); + let count_outs = + evaluated_columns.get("Count").into_iter().flatten().map( + |(id, _)| -> PostprocessingResult<_> { Ok((id.clone(), count_column.clone())) }, + ); + let new_owned_table: OwnedTable = group_by_outs .into_iter() .chain(sum_outs) @@ -331,6 +352,7 @@ impl PostprocessingStep for GroupByPostprocessing { .chain(min_outs) .chain(count_outs) .process_results(|iter| OwnedTable::try_from_iter(iter))??; + // If there are no columns at all we need to have the count column so that we can handle // queries such as `SELECT 1 FROM table` let target_table = if new_owned_table.is_empty() { @@ -338,12 +360,14 @@ impl PostprocessingStep for GroupByPostprocessing { } else { new_owned_table }; + let result = self .remainder_exprs .iter() .map(|aliased_expr| -> PostprocessingResult<_> { - let column = target_table.evaluate(&aliased_expr.expr)?; - let alias: Ident = aliased_expr.alias.into(); + let expr_as_expr: Expr = (*aliased_expr.expr).clone(); + let column = target_table.evaluate(&expr_as_expr)?; + let alias: Ident = aliased_expr.alias.clone(); Ok((alias, column)) }) .process_results(|iter| OwnedTable::try_from_iter(iter))??; @@ -359,32 +383,32 @@ mod tests { #[test] fn we_can_detect_nested_aggregation() { // SUM(SUM(a)) - let expr = sum(sum(col("a"))); + let expr = (*sum(sum(col("a")))).into(); assert!(contains_nested_aggregation(&expr, false)); assert!(contains_nested_aggregation(&expr, true)); // MAX(a) + SUM(b) - let expr = add(max(col("a")), sum(col("b"))); + let expr = (*add(max(col("a")), sum(col("b")))).into(); assert!(!contains_nested_aggregation(&expr, false)); assert!(contains_nested_aggregation(&expr, true)); // a + SUM(b) - let expr = add(col("a"), sum(col("b"))); + let expr = (*add(col("a"), sum(col("b")))).into(); assert!(!contains_nested_aggregation(&expr, false)); assert!(contains_nested_aggregation(&expr, true)); // SUM(a) + b - SUM(2 * c) - let expr = sub(add(sum(col("a")), col("b")), sum(mul(lit(2), col("c")))); + let expr = (*sub(add(sum(col("a")), col("b")), sum(mul(lit(2), col("c"))))).into(); assert!(!contains_nested_aggregation(&expr, false)); assert!(contains_nested_aggregation(&expr, true)); // a + COUNT(SUM(a)) - let expr = add(col("a"), count(sum(col("a")))); + let expr = (*add(col("a"), count(sum(col("a"))))).into(); assert!(contains_nested_aggregation(&expr, false)); assert!(contains_nested_aggregation(&expr, true)); // a + b + 1 - let expr = add(add(col("a"), col("b")), lit(1)); + let expr = (*add(add(col("a"), col("b")), lit(1))).into(); assert!(!contains_nested_aggregation(&expr, false)); assert!(!contains_nested_aggregation(&expr, true)); } @@ -392,31 +416,31 @@ mod tests { #[test] fn we_can_get_free_identifiers_from_expr() { // Literal - let expr = lit("Not an identifier"); + let expr = (*lit("Not an identifier")).into(); let expected: IndexSet = IndexSet::default(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); // a + b + 1 - let expr = add(add(col("a"), col("b")), lit(1)); + let expr = (*add(add(col("a"), col("b")), lit(1))).into(); let expected: IndexSet = ["a".into(), "b".into()].into_iter().collect(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); // ! (a == b || c >= a) - let expr = not(or(equal(col("a"), col("b")), ge(col("c"), col("a")))); + let expr = (*not(or(equal(col("a"), col("b")), ge(col("c"), col("a"))))).into(); let expected: IndexSet = ["a".into(), "b".into(), "c".into()].into_iter().collect(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); // SUM(a + b) * 2 - let expr = mul(sum(add(col("a"), col("b"))), lit(2)); + let expr = (*mul(sum(add(col("a"), col("b"))), lit(2))).into(); let expected: IndexSet = IndexSet::default(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); // (COUNT(a + b) + c) * d - let expr = mul(add(count(add(col("a"), col("b"))), col("c")), col("d")); + let expr = (*mul(add(count(add(col("a"), col("b"))), col("c")), col("d"))).into(); let expected: IndexSet = ["c".into(), "d".into()].into_iter().collect(); let actual = get_free_identifiers_from_expr(&expr); assert_eq!(actual, expected); @@ -424,63 +448,59 @@ mod tests { #[test] fn we_can_get_aggregate_and_remainder_expressions() { - let mut aggregation_expr_map: IndexMap<(AggregationOperator, Expression), Ident> = - IndexMap::default(); + let mut aggregation_expr_map: IndexMap<(String, Expr), Ident> = IndexMap::default(); + // SUM(a) + b let expr = add(sum(col("a")), col("b")); let remainder_expr = - get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map); + get_aggregate_and_remainder_expressions((*expr).into(), &mut aggregation_expr_map); assert_eq!( - aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))], + aggregation_expr_map[&("SUM".to_string(), (*col("a")).into())], "__col_agg_0".into() ); - assert_eq!(remainder_expr, Ok(*add(col("__col_agg_0"), col("b")))); + let expected_remainder: Expr = (*add(col("__col_agg_0"), col("b"))).into(); + assert_eq!(remainder_expr, Ok(expected_remainder)); assert_eq!(aggregation_expr_map.len(), 1); // SUM(a) + SUM(b) let expr = add(sum(col("a")), sum(col("b"))); let remainder_expr = - get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map); + get_aggregate_and_remainder_expressions((*expr).into(), &mut aggregation_expr_map); assert_eq!( - aggregation_expr_map[&(AggregationOperator::Sum, *col("a"))], + aggregation_expr_map[&("SUM".to_string(), (*col("a")).into())], "__col_agg_0".into() ); assert_eq!( - aggregation_expr_map[&(AggregationOperator::Sum, *col("b"))], + aggregation_expr_map[&("SUM".to_string(), (*col("b")).into())], "__col_agg_1".into() ); - assert_eq!( - remainder_expr, - Ok(*add(col("__col_agg_0"), col("__col_agg_1"))) - ); - assert_eq!(aggregation_expr_map.len(), 2); + let expected_remainder: Expr = (*add(col("__col_agg_0"), col("__col_agg_1"))).into(); + assert_eq!(remainder_expr, Ok(expected_remainder)); // MAX(a + 1) + MIN(2 * b - 4) + c let expr = add( add( - max(col("a") + lit(1)), + max(add(col("a"), lit(1))), min(sub(mul(lit(2), col("b")), lit(4))), ), col("c"), ); let remainder_expr = - get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map); + get_aggregate_and_remainder_expressions((*expr).into(), &mut aggregation_expr_map); assert_eq!( - aggregation_expr_map[&(AggregationOperator::Max, *add(col("a"), lit(1)))], + aggregation_expr_map[&("MAX".to_string(), (*add(col("a"), lit(1))).into())], "__col_agg_2".into() ); assert_eq!( aggregation_expr_map[&( - AggregationOperator::Min, - *sub(mul(lit(2), col("b")), lit(4)) + "MIN".to_string(), + (*sub(mul(lit(2), col("b")), lit(4))).into() )], "__col_agg_3".into() ); - assert_eq!( - remainder_expr, - Ok(*add(add(col("__col_agg_2"), col("__col_agg_3")), col("c"))) - ); - assert_eq!(aggregation_expr_map.len(), 4); + let expected_remainder: Expr = + (*add(add(col("__col_agg_2"), col("__col_agg_3")), col("c"))).into(); + assert_eq!(remainder_expr, Ok(expected_remainder)); // COUNT(2 * a) * 2 + SUM(b) + 1 let expr = add( @@ -488,18 +508,16 @@ mod tests { lit(1), ); let remainder_expr = - get_aggregate_and_remainder_expressions(*expr, &mut aggregation_expr_map); + get_aggregate_and_remainder_expressions((*expr).into(), &mut aggregation_expr_map); assert_eq!( - aggregation_expr_map[&(AggregationOperator::Count, *mul(lit(2), col("a")))], + aggregation_expr_map[&("COUNT".to_string(), (*mul(lit(2), col("a"))).into())], "__col_agg_4".into() ); - assert_eq!( - remainder_expr, - Ok(*add( - add(mul(col("__col_agg_4"), lit(2)), col("__col_agg_1")), - lit(1) - )) - ); - assert_eq!(aggregation_expr_map.len(), 5); + let expected_remainder: Expr = (*add( + add(mul(col("__col_agg_4"), lit(2)), col("__col_agg_1")), + lit(1), + )) + .into(); + assert_eq!(remainder_expr, Ok(expected_remainder)); } } diff --git a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs index 3af8fae7e..4ce6858b5 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/group_by_postprocessing_test.rs @@ -1,6 +1,6 @@ use crate::{ base::{ - database::{owned_table_utility::*, OwnedTable}, + database::{expr_utility::*, owned_table_utility::*, OwnedTable}, scalar::Curve25519Scalar, }, sql::postprocessing::{ @@ -9,7 +9,7 @@ use crate::{ }, }; use bigdecimal::BigDecimal; -use proof_of_sql_parser::{intermediate_ast::AggregationOperator, utility::*}; +use sqlparser::ast::Ident; #[test] fn we_cannot_have_invalid_group_bys() { // Column in result but not in group by or aggregation @@ -48,17 +48,16 @@ fn we_can_make_group_by_postprocessing() { aliased_expr(col("__col_agg_1"), "c1"), ] ); - assert_eq!( - res.aggregation_exprs(), - &[ - (AggregationOperator::Sum, *col("a"), "__col_agg_0".into()), - ( - AggregationOperator::Sum, - *add(col("b"), col("a")), - "__col_agg_1".into() - ), - ] - ); + + let expected_aggregation_exprs = vec![ + ("SUM".to_string(), col("a"), Ident::new("__col_agg_0")), + ( + "SUM".to_string(), + add(col("b"), col("a")), + Ident::new("__col_agg_1"), + ), + ]; + assert_eq!(res.aggregation_exprs(), &expected_aggregation_exprs); } #[allow(clippy::too_many_lines)] diff --git a/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing.rs b/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing.rs index 9437c5daf..7a30935b4 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing.rs @@ -5,21 +5,21 @@ use crate::base::{ scalar::Scalar, }; use alloc::vec::Vec; -use proof_of_sql_parser::intermediate_ast::AliasedResultExpr; +use proof_of_sql_parser::sqlparser::SqlAliasedResultExpr; use serde::{Deserialize, Serialize}; -use sqlparser::ast::Ident; +use sqlparser::ast::{Expr, Ident}; /// The select expression used to select, reorder, and apply alias transformations #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct SelectPostprocessing { /// The aliased result expressions we select - aliased_result_exprs: Vec, + aliased_result_exprs: Vec, } impl SelectPostprocessing { /// Create a new `SelectPostprocessing` node. #[must_use] - pub fn new(aliased_result_exprs: Vec) -> Self { + pub fn new(aliased_result_exprs: Vec) -> Self { Self { aliased_result_exprs, } @@ -34,8 +34,9 @@ impl PostprocessingStep for SelectPostprocessing { .iter() .map( |aliased_result_expr| -> PostprocessingResult<(Ident, OwnedColumn)> { - let result_column = owned_table.evaluate(&aliased_result_expr.expr)?; - Ok((aliased_result_expr.alias.into(), result_column)) + let sql_expr: Expr = (*aliased_result_expr.expr).clone(); + let result_column = owned_table.evaluate(&sql_expr)?; + Ok((aliased_result_expr.alias.clone(), result_column)) }, ) .collect::>()?; diff --git a/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing_test.rs b/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing_test.rs index f80265096..fbebc3aeb 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing_test.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/select_postprocessing_test.rs @@ -1,11 +1,10 @@ use crate::{ base::{ - database::{owned_table_utility::*, OwnedTable}, + database::{expr_utility::*, owned_table_utility::*, OwnedTable}, scalar::Curve25519Scalar, }, sql::postprocessing::{apply_postprocessing_steps, test_utility::*, OwnedTablePostprocessing}, }; -use proof_of_sql_parser::utility::*; #[test] fn we_can_filter_out_owned_table_columns() { diff --git a/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs b/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs index 24f8904b4..44ededcd1 100644 --- a/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs +++ b/crates/proof-of-sql/src/sql/postprocessing/test_utility.rs @@ -1,11 +1,14 @@ use super::*; -use proof_of_sql_parser::intermediate_ast::{AliasedResultExpr, OrderBy, OrderByDirection}; +use proof_of_sql_parser::{ + intermediate_ast::{OrderBy, OrderByDirection}, + sqlparser::SqlAliasedResultExpr, +}; use sqlparser::ast::Ident; #[must_use] pub fn group_by_postprocessing( cols: &[&str], - result_exprs: &[AliasedResultExpr], + result_exprs: &[SqlAliasedResultExpr], ) -> OwnedTablePostprocessing { let ids: Vec = cols.iter().map(|col| (*col).into()).collect(); OwnedTablePostprocessing::new_group_by( @@ -18,7 +21,7 @@ pub fn group_by_postprocessing( /// /// This function may panic if the internal structures cannot be created properly, although this is unlikely under normal circumstances. #[must_use] -pub fn select_expr(result_exprs: &[AliasedResultExpr]) -> OwnedTablePostprocessing { +pub fn select_expr(result_exprs: &[SqlAliasedResultExpr]) -> OwnedTablePostprocessing { OwnedTablePostprocessing::new_select(SelectPostprocessing::new(result_exprs.to_vec())) } diff --git a/crates/proof-of-sql/src/sql/proof/query_proof.rs b/crates/proof-of-sql/src/sql/proof/query_proof.rs index 6b6e4c773..b41facb72 100644 --- a/crates/proof-of-sql/src/sql/proof/query_proof.rs +++ b/crates/proof-of-sql/src/sql/proof/query_proof.rs @@ -24,6 +24,7 @@ use alloc::{string::String, vec, vec::Vec}; use bumpalo::Bump; use core::cmp; use num_traits::Zero; +use proof_of_sql_parser::sqlparser::TimezoneInfoExt; use serde::{Deserialize, Serialize}; /// Return the row number range of tables referenced in the Query @@ -514,9 +515,9 @@ fn extend_transcript_with_owned_table( OwnedColumn::Scalar(col) => { transcript.extend_as_be(col.iter().map(|&s| Into::<[u64; 4]>::into(s))); } - OwnedColumn::TimestampTZ(po_sqltime_unit, po_sqltime_zone, col) => { + OwnedColumn::TimestampTZ(po_sqltime_unit, timezone_info, col) => { transcript.extend_as_be([u64::from(*po_sqltime_unit)]); - transcript.extend_as_be([po_sqltime_zone.offset()]); + transcript.extend_as_be([timezone_info.offset(Some("+00:00"))]); transcript.extend_as_be_from_refs(col); } } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/aggregation.rs b/crates/proof-of-sql/src/sql/proof_exprs/aggregation.rs new file mode 100644 index 000000000..33f02be53 --- /dev/null +++ b/crates/proof-of-sql/src/sql/proof_exprs/aggregation.rs @@ -0,0 +1,33 @@ +use alloc::{ + fmt, + fmt::{Display, Formatter}, +}; +use serde::{Deserialize, Serialize}; + +/// Aggregation operators +#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] +/// Aggregation operators +pub enum AggOperator { + /// Maximum + Max, + /// Minimum + Min, + /// Sum + Sum, + /// Count + Count, + /// Return the first value + First, +} + +impl Display for AggOperator { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self { + AggOperator::Max => write!(f, "max"), + AggOperator::Min => write!(f, "min"), + AggOperator::Sum => write!(f, "sum"), + AggOperator::Count => write!(f, "count"), + AggOperator::First => write!(f, "first"), + } + } +} diff --git a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs index 92601f512..ac73b2e44 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/comparison_util.rs @@ -1,16 +1,19 @@ use crate::{ base::{ - database::{Column, ColumnarValue, LiteralValue}, + database::{ + literal_value::{ExprExt, ToScalar}, + Column, ColumnError, ColumnarValue, + }, math::decimal::{DecimalError, Precision}, scalar::{Scalar, ScalarExt}, slice_ops, }, sql::parse::{type_check_binary_operation, ConversionError, ConversionResult}, }; -use alloc::string::ToString; +use alloc::{format, string::ToString, vec}; use bumpalo::Bump; use core::cmp::{max, Ordering}; -use sqlparser::ast::BinaryOperator; +use sqlparser::ast::{BinaryOperator, DataType, Expr as SqlExpr, ObjectName}; /// Scale LHS and RHS to the same scale if at least one of them is decimal /// and take the difference. This function is used for comparisons. @@ -20,8 +23,8 @@ use sqlparser::ast::BinaryOperator; /// or if we have precision overflow issues. #[allow(clippy::cast_sign_loss)] pub fn scale_and_subtract_literal( - lhs: &LiteralValue, - rhs: &LiteralValue, + lhs: &SqlExpr, + rhs: &SqlExpr, lhs_scale: i8, rhs_scale: i8, is_equal: bool, @@ -156,37 +159,48 @@ pub(crate) fn scale_and_subtract_columnar_value<'a, S: Scalar>( lhs_scale: i8, rhs_scale: i8, is_equal: bool, -) -> ConversionResult> { +) -> Result, ColumnError> { match (lhs, rhs) { (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { - Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( - alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal, - )?))) + Ok(ColumnarValue::Column(Column::Scalar( + scale_and_subtract(alloc, lhs, rhs, lhs_scale, rhs_scale, is_equal) + .map_err(|err| ColumnError::ConversionError { source: err })?, + ))) } (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { - Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( - alloc, - Column::from_literal_with_length(&lhs, rhs.len(), alloc), - rhs, - lhs_scale, - rhs_scale, - is_equal, - )?))) + Ok(ColumnarValue::Column(Column::Scalar( + scale_and_subtract( + alloc, + Column::from_literal_with_length(&lhs, rhs.len(), alloc)?, + rhs, + lhs_scale, + rhs_scale, + is_equal, + ) + .map_err(|err| ColumnError::ConversionError { source: err })?, + ))) } (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { - Ok(ColumnarValue::Column(Column::Scalar(scale_and_subtract( - alloc, - lhs, - Column::from_literal_with_length(&rhs, lhs.len(), alloc), - lhs_scale, - rhs_scale, - is_equal, - )?))) + Ok(ColumnarValue::Column(Column::Scalar( + scale_and_subtract( + alloc, + lhs, + Column::from_literal_with_length(&rhs, lhs.len(), alloc)?, + lhs_scale, + rhs_scale, + is_equal, + ) + .map_err(|err| ColumnError::ConversionError { source: err })?, + ))) } (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { - Ok(ColumnarValue::Literal(LiteralValue::Scalar( - scale_and_subtract_literal::(&lhs, &rhs, lhs_scale, rhs_scale, is_equal)?.into(), - ))) + let result_scalar = + scale_and_subtract_literal::(&lhs, &rhs, lhs_scale, rhs_scale, is_equal) + .map_err(|err| ColumnError::ConversionError { source: err })?; + Ok(ColumnarValue::Literal(SqlExpr::TypedString { + data_type: DataType::Custom(ObjectName(vec![]), vec![]), + value: format!("scalar:{result_scalar}"), + })) } } } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs index efe540b52..a02ba8bf3 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/dyn_proof_expr.rs @@ -4,7 +4,7 @@ use super::{ }; use crate::{ base::{ - database::{Column, ColumnRef, ColumnType, LiteralValue, Table}, + database::{Column, ColumnRef, ColumnType, Table}, map::{IndexMap, IndexSet}, proof::ProofError, scalar::Scalar, @@ -14,12 +14,12 @@ use crate::{ proof::{FinalRoundBuilder, VerificationBuilder}, }, }; -use alloc::{boxed::Box, string::ToString}; +use alloc::{boxed::Box, format, string::ToString}; use bumpalo::Bump; use core::fmt::Debug; use proof_of_sql_parser::intermediate_ast::AggregationOperator; use serde::{Deserialize, Serialize}; -use sqlparser::ast::BinaryOperator; +use sqlparser::ast::{BinaryOperator, Expr as SqlExpr}; /// Enum of AST column expression types that implement `ProofExpr`. Is itself a `ProofExpr`. #[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] @@ -69,8 +69,8 @@ impl DynProofExpr { Ok(Self::Not(NotExpr::new(Box::new(expr)))) } /// Create CONST expression - pub fn new_literal(value: LiteralValue) -> Self { - Self::Literal(LiteralExpr::new(value)) + pub fn new_literal(expr: SqlExpr) -> Self { + Self::Literal(LiteralExpr::new(expr)) } /// Create a new equals expression pub fn try_new_equals(lhs: DynProofExpr, rhs: DynProofExpr) -> ConversionResult { @@ -161,8 +161,22 @@ impl DynProofExpr { } /// Create a new aggregate expression - pub fn new_aggregate(op: AggregationOperator, expr: DynProofExpr) -> Self { - Self::Aggregate(AggregateExpr::new(op, Box::new(expr))) + pub fn new_aggregate(op: &str, expr: DynProofExpr) -> Result { + let aggregation_operator = match op.to_uppercase().as_str() { + "SUM" => AggregationOperator::Sum, + "COUNT" => AggregationOperator::Count, + "MAX" => AggregationOperator::Max, + "MIN" => AggregationOperator::Min, + _ => { + return Err(ConversionError::Unprovable { + error: format!("Unsupported aggregation operator: {op}"), + }) + } + }; + Ok(Self::Aggregate(AggregateExpr::new( + aggregation_operator, + Box::new(expr), + ))) } /// Check that the plan has the correct data type diff --git a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs index fa83572d3..9bcbd4793 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/inequality_expr_test.rs @@ -2,8 +2,8 @@ use crate::{ base::{ commitment::InnerProductProof, database::{ - owned_table_utility::*, table_utility::*, Column, LiteralValue, OwnedTable, - OwnedTableTestAccessor, TableTestAccessor, TestAccessor, + owned_table_utility::*, table_utility::*, Column, OwnedTable, OwnedTableTestAccessor, + TableTestAccessor, TestAccessor, }, scalar::{Curve25519Scalar, Scalar, ScalarExt}, }, @@ -16,19 +16,20 @@ use crate::{ }; use bumpalo::Bump; use itertools::{multizip, MultiUnzip}; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; use rand::{ distributions::{Distribution, Uniform}, rngs::StdRng, }; use rand_core::SeedableRng; +use sqlparser::ast::{DataType, Expr, TimezoneInfo}; #[test] fn we_can_compare_columns_with_small_timestamp_values_gte() { let data: OwnedTable = owned_table([timestamptz( "a", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::WithTimeZone, vec![-1, 0, 1], )]); let t = "sxt.t".parse().unwrap(); @@ -38,11 +39,10 @@ fn we_can_compare_columns_with_small_timestamp_values_gte() { tab(t), gte( column(t, "a", &accessor), - DynProofExpr::new_literal(LiteralValue::TimeStampTZ( - PoSQLTimeUnit::Nanosecond, - PoSQLTimeZone::utc(), - 1, - )), + DynProofExpr::new_literal(Expr::TypedString { + data_type: DataType::Timestamp(Some(1), TimezoneInfo::WithTimeZone), + value: "1".to_string(), + }), ), ); @@ -51,7 +51,7 @@ fn we_can_compare_columns_with_small_timestamp_values_gte() { let expected_res = owned_table([timestamptz( "a", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::WithTimeZone, vec![1], )]); assert_eq!(res, expected_res); @@ -62,7 +62,7 @@ fn we_can_compare_columns_with_small_timestamp_values_lte() { let data: OwnedTable = owned_table([timestamptz( "a", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::WithTimeZone, vec![-1, 0, 1], )]); let t = "sxt.t".parse().unwrap(); @@ -72,11 +72,10 @@ fn we_can_compare_columns_with_small_timestamp_values_lte() { tab(t), lte( column(t, "a", &accessor), - DynProofExpr::new_literal(LiteralValue::TimeStampTZ( - PoSQLTimeUnit::Nanosecond, - PoSQLTimeZone::utc(), - 1, - )), + DynProofExpr::new_literal(Expr::TypedString { + data_type: DataType::Timestamp(Some(1), TimezoneInfo::WithTimeZone), + value: "1".to_string(), + }), ), ); @@ -85,7 +84,7 @@ fn we_can_compare_columns_with_small_timestamp_values_lte() { let expected_res = owned_table([timestamptz( "a", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::WithTimeZone, vec![-1, 0], )]); assert_eq!(res, expected_res); diff --git a/crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs b/crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs index e4cd8cf0e..6fc4f0e62 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/literal_expr.rs @@ -1,7 +1,7 @@ use super::ProofExpr; use crate::{ base::{ - database::{Column, ColumnRef, ColumnType, LiteralValue, Table}, + database::{Column, ColumnRef, ColumnType, ExprExt, Table, ToScalar}, map::{IndexMap, IndexSet}, proof::ProofError, scalar::Scalar, @@ -11,7 +11,7 @@ use crate::{ }; use bumpalo::Bump; use serde::{Deserialize, Serialize}; - +use sqlparser::ast::Expr; /// Provable CONST expression /// /// This node allows us to easily represent queries like @@ -25,12 +25,12 @@ use serde::{Deserialize, Serialize}; /// changes, and the performance is sufficient for present. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct LiteralExpr { - pub(crate) value: LiteralValue, + pub(crate) value: Expr, } impl LiteralExpr { /// Create literal expression - pub fn new(value: LiteralValue) -> Self { + pub fn new(value: Expr) -> Self { Self { value } } } @@ -52,7 +52,7 @@ impl ProofExpr for LiteralExpr { log::log_memory_usage("End"); - res + res.expect("Failed to evaluate literal expression") } #[tracing::instrument(name = "LiteralExpr::prover_evaluate", level = "debug", skip_all)] @@ -69,7 +69,7 @@ impl ProofExpr for LiteralExpr { log::log_memory_usage("End"); - res + res.expect("Failed to evaluate literal expression") } fn verifier_evaluate( diff --git a/crates/proof-of-sql/src/sql/proof_exprs/mod.rs b/crates/proof-of-sql/src/sql/proof_exprs/mod.rs index 298ad945d..318e33073 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/mod.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/mod.rs @@ -14,6 +14,7 @@ mod add_subtract_expr_test; mod aggregate_expr; pub(crate) use aggregate_expr::AggregateExpr; +mod aggregation; mod multiply_expr; use multiply_expr::MultiplyExpr; diff --git a/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs b/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs index 79f16167a..c5ddca301 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/numerical_util.rs @@ -1,15 +1,17 @@ use crate::base::{ - database::{Column, ColumnarValue, LiteralValue}, + database::{literal_value::ToScalar, Column, ColumnError, ColumnarValue}, scalar::{Scalar, ScalarExt}, }; +use alloc::{format, vec}; use bumpalo::Bump; use core::cmp::Ordering; +use sqlparser::ast::{DataType, Expr as SqlExpr, ObjectName}; #[allow(clippy::cast_sign_loss)] /// Add or subtract two literals together. pub(crate) fn add_subtract_literals( - lhs: &LiteralValue, - rhs: &LiteralValue, + lhs: &SqlExpr, + rhs: &SqlExpr, lhs_scale: i8, rhs_scale: i8, is_subtract: bool, @@ -73,42 +75,46 @@ pub(crate) fn add_subtract_columnar_values<'a, S: Scalar>( rhs_scale: i8, alloc: &'a Bump, is_subtract: bool, -) -> ColumnarValue<'a, S> { +) -> Result, ColumnError> { match (lhs, rhs) { (ColumnarValue::Column(lhs), ColumnarValue::Column(rhs)) => { - ColumnarValue::Column(Column::Scalar(add_subtract_columns( + Ok(ColumnarValue::Column(Column::Scalar(add_subtract_columns( lhs, rhs, lhs_scale, rhs_scale, alloc, is_subtract, - ))) + )))) } (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { - ColumnarValue::Column(Column::Scalar(add_subtract_columns( - Column::from_literal_with_length(&lhs, rhs.len(), alloc), + Ok(ColumnarValue::Column(Column::Scalar(add_subtract_columns( + Column::from_literal_with_length(&lhs, rhs.len(), alloc)?, rhs, lhs_scale, rhs_scale, alloc, is_subtract, - ))) + )))) } (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { - ColumnarValue::Column(Column::Scalar(add_subtract_columns( + let rhs_column = Column::from_literal_with_length(&rhs, lhs.len(), alloc)?; + Ok(ColumnarValue::Column(Column::Scalar(add_subtract_columns( lhs, - Column::from_literal_with_length(&rhs, lhs.len(), alloc), + rhs_column, lhs_scale, rhs_scale, alloc, is_subtract, - ))) + )))) } (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { - ColumnarValue::Literal(LiteralValue::Scalar( - add_subtract_literals::(&lhs, &rhs, lhs_scale, rhs_scale, is_subtract).into(), - )) + let result_scalar = + add_subtract_literals::(&lhs, &rhs, lhs_scale, rhs_scale, is_subtract); + Ok(ColumnarValue::Literal(SqlExpr::TypedString { + data_type: DataType::Custom(ObjectName(vec![]), vec![]), + value: format!("scalar:{result_scalar}"), + })) } } } @@ -146,20 +152,23 @@ pub(crate) fn multiply_columnar_values<'a, S: Scalar>( ColumnarValue::Column(Column::Scalar(multiply_columns(lhs, rhs, alloc))) } (ColumnarValue::Literal(lhs), ColumnarValue::Column(rhs)) => { - let lhs_scalar = lhs.to_scalar::(); + let lhs_scalar = (*lhs).to_scalar::(); let result = alloc.alloc_slice_fill_with(rhs.len(), |i| lhs_scalar * rhs.scalar_at(i).unwrap()); ColumnarValue::Column(Column::Scalar(result)) } (ColumnarValue::Column(lhs), ColumnarValue::Literal(rhs)) => { - let rhs_scalar = rhs.to_scalar(); + let rhs_scalar = (*rhs).to_scalar(); let result = alloc.alloc_slice_fill_with(lhs.len(), |i| lhs.scalar_at(i).unwrap() * rhs_scalar); ColumnarValue::Column(Column::Scalar(result)) } (ColumnarValue::Literal(lhs), ColumnarValue::Literal(rhs)) => { - let result = lhs.to_scalar::() * rhs.to_scalar(); - ColumnarValue::Literal(LiteralValue::Scalar(result.into())) + let result = (*lhs).to_scalar::() * (*rhs).to_scalar(); + ColumnarValue::Literal(SqlExpr::TypedString { + data_type: DataType::Custom(ObjectName(vec![]), vec![]), + value: format!("scalar:{result}"), + }) } } } diff --git a/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs b/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs index baa8b44e5..22c2adac8 100644 --- a/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs +++ b/crates/proof-of-sql/src/sql/proof_exprs/test_utility.rs @@ -1,11 +1,10 @@ use super::{AliasedDynProofExpr, ColumnExpr, DynProofExpr, TableExpr}; use crate::base::{ - database::{ColumnRef, LiteralValue, SchemaAccessor, TableRef}, - math::{decimal::Precision, i256::I256}, + database::{ColumnRef, SchemaAccessor, TableRef}, + math::i256::I256, scalar::Scalar, }; -use proof_of_sql_parser::intermediate_ast::AggregationOperator; -use sqlparser::ast::Ident; +use sqlparser::ast::{DataType, ExactNumberInfo, Expr, Ident, ObjectName, Value}; pub fn col_ref(tab: TableRef, name: &str, accessor: &impl SchemaAccessor) -> ColumnRef { let name: Ident = name.into(); @@ -86,43 +85,52 @@ pub fn multiply(left: DynProofExpr, right: DynProofExpr) -> DynProofExpr { } pub fn const_bool(val: bool) -> DynProofExpr { - DynProofExpr::new_literal(LiteralValue::Boolean(val)) + DynProofExpr::new_literal(Expr::Value(Value::Boolean(val))) } pub fn const_smallint(val: i16) -> DynProofExpr { - DynProofExpr::new_literal(LiteralValue::SmallInt(val)) + DynProofExpr::new_literal(Expr::Value(Value::Number(val.to_string(), false))) } pub fn const_int(val: i32) -> DynProofExpr { - DynProofExpr::new_literal(LiteralValue::Int(val)) + DynProofExpr::new_literal(Expr::Value(Value::Number(val.to_string(), false))) } pub fn const_bigint(val: i64) -> DynProofExpr { - DynProofExpr::new_literal(LiteralValue::BigInt(val)) + DynProofExpr::new_literal(Expr::Value(Value::Number(val.to_string(), false))) } pub fn const_int128(val: i128) -> DynProofExpr { - DynProofExpr::new_literal(LiteralValue::Int128(val)) + DynProofExpr::new_literal(Expr::Value(Value::Number(val.to_string(), false))) } pub fn const_varchar(val: &str) -> DynProofExpr { - DynProofExpr::new_literal(LiteralValue::VarChar(val.to_string())) + DynProofExpr::new_literal(Expr::Value(Value::SingleQuotedString(val.to_string()))) } /// Create a constant scalar value. Used if we don't want to specify column types. pub fn const_scalar>(val: T) -> DynProofExpr { - DynProofExpr::new_literal(LiteralValue::Scalar(val.into().into())) + let scalar_str = format!("scalar:{}", val.into()); + + DynProofExpr::new_literal(Expr::TypedString { + data_type: DataType::Custom(ObjectName(vec![Ident::new("scalar")]), vec![]), + value: scalar_str, + }) } /// # Panics /// Panics if: /// - `Precision::new(precision)` fails, meaning the provided precision is invalid. pub fn const_decimal75>(precision: u8, scale: i8, val: T) -> DynProofExpr { - DynProofExpr::new_literal(LiteralValue::Decimal75( - Precision::new(precision).unwrap(), - scale, - val.into(), - )) + let decimal_value = val.into(); + let decimal_str = format!("{decimal_value}e{scale}"); + DynProofExpr::new_literal(Expr::TypedString { + data_type: DataType::Decimal(ExactNumberInfo::PrecisionAndScale( + u64::from(precision), + i64::from(scale).try_into().unwrap(), + )), + value: decimal_str, + }) } pub fn tab(tab: TableRef) -> TableExpr { @@ -208,7 +216,8 @@ pub fn cols_expr(tab: TableRef, names: &[&str], accessor: &impl SchemaAccessor) /// - `alias.parse()` fails to parse the provided alias string. pub fn sum_expr(expr: DynProofExpr, alias: &str) -> AliasedDynProofExpr { AliasedDynProofExpr { - expr: DynProofExpr::new_aggregate(AggregationOperator::Sum, expr), + expr: DynProofExpr::new_aggregate("SUM", expr) + .expect("Failed to create aggregate expression"), alias: alias.into(), } } diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs index 3a74c0e6e..caad49733 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test.rs @@ -3,8 +3,7 @@ use crate::{ base::{ database::{ owned_table_utility::*, table_utility::*, ColumnField, ColumnRef, ColumnType, - LiteralValue, OwnedTable, OwnedTableTestAccessor, TableRef, TableTestAccessor, - TestAccessor, + OwnedTable, OwnedTableTestAccessor, TableRef, TableTestAccessor, TestAccessor, }, map::{indexmap, IndexMap, IndexSet}, math::decimal::Precision, @@ -21,7 +20,7 @@ use crate::{ use blitzar::proof::InnerProductProof; use bumpalo::Bump; use proof_of_sql_parser::ResourceId; -use sqlparser::ast::Ident; +use sqlparser::ast::{Expr, Ident, Value}; #[test] fn we_can_correctly_fetch_the_query_result_schema() { @@ -54,7 +53,10 @@ fn we_can_correctly_fetch_the_query_result_schema() { Ident::new("c"), ColumnType::BigInt, ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(123))), + DynProofExpr::Literal(LiteralExpr::new(Expr::Value(Value::Number( + "123".to_string(), + false, + )))), ) .unwrap(), ); @@ -102,7 +104,10 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { Ident::new("f"), ColumnType::BigInt, ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(45))), + DynProofExpr::Literal(LiteralExpr::new(Expr::Value(Value::Number( + "45".to_string(), + false, + )))), ) .unwrap(), DynProofExpr::try_new_equals( @@ -111,7 +116,10 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { Ident::new("c"), ColumnType::BigInt, ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(-2))), + DynProofExpr::Literal(LiteralExpr::new(Expr::Value(Value::Number( + "-2".to_string(), + false, + )))), ) .unwrap(), ), @@ -121,7 +129,10 @@ fn we_can_correctly_fetch_all_the_referenced_columns() { Ident::new("b"), ColumnType::BigInt, ))), - DynProofExpr::Literal(LiteralExpr::new(LiteralValue::BigInt(3))), + DynProofExpr::Literal(LiteralExpr::new(Expr::Value(Value::Number( + "3".to_string(), + false, + )))), ) .unwrap(), )), diff --git a/crates/proof-of-sql/tests/timestamp_integration_tests.rs b/crates/proof-of-sql/tests/timestamp_integration_tests.rs index d45f14da7..855aaaf0b 100644 --- a/crates/proof-of-sql/tests/timestamp_integration_tests.rs +++ b/crates/proof-of-sql/tests/timestamp_integration_tests.rs @@ -11,7 +11,8 @@ use proof_of_sql::{ }, sql::{parse::QueryExpr, proof::VerifiableQueryResult}, }; -use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone}; +use proof_of_sql_parser::posql_time::PoSQLTimeUnit; +use sqlparser::ast::TimezoneInfo; #[test] fn we_can_prove_a_basic_query_containing_rfc3339_timestamp_with_dory() { @@ -34,7 +35,7 @@ fn we_can_prove_a_basic_query_containing_rfc3339_timestamp_with_dory() { timestamptz( "times", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [i64::MIN, 0, i64::MAX], ), ]), @@ -60,7 +61,7 @@ fn we_can_prove_a_basic_query_containing_rfc3339_timestamp_with_dory() { let expected_result = owned_table([timestamptz( "times", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [0], )]); assert_eq!(owned_table_result, expected_result); @@ -81,7 +82,7 @@ fn run_timestamp_query_test( owned_table([timestamptz( "times", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, test_timestamps, )]), 0, @@ -100,7 +101,7 @@ fn run_timestamp_query_test( let expected_result = owned_table([timestamptz( "times", PoSQLTimeUnit::Second, - PoSQLTimeZone::utc(), + TimezoneInfo::None, expected_timestamps, )]); @@ -396,7 +397,7 @@ fn we_can_prove_timestamp_inequality_queries_with_multiple_columns() { timestamptz( "a", PoSQLTimeUnit::Nanosecond, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [ i64::MIN, 2, @@ -411,7 +412,7 @@ fn we_can_prove_timestamp_inequality_queries_with_multiple_columns() { timestamptz( "b", PoSQLTimeUnit::Nanosecond, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [ i64::MAX, -2, @@ -447,13 +448,13 @@ fn we_can_prove_timestamp_inequality_queries_with_multiple_columns() { timestamptz( "a", PoSQLTimeUnit::Nanosecond, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [i64::MIN, -1, -2], ), timestamptz( "b", PoSQLTimeUnit::Nanosecond, - PoSQLTimeZone::utc(), + TimezoneInfo::None, [i64::MAX, -1, 1], ), boolean("res", [true, true, true]),