diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 17d23a86f85c..59d6f44f59b1 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -2058,7 +2058,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(results.to_string(), "Error during planning: Coercion from [Timestamp(Nanosecond, None)] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed."); + assert_eq!(results.to_string(), "Error during planning: The function Sum does not support inputs of type Timestamp(Nanosecond, None)."); Ok(()) } @@ -2155,7 +2155,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(results.to_string(), "Error during planning: Coercion from [Timestamp(Nanosecond, None)] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed."); + assert_eq!(results.to_string(), "Error during planning: The function Avg does not support inputs of type Timestamp(Nanosecond, None)."); Ok(()) } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 1ec33a409efb..3f9766fd5680 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -28,15 +28,16 @@ use super::{ functions::{Signature, Volatility}, - type_coercion::{coerce, data_types}, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use expressions::{avg_return_type, sum_return_type}; use std::{fmt, str::FromStr, sync::Arc}; + /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = Arc Result> + Send + Sync>; @@ -87,13 +88,14 @@ impl FromStr for AggregateFunction { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", name - ))) + ))); } }) } } -/// Returns the datatype of the aggregation function +/// Returns the datatype of the aggregate function. +/// This is used to get the returned data type for aggregate expr. pub fn return_type( fun: &AggregateFunction, input_expr_types: &[DataType], @@ -101,21 +103,23 @@ pub fn return_type( // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &signature(fun))?; + let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?; match fun { + // TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64. AggregateFunction::Count | AggregateFunction::ApproxDistinct => { Ok(DataType::UInt64) } AggregateFunction::Max | AggregateFunction::Min => { - Ok(input_expr_types[0].clone()) + // For min and max agg function, the returned type is same as input type. + // The coerced_data_types is same with input_types. + Ok(coerced_data_types[0].clone()) } - AggregateFunction::Sum => sum_return_type(&input_expr_types[0]), - AggregateFunction::Avg => avg_return_type(&input_expr_types[0]), + AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), + AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", - input_expr_types[0].clone(), + coerced_data_types[0].clone(), true, )))), } @@ -131,26 +135,26 @@ pub fn create_aggregate_expr( name: impl Into, ) -> Result> { let name = name.into(); - let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?; + // get the coerced phy exprs if some expr need to be wrapped with the try cast. + let coerced_phy_exprs = + coerce_exprs(fun, input_phy_exprs, input_schema, &signature(fun))?; if coerced_phy_exprs.is_empty() { return Err(DataFusionError::Plan(format!( "Invalid or wrong number of arguments passed to aggregate: '{}'", name, ))); } - let coerced_exprs_types = coerced_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - let input_exprs_types = input_phy_exprs + // get the result data type for this aggregate function + let input_phy_types = input_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - - // In order to get the result data type, we must use the original input data type to calculate the result type. - let return_type = return_type(fun, &input_exprs_types)?; + let return_type = return_type(fun, &input_phy_types)?; Ok(match (fun, distinct) { (AggregateFunction::Count, false) => Arc::new(expressions::Count::new( @@ -161,7 +165,7 @@ pub fn create_aggregate_expr( (AggregateFunction::Count, true) => { Arc::new(distinct_expressions::DistinctCount::new( coerced_exprs_types, - coerced_phy_exprs.to_vec(), + coerced_phy_exprs, name, return_type, )) @@ -262,6 +266,199 @@ pub fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::error::Result; + use crate::physical_plan::expressions::{ + ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, + }; + + #[test] + fn test_count_arragg_approx_expr() -> Result<()> { + let funcs = vec![ + AggregateFunction::Count, + AggregateFunction::ArrayAgg, + AggregateFunction::ApproxDistinct, + ]; + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal(10, 2), + DataType::Utf8, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Count => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::ApproxDistinct => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, false), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::ArrayAgg => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new( + "c1", + DataType::List(Box::new(Field::new( + "item", + data_type.clone(), + true + ))), + false + ), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + + #[test] + fn test_min_max_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal(10, 2), + DataType::Utf8, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Min => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::Max => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + + #[test] + fn test_sum_avg_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Sum => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + let mut expect_type = data_type.clone(); + if matches!( + data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) { + expect_type = DataType::UInt64; + } else if matches!( + data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + ) { + expect_type = DataType::Int64; + } else if matches!( + data_type, + DataType::Float32 | DataType::Float64 + ) { + expect_type = data_type.clone(); + } + assert_eq!( + Field::new("c1", expect_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::Avg => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } #[test] fn test_min_max() -> Result<()> { @@ -270,6 +467,16 @@ mod tests { let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?; assert_eq!(DataType::Int32, observed); + + // test decimal for min + let observed = return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)])?; + assert_eq!(DataType::Decimal(10, 6), observed); + + // test decimal for max + let observed = + return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)])?; + assert_eq!(DataType::Decimal(28, 13), observed); + Ok(()) } @@ -293,6 +500,10 @@ mod tests { let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?; assert_eq!(DataType::UInt64, observed); + + let observed = + return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?; + assert_eq!(DataType::UInt64, observed); Ok(()) } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs new file mode 100644 index 000000000000..d7b437528d5c --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -0,0 +1,207 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Support the coercion rule for aggregate function. + +use crate::arrow::datatypes::Schema; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::aggregates::AggregateFunction; +use crate::physical_plan::expressions::{ + is_avg_support_arg_type, is_sum_support_arg_type, try_cast, +}; +use crate::physical_plan::functions::{Signature, TypeSignature}; +use crate::physical_plan::PhysicalExpr; +use arrow::datatypes::DataType; +use std::ops::Deref; +use std::sync::Arc; + +/// Returns the coerced data type for each `input_types`. +/// Different aggregate function with different input data type will get corresponding coerced data type. +pub(crate) fn coerce_types( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &Signature, +) -> Result> { + match signature.type_signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != agg_count { + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + agg_count, + input_types.len() + ))); + } + } + _ => { + return Err(DataFusionError::Internal(format!( + "Aggregate functions do not support this {:?}", + signature + ))); + } + }; + match agg_fun { + AggregateFunction::Count | AggregateFunction::ApproxDistinct => { + Ok(input_types.to_vec()) + } + AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), + AggregateFunction::Min | AggregateFunction::Max => { + // min and max support the dictionary data type + // unpack the dictionary to get the value + get_min_max_result_type(input_types) + } + AggregateFunction::Sum => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + if !is_sum_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Avg => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval + if !is_avg_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +fn get_min_max_result_type(input_types: &[DataType]) -> Result> { + // make sure that the input types only has one element. + assert_eq!(input_types.len(), 1); + // min and max support the dictionary data type + // unpack the dictionary to get the value + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + // TODO add checker for datatype which min and max supported + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function + _ => Ok(input_types.to_vec()), + } +} + +/// Returns the coerced exprs for each `input_exprs`. +/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the +/// data type of `input_exprs` need to be coerced. +pub(crate) fn coerce_exprs( + agg_fun: &AggregateFunction, + input_exprs: &[Arc], + schema: &Schema, + signature: &Signature, +) -> Result>> { + if input_exprs.is_empty() { + return Ok(vec![]); + } + let input_types = input_exprs + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + // get the coerced data types + let coerced_types = coerce_types(agg_fun, &input_types, signature)?; + + // try cast if need + input_exprs + .iter() + .zip(coerced_types.into_iter()) + .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) + .collect::>>() +} + +#[cfg(test)] +mod tests { + use crate::physical_plan::aggregates; + use crate::physical_plan::aggregates::AggregateFunction; + use crate::physical_plan::coercion_rule::aggregate_rule::coerce_types; + use arrow::datatypes::DataType; + + #[test] + fn test_aggregate_coerce_types() { + // test input args with error number input types + let fun = AggregateFunction::Min; + let input_types = vec![DataType::Int64, DataType::Int32]; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().to_string()); + + // test input args is invalid data type for sum or avg + let fun = AggregateFunction::Sum; + let input_types = vec![DataType::Utf8]; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!( + "Error during planning: The function Sum does not support inputs of type Utf8.", + result.unwrap_err().to_string() + ); + let fun = AggregateFunction::Avg; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!( + "Error during planning: The function Avg does not support inputs of type Utf8.", + result.unwrap_err().to_string() + ); + + // test count, array_agg, approx_distinct, min, max. + // the coerced types is same with input types + let funs = vec![ + AggregateFunction::Count, + AggregateFunction::ArrayAgg, + AggregateFunction::ApproxDistinct, + AggregateFunction::Min, + AggregateFunction::Max, + ]; + let input_types = vec![ + vec![DataType::Int32], + // support the decimal data type for min/max agg + // vec![DataType::Decimal(10, 2)], + vec![DataType::Utf8], + ]; + for fun in funs { + for input_type in &input_types { + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } + } + // test sum, avg + let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Float32], + // support the decimal data type + // vec![DataType::Decimal(20, 3)], + ]; + for fun in funs { + for input_type in &input_types { + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } + } + } +} diff --git a/datafusion/src/physical_plan/coercion_rule/mod.rs b/datafusion/src/physical_plan/coercion_rule/mod.rs new file mode 100644 index 000000000000..1aeabda793b1 --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Define coercion rules for different Expr type. +//! +//! Aggregate function rule + +pub(crate) mod aggregate_rule; diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 17d3041453d0..feb568c8dd72 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -60,6 +60,22 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { } } +pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + impl Avg { /// Create a new AVG aggregate function pub fn new( diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 5647ee0a4d27..134c6d89ac4f 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -60,6 +60,7 @@ pub mod helpers { pub use approx_distinct::ApproxDistinct; pub use array_agg::ArrayAgg; +pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; @@ -83,6 +84,7 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index c3f57e31e0d5..c570aef72b52 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -63,6 +63,22 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } } +pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ) +} + impl Sum { /// Create a new SUM aggregate function pub fn new( diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index ef53d8602b40..8c5f662a4ac7 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -608,6 +608,7 @@ pub mod analyze; pub mod array_expressions; pub mod coalesce_batches; pub mod coalesce_partitions; +mod coercion_rule; pub mod common; pub mod cross_join; #[cfg(feature = "crypto_expressions")] diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 640556cb2724..1dbc90da7df2 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -5612,10 +5612,15 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> { let mut ctx = ExecutionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; - let logical_plan = ctx.create_logical_plan(sql)?; - let physical_plan = ctx.create_physical_plan(&logical_plan).await; - let err = physical_plan.unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Invalid or wrong number of arguments passed to aggregate: 'COUNT(DISTINCT )'"); + let logical_plan = ctx.create_logical_plan(sql); + let err = logical_plan.unwrap_err(); + assert_eq!( + err.to_string(), + DataFusionError::Plan( + "The function Count expects 1 arguments, but 0 were provided".to_string() + ) + .to_string() + ); Ok(()) }