From 4670853207c76610da3faeee72f0665c1a816f3b Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Wed, 4 Sep 2024 18:19:30 +0400 Subject: [PATCH] feat: `WITHIN GROUP` expression support --- src/ast/mod.rs | 23 +++++++++++++++++++++++ src/parser.rs | 23 +++++++++++++++++++++-- tests/sqlparser_common.rs | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 2 deletions(-) diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 5ae4679af..7b3b300b8 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -348,6 +348,8 @@ pub enum Expr { ListAgg(ListAgg), /// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)` ArrayAgg(ArrayAgg), + /// The `WITHIN GROUP` expr `... WITHIN GROUP (ORDER BY ...)` + WithinGroup(WithinGroup), /// The `GROUPING SETS` expr. GroupingSets(Vec>), /// The `CUBE` expr. @@ -549,6 +551,7 @@ impl fmt::Display for Expr { Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s), Expr::ListAgg(listagg) => write!(f, "{}", listagg), Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg), + Expr::WithinGroup(withingroup) => write!(f, "{}", withingroup), Expr::GroupingSets(sets) => { write!(f, "GROUPING SETS (")?; let mut sep = ""; @@ -2523,6 +2526,26 @@ impl fmt::Display for ArrayAgg { } } +/// A `WITHIN GROUP` invocation ` WITHIN GROUP (ORDER BY )` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub struct WithinGroup { + pub expr: Box, + pub order_by: Vec, +} + +impl fmt::Display for WithinGroup { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} WITHIN GROUP (ORDER BY {})", + self.expr, + display_comma_separated(&self.order_by), + )?; + Ok(()) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub enum ObjectType { diff --git a/src/parser.rs b/src/parser.rs index 753c6a11d..ebd8b0bfe 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -627,14 +627,33 @@ impl<'a> Parser<'a> { None }; - Ok(Expr::Function(Function { + let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) { + self.expect_token(&Token::LParen)?; + self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?; + let order_by_expr = self.parse_comma_separated(Parser::parse_order_by_expr)?; + self.expect_token(&Token::RParen)?; + Some(order_by_expr) + } else { + None + }; + + let function = Expr::Function(Function { name, args, over, distinct, special: false, approximate: false, - })) + }); + + Ok(if let Some(within_group) = within_group { + Expr::WithinGroup(WithinGroup { + expr: Box::new(function), + order_by: within_group, + }) + } else { + function + }) } pub fn parse_time_functions(&mut self, name: ObjectName) -> Result { diff --git a/tests/sqlparser_common.rs b/tests/sqlparser_common.rs index 96071f1f2..9ec01a60a 100644 --- a/tests/sqlparser_common.rs +++ b/tests/sqlparser_common.rs @@ -1736,6 +1736,41 @@ fn parse_array_agg_func() { } } +#[test] +fn parse_within_group() { + let sql = "SELECT PERCENTILE_CONT(0.0) WITHIN GROUP (ORDER BY name ASC NULLS FIRST)"; + let select = verified_only_select(sql); + + #[cfg(feature = "bigdecimal")] + let value = bigdecimal::BigDecimal::from(0); + #[cfg(not(feature = "bigdecimal"))] + let value = "0.0".to_string(); + let expr = Expr::Value(Value::Number(value, false)); + let function = Expr::Function(Function { + name: ObjectName(vec![Ident::new("PERCENTILE_CONT")]), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))], + over: None, + distinct: false, + special: false, + approximate: false, + }); + let within_group = vec![OrderByExpr { + expr: Expr::Identifier(Ident { + value: "name".to_string(), + quote_style: None, + }), + asc: Some(true), + nulls_first: Some(true), + }]; + assert_eq!( + &Expr::WithinGroup(WithinGroup { + expr: Box::new(function), + order_by: within_group + }), + expr_from_projection(only(&select.projection)) + ); +} + #[test] fn parse_create_table() { let sql = "CREATE TABLE uk_cities (\