Skip to content

Commit

Permalink
better group value view aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Jul 17, 2024
1 parent 1cf55e4 commit 67618a7
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 20 deletions.
4 changes: 4 additions & 0 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use ahash::RandomState;
use datafusion_physical_expr_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use std::collections::HashSet;
use std::ops::BitAnd;
use std::{fmt::Debug, sync::Arc};
Expand Down Expand Up @@ -230,6 +231,9 @@ impl AggregateUDFImpl for Count {
DataType::Utf8 => {
Box::new(BytesDistinctCountAccumulator::<i32>::new(OutputType::Utf8))
}
DataType::Utf8View => {
Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8))
}
DataType::LargeUtf8 => {
Box::new(BytesDistinctCountAccumulator::<i64>::new(OutputType::Utf8))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values

use crate::binary_map::{ArrowBytesSet, OutputType};
use crate::binary_view_map::ArrowBytesViewSet;
use arrow::array::{ArrayRef, OffsetSizeTrait};
use datafusion_common::cast::as_list_array;
use datafusion_common::utils::array_into_list_array_nullable;
Expand Down Expand Up @@ -88,3 +89,63 @@ impl<O: OffsetSizeTrait> Accumulator for BytesDistinctCountAccumulator<O> {
std::mem::size_of_val(self) + self.0.size()
}
}

/// Specialized implementation of
/// `COUNT DISTINCT` for [`StringViewArray`] and [`BinaryViewArray`].
///
/// [`StringViewArray`]: arrow::array::StringViewArray
/// [`BinaryViewArray`]: arrow::array::BinaryViewArray
#[derive(Debug)]
pub struct BytesViewDistinctCountAccumulator(ArrowBytesViewSet);

impl BytesViewDistinctCountAccumulator {
pub fn new(output_type: OutputType) -> Self {
Self(ArrowBytesViewSet::new(output_type))
}
}

impl Accumulator for BytesViewDistinctCountAccumulator {
fn state(&mut self) -> datafusion_common::Result<Vec<ScalarValue>> {
let set = self.0.take();
let arr = set.into_state();
let list = Arc::new(array_into_list_array_nullable(arr));
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> {
if values.is_empty() {
return Ok(());
}

self.0.insert(&values[0]);

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
self.0.insert(&list);
};
Ok(())
})
}

fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.0.non_null_len() as i64)))
}

fn size(&self) -> usize {
std::mem::size_of_val(self) + self.0.size()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ mod bytes;
mod native;

pub use bytes::BytesDistinctCountAccumulator;
pub use bytes::BytesViewDistinctCountAccumulator;
pub use native::FloatDistinctCountAccumulator;
pub use native::PrimitiveDistinctCountAccumulator;
6 changes: 6 additions & 0 deletions datafusion/physical-expr-common/src/binary_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ use std::sync::Arc;
pub enum OutputType {
/// `StringArray` or `LargeStringArray`
Utf8,
/// `StringViewArray`
Utf8View,
/// `BinaryArray` or `LargeBinaryArray`
Binary,
/// `BinaryViewArray`
BinaryView,
}

/// HashSet optimized for storing string or binary values that can produce that
Expand Down Expand Up @@ -318,6 +322,7 @@ where
observe_payload_fn,
)
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
};
}

Expand Down Expand Up @@ -516,6 +521,7 @@ where
GenericStringArray::new_unchecked(offsets, values, nulls)
})
}
_ => unreachable!("View types should use `ArrowBytesViewMap`"),
}
}

Expand Down
23 changes: 14 additions & 9 deletions datafusion/physical-expr-common/src/binary_view_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,9 @@ use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
use std::fmt::Debug;
use std::sync::Arc;

/// Should the output be a String or Binary?
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OutputType {
/// `StringArray` or `LargeStringArray`
Utf8View,
/// `BinaryArray` or `LargeBinaryArray`
BinaryView,
}
use crate::binary_map::OutputType;

/// HashSet optimized for storing string or binary values that can produce that
//// HashSet optimized for storing string or binary values that can produce that
/// the final set as a GenericStringArray with minimal copies.
#[derive(Debug)]
pub struct ArrowBytesViewSet(ArrowBytesViewMap<()>);
Expand All @@ -53,6 +46,14 @@ impl ArrowBytesViewSet {
.insert_if_new(values, make_payload_fn, observe_payload_fn);
}

/// Return the contents of this set and replace it with a new empty set with
/// the same output type
pub fn take(&mut self) -> Self {
let mut new_self = Self::new(self.0.output_type);
std::mem::swap(self, &mut new_self);
new_self
}

/// Converts this set into a `StringArray`/`LargeStringArray` or
/// `BinaryArray`/`LargeBinaryArray` containing each distinct value that
/// was interned. This is done without copying the values.
Expand Down Expand Up @@ -214,6 +215,7 @@ where
observe_payload_fn,
)
}
_ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"),
};
}

Expand Down Expand Up @@ -325,6 +327,9 @@ where
let array = unsafe { array.to_string_view_unchecked() };
Arc::new(array)
}
_ => {
unreachable!("Utf8/Binary should use `ArrowBytesMap`")
}
}
}

Expand Down
129 changes: 129 additions & 0 deletions datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::aggregates::group_values::GroupValues;
use arrow_array::{Array, ArrayRef, RecordBatch};
use datafusion_expr::EmitTo;
use datafusion_physical_expr::binary_map::OutputType;
use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap;

/// A [`GroupValues`] storing single column of Utf8View/BinaryView values
///
/// This specialization is significantly faster than using the more general
/// purpose `Row`s format
pub struct GroupValuesBytesView {
/// Map string/binary values to group index
map: ArrowBytesViewMap<usize>,
/// The total number of groups so far (used to assign group_index)
num_groups: usize,
}

impl GroupValuesBytesView {
pub fn new(output_type: OutputType) -> Self {
Self {
map: ArrowBytesViewMap::new(output_type),
num_groups: 0,
}
}
}

impl GroupValues for GroupValuesBytesView {
fn intern(
&mut self,
cols: &[ArrayRef],
groups: &mut Vec<usize>,
) -> datafusion_common::Result<()> {
assert_eq!(cols.len(), 1);

// look up / add entries in the table
let arr = &cols[0];

groups.clear();
self.map.insert_if_new(
arr,
// called for each new group
|_value| {
// assign new group index on each insert
let group_idx = self.num_groups;
self.num_groups += 1;
group_idx
},
// called for each group
|group_idx| {
groups.push(group_idx);
},
);

// ensure we assigned a group to for each row
assert_eq!(groups.len(), arr.len());
Ok(())
}

fn size(&self) -> usize {
self.map.size() + std::mem::size_of::<Self>()
}

fn is_empty(&self) -> bool {
self.num_groups == 0
}

fn len(&self) -> usize {
self.num_groups
}

fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result<Vec<ArrayRef>> {
// Reset the map to default, and convert it into a single array
let map_contents = self.map.take().into_state();

let group_values = match emit_to {
EmitTo::All => {
self.num_groups -= map_contents.len();
map_contents
}
EmitTo::First(n) if n == self.len() => {
self.num_groups -= map_contents.len();
map_contents
}
EmitTo::First(n) => {
// if we only wanted to take the first n, insert the rest back
// into the map we could potentially avoid this reallocation, at
// the expense of much more complex code.
// see https://github.com/apache/datafusion/issues/9195
let emit_group_values = map_contents.slice(0, n);
let remaining_group_values =
map_contents.slice(n, map_contents.len() - n);

self.num_groups = 0;
let mut group_indexes = vec![];
self.intern(&[remaining_group_values], &mut group_indexes)?;

// Verify that the group indexes were assigned in the correct order
assert_eq!(0, group_indexes[0]);

emit_group_values
}
};

Ok(vec![group_values])
}

fn clear_shrink(&mut self, _batch: &RecordBatch) {
// in theory we could potentially avoid this reallocation and clear the
// contents of the maps, but for now we just reset the map from the beginning
self.map.take();
}
}
33 changes: 22 additions & 11 deletions datafusion/physical-plan/src/aggregates/group_values/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use arrow::record_batch::RecordBatch;
use arrow_array::{downcast_primitive, ArrayRef};
use arrow_schema::{DataType, SchemaRef};
use bytes_view::GroupValuesBytesView;
use datafusion_common::Result;

pub(crate) mod primitive;
Expand All @@ -28,6 +29,7 @@ mod row;
use row::GroupValuesRows;

mod bytes;
mod bytes_view;
use bytes::GroupValuesByes;
use datafusion_physical_expr::binary_map::OutputType;

Expand Down Expand Up @@ -67,17 +69,26 @@ pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {
_ => {}
}

if let DataType::Utf8 = d {
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Utf8)));
}
if let DataType::LargeUtf8 = d {
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Utf8)));
}
if let DataType::Binary = d {
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Binary)));
}
if let DataType::LargeBinary = d {
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Binary)));
match d {
DataType::Utf8 => {
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Utf8)));
}
DataType::LargeUtf8 => {
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Utf8)));
}
DataType::Utf8View => {
return Ok(Box::new(GroupValuesBytesView::new(OutputType::Utf8View)));
}
DataType::Binary => {
return Ok(Box::new(GroupValuesByes::<i32>::new(OutputType::Binary)));
}
DataType::LargeBinary => {
return Ok(Box::new(GroupValuesByes::<i64>::new(OutputType::Binary)));
}
DataType::BinaryView => {
return Ok(Box::new(GroupValuesBytesView::new(OutputType::BinaryView)));
}
_ => {}
}
}

Expand Down

0 comments on commit 67618a7

Please sign in to comment.