Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: Expression to use Expr #456

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* text=auto
16 changes: 16 additions & 0 deletions crates/proof-of-sql-parser/src/posql_time/unit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::PoSQLTimestampError;
use crate::alloc::string::ToString;
use core::fmt;
use serde::{Deserialize, Serialize};

Expand All @@ -16,6 +17,21 @@ pub enum PoSQLTimeUnit {
Nanosecond,
}

impl PoSQLTimeUnit {
/// Converts a precision value into a corresponding `PoSQLTimeUnit`.
pub fn from_precision(precision: u64) -> Result<Self, PoSQLTimestampError> {
match precision {
0 => Ok(PoSQLTimeUnit::Second),
3 => Ok(PoSQLTimeUnit::Millisecond),
6 => Ok(PoSQLTimeUnit::Microsecond),
9 => Ok(PoSQLTimeUnit::Nanosecond),
_ => Err(PoSQLTimestampError::UnsupportedPrecision {
error: precision.to_string(),
}),
}
}
}

impl From<PoSQLTimeUnit> for u64 {
fn from(value: PoSQLTimeUnit) -> u64 {
match value {
Expand Down
88 changes: 86 additions & 2 deletions crates/proof-of-sql-parser/src/sqlparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@ use crate::{
OrderBy as PoSqlOrderBy, OrderByDirection, SelectResultExpr, SetExpression,
TableExpression, UnaryOperator as PoSqlUnaryOperator,
},
posql_time::{PoSQLTimeUnit, PoSQLTimeZone},
Identifier, ResourceId, SelectStatement,
};
use alloc::{boxed::Box, string::ToString, vec};
use alloc::{
boxed::Box,
string::{String, ToString},
vec,
};
use core::fmt::Display;
use serde::{Deserialize, Serialize};
use sqlparser::ast::{
BinaryOperator, DataType, Expr, Function, FunctionArg, FunctionArgExpr, GroupByExpr, Ident,
ObjectName, Offset, OffsetRows, OrderByExpr, Query, Select, SelectItem, SetExpr, TableFactor,
Expand All @@ -28,6 +34,79 @@ fn id(id: Identifier) -> Expr {
Expr::Identifier(id.into())
}

#[must_use]
/// New `AliasedResultExpr` using sqlparser types
/// Represents an aliased SQL expression, e.g., `a + 1 AS alias`.
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct SqlAliasedResultExpr {
/// The SQL expression being aliased (e.g., `a + 1`).
pub expr: Box<Expr>,
/// The alias for the expression (e.g., `alias` in `a + 1 AS alias`).
pub alias: Ident,
}

impl SqlAliasedResultExpr {
/// Creates a new `SqlAliasedResultExpr`.
pub fn new(expr: Box<Expr>, alias: Ident) -> Self {
Self { expr, alias }
}

/// Try to get the identifier of the expression if it is a column
/// Otherwise, return None
#[must_use]
pub fn try_as_identifier(&self) -> Option<&Ident> {
match self.expr.as_ref() {
Expr::Identifier(identifier) => Some(identifier),
_ => None,
}
}
}

/// Provides an extension for the `TimezoneInfo` type for offsets.
pub trait TimezoneInfoExt {
/// Retrieve the offset in seconds for `TimezoneInfo`.
fn offset(&self, timezone_str: Option<&str>) -> i32;
}

impl TimezoneInfoExt for TimezoneInfo {
fn offset(&self, timezone_str: Option<&str>) -> i32 {
match self {
TimezoneInfo::None => PoSQLTimeZone::utc().offset(),
TimezoneInfo::WithTimeZone => match timezone_str {
Some(tz_str) => PoSQLTimeZone::try_from(&Some(tz_str.into()))
.unwrap_or_else(|_| PoSQLTimeZone::utc())
.offset(),
None => PoSQLTimeZone::utc().offset(),
},
_ => panic!("Offsets are not applicable for WithoutTimeZone or Tz variants."),
}
}
}

/// Utility function to create a `Timestamp` expression.
pub fn timestamp_to_expr(
value: &str,
time_unit: PoSQLTimeUnit,
timezone: TimezoneInfo,
) -> Result<Expr, String> {
let time_unit_as_u64 = u64::from(time_unit);

Ok(Expr::TypedString {
data_type: DataType::Timestamp(Some(time_unit_as_u64), timezone),
value: value.to_string(),
})
}

/// Parses [`PoSQLTimeZone`] into a `TimezoneInfo`.
impl From<PoSQLTimeZone> for TimezoneInfo {
fn from(posql_timezone: PoSQLTimeZone) -> Self {
match posql_timezone.offset() {
0 => TimezoneInfo::None,
_ => TimezoneInfo::WithTimeZone,
}
}
}

impl From<Identifier> for Ident {
fn from(id: Identifier) -> Self {
Ident::new(id.as_str())
Expand Down Expand Up @@ -125,7 +204,7 @@ impl From<PoSqlOrderBy> for OrderByExpr {
impl From<Expression> for Expr {
fn from(expr: Expression) -> Self {
match expr {
Expression::Literal(literal) => literal.into(),
Expression::Literal(literal) => Expr::from(literal),
Expression::Column(identifier) => id(identifier),
Expression::Unary { op, expr } => Expr::UnaryOp {
op: op.into(),
Expand Down Expand Up @@ -268,6 +347,11 @@ mod test {
"select timestamp '2024-11-07T04:55:12.345+03:00' as time from t;",
"select timestamp(3) '2024-11-07 01:55:12.345 UTC' as time from t;",
);

check_posql_intermediate_ast_to_sqlparser_equivalence(
"select timestamp '2024-11-07T04:55:12+00:00' as time from t;",
"select timestamp(0) '2024-11-07 04:55:12 UTC' as time from t;",
);
}

// Check that PoSQL intermediate AST can be converted to SQL parser AST and that the two are equal.
Expand Down
5 changes: 3 additions & 2 deletions crates/proof-of-sql/benches/bench_append_rows.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use proof_of_sql::{
DoryCommitment, DoryProverPublicSetup, DoryScalar, ProverSetup, PublicParameters,
},
};
use proof_of_sql_parser::posql_time::{PoSQLTimeUnit, PoSQLTimeZone};
use proof_of_sql_parser::posql_time::PoSQLTimeUnit;
use rand::Rng;
use sqlparser::ast::TimezoneInfo;

/// Bench dory performance when appending rows to a table. This includes the computation of
/// commitments. Chose the number of columns to randomly generate across supported `PoSQL`
Expand Down Expand Up @@ -121,7 +122,7 @@ pub fn generate_random_owned_table<S: Scalar>(
"timestamptz" => columns.push(timestamptz(
&*identifier,
PoSQLTimeUnit::Second,
PoSQLTimeZone::utc(),
TimezoneInfo::None,
vec![rng.gen::<i64>(); num_rows],
)),
_ => unreachable!(),
Expand Down
120 changes: 65 additions & 55 deletions crates/proof-of-sql/src/base/arrow/arrow_array_to_column_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,60 +202,69 @@ impl ArrayRefExt for ArrayRef {
}
}
// Handle all possible TimeStamp TimeUnit instances
DataType::Timestamp(time_unit, tz) => match time_unit {
ArrowTimeUnit::Second => {
if let Some(array) = self.as_any().downcast_ref::<TimestampSecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
DataType::Timestamp(time_unit, tz) => {
let timezone = PoSQLTimeZone::try_from(tz)?;
match time_unit {
ArrowTimeUnit::Second => {
if let Some(array) = self.as_any().downcast_ref::<TimestampSecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Second,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Millisecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampMillisecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Millisecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampMillisecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Millisecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Microsecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampMicrosecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Microsecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampMicrosecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Microsecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
ArrowTimeUnit::Nanosecond => {
if let Some(array) = self.as_any().downcast_ref::<TimestampNanosecondArray>() {
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(tz)?,
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
ArrowTimeUnit::Nanosecond => {
if let Some(array) =
self.as_any().downcast_ref::<TimestampNanosecondArray>()
{
Ok(Column::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
timezone.into(),
&array.values()[range.start..range.end],
))
} else {
Err(ArrowArrayToColumnConversionError::UnsupportedType {
datatype: self.data_type().clone(),
})
}
}
}
},
}
DataType::Utf8 => {
if let Some(array) = self.as_any().downcast_ref::<StringArray>() {
let vals = alloc
Expand Down Expand Up @@ -292,6 +301,7 @@ mod tests {
use alloc::sync::Arc;
use arrow::array::Decimal256Builder;
use core::str::FromStr;
use sqlparser::ast::TimezoneInfo;

#[test]
fn we_can_convert_timestamp_array_normal_range() {
Expand All @@ -305,7 +315,7 @@ mod tests {
let result = array.to_column::<TestScalar>(&alloc, &(1..3), None);
assert_eq!(
result.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3])
);
}

Expand All @@ -323,7 +333,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}

Expand All @@ -339,7 +349,7 @@ mod tests {
let result = array.to_column::<DoryScalar>(&alloc, &(1..1), None);
assert_eq!(
result.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}

Expand Down Expand Up @@ -1006,7 +1016,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[..])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[..])
);
}

Expand Down Expand Up @@ -1076,7 +1086,7 @@ mod tests {
array
.to_column::<TestScalar>(&alloc, &(1..3), None)
.unwrap(),
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &data[1..3])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &data[1..3])
);
}

Expand Down Expand Up @@ -1134,7 +1144,7 @@ mod tests {
.unwrap();
assert_eq!(
result,
Column::TimestampTZ(PoSQLTimeUnit::Second, PoSQLTimeZone::utc(), &[])
Column::TimestampTZ(PoSQLTimeUnit::Second, TimezoneInfo::None, &[])
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl TryFrom<DataType> for ColumnType {
};
Ok(ColumnType::TimestampTZ(
posql_time_unit,
PoSQLTimeZone::try_from(&timezone_option)?,
PoSQLTimeZone::try_from(&timezone_option)?.into(),
))
}
DataType::Utf8 => Ok(ColumnType::VarChar),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Second,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -252,7 +252,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Millisecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -266,7 +266,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Microsecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand All @@ -280,7 +280,7 @@ impl<S: Scalar> TryFrom<&ArrayRef> for OwnedColumn<S> {
let timestamps = array.values().iter().copied().collect::<Vec<i64>>();
Ok(OwnedColumn::TimestampTZ(
PoSQLTimeUnit::Nanosecond,
PoSQLTimeZone::try_from(timezone)?,
PoSQLTimeZone::try_from(timezone)?.into(),
timestamps,
))
}
Expand Down
Loading
Loading