From a873f5156364f4357592c4bc9117887916e606f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Tue, 18 Jun 2024 22:08:13 +0800 Subject: [PATCH] Convert `StringAgg` to UDAF (#10945) * Convert StringAgg to UDAF * generate proto code * Fix bug * Fix * Add license * Add doc * Fix clippy * Remove aliases field * Add StringAgg proto test * Add roundtrip_expr_api test --- datafusion/expr/src/aggregate_function.rs | 8 - .../expr/src/type_coercion/aggregates.rs | 26 -- datafusion/functions-aggregate/src/lib.rs | 2 + .../functions-aggregate/src/string_agg.rs | 153 +++++++++++ .../physical-expr/src/aggregate/build_in.rs | 16 -- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/aggregate/string_agg.rs | 246 ------------------ .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 4 - .../proto/src/physical_plan/to_proto.rs | 5 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + .../tests/cases/roundtrip_physical_plan.rs | 23 +- .../sqllogictest/test_files/aggregate.slt | 16 ++ 17 files changed, 192 insertions(+), 321 deletions(-) create mode 100644 datafusion/functions-aggregate/src/string_agg.rs delete mode 100644 datafusion/physical-expr/src/aggregate/string_agg.rs diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index a7fbf26febb1..1cde1c5050a8 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -51,8 +51,6 @@ pub enum AggregateFunction { BoolAnd, /// Bool Or BoolOr, - /// String aggregation - StringAgg, } impl AggregateFunction { @@ -68,7 +66,6 @@ impl AggregateFunction { Grouping => "GROUPING", BoolAnd => "BOOL_AND", BoolOr => "BOOL_OR", - StringAgg => "STRING_AGG", } } } @@ -92,7 +89,6 @@ impl FromStr for AggregateFunction { "min" => AggregateFunction::Min, "array_agg" => AggregateFunction::ArrayAgg, "nth_value" => AggregateFunction::NthValue, - "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, // other @@ -146,7 +142,6 @@ impl AggregateFunction { )))), AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), - AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), } } } @@ -195,9 +190,6 @@ impl AggregateFunction { AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } - AggregateFunction::StringAgg => { - Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable) - } } } } diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index a216c98899fe..abe6d8b1823d 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -145,23 +145,6 @@ pub fn coerce_types( } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), - AggregateFunction::StringAgg => { - if !is_string_agg_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}", - agg_fun, - input_types[0] - ); - } - if !is_string_agg_supported_arg_type(&input_types[1]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}", - agg_fun, - input_types[1] - ); - } - Ok(vec![LargeUtf8, input_types[1].clone()]) - } } } @@ -391,15 +374,6 @@ pub fn is_integer_arg_type(arg_type: &DataType) -> bool { arg_type.is_integer() } -/// Return `true` if `arg_type` is of a [`DataType`] that the -/// [`AggregateFunction::StringAgg`] aggregation can operate on. -pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - DataType::Utf8 | DataType::LargeUtf8 | DataType::Null - ) -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 990303bd1de3..20a8d2c15926 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -70,6 +70,7 @@ pub mod approx_median; pub mod approx_percentile_cont; pub mod approx_percentile_cont_with_weight; pub mod bit_and_or_xor; +pub mod string_agg; use crate::approx_percentile_cont::approx_percentile_cont_udaf; use crate::approx_percentile_cont_with_weight::approx_percentile_cont_with_weight_udaf; @@ -138,6 +139,7 @@ pub fn all_default_aggregate_functions() -> Vec> { approx_distinct::approx_distinct_udaf(), approx_percentile_cont_udaf(), approx_percentile_cont_with_weight_udaf(), + string_agg::string_agg_udaf(), bit_and_or_xor::bit_and_udaf(), bit_and_or_xor::bit_or_udaf(), bit_and_or_xor::bit_xor_udaf(), diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs new file mode 100644 index 000000000000..371cc8fb9739 --- /dev/null +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function + +use arrow::array::ArrayRef; +use arrow_schema::DataType; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::Result; +use datafusion_common::{not_impl_err, ScalarValue}; +use datafusion_expr::function::AccumulatorArgs; +use datafusion_expr::{ + Accumulator, AggregateUDFImpl, Expr, Signature, TypeSignature, Volatility, +}; +use std::any::Any; + +make_udaf_expr_and_func!( + StringAgg, + string_agg, + expr delimiter, + "Concatenates the values of string expressions and places separator values between them", + string_agg_udaf +); + +/// STRING_AGG aggregate expression +#[derive(Debug)] +pub struct StringAgg { + signature: Signature, +} + +impl StringAgg { + /// Create a new StringAgg aggregate function + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::LargeUtf8]), + TypeSignature::Exact(vec![DataType::LargeUtf8, DataType::Null]), + ], + Volatility::Immutable, + ), + } + } +} + +impl Default for StringAgg { + fn default() -> Self { + Self::new() + } +} + +impl AggregateUDFImpl for StringAgg { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "string_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::LargeUtf8) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + match &acc_args.input_exprs[1] { + Expr::Literal(ScalarValue::Utf8(Some(delimiter))) + | Expr::Literal(ScalarValue::LargeUtf8(Some(delimiter))) => { + Ok(Box::new(StringAggAccumulator::new(delimiter))) + } + Expr::Literal(ScalarValue::Utf8(None)) + | Expr::Literal(ScalarValue::LargeUtf8(None)) + | Expr::Literal(ScalarValue::Null) => { + Ok(Box::new(StringAggAccumulator::new(""))) + } + _ => not_impl_err!( + "StringAgg not supported for delimiter {}", + &acc_args.input_exprs[1] + ), + } + } +} + +#[derive(Debug)] +pub(crate) struct StringAggAccumulator { + values: Option, + delimiter: String, +} + +impl StringAggAccumulator { + pub fn new(delimiter: &str) -> Self { + Self { + values: None, + delimiter: delimiter.to_string(), + } + } +} + +impl Accumulator for StringAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let string_array: Vec<_> = as_generic_string_array::(&values[0])? + .iter() + .filter_map(|v| v.as_ref().map(ToString::to_string)) + .collect(); + if !string_array.is_empty() { + let s = string_array.join(self.delimiter.as_str()); + let v = self.values.get_or_insert("".to_string()); + if !v.is_empty() { + v.push_str(self.delimiter.as_str()); + } + v.push_str(s.as_str()); + } + Ok(()) + } + + fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.update_batch(values)?; + Ok(()) + } + + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } + + fn evaluate(&mut self) -> Result { + Ok(ScalarValue::LargeUtf8(self.values.clone())) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) + + self.delimiter.capacity() + } +} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 6c01decdbf95..1dfe9ffd6905 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -155,22 +155,6 @@ pub fn create_aggregate_expr( ordering_req.to_vec(), )) } - (AggregateFunction::StringAgg, false) => { - if !ordering_req.is_empty() { - return not_impl_err!( - "STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available" - ); - } - Arc::new(expressions::StringAgg::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )) - } - (AggregateFunction::StringAgg, true) => { - return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available"); - } }) } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 0b1f5f577435..87c7deccc2cd 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -26,7 +26,6 @@ pub(crate) mod correlation; pub(crate) mod covariance; pub(crate) mod grouping; pub(crate) mod nth_value; -pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; diff --git a/datafusion/physical-expr/src/aggregate/string_agg.rs b/datafusion/physical-expr/src/aggregate/string_agg.rs deleted file mode 100644 index dc0ffc557968..000000000000 --- a/datafusion/physical-expr/src/aggregate/string_agg.rs +++ /dev/null @@ -1,246 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{format_state_name, Literal}; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field}; -use datafusion_common::cast::as_generic_string_array; -use datafusion_common::{not_impl_err, Result, ScalarValue}; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// STRING_AGG aggregate expression -#[derive(Debug)] -pub struct StringAgg { - name: String, - data_type: DataType, - expr: Arc, - delimiter: Arc, - nullable: bool, -} - -impl StringAgg { - /// Create a new StringAgg aggregate function - pub fn new( - expr: Arc, - delimiter: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - Self { - name: name.into(), - data_type, - delimiter, - expr, - nullable: true, - } - } -} - -impl AggregateExpr for StringAgg { - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new( - &self.name, - self.data_type.clone(), - self.nullable, - )) - } - - fn create_accumulator(&self) -> Result> { - if let Some(delimiter) = self.delimiter.as_any().downcast_ref::() { - match delimiter.value() { - ScalarValue::Utf8(Some(delimiter)) - | ScalarValue::LargeUtf8(Some(delimiter)) => { - return Ok(Box::new(StringAggAccumulator::new(delimiter))); - } - ScalarValue::Null => { - return Ok(Box::new(StringAggAccumulator::new(""))); - } - _ => return not_impl_err!("StringAgg not supported for {}", self.name), - } - } - not_impl_err!("StringAgg not supported for {}", self.name) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new( - format_state_name(&self.name, "string_agg"), - self.data_type.clone(), - self.nullable, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone(), self.delimiter.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for StringAgg { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - && self.delimiter.eq(&x.delimiter) - }) - .unwrap_or(false) - } -} - -#[derive(Debug)] -pub(crate) struct StringAggAccumulator { - values: Option, - delimiter: String, -} - -impl StringAggAccumulator { - pub fn new(delimiter: &str) -> Self { - Self { - values: None, - delimiter: delimiter.to_string(), - } - } -} - -impl Accumulator for StringAggAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let string_array: Vec<_> = as_generic_string_array::(&values[0])? - .iter() - .filter_map(|v| v.as_ref().map(ToString::to_string)) - .collect(); - if !string_array.is_empty() { - let s = string_array.join(self.delimiter.as_str()); - let v = self.values.get_or_insert("".to_string()); - if !v.is_empty() { - v.push_str(self.delimiter.as_str()); - } - v.push_str(s.as_str()); - } - Ok(()) - } - - fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - self.update_batch(values)?; - Ok(()) - } - - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - - fn evaluate(&mut self) -> Result { - Ok(ScalarValue::LargeUtf8(self.values.clone())) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) - + self.delimiter.capacity() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::tests::aggregate; - use crate::expressions::{col, create_aggregate_expr, try_cast}; - use arrow::datatypes::*; - use arrow::record_batch::RecordBatch; - use arrow_array::LargeStringArray; - use arrow_array::StringArray; - use datafusion_expr::type_coercion::aggregates::coerce_types; - use datafusion_expr::AggregateFunction; - - fn assert_string_aggregate( - array: ArrayRef, - function: AggregateFunction, - distinct: bool, - expected: ScalarValue, - delimiter: String, - ) { - let data_type = array.data_type(); - let sig = function.signature(); - let coerced = - coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap(); - - let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]); - let batch = - RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap(); - - let input = try_cast( - col("a", &input_schema).unwrap(), - &input_schema, - coerced[0].clone(), - ) - .unwrap(); - - let delimiter = Arc::new(Literal::new(ScalarValue::from(delimiter))); - let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]); - let agg = create_aggregate_expr( - &function, - distinct, - &[input, delimiter], - &[], - &schema, - "agg", - false, - ) - .unwrap(); - - let result = aggregate(&batch, agg).unwrap(); - assert_eq!(expected, result); - } - - #[test] - fn string_agg_utf8() { - let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"])); - assert_string_aggregate( - a, - AggregateFunction::StringAgg, - false, - ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())), - ",".to_owned(), - ); - } - - #[test] - fn string_agg_largeutf8() { - let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"])); - assert_string_aggregate( - a, - AggregateFunction::StringAgg, - false, - ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())), - "|".to_owned(), - ); - } -} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index bffaafd7dac2..322610404074 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -47,7 +47,6 @@ pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; pub use crate::aggregate::stats::StatsType; -pub use crate::aggregate::string_agg::StringAgg; pub use crate::window::cume_dist::{cume_dist, CumeDist}; pub use crate::window::lead_lag::{lag, lead, WindowShift}; pub use crate::window::nth_value::NthValue; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index ae4445eaa8ce..6375df721ae6 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -505,7 +505,7 @@ enum AggregateFunction { // REGR_SXX = 32; // REGR_SYY = 33; // REGR_SXY = 34; - STRING_AGG = 35; + // STRING_AGG = 35; NTH_VALUE_AGG = 36; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 243c75435f8d..5c483f70d150 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -540,7 +540,6 @@ impl serde::Serialize for AggregateFunction { Self::Grouping => "GROUPING", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", - Self::StringAgg => "STRING_AGG", Self::NthValueAgg => "NTH_VALUE_AGG", }; serializer.serialize_str(variant) @@ -561,7 +560,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "GROUPING", "BOOL_AND", "BOOL_OR", - "STRING_AGG", "NTH_VALUE_AGG", ]; @@ -611,7 +609,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "GROUPING" => Ok(AggregateFunction::Grouping), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), - "STRING_AGG" => Ok(AggregateFunction::StringAgg), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 1172eccb90fd..bc5b6be2ad87 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1959,7 +1959,7 @@ pub enum AggregateFunction { /// REGR_SXX = 32; /// REGR_SYY = 33; /// REGR_SXY = 34; - StringAgg = 35, + /// STRING_AGG = 35; NthValueAgg = 36, } impl AggregateFunction { @@ -1977,7 +1977,6 @@ impl AggregateFunction { AggregateFunction::Grouping => "GROUPING", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::StringAgg => "STRING_AGG", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } } @@ -1992,7 +1991,6 @@ impl AggregateFunction { "GROUPING" => Some(Self::Grouping), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), - "STRING_AGG" => Some(Self::StringAgg), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 43cc352f98dd..5bec655bb1ff 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -146,7 +146,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Correlation => Self::Correlation, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, - protobuf::AggregateFunction::StringAgg => Self::StringAgg, } } } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 33a58daeaf0a..66b7c77799ea 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -117,7 +117,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Correlation => Self::Correlation, AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, - AggregateFunction::StringAgg => Self::StringAgg, } } } @@ -387,9 +386,6 @@ pub fn serialize_expr( AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg - } }; let aggregate_expr = protobuf::AggregateExprNode { diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 886179bf5627..ed966509b842 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -26,8 +26,7 @@ use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr, - WindowShift, + OrderSensitiveArrayAgg, Rank, RankType, RowNumber, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -260,8 +259,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::StringAgg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::NthValueAgg } else { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 52696a106183..61764394ee74 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -60,6 +60,7 @@ use datafusion_expr::{ WindowFunctionDefinition, WindowUDF, WindowUDFImpl, }; use datafusion_functions_aggregate::expr_fn::{bit_and, bit_or, bit_xor}; +use datafusion_functions_aggregate::string_agg::string_agg; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, @@ -669,6 +670,7 @@ async fn roundtrip_expr_api() -> Result<()> { bit_and(lit(2)), bit_or(lit(2)), bit_xor(lit(2)), + string_agg(col("a").cast_to(&DataType::Utf8, &schema)?, lit("|")), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7f66cdbf7663..eb3313239544 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -48,7 +48,7 @@ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, NotExpr, NthValue, - PhysicalSortExpr, StringAgg, + PhysicalSortExpr, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::insert::DataSinkExec; @@ -79,6 +79,7 @@ use datafusion_expr::{ Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; +use datafusion_functions_aggregate::string_agg::StringAgg; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; @@ -357,12 +358,20 @@ fn rountrip_aggregate() -> Result<()> { Vec::new(), ))], // STRING_AGG - vec![Arc::new(StringAgg::new( - cast(col("b", &schema)?, &schema, DataType::Utf8)?, - lit(ScalarValue::Utf8(Some(",".to_string()))), - "STRING_AGG(name, ',')".to_string(), - DataType::Utf8, - ))], + vec![udaf::create_aggregate_expr( + &AggregateUDF::new_from_impl(StringAgg::new()), + &[ + cast(col("b", &schema)?, &schema, DataType::Utf8)?, + lit(ScalarValue::Utf8(Some(",".to_string()))), + ], + &[], + &[], + &[], + &schema, + "STRING_AGG(name, ',')", + false, + false, + )?], ]; for aggregates in test_cases { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 0a6def3d6f27..378cab206240 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -4972,6 +4972,22 @@ CREATE TABLE float_table ( ( 32768.3, arrow_cast('NAN','Float32'), 32768.3, 32768.3 ), ( 27.3, 27.3, 27.3, arrow_cast('NAN','Float64') ); +# Test string_agg with largeutf8 +statement ok +create table string_agg_large_utf8 (c string) as values + (arrow_cast('a', 'LargeUtf8')), + (arrow_cast('b', 'LargeUtf8')), + (arrow_cast('c', 'LargeUtf8')) +; + +query T +SELECT STRING_AGG(c, ',') FROM string_agg_large_utf8; +---- +a,b,c + +statement ok +drop table string_agg_large_utf8; + query RRRRI select min(col_f32), max(col_f32), avg(col_f32), sum(col_f32), count(col_f32) from float_table; ----