From 8dcef91806443f9b9b512bf6d819dc20961b29c8 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Thu, 6 Oct 2022 14:27:57 -0500 Subject: [PATCH] Remove type coercions from ScalarValue and aggregation function code (#3705) * Sanitize ScalarValue and aggregation code from type coercions * Remove forced type cast from sum_row! macro used in SumRowAccumulator --- datafusion/common/src/scalar.rs | 470 +++++++----------- .../physical-expr/src/aggregate/average.rs | 66 +-- .../src/aggregate/correlation.rs | 24 +- .../physical-expr/src/aggregate/count.rs | 48 +- .../physical-expr/src/aggregate/covariance.rs | 27 +- .../physical-expr/src/aggregate/min_max.rs | 214 ++------ .../physical-expr/src/aggregate/stddev.rs | 29 +- datafusion/physical-expr/src/aggregate/sum.rs | 160 ++---- .../src/aggregate/sum_distinct.rs | 28 +- .../physical-expr/src/aggregate/variance.rs | 49 +- .../physical-expr/src/expressions/mod.rs | 14 + 11 files changed, 334 insertions(+), 795 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 42f0a7d16fcc..c3f91dd9b1d1 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -312,155 +312,186 @@ impl Eq for ScalarValue {} // TODO implement this in arrow-rs with simd // https://github.com/apache/arrow-rs/issues/1010 macro_rules! decimal_op { - ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr, $OPERATION:tt ) => {{ - let (difference, side) = if $LHS_SCALE > $RHS_SCALE { - ($LHS_SCALE - $RHS_SCALE, true) - } else { - ($RHS_SCALE - $LHS_SCALE, false) - }; - let scale = max($LHS_SCALE, $RHS_SCALE); - match ($LHS, $RHS, difference) { - (None, None, _) => ScalarValue::Decimal128(None, $PRECISION, scale), - (None, Some(rhs_value), 0) => ScalarValue::Decimal128(Some((0 as i128) $OPERATION rhs_value), $PRECISION, scale), - (None, Some(rhs_value), _) => { - let mut new_value = ((0 as i128) $OPERATION rhs_value); - if side { - new_value *= 10_i128.pow((difference) as u32) - }; - ScalarValue::Decimal128(Some(new_value), $PRECISION, scale) - } - (Some(lhs_value), None, 0) => ScalarValue::Decimal128(Some(lhs_value $OPERATION (0 as i128)), $PRECISION, scale), - (Some(lhs_value), None, _) => { - let mut new_value = (lhs_value $OPERATION (0 as i128)); - if !!!side { - new_value *= 10_i128.pow((difference) as u32) + ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr, $OPERATION:tt) => {{ + let (difference, side) = if $LHS_SCALE > $RHS_SCALE { + ($LHS_SCALE - $RHS_SCALE, true) + } else { + ($RHS_SCALE - $LHS_SCALE, false) + }; + let scale = max($LHS_SCALE, $RHS_SCALE); + Ok(match ($LHS, $RHS, difference) { + (None, None, _) => ScalarValue::Decimal128(None, $PRECISION, scale), + (lhs, None, 0) => ScalarValue::Decimal128(*lhs, $PRECISION, scale), + (Some(lhs_value), None, _) => { + let mut new_value = *lhs_value; + if !side { + new_value *= 10_i128.pow(difference as u32) + } + ScalarValue::Decimal128(Some(new_value), $PRECISION, scale) } - ScalarValue::Decimal128(Some(new_value), $PRECISION, scale) - } - (Some(lhs_value), Some(rhs_value), 0) => { - ScalarValue::Decimal128(Some(lhs_value $OPERATION rhs_value), $PRECISION, scale) - } - (Some(lhs_value), Some(rhs_value), _) => { - let new_value = if side { - rhs_value * 10_i128.pow((difference) as u32) $OPERATION lhs_value - } else { - lhs_value * 10_i128.pow((difference) as u32) $OPERATION rhs_value - }; - ScalarValue::Decimal128(Some(new_value), $PRECISION, scale) - } - }} + (None, Some(rhs_value), 0) => { + let value = decimal_right!(*rhs_value, $OPERATION); + ScalarValue::Decimal128(Some(value), $PRECISION, scale) + } + (None, Some(rhs_value), _) => { + let mut new_value = decimal_right!(*rhs_value, $OPERATION); + if side { + new_value *= 10_i128.pow(difference as u32) + }; + ScalarValue::Decimal128(Some(new_value), $PRECISION, scale) + } + (Some(lhs_value), Some(rhs_value), 0) => { + decimal_binary_op!(lhs_value, rhs_value, $OPERATION, $PRECISION, scale) + } + (Some(lhs_value), Some(rhs_value), _) => { + let (left_arg, right_arg) = if side { + (*lhs_value, rhs_value * 10_i128.pow(difference as u32)) + } else { + (lhs_value * 10_i128.pow(difference as u32), *rhs_value) + }; + decimal_binary_op!(left_arg, right_arg, $OPERATION, $PRECISION, scale) + } + }) + }}; +} - } +macro_rules! decimal_binary_op { + ($LHS:expr, $RHS:expr, $OPERATION:tt, $PRECISION:expr, $SCALE:expr) => { + // TODO: This simple implementation loses precision for calculations like + // multiplication and division. Improve this implementation for such + // operations. + ScalarValue::Decimal128(Some($LHS $OPERATION $RHS), $PRECISION, $SCALE) + }; } -// Returns the result of applying operation to two scalar values, including coercion into $TYPE. -macro_rules! typed_op { - ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $TYPE:ident, $OPERATION:tt) => { - Some(ScalarValue::$SCALAR(match ($LEFT, $RIGHT) { - (None, None) => None, - (Some(a), None) => Some((*a as $TYPE) $OPERATION (0 as $TYPE)), - (None, Some(b)) => Some((0 as $TYPE) $OPERATION (*b as $TYPE)), - (Some(a), Some(b)) => Some((*a as $TYPE) $OPERATION (*b as $TYPE)), - })) +macro_rules! decimal_right { + ($TERM:expr, +) => { + $TERM + }; + ($TERM:expr, *) => { + $TERM + }; + ($TERM:expr, -) => { + -$TERM + }; + ($TERM:expr, /) => { + Err(DataFusionError::NotImplemented(format!( + "Decimal reciprocation not yet supported", + ))) }; } -macro_rules! impl_common_symmetric_cases_op { - ($LHS:expr, $RHS:expr, $OPERATION:tt, [$([$L_TYPE:ident, $R_TYPE:ident, $O_TYPE:ident, $O_PRIM:ident]),+]) => { - match ($LHS, $RHS) { - $( - (ScalarValue::$L_TYPE(lhs), ScalarValue::$R_TYPE(rhs)) => { - typed_op!(lhs, rhs, $O_TYPE, $O_PRIM, $OPERATION) - } - (ScalarValue::$R_TYPE(lhs), ScalarValue::$L_TYPE(rhs)) => { - typed_op!(lhs, rhs, $O_TYPE, $O_PRIM, $OPERATION) - } - )+ - _ => None +// Returns the result of applying operation to two scalar values. +macro_rules! primitive_op { + ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $OPERATION:tt) => { + match ($LEFT, $RIGHT) { + (lhs, None) => Ok(ScalarValue::$SCALAR(*lhs)), + #[allow(unused_variables)] + (None, Some(b)) => { primitive_right!(*b, $OPERATION, $SCALAR) }, + (Some(a), Some(b)) => Ok(ScalarValue::$SCALAR(Some(*a $OPERATION *b))), } - } + }; } -macro_rules! impl_common_cases_op { +macro_rules! primitive_right { + ($TERM:expr, +, $SCALAR:ident) => { + Ok(ScalarValue::$SCALAR(Some($TERM))) + }; + ($TERM:expr, *, $SCALAR:ident) => { + Ok(ScalarValue::$SCALAR(Some($TERM))) + }; + ($TERM:expr, -, UInt64) => { + unsigned_subtraction_error!("UInt64") + }; + ($TERM:expr, -, UInt32) => { + unsigned_subtraction_error!("UInt32") + }; + ($TERM:expr, -, UInt16) => { + unsigned_subtraction_error!("UInt16") + }; + ($TERM:expr, -, UInt8) => { + unsigned_subtraction_error!("UInt8") + }; + ($TERM:expr, -, $SCALAR:ident) => { + Ok(ScalarValue::$SCALAR(Some(-$TERM))) + }; + ($TERM:expr, /, Float64) => { + Ok(ScalarValue::$SCALAR(Some($TERM.recip()))) + }; + ($TERM:expr, /, Float32) => { + Ok(ScalarValue::$SCALAR(Some($TERM.recip()))) + }; + ($TERM:expr, /, $SCALAR:ident) => { + Err(DataFusionError::Internal(format!( + "Can not divide an uninitialized value to a non-floating point value", + ))) + }; +} + +macro_rules! unsigned_subtraction_error { + ($SCALAR:expr) => {{ + let msg = format!( + "Can not subtract a {} value from an uninitialized value", + $SCALAR + ); + Err(DataFusionError::Internal(msg)) + }}; +} + +macro_rules! impl_op { ($LHS:expr, $RHS:expr, $OPERATION:tt) => { match ($LHS, $RHS) { ( ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2), ) => { - let max_precision = *p1.max(p2); - Some(decimal_op!(v1, v2, max_precision, *s1, *s2, $OPERATION)) + decimal_op!(v1, v2, *p1.max(p2), *s1, *s2, $OPERATION) } (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - typed_op!(lhs, rhs, Float64, f64, $OPERATION) + primitive_op!(lhs, rhs, Float64, $OPERATION) } (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - typed_op!(lhs, rhs, Float32, f32, $OPERATION) + primitive_op!(lhs, rhs, Float32, $OPERATION) } (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - typed_op!(lhs, rhs, UInt64, u64, $OPERATION) + primitive_op!(lhs, rhs, UInt64, $OPERATION) } (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - typed_op!(lhs, rhs, Int64, i64, $OPERATION) + primitive_op!(lhs, rhs, Int64, $OPERATION) } (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { - typed_op!(lhs, rhs, UInt32, u32, $OPERATION) + primitive_op!(lhs, rhs, UInt32, $OPERATION) } (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { - typed_op!(lhs, rhs, Int32, i32, $OPERATION) + primitive_op!(lhs, rhs, Int32, $OPERATION) } (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { - typed_op!(lhs, rhs, UInt16, u16, $OPERATION) + primitive_op!(lhs, rhs, UInt16, $OPERATION) } (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { - typed_op!(lhs, rhs, Int16, i16, $OPERATION) + primitive_op!(lhs, rhs, Int16, $OPERATION) } (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { - typed_op!(lhs, rhs, UInt8, u8, $OPERATION) + primitive_op!(lhs, rhs, UInt8, $OPERATION) } (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { - typed_op!(lhs, rhs, Int8, i8, $OPERATION) + primitive_op!(lhs, rhs, Int8, $OPERATION) + } + _ => { + impl_distinct_cases_op!($LHS, $RHS, $OPERATION) } - _ => impl_common_symmetric_cases_op!( - $LHS, - $RHS, - $OPERATION, - [ - // Float64 coerces everything to f64: - [Float64, Float32, Float64, f64], - [Float64, Int64, Float64, f64], - [Float64, Int32, Float64, f64], - [Float64, Int16, Float64, f64], - [Float64, Int8, Float64, f64], - [Float64, UInt64, Float64, f64], - [Float64, UInt32, Float64, f64], - [Float64, UInt16, Float64, f64], - [Float64, UInt8, Float64, f64], - // UInt64 coerces all smaller unsigned types to u64: - [UInt64, UInt32, UInt64, u64], - [UInt64, UInt16, UInt64, u64], - [UInt64, UInt8, UInt64, u64], - // Int64 coerces all smaller integral types to i64: - [Int64, Int32, Int64, i64], - [Int64, Int16, Int64, i64], - [Int64, Int8, Int64, i64], - [Int64, UInt32, Int64, i64], - [Int64, UInt16, Int64, i64], - [Int64, UInt8, Int64, i64] - ] - ), } }; } -/// If we want a special implementation for an ooperation this is the place to implement it -/// For instance, in the future we may want to implement subtraction for dates but not summation -/// so we can implement special case in the corresponding place +// If we want a special implementation for an operation this is the place to implement it. +// For instance, in the future we may want to implement subtraction for dates but not addition. +// We can implement such special cases here. macro_rules! impl_distinct_cases_op { ($LHS:expr, $RHS:expr, +) => { match ($LHS, $RHS) { e => Err(DataFusionError::Internal(format!( - "Summation is not implemented for {:?}", + "Addition is not implemented for {:?}", e ))), } @@ -475,15 +506,6 @@ macro_rules! impl_distinct_cases_op { }; } -macro_rules! impl_op { - ($LHS:expr, $RHS:expr, $OPERATION:tt) => { - match impl_common_cases_op!($LHS, $RHS, $OPERATION) { - Some(elem) => Ok(elem), - None => impl_distinct_cases_op!($LHS, $RHS, $OPERATION), - } - }; -} - // manual implementation of `Hash` that uses OrderedFloat to // get defined behavior for floating point impl std::hash::Hash for ScalarValue { @@ -938,11 +960,13 @@ impl ScalarValue { } pub fn is_unsigned(&self) -> bool { - let value_type = self.get_datatype(); - value_type == DataType::UInt64 - || value_type == DataType::UInt32 - || value_type == DataType::UInt16 - || value_type == DataType::UInt8 + matches!( + self, + ScalarValue::UInt8(_) + | ScalarValue::UInt16(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + ) } /// whether this value is null or not. @@ -2180,35 +2204,43 @@ impl TryFrom<&DataType> for ScalarValue { } } +// TODO: Remove these coercions once the hardcoded "u64" offset is changed to a +// ScalarValue in WindowFrameBound. pub trait TryFromValue { fn try_from_value(datatype: &DataType, value: T) -> Result; } macro_rules! impl_try_from_value { - ($NATIVE:ty, [$([$SCALAR:ident, $PRIMITIVE:tt]),+]) => { + ($NATIVE:ty, [$([$SCALAR:ident, $PRIMITIVE:ty]),+]) => { impl TryFromValue<$NATIVE> for ScalarValue { fn try_from_value(datatype: &DataType, value: $NATIVE) -> Result { - Ok(match datatype { - $(DataType::$SCALAR => ScalarValue::$SCALAR(Some(value as $PRIMITIVE)),)+ + match datatype { + $(DataType::$SCALAR => Ok(ScalarValue::$SCALAR(Some(value as $PRIMITIVE))),)+ _ => { - return Err(DataFusionError::NotImplemented(format!( - "Can't create a scalar from data_type \"{:?}\"", - datatype - ))); + let msg = format!("Can't create a scalar from data_type \"{:?}\"", datatype); + Err(DataFusionError::NotImplemented(msg)) } - }) + } } } }; } -macro_rules! impl_try_from_value_all { - ([$($NATIVE:ty),+]) => { - $(impl_try_from_value!($NATIVE, [[Float64, f64], [Float32, f32], [UInt64, u64], [UInt32, u32], [UInt16, u16], [UInt8, u8], [Int64, i64], [Int32, i32], [Int16, i16], [Int8, i8]]);)+ - } -} - -impl_try_from_value_all!([f64, f32, u64, u32, u16, u8, i64, i32, i16, i8]); +impl_try_from_value!( + u64, + [ + [Float64, f64], + [Float32, f32], + [UInt64, u64], + [UInt32, u32], + [UInt16, u16], + [UInt8, u8], + [Int64, i64], + [Int32, i32], + [Int16, i16], + [Int8, i8] + ] +); macro_rules! format_option { ($F:expr, $EXPR:expr) => {{ @@ -2440,18 +2472,6 @@ mod tests { float_value.sub(&float_value_2)?, ScalarValue::Float64(Some(0.)) ); - assert_eq!( - float_value.sub(&float_value_2)?, - ScalarValue::Float64(Some(0.)) - ); - assert_eq!( - float_value.sub(&float_value_2)?, - ScalarValue::Float64(Some(0.)) - ); - assert_eq!( - float_value.sub(&float_value_2)?, - ScalarValue::Float64(Some(0.)) - ); assert_eq!( float_value.sub(float_value_2)?, ScalarValue::Float64(Some(0.)) @@ -3693,37 +3713,36 @@ mod tests { Ok(()) } - #[test] - fn test_subtraction() { - let lhs = ScalarValue::Float64(Some(11.0)); - let rhs = ScalarValue::Float64(Some(12.0)); - assert_eq!(lhs.sub(rhs).unwrap(), ScalarValue::Float64(Some(-1.0))); - } - - #[test] - fn expect_subtraction_error() { - let lhs = ScalarValue::UInt64(Some(12)); - let rhs = ScalarValue::Int32(Some(-3)); - let expected_error = "Subtraction is not implemented"; - match lhs.sub(&rhs) { - Ok(_result) => { - panic!( - "Expected summation error between lhs: '{:?}', rhs: {:?}", - lhs, rhs - ); - } - Err(e) => { - let error_message = e.to_string(); - assert!( - error_message.contains(expected_error), - "Expected error '{}' not found in actual error '{}'", - expected_error, - error_message - ); + macro_rules! expect_operation_error { + ($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => { + #[test] + fn $TEST_NAME() { + let lhs = ScalarValue::UInt64(Some(12)); + let rhs = ScalarValue::Int32(Some(-3)); + match lhs.$FUNCTION(&rhs) { + Ok(_result) => { + panic!( + "Expected summation error between lhs: '{:?}', rhs: {:?}", + lhs, rhs + ); + } + Err(e) => { + let error_message = e.to_string(); + assert!( + error_message.contains($EXPECTED_ERROR), + "Expected error '{}' not found in actual error '{}'", + $EXPECTED_ERROR, + error_message + ); + } + } } - } + }; } + expect_operation_error!(expect_add_error, add, "Addition is not implemented"); + expect_operation_error!(expect_sub_error, sub, "Subtraction is not implemented"); + macro_rules! decimal_op_test_cases { ($OPERATION:ident, [$([$L_VALUE:expr, $L_PRECISION:expr, $L_SCALE:expr, $R_VALUE:expr, $R_PRECISION:expr, $R_SCALE:expr, $O_VALUE:expr, $O_PRECISION:expr, $O_SCALE:expr]),+]) => { $( @@ -3791,109 +3810,4 @@ mod tests { ] ); } - - macro_rules! op_test_cases { - ($LHS:expr, $RHS:expr, $OUT:expr, $OPERATION:ident, [$([$L_TYPE:ident, $L_PRIM:ident, $R_TYPE:ident, $R_PRIM:ident, $O_TYPE:ident, $O_PRIM:ident]),+]) => { - $( - // From left - let lhs = ScalarValue::$L_TYPE(Some($LHS as $L_PRIM)); - let rhs = ScalarValue::$R_TYPE(Some($RHS as $R_PRIM)); - assert_eq!(lhs.$OPERATION(rhs).unwrap(), ScalarValue::$O_TYPE(Some($OUT as $O_PRIM))); - // From right. The values ($RHS and $LHS) also crossed to produce same output for subtraction. - let lhs = ScalarValue::$L_TYPE(Some($RHS as $L_PRIM)); - let rhs = ScalarValue::$R_TYPE(Some($LHS as $R_PRIM)); - assert_eq!(rhs.$OPERATION(lhs).unwrap(), ScalarValue::$O_TYPE(Some($OUT as $O_PRIM))); - )+ - }; - } - - #[test] - fn test_sum_operation_different_types() { - op_test_cases!( - 11, - 12, - 23, - add, - [ - // FloatXY coerces everything to fXY: - [Float64, f64, Float32, f32, Float64, f64], - [Float64, f64, Int64, i64, Float64, f64], - [Float64, f64, Int32, i32, Float64, f64], - [Float64, f64, Int16, i16, Float64, f64], - [Float64, f64, Int8, i8, Float64, f64], - [Float64, f64, UInt64, u64, Float64, f64], - [Float64, f64, UInt32, u32, Float64, f64], - [Float64, f64, UInt16, u16, Float64, f64], - [Float64, f64, UInt8, u8, Float64, f64], - // UIntXY coerces all smaller unsigned types to uXY: - [UInt64, u64, UInt32, u32, UInt64, u64], - [UInt64, u64, UInt16, u16, UInt64, u64], - [UInt64, u64, UInt8, u8, UInt64, u64], - // IntXY types coerce smaller integral types to iXY: - [Int64, i64, Int32, i32, Int64, i64], - [Int64, i64, Int16, i16, Int64, i64], - [Int64, i64, Int8, i8, Int64, i64], - [Int64, i64, UInt32, u32, Int64, i64], - [Int64, i64, UInt16, u16, Int64, i64], - [Int64, i64, UInt8, u8, Int64, i64] - ] - ); - } - - #[test] - fn test_sub_operation_different_types() { - op_test_cases!( - 20, - 8, - 12, - sub, - [ - // FloatXY coerces everything to fXY: - [Float64, f64, Float32, f32, Float64, f64], - [Float64, f64, Int64, i64, Float64, f64], - [Float64, f64, Int32, i32, Float64, f64], - [Float64, f64, Int16, i16, Float64, f64], - [Float64, f64, Int8, i8, Float64, f64], - [Float64, f64, UInt64, u64, Float64, f64], - [Float64, f64, UInt32, u32, Float64, f64], - [Float64, f64, UInt16, u16, Float64, f64], - [Float64, f64, UInt8, u8, Float64, f64], - // UIntXY coerces all smaller unsigned types to uXY: - [UInt64, u64, UInt32, u32, UInt64, u64], - [UInt64, u64, UInt16, u16, UInt64, u64], - [UInt64, u64, UInt8, u8, UInt64, u64], - // IntXY types coerce smaller integral types to iXY: - [Int64, i64, Int32, i32, Int64, i64], - [Int64, i64, Int16, i16, Int64, i64], - [Int64, i64, Int8, i8, Int64, i64], - [Int64, i64, UInt32, u32, Int64, i64], - [Int64, i64, UInt16, u16, Int64, i64], - [Int64, i64, UInt8, u8, Int64, i64] - ] - ); - } - - #[test] - fn expect_summation_error() { - let lhs = ScalarValue::UInt64(Some(12)); - let rhs = ScalarValue::Int32(Some(-3)); - let expected_error = "Summation is not implemented"; - match lhs.add(&rhs) { - Ok(_result) => { - panic!( - "Expected summation error between lhs: '{:?}', rhs: {:?}", - lhs, rhs - ); - } - Err(e) => { - let error_message = e.to_string(); - assert!( - error_message.contains(expected_error), - "Expected error '{}' not found in actual error '{}'", - expected_error, - error_message - ); - } - } - } } diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 723ae7e9abb3..f034e3d56897 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -230,7 +230,6 @@ impl RowAccumulator for AvgRowAccumulator { // sum sum::add_to_row( - &self.sum_datatype, self.state_index() + 1, accessor, &sum::sum_batch(values, &self.sum_datatype)?, @@ -249,12 +248,8 @@ impl RowAccumulator for AvgRowAccumulator { accessor.add_u64(self.state_index(), delta); // sum - sum::add_to_row( - &self.sum_datatype, - self.state_index() + 1, - accessor, - &sum::sum_batch(&states[1], &self.sum_datatype)?, - )?; + let difference = sum::sum_batch(&states[1], &self.sum_datatype)?; + sum::add_to_row(self.state_index() + 1, accessor, &difference)?; Ok(()) } @@ -301,8 +296,7 @@ mod tests { array, DataType::Decimal128(10, 0), Avg, - ScalarValue::Decimal128(Some(35000), 14, 4), - DataType::Decimal128(14, 4) + ScalarValue::Decimal128(Some(35000), 14, 4) ) } @@ -318,8 +312,7 @@ mod tests { array, DataType::Decimal128(10, 0), Avg, - ScalarValue::Decimal128(Some(32500), 14, 4), - DataType::Decimal128(14, 4) + ScalarValue::Decimal128(Some(32500), 14, 4) ) } @@ -337,21 +330,14 @@ mod tests { array, DataType::Decimal128(10, 0), Avg, - ScalarValue::Decimal128(None, 14, 4), - DataType::Decimal128(14, 4) + ScalarValue::Decimal128(None, 14, 4) ) } #[test] fn avg_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - Avg, - ScalarValue::from(3_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3_f64)) } #[test] @@ -363,63 +349,33 @@ mod tests { Some(4), Some(5), ])); - generic_test_op!( - a, - DataType::Int32, - Avg, - ScalarValue::from(3.25f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3.25f64)) } #[test] fn avg_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!( - a, - DataType::Int32, - Avg, - ScalarValue::Float64(None), - DataType::Float64 - ) + generic_test_op!(a, DataType::Int32, Avg, ScalarValue::Float64(None)) } #[test] fn avg_u32() -> Result<()> { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - Avg, - ScalarValue::from(3.0f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::UInt32, Avg, ScalarValue::from(3.0f64)) } #[test] fn avg_f32() -> Result<()> { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!( - a, - DataType::Float32, - Avg, - ScalarValue::from(3_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float32, Avg, ScalarValue::from(3_f64)) } #[test] fn avg_f64() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - Avg, - ScalarValue::from(3_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float64, Avg, ScalarValue::from(3_f64)) } } diff --git a/datafusion/physical-expr/src/aggregate/correlation.rs b/datafusion/physical-expr/src/aggregate/correlation.rs index f25cb5790d37..8645bd5497af 100644 --- a/datafusion/physical-expr/src/aggregate/correlation.rs +++ b/datafusion/physical-expr/src/aggregate/correlation.rs @@ -217,8 +217,7 @@ mod tests { DataType::Float64, DataType::Float64, Correlation, - ScalarValue::from(0.9819805060619659), - DataType::Float64 + ScalarValue::from(0.9819805060619659_f64) ) } @@ -233,8 +232,7 @@ mod tests { DataType::Float64, DataType::Float64, Correlation, - ScalarValue::from(0.17066403719657236), - DataType::Float64 + ScalarValue::from(0.17066403719657236_f64) ) } @@ -249,8 +247,7 @@ mod tests { DataType::Float64, DataType::Float64, Correlation, - ScalarValue::from(1_f64), - DataType::Float64 + ScalarValue::from(1_f64) ) } @@ -269,8 +266,7 @@ mod tests { DataType::Float64, DataType::Float64, Correlation, - ScalarValue::from(0.9860135594710389), - DataType::Float64 + ScalarValue::from(0.9860135594710389_f64) ) } @@ -285,8 +281,7 @@ mod tests { DataType::Int32, DataType::Int32, Correlation, - ScalarValue::from(1_f64), - DataType::Float64 + ScalarValue::from(1_f64) ) } @@ -300,8 +295,7 @@ mod tests { DataType::UInt32, DataType::UInt32, Correlation, - ScalarValue::from(1_f64), - DataType::Float64 + ScalarValue::from(1_f64) ) } @@ -315,8 +309,7 @@ mod tests { DataType::Float32, DataType::Float32, Correlation, - ScalarValue::from(1_f64), - DataType::Float64 + ScalarValue::from(1_f64) ) } @@ -333,8 +326,7 @@ mod tests { DataType::Int32, DataType::Int32, Correlation, - ScalarValue::from(0.1889822365046137), - DataType::Float64 + ScalarValue::from(0.1889822365046137_f64) ) } diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 8cfb85fa9152..b64328aa3a78 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -210,13 +210,7 @@ mod tests { #[test] fn count_elements() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - Count, - ScalarValue::from(5i64), - DataType::Int64 - ) + generic_test_op!(a, DataType::Int32, Count, ScalarValue::from(5i64)) } #[test] @@ -229,13 +223,7 @@ mod tests { Some(3), None, ])); - generic_test_op!( - a, - DataType::Int32, - Count, - ScalarValue::from(3i64), - DataType::Int64 - ) + generic_test_op!(a, DataType::Int32, Count, ScalarValue::from(3i64)) } #[test] @@ -243,51 +231,27 @@ mod tests { let a: ArrayRef = Arc::new(BooleanArray::from(vec![ None, None, None, None, None, None, None, None, ])); - generic_test_op!( - a, - DataType::Boolean, - Count, - ScalarValue::from(0i64), - DataType::Int64 - ) + generic_test_op!(a, DataType::Boolean, Count, ScalarValue::from(0i64)) } #[test] fn count_empty() -> Result<()> { let a: Vec = vec![]; let a: ArrayRef = Arc::new(BooleanArray::from(a)); - generic_test_op!( - a, - DataType::Boolean, - Count, - ScalarValue::from(0i64), - DataType::Int64 - ) + generic_test_op!(a, DataType::Boolean, Count, ScalarValue::from(0i64)) } #[test] fn count_utf8() -> Result<()> { let a: ArrayRef = Arc::new(StringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"])); - generic_test_op!( - a, - DataType::Utf8, - Count, - ScalarValue::from(5i64), - DataType::Int64 - ) + generic_test_op!(a, DataType::Utf8, Count, ScalarValue::from(5i64)) } #[test] fn count_large_utf8() -> Result<()> { let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "bb", "ccc", "dddd", "ad"])); - generic_test_op!( - a, - DataType::LargeUtf8, - Count, - ScalarValue::from(5i64), - DataType::Int64 - ) + generic_test_op!(a, DataType::LargeUtf8, Count, ScalarValue::from(5i64)) } } diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index 0911111a3f99..63a8137c2371 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -397,8 +397,7 @@ mod tests { DataType::Float64, DataType::Float64, CovariancePop, - ScalarValue::from(0.6666666666666666), - DataType::Float64 + ScalarValue::from(0.6666666666666666_f64) ) } @@ -413,8 +412,7 @@ mod tests { DataType::Float64, DataType::Float64, Covariance, - ScalarValue::from(1_f64), - DataType::Float64 + ScalarValue::from(1_f64) ) } @@ -429,8 +427,7 @@ mod tests { DataType::Float64, DataType::Float64, Covariance, - ScalarValue::from(0.9033333333333335_f64), - DataType::Float64 + ScalarValue::from(0.9033333333333335_f64) ) } @@ -445,8 +442,7 @@ mod tests { DataType::Float64, DataType::Float64, CovariancePop, - ScalarValue::from(0.6022222222222223_f64), - DataType::Float64 + ScalarValue::from(0.6022222222222223_f64) ) } @@ -465,8 +461,7 @@ mod tests { DataType::Float64, DataType::Float64, CovariancePop, - ScalarValue::from(0.7616666666666666), - DataType::Float64 + ScalarValue::from(0.7616666666666666_f64) ) } @@ -481,8 +476,7 @@ mod tests { DataType::Int32, DataType::Int32, CovariancePop, - ScalarValue::from(0.6666666666666666_f64), - DataType::Float64 + ScalarValue::from(0.6666666666666666_f64) ) } @@ -496,8 +490,7 @@ mod tests { DataType::UInt32, DataType::UInt32, CovariancePop, - ScalarValue::from(0.6666666666666666_f64), - DataType::Float64 + ScalarValue::from(0.6666666666666666_f64) ) } @@ -511,8 +504,7 @@ mod tests { DataType::Float32, DataType::Float32, CovariancePop, - ScalarValue::from(0.6666666666666666_f64), - DataType::Float64 + ScalarValue::from(0.6666666666666666_f64) ) } @@ -527,8 +519,7 @@ mod tests { DataType::Int32, DataType::Int32, CovariancePop, - ScalarValue::from(1_f64), - DataType::Float64 + ScalarValue::from(1_f64) ) } diff --git a/datafusion/physical-expr/src/aggregate/min_max.rs b/datafusion/physical-expr/src/aggregate/min_max.rs index bdccdf522207..36d58c78008e 100644 --- a/datafusion/physical-expr/src/aggregate/min_max.rs +++ b/datafusion/physical-expr/src/aggregate/min_max.rs @@ -154,16 +154,10 @@ macro_rules! typed_min_max_batch_string { // Statically-typed version of min/max(array) -> ScalarValue for non-string types. macro_rules! typed_min_max_batch { - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{ - let array = downcast_value!($VALUES, $ARRAYTYPE); - let value = compute::$OP(array); - ScalarValue::$SCALAR(value) - }}; - - ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident, $TZ:expr) => {{ + ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ let array = downcast_value!($VALUES, $ARRAYTYPE); let value = compute::$OP(array); - ScalarValue::$SCALAR(value, $TZ.clone()) + ScalarValue::$SCALAR(value, $($EXTRA_ARGS.clone()),*) }}; } @@ -296,41 +290,18 @@ fn max_batch(values: &ArrayRef) -> Result { _ => min_max_batch!(values, max), }) } -macro_rules! typed_min_max_decimal { - ($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR( - match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), - (Some(a), Some(b)) => Some((*a).$OP(*b)), - }, - $PRECISION.clone(), - $SCALE.clone(), - ) - }}; -} // min/max of two non-string scalar values. macro_rules! typed_min_max { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR(match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), - (Some(a), Some(b)) => Some((*a).$OP(*b)), - }) - }}; - - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $TZ:expr) => {{ + ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ ScalarValue::$SCALAR( match ($VALUE, $DELTA) { (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), (Some(a), Some(b)) => Some((*a).$OP(*b)), }, - $TZ.clone(), + $($EXTRA_ARGS.clone()),* ) }}; } @@ -363,13 +334,16 @@ macro_rules! typed_min_max_string { macro_rules! min_max { ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ Ok(match ($VALUE, $DELTA) { - (ScalarValue::Decimal128(lhsv,lhsp,lhss), ScalarValue::Decimal128(rhsv,rhsp,rhss)) => { + ( + lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), + rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) + ) => { if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max_decimal!(lhsv, rhsv, lhsp, lhss, Decimal128, $OP) + typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) } else { return Err(DataFusionError::Internal(format!( "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (ScalarValue::Decimal128(*lhsv,*lhsp,*lhss),ScalarValue::Decimal128(*rhsv,*rhsp,*rhss)) + (lhs, rhs) ))); } } @@ -815,8 +789,7 @@ mod tests { array, DataType::Decimal128(10, 0), Min, - ScalarValue::Decimal128(Some(1), 10, 0), - DataType::Decimal128(10, 0) + ScalarValue::Decimal128(Some(1), 10, 0) ) } @@ -834,8 +807,7 @@ mod tests { array, DataType::Decimal128(10, 0), Min, - ScalarValue::Decimal128(None, 10, 0), - DataType::Decimal128(10, 0) + ScalarValue::Decimal128(None, 10, 0) ) } @@ -853,8 +825,7 @@ mod tests { array, DataType::Decimal128(10, 0), Min, - ScalarValue::Decimal128(Some(1), 10, 0), - DataType::Decimal128(10, 0) + ScalarValue::Decimal128(Some(1), 10, 0) ) } @@ -906,8 +877,7 @@ mod tests { array, DataType::Decimal128(10, 0), Max, - ScalarValue::Decimal128(Some(5), 10, 0), - DataType::Decimal128(10, 0) + ScalarValue::Decimal128(Some(5), 10, 0) ) } @@ -923,8 +893,7 @@ mod tests { array, DataType::Decimal128(10, 0), Max, - ScalarValue::Decimal128(Some(5), 10, 0), - DataType::Decimal128(10, 0) + ScalarValue::Decimal128(Some(5), 10, 0) ) } @@ -941,33 +910,20 @@ mod tests { array, DataType::Decimal128(10, 0), Min, - ScalarValue::Decimal128(None, 10, 0), - DataType::Decimal128(10, 0) + ScalarValue::Decimal128(None, 10, 0) ) } #[test] fn max_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - Max, - ScalarValue::from(5i32), - DataType::Int32 - ) + generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32)) } #[test] fn min_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - Min, - ScalarValue::from(1i32), - DataType::Int32 - ) + generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32)) } #[test] @@ -977,8 +933,7 @@ mod tests { a, DataType::Utf8, Max, - ScalarValue::Utf8(Some("d".to_string())), - DataType::Utf8 + ScalarValue::Utf8(Some("d".to_string())) ) } @@ -989,8 +944,7 @@ mod tests { a, DataType::LargeUtf8, Max, - ScalarValue::LargeUtf8(Some("d".to_string())), - DataType::LargeUtf8 + ScalarValue::LargeUtf8(Some("d".to_string())) ) } @@ -1001,8 +955,7 @@ mod tests { a, DataType::Utf8, Min, - ScalarValue::Utf8(Some("a".to_string())), - DataType::Utf8 + ScalarValue::Utf8(Some("a".to_string())) ) } @@ -1013,8 +966,7 @@ mod tests { a, DataType::LargeUtf8, Min, - ScalarValue::LargeUtf8(Some("a".to_string())), - DataType::LargeUtf8 + ScalarValue::LargeUtf8(Some("a".to_string())) ) } @@ -1027,13 +979,7 @@ mod tests { Some(4), Some(5), ])); - generic_test_op!( - a, - DataType::Int32, - Max, - ScalarValue::from(5i32), - DataType::Int32 - ) + generic_test_op!(a, DataType::Int32, Max, ScalarValue::from(5i32)) } #[test] @@ -1045,163 +991,85 @@ mod tests { Some(4), Some(5), ])); - generic_test_op!( - a, - DataType::Int32, - Min, - ScalarValue::from(1i32), - DataType::Int32 - ) + generic_test_op!(a, DataType::Int32, Min, ScalarValue::from(1i32)) } #[test] fn max_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!( - a, - DataType::Int32, - Max, - ScalarValue::Int32(None), - DataType::Int32 - ) + generic_test_op!(a, DataType::Int32, Max, ScalarValue::Int32(None)) } #[test] fn min_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!( - a, - DataType::Int32, - Min, - ScalarValue::Int32(None), - DataType::Int32 - ) + generic_test_op!(a, DataType::Int32, Min, ScalarValue::Int32(None)) } #[test] fn max_u32() -> Result<()> { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - Max, - ScalarValue::from(5_u32), - DataType::UInt32 - ) + generic_test_op!(a, DataType::UInt32, Max, ScalarValue::from(5_u32)) } #[test] fn min_u32() -> Result<()> { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - Min, - ScalarValue::from(1u32), - DataType::UInt32 - ) + generic_test_op!(a, DataType::UInt32, Min, ScalarValue::from(1u32)) } #[test] fn max_f32() -> Result<()> { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!( - a, - DataType::Float32, - Max, - ScalarValue::from(5_f32), - DataType::Float32 - ) + generic_test_op!(a, DataType::Float32, Max, ScalarValue::from(5_f32)) } #[test] fn min_f32() -> Result<()> { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!( - a, - DataType::Float32, - Min, - ScalarValue::from(1_f32), - DataType::Float32 - ) + generic_test_op!(a, DataType::Float32, Min, ScalarValue::from(1_f32)) } #[test] fn max_f64() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - Max, - ScalarValue::from(5_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float64, Max, ScalarValue::from(5_f64)) } #[test] fn min_f64() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - Min, - ScalarValue::from(1_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float64, Min, ScalarValue::from(1_f64)) } #[test] fn min_date32() -> Result<()> { let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Date32, - Min, - ScalarValue::Date32(Some(1)), - DataType::Date32 - ) + generic_test_op!(a, DataType::Date32, Min, ScalarValue::Date32(Some(1))) } #[test] fn min_date64() -> Result<()> { let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Date64, - Min, - ScalarValue::Date64(Some(1)), - DataType::Date64 - ) + generic_test_op!(a, DataType::Date64, Min, ScalarValue::Date64(Some(1))) } #[test] fn max_date32() -> Result<()> { let a: ArrayRef = Arc::new(Date32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Date32, - Max, - ScalarValue::Date32(Some(5)), - DataType::Date32 - ) + generic_test_op!(a, DataType::Date32, Max, ScalarValue::Date32(Some(5))) } #[test] fn max_date64() -> Result<()> { let a: ArrayRef = Arc::new(Date64Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Date64, - Max, - ScalarValue::Date64(Some(5)), - DataType::Date64 - ) + generic_test_op!(a, DataType::Date64, Max, ScalarValue::Date64(Some(5))) } #[test] @@ -1211,8 +1079,7 @@ mod tests { a, DataType::Time64(TimeUnit::Nanosecond), Max, - ScalarValue::Time64(Some(5)), - DataType::Time64(TimeUnit::Nanosecond) + ScalarValue::Time64(Some(5)) ) } @@ -1223,8 +1090,7 @@ mod tests { a, DataType::Time64(TimeUnit::Nanosecond), Max, - ScalarValue::Time64(Some(5)), - DataType::Time64(TimeUnit::Nanosecond) + ScalarValue::Time64(Some(5)) ) } } diff --git a/datafusion/physical-expr/src/aggregate/stddev.rs b/datafusion/physical-expr/src/aggregate/stddev.rs index 77f080293e27..5197018a568b 100644 --- a/datafusion/physical-expr/src/aggregate/stddev.rs +++ b/datafusion/physical-expr/src/aggregate/stddev.rs @@ -227,13 +227,7 @@ mod tests { #[test] fn stddev_f64_1() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64])); - generic_test_op!( - a, - DataType::Float64, - StddevPop, - ScalarValue::from(0.5_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float64, StddevPop, ScalarValue::from(0.5_f64)) } #[test] @@ -243,8 +237,7 @@ mod tests { a, DataType::Float64, StddevPop, - ScalarValue::from(0.7760297817881877), - DataType::Float64 + ScalarValue::from(0.7760297817881877_f64) ) } @@ -256,8 +249,7 @@ mod tests { a, DataType::Float64, StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2), - DataType::Float64 + ScalarValue::from(std::f64::consts::SQRT_2) ) } @@ -268,8 +260,7 @@ mod tests { a, DataType::Float64, Stddev, - ScalarValue::from(0.9504384952922168), - DataType::Float64 + ScalarValue::from(0.9504384952922168_f64) ) } @@ -280,8 +271,7 @@ mod tests { a, DataType::Int32, StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2), - DataType::Float64 + ScalarValue::from(std::f64::consts::SQRT_2) ) } @@ -293,8 +283,7 @@ mod tests { a, DataType::UInt32, StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2), - DataType::Float64 + ScalarValue::from(std::f64::consts::SQRT_2) ) } @@ -306,8 +295,7 @@ mod tests { a, DataType::Float32, StddevPop, - ScalarValue::from(std::f64::consts::SQRT_2), - DataType::Float64 + ScalarValue::from(std::f64::consts::SQRT_2) ) } @@ -341,8 +329,7 @@ mod tests { a, DataType::Int32, StddevPop, - ScalarValue::from(1.479019945774904), - DataType::Float64 + ScalarValue::from(1.479019945774904_f64) ) } diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 892ef5964deb..ca9b4c819eab 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -168,12 +168,11 @@ fn sum_decimal_batch(values: &ArrayRef, precision: u8, scale: u8) -> Result s + v.as_i128(), + None => s, + }); + Ok(ScalarValue::Decimal128(Some(result), precision, scale)) } @@ -206,87 +205,37 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result {{ paste::item! { - match $DELTA { - None => {} - Some(v) => $ACC.[]($INDEX, *v as $TYPE) + if let Some(v) = $DELTA { + $ACC.[]($INDEX, *v) } } }}; } pub(crate) fn add_to_row( - dt: &DataType, index: usize, accessor: &mut RowAccessor, s: &ScalarValue, ) -> Result<()> { - match (dt, s) { - // float64 coerces everything to f64 - (DataType::Float64, ScalarValue::Float64(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::Float32(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::Int64(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::Int32(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::Int16(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::Int8(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::UInt64(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::UInt32(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::UInt16(rhs)) => { - sum_row!(index, accessor, rhs, f64) - } - (DataType::Float64, ScalarValue::UInt8(rhs)) => { + match s { + ScalarValue::Float64(rhs) => { sum_row!(index, accessor, rhs, f64) } - // float32 has no cast - (DataType::Float32, ScalarValue::Float32(rhs)) => { + ScalarValue::Float32(rhs) => { sum_row!(index, accessor, rhs, f32) } - // u64 coerces u* to u64 - (DataType::UInt64, ScalarValue::UInt64(rhs)) => { + ScalarValue::UInt64(rhs) => { sum_row!(index, accessor, rhs, u64) } - (DataType::UInt64, ScalarValue::UInt32(rhs)) => { - sum_row!(index, accessor, rhs, u64) - } - (DataType::UInt64, ScalarValue::UInt16(rhs)) => { - sum_row!(index, accessor, rhs, u64) - } - (DataType::UInt64, ScalarValue::UInt8(rhs)) => { - sum_row!(index, accessor, rhs, u64) - } - // i64 coerces i* to i64 - (DataType::Int64, ScalarValue::Int64(rhs)) => { + ScalarValue::Int64(rhs) => { sum_row!(index, accessor, rhs, i64) } - (DataType::Int64, ScalarValue::Int32(rhs)) => { - sum_row!(index, accessor, rhs, i64) - } - (DataType::Int64, ScalarValue::Int16(rhs)) => { - sum_row!(index, accessor, rhs, i64) - } - (DataType::Int64, ScalarValue::Int8(rhs)) => { - sum_row!(index, accessor, rhs, i64) - } - e => { - return Err(DataFusionError::Internal(format!( + _ => { + let msg = format!( "Row sum updater is not expected to receive a scalar {:?}", - e - ))); + s + ); + return Err(DataFusionError::Internal(msg)); } } Ok(()) @@ -303,18 +252,16 @@ impl Accumulator for SumAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; self.count += (values.len() - values.data().null_count()) as u64; - self.sum = self - .sum - .add(&sum_batch(values, &self.sum.get_datatype())?)?; + let delta = sum_batch(values, &self.sum.get_datatype())?; + self.sum = self.sum.add(&delta)?; Ok(()) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; self.count -= (values.len() - values.data().null_count()) as u64; - self.sum = self - .sum - .sub(&sum_batch(values, &self.sum.get_datatype())?)?; + let delta = sum_batch(values, &self.sum.get_datatype())?; + self.sum = self.sum.sub(&delta)?; Ok(()) } @@ -353,12 +300,8 @@ impl RowAccumulator for SumRowAccumulator { accessor: &mut RowAccessor, ) -> Result<()> { let values = &values[0]; - add_to_row( - &self.datatype, - self.index, - accessor, - &sum_batch(values, &self.datatype)?, - )?; + let delta = sum_batch(values, &self.datatype)?; + add_to_row(self.index, accessor, &delta)?; Ok(()) } @@ -414,8 +357,7 @@ mod tests { array, DataType::Decimal128(10, 0), Sum, - ScalarValue::Decimal128(Some(15), 20, 0), - DataType::Decimal128(20, 0) + ScalarValue::Decimal128(Some(15), 20, 0) ) } @@ -442,8 +384,7 @@ mod tests { array, DataType::Decimal128(35, 0), Sum, - ScalarValue::Decimal128(Some(13), 38, 0), - DataType::Decimal128(38, 0) + ScalarValue::Decimal128(Some(13), 38, 0) ) } @@ -465,21 +406,14 @@ mod tests { array, DataType::Decimal128(10, 0), Sum, - ScalarValue::Decimal128(None, 20, 0), - DataType::Decimal128(20, 0) + ScalarValue::Decimal128(None, 20, 0) ) } #[test] fn sum_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - Sum, - ScalarValue::from(15i64), - DataType::Int64 - ) + generic_test_op!(a, DataType::Int32, Sum, ScalarValue::from(15i32)) } #[test] @@ -491,63 +425,33 @@ mod tests { Some(4), Some(5), ])); - generic_test_op!( - a, - DataType::Int32, - Sum, - ScalarValue::from(13i64), - DataType::Int64 - ) + generic_test_op!(a, DataType::Int32, Sum, ScalarValue::from(13i32)) } #[test] fn sum_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!( - a, - DataType::Int32, - Sum, - ScalarValue::Int64(None), - DataType::Int64 - ) + generic_test_op!(a, DataType::Int32, Sum, ScalarValue::Int32(None)) } #[test] fn sum_u32() -> Result<()> { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - Sum, - ScalarValue::from(15u64), - DataType::UInt64 - ) + generic_test_op!(a, DataType::UInt32, Sum, ScalarValue::from(15u32)) } #[test] fn sum_f32() -> Result<()> { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!( - a, - DataType::Float32, - Sum, - ScalarValue::from(15_f32), - DataType::Float32 - ) + generic_test_op!(a, DataType::Float32, Sum, ScalarValue::from(15_f32)) } #[test] fn sum_f64() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - Sum, - ScalarValue::from(15_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float64, Sum, ScalarValue::from(15_f64)) } } diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index c8abdcac0cb8..73c4828e887f 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -200,7 +200,7 @@ mod tests { } macro_rules! generic_test_sum_distinct { - ($ARRAY:expr, $DATATYPE:expr, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ + ($ARRAY:expr, $DATATYPE:expr, $EXPECTED:expr) => {{ let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?; @@ -208,7 +208,7 @@ mod tests { let agg = Arc::new(DistinctSum::new( vec![col("a", &schema)?], "count_distinct_a".to_string(), - $EXPECTED_DATATYPE, + $EXPECTED.get_datatype(), )); let actual = aggregate(&batch, agg)?; let expected = ScalarValue::from($EXPECTED); @@ -241,12 +241,7 @@ mod tests { Some(2), Some(3), ])); - generic_test_sum_distinct!( - array, - DataType::Int32, - ScalarValue::from(6i64), - DataType::Int64 - ) + generic_test_sum_distinct!(array, DataType::Int32, ScalarValue::from(6_i32)) } #[test] @@ -258,24 +253,14 @@ mod tests { Some(3_u32), None, ])); - generic_test_sum_distinct!( - array, - DataType::UInt32, - ScalarValue::from(4i64), - DataType::Int64 - ) + generic_test_sum_distinct!(array, DataType::UInt32, ScalarValue::from(4_u32)) } #[test] fn sum_distinct_f64() -> Result<()> { let array: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 1_f64, 3_f64, 3_f64, 3_f64])); - generic_test_sum_distinct!( - array, - DataType::Float64, - ScalarValue::from(4_f64), - DataType::Float64 - ) + generic_test_sum_distinct!(array, DataType::Float64, ScalarValue::from(4_f64)) } #[test] @@ -289,8 +274,7 @@ mod tests { generic_test_sum_distinct!( array, DataType::Decimal128(35, 0), - ScalarValue::Decimal128(Some(1), 38, 0), - DataType::Decimal128(38, 0) + ScalarValue::Decimal128(Some(1), 38, 0) ) } } diff --git a/datafusion/physical-expr/src/aggregate/variance.rs b/datafusion/physical-expr/src/aggregate/variance.rs index 7ccec55ac34e..d6ed8c95778b 100644 --- a/datafusion/physical-expr/src/aggregate/variance.rs +++ b/datafusion/physical-expr/src/aggregate/variance.rs @@ -326,8 +326,7 @@ mod tests { a, DataType::Float64, VariancePop, - ScalarValue::from(0.25_f64), - DataType::Float64 + ScalarValue::from(0.25_f64) ) } @@ -335,26 +334,14 @@ mod tests { fn variance_f64_2() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - VariancePop, - ScalarValue::from(2_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float64, VariancePop, ScalarValue::from(2_f64)) } #[test] fn variance_f64_3() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!( - a, - DataType::Float64, - Variance, - ScalarValue::from(2.5_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float64, Variance, ScalarValue::from(2.5_f64)) } #[test] @@ -364,47 +351,28 @@ mod tests { a, DataType::Float64, Variance, - ScalarValue::from(0.9033333333333333_f64), - DataType::Float64 + ScalarValue::from(0.9033333333333333_f64) ) } #[test] fn variance_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!( - a, - DataType::Int32, - VariancePop, - ScalarValue::from(2_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Int32, VariancePop, ScalarValue::from(2_f64)) } #[test] fn variance_u32() -> Result<()> { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!( - a, - DataType::UInt32, - VariancePop, - ScalarValue::from(2.0f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::UInt32, VariancePop, ScalarValue::from(2.0f64)) } #[test] fn variance_f32() -> Result<()> { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!( - a, - DataType::Float32, - VariancePop, - ScalarValue::from(2_f64), - DataType::Float64 - ) + generic_test_op!(a, DataType::Float32, VariancePop, ScalarValue::from(2_f64)) } #[test] @@ -437,8 +405,7 @@ mod tests { a, DataType::Int32, VariancePop, - ScalarValue::from(2.1875f64), - DataType::Float64 + ScalarValue::from(2.1875_f64) ) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 00bf6aafade4..208e6d0b51fb 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -101,6 +101,9 @@ pub(crate) mod tests { /// macro to perform an aggregation and verify the result. #[macro_export] macro_rules! generic_test_op { + ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr) => { + generic_test_op!($ARRAY, $DATATYPE, $OP, $EXPECTED, $EXPECTED.get_datatype()) + }; ($ARRAY:expr, $DATATYPE:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ let schema = Schema::new(vec![Field::new("a", $DATATYPE, true)]); @@ -123,6 +126,17 @@ pub(crate) mod tests { /// macro to perform an aggregation with two inputs and verify the result. #[macro_export] macro_rules! generic_test_op2 { + ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr) => { + generic_test_op2!( + $ARRAY1, + $ARRAY2, + $DATATYPE1, + $DATATYPE2, + $OP, + $EXPECTED, + $EXPECTED.get_datatype() + ) + }; ($ARRAY1:expr, $ARRAY2:expr, $DATATYPE1:expr, $DATATYPE2:expr, $OP:ident, $EXPECTED:expr, $EXPECTED_DATATYPE:expr) => {{ let schema = Schema::new(vec![ Field::new("a", $DATATYPE1, true),