From d0c73f832ffb4e5452e26218f86f11ff5d827982 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 22 Aug 2023 18:02:12 +0100 Subject: [PATCH 1/2] Specialize Median Accumulator --- .../physical-expr/src/aggregate/median.rs | 193 ++++++------------ 1 file changed, 64 insertions(+), 129 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/median.rs b/datafusion/physical-expr/src/aggregate/median.rs index 2f6096609319..1ec412402638 100644 --- a/datafusion/physical-expr/src/aggregate/median.rs +++ b/datafusion/physical-expr/src/aggregate/median.rs @@ -20,13 +20,15 @@ use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use arrow::array::{Array, ArrayRef, UInt32Array}; -use arrow::compute::sort_to_indices; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field}; -use datafusion_common::internal_err; +use arrow_array::cast::AsArray; +use arrow_array::{downcast_integer, ArrowNativeTypeOp, ArrowNumericType}; +use arrow_buffer::ArrowNativeType; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use std::any::Any; +use std::fmt::Formatter; use std::sync::Arc; /// MEDIAN aggregate expression. This uses a lot of memory because all values need to be @@ -65,11 +67,29 @@ impl AggregateExpr for Median { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(MedianAccumulator { - data_type: self.data_type.clone(), - arrays: vec![], - all_values: vec![], - })) + use arrow_array::types::*; + macro_rules! helper { + ($t:ty, $dt:expr) => { + Ok(Box::new(MedianAccumulator::<$t> { + data_type: $dt.clone(), + all_values: vec![], + })) + }; + } + let dt = &self.data_type; + downcast_integer! { + dt => (helper, dt), + DataType::Float16 => helper!(Float16Type, dt), + DataType::Float32 => helper!(Float32Type, dt), + DataType::Float64 => helper!(Float64Type, dt), + DataType::Decimal128(_, _) => helper!(Decimal128Type, dt), + DataType::Decimal256(_, _) => helper!(Decimal256Type, dt), + _ => Err(DataFusionError::NotImplemented(format!( + "MedianAccumulator not supported for {} with {}", + self.name(), + self.data_type + ))), + } } fn state_fields(&self) -> Result> { @@ -106,7 +126,6 @@ impl PartialEq for Median { } } -#[derive(Debug)] /// The median accumulator accumulates the raw input values /// as `ScalarValue`s /// @@ -114,151 +133,68 @@ impl PartialEq for Median { /// `merge_batch` and a `Vec` of `ArrayRef` that are converted to scalar values /// in the final evaluation step so that we avoid expensive conversions and /// allocations during `update_batch`. -struct MedianAccumulator { +struct MedianAccumulator { data_type: DataType, - arrays: Vec, - all_values: Vec, + all_values: Vec, } -impl Accumulator for MedianAccumulator { +impl std::fmt::Debug for MedianAccumulator { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "MedianAccumulator({})", self.data_type) + } +} + +impl Accumulator for MedianAccumulator { fn state(&self) -> Result> { - let all_values = to_scalar_values(&self.arrays)?; + let all_values = self + .all_values + .iter() + .map(|x| ScalarValue::new_primitive::(Some(*x), &self.data_type)) + .collect(); let state = ScalarValue::new_list(Some(all_values), self.data_type.clone()); Ok(vec![state]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - assert_eq!(values.len(), 1); - let array = &values[0]; - - // Defer conversions to scalar values to final evaluation. - assert_eq!(array.data_type(), &self.data_type); - self.arrays.push(array.clone()); - + let values = values[0].as_primitive::(); + self.all_values.reserve(values.len() - values.null_count()); + self.all_values.extend(values.iter().flatten()); Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - assert_eq!(states.len(), 1); - - let array = &states[0]; - assert!(matches!(array.data_type(), DataType::List(_))); - for index in 0..array.len() { - match ScalarValue::try_from_array(array, index)? { - ScalarValue::List(Some(mut values), _) => { - self.all_values.append(&mut values); - } - ScalarValue::List(None, _) => {} // skip empty state - v => { - return internal_err!( - "unexpected state in median. Expected DataType::List, got {v:?}" - ) - } - } + let array = states[0].as_list::(); + for v in array.iter().flatten() { + self.update_batch(&[v])? } Ok(()) } fn evaluate(&self) -> Result { - let batch_values = to_scalar_values(&self.arrays)?; - - if !self - .all_values - .iter() - .chain(batch_values.iter()) - .any(|v| !v.is_null()) - { - return ScalarValue::try_from(&self.data_type); - } - - // Create an array of all the non null values and find the - // sorted indexes - let array = ScalarValue::iter_to_array( - self.all_values - .iter() - .chain(batch_values.iter()) - // ignore null values - .filter(|v| !v.is_null()) - .cloned(), - )?; - - // find the mid point - let len = array.len(); - let mid = len / 2; - - // only sort up to the top size/2 elements - let limit = Some(mid + 1); - let options = None; - let indices = sort_to_indices(&array, options, limit)?; - - // pick the relevant indices in the original arrays - let result = if len >= 2 && len % 2 == 0 { - // even number of values, average the two mid points - let s1 = scalar_at_index(&array, &indices, mid - 1)?; - let s2 = scalar_at_index(&array, &indices, mid)?; - match s1.add(s2)? { - ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(v / 2)), - ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(v / 2)), - ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(v / 2)), - ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(v / 2)), - ScalarValue::UInt8(Some(v)) => ScalarValue::UInt8(Some(v / 2)), - ScalarValue::UInt16(Some(v)) => ScalarValue::UInt16(Some(v / 2)), - ScalarValue::UInt32(Some(v)) => ScalarValue::UInt32(Some(v / 2)), - ScalarValue::UInt64(Some(v)) => ScalarValue::UInt64(Some(v / 2)), - ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(v / 2.0)), - ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(v / 2.0)), - ScalarValue::Decimal128(Some(v), p, s) => { - ScalarValue::Decimal128(Some(v / 2), p, s) - } - v => { - return internal_err!("Unsupported type in MedianAccumulator: {v:?}") - } - } + // TODO: evaluate could pass &mut self + let mut d = self.all_values.clone(); + let cmp = |x: &T::Native, y: &T::Native| x.compare(*y); + + let len = d.len(); + let median = if len == 0 { + None + } else if len % 2 == 0 { + let (low, high, _) = d.select_nth_unstable_by(len / 2, cmp); + let (_, low, _) = low.select_nth_unstable_by(low.len() - 1, cmp); + let median = low.add_wrapping(*high).div_wrapping(T::Native::usize_as(2)); + Some(median) } else { - // odd number of values, pick that one - scalar_at_index(&array, &indices, mid)? + let (_, median, _) = d.select_nth_unstable_by(len / 2, cmp); + Some(*median) }; - - Ok(result) + Ok(ScalarValue::new_primitive::(median, &self.data_type)) } fn size(&self) -> usize { - let arrays_size: usize = self.arrays.iter().map(|a| a.len()).sum(); - std::mem::size_of_val(self) - + ScalarValue::size_of_vec(&self.all_values) - + arrays_size - - std::mem::size_of_val(&self.all_values) - + self.data_type.size() - - std::mem::size_of_val(&self.data_type) - } -} - -fn to_scalar_values(arrays: &[ArrayRef]) -> Result> { - let num_values: usize = arrays.iter().map(|a| a.len()).sum(); - let mut all_values = Vec::with_capacity(num_values); - - for array in arrays { - for index in 0..array.len() { - all_values.push(ScalarValue::try_from_array(&array, index)?); - } + + self.all_values.capacity() * std::mem::size_of::() } - - Ok(all_values) -} - -/// Given a returns `array[indicies[indicie_index]]` as a `ScalarValue` -fn scalar_at_index( - array: &dyn Array, - indices: &UInt32Array, - indicies_index: usize, -) -> Result { - let array_index = indices - .value(indicies_index) - .try_into() - .expect("Convert uint32 to usize"); - ScalarValue::try_from_array(array, array_index) } #[cfg(test)] @@ -269,7 +205,6 @@ mod tests { use crate::generic_test_op; use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; - use datafusion_common::Result; #[test] fn median_decimal() -> Result<()> { From 5893150c4a7a50841123ee7a567afb680caecd6b Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 22 Aug 2023 18:35:22 +0100 Subject: [PATCH 2/2] Tweak memory limit test --- datafusion/core/tests/memory_limit.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 623a6f71f451..f722addf6fb2 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -68,12 +68,12 @@ async fn oom_sort() { #[tokio::test] async fn group_by_none() { TestCase::new() - .with_query("select median(image) from t") + .with_query("select median(request_bytes) from t") .with_expected_errors(vec![ "Resources exhausted: Failed to allocate additional", "AggregateStream", ]) - .with_memory_limit(20_000) + .with_memory_limit(2_000) .run() .await }