Skip to content

Commit

Permalink
Make SUM and AVG Aggregate Type Coercion Explicit (#7369)
Browse files Browse the repository at this point in the history
* Make Aggregate Type Coercion Explicit

* Clippy
  • Loading branch information
tustvold authored Aug 22, 2023
1 parent ffccbe6 commit 870857a
Show file tree
Hide file tree
Showing 19 changed files with 385 additions and 416 deletions.
37 changes: 2 additions & 35 deletions datafusion/core/src/physical_plan/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use datafusion_execution::TaskContext;
use datafusion_expr::Accumulator;
use datafusion_physical_expr::{
equivalence::project_equivalence_properties,
expressions::{Avg, CastExpr, Column, Sum},
expressions::Column,
normalize_out_expr_with_columns_map, reverse_order_bys,
utils::{convert_to_expr, get_indices_of_matching_exprs},
AggregateExpr, LexOrdering, LexOrderingReq, OrderingEquivalenceProperties,
Expand Down Expand Up @@ -1010,40 +1010,7 @@ fn aggregate_expressions(
| AggregateMode::SinglePartitioned => Ok(aggr_expr
.iter()
.map(|agg| {
let pre_cast_type = if let Some(Sum {
data_type,
pre_cast_to_sum_type,
..
}) = agg.as_any().downcast_ref::<Sum>()
{
if *pre_cast_to_sum_type {
Some(data_type.clone())
} else {
None
}
} else if let Some(Avg {
sum_data_type,
pre_cast_to_sum_type,
..
}) = agg.as_any().downcast_ref::<Avg>()
{
if *pre_cast_to_sum_type {
Some(sum_data_type.clone())
} else {
None
}
} else {
None
};
let mut result = agg
.expressions()
.into_iter()
.map(|expr| {
pre_cast_type.clone().map_or(expr.clone(), |cast_type| {
Arc::new(CastExpr::new(expr, cast_type, None))
})
})
.collect::<Vec<_>>();
let mut result = agg.expressions().clone();
// In partial mode, append ordering requirements to expressions' results.
// Ordering requirements are used by subsequent executors to satisfy the required
// ordering for `AggregateMode::FinalPartitioned`/`AggregateMode::Final` modes.
Expand Down
11 changes: 10 additions & 1 deletion datafusion/core/tests/fuzz_cases/window_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ use datafusion_expr::{

use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion_common::{Result, ScalarValue};
use datafusion_physical_expr::expressions::{col, lit};
use datafusion_expr::type_coercion::aggregates::coerce_types;
use datafusion_physical_expr::expressions::{cast, col, lit};
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;

Expand Down Expand Up @@ -261,6 +262,14 @@ fn get_random_function(
let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];
let (window_fn, new_args) = window_fn_map.values().collect::<Vec<_>>()[rand_fn_idx];
if let WindowFunction::AggregateFunction(f) = window_fn {
let a = args[0].clone();
let dt = a.data_type(schema.as_ref()).unwrap();
let sig = f.signature();
let coerced = coerce_types(f, &[dt], &sig).unwrap();
args[0] = cast(a, schema, coerced[0].clone()).unwrap();
}

for new_arg in new_args {
args.push(new_arg.clone());
}
Expand Down
24 changes: 10 additions & 14 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,20 +228,16 @@ impl AggregateFunction {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.

let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
self,
input_expr_types,
&self.signature(),
)
// original errors are all related to wrong function signature
// aggregate them for better error message
.map_err(|_| {
DataFusionError::Plan(utils::generate_signature_error_msg(
&format!("{self}"),
self.signature(),
input_expr_types,
))
})?;
let coerced_data_types = coerce_types(self, input_expr_types, &self.signature())
// original errors are all related to wrong function signature
// aggregate them for better error message
.map_err(|_| {
DataFusionError::Plan(utils::generate_signature_error_msg(
&format!("{self}"),
self.signature(),
input_expr_types,
))
})?;

match self {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Expand Down
111 changes: 67 additions & 44 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow::datatypes::{
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};

use datafusion_common::{internal_err, plan_err, DataFusionError, Result};
use std::ops::Deref;

Expand Down Expand Up @@ -89,6 +90,7 @@ pub fn coerce_types(
input_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
use DataType::*;
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun, input_types, &signature.type_signature)?;

Expand All @@ -105,26 +107,44 @@ pub fn coerce_types(
AggregateFunction::Sum => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
if !is_sum_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(input_types.to_vec())
let v = match &input_types[0] {
Decimal128(p, s) => Decimal128(*p, *s),
Decimal256(p, s) => Decimal256(*p, *s),
d if d.is_signed_integer() => Int64,
d if d.is_unsigned_integer() => UInt64,
d if d.is_floating() => Float64,
Dictionary(_, v) => {
return coerce_types(agg_fun, &[v.as_ref().clone()], signature)
}
_ => {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
)
}
};
Ok(vec![v])
}
AggregateFunction::Avg => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval
if !is_avg_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(input_types.to_vec())
let v = match &input_types[0] {
Decimal128(p, s) => Decimal128(*p, *s),
Decimal256(p, s) => Decimal256(*p, *s),
d if d.is_numeric() => Float64,
Dictionary(_, v) => {
return coerce_types(agg_fun, &[v.as_ref().clone()], signature)
}
_ => {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
)
}
};
Ok(vec![v])
}
AggregateFunction::BitAnd
| AggregateFunction::BitOr
Expand Down Expand Up @@ -160,7 +180,7 @@ pub fn coerce_types(
input_types[0]
);
}
Ok(input_types.to_vec())
Ok(vec![Float64, Float64])
}
AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
Expand All @@ -170,7 +190,7 @@ pub fn coerce_types(
input_types[0]
);
}
Ok(input_types.to_vec())
Ok(vec![Float64, Float64])
}
AggregateFunction::Stddev | AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
Expand All @@ -180,7 +200,7 @@ pub fn coerce_types(
input_types[0]
);
}
Ok(input_types.to_vec())
Ok(vec![Float64])
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
Expand All @@ -190,7 +210,7 @@ pub fn coerce_types(
input_types[0]
);
}
Ok(input_types.to_vec())
Ok(vec![Float64, Float64])
}
AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
Expand All @@ -211,7 +231,7 @@ pub fn coerce_types(
input_types[0]
);
}
Ok(input_types.to_vec())
Ok(vec![Float64, Float64])
}
AggregateFunction::ApproxPercentileCont => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
Expand Down Expand Up @@ -357,11 +377,9 @@ fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
/// function return type of a sum
pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
arg_type if SIGNED_INTEGERS.contains(arg_type) => Ok(DataType::Int64),
arg_type if UNSIGNED_INTEGERS.contains(arg_type) => Ok(DataType::UInt64),
// In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,
// the result type of floating-point is FLOAT64 with the double precision.
DataType::Float64 | DataType::Float32 => Ok(DataType::Float64),
DataType::Int64 => Ok(DataType::Int64),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float64 => Ok(DataType::Float64),
DataType::Decimal128(precision, scale) => {
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
Expand All @@ -374,9 +392,6 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal256(new_precision, *scale))
}
DataType::Dictionary(_, dict_value_type) => {
sum_return_type(dict_value_type.as_ref())
}
other => plan_err!("SUM does not support type \"{other:?}\""),
}
}
Expand Down Expand Up @@ -601,21 +616,29 @@ mod tests {
assert_eq!(*input_type, result.unwrap());
}
}
// test sum, avg
let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg];
let input_types = vec![
vec![DataType::Int32],
vec![DataType::Float32],
vec![DataType::Decimal128(20, 3)],
vec![DataType::Decimal256(20, 3)],
];
for fun in funs {
for input_type in &input_types {
let signature = fun.signature();
let result = coerce_types(&fun, input_type, &signature);
assert_eq!(*input_type, result.unwrap());
}
}
// test sum
let fun = AggregateFunction::Sum;
let signature = fun.signature();
let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap();
assert_eq!(r[0], DataType::Int64);
let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap();
assert_eq!(r[0], DataType::Float64);
let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap();
assert_eq!(r[0], DataType::Decimal128(20, 3));
let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap();
assert_eq!(r[0], DataType::Decimal256(20, 3));

// test avg
let fun = AggregateFunction::Avg;
let signature = fun.signature();
let r = coerce_types(&fun, &[DataType::Int32], &signature).unwrap();
assert_eq!(r[0], DataType::Float64);
let r = coerce_types(&fun, &[DataType::Float32], &signature).unwrap();
assert_eq!(r[0], DataType::Float64);
let r = coerce_types(&fun, &[DataType::Decimal128(20, 3)], &signature).unwrap();
assert_eq!(r[0], DataType::Decimal128(20, 3));
let r = coerce_types(&fun, &[DataType::Decimal256(20, 3)], &signature).unwrap();
assert_eq!(r[0], DataType::Decimal256(20, 3));

// ApproxPercentileCont input types
let input_types = vec![
Expand Down
21 changes: 17 additions & 4 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_numeric, is_utf8_or_large_u
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, LogicalPlan, Operator,
Projection, WindowFrame, WindowFrameBound, WindowFrameUnits,
type_coercion, window_function, AggregateFunction, BuiltinScalarFunction, Expr,
LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits,
};
use datafusion_expr::{ExprSchemable, Signature};

Expand Down Expand Up @@ -381,6 +381,19 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}) => {
let window_frame =
coerce_window_frame(window_frame, &self.schema, &order_by)?;

let args = match &fun {
window_function::WindowFunction::AggregateFunction(fun) => {
coerce_agg_exprs_for_signature(
fun,
&args,
&self.schema,
&fun.signature(),
)?
}
_ => args,
};

let expr = Expr::WindowFunction(WindowFunction::new(
fun,
args,
Expand Down Expand Up @@ -961,7 +974,7 @@ mod test {
None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
let expected = "Projection: AVG(Int64(12))\n EmptyRelation";
let expected = "Projection: AVG(CAST(Int64(12) AS Float64))\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;

let empty = empty_with_type(DataType::Int32);
Expand All @@ -974,7 +987,7 @@ mod test {
None,
));
let plan = LogicalPlan::Projection(Projection::try_new(vec![agg_expr], empty)?);
let expected = "Projection: AVG(a)\n EmptyRelation";
let expected = "Projection: AVG(CAST(a AS Float64))\n EmptyRelation";
assert_analyzed_plan_eq(Arc::new(TypeCoercion::new()), &plan, expected)?;
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ fn subquery_filter_with_cast() -> Result<()> {
\n Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.AVG(test.col_int32)\
\n TableScan: test projection=[col_int32]\
\n SubqueryAlias: __scalar_sq_1\
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\
\n Aggregate: groupBy=[[]], aggr=[[AVG(CAST(test.col_int32 AS Float64))]]\
\n Projection: test.col_int32\
\n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\
\n TableScan: test projection=[col_int32, col_utf8]";
Expand Down
Loading

0 comments on commit 870857a

Please sign in to comment.