diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 14d9197217f06..f81efe5c35e3e 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -419,6 +419,7 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { let new_precision = DECIMAL_MAX_PRECISION.min(*precision + 10); Ok(DataType::Decimal(new_precision, *scale)) } + DataType::Null => Ok(DataType::Null), other => Err(DataFusionError::Plan(format!( "SUM does not support type \"{:?}\"", other @@ -526,6 +527,7 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { | DataType::UInt64 | DataType::Float32 | DataType::Float64 => Ok(DataType::Float64), + DataType::Null => Ok(DataType::Null), other => Err(DataFusionError::Plan(format!( "AVG does not support {:?}", other @@ -616,6 +618,7 @@ pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { | DataType::Float32 | DataType::Float64 | DataType::Decimal(_, _) + | DataType::Null ) } @@ -633,6 +636,7 @@ pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { | DataType::Float32 | DataType::Float64 | DataType::Decimal(_, _) + | DataType::Null ) } @@ -778,6 +782,7 @@ mod tests { vec![DataType::Int32], vec![DataType::Float32], vec![DataType::Decimal(20, 3)], + vec![DataType::Null], ]; for fun in funs { for input_type in &input_types { diff --git a/datafusion/physical-expr/src/expressions/average.rs b/datafusion/physical-expr/src/expressions/average.rs index 5c26e6c9d97cb..4844ed3784d5b 100644 --- a/datafusion/physical-expr/src/expressions/average.rs +++ b/datafusion/physical-expr/src/expressions/average.rs @@ -52,7 +52,7 @@ impl Avg { // the result of avg just support FLOAT64 and Decimal data type. assert!(matches!( data_type, - DataType::Float64 | DataType::Decimal(_, _) + DataType::Float64 | DataType::Decimal(_, _) | DataType::Null )); Self { name: name.into(), @@ -160,6 +160,7 @@ impl Accumulator for AvgAccumulator { ), }) } + ScalarValue::Null => Ok(ScalarValue::Null), _ => Err(DataFusionError::Internal( "Sum should be f64 on average".to_string(), )), diff --git a/datafusion/physical-expr/src/expressions/sum.rs b/datafusion/physical-expr/src/expressions/sum.rs index 84594fe7a530a..fe4fb2dc11c8c 100644 --- a/datafusion/physical-expr/src/expressions/sum.rs +++ b/datafusion/physical-expr/src/expressions/sum.rs @@ -159,6 +159,7 @@ pub(super) fn sum_batch(values: &ArrayRef) -> Result { DataType::UInt32 => typed_sum_delta_batch!(values, UInt32Array, UInt32), DataType::UInt16 => typed_sum_delta_batch!(values, UInt16Array, UInt16), DataType::UInt8 => typed_sum_delta_batch!(values, UInt8Array, UInt8), + DataType::Null => ScalarValue::Null, e => { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive the type {:?}", @@ -297,6 +298,7 @@ pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { (ScalarValue::Int64(lhs), ScalarValue::Int8(rhs)) => { typed_sum!(lhs, rhs, Int64, i64) } + (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, e => { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive a scalar {:?}",