diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index c2fd32a96c4f..f7c13948b2dc 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -15,21 +15,32 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, TimeUnit}; +use arrow_array::types::{ + ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::PrimitiveArray; use std::any::Any; +use std::cmp::Eq; use std::fmt::Debug; +use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; use std::collections::HashSet; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::Result; -use datafusion_common::ScalarValue; +use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; type DistinctScalarValues = ScalarValue; @@ -60,6 +71,18 @@ impl DistinctCount { } } +macro_rules! native_distinct_count_accumulator { + ($TYPE:ident) => {{ + Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new())) + }}; +} + +macro_rules! float_distinct_count_accumulator { + ($TYPE:ident) => {{ + Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new())) + }}; +} + impl AggregateExpr for DistinctCount { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -83,10 +106,57 @@ impl AggregateExpr for DistinctCount { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(DistinctCountAccumulator { - values: HashSet::default(), - state_data_type: self.state_data_type.clone(), - })) + use DataType::*; + use TimeUnit::*; + + match &self.state_data_type { + Int8 => native_distinct_count_accumulator!(Int8Type), + Int16 => native_distinct_count_accumulator!(Int16Type), + Int32 => native_distinct_count_accumulator!(Int32Type), + Int64 => native_distinct_count_accumulator!(Int64Type), + UInt8 => native_distinct_count_accumulator!(UInt8Type), + UInt16 => native_distinct_count_accumulator!(UInt16Type), + UInt32 => native_distinct_count_accumulator!(UInt32Type), + UInt64 => native_distinct_count_accumulator!(UInt64Type), + Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type), + Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type), + + Date32 => native_distinct_count_accumulator!(Date32Type), + Date64 => native_distinct_count_accumulator!(Date64Type), + Time32(Millisecond) => { + native_distinct_count_accumulator!(Time32MillisecondType) + } + Time32(Second) => { + native_distinct_count_accumulator!(Time32SecondType) + } + Time64(Microsecond) => { + native_distinct_count_accumulator!(Time64MicrosecondType) + } + Time64(Nanosecond) => { + native_distinct_count_accumulator!(Time64NanosecondType) + } + Timestamp(Microsecond, _) => { + native_distinct_count_accumulator!(TimestampMicrosecondType) + } + Timestamp(Millisecond, _) => { + native_distinct_count_accumulator!(TimestampMillisecondType) + } + Timestamp(Nanosecond, _) => { + native_distinct_count_accumulator!(TimestampNanosecondType) + } + Timestamp(Second, _) => { + native_distinct_count_accumulator!(TimestampSecondType) + } + + Float16 => float_distinct_count_accumulator!(Float16Type), + Float32 => float_distinct_count_accumulator!(Float32Type), + Float64 => float_distinct_count_accumulator!(Float64Type), + + _ => Ok(Box::new(DistinctCountAccumulator { + values: HashSet::default(), + state_data_type: self.state_data_type.clone(), + })), + } } fn name(&self) -> &str { @@ -192,6 +262,182 @@ impl Accumulator for DistinctCountAccumulator { } } +#[derive(Debug)] +struct NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + values: HashSet, +} + +impl NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, + T::Native: Eq + Hash, +{ + fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for NativeDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, + T::Native: Eq + Hash, +{ + fn state(&self) -> Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().cloned(), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(value); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values.extend(list.values()) + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + +#[derive(Debug)] +struct FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + values: HashSet, RandomState>, +} + +impl FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send, +{ + fn new() -> Self { + Self { + values: HashSet::default(), + } + } +} + +impl Accumulator for FloatDistinctCountAccumulator +where + T: ArrowPrimitiveType + Send + Debug, +{ + fn state(&self) -> Result> { + let arr = Arc::new(PrimitiveArray::::from_iter_values( + self.values.iter().map(|v| v.0), + )) as ArrayRef; + let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let arr = as_primitive_array::(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.values.insert(Hashable(value)); + } + }); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_primitive_array::(&list)?; + self.values + .extend(list.values().iter().map(|v| Hashable(*v))); + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.values.len() as i64))) + } + + fn size(&self) -> usize { + let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) + / 7) + .next_power_of_two(); + + // Size of accumulator + // + size of entry * number of buckets + // + 1 byte for each bucket + // + fixed size of HashSet + std::mem::size_of_val(self) + + std::mem::size_of::() * estimated_buckets + + estimated_buckets + + std::mem::size_of_val(&self.values) + } +} + #[cfg(test)] mod tests { use crate::expressions::NoOp; @@ -206,6 +452,8 @@ mod tests { Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; + use arrow_array::Decimal256Array; + use arrow_buffer::i256; use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; use datafusion_common::internal_err; use datafusion_common::DataFusionError; @@ -367,6 +615,35 @@ mod tests { }}; } + macro_rules! test_count_distinct_update_batch_bigint { + ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ + let values: Vec> = vec![ + Some(i256::from(1)), + Some(i256::from(1)), + None, + Some(i256::from(3)), + Some(i256::from(2)), + None, + Some(i256::from(2)), + Some(i256::from(3)), + Some(i256::from(1)), + ]; + + let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; + + let (states, result) = run_update_batch(&arrays)?; + + let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); + state_vec.sort(); + + assert_eq!(states.len(), 1); + assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); + assert_eq!(result, ScalarValue::Int64(Some(3))); + + Ok(()) + }}; + } + #[test] fn count_distinct_update_batch_i8() -> Result<()> { test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) @@ -417,6 +694,11 @@ mod tests { test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) } + #[test] + fn count_distinct_update_batch_i256() -> Result<()> { + test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) + } + #[test] fn count_distinct_update_batch_boolean() -> Result<()> { let get_count = |data: BooleanArray| -> Result<(Vec, i64)> { diff --git a/datafusion/physical-expr/src/aggregate/sum_distinct.rs b/datafusion/physical-expr/src/aggregate/sum_distinct.rs index 0cf4a90ab8cc..6dbb39224629 100644 --- a/datafusion/physical-expr/src/aggregate/sum_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/sum_distinct.rs @@ -25,11 +25,11 @@ use arrow::array::{Array, ArrayRef}; use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::{ArrowNativeTypeOp, ArrowPrimitiveType}; -use arrow_buffer::{ArrowNativeType, ToByteSlice}; +use arrow_buffer::ArrowNativeType; use std::collections::HashSet; use crate::aggregate::sum::downcast_sum; -use crate::aggregate::utils::down_cast_any_ref; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::aggregates::sum_return_type; @@ -119,24 +119,6 @@ impl PartialEq for DistinctSum { } } -/// A wrapper around a type to provide hash for floats -#[derive(Copy, Clone)] -struct Hashable(T); - -impl std::hash::Hash for Hashable { - fn hash(&self, state: &mut H) { - self.0.to_byte_slice().hash(state) - } -} - -impl PartialEq for Hashable { - fn eq(&self, other: &Self) -> bool { - self.0.is_eq(other.0) - } -} - -impl Eq for Hashable {} - struct DistinctSumAccumulator { values: HashSet, RandomState>, data_type: DataType, diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index 9777158da133..d73c46a0f687 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -28,7 +28,7 @@ use arrow_array::types::{ Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, }; -use arrow_buffer::ArrowNativeType; +use arrow_buffer::{ArrowNativeType, ToByteSlice}; use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::{exec_err, DataFusionError, Result}; use datafusion_expr::Accumulator; @@ -211,3 +211,21 @@ pub(crate) fn ordering_fields( pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec { ordering_req.iter().map(|item| item.options).collect() } + +/// A wrapper around a type to provide hash for floats +#[derive(Copy, Clone, Debug)] +pub(crate) struct Hashable(pub T); + +impl std::hash::Hash for Hashable { + fn hash(&self, state: &mut H) { + self.0.to_byte_slice().hash(state) + } +} + +impl PartialEq for Hashable { + fn eq(&self, other: &Self) -> bool { + self.0.is_eq(other.0) + } +} + +impl Eq for Hashable {}