From 033d9a6b4d03f751035acd9435082c0d73a18606 Mon Sep 17 00:00:00 2001 From: Sean Smith Date: Sat, 27 Jan 2024 13:33:40 -0600 Subject: [PATCH 1/2] Properly encode STRING_AGG in physical plan protobufs --- .../proto/src/physical_plan/to_proto.rs | 10 +++- .../tests/cases/roundtrip_physical_plan.rs | 55 ++++++++++++++----- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index f4e3f9e4dca7..cff32ca2f8c9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -49,9 +49,9 @@ use datafusion::physical_plan::expressions::{ CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, LastValue, LikeExpr, Literal, Max, Median, - Min, NegativeExpr, NotExpr, NthValue, Ntile, OrderSensitiveArrayAgg, Rank, RankType, - Regr, RegrType, RowNumber, Stddev, StddevPop, Sum, TryCastExpr, Variance, - VariancePop, WindowShift, + Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, + Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, + TryCastExpr, Variance, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -363,6 +363,10 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::FirstValueAgg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::LastValueAgg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::StringAgg + } else if aggr_expr.downcast_ref::().is_some() { + protobuf::AggregateFunction::NthValueAgg } else { return not_impl_err!("Aggregate function not supported: {expr:?}"); }; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9a95e103c294..29887ae40b2e 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -44,7 +44,8 @@ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, - GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, Sum, + GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, + StringAgg, Sum, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::FileSinkExec; @@ -328,20 +329,46 @@ fn rountrip_aggregate() -> Result<()> { let groups: Vec<(Arc, String)> = vec![(col("a", &schema)?, "unused".to_string())]; - let aggregates: Vec> = vec![Arc::new(Avg::new( - cast(col("b", &schema)?, &schema, DataType::Float64)?, - "AVG(b)".to_string(), - DataType::Float64, - ))]; + let test_cases: Vec>> = vec![ + // AVG + vec![Arc::new(Avg::new( + cast(col("b", &schema)?, &schema, DataType::Float64)?, + "AVG(b)".to_string(), + DataType::Float64, + ))], + // TODO: + // // NTH_VALUE + // vec![Arc::new(NthValueAgg::new( + // col("b", &schema)?, + // 1, + // "NTH_VALUE(b, 1)".to_string(), + // DataType::Int64, + // false, + // Vec::new(), + // Vec::new(), + // ))], + // STRING_AGG + vec![Arc::new(StringAgg::new( + cast(col("b", &schema)?, &schema, DataType::Utf8)?, + lit(ScalarValue::Utf8(Some(",".to_string()))), + "STRING_AGG(name, ',')".to_string(), + DataType::Utf8, + ))], + ]; - roundtrip_test(Arc::new(AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new_single(groups.clone()), - aggregates.clone(), - vec![None], - Arc::new(EmptyExec::new(schema.clone())), - schema, - )?)) + for aggregates in test_cases { + let schema = schema.clone(); + roundtrip_test(Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates, + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?))?; + } + + Ok(()) } #[test] From 53bc2df4b0c3614e84611ade1917bbcd7d8a3fb9 Mon Sep 17 00:00:00 2001 From: Sean Smith Date: Sat, 27 Jan 2024 13:48:51 -0600 Subject: [PATCH 2/2] reference issue for nth_value --- datafusion/proto/tests/cases/roundtrip_physical_plan.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 29887ae40b2e..38eb39000317 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -336,7 +336,7 @@ fn rountrip_aggregate() -> Result<()> { "AVG(b)".to_string(), DataType::Float64, ))], - // TODO: + // TODO: See // // NTH_VALUE // vec![Arc::new(NthValueAgg::new( // col("b", &schema)?,