From bd137bc026a56ad843b641ea9b4b1aadd8b1ed7c Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 29 Nov 2023 14:02:53 +0100 Subject: [PATCH 1/8] Refactor aggregate function handling --- .../core/src/datasource/listing/helpers.rs | 3 +- datafusion/core/src/physical_planner.rs | 198 +++++++------- datafusion/expr/src/expr.rs | 98 +++---- datafusion/expr/src/expr_schema.rs | 38 +-- datafusion/expr/src/tree_node/expr.rs | 58 ++--- datafusion/expr/src/udaf.rs | 11 +- datafusion/expr/src/utils.rs | 6 +- .../src/analyzer/count_wildcard_rule.rs | 10 +- .../optimizer/src/analyzer/type_coercion.rs | 68 ++--- .../optimizer/src/common_subexpr_eliminate.rs | 21 +- datafusion/optimizer/src/decorrelate.rs | 27 +- datafusion/optimizer/src/push_down_filter.rs | 1 - .../simplify_expressions/expr_simplifier.rs | 1 - .../src/single_distinct_to_groupby.rs | 21 +- .../proto/src/logical_plan/from_proto.rs | 3 +- datafusion/proto/src/logical_plan/to_proto.rs | 246 +++++++++--------- datafusion/sql/src/expr/function.rs | 4 +- datafusion/sql/src/expr/mod.rs | 3 +- datafusion/sql/src/select.rs | 9 +- .../substrait/src/logical_plan/consumer.rs | 19 +- .../substrait/src/logical_plan/producer.rs | 126 ++++----- 21 files changed, 503 insertions(+), 468 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index f9b02f4d0c10..0c39877cd11e 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -122,8 +122,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { // - AGGREGATE, WINDOW and SORT should not end up in filter conditions, except maybe in some edge cases // - Can `Wildcard` be considered as a `Literal`? // - ScalarVariable could be `applicable`, but that would require access to the context - Expr::AggregateUDF { .. } - | Expr::AggregateFunction { .. } + Expr::AggregateFunction { .. } | Expr::Sort { .. } | Expr::WindowFunction { .. } | Expr::Wildcard { .. } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ef364c22ee7d..23a1aa01b128 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -82,8 +82,9 @@ use datafusion_common::{ }; use datafusion_expr::dml::{CopyOptions, CopyTo}; use datafusion_expr::expr::{ - self, AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Cast, - GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, WindowFunction, + self, AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, + Cast, GetFieldAccess, GetIndexedField, GroupingSet, InList, Like, TryCast, + WindowFunction, }; use datafusion_expr::expr_rewriter::{unalias, unnormalize_cols}; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -229,30 +230,11 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { create_function_physical_name(&fun.to_string(), false, args) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, .. - }) => create_function_physical_name(&fun.to_string(), *distinct, args), - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - // TODO: Add support for filter and order by in AggregateUDF - if filter.is_some() { - return exec_err!("aggregate expression with filter is not supported"); - } - if order_by.is_some() { - return exec_err!("aggregate expression with order_by is not supported"); - } - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_physical_name(e, false)?); - } - Ok(format!("{}({})", fun.name(), names.join(","))) - } + }) => create_function_physical_name(func_def.name(), *distinct, args), Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -1705,105 +1687,105 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( ) -> Result { match e { Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, order_by, - }) => { - let args = args - .iter() - .map(|e| { - create_physical_expr( + }) => match func_def { + AggregateFunctionDefinition::BuiltIn { fun, name } => { + let args = args + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; + let filter = match filter { + Some(e) => Some(create_physical_expr( e, logical_input_schema, physical_input_schema, execution_props, - ) - }) - .collect::>>()?; - let filter = match filter { - Some(e) => Some(create_physical_expr( - e, - logical_input_schema, + )?), + None => None, + }; + let order_by = match order_by { + Some(e) => Some( + e.iter() + .map(|expr| { + create_physical_sort_expr( + expr, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?, + ), + None => None, + }; + let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let agg_expr = aggregates::create_aggregate_expr( + fun, + *distinct, + &args, + &ordering_reqs, physical_input_schema, - execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, - }; - let ordering_reqs = order_by.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - &ordering_reqs, - physical_input_schema, - name, - )?; - Ok((agg_expr, filter, order_by)) - } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let args = args - .iter() - .map(|e| { - create_physical_expr( + name.to_string(), + )?; + Ok((agg_expr, filter, order_by)) + } + AggregateFunctionDefinition::UDF(fun) => { + let args = args + .iter() + .map(|e| { + create_physical_expr( + e, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?; + + let filter = match filter { + Some(e) => Some(create_physical_expr( e, logical_input_schema, physical_input_schema, execution_props, - ) - }) - .collect::>>()?; - - let filter = match filter { - Some(e) => Some(create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, - }; + )?), + None => None, + }; + let order_by = match order_by { + Some(e) => Some( + e.iter() + .map(|expr| { + create_physical_sort_expr( + expr, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?, + ), + None => None, + }; - let agg_expr = - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); - Ok((agg_expr?, filter, order_by)) - } + let agg_expr = + udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); + Ok((agg_expr?, filter, order_by)) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Aggregate function name should have been resolved") + } + }, other => internal_err!("Invalid aggregate expression '{other:?}'"), } } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b46d204faafb..aec2e1bf8a52 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -154,8 +154,6 @@ pub enum Expr { AggregateFunction(AggregateFunction), /// Represents the call of a window function with arguments. WindowFunction(WindowFunction), - /// aggregate function - AggregateUDF(AggregateUDF), /// Returns whether the list contains the expr value. InList(InList), /// EXISTS subquery @@ -484,11 +482,39 @@ impl Sort { } } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// Defines which implementation of a function for DataFusion to call. +pub enum AggregateFunctionDefinition { + /// Resolved to a `BuiltinScalarFunction` + /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) + /// This variant is planned to be removed in long term + BuiltIn { + fun: aggregate_function::AggregateFunction, + name: Arc, + }, + /// Resolved to a user defined function + UDF(Arc), + /// A scalar function constructed with name. This variant can not be executed directly + /// and instead must be resolved to one of the other variants prior to physical planning. + Name(Arc), +} + +impl AggregateFunctionDefinition { + /// Function's name for display + pub fn name(&self) -> &str { + match self { + AggregateFunctionDefinition::BuiltIn { name, .. } => name.as_ref(), + AggregateFunctionDefinition::UDF(udf) => udf.name(), + AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), + } + } +} + /// Aggregate function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct AggregateFunction { /// Name of the function - pub fun: aggregate_function::AggregateFunction, + pub func_def: AggregateFunctionDefinition, /// List of expressions to feed to the functions as arguments pub args: Vec, /// Whether this is a DISTINCT aggregation or not @@ -508,7 +534,27 @@ impl AggregateFunction { order_by: Option>, ) -> Self { Self { - fun, + func_def: AggregateFunctionDefinition::BuiltIn { + fun: fun.clone(), + name: Arc::from(fun.to_string()), + }, + args, + distinct, + filter, + order_by, + } + } + + /// Create a new ScalarFunction expression with a user-defined function (UDF) + pub fn new_udf( + udf: Arc, + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + ) -> Self { + Self { + func_def: AggregateFunctionDefinition::UDF(udf), args, distinct, filter, @@ -736,7 +782,6 @@ impl Expr { pub fn variant_name(&self) -> &str { match self { Expr::AggregateFunction { .. } => "AggregateFunction", - Expr::AggregateUDF { .. } => "AggregateUDF", Expr::Alias(..) => "Alias", Expr::Between { .. } => "Between", Expr::BinaryExpr { .. } => "BinaryExpr", @@ -1251,30 +1296,14 @@ impl fmt::Display for Expr { Ok(()) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, ref args, filter, order_by, .. }) => { - fmt_function(f, &fun.to_string(), *distinct, args, true)?; - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; - } - if let Some(ob) = order_by { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; - } - Ok(()) - } - Expr::AggregateUDF(AggregateUDF { - fun, - ref args, - filter, - order_by, - .. - }) => { - fmt_function(f, fun.name(), false, args, true)?; + fmt_function(f, func_def.name(), *distinct, args, true)?; if let Some(fe) = filter { write!(f, " FILTER (WHERE {fe})")?; } @@ -1579,13 +1608,13 @@ fn create_name(e: &Expr) -> Result { Ok(parts.join(" ")) } Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, order_by, }) => { - let mut name = create_function_name(&fun.to_string(), *distinct, args)?; + let mut name = create_function_name(func_def.name(), *distinct, args)?; if let Some(fe) = filter { name = format!("{name} FILTER (WHERE {fe})"); }; @@ -1594,25 +1623,6 @@ fn create_name(e: &Expr) -> Result { }; Ok(name) } - Expr::AggregateUDF(AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e)?); - } - let mut info = String::new(); - if let Some(fe) = filter { - info += &format!(" FILTER (WHERE {fe})"); - } - if let Some(ob) = order_by { - info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob)); - } - Ok(format!("{}({}){}", fun.name(), names.join(","), info)) - } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { Ok(format!("ROLLUP ({})", create_names(exprs.as_slice())?)) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index d5d9c848b2e9..65ecb4153d3e 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, + GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; @@ -123,19 +123,26 @@ impl ExprSchemable for Expr { .collect::>>()?; fun.return_type(&data_types) } - Expr::AggregateFunction(AggregateFunction { fun, args, .. }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - fun.return_type(&data_types) - } - Expr::AggregateUDF(AggregateUDF { fun, args, .. }) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; - fun.return_type(&data_types) + Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn { fun, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + fun.return_type(&data_types) + } + AggregateFunctionDefinition::UDF(fun) => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok(fun.return_type(&data_types)?) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + } } Expr::Not(_) | Expr::IsNull(_) @@ -252,7 +259,6 @@ impl ExprSchemable for Expr { | Expr::ScalarFunction(..) | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::Placeholder(_) => Ok(true), Expr::IsNull(_) | Expr::IsNotNull(_) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 474b5f7689b9..fa2c8e4cfdd5 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -18,9 +18,9 @@ //! Tree node implementation for logical expr use crate::expr::{ - AggregateFunction, AggregateUDF, Alias, Between, BinaryExpr, Case, Cast, - GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, - ScalarFunctionDefinition, Sort, TryCast, WindowFunction, + AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, + Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, + ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, WindowFunction, }; use crate::{Expr, GetFieldAccess}; @@ -108,7 +108,7 @@ impl TreeNode for Expr { expr_vec } Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) - | Expr::AggregateUDF(AggregateUDF { args, filter, order_by, .. }) => { + => { let mut expr_vec = args.clone(); if let Some(f) = filter { @@ -304,17 +304,32 @@ impl TreeNode for Expr { )), Expr::AggregateFunction(AggregateFunction { args, - fun, + func_def, distinct, filter, order_by, - }) => Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )), + }) => match func_def { + AggregateFunctionDefinition::BuiltIn { fun, .. } => { + Expr::AggregateFunction(AggregateFunction::new( + fun, args, distinct, filter, order_by, + )) + } + AggregateFunctionDefinition::UDF(fun) => { + let order_by = if let Some(order_by) = order_by { + Some(transform_vec(order_by, &mut transform)?) + } else { + None + }; + Expr::AggregateFunction(AggregateFunction::new_udf( + fun, args, distinct, filter, order_by, + )) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( transform_vec(exprs, &mut transform)?, @@ -331,24 +346,7 @@ impl TreeNode for Expr { )) } }, - Expr::AggregateUDF(AggregateUDF { - args, - fun, - filter, - order_by, - }) => { - let order_by = if let Some(order_by) = order_by { - Some(transform_vec(order_by, &mut transform)?) - } else { - None - }; - Expr::AggregateUDF(AggregateUDF::new( - fun, - transform_vec(args, &mut transform)?, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } + Expr::InList(InList { expr, list, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index b06e97acc283..cfbca4ab1337 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -107,12 +107,13 @@ impl AggregateUDF { /// This utility allows using the UDAF without requiring access to /// the registry, such as with the DataFrame API. pub fn call(&self, args: Vec) -> Expr { - Expr::AggregateUDF(crate::expr::AggregateUDF { - fun: Arc::new(self.clone()), + Expr::AggregateFunction(crate::expr::AggregateFunction::new_udf( + Arc::new(self.clone()), args, - filter: None, - order_by: None, - }) + false, + None, + None, + )) } /// Returns this function's name diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7deb13c89be5..7bab751ee1b4 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -291,7 +291,6 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::WindowFunction { .. } | Expr::AggregateFunction { .. } | Expr::GroupingSet(_) - | Expr::AggregateUDF { .. } | Expr::InList { .. } | Expr::Exists { .. } | Expr::InSubquery(_) @@ -600,10 +599,7 @@ pub fn group_window_expr_by_sort_keys( /// first), with duplicates omitted. pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { - matches!( - nested_expr, - Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } - ) + matches!(nested_expr, Expr::AggregateFunction { .. }) }) } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index b4de322f76f6..16d8c72d462a 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -144,20 +144,20 @@ impl TreeNodeRewriter for CountWildcardRewriter { _ => old_expr, }, Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, + func_def: _, args, distinct, filter, order_by, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction { - fun: aggregate_function::AggregateFunction::Count, - args: vec![lit(COUNT_STAR_EXPANSION)], + Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::Count, + vec![lit(COUNT_STAR_EXPANSION)], distinct, filter, order_by, - }) + )) } _ => old_expr, }, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index eb5d8c53a5e0..53aa7a4c1d34 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -28,8 +28,8 @@ use datafusion_common::{ DataFusionError, Result, ScalarValue, }; use datafusion_expr::expr::{ - self, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, ScalarFunction, - WindowFunction, + self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, WindowFunction, }; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::expr_schema::cast_subquery; @@ -346,39 +346,39 @@ impl TreeNodeRewriter for TypeCoercionRewriter { } }, Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def, args, distinct, filter, order_by, - }) => { - let new_expr = coerce_agg_exprs_for_signature( - &fun, - &args, - &self.schema, - &fun.signature(), - )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) - } - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => { - let new_expr = coerce_arguments_for_signature( - args.as_slice(), - &self.schema, - fun.signature(), - )?; - let expr = Expr::AggregateUDF(expr::AggregateUDF::new( - fun, new_expr, filter, order_by, - )); - Ok(expr) - } + }) => match func_def { + AggregateFunctionDefinition::BuiltIn { fun, .. } => { + let new_expr = coerce_agg_exprs_for_signature( + &fun, + &args, + &self.schema, + &fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::UDF(fun) => { + let new_expr = coerce_arguments_for_signature( + args.as_slice(), + &self.schema, + fun.signature(), + )?; + let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + )); + Ok(expr) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Function `Expr` with name should be resolved.") + } + }, Expr::WindowFunction(WindowFunction { fun, args, @@ -914,9 +914,10 @@ mod test { Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit(10i64)], + false, None, None, )); @@ -941,9 +942,10 @@ mod test { &accumulator, &state_type, ); - let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( + let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], + false, None, None, )); diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f5ad767c5016..8ea6ca5304da 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -528,10 +528,7 @@ impl ExprMask { | Expr::Wildcard { .. } ); - let is_aggr = matches!( - expr, - Expr::AggregateFunction(..) | Expr::AggregateUDF { .. } - ); + let is_aggr = matches!(expr, Expr::AggregateFunction(..)); match self { Self::Normal => is_normal_minus_aggregates || is_aggr, @@ -908,7 +905,7 @@ mod test { let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { - Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( Arc::new(AggregateUDF::new( "my_agg", &Signature::exact(vec![DataType::UInt32], Volatility::Stable), @@ -917,9 +914,23 @@ mod test { &state_type, )), vec![inner], + false, None, None, )) + + // Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( + // Arc::new(AggregateUDF::new( + // "my_agg", + // &Signature::exact(vec![DataType::UInt32], Volatility::Stable), + // &return_type, + // &accumulator, + // &state_type, + // )), + // vec![inner], + // None, + // None, + // )) }; // test: common aggregates diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index ed6f472186d4..f9f5847d6aff 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -22,7 +22,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; -use datafusion_expr::expr::Alias; +use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; @@ -372,16 +372,25 @@ fn agg_exprs_evaluation_result_on_empty_batch( for e in agg_expr.iter() { let result_expr = e.clone().transform_up(&|expr| { let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, .. }) => { - if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some(0)))) - } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn { fun, .. } => { + if matches!(fun, datafusion_expr::AggregateFunction::Count) { + Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + 0, + )))) + } else { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + } + AggregateFunctionDefinition::UDF { .. } => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } + AggregateFunctionDefinition::Name(_) => { + Transformed::Yes(Expr::Literal(ScalarValue::Null)) + } } } - Expr::AggregateUDF(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) - } _ => Transformed::No(expr), }; Ok(new_expr) diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 95eeee931b4f..bad6e24715c9 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -253,7 +253,6 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) - | Expr::AggregateUDF { .. } | Expr::Wildcard { .. } | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), })?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 3310bfed75bf..c7366e17619c 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -332,7 +332,6 @@ impl<'a> ConstEvaluator<'a> { // Has no runtime cost, but needed during planning Expr::Alias(..) | Expr::AggregateFunction { .. } - | Expr::AggregateUDF { .. } | Expr::ScalarVariable(_, _) | Expr::Column(_) | Expr::OuterReferenceColumn(_, _) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index fa142438c4a3..f66ff0eae6f0 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -23,6 +23,7 @@ use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::{ aggregate_function::AggregateFunction::{Max, Min, Sum}, col, @@ -70,7 +71,7 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - fun, + func_def, distinct, args, filter, @@ -85,8 +86,19 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { for e in args { fields_set.insert(e.canonical_name()); } - } else if !matches!(fun, Sum | Min | Max) { - return Ok(false); + } else { + match func_def { + AggregateFunctionDefinition::BuiltIn { fun, name: _ } => { + if !matches!(fun, Sum | Min | Max) { + return Ok(false); + } else { + return Ok(true); + } + } + _ => { + return Ok(false); + } + } } } } @@ -170,7 +182,8 @@ impl OptimizerRule for SingleDistinctToGroupBy { .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { - fun, + func_def: + AggregateFunctionDefinition::BuiltIn { fun, name: _ }, args, distinct, .. diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b2455d5a0d13..d7071c6ddf10 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -1739,12 +1739,13 @@ pub fn parse_expr( ExprType::AggregateUdfExpr(pb) => { let agg_fn = registry.udaf(pb.fun_name.as_str())?; - Ok(Expr::AggregateUDF(expr::AggregateUDF::new( + Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, pb.args .iter() .map(|expr| parse_expr(expr, registry)) .collect::, Error>>()?, + false, parse_optional_expr(pb.filter.as_deref(), registry)?.map(Box::new), parse_vec_expr(&pb.order_by, registry)?, ))) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9be4a532bb5b..c8aaea7c058f 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -44,8 +44,9 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, GroupingSet, - InList, Like, Placeholder, ScalarFunction, ScalarFunctionDefinition, Sort, + self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess, + GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, + ScalarFunctionDefinition, Sort, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -652,104 +653,139 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { } } Expr::AggregateFunction(expr::AggregateFunction { - ref fun, + ref func_def, ref args, ref distinct, ref filter, ref order_by, }) => { - let aggr_function = match fun { - AggregateFunction::ApproxDistinct => { - protobuf::AggregateFunction::ApproxDistinct - } - AggregateFunction::ApproxPercentileCont => { - protobuf::AggregateFunction::ApproxPercentileCont - } - AggregateFunction::ApproxPercentileContWithWeight => { - protobuf::AggregateFunction::ApproxPercentileContWithWeight - } - AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, - AggregateFunction::Min => protobuf::AggregateFunction::Min, - AggregateFunction::Max => protobuf::AggregateFunction::Max, - AggregateFunction::Sum => protobuf::AggregateFunction::Sum, - AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, - AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, - AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, - AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, - AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, - AggregateFunction::Avg => protobuf::AggregateFunction::Avg, - AggregateFunction::Count => protobuf::AggregateFunction::Count, - AggregateFunction::Variance => protobuf::AggregateFunction::Variance, - AggregateFunction::VariancePop => { - protobuf::AggregateFunction::VariancePop - } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } - AggregateFunction::CovariancePop => { - protobuf::AggregateFunction::CovariancePop - } - AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, - AggregateFunction::StddevPop => { - protobuf::AggregateFunction::StddevPop - } - AggregateFunction::Correlation => { - protobuf::AggregateFunction::Correlation - } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } - AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, - AggregateFunction::Median => protobuf::AggregateFunction::Median, - AggregateFunction::FirstValue => { - protobuf::AggregateFunction::FirstValueAgg - } - AggregateFunction::LastValue => { - protobuf::AggregateFunction::LastValueAgg - } - AggregateFunction::StringAgg => { - protobuf::AggregateFunction::StringAgg + match func_def { + AggregateFunctionDefinition::BuiltIn { fun, .. } => { + let aggr_function = match fun { + AggregateFunction::ApproxDistinct => { + protobuf::AggregateFunction::ApproxDistinct + } + AggregateFunction::ApproxPercentileCont => { + protobuf::AggregateFunction::ApproxPercentileCont + } + AggregateFunction::ApproxPercentileContWithWeight => { + protobuf::AggregateFunction::ApproxPercentileContWithWeight + } + AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, + AggregateFunction::Min => protobuf::AggregateFunction::Min, + AggregateFunction::Max => protobuf::AggregateFunction::Max, + AggregateFunction::Sum => protobuf::AggregateFunction::Sum, + AggregateFunction::BitAnd => protobuf::AggregateFunction::BitAnd, + AggregateFunction::BitOr => protobuf::AggregateFunction::BitOr, + AggregateFunction::BitXor => protobuf::AggregateFunction::BitXor, + AggregateFunction::BoolAnd => protobuf::AggregateFunction::BoolAnd, + AggregateFunction::BoolOr => protobuf::AggregateFunction::BoolOr, + AggregateFunction::Avg => protobuf::AggregateFunction::Avg, + AggregateFunction::Count => protobuf::AggregateFunction::Count, + AggregateFunction::Variance => protobuf::AggregateFunction::Variance, + AggregateFunction::VariancePop => { + protobuf::AggregateFunction::VariancePop + } + AggregateFunction::Covariance => { + protobuf::AggregateFunction::Covariance + } + AggregateFunction::CovariancePop => { + protobuf::AggregateFunction::CovariancePop + } + AggregateFunction::Stddev => protobuf::AggregateFunction::Stddev, + AggregateFunction::StddevPop => { + protobuf::AggregateFunction::StddevPop + } + AggregateFunction::Correlation => { + protobuf::AggregateFunction::Correlation + } + AggregateFunction::RegrSlope => { + protobuf::AggregateFunction::RegrSlope + } + AggregateFunction::RegrIntercept => { + protobuf::AggregateFunction::RegrIntercept + } + AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, + AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, + AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, + AggregateFunction::RegrCount => { + protobuf::AggregateFunction::RegrCount + } + AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, + AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, + AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } + AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, + AggregateFunction::Median => protobuf::AggregateFunction::Median, + AggregateFunction::FirstValue => { + protobuf::AggregateFunction::FirstValueAgg + } + AggregateFunction::LastValue => { + protobuf::AggregateFunction::LastValueAgg + } + AggregateFunction::StringAgg => { + protobuf::AggregateFunction::StringAgg + } + }; + + let aggregate_expr = protobuf::AggregateExprNode { + aggr_function: aggr_function.into(), + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + distinct: *distinct, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }; + Self { + expr_type: Some(ExprType::AggregateExpr(Box::new( + aggregate_expr, + ))), + } } - }; - - let aggregate_expr = protobuf::AggregateExprNode { - aggr_function: aggr_function.into(), - expr: args - .iter() - .map(|v| v.try_into()) - .collect::, _>>()?, - distinct: *distinct, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], + AggregateFunctionDefinition::UDF(fun) => Self { + expr_type: Some(ExprType::AggregateUdfExpr(Box::new( + protobuf::AggregateUdfExprNode { + fun_name: fun.name().to_string(), + args: args + .iter() + .map(|expr| expr.try_into()) + .collect::, Error>>()?, + filter: match filter { + Some(e) => Some(Box::new(e.as_ref().try_into()?)), + None => None, + }, + order_by: match order_by { + Some(e) => e + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?, + None => vec![], + }, + }, + ))), }, - }; - Self { - expr_type: Some(ExprType::AggregateExpr(Box::new(aggregate_expr))), + AggregateFunctionDefinition::Name(_) => { + return Err(Error::NotImplemented( + "Proto serialization error: Trying to serialize a unresolved function" + .to_string(), + )); + } } } + Expr::ScalarVariable(_, _) => { return Err(Error::General( "Proto serialization error: Scalar Variable not supported" @@ -790,34 +826,6 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { )); } }, - Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }) => Self { - expr_type: Some(ExprType::AggregateUdfExpr(Box::new( - protobuf::AggregateUdfExprNode { - fun_name: fun.name().to_string(), - args: args.iter().map(|expr| expr.try_into()).collect::, - Error, - >>( - )?, - filter: match filter { - Some(e) => Some(Box::new(e.as_ref().try_into()?)), - None => None, - }, - order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - None => vec![], - }, - }, - ))), - }, Expr::Not(expr) => { let expr = Box::new(protobuf::Not { expr: Some(Box::new(expr.as_ref().try_into()?)), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 24ba4d1b506a..958e03879842 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -135,8 +135,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // User defined aggregate functions (UDAF) have precedence in case it has the same name as a scalar built-in function if let Some(fm) = self.context_provider.get_aggregate_meta(&name) { let args = self.function_args_to_expr(args, schema, planner_context)?; - return Ok(Expr::AggregateUDF(expr::AggregateUDF::new( - fm, args, None, None, + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( + fm, args, false, None, None, ))); } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 25fe6b6633c2..743893c2022e 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -34,6 +34,7 @@ use datafusion_common::{ internal_err, not_impl_err, plan_err, Column, DFSchema, DataFusionError, Result, ScalarValue, }; +use datafusion_expr::expr::AggregateFunctionDefinition; use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ @@ -706,7 +707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { Expr::AggregateFunction(expr::AggregateFunction { - fun, + func_def: AggregateFunctionDefinition::BuiltIn { fun, .. }, args, distinct, order_by, diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 356c53605131..c546ca755206 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -170,11 +170,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select_exprs .iter() .filter(|select_expr| match select_expr { - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) => false, - Expr::Alias(Alias { expr, name: _, .. }) => !matches!( - **expr, - Expr::AggregateFunction(_) | Expr::AggregateUDF(_) - ), + Expr::AggregateFunction(_) => false, + Expr::Alias(Alias { expr, name: _, .. }) => { + !matches!(**expr, Expr::AggregateFunction(_)) + } _ => true, }) .cloned() diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index b7a51032dcd9..cf05d814a5cb 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -692,21 +692,14 @@ pub async fn from_substrait_agg_func( // try udaf first, then built-in aggr fn. if let Ok(fun) = ctx.udaf(function_name) { - Ok(Arc::new(Expr::AggregateUDF(expr::AggregateUDF { - fun, - args, - filter, - order_by, - }))) + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by), + ))) } else if let Ok(fun) = aggregate_function::AggregateFunction::from_str(function_name) { - Ok(Arc::new(Expr::AggregateFunction(expr::AggregateFunction { - fun, - args, - distinct, - filter, - order_by, - }))) + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new(fun, args, distinct, filter, order_by), + ))) } else { not_impl_err!( "Aggregated function {} is not supported: function anchor = {:?}", diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 2be3e7b4e884..9492f93c4302 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -33,8 +33,8 @@ use datafusion::common::{exec_err, internal_err, not_impl_err}; #[allow(unused_imports)] use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, ScalarFunctionDefinition, Sort, - WindowFunction, + AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + ScalarFunction as DFScalarFunction, ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -578,65 +578,73 @@ pub fn to_substrait_agg_measure( ), ) -> Result { match expr { - Expr::AggregateFunction(expr::AggregateFunction { fun, args, distinct, filter, order_by }) => { - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); - } - let function_anchor = _register_function(fun.to_string(), extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: match distinct { - true => AggregationInvocation::Distinct as i32, - false => AggregationInvocation::All as i32, - }, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), - None => None + Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => { + match func_def { + AggregateFunctionDefinition::BuiltIn { fun, .. } => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: match distinct { + true => AggregationInvocation::Distinct as i32, + false => AggregationInvocation::All as i32, + }, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) } - }) - } - Expr::AggregateUDF(expr::AggregateUDF{ fun, args, filter, order_by }) =>{ - let sorts = if let Some(order_by) = order_by { - order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? - } else { - vec![] - }; - let mut arguments: Vec = vec![]; - for arg in args { - arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); - } - let function_anchor = _register_function(fun.name().to_string(), extension_info); - Ok(Measure { - measure: Some(AggregateFunction { - function_reference: function_anchor, - arguments, - sorts, - output_type: None, - invocation: AggregationInvocation::All as i32, - phase: AggregationPhase::Unspecified as i32, - args: vec![], - options: vec![], - }), - filter: match filter { - Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), - None => None + AggregateFunctionDefinition::UDF(fun) => { + let sorts = if let Some(order_by) = order_by { + order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? + } else { + vec![] + }; + let mut arguments: Vec = vec![]; + for arg in args { + arguments.push(FunctionArgument { arg_type: Some(ArgType::Value(to_substrait_rex(arg, schema, 0, extension_info)?)) }); + } + let function_anchor = _register_function(fun.name().to_string(), extension_info); + Ok(Measure { + measure: Some(AggregateFunction { + function_reference: function_anchor, + arguments, + sorts, + output_type: None, + invocation: AggregationInvocation::All as i32, + phase: AggregationPhase::Unspecified as i32, + args: vec![], + options: vec![], + }), + filter: match filter { + Some(f) => Some(to_substrait_rex(f, schema, 0, extension_info)?), + None => None + } + }) } - }) - }, + AggregateFunctionDefinition::Name(name) => { + internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name) + } + } + + } Expr::Alias(Alias{expr,..})=> { to_substrait_agg_measure(expr, schema, extension_info) } From 0b5a28a3373acfbec5a3638af731866b2441cdfd Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 29 Nov 2023 14:18:31 +0100 Subject: [PATCH 2/8] fix ci --- datafusion/expr/src/tree_node/expr.rs | 12 ++++++++++-- datafusion/expr/src/utils.rs | 4 ++-- .../optimizer/src/common_subexpr_eliminate.rs | 16 +--------------- .../proto/tests/cases/roundtrip_logical_plan.rs | 3 ++- 4 files changed, 15 insertions(+), 20 deletions(-) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index fa2c8e4cfdd5..53cfa630dac4 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -311,7 +311,11 @@ impl TreeNode for Expr { }) => match func_def { AggregateFunctionDefinition::BuiltIn { fun, .. } => { Expr::AggregateFunction(AggregateFunction::new( - fun, args, distinct, filter, order_by, + fun, + transform_vec(args, &mut transform)?, + distinct, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, )) } AggregateFunctionDefinition::UDF(fun) => { @@ -321,7 +325,11 @@ impl TreeNode for Expr { None }; Expr::AggregateFunction(AggregateFunction::new_udf( - fun, args, distinct, filter, order_by, + fun, + transform_vec(args, &mut transform)?, + false, + transform_option_box(filter, &mut transform)?, + transform_option_vec(order_by, &mut transform)?, )) } AggregateFunctionDefinition::Name(_) => { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7bab751ee1b4..7d126a0f3373 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -594,8 +594,8 @@ pub fn group_window_expr_by_sort_keys( Ok(result) } -/// Collect all deeply nested `Expr::AggregateFunction` and -/// `Expr::AggregateUDF`. They are returned in order of occurrence (depth +/// Collect all deeply nested `Expr::AggregateFunction`. +/// They are returned in order of occurrence (depth /// first), with duplicates omitted. pub fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { find_exprs_in_exprs(exprs, &|nested_expr| { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 8ea6ca5304da..1d21407a6985 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -509,10 +509,9 @@ enum ExprMask { /// - [`Sort`](Expr::Sort) /// - [`Wildcard`](Expr::Wildcard) /// - [`AggregateFunction`](Expr::AggregateFunction) - /// - [`AggregateUDF`](Expr::AggregateUDF) Normal, - /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction) and [`AggregateUDF`](Expr::AggregateUDF). + /// Like [`Normal`](Self::Normal), but includes [`AggregateFunction`](Expr::AggregateFunction). NormalAndAggregates, } @@ -918,19 +917,6 @@ mod test { None, None, )) - - // Expr::AggregateUDF(datafusion_expr::expr::AggregateUDF::new( - // Arc::new(AggregateUDF::new( - // "my_agg", - // &Signature::exact(vec![DataType::UInt32], Volatility::Stable), - // &return_type, - // &accumulator, - // &state_type, - // )), - // vec![inner], - // None, - // None, - // )) }; // test: common aggregates diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3ab001298ed2..45727c39a373 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1375,9 +1375,10 @@ fn roundtrip_aggregate_udf() { Arc::new(vec![DataType::Float64, DataType::UInt32]), ); - let test_expr = Expr::AggregateUDF(expr::AggregateUDF::new( + let test_expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(dummy_agg.clone()), vec![lit(1.0_f64)], + false, Some(Box::new(lit(true))), None, )); From 8e1c805d7f39c4aa7e5297b8e5b0d35045cccadf Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 29 Nov 2023 14:23:50 +0100 Subject: [PATCH 3/8] update comment --- datafusion/expr/src/expr.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index aec2e1bf8a52..674557b1f9d9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -485,16 +485,13 @@ impl Sort { #[derive(Debug, Clone, PartialEq, Eq, Hash)] /// Defines which implementation of a function for DataFusion to call. pub enum AggregateFunctionDefinition { - /// Resolved to a `BuiltinScalarFunction` - /// There is plan to migrate `BuiltinScalarFunction` to UDF-based implementation (issue#8045) - /// This variant is planned to be removed in long term BuiltIn { fun: aggregate_function::AggregateFunction, name: Arc, }, /// Resolved to a user defined function UDF(Arc), - /// A scalar function constructed with name. This variant can not be executed directly + /// A aggregation function constructed with name. This variant can not be executed directly /// and instead must be resolved to one of the other variants prior to physical planning. Name(Arc), } From 49381526617b490380ad99fa27a57e7bb573aa37 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 29 Nov 2023 14:35:52 +0100 Subject: [PATCH 4/8] fix ci --- datafusion/core/src/physical_planner.rs | 34 ++++++++++++++-- datafusion/expr/src/expr.rs | 39 +++++++++++++------ .../src/analyzer/count_wildcard_rule.rs | 8 +++- .../src/single_distinct_to_groupby.rs | 17 ++------ 4 files changed, 67 insertions(+), 31 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 23a1aa01b128..ebbc3d83bdd7 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -233,8 +233,34 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { func_def, distinct, args, - .. - }) => create_function_physical_name(func_def.name(), *distinct, args), + filter, + order_by, + }) => match func_def { + AggregateFunctionDefinition::BuiltIn { fun: _, name: _ } => { + create_function_physical_name(func_def.name(), *distinct, args) + } + AggregateFunctionDefinition::UDF(fun) => { + // TODO: Add support for filter and order by in AggregateUDF + if filter.is_some() { + return exec_err!( + "aggregate expression with filter is not supported" + ); + } + if order_by.is_some() { + return exec_err!( + "aggregate expression with order_by is not supported" + ); + } + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_physical_name(e, false)?); + } + Ok(format!("{}({})", fun.name(), names.join(","))) + } + AggregateFunctionDefinition::Name(_) => { + internal_err!("Aggregate function `Expr` with name should be resolved.") + } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( "ROLLUP ({})", @@ -1693,7 +1719,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( filter, order_by, }) => match func_def { - AggregateFunctionDefinition::BuiltIn { fun, name } => { + AggregateFunctionDefinition::BuiltIn { fun, name: _ } => { let args = args .iter() .map(|e| { @@ -1736,7 +1762,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( &args, &ordering_reqs, physical_input_schema, - name.to_string(), + name, )?; Ok((agg_expr, filter, order_by)) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 674557b1f9d9..7603983c1b7f 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -542,7 +542,7 @@ impl AggregateFunction { } } - /// Create a new ScalarFunction expression with a user-defined function (UDF) + /// Create a new AggregateFunction expression with a user-defined function (UDF) pub fn new_udf( udf: Arc, args: Vec, @@ -1610,16 +1610,33 @@ fn create_name(e: &Expr) -> Result { args, filter, order_by, - }) => { - let mut name = create_function_name(func_def.name(), *distinct, args)?; - if let Some(fe) = filter { - name = format!("{name} FILTER (WHERE {fe})"); - }; - if let Some(order_by) = order_by { - name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); - }; - Ok(name) - } + }) => match func_def { + AggregateFunctionDefinition::BuiltIn { fun: _, name: _ } + | AggregateFunctionDefinition::Name(_) => { + let mut name = create_function_name(func_def.name(), *distinct, args)?; + if let Some(fe) = filter { + name = format!("{name} FILTER (WHERE {fe})"); + }; + if let Some(order_by) = order_by { + name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); + }; + Ok(name) + } + AggregateFunctionDefinition::UDF(fun) => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e)?); + } + let mut info = String::new(); + if let Some(fe) = filter { + info += &format!(" FILTER (WHERE {fe})"); + } + if let Some(ob) = order_by { + info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob)); + } + Ok(format!("{}({}){}", fun.name(), names.join(","), info)) + } + }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { Ok(format!("ROLLUP ({})", create_names(exprs.as_slice())?)) diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 16d8c72d462a..b63f1d21116d 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -19,7 +19,7 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, InSubquery}; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::ScalarSubquery; @@ -144,7 +144,11 @@ impl TreeNodeRewriter for CountWildcardRewriter { _ => old_expr, }, Expr::AggregateFunction(AggregateFunction { - func_def: _, + func_def: + AggregateFunctionDefinition::BuiltIn { + fun: aggregate_function::AggregateFunction::Count, + name: _, + }, args, distinct, filter, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index f66ff0eae6f0..f3e29296d7a3 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -71,7 +71,7 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - func_def, + func_def: AggregateFunctionDefinition::BuiltIn { fun, name: _ }, distinct, args, filter, @@ -86,19 +86,8 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { for e in args { fields_set.insert(e.canonical_name()); } - } else { - match func_def { - AggregateFunctionDefinition::BuiltIn { fun, name: _ } => { - if !matches!(fun, Sum | Min | Max) { - return Ok(false); - } else { - return Ok(true); - } - } - _ => { - return Ok(false); - } - } + } else if !matches!(fun, Sum | Min | Max) { + return Ok(false); } } } From 99273deb083c7abf32744540d590175ee44d15c7 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 29 Nov 2023 22:21:36 +0100 Subject: [PATCH 5/8] simplify the code --- datafusion/core/src/physical_planner.rs | 158 +++++++----------- datafusion/expr/src/aggregate_function.rs | 2 +- datafusion/expr/src/expr.rs | 67 ++++---- datafusion/expr/src/expr_schema.rs | 14 +- datafusion/expr/src/tree_node/expr.rs | 2 +- .../src/analyzer/count_wildcard_rule.rs | 7 +- .../optimizer/src/analyzer/type_coercion.rs | 2 +- datafusion/optimizer/src/decorrelate.rs | 2 +- .../src/single_distinct_to_groupby.rs | 5 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- datafusion/sql/src/expr/mod.rs | 2 +- .../substrait/src/logical_plan/producer.rs | 4 +- 12 files changed, 114 insertions(+), 153 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ebbc3d83bdd7..9e64eb9c5108 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -236,7 +236,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { filter, order_by, }) => match func_def { - AggregateFunctionDefinition::BuiltIn { fun: _, name: _ } => { + AggregateFunctionDefinition::BuiltIn(..) => { create_function_physical_name(func_def.name(), *distinct, args) } AggregateFunctionDefinition::UDF(fun) => { @@ -251,10 +251,10 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { "aggregate expression with order_by is not supported" ); } - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_physical_name(e, false)?); - } + let names = args + .iter() + .map(|e| create_physical_name(e, false)) + .collect::>>()?; Ok(format!("{}({})", fun.name(), names.join(","))) } AggregateFunctionDefinition::Name(_) => { @@ -1718,100 +1718,72 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( args, filter, order_by, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn { fun, name: _ } => { - let args = args - .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?; - let filter = match filter { - Some(e) => Some(create_physical_expr( + }) => { + let args = args + .iter() + .map(|e| { + create_physical_expr( e, logical_input_schema, physical_input_schema, execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, - }; - let ordering_reqs = order_by.clone().unwrap_or(vec![]); - let agg_expr = aggregates::create_aggregate_expr( - fun, - *distinct, - &args, - &ordering_reqs, + ) + }) + .collect::>>()?; + let filter = match filter { + Some(e) => Some(create_physical_expr( + e, + logical_input_schema, physical_input_schema, - name, - )?; - Ok((agg_expr, filter, order_by)) - } - AggregateFunctionDefinition::UDF(fun) => { - let args = args - .iter() - .map(|e| { - create_physical_expr( - e, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?; - - let filter = match filter { - Some(e) => Some(create_physical_expr( - e, - logical_input_schema, + execution_props, + )?), + None => None, + }; + let order_by = match order_by { + Some(e) => Some( + e.iter() + .map(|expr| { + create_physical_sort_expr( + expr, + logical_input_schema, + physical_input_schema, + execution_props, + ) + }) + .collect::>>()?, + ), + None => None, + }; + let (agg_expr, filter, order_by) = match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + let ordering_reqs = order_by.clone().unwrap_or(vec![]); + let agg_expr = aggregates::create_aggregate_expr( + fun, + *distinct, + &args, + &ordering_reqs, physical_input_schema, - execution_props, - )?), - None => None, - }; - let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - physical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), - None => None, - }; - - let agg_expr = - udaf::create_aggregate_expr(fun, &args, physical_input_schema, name); - Ok((agg_expr?, filter, order_by)) - } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Aggregate function name should have been resolved") - } - }, + name, + )?; + (agg_expr, filter, order_by) + } + AggregateFunctionDefinition::UDF(fun) => { + let agg_expr = udaf::create_aggregate_expr( + fun, + &args, + physical_input_schema, + name, + ); + (agg_expr?, filter, order_by) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Aggregate function name should have been resolved" + ) + } + }; + Ok((agg_expr, filter, order_by)) + } other => internal_err!("Invalid aggregate expression '{other:?}'"), } } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 4611c7fb10d7..cea72c3cb5e6 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -105,7 +105,7 @@ pub enum AggregateFunction { } impl AggregateFunction { - fn name(&self) -> &str { + pub fn name(&self) -> &str { use AggregateFunction::*; match self { Count => "COUNT", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 7603983c1b7f..4eb1bc1f6ca7 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -483,13 +483,10 @@ impl Sort { } #[derive(Debug, Clone, PartialEq, Eq, Hash)] -/// Defines which implementation of a function for DataFusion to call. +/// Defines which implementation of an aggregate function DataFusion should call. pub enum AggregateFunctionDefinition { - BuiltIn { - fun: aggregate_function::AggregateFunction, - name: Arc, - }, - /// Resolved to a user defined function + BuiltIn(aggregate_function::AggregateFunction), + /// Resolved to a user defined aggregate function UDF(Arc), /// A aggregation function constructed with name. This variant can not be executed directly /// and instead must be resolved to one of the other variants prior to physical planning. @@ -500,11 +497,17 @@ impl AggregateFunctionDefinition { /// Function's name for display pub fn name(&self) -> &str { match self { - AggregateFunctionDefinition::BuiltIn { name, .. } => name.as_ref(), + AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), AggregateFunctionDefinition::UDF(udf) => udf.name(), AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), } } + + pub fn new_builtin( + fun: aggregate_function::AggregateFunction, + ) -> AggregateFunctionDefinition { + Self::BuiltIn(fun) + } } /// Aggregate function @@ -531,10 +534,7 @@ impl AggregateFunction { order_by: Option>, ) -> Self { Self { - func_def: AggregateFunctionDefinition::BuiltIn { - fun: fun.clone(), - name: Arc::from(fun.to_string()), - }, + func_def: AggregateFunctionDefinition::new_builtin(fun), args, distinct, filter, @@ -1610,33 +1610,28 @@ fn create_name(e: &Expr) -> Result { args, filter, order_by, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn { fun: _, name: _ } - | AggregateFunctionDefinition::Name(_) => { - let mut name = create_function_name(func_def.name(), *distinct, args)?; - if let Some(fe) = filter { - name = format!("{name} FILTER (WHERE {fe})"); - }; - if let Some(order_by) = order_by { - name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); - }; - Ok(name) - } - AggregateFunctionDefinition::UDF(fun) => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e)?); - } - let mut info = String::new(); - if let Some(fe) = filter { - info += &format!(" FILTER (WHERE {fe})"); + }) => { + let mut name = match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + create_function_name(func_def.name(), *distinct, args)? } - if let Some(ob) = order_by { - info += &format!(" ORDER BY ([{}])", expr_vec_fmt!(ob)); + AggregateFunctionDefinition::UDF(..) => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e)?); + } + names.join(",") } - Ok(format!("{}({}){}", fun.name(), names.join(","), info)) - } - }, + }; + if let Some(fe) = filter { + name = format!("{name} FILTER (WHERE {fe})"); + }; + if let Some(order_by) = order_by { + name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); + }; + Ok(name) + } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { Ok(format!("ROLLUP ({})", create_names(exprs.as_slice())?)) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 65ecb4153d3e..99b27e8912bc 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -124,19 +124,15 @@ impl ExprSchemable for Expr { fun.return_type(&data_types) } Expr::AggregateFunction(AggregateFunction { func_def, args, .. }) => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; match func_def { - AggregateFunctionDefinition::BuiltIn { fun, .. } => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; + AggregateFunctionDefinition::BuiltIn(fun) => { fun.return_type(&data_types) } AggregateFunctionDefinition::UDF(fun) => { - let data_types = args - .iter() - .map(|e| e.get_type(schema)) - .collect::>>()?; Ok(fun.return_type(&data_types)?) } AggregateFunctionDefinition::Name(_) => { diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 53cfa630dac4..0818ffc9ef80 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -309,7 +309,7 @@ impl TreeNode for Expr { filter, order_by, }) => match func_def { - AggregateFunctionDefinition::BuiltIn { fun, .. } => { + AggregateFunctionDefinition::BuiltIn (fun) => { Expr::AggregateFunction(AggregateFunction::new( fun, transform_vec(args, &mut transform)?, diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index b63f1d21116d..fd84bb80160b 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -145,10 +145,9 @@ impl TreeNodeRewriter for CountWildcardRewriter { }, Expr::AggregateFunction(AggregateFunction { func_def: - AggregateFunctionDefinition::BuiltIn { - fun: aggregate_function::AggregateFunction::Count, - name: _, - }, + AggregateFunctionDefinition::BuiltIn( + aggregate_function::AggregateFunction::Count, + ), args, distinct, filter, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 53aa7a4c1d34..c54ca196b318 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -352,7 +352,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { filter, order_by, }) => match func_def { - AggregateFunctionDefinition::BuiltIn { fun, .. } => { + AggregateFunctionDefinition::BuiltIn (fun) => { let new_expr = coerce_agg_exprs_for_signature( &fun, &args, diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index f9f5847d6aff..19ab52b6c57a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -374,7 +374,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { match func_def { - AggregateFunctionDefinition::BuiltIn { fun, .. } => { + AggregateFunctionDefinition::BuiltIn (fun) => { if matches!(fun, datafusion_expr::AggregateFunction::Count) { Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( 0, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index f3e29296d7a3..7e6fb6b355ab 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -71,7 +71,7 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result { let mut aggregate_count = 0; for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn { fun, name: _ }, + func_def: AggregateFunctionDefinition::BuiltIn(fun), distinct, args, filter, @@ -171,8 +171,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .iter() .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { - func_def: - AggregateFunctionDefinition::BuiltIn { fun, name: _ }, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, .. diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index c8aaea7c058f..ac5a59533645 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -660,7 +660,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref order_by, }) => { match func_def { - AggregateFunctionDefinition::BuiltIn { fun, .. } => { + AggregateFunctionDefinition::BuiltIn (fun) => { let aggr_function = match fun { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 743893c2022e..b8c130055a5a 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -707,7 +707,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ) -> Result { match self.sql_expr_to_logical_expr(expr, schema, planner_context)? { Expr::AggregateFunction(expr::AggregateFunction { - func_def: AggregateFunctionDefinition::BuiltIn { fun, .. }, + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, order_by, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 9492f93c4302..d576e70711df 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -34,7 +34,7 @@ use datafusion::common::{exec_err, internal_err, not_impl_err}; use datafusion::logical_expr::aggregate_function; use datafusion::logical_expr::expr::{ AggregateFunctionDefinition, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, - ScalarFunction as DFScalarFunction, ScalarFunctionDefinition, Sort, WindowFunction, + ScalarFunctionDefinition, Sort, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -580,7 +580,7 @@ pub fn to_substrait_agg_measure( match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, args, distinct, filter, order_by }) => { match func_def { - AggregateFunctionDefinition::BuiltIn { fun, .. } => { + AggregateFunctionDefinition::BuiltIn (fun) => { let sorts = if let Some(order_by) = order_by { order_by.iter().map(|expr| to_substrait_sort_field(expr, schema, extension_info)).collect::>>()? } else { From 6e4ca515be5e5fc9733c0d31909ef3e0abbbcc20 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 29 Nov 2023 22:28:10 +0100 Subject: [PATCH 6/8] fix fmt --- datafusion/expr/src/tree_node/expr.rs | 2 +- datafusion/optimizer/src/analyzer/type_coercion.rs | 2 +- datafusion/optimizer/src/decorrelate.rs | 2 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 0818ffc9ef80..fcb0a4cd93f3 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -309,7 +309,7 @@ impl TreeNode for Expr { filter, order_by, }) => match func_def { - AggregateFunctionDefinition::BuiltIn (fun) => { + AggregateFunctionDefinition::BuiltIn(fun) => { Expr::AggregateFunction(AggregateFunction::new( fun, transform_vec(args, &mut transform)?, diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c54ca196b318..bedc86e2f4f1 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -352,7 +352,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { filter, order_by, }) => match func_def { - AggregateFunctionDefinition::BuiltIn (fun) => { + AggregateFunctionDefinition::BuiltIn(fun) => { let new_expr = coerce_agg_exprs_for_signature( &fun, &args, diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 19ab52b6c57a..b1000f042c98 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -374,7 +374,7 @@ fn agg_exprs_evaluation_result_on_empty_batch( let new_expr = match expr { Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { match func_def { - AggregateFunctionDefinition::BuiltIn (fun) => { + AggregateFunctionDefinition::BuiltIn(fun) => { if matches!(fun, datafusion_expr::AggregateFunction::Count) { Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( 0, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ac5a59533645..6bfd4c3438f5 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -660,7 +660,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { ref order_by, }) => { match func_def { - AggregateFunctionDefinition::BuiltIn (fun) => { + AggregateFunctionDefinition::BuiltIn(fun) => { let aggr_function = match fun { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct From 441ca13669e298b950d963e76b760c8009637447 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 29 Nov 2023 22:51:09 +0100 Subject: [PATCH 7/8] fix ci --- datafusion/expr/src/expr.rs | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 4eb1bc1f6ca7..11a2db2ca62d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -502,12 +502,6 @@ impl AggregateFunctionDefinition { AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), } } - - pub fn new_builtin( - fun: aggregate_function::AggregateFunction, - ) -> AggregateFunctionDefinition { - Self::BuiltIn(fun) - } } /// Aggregate function @@ -534,7 +528,7 @@ impl AggregateFunction { order_by: Option>, ) -> Self { Self { - func_def: AggregateFunctionDefinition::new_builtin(fun), + func_def: AggregateFunctionDefinition::BuiltIn(fun), args, distinct, filter, @@ -1611,26 +1605,33 @@ fn create_name(e: &Expr) -> Result { filter, order_by, }) => { - let mut name = match func_def { + let name = match func_def { AggregateFunctionDefinition::BuiltIn(..) | AggregateFunctionDefinition::Name(..) => { create_function_name(func_def.name(), *distinct, args)? } AggregateFunctionDefinition::UDF(..) => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e)?); - } + let names: Vec = + args.iter().map(|e| create_name(e)).collect::>()?; names.join(",") } }; + let mut info = String::new(); if let Some(fe) = filter { - name = format!("{name} FILTER (WHERE {fe})"); + info += &format!(" FILTER (WHERE {fe})"); }; if let Some(order_by) = order_by { - name = format!("{name} ORDER BY [{}]", expr_vec_fmt!(order_by)); + info += &format!(" ORDER BY [{}]", expr_vec_fmt!(order_by)); }; - Ok(name) + match func_def { + AggregateFunctionDefinition::BuiltIn(..) + | AggregateFunctionDefinition::Name(..) => { + Ok(format!("{}{}", name, info)) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(format!("{}({}){}", fun.name(), name, info)) + } + } } Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => { From b9017bdfcc0b235ce5f884f619da0a9a65b16170 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Thu, 30 Nov 2023 00:04:14 +0100 Subject: [PATCH 8/8] fix clippy --- datafusion/expr/src/expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 11a2db2ca62d..256f5b210ec2 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1612,7 +1612,7 @@ fn create_name(e: &Expr) -> Result { } AggregateFunctionDefinition::UDF(..) => { let names: Vec = - args.iter().map(|e| create_name(e)).collect::>()?; + args.iter().map(create_name).collect::>()?; names.join(",") } };