From c01db7e2fed1b61d0b2c0bbde3faa1406712288a Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 30 Nov 2021 18:29:45 +0800 Subject: [PATCH 1/4] upgrade the arrow-rs version --- ballista/rust/core/Cargo.toml | 2 +- ballista/rust/executor/Cargo.toml | 4 ++-- datafusion-cli/Cargo.toml | 2 +- datafusion-examples/Cargo.toml | 2 +- datafusion/Cargo.toml | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index c59cd1ea4e48..ee41669ed08b 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -43,7 +43,7 @@ tonic = "0.5" uuid = { version = "0.8", features = ["v4"] } chrono = "0.4" -arrow-flight = { version = "6.2.0" } +arrow-flight = { version = "6.3.0" } datafusion = { path = "../../../datafusion", version = "6.0.0" } diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index f190392223a3..6717a0a12f5c 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -29,8 +29,8 @@ edition = "2018" snmalloc = ["snmalloc-rs"] [dependencies] -arrow = { version = "6.2.0" } -arrow-flight = { version = "6.2.0" } +arrow = { version = "6.3.0" } +arrow-flight = { version = "6.3.0" } anyhow = "1" async-trait = "0.1.36" ballista-core = { path = "../core", version = "0.6.0" } diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 54f3084293ce..0434f090da01 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -31,5 +31,5 @@ clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } datafusion = { path = "../datafusion", version = "6.0.0" } -arrow = { version = "6.2.0" } +arrow = { version = "6.3.0" } ballista = { path = "../ballista/rust/client", version = "0.6.0" } diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index bca0ec566030..3e8a6ec77f6d 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -34,7 +34,7 @@ path = "examples/avro_sql.rs" required-features = ["datafusion/avro"] [dev-dependencies] -arrow-flight = { version = "6.2.0" } +arrow-flight = { version = "6.3.0" } datafusion = { path = "../datafusion" } prost = "0.8" tonic = "0.5" diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 390dab43160d..fbe84e3ed0a0 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -52,8 +52,8 @@ avro = ["avro-rs", "num-traits"] [dependencies] ahash = "0.7" hashbrown = { version = "0.11", features = ["raw"] } -arrow = { version = "6.2.0", features = ["prettyprint"] } -parquet = { version = "6.2.0", features = ["arrow"] } +arrow = { version = "6.3.0", features = ["prettyprint"] } +parquet = { version = "6.3.0", features = ["arrow"] } sqlparser = "0.12" paste = "^1.0" num_cpus = "1.13.0" From fc48c5868a122fbadac3311c7f1e3c3a5686e2ed Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Wed, 1 Dec 2021 09:55:46 +0800 Subject: [PATCH 2/4] new framework for type coercion --- datafusion/src/execution/context.rs | 4 +- datafusion/src/physical_plan/aggregates.rs | 181 ++++++++++++++++-- .../coercion_rule/aggregate_rule.rs | 175 +++++++++++++++++ .../src/physical_plan/coercion_rule/mod.rs | 19 ++ .../src/physical_plan/expressions/average.rs | 19 ++ .../src/physical_plan/expressions/mod.rs | 2 + .../src/physical_plan/expressions/sum.rs | 19 ++ datafusion/src/physical_plan/mod.rs | 1 + datafusion/tests/sql.rs | 24 +-- 9 files changed, 411 insertions(+), 33 deletions(-) create mode 100644 datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs create mode 100644 datafusion/src/physical_plan/coercion_rule/mod.rs diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 27116c0a4a95..76ea4936823e 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -2058,7 +2058,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(results.to_string(), "Error during planning: Coercion from [Timestamp(Nanosecond, None)] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed."); + assert_eq!(results.to_string(), "Error during planning: The function Sum do not support the Timestamp(Nanosecond, None)."); Ok(()) } @@ -2155,7 +2155,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(results.to_string(), "Error during planning: Coercion from [Timestamp(Nanosecond, None)] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed."); + assert_eq!(results.to_string(), "Error during planning: The function Avg do not support the Timestamp(Nanosecond, None)."); Ok(()) } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 1ec33a409efb..44baf1b94c69 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -28,23 +28,24 @@ use super::{ functions::{Signature, Volatility}, - type_coercion::{coerce, data_types}, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use expressions::{avg_return_type, sum_return_type}; use std::{fmt, str::FromStr, sync::Arc}; + /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = - Arc Result> + Send + Sync>; +Arc Result> + Send + Sync>; /// This signature corresponds to which types an aggregator serializes /// its state, given its return datatype. pub type StateTypeFunction = - Arc Result>> + Send + Sync>; +Arc Result>> + Send + Sync>; /// Enum of all built-in aggregate functions #[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] @@ -87,13 +88,14 @@ impl FromStr for AggregateFunction { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", name - ))) + ))); } }) } } -/// Returns the datatype of the aggregation function +/// Returns the datatype of the aggregate function. +/// This is used to get the returned data type for aggregate expr. pub fn return_type( fun: &AggregateFunction, input_expr_types: &[DataType], @@ -101,21 +103,23 @@ pub fn return_type( // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &signature(fun))?; + let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?; match fun { + // TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64. AggregateFunction::Count | AggregateFunction::ApproxDistinct => { Ok(DataType::UInt64) } AggregateFunction::Max | AggregateFunction::Min => { - Ok(input_expr_types[0].clone()) + // For min and max agg function, the returned type is same as input type. + // The coerced_data_types is same with input_types. + Ok(coerced_data_types[0].clone()) } - AggregateFunction::Sum => sum_return_type(&input_expr_types[0]), - AggregateFunction::Avg => avg_return_type(&input_expr_types[0]), + AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), + AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", - input_expr_types[0].clone(), + coerced_data_types[0].clone(), true, )))), } @@ -131,26 +135,26 @@ pub fn create_aggregate_expr( name: impl Into, ) -> Result> { let name = name.into(); - let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?; + // get the coerced phy exprs if some expr need to be wrapped with the try cast. + let coerced_phy_exprs = + coerce_exprs(fun, input_phy_exprs, input_schema, &signature(fun))?; if coerced_phy_exprs.is_empty() { return Err(DataFusionError::Plan(format!( "Invalid or wrong number of arguments passed to aggregate: '{}'", name, ))); } - let coerced_exprs_types = coerced_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - let input_exprs_types = input_phy_exprs + // get the result data type for this aggregate function + let input_phy_types = input_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - - // In order to get the result data type, we must use the original input data type to calculate the result type. - let return_type = return_type(fun, &input_exprs_types)?; + let return_type = return_type(fun, &input_phy_types)?; Ok(match (fun, distinct) { (AggregateFunction::Count, false) => Arc::new(expressions::Count::new( @@ -161,7 +165,7 @@ pub fn create_aggregate_expr( (AggregateFunction::Count, true) => { Arc::new(distinct_expressions::DistinctCount::new( coerced_exprs_types, - coerced_phy_exprs.to_vec(), + coerced_phy_exprs, name, return_type, )) @@ -262,6 +266,131 @@ pub fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::error::Result; + use crate::physical_plan::expressions::{ApproxDistinct, ArrayAgg, Count, Max, Min}; + + #[test] + fn test_count_arragg_approx_expr() -> Result<()> { + let funcs = vec![ + AggregateFunction::Count, + AggregateFunction::ArrayAgg, + AggregateFunction::ApproxDistinct, + ]; + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal(10, 2), + DataType::Utf8, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Count => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::ApproxDistinct => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, false), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::ArrayAgg => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new( + "c1", + DataType::List(Box::new(Field::new( + "item", + data_type.clone(), + true + ))), + false + ), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + + #[test] + fn test_min_max_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal(10, 2), + DataType::Utf8, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Min => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::Max => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + + #[test] + fn test_sum_avg_expr() -> Result<()> { + // TODO + Ok(()) + } #[test] fn test_min_max() -> Result<()> { @@ -270,6 +399,16 @@ mod tests { let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?; assert_eq!(DataType::Int32, observed); + + // test decimal for min + let observed = return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)])?; + assert_eq!(DataType::Decimal(10, 6), observed); + + // test decimal for max + let observed = + return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)])?; + assert_eq!(DataType::Decimal(28, 13), observed); + Ok(()) } @@ -293,6 +432,10 @@ mod tests { let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?; assert_eq!(DataType::UInt64, observed); + + let observed = + return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?; + assert_eq!(DataType::UInt64, observed); Ok(()) } @@ -311,4 +454,4 @@ mod tests { let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]); assert!(observed.is_err()); } -} +} \ No newline at end of file diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs new file mode 100644 index 000000000000..5645379e5dc5 --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Support the coercion rule for aggregate function. + +use crate::arrow::datatypes::Schema; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::aggregates::AggregateFunction; +use crate::physical_plan::expressions::{ + is_avg_support_arg_type, is_sum_support_arg_type, try_cast, +}; +use crate::physical_plan::functions::{Signature, TypeSignature}; +use crate::physical_plan::PhysicalExpr; +use arrow::datatypes::DataType; +use std::sync::Arc; + +pub fn coerce_types( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &Signature, +) -> Result> { + match signature.type_signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != agg_count { + return Err(DataFusionError::Plan(format!("The function {:?} expect argument number is {:?}, but the input argument number is {:?}", + agg_fun, agg_count, input_types.len()))); + } + } + _ => { + return Err(DataFusionError::Plan(format!( + "The aggregate coercion rule don't support this {:?}", + signature + ))); + } + }; + match agg_fun { + AggregateFunction::Count | AggregateFunction::ApproxDistinct => { + Ok(input_types.to_vec()) + } + AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), + AggregateFunction::Min | AggregateFunction::Max => Ok(input_types.to_vec()), + AggregateFunction::Sum => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + if !is_sum_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} do not support the {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Avg => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval + if !is_avg_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} do not support the {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +pub fn coerce_exprs( + agg_fun: &AggregateFunction, + input_exprs: &[Arc], + schema: &Schema, + signature: &Signature, +) -> Result>> { + if input_exprs.is_empty() { + return Ok(vec![]); + } + let input_types = input_exprs + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + // get the coerced data types + let coerced_types = coerce_types(agg_fun, &input_types, signature)?; + + // try cast if need + input_exprs + .iter() + .enumerate() + .map(|(i, expr)| try_cast(expr.clone(), schema, coerced_types[i].clone())) + .collect::>>() +} + +#[cfg(test)] +mod tests { + use crate::physical_plan::aggregates; + use crate::physical_plan::aggregates::{signature, AggregateFunction}; + use crate::physical_plan::coercion_rule::aggregate_rule::coerce_types; + use arrow::datatypes::DataType; + + #[test] + fn test_aggregate_coerce_types() { + // test input args with error number input types + let fun = AggregateFunction::Min; + let input_types = vec![DataType::Int64, DataType::Int32]; + let signature = signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!("Error during planning: The function Min expect argument number is 1, but the input argument number is 2", result.unwrap_err().to_string()); + + // test input args is invalid data type for sum or avg + let fun = AggregateFunction::Sum; + let input_types = vec![DataType::Utf8]; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!( + "Error during planning: The function Sum do not support the Utf8.", + result.unwrap_err().to_string() + ); + let fun = AggregateFunction::Avg; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!( + "Error during planning: The function Avg do not support the Utf8.", + result.unwrap_err().to_string() + ); + + // test count, array_agg, approx_distinct, min, max. + // the coerced types is same with input types + let funs = vec![ + AggregateFunction::Count, + AggregateFunction::ArrayAgg, + AggregateFunction::ApproxDistinct, + AggregateFunction::Min, + AggregateFunction::Max, + ]; + let input_types = vec![ + vec![DataType::Int32], + // vec![DataType::Decimal(10, 2)], + vec![DataType::Utf8], + ]; + for fun in funs { + for input_type in &input_types { + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } + } + // test sum, avg + let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Float32], + // vec![DataType::Decimal(20, 3)], + ]; + for fun in funs { + for input_type in &input_types { + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } + } + } +} \ No newline at end of file diff --git a/datafusion/src/physical_plan/coercion_rule/mod.rs b/datafusion/src/physical_plan/coercion_rule/mod.rs new file mode 100644 index 000000000000..78ae20028bd7 --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! define the coercion rule for different Expr type +pub(crate) mod aggregate_rule; \ No newline at end of file diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 2e218191f668..90cc41a45ad1 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -62,6 +62,25 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { } } +pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { + // TODO support the interval + // TODO: do we need to support the unsigned data type? + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + // | DataType::Decimal(_, _) + ) +} + impl Avg { /// Create a new AVG aggregate function pub fn new( diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 5647ee0a4d27..134c6d89ac4f 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -60,6 +60,7 @@ pub mod helpers { pub use approx_distinct::ApproxDistinct; pub use array_agg::ArrayAgg; +pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; @@ -83,6 +84,7 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index c3f57e31e0d5..054fc5df322f 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -63,6 +63,25 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } } +pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { + // TODO support the interval + // TODO: do we need to support the unsigned data type? + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + // | DataType::Decimal(_, _) + ) +} + impl Sum { /// Create a new SUM aggregate function pub fn new( diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index ef53d8602b40..20cd77193981 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -648,3 +648,4 @@ pub mod union; pub mod values; pub mod window_functions; pub mod windows; +mod coercion_rule; diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 640556cb2724..03d5cb07b03d 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -4400,7 +4400,7 @@ async fn query_group_on_null_multi_col() -> Result<()> { Ok(()) } -#[tokio::test] +// #[tokio::test] async fn query_on_string_dictionary() -> Result<()> { // Test to ensure DataFusion can operate on dictionary types // Use StringDictionary (32 bit indexes = keys) @@ -5607,17 +5607,17 @@ async fn test_physical_plan_display_indent_multi_children() { ); } -#[tokio::test] -async fn test_aggregation_with_bad_arguments() -> Result<()> { - let mut ctx = ExecutionContext::new(); - register_aggregate_csv(&mut ctx).await?; - let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; - let logical_plan = ctx.create_logical_plan(sql)?; - let physical_plan = ctx.create_physical_plan(&logical_plan).await; - let err = physical_plan.unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Invalid or wrong number of arguments passed to aggregate: 'COUNT(DISTINCT )'"); - Ok(()) -} +// #[tokio::test] +// async fn test_aggregation_with_bad_arguments() -> Result<()> { +// let mut ctx = ExecutionContext::new(); +// register_aggregate_csv(&mut ctx).await?; +// let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; +// let logical_plan = ctx.create_logical_plan(sql)?; +// let physical_plan = ctx.create_physical_plan(&logical_plan).await; +// let err = physical_plan.unwrap_err(); +// assert_eq!(err.to_string(), DataFusionError::Plan("The function Count expect argument number is 1, but the input argument number is 0".to_string()).to_string()); +// Ok(()) +// } // Normalizes parts of an explain plan that vary from run to run (such as path) fn normalize_for_explain(s: &str) -> String { From 42b219215f1962bbf529c06392f5c7c978f70cc4 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Wed, 1 Dec 2021 15:36:59 +0800 Subject: [PATCH 3/4] fix min/max accept the dict data type --- datafusion/src/physical_plan/aggregates.rs | 6 ++--- .../coercion_rule/aggregate_rule.rs | 23 +++++++++++++++++-- .../src/physical_plan/coercion_rule/mod.rs | 2 +- .../src/physical_plan/expressions/average.rs | 1 - .../src/physical_plan/expressions/sum.rs | 1 - datafusion/src/physical_plan/mod.rs | 2 +- datafusion/tests/sql.rs | 23 +++++++++---------- 7 files changed, 37 insertions(+), 21 deletions(-) diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 44baf1b94c69..9efccc89fe5f 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -40,12 +40,12 @@ use std::{fmt, str::FromStr, sync::Arc}; /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = -Arc Result> + Send + Sync>; + Arc Result> + Send + Sync>; /// This signature corresponds to which types an aggregator serializes /// its state, given its return datatype. pub type StateTypeFunction = -Arc Result>> + Send + Sync>; + Arc Result>> + Send + Sync>; /// Enum of all built-in aggregate functions #[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] @@ -454,4 +454,4 @@ mod tests { let observed = return_type(&AggregateFunction::Avg, &[DataType::Utf8]); assert!(observed.is_err()); } -} \ No newline at end of file +} diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index 5645379e5dc5..cb4d2669aa21 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -26,6 +26,7 @@ use crate::physical_plan::expressions::{ use crate::physical_plan::functions::{Signature, TypeSignature}; use crate::physical_plan::PhysicalExpr; use arrow::datatypes::DataType; +use std::ops::Deref; use std::sync::Arc; pub fn coerce_types( @@ -52,7 +53,11 @@ pub fn coerce_types( Ok(input_types.to_vec()) } AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), - AggregateFunction::Min | AggregateFunction::Max => Ok(input_types.to_vec()), + AggregateFunction::Min | AggregateFunction::Max => { + // min and max support the dictionary data type + // unpack the dictionary to get the value + get_min_max_result_type(input_types) + } AggregateFunction::Sum => { // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc // smallint, int, bigint, real, double precision, decimal, or interval. @@ -78,6 +83,20 @@ pub fn coerce_types( } } +fn get_min_max_result_type(input_types: &[DataType]) -> Result> { + // min and max support the dictionary data type + // unpack the dictionary to get the value + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + // TODO add checker for datatype which min and max supported + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function + _ => Ok(input_types.to_vec()), + } +} + pub fn coerce_exprs( agg_fun: &AggregateFunction, input_exprs: &[Arc], @@ -172,4 +191,4 @@ mod tests { } } } -} \ No newline at end of file +} diff --git a/datafusion/src/physical_plan/coercion_rule/mod.rs b/datafusion/src/physical_plan/coercion_rule/mod.rs index 78ae20028bd7..8d07b10bfe23 100644 --- a/datafusion/src/physical_plan/coercion_rule/mod.rs +++ b/datafusion/src/physical_plan/coercion_rule/mod.rs @@ -16,4 +16,4 @@ // under the License. //! define the coercion rule for different Expr type -pub(crate) mod aggregate_rule; \ No newline at end of file +pub(crate) mod aggregate_rule; diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 90cc41a45ad1..ec3af1fe46aa 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -77,7 +77,6 @@ pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 - // | DataType::Decimal(_, _) ) } diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index 054fc5df322f..dad9ac955610 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -78,7 +78,6 @@ pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { | DataType::Int64 | DataType::Float32 | DataType::Float64 - // | DataType::Decimal(_, _) ) } diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index 20cd77193981..8c5f662a4ac7 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -608,6 +608,7 @@ pub mod analyze; pub mod array_expressions; pub mod coalesce_batches; pub mod coalesce_partitions; +mod coercion_rule; pub mod common; pub mod cross_join; #[cfg(feature = "crypto_expressions")] @@ -648,4 +649,3 @@ pub mod union; pub mod values; pub mod window_functions; pub mod windows; -mod coercion_rule; diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 03d5cb07b03d..1d8eca43d8a9 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -4400,7 +4400,7 @@ async fn query_group_on_null_multi_col() -> Result<()> { Ok(()) } -// #[tokio::test] +#[tokio::test] async fn query_on_string_dictionary() -> Result<()> { // Test to ensure DataFusion can operate on dictionary types // Use StringDictionary (32 bit indexes = keys) @@ -5607,17 +5607,16 @@ async fn test_physical_plan_display_indent_multi_children() { ); } -// #[tokio::test] -// async fn test_aggregation_with_bad_arguments() -> Result<()> { -// let mut ctx = ExecutionContext::new(); -// register_aggregate_csv(&mut ctx).await?; -// let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; -// let logical_plan = ctx.create_logical_plan(sql)?; -// let physical_plan = ctx.create_physical_plan(&logical_plan).await; -// let err = physical_plan.unwrap_err(); -// assert_eq!(err.to_string(), DataFusionError::Plan("The function Count expect argument number is 1, but the input argument number is 0".to_string()).to_string()); -// Ok(()) -// } +#[tokio::test] +async fn test_aggregation_with_bad_arguments() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; + let logical_plan = ctx.create_logical_plan(sql); + let err = logical_plan.unwrap_err(); + assert_eq!(err.to_string(), DataFusionError::Plan("The function Count expect argument number is 1, but the input argument number is 0".to_string()).to_string()); + Ok(()) +} // Normalizes parts of an explain plan that vary from run to run (such as path) fn normalize_for_explain(s: &str) -> String { From 17064f25fa6dd45147422b3b278e09dc436bd6f3 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 7 Dec 2021 16:12:43 +0800 Subject: [PATCH 4/4] address comments --- datafusion/src/execution/context.rs | 4 +- datafusion/src/physical_plan/aggregates.rs | 72 ++++++++++++++++++- .../coercion_rule/aggregate_rule.rs | 43 +++++++---- .../src/physical_plan/coercion_rule/mod.rs | 5 +- .../src/physical_plan/expressions/average.rs | 2 - .../src/physical_plan/expressions/sum.rs | 2 - datafusion/tests/sql.rs | 8 ++- 7 files changed, 111 insertions(+), 25 deletions(-) diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 9f34b45369c0..59d6f44f59b1 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -2058,7 +2058,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(results.to_string(), "Error during planning: The function Sum do not support the Timestamp(Nanosecond, None)."); + assert_eq!(results.to_string(), "Error during planning: The function Sum does not support inputs of type Timestamp(Nanosecond, None)."); Ok(()) } @@ -2155,7 +2155,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(results.to_string(), "Error during planning: The function Avg do not support the Timestamp(Nanosecond, None)."); + assert_eq!(results.to_string(), "Error during planning: The function Avg does not support inputs of type Timestamp(Nanosecond, None)."); Ok(()) } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 9efccc89fe5f..3f9766fd5680 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -266,7 +266,9 @@ pub fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::error::Result; - use crate::physical_plan::expressions::{ApproxDistinct, ArrayAgg, Count, Max, Min}; + use crate::physical_plan::expressions::{ + ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Sum, + }; #[test] fn test_count_arragg_approx_expr() -> Result<()> { @@ -388,7 +390,73 @@ mod tests { #[test] fn test_sum_avg_expr() -> Result<()> { - // TODO + let funcs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Sum => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + let mut expect_type = data_type.clone(); + if matches!( + data_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + ) { + expect_type = DataType::UInt64; + } else if matches!( + data_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + ) { + expect_type = DataType::Int64; + } else if matches!( + data_type, + DataType::Float32 | DataType::Float64 + ) { + expect_type = data_type.clone(); + } + assert_eq!( + Field::new("c1", expect_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::Avg => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::Float64, true), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } Ok(()) } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index cb4d2669aa21..d7b437528d5c 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -29,7 +29,9 @@ use arrow::datatypes::DataType; use std::ops::Deref; use std::sync::Arc; -pub fn coerce_types( +/// Returns the coerced data type for each `input_types`. +/// Different aggregate function with different input data type will get corresponding coerced data type. +pub(crate) fn coerce_types( agg_fun: &AggregateFunction, input_types: &[DataType], signature: &Signature, @@ -37,13 +39,17 @@ pub fn coerce_types( match signature.type_signature { TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { if input_types.len() != agg_count { - return Err(DataFusionError::Plan(format!("The function {:?} expect argument number is {:?}, but the input argument number is {:?}", - agg_fun, agg_count, input_types.len()))); + return Err(DataFusionError::Plan(format!( + "The function {:?} expects {:?} arguments, but {:?} were provided", + agg_fun, + agg_count, + input_types.len() + ))); } } _ => { - return Err(DataFusionError::Plan(format!( - "The aggregate coercion rule don't support this {:?}", + return Err(DataFusionError::Internal(format!( + "Aggregate functions do not support this {:?}", signature ))); } @@ -63,7 +69,7 @@ pub fn coerce_types( // smallint, int, bigint, real, double precision, decimal, or interval. if !is_sum_support_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( - "The function {:?} do not support the {:?}.", + "The function {:?} does not support inputs of type {:?}.", agg_fun, input_types[0] ))); } @@ -74,7 +80,7 @@ pub fn coerce_types( // smallint, int, bigint, real, double precision, decimal, or interval if !is_avg_support_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( - "The function {:?} do not support the {:?}.", + "The function {:?} does not support inputs of type {:?}.", agg_fun, input_types[0] ))); } @@ -84,6 +90,8 @@ pub fn coerce_types( } fn get_min_max_result_type(input_types: &[DataType]) -> Result> { + // make sure that the input types only has one element. + assert_eq!(input_types.len(), 1); // min and max support the dictionary data type // unpack the dictionary to get the value match &input_types[0] { @@ -97,7 +105,10 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result> { } } -pub fn coerce_exprs( +/// Returns the coerced exprs for each `input_exprs`. +/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the +/// data type of `input_exprs` need to be coerced. +pub(crate) fn coerce_exprs( agg_fun: &AggregateFunction, input_exprs: &[Arc], schema: &Schema, @@ -117,15 +128,15 @@ pub fn coerce_exprs( // try cast if need input_exprs .iter() - .enumerate() - .map(|(i, expr)| try_cast(expr.clone(), schema, coerced_types[i].clone())) + .zip(coerced_types.into_iter()) + .map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type)) .collect::>>() } #[cfg(test)] mod tests { use crate::physical_plan::aggregates; - use crate::physical_plan::aggregates::{signature, AggregateFunction}; + use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::coercion_rule::aggregate_rule::coerce_types; use arrow::datatypes::DataType; @@ -134,9 +145,9 @@ mod tests { // test input args with error number input types let fun = AggregateFunction::Min; let input_types = vec![DataType::Int64, DataType::Int32]; - let signature = signature(&fun); + let signature = aggregates::signature(&fun); let result = coerce_types(&fun, &input_types, &signature); - assert_eq!("Error during planning: The function Min expect argument number is 1, but the input argument number is 2", result.unwrap_err().to_string()); + assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().to_string()); // test input args is invalid data type for sum or avg let fun = AggregateFunction::Sum; @@ -144,14 +155,14 @@ mod tests { let signature = aggregates::signature(&fun); let result = coerce_types(&fun, &input_types, &signature); assert_eq!( - "Error during planning: The function Sum do not support the Utf8.", + "Error during planning: The function Sum does not support inputs of type Utf8.", result.unwrap_err().to_string() ); let fun = AggregateFunction::Avg; let signature = aggregates::signature(&fun); let result = coerce_types(&fun, &input_types, &signature); assert_eq!( - "Error during planning: The function Avg do not support the Utf8.", + "Error during planning: The function Avg does not support inputs of type Utf8.", result.unwrap_err().to_string() ); @@ -166,6 +177,7 @@ mod tests { ]; let input_types = vec![ vec![DataType::Int32], + // support the decimal data type for min/max agg // vec![DataType::Decimal(10, 2)], vec![DataType::Utf8], ]; @@ -181,6 +193,7 @@ mod tests { let input_types = vec![ vec![DataType::Int32], vec![DataType::Float32], + // support the decimal data type // vec![DataType::Decimal(20, 3)], ]; for fun in funs { diff --git a/datafusion/src/physical_plan/coercion_rule/mod.rs b/datafusion/src/physical_plan/coercion_rule/mod.rs index 8d07b10bfe23..1aeabda793b1 100644 --- a/datafusion/src/physical_plan/coercion_rule/mod.rs +++ b/datafusion/src/physical_plan/coercion_rule/mod.rs @@ -15,5 +15,8 @@ // specific language governing permissions and limitations // under the License. -//! define the coercion rule for different Expr type +//! Define coercion rules for different Expr type. +//! +//! Aggregate function rule + pub(crate) mod aggregate_rule; diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index d7d839fc5d57..feb568c8dd72 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -61,8 +61,6 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { } pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { - // TODO support the interval - // TODO: do we need to support the unsigned data type? matches!( arg_type, DataType::UInt8 diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index dad9ac955610..c570aef72b52 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -64,8 +64,6 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { - // TODO support the interval - // TODO: do we need to support the unsigned data type? matches!( arg_type, DataType::UInt8 diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 1d8eca43d8a9..1dbc90da7df2 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -5614,7 +5614,13 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> { let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; let logical_plan = ctx.create_logical_plan(sql); let err = logical_plan.unwrap_err(); - assert_eq!(err.to_string(), DataFusionError::Plan("The function Count expect argument number is 1, but the input argument number is 0".to_string()).to_string()); + assert_eq!( + err.to_string(), + DataFusionError::Plan( + "The function Count expects 1 arguments, but 0 were provided".to_string() + ) + .to_string() + ); Ok(()) }