From 1291abcded2ac43be4f5ee19593b66e7ab8242f3 Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Wed, 14 Aug 2024 19:49:11 +0400 Subject: [PATCH] feat: Support `PERCENTILE_CONT` planning --- Cargo.lock | 2 +- datafusion-cli/Cargo.lock | 2 +- datafusion/common/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- .../core/src/logical_plan/expr_rewriter.rs | 18 +- .../core/src/logical_plan/expr_schema.rs | 15 +- .../core/src/logical_plan/expr_visitor.rs | 16 +- .../optimizer/single_distinct_to_groupby.rs | 14 +- datafusion/core/src/optimizer/utils.rs | 32 +++- .../core/src/physical_plan/aggregates.rs | 126 ++++++++----- datafusion/core/src/physical_plan/planner.rs | 63 ++++++- .../core/src/physical_plan/windows/mod.rs | 9 +- datafusion/core/src/sql/planner.rs | 60 +++++-- datafusion/core/src/sql/utils.rs | 29 ++- datafusion/expr/Cargo.toml | 2 +- datafusion/expr/src/aggregate_function.rs | 19 ++ datafusion/expr/src/expr.rs | 43 ++++- datafusion/expr/src/expr_fn.rs | 11 ++ datafusion/expr/src/window_function.rs | 2 +- .../physical-expr/src/expressions/mod.rs | 2 + .../src/expressions/percentile_cont.rs | 169 ++++++++++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/from_proto.rs | 4 +- datafusion/proto/src/lib.rs | 1 + datafusion/proto/src/to_proto.rs | 4 + 25 files changed, 557 insertions(+), 91 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/percentile_cont.rs diff --git a/Cargo.lock b/Cargo.lock index 8527bb7925c7..dfe9b8876007 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2409,7 +2409,7 @@ dependencies = [ [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=6a54d27d3b75a04b9f9cbe309a83078aa54b32fd#6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=5fe1b77d1a91b80529a0b7af0b89411d3cba5137#5fe1b77d1a91b80529a0b7af0b89411d3cba5137" dependencies = [ "log", ] diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 46dbea44665b..e5ceae75d1d1 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1531,7 +1531,7 @@ checksum = "45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451" [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=6a54d27d3b75a04b9f9cbe309a83078aa54b32fd#6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=4670853207c76610da3faeee72f0665c1a816f3b#4670853207c76610da3faeee72f0665c1a816f3b" dependencies = [ "log", ] diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index 8abf1953bfdf..96917f246bd5 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -44,4 +44,4 @@ cranelift-module = { version = "0.82.0", optional = true } ordered-float = "2.10" parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "a03d4eef5640e05dddf99fc2357ad6d58b5337cb", features = ["arrow"], optional = true } pyo3 = { version = "0.16", optional = true } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index fc9f766778aa..4a36a18c8afb 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -79,7 +79,7 @@ pin-project-lite= "^0.2.7" pyo3 = { version = "0.16", optional = true } rand = "0.8" smallvec = { version = "1.6", features = ["union"] } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/core/src/logical_plan/expr_rewriter.rs b/datafusion/core/src/logical_plan/expr_rewriter.rs index 0d16d9674642..e7801ec702e7 100644 --- a/datafusion/core/src/logical_plan/expr_rewriter.rs +++ b/datafusion/core/src/logical_plan/expr_rewriter.rs @@ -252,11 +252,19 @@ impl ExprRewritable for Expr { args, fun, distinct, - } => Expr::AggregateFunction { - args: rewrite_vec(args, rewriter)?, - fun, - distinct, - }, + within_group, + } => { + let within_group = match within_group { + Some(within_group) => Some(rewrite_vec(within_group, rewriter)?), + None => None, + }; + Expr::AggregateFunction { + args: rewrite_vec(args, rewriter)?, + fun, + distinct, + within_group, + } + } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?)) diff --git a/datafusion/core/src/logical_plan/expr_schema.rs b/datafusion/core/src/logical_plan/expr_schema.rs index f7b4778adf7b..47e33bf6c2f3 100644 --- a/datafusion/core/src/logical_plan/expr_schema.rs +++ b/datafusion/core/src/logical_plan/expr_schema.rs @@ -92,12 +92,23 @@ impl ExprSchemable for Expr { .collect::>>()?; window_function::return_type(fun, &data_types) } - Expr::AggregateFunction { fun, args, .. } => { + Expr::AggregateFunction { + fun, + args, + within_group, + .. + } => { let data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - aggregate_function::return_type(fun, &data_types) + let within_group = within_group + .as_ref() + .unwrap_or(&vec![]) + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + aggregate_function::return_type(fun, &data_types, &within_group) } Expr::AggregateUDF { fun, args, .. } => { let data_types = args diff --git a/datafusion/core/src/logical_plan/expr_visitor.rs b/datafusion/core/src/logical_plan/expr_visitor.rs index e0befe0ddbc2..175f60c7106c 100644 --- a/datafusion/core/src/logical_plan/expr_visitor.rs +++ b/datafusion/core/src/logical_plan/expr_visitor.rs @@ -179,10 +179,24 @@ impl ExprVisitable for Expr { Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } | Expr::TableUDF { args, .. } - | Expr::AggregateFunction { args, .. } | Expr::AggregateUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::AggregateFunction { + args, within_group, .. + } => { + let visitor = args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = if let Some(within_group) = within_group.as_ref() { + within_group + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))? + } else { + visitor + }; + Ok(visitor) + } Expr::WindowFunction { args, partition_by, diff --git a/datafusion/core/src/optimizer/single_distinct_to_groupby.rs b/datafusion/core/src/optimizer/single_distinct_to_groupby.rs index 2172614f31e2..af3518f7ad8f 100644 --- a/datafusion/core/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/core/src/optimizer/single_distinct_to_groupby.rs @@ -80,6 +80,7 @@ fn optimize(plan: &LogicalPlan) -> Result { fun: fun.clone(), args: vec![col(SINGLE_DISTINCT_ALIAS)], distinct: false, + within_group: None, } } _ => agg_expr.clone(), @@ -168,13 +169,21 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> bool { .iter() .filter(|expr| { let mut is_distinct = false; - if let Expr::AggregateFunction { distinct, args, .. } = expr { + let mut is_within_group = false; + if let Expr::AggregateFunction { + distinct, + args, + within_group, + .. + } = expr + { is_distinct = *distinct; + is_within_group = within_group.is_some(); args.iter().for_each(|expr| { fields_set.insert(expr.name(input.schema()).unwrap()); }) } - is_distinct + is_distinct && !is_within_group }) .count() == aggr_expr.len() @@ -314,6 +323,7 @@ mod tests { fun: aggregates::AggregateFunction::Max, distinct: true, args: vec![col("b")], + within_group: None, }, ], )? diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 0e908dfde963..37f62925d350 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -339,8 +339,14 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::ScalarFunction { args, .. } | Expr::ScalarUDF { args, .. } | Expr::TableUDF { args, .. } - | Expr::AggregateFunction { args, .. } | Expr::AggregateUDF { args, .. } => Ok(args.clone()), + Expr::AggregateFunction { + args, within_group, .. + } => Ok(args + .iter() + .chain(within_group.as_ref().unwrap_or(&vec![])) + .cloned() + .collect()), Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(exprs.clone()), GroupingSet::Cube(exprs) => Ok(exprs.clone()), @@ -517,11 +523,25 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { }) } } - Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { - fun: fun.clone(), - args: expressions.to_vec(), - distinct: *distinct, - }), + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => { + let args_limit = args.len(); + let within_group = if expressions.len() > args_limit { + Some(expressions[args_limit..].to_vec()) + } else { + None + }; + Ok(Expr::AggregateFunction { + fun: fun.clone(), + args: expressions[..args_limit].to_vec(), + distinct: *distinct, + within_group, + }) + } Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF { fun: fun.clone(), args: expressions.to_vec(), diff --git a/datafusion/core/src/physical_plan/aggregates.rs b/datafusion/core/src/physical_plan/aggregates.rs index 07d85d31fa34..b65fc78adb2d 100644 --- a/datafusion/core/src/physical_plan/aggregates.rs +++ b/datafusion/core/src/physical_plan/aggregates.rs @@ -44,6 +44,7 @@ pub fn create_aggregate_expr( input_phy_exprs: &[Arc], input_schema: &Schema, name: impl Into, + within_group: Vec<(Arc, bool, bool)>, ) -> Result> { let name = name.into(); // get the coerced phy exprs if some expr need to be wrapped with the try cast. @@ -69,7 +70,11 @@ pub fn create_aggregate_expr( .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - let return_type = return_type(fun, &input_phy_types)?; + let within_group_types = within_group + .iter() + .map(|(e, _, _)| e.data_type(input_schema)) + .collect::>>()?; + let return_type = return_type(fun, &input_phy_types, &within_group_types)?; Ok(match (fun, distinct) { (AggregateFunction::Count, false) => Arc::new(expressions::Count::new( @@ -239,6 +244,15 @@ pub fn create_aggregate_expr( .to_string(), )); } + (AggregateFunction::PercentileCont, _) => { + Arc::new(expressions::PercentileCont::new( + // Pass in the desired percentile expr + name, + coerced_phy_exprs, + return_type, + within_group, + )?) + } (AggregateFunction::ApproxMedian, false) => { Arc::new(expressions::ApproxMedian::new( coerced_phy_exprs[0].clone(), @@ -301,6 +315,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; match fun { AggregateFunction::Count => { @@ -344,6 +359,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; match fun { AggregateFunction::Count => { @@ -402,6 +418,7 @@ mod tests { &input_phy_exprs[..], &input_schema, "c1", + vec![], ) .expect("failed to create aggregate expr"); @@ -431,6 +448,7 @@ mod tests { &input_phy_exprs[..], &input_schema, "c1", + vec![], ) .expect_err("should fail due to invalid percentile"); @@ -462,6 +480,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; match fun { AggregateFunction::Min => { @@ -511,6 +530,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; match fun { AggregateFunction::Sum => { @@ -573,6 +593,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; if fun == AggregateFunction::Variance { assert!(result_agg_phy_exprs.as_any().is::()); @@ -611,6 +632,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; if fun == AggregateFunction::Variance { assert!(result_agg_phy_exprs.as_any().is::()); @@ -649,6 +671,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; if fun == AggregateFunction::Variance { assert!(result_agg_phy_exprs.as_any().is::()); @@ -687,6 +710,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; if fun == AggregateFunction::Variance { assert!(result_agg_phy_exprs.as_any().is::()); @@ -734,6 +758,7 @@ mod tests { &input_phy_exprs[0..2], &input_schema, "c1", + vec![], )?; if fun == AggregateFunction::Covariance { assert!(result_agg_phy_exprs.as_any().is::()); @@ -781,6 +806,7 @@ mod tests { &input_phy_exprs[0..2], &input_schema, "c1", + vec![], )?; if fun == AggregateFunction::Covariance { assert!(result_agg_phy_exprs.as_any().is::()); @@ -828,6 +854,7 @@ mod tests { &input_phy_exprs[0..2], &input_schema, "c1", + vec![], )?; if fun == AggregateFunction::Covariance { assert!(result_agg_phy_exprs.as_any().is::()); @@ -866,6 +893,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; if fun == AggregateFunction::ApproxMedian { @@ -896,6 +924,7 @@ mod tests { &input_phy_exprs[0..1], &input_schema, "c1", + vec![], )?; match fun { AggregateFunction::BoolAnd => { @@ -922,15 +951,18 @@ mod tests { #[test] fn test_median() -> Result<()> { - let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Utf8]); + let observed = + return_type(&AggregateFunction::ApproxMedian, &[DataType::Utf8], &[]); assert!(observed.is_err()); - let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Int32])?; + let observed = + return_type(&AggregateFunction::ApproxMedian, &[DataType::Int32], &[])?; assert_eq!(DataType::Int32, observed); let observed = return_type( &AggregateFunction::ApproxMedian, &[DataType::Decimal(10, 6)], + &[], ); assert!(observed.is_err()); @@ -939,19 +971,20 @@ mod tests { #[test] fn test_min_max() -> Result<()> { - let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; + let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8], &[])?; assert_eq!(DataType::Utf8, observed); - let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?; + let observed = return_type(&AggregateFunction::Max, &[DataType::Int32], &[])?; assert_eq!(DataType::Int32, observed); // test decimal for min - let observed = return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)])?; + let observed = + return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)], &[])?; assert_eq!(DataType::Decimal(10, 6), observed); // test decimal for max let observed = - return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)])?; + return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)], &[])?; assert_eq!(DataType::Decimal(28, 13), observed); Ok(()) @@ -959,22 +992,24 @@ mod tests { #[test] fn test_sum_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32])?; + let observed = return_type(&AggregateFunction::Sum, &[DataType::Int32], &[])?; assert_eq!(DataType::Int64, observed); - let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8])?; + let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt8], &[])?; assert_eq!(DataType::UInt64, observed); - let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32])?; + let observed = return_type(&AggregateFunction::Sum, &[DataType::Float32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64])?; + let observed = return_type(&AggregateFunction::Sum, &[DataType::Float64], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(10, 5)])?; + let observed = + return_type(&AggregateFunction::Sum, &[DataType::Decimal(10, 5)], &[])?; assert_eq!(DataType::Decimal(20, 5), observed); - let observed = return_type(&AggregateFunction::Sum, &[DataType::Decimal(35, 5)])?; + let observed = + return_type(&AggregateFunction::Sum, &[DataType::Decimal(35, 5)], &[])?; assert_eq!(DataType::Decimal(38, 5), observed); Ok(()) @@ -982,71 +1017,78 @@ mod tests { #[test] fn test_sum_no_utf8() { - let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8]); + let observed = return_type(&AggregateFunction::Sum, &[DataType::Utf8], &[]); assert!(observed.is_err()); } #[test] fn test_sum_upcasts() -> Result<()> { - let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt32])?; + let observed = return_type(&AggregateFunction::Sum, &[DataType::UInt32], &[])?; assert_eq!(DataType::UInt64, observed); Ok(()) } #[test] fn test_count_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?; + let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8], &[])?; assert_eq!(DataType::Int64, observed); - let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?; + let observed = return_type(&AggregateFunction::Count, &[DataType::Int8], &[])?; assert_eq!(DataType::Int64, observed); let observed = - return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?; + return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)], &[])?; assert_eq!(DataType::Int64, observed); Ok(()) } #[test] fn test_avg_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Avg, &[DataType::Float32])?; + let observed = return_type(&AggregateFunction::Avg, &[DataType::Float32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64])?; + let observed = return_type(&AggregateFunction::Avg, &[DataType::Float64], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32])?; + let observed = return_type(&AggregateFunction::Avg, &[DataType::Int32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(10, 6)])?; + let observed = + return_type(&AggregateFunction::Avg, &[DataType::Decimal(10, 6)], &[])?; assert_eq!(DataType::Decimal(14, 10), observed); - let observed = return_type(&AggregateFunction::Avg, &[DataType::Decimal(36, 6)])?; + let observed = + return_type(&AggregateFunction::Avg, &[DataType::Decimal(36, 6)], &[])?; assert_eq!(DataType::Decimal(38, 10), observed); Ok(()) } #[test] fn test_avg_no_utf8() { - let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]); + let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8], &[]); assert!(observed.is_err()); } #[test] fn test_variance_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Variance, &[DataType::Float32])?; + let observed = + return_type(&AggregateFunction::Variance, &[DataType::Float32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Variance, &[DataType::Float64])?; + let observed = + return_type(&AggregateFunction::Variance, &[DataType::Float64], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Variance, &[DataType::Int32])?; + let observed = + return_type(&AggregateFunction::Variance, &[DataType::Int32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Variance, &[DataType::UInt32])?; + let observed = + return_type(&AggregateFunction::Variance, &[DataType::UInt32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Variance, &[DataType::Int64])?; + let observed = + return_type(&AggregateFunction::Variance, &[DataType::Int64], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -1054,25 +1096,27 @@ mod tests { #[test] fn test_variance_no_utf8() { - let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8]); + let observed = return_type(&AggregateFunction::Variance, &[DataType::Utf8], &[]); assert!(observed.is_err()); } #[test] fn test_stddev_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float32])?; + let observed = + return_type(&AggregateFunction::Stddev, &[DataType::Float32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Float64])?; + let observed = + return_type(&AggregateFunction::Stddev, &[DataType::Float64], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int32])?; + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Stddev, &[DataType::UInt32])?; + let observed = return_type(&AggregateFunction::Stddev, &[DataType::UInt32], &[])?; assert_eq!(DataType::Float64, observed); - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int64])?; + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Int64], &[])?; assert_eq!(DataType::Float64, observed); Ok(()) @@ -1080,13 +1124,14 @@ mod tests { #[test] fn test_stddev_no_utf8() { - let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8]); + let observed = return_type(&AggregateFunction::Stddev, &[DataType::Utf8], &[]); assert!(observed.is_err()); } #[test] fn test_bool_and_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::BoolAnd, &[DataType::Boolean])?; + let observed = + return_type(&AggregateFunction::BoolAnd, &[DataType::Boolean], &[])?; assert_eq!(DataType::Boolean, observed); Ok(()) @@ -1094,13 +1139,14 @@ mod tests { #[test] fn test_bool_and_no_utf8() { - let observed = return_type(&AggregateFunction::BoolAnd, &[DataType::Utf8]); + let observed = return_type(&AggregateFunction::BoolAnd, &[DataType::Utf8], &[]); assert!(observed.is_err()); } #[test] fn test_bool_or_return_type() -> Result<()> { - let observed = return_type(&AggregateFunction::BoolOr, &[DataType::Boolean])?; + let observed = + return_type(&AggregateFunction::BoolOr, &[DataType::Boolean], &[])?; assert_eq!(DataType::Boolean, observed); Ok(()) @@ -1108,7 +1154,7 @@ mod tests { #[test] fn test_bool_or_no_utf8() { - let observed = return_type(&AggregateFunction::BoolOr, &[DataType::Utf8]); + let observed = return_type(&AggregateFunction::BoolOr, &[DataType::Utf8], &[]); assert!(observed.is_err()); } } diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 6cd1cf91dccf..d7cddea14999 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -93,6 +93,26 @@ fn create_function_physical_name( Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) } +fn create_aggregate_function_physical_name( + fun: &str, + distinct: bool, + args: &[Expr], + within_group: Option<&[Expr]>, +) -> Result { + let function_name = create_function_physical_name(fun, distinct, args)?; + let within_group = match within_group { + Some(within_group) => { + let names = within_group + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()?; + format!(" WITHIN GROUP (ORDER BY {})", names.join(", ")) + } + None => "".to_string(), + }; + Ok(format!("{}{}", function_name, within_group)) +} + fn physical_name(e: &Expr) -> Result { create_physical_name(e, true) } @@ -189,8 +209,14 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { fun, distinct, args, + within_group, .. - } => create_function_physical_name(&fun.to_string(), *distinct, args), + } => create_aggregate_function_physical_name( + &fun.to_string(), + *distinct, + args, + within_group.as_ref().map(|exprs| exprs.as_slice()), + ), Expr::AggregateUDF { fun, args } => { let mut names = Vec::with_capacity(args.len()); for e in args { @@ -324,9 +350,16 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { Ok(format!("{} SIMILAR TO {}{}", expr, pattern, escape)) } } - Expr::Sort { .. } => Err(DataFusionError::Internal( - "Create physical name does not support sort expression".to_string(), - )), + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let expr = create_physical_name(expr, false)?; + let asc = if *asc { "ASC" } else { "DESC" }; + let first = if *nulls_first { "FIRST" } else { "LAST" }; + Ok(format!("{} {} NULLS {}", expr, asc, first)) + } Expr::Wildcard => Err(DataFusionError::Internal( "Create physical name does not support wildcard".to_string(), )), @@ -1562,6 +1595,7 @@ pub fn create_aggregate_expr_with_name( fun, distinct, args, + within_group, .. } => { let args = args @@ -1575,12 +1609,33 @@ pub fn create_aggregate_expr_with_name( ) }) .collect::>>()?; + let within_group = within_group + .as_ref() + .unwrap_or(&vec![]) + .iter() + .map(|wg| { + let Expr::Sort { expr, asc, nulls_first } = wg else { + return Err(DataFusionError::Internal(format!( + "Non-Sort expression encountered in ORDER BY: {}", + wg + ))); + }; + let expr = create_physical_expr( + expr, + logical_input_schema, + physical_input_schema, + execution_props, + )?; + Ok((expr, *asc, *nulls_first)) + }) + .collect::>>()?; aggregates::create_aggregate_expr( fun, *distinct, &args, physical_input_schema, name, + within_group, ) } Expr::AggregateUDF { fun, args, .. } => { diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 9b69db7e6688..4f3d98661cba 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -57,7 +57,14 @@ pub fn create_window_expr( ) -> Result> { Ok(match fun { WindowFunction::AggregateFunction(fun) => Arc::new(AggregateWindowExpr::new( - aggregates::create_aggregate_expr(fun, false, args, input_schema, name)?, + aggregates::create_aggregate_expr( + fun, + false, + args, + input_schema, + name, + vec![], + )?, partition_by, order_by, window_frame, diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 1c0402f631c0..d30d309b1cc3 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -58,7 +58,7 @@ use sqlparser::ast::{ Fetch, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, ObjectName, Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator, ShowStatementFilter, TableFactor, TableWithJoins, TrimWhereField, UnaryOperator, - Value, Values as SQLValues, + Value, Values as SQLValues, WithinGroup, }; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; use sqlparser::ast::{ObjectType, OrderByExpr, Statement}; @@ -1437,14 +1437,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by_rex = order_by .into_iter() - .map(|e| self.order_by_to_sort_expr(e, plan.schema())) + .map(|e| self.order_by_to_sort_expr(e, plan.schema(), true)) .collect::>>()?; LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() } /// convert sql OrderByExpr to Expr::Sort - fn order_by_to_sort_expr(&self, e: OrderByExpr, schema: &DFSchema) -> Result { + fn order_by_to_sort_expr( + &self, + e: OrderByExpr, + schema: &DFSchema, + parse_indexes: bool, + ) -> Result { let OrderByExpr { asc, expr, @@ -1452,7 +1457,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = e; let expr = match expr { - SQLExpr::Value(Value::Number(v, _)) => { + SQLExpr::Value(Value::Number(v, _)) if parse_indexes => { let field_index = v .parse::() .map_err(|err| DataFusionError::Plan(err.to_string()))?; @@ -2310,7 +2315,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by = window .order_by .into_iter() - .map(|e| self.order_by_to_sort_expr(e, schema)) + .map(|e| self.order_by_to_sort_expr(e, schema, true)) .collect::>>()?; let window_frame = window .window_frame @@ -2372,6 +2377,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun, distinct, args, + within_group: None, }); }; @@ -2438,6 +2444,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::ArrayAgg(array_agg) => self.parse_array_agg(array_agg, schema), + SQLExpr::WithinGroup(within_group) => self.parse_within_group(within_group, schema), + _ => Err(DataFusionError::NotImplemented(format!( "Unsupported ast node {:?} in sqltorel", sql @@ -2455,7 +2463,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { distinct, expr, limit, - within_group, .. } = array_agg; @@ -2474,12 +2481,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))); } - if within_group { - return Err(DataFusionError::NotImplemented( - "WITHIN GROUP not supported in ARRAY_AGG".to_string(), - )); - } - let args = vec![self.sql_expr_to_logical_expr(*expr, input_schema)?]; // next, aggregate built-ins let fun = aggregates::AggregateFunction::ArrayAgg; @@ -2488,9 +2489,35 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun, distinct, args, + within_group: None, }) } + fn parse_within_group( + &self, + within_group: WithinGroup, + input_schema: &DFSchema, + ) -> Result { + let mut expr = self.sql_expr_to_logical_expr(*within_group.expr, input_schema)?; + if let Expr::AggregateFunction { + within_group: agg_within_group, + .. + } = &mut expr + { + let order_by = within_group + .order_by + .into_iter() + .map(|e| self.order_by_to_sort_expr(e, input_schema, false)) + .collect::>>()?; + *agg_within_group = Some(order_by); + return Ok(expr); + } + Err(DataFusionError::NotImplemented( + "WITHIN GROUP is only supported with built-in aggregate functions" + .to_string(), + )) + } + fn function_args_to_expr( &self, args: Vec, @@ -4130,6 +4157,15 @@ mod tests { quick_test(sql, expected); } + #[test] + fn select_percentile_cont() { + let sql = "SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY age) FROM person"; + let expected = "Projection: #PERCENTILECONT(Float64(0.5)) WITHIN GROUP (ORDER BY [#person.age ASC NULLS LAST])\ + \n Aggregate: groupBy=[[]], aggr=[[PERCENTILECONT(Float64(0.5)) WITHIN GROUP (ORDER BY #person.age ASC NULLS LAST)]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; diff --git a/datafusion/core/src/sql/utils.rs b/datafusion/core/src/sql/utils.rs index 7f6c86989462..651869d38dd7 100644 --- a/datafusion/core/src/sql/utils.rs +++ b/datafusion/core/src/sql/utils.rs @@ -293,14 +293,27 @@ where fun, args, distinct, - } => Ok(Expr::AggregateFunction { - fun: fun.clone(), - args: args - .iter() - .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, - distinct: *distinct, - }), + within_group, + } => { + let within_group = match within_group { + Some(within_group) => Some( + within_group + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + ), + None => None, + }; + Ok(Expr::AggregateFunction { + fun: fun.clone(), + args: args + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, + distinct: *distinct, + within_group, + }) + } Expr::WindowFunction { fun, args, diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index bceb80c09085..f9dccf30e2d4 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,4 +38,4 @@ path = "src/lib.rs" ahash = { version = "0.7", default-features = false } arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "a03d4eef5640e05dddf99fc2357ad6d58b5337cb", features = ["prettyprint"] } datafusion-common = { path = "../common", version = "7.0.0" } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index f81efe5c35e3..16f495420d82 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -84,6 +84,8 @@ pub enum AggregateFunction { ApproxPercentileCont, /// Approximate continuous percentile function with weight ApproxPercentileContWithWeight, + /// Continuous percentile function + PercentileCont, /// ApproxMedian ApproxMedian, /// BoolAnd @@ -124,6 +126,7 @@ impl FromStr for AggregateFunction { "approx_percentile_cont_with_weight" => { AggregateFunction::ApproxPercentileContWithWeight } + "percentile_cont" => AggregateFunction::PercentileCont, "approx_median" => AggregateFunction::ApproxMedian, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, @@ -142,6 +145,7 @@ impl FromStr for AggregateFunction { pub fn return_type( fun: &AggregateFunction, input_expr_types: &[DataType], + within_group_types: &[DataType], ) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. @@ -178,6 +182,7 @@ pub fn return_type( AggregateFunction::ApproxPercentileContWithWeight => { Ok(coerced_data_types[0].clone()) } + AggregateFunction::PercentileCont => Ok(within_group_types[0].clone()), AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), AggregateFunction::BoolAnd | AggregateFunction::BoolOr => Ok(DataType::Boolean), } @@ -324,6 +329,15 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::PercentileCont => { + if !matches!(input_types[0], DataType::Float64) { + return Err(DataFusionError::Plan(format!( + "The percentile argument for {:?} must be Float64, not {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } AggregateFunction::ApproxMedian => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( @@ -395,6 +409,11 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect(), Volatility::Immutable, ), + AggregateFunction::PercentileCont => Signature::one_of( + // Accept a float64 percentile paired with any numeric value, plus bool values + vec![TypeSignature::Exact(vec![DataType::Float64])], + Volatility::Immutable, + ), AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Signature::exact(vec![DataType::Boolean], Volatility::Immutable) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index d380d9360a8d..5ce7300bbb70 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -223,6 +223,8 @@ pub enum Expr { args: Vec, /// Whether this is a DISTINCT aggregation or not distinct: bool, + /// WITHIN GROUP (ORDER BY ...) expression + within_group: Option>, }, /// Represents the call of a window function with arguments. WindowFunction { @@ -494,7 +496,19 @@ impl std::fmt::Display for Expr { ref args, /// Whether this is a DISTINCT aggregation or not ref distinct, - } => fmt_function(f, &fun.to_string(), *distinct, args, true), + /// Aggregate function's WITHIN GROUP expression + ref within_group, + } => { + fmt_function(f, &fun.to_string(), *distinct, args, true)?; + if let Some(within_group) = within_group.as_ref() { + let exprs = within_group + .iter() + .map(|arg| format!("{}", arg)) + .collect::>(); + write!(f, " WITHIN GROUP (ORDER BY {})", exprs.join(", "))?; + } + Ok(()) + } Expr::ScalarFunction { /// Name of the function ref fun, @@ -608,8 +622,19 @@ impl fmt::Debug for Expr { fun, distinct, ref args, + within_group, .. - } => fmt_function(f, &fun.to_string(), *distinct, args, true), + } => { + fmt_function(f, &fun.to_string(), *distinct, args, false)?; + if let Some(within_group) = within_group.as_ref() { + let exprs = within_group + .iter() + .map(|arg| format!("{:?}", arg)) + .collect::>(); + write!(f, " WITHIN GROUP (ORDER BY {})", exprs.join(", "))?; + } + Ok(()) + } Expr::AggregateUDF { fun, ref args, .. } => { fmt_function(f, &fun.name, false, args, false) } @@ -956,8 +981,20 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { fun, distinct, args, + within_group, .. - } => create_function_name(&fun.to_string(), *distinct, args, input_schema), + } => { + let mut parts = vec![create_function_name( + &fun.to_string(), + *distinct, + args, + input_schema, + )?]; + if let Some(within_group) = within_group.as_ref() { + parts.push(format!("WITHIN GROUP (ORDER BY {:?})", within_group)); + } + Ok(parts.join(" ")) + } Expr::AggregateUDF { fun, args } => { let mut names = Vec::with_capacity(args.len()); for e in args { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8d6f0d08e65a..2b55b59dd932 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -57,6 +57,7 @@ pub fn min(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Min, distinct: false, args: vec![expr], + within_group: None, } } @@ -66,6 +67,7 @@ pub fn max(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Max, distinct: false, args: vec![expr], + within_group: None, } } @@ -75,6 +77,7 @@ pub fn sum(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Sum, distinct: false, args: vec![expr], + within_group: None, } } @@ -84,6 +87,7 @@ pub fn avg(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Avg, distinct: false, args: vec![expr], + within_group: None, } } @@ -93,6 +97,7 @@ pub fn count(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Count, distinct: false, args: vec![expr], + within_group: None, } } @@ -102,6 +107,7 @@ pub fn count_distinct(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::Count, distinct: true, args: vec![expr], + within_group: None, } } @@ -162,6 +168,7 @@ pub fn approx_distinct(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::ApproxDistinct, distinct: false, args: vec![expr], + within_group: None, } } @@ -171,6 +178,7 @@ pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { fun: aggregate_function::AggregateFunction::ApproxPercentileCont, distinct: false, args: vec![expr, percentile], + within_group: None, } } @@ -184,6 +192,7 @@ pub fn approx_percentile_cont_with_weight( fun: aggregate_function::AggregateFunction::ApproxPercentileContWithWeight, distinct: false, args: vec![expr, weight_expr, percentile], + within_group: None, } } @@ -193,6 +202,7 @@ pub fn bool_and(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::BoolAnd, distinct: false, args: vec![expr], + within_group: None, } } @@ -202,6 +212,7 @@ pub fn bool_or(expr: Expr) -> Expr { fun: aggregate_function::AggregateFunction::BoolOr, distinct: false, args: vec![expr], + within_group: None, } } diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index 726ca5c726cd..05c2435a748e 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -148,7 +148,7 @@ pub fn return_type( ) -> Result { match fun { WindowFunction::AggregateFunction(fun) => { - aggregate_function::return_type(fun, input_expr_types) + aggregate_function::return_type(fun, input_expr_types, &[]) } WindowFunction::BuiltInWindowFunction(fun) => { return_type_for_built_in(fun, input_expr_types) diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index d55814deae86..c92eb5d88fe9 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -49,6 +49,7 @@ mod not; mod nth_value; mod nullif; mod outer_column; +mod percentile_cont; mod rank; mod row_number; mod stats; @@ -95,6 +96,7 @@ pub use not::{not, NotExpr}; pub use nth_value::NthValue; pub use nullif::nullif_func; pub use outer_column::OuterColumn; +pub use percentile_cont::PercentileCont; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; pub use stats::StatsType; diff --git a/datafusion/physical-expr/src/expressions/percentile_cont.rs b/datafusion/physical-expr/src/expressions/percentile_cont.rs new file mode 100644 index 000000000000..2d96be5c16e2 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/percentile_cont.rs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{format_state_name, Literal}; + +use crate::{AggregateExpr, PhysicalExpr}; +use arrow::datatypes::{DataType, Field}; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::Accumulator; + +use std::{any::Any, sync::Arc}; + +/// PERCENTILE_CONT aggregate expression +#[derive(Debug)] +pub struct PercentileCont { + name: String, + input_data_type: DataType, + _percentile: f64, + expr: Arc, + _asc: bool, + _nulls_first: bool, +} + +impl PercentileCont { + /// Create a new [`PercentileCont`] aggregate function. + pub fn new( + name: impl Into, + expr: Vec>, + input_data_type: DataType, + within_group: Vec<(Arc, bool, bool)>, + ) -> Result { + // Arguments should be [DesiredPercentileLiteral] + debug_assert_eq!(expr.len(), 1); + + // Extract the desired percentile literal + let lit = expr[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "desired percentile argument must be float literal".to_string(), + ) + })? + .value(); + let percentile = match lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q as f64, + got => return Err(DataFusionError::NotImplemented(format!( + "Percentile value for 'PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + got + ))) + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return Err(DataFusionError::Plan(format!( + "Percentile value must be between 0.0 and 1.0 inclusive, {} is invalid", + percentile + ))); + } + + // Ensure that WITHIN GROUP contains exactly one value + if within_group.len() != 1 { + return Err(DataFusionError::Plan( + "PERCENTILE_CONT ... WITHIN GROUP must have exactly one expression in ORDER BY".to_string(), + )); + } + let (order_by, asc, nulls_first) = &within_group[0]; + + // ORDER BY type must be Float64 or one of the Interval types + match input_data_type { + DataType::Float64 | DataType::Interval(_) => (), + typ => { + return Err(DataFusionError::Plan(format!( + "WITHIN GROUP (ORDER BY ...) must be Float64 or Interval, got {}", + typ + ))) + } + } + + Ok(Self { + name: name.into(), + input_data_type, + _percentile: percentile, + // The physical expr to evaluate during accumulation + expr: order_by.clone(), + _asc: *asc, + _nulls_first: *nulls_first, + }) + } +} + +impl AggregateExpr for PercentileCont { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), false)) + } + + #[allow(rustdoc::private_intra_doc_links)] + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + &format_state_name(&self.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "count"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "max"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "min"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "centroids"), + DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + false, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + Err(DataFusionError::NotImplemented( + "percentile_cont(...) execution is not implemented".to_string(), + )) + } + + fn name(&self) -> &str { + &self.name + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d159e1511f87..a862f76fce5e 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -224,6 +224,7 @@ enum AggregateFunction { // Cubesql BOOL_AND = 17; BOOL_OR = 18; + PERCENTILE_CONT = 19; } message AggregateExprNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 893e011f9c3b..70417060366e 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -462,6 +462,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight } + protobuf::AggregateFunction::PercentileCont => Self::PercentileCont, protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, @@ -975,7 +976,8 @@ pub fn parse_expr( .iter() .map(|e| parse_expr(e, registry)) .collect::, _>>()?, - distinct: false, // TODO + distinct: false, // TODO + within_group: None, // TODO }) } ExprType::Alias(alias) => Ok(Expr::Alias( diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 25e522c5cb6f..8f511354d1f1 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -760,6 +760,7 @@ mod roundtrip_tests { fun: aggregates::AggregateFunction::ApproxPercentileCont, args: vec![col("bananas"), lit(0.42_f32)], distinct: false, + within_group: None, }; let ctx = SessionContext::new(); diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 90bc1b3e050d..fceac6075a90 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -312,6 +312,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight } + AggregateFunction::PercentileCont => Self::PercentileCont, AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, @@ -522,6 +523,9 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::ApproxPercentileContWithWeight => { protobuf::AggregateFunction::ApproxPercentileContWithWeight } + AggregateFunction::PercentileCont => { + protobuf::AggregateFunction::PercentileCont + } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max,