From 20f4af664fc251fbd67cd761cfc6abcad0769979 Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Mon, 7 Oct 2024 22:15:37 +0200 Subject: [PATCH 1/2] simple mode f --- Cargo.toml | 6 +- src/common/collections.rs | 22 - src/common/collections/binary_map.rs | 1053 --------------------- src/common/collections/binary_view_map.rs | 765 --------------- src/common/mod.rs | 1 - src/common/mode.rs | 1 - src/common/mode/bytes.rs | 290 ++---- src/common/mode/native.rs | 12 +- src/mode.rs | 12 +- tests/main.rs | 16 +- 10 files changed, 102 insertions(+), 2076 deletions(-) delete mode 100644 src/common/collections.rs delete mode 100644 src/common/collections/binary_map.rs delete mode 100644 src/common/collections/binary_view_map.rs diff --git a/Cargo.toml b/Cargo.toml index df9fa2d..fd41e8c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,14 +31,10 @@ name = "datafusion_functions_extra" path = "src/lib.rs" [dependencies] -ahash = { version = "0.8", default-features = false, features = [ - "runtime-rng", -] } +arrow = { version = "53.0.0", features = ["test_utils"] } datafusion = "42" -hashbrown = { version = "0.14.5", features = ["raw"] } log = "^0.4" paste = "1" -arrow = { version = "53.0.0", features = ["test_utils"] } [dev-dependencies] arrow = { version = "53.0.0", features = ["test_utils"] } diff --git a/src/common/collections.rs b/src/common/collections.rs deleted file mode 100644 index d197ce5..0000000 --- a/src/common/collections.rs +++ /dev/null @@ -1,22 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -mod binary_map; -mod binary_view_map; - -pub use binary_map::ArrowBytesMap; -pub use binary_view_map::ArrowBytesViewMap; diff --git a/src/common/collections/binary_map.rs b/src/common/collections/binary_map.rs deleted file mode 100644 index bba6dd2..0000000 --- a/src/common/collections/binary_map.rs +++ /dev/null @@ -1,1053 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ArrowBytesMap`] and [`ArrowBytesSet`] for storing maps/sets of values from -//! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray. - -use ahash::RandomState; - -use arrow::array::cast::AsArray; -use arrow::array::types::{ByteArrayType, GenericBinaryType, GenericStringType}; -use arrow::array::{ - Array, ArrayRef, BooleanBufferBuilder, BufferBuilder, GenericBinaryArray, GenericStringArray, OffsetSizeTrait, -}; -use arrow::buffer::{NullBuffer, OffsetBuffer, ScalarBuffer}; -use arrow::datatypes::DataType; -use datafusion::arrow; -use datafusion::common::hash_utils::create_hashes; -use datafusion::common::utils::proxy::{RawTableAllocExt, VecAllocExt}; -use datafusion::physical_expr::binary_map::OutputType; -use std::any::type_name; -use std::fmt::Debug; -use std::mem; -use std::ops::Range; -use std::sync::Arc; - -/// Optimized map for storing Arrow "bytes" types (`String`, `LargeString`, -/// `Binary`, and `LargeBinary`) values that can produce the set of keys on -/// output as `GenericBinaryArray` without copies. -/// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. -/// -/// # Generic Arguments -/// -/// * `O`: OffsetSize (String/LargeString) -/// * `V`: payload type -/// -/// # Description -/// -/// This is a specialized HashMap with the following properties: -/// -/// 1. Optimized for storing and emitting Arrow byte types (e.g. -/// `StringArray` / `BinaryArray`) very efficiently by minimizing copying of -/// the string values themselves, both when inserting and when emitting the -/// final array. -/// -/// -/// 2. Retains the insertion order of entries in the final array. The values are -/// in the same order as they were inserted. -/// -/// Note this structure can be used as a `HashSet` by specifying the value type -/// as `()`, as is done by [`ArrowBytesSet`]. -/// -/// This map is used by the special `COUNT DISTINCT` aggregate function to -/// store the distinct values, and by the `GROUP BY` operator to store -/// group values when they are a single string array. -/// -/// # Example -/// -/// The following diagram shows how the map would store the four strings -/// "Foo", NULL, "Bar", "TheQuickBrownFox": -/// -/// * `hashtable` stores entries for each distinct string that has been -/// inserted. The entries contain the payload as well as information about the -/// value (either an offset or the actual bytes, see `Entry` docs for more -/// details) -/// -/// * `offsets` stores offsets into `buffer` for each distinct string value, -/// following the same convention as the offsets in a `StringArray` or -/// `LargeStringArray`. -/// -/// * `buffer` stores the actual byte data -/// -/// * `null`: stores the index and payload of the null value, in this case the -/// second value (index 1) -/// -/// ```text -/// ┌───────────────────────────────────┐ ┌─────┐ ┌────┐ -/// │ ... │ │ 0 │ │FooB│ -/// │ ┌──────────────────────────────┐ │ │ 0 │ │arTh│ -/// │ │ │ │ │ 3 │ │eQui│ -/// │ │ len: 3 │ │ │ 3 │ │ckBr│ -/// │ │ offset_or_inline: "Bar" │ │ │ 6 │ │ownF│ -/// │ │ payload:... │ │ │ │ │ox │ -/// │ └──────────────────────────────┘ │ │ │ │ │ -/// │ ... │ └─────┘ └────┘ -/// │ ┌──────────────────────────────┐ │ -/// │ ││ │ offsets buffer -/// │ │ len: 16 │ │ -/// │ │ offset_or_inline: 6 │ │ ┌───────────────┐ -/// │ │ payload: ... │ │ │ Some(1) │ -/// │ └──────────────────────────────┘ │ │ payload: ... │ -/// │ ... │ └───────────────┘ -/// └───────────────────────────────────┘ -/// null -/// HashTable -/// ``` -/// -/// # Entry Format -/// -/// Entries stored in a [`ArrowBytesMap`] represents a value that is either -/// stored inline or in the buffer -/// -/// This helps the case where there are many short (less than 8 bytes) strings -/// that are the same (e.g. "MA", "CA", "NY", "TX", etc) -/// -/// ```text -/// ┌──────────────────┐ -/// ─ ─ ─ ─ ─ ─ ─▶│... │ -/// │ │TheQuickBrownFox │ -/// │... │ -/// │ │ │ -/// └──────────────────┘ -/// │ buffer of u8 -/// -/// │ -/// ┌────────────────┬───────────────┬───────────────┐ -/// Storing │ │ starting byte │ length, in │ -/// "TheQuickBrownFox" │ hash value │ offset in │ bytes (not │ -/// (long string) │ │ buffer │ characters) │ -/// └────────────────┴───────────────┴───────────────┘ -/// 8 bytes 8 bytes 4 or 8 -/// -/// -/// ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐ -/// Storing "foobar" │ │ │ │ │ │ │ │ │ │ length, in │ -/// (short string) │ hash value │?│?│f│o│o│b│a│r│ bytes (not │ -/// │ │ │ │ │ │ │ │ │ │ characters) │ -/// └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘ -/// 8 bytes 8 bytes 4 or 8 -/// ``` - -// TODO: Remove after DataFusion next release once insert_or_update and get_payloads are added to the collection. -// Copied from datafusion/physical-expr-common/binary_map.rs. -pub struct ArrowBytesMap -where - O: OffsetSizeTrait, - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - /// Should the output be String or Binary? - output_type: OutputType, - /// Underlying hash set for each distinct value - map: hashbrown::raw::RawTable>, - /// Total size of the map in bytes - map_size: usize, - /// In progress arrow `Buffer` containing all values - buffer: BufferBuilder, - /// Offsets into `buffer` for each distinct value. These offsets as used - /// directly to create the final `GenericBinaryArray`. The `i`th string is - /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values - /// are stored as a zero length string. - offsets: Vec, - /// random state used to generate hashes - random_state: RandomState, - /// buffer that stores hash values (reused across batches to save allocations) - hashes_buffer: Vec, - /// `(payload, null_index)` for the 'null' value, if any - /// NOTE null_index is the logical index in the final array, not the index - /// in the buffer - null: Option<(V, usize)>, -} - -/// The size, in number of entries, of the initial hash table -const INITIAL_MAP_CAPACITY: usize = 128; -/// The initial size, in bytes, of the string data -const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024; -impl ArrowBytesMap -where - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - pub fn new(output_type: OutputType) -> Self { - Self { - output_type, - map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), - map_size: 0, - buffer: BufferBuilder::new(INITIAL_BUFFER_CAPACITY), - offsets: vec![O::default()], // first offset is always 0 - random_state: RandomState::new(), - hashes_buffer: vec![], - null: None, - } - } - - /// Return the contents of this map and replace it with a new empty map with - /// the same output type - pub fn take(&mut self) -> Self { - let mut new_self = Self::new(self.output_type); - mem::swap(self, &mut new_self); - new_self - } - - /// Inserts each value from `values` into the map, invoking `payload_fn` for - /// each value if *not* already present, deferring the allocation of the - /// payload until it is needed. - /// - /// Note that this is different than a normal map that would replace the - /// existing entry - /// - /// # Arguments: - /// - /// `values`: array whose values are inserted - /// - /// `make_payload_fn`: invoked for each value that is not already present - /// to create the payload, in order of the values in `values` - /// - /// `observe_payload_fn`: invoked once, for each value in `values`, that was - /// already present in the map, with corresponding payload value. - /// - /// # Returns - /// - /// The payload value for the entry, either the existing value or - /// the newly inserted value - /// - /// # Safety: - /// - /// Note that `make_payload_fn` and `observe_payload_fn` are only invoked - /// with valid values from `values`, not for the `NULL` value. - pub fn insert_if_new(&mut self, values: &ArrayRef, make_payload_fn: MP, observe_payload_fn: OP) - where - MP: FnMut(Option<&[u8]>) -> V, - OP: FnMut(V), - { - // Sanity array type - match self.output_type { - OutputType::Binary => { - assert!(matches!(values.data_type(), DataType::Binary | DataType::LargeBinary)); - self.insert_if_new_inner::>(values, make_payload_fn, observe_payload_fn) - } - OutputType::Utf8 => { - assert!(matches!(values.data_type(), DataType::Utf8 | DataType::LargeUtf8)); - self.insert_if_new_inner::>(values, make_payload_fn, observe_payload_fn) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - }; - } - - /// Generic version of [`Self::insert_if_new`] that handles `ByteArrayType` - /// (both String and Binary) - /// - /// Note this is the only function that is generic on [`ByteArrayType`], which - /// avoids having to template the entire structure, making the code - /// simpler and understand and reducing code bloat due to duplication. - /// - /// See comments on `insert_if_new` for more details - fn insert_if_new_inner(&mut self, values: &ArrayRef, mut make_payload_fn: MP, mut observe_payload_fn: OP) - where - MP: FnMut(Option<&[u8]>) -> V, - OP: FnMut(V), - B: ByteArrayType, - { - // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); - - // step 2: insert each value into the set, if not already present - let values = values.as_bytes::(); - - // Ensure lengths are equivalent - assert_eq!(values.len(), batch_hashes.len()); - - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // handle null value - let Some(value) = value else { - let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { - payload - } else { - let payload = make_payload_fn(None); - let null_index = self.offsets.len() - 1; - // nulls need a zero length in the offset buffer - let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - self.null = Some((payload, null_index)); - payload - }; - observe_payload_fn(payload); - continue; - }; - - // get the value as bytes - let value: &[u8] = value.as_ref(); - let value_len = O::usize_as(value.len()); - - // value is "small" - let payload = if value.len() <= SHORT_VALUE_LEN { - let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); - - // is value is already present in the set? - let entry = self.map.get_mut(hash, |header| { - // compare value if hashes match - if header.len != value_len { - return false; - } - // value is stored inline so no need to consult buffer - // (this is the "small string optimization") - inline == header.offset_or_inline - }); - - if let Some(entry) = entry { - entry.payload - } - // if no existing entry, make a new one - else { - // Put the small values into buffer and offsets so it appears - // the output array, but store the actual bytes inline for - // comparison - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); - let payload = make_payload_fn(Some(value)); - let new_header = Entry { - hash, - len: value_len, - offset_or_inline: inline, - payload, - }; - self.map - .insert_accounted(new_header, |header| header.hash, &mut self.map_size); - payload - } - } - // value is not "small" - else { - // Check if the value is already present in the set - let entry = self.map.get_mut(hash, |header| { - // compare value if hashes match - if header.len != value_len { - return false; - } - // Need to compare the bytes in the buffer - // SAFETY: buffer is only appended to, and we correctly inserted values and offsets - let existing_value = unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; - value == existing_value - }); - - if let Some(entry) = entry { - entry.payload - } - // if no existing entry, make a new one - else { - // Put the small values into buffer and offsets so it - // appears the output array, and store that offset - // so the bytes can be compared if needed - let offset = self.buffer.len(); // offset of start for data - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); - - let payload = make_payload_fn(Some(value)); - let new_header = Entry { - hash, - len: value_len, - offset_or_inline: offset, - payload, - }; - self.map - .insert_accounted(new_header, |header| header.hash, &mut self.map_size); - payload - } - }; - observe_payload_fn(payload); - } - // Check for overflow in offsets (if more data was sent than can be represented) - if O::from_usize(self.buffer.len()).is_none() { - panic!( - "Put {} bytes in buffer, more than can be represented by a {}", - self.buffer.len(), - type_name::() - ); - } - } - - /// Inserts each value from `values` into the map, invoking `make_payload_fn` for - /// each value if not already present, or `update_payload_fn` if the value already exists. - /// - /// This function handles both the insert and update cases. - /// - /// # Arguments: - /// - /// `values`: The array whose values are inserted or updated in the map. - /// - /// `make_payload_fn`: Invoked for each value that is not already present - /// to create the payload, in the order of the values in `values`. - /// - /// `update_payload_fn`: Invoked for each value that is already present, - /// allowing the payload to be updated in-place. - /// - /// # Safety: - /// - /// Note that `make_payload_fn` and `update_payload_fn` are only invoked - /// with valid values from `values`, not for the `NULL` value. - pub fn insert_or_update(&mut self, values: &ArrayRef, make_payload_fn: MP, update_payload_fn: UP) - where - MP: FnMut(Option<&[u8]>) -> V, - UP: FnMut(&mut V), - { - // Check the output type and dispatch to the appropriate internal function - match self.output_type { - OutputType::Binary => { - assert!(matches!(values.data_type(), DataType::Binary | DataType::LargeBinary)); - self.insert_or_update_inner::>(values, make_payload_fn, update_payload_fn) - } - OutputType::Utf8 => { - assert!(matches!(values.data_type(), DataType::Utf8 | DataType::LargeUtf8)); - self.insert_or_update_inner::>(values, make_payload_fn, update_payload_fn) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - }; - } - - /// Generic version of [`Self::insert_or_update`] that handles `ByteArrayType` - /// (both String and Binary). - /// - /// This is the only function that is generic on [`ByteArrayType`], which avoids having - /// to template the entire structure, simplifying the code and reducing code bloat due - /// to duplication. - /// - /// See comments on `insert_or_update` for more details. - fn insert_or_update_inner( - &mut self, - values: &ArrayRef, - mut make_payload_fn: MP, - mut update_payload_fn: UP, - ) where - MP: FnMut(Option<&[u8]>) -> V, // Function to create a new entry - UP: FnMut(&mut V), // Function to update an existing entry - B: ByteArrayType, - { - // Step 1: Compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes).unwrap(); // Compute the hashes for the values - - // Step 2: Insert or update each value - let values = values.as_bytes::(); - - assert_eq!(values.len(), batch_hashes.len()); // Ensure hash count matches value count - - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // Handle null value - let Some(value) = value else { - if let Some((ref mut payload, _)) = self.null { - // If null is already present, update the payload - update_payload_fn(payload); - } else { - // Null value doesn't exist, so create a new one - let payload = make_payload_fn(None); - let null_index = self.offsets.len() - 1; - // Nulls need a zero length in the offset buffer - let offset = self.buffer.len(); - self.offsets.push(O::usize_as(offset)); - self.null = Some((payload, null_index)); - } - continue; - }; - - let value: &[u8] = value.as_ref(); - let value_len = O::usize_as(value.len()); - - // Small value optimization - if value.len() <= SHORT_VALUE_LEN { - let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); - - // Check if the value is already present in the set - let entry = self.map.get_mut(hash, |header| { - if header.len != value_len { - return false; - } - inline == header.offset_or_inline - }); - - if let Some(entry) = entry { - update_payload_fn(&mut entry.payload); - } else { - // Insert a new value if not found - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); - let payload = make_payload_fn(Some(value)); - let new_entry = Entry { - hash, - len: value_len, - offset_or_inline: inline, - payload, - }; - self.map - .insert_accounted(new_entry, |header| header.hash, &mut self.map_size); - } - } else { - // Handle larger values - let entry = self.map.get_mut(hash, |header| { - if header.len != value_len { - return false; - } - let existing_value = unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; - value == existing_value - }); - - if let Some(entry) = entry { - update_payload_fn(&mut entry.payload); - } else { - // Insert a new large value if not found - let offset = self.buffer.len(); - self.buffer.append_slice(value); - self.offsets.push(O::usize_as(self.buffer.len())); - let payload = make_payload_fn(Some(value)); - let new_entry = Entry { - hash, - len: value_len, - offset_or_inline: offset, - payload, - }; - self.map - .insert_accounted(new_entry, |header| header.hash, &mut self.map_size); - } - }; - } - - // Ensure no overflow in offsets - if O::from_usize(self.buffer.len()).is_none() { - panic!( - "Put {} bytes in buffer, more than can be represented by a {}", - self.buffer.len(), - type_name::() - ); - } - } - - /// Generic version of [`Self::get_payloads`] that handles `ByteArrayType` - /// (both String and Binary). - /// - /// This function computes the hashes for each value and retrieves the payloads - /// stored in the map, leveraging small value optimizations when possible. - /// - /// # Arguments: - /// - /// `values`: The array whose payloads are being retrieved. - /// - /// # Returns - /// - /// A vector of payloads for each value, or `None` if the value is not found. - /// - /// # Safety: - /// - /// This function ensures that small values are handled using inline optimization - /// and larger values are safely retrieved from the buffer. - fn get_payloads_inner(self, values: &ArrayRef) -> Vec> - where - B: ByteArrayType, - { - // Step 1: Compute hashes - let mut batch_hashes = vec![0u64; values.len()]; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, &mut batch_hashes).unwrap(); // Compute the hashes for the values - - // Step 2: Get payloads for each value - let values = values.as_bytes::(); - assert_eq!(values.len(), batch_hashes.len()); // Ensure hash count matches value count - - let mut payloads = Vec::with_capacity(values.len()); - - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // Handle null value - let Some(value) = value else { - if let Some(&(payload, _)) = self.null.as_ref() { - payloads.push(Some(payload)); - } else { - payloads.push(None); - } - continue; - }; - - let value: &[u8] = value.as_ref(); - let value_len = O::usize_as(value.len()); - - // Small value optimization - let payload = if value.len() <= SHORT_VALUE_LEN { - let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); - - // Check if the value is already present in the set - let entry = self.map.get(hash, |header| { - if header.len != value_len { - return false; - } - inline == header.offset_or_inline - }); - - entry.map(|entry| entry.payload) - } else { - // Handle larger values - let entry = self.map.get(hash, |header| { - if header.len != value_len { - return false; - } - let existing_value = unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; - value == existing_value - }); - - entry.map(|entry| entry.payload) - }; - - payloads.push(payload); - } - - payloads - } - - /// Retrieves the payloads for each value from `values`, either by using - /// small value optimizations or larger value handling. - /// - /// This function will compute hashes for each value and attempt to retrieve - /// the corresponding payload from the map. If the value is not found, it will return `None`. - /// - /// # Arguments: - /// - /// `values`: The array whose payloads need to be retrieved. - /// - /// # Returns - /// - /// A vector of payloads for each value, or `None` if the value is not found. - /// - /// # Safety: - /// - /// This function handles both small and large values in a safe manner, though `unsafe` code is - /// used internally for performance optimization. - pub fn get_payloads(self, values: &ArrayRef) -> Vec> { - match self.output_type { - OutputType::Binary => { - assert!(matches!(values.data_type(), DataType::Binary | DataType::LargeBinary)); - self.get_payloads_inner::>(values) - } - OutputType::Utf8 => { - assert!(matches!(values.data_type(), DataType::Utf8 | DataType::LargeUtf8)); - self.get_payloads_inner::>(values) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - } - } - - /// Converts this set into a `StringArray`, `LargeStringArray`, - /// `BinaryArray`, or `LargeBinaryArray` containing each distinct value - /// that was inserted. This is done without copying the values. - /// - /// The values are guaranteed to be returned in the same order in which - /// they were first seen. - pub fn into_state(self) -> ArrayRef { - let Self { - output_type, - map: _, - map_size: _, - offsets, - mut buffer, - random_state: _, - hashes_buffer: _, - null, - } = self; - - // Only make a `NullBuffer` if there was a null value - let nulls = null.map(|(_payload, null_index)| { - let num_values = offsets.len() - 1; - single_null_buffer(num_values, null_index) - }); - // SAFETY: the offsets were constructed correctly in `insert_if_new` -- - // monotonically increasing, overflows were checked. - let offsets = unsafe { OffsetBuffer::new_unchecked(ScalarBuffer::from(offsets)) }; - let values = buffer.finish(); - - match output_type { - OutputType::Binary => { - // SAFETY: the offsets were constructed correctly - Arc::new(unsafe { GenericBinaryArray::new_unchecked(offsets, values, nulls) }) - } - OutputType::Utf8 => { - // SAFETY: - // 1. the offsets were constructed safely - // - // 2. we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out - Arc::new(unsafe { GenericStringArray::new_unchecked(offsets, values, nulls) }) - } - _ => unreachable!("View types should use `ArrowBytesViewMap`"), - } - } - - /// Total number of entries (including null, if present) - pub fn len(&self) -> usize { - self.non_null_len() + self.null.map(|_| 1).unwrap_or(0) - } - - /// Is the set empty? - pub fn is_empty(&self) -> bool { - self.map.is_empty() && self.null.is_none() - } - - /// Number of non null entries - pub fn non_null_len(&self) -> usize { - self.map.len() - } - - /// Return the total size, in bytes, of memory used to store the data in - /// this set, not including `self` - pub fn size(&self) -> usize { - self.map_size - + self.buffer.capacity() * mem::size_of::() - + self.offsets.allocated_size() - + self.hashes_buffer.allocated_size() - } -} - -/// Returns a `NullBuffer` with a single null value at the given index -fn single_null_buffer(num_values: usize, null_index: usize) -> NullBuffer { - let mut bool_builder = BooleanBufferBuilder::new(num_values); - bool_builder.append_n(num_values, true); - bool_builder.set_bit(null_index, false); - NullBuffer::from(bool_builder.finish()) -} - -impl Debug for ArrowBytesMap -where - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ArrowBytesMap") - .field("map", &"") - .field("map_size", &self.map_size) - .field("buffer", &self.buffer) - .field("random_state", &self.random_state) - .field("hashes_buffer", &self.hashes_buffer) - .finish() - } -} - -/// Maximum size of a value that can be inlined in the hash table -const SHORT_VALUE_LEN: usize = mem::size_of::(); - -/// Entry in the hash table -- see [`ArrowBytesMap`] for more details -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] -struct Entry -where - O: OffsetSizeTrait, - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - /// hash of the value (stored to avoid recomputing it in hash table check) - hash: u64, - /// if len =< [`SHORT_VALUE_LEN`]: the data inlined - /// if len > [`SHORT_VALUE_LEN`], the offset of where the data starts - offset_or_inline: usize, - /// length of the value, in bytes (use O here so we use only i32 for - /// strings, rather 64 bit usize) - len: O, - /// value stored by the entry - payload: V, -} - -impl Entry -where - O: OffsetSizeTrait, - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - /// returns self.offset..self.offset + self.len - #[inline(always)] - fn range(&self) -> Range { - self.offset_or_inline..self.offset_or_inline + self.len.as_usize() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use arrow::array::StringArray; - use datafusion::arrow; - use std::collections::HashMap; - - #[test] - fn test_insert_or_update_count_u8() { - let input = vec![ - Some("A"), - Some("bcdefghijklmnop"), - Some("X"), - Some("Y"), - None, - Some("qrstuvqxyzhjwya"), - Some("✨🔥"), - Some("🔥"), - Some("🔥🔥🔥🔥🔥🔥"), - Some("A"), // Duplicate to test the count increment - Some("Y"), // Another duplicate to test the count increment - ]; - - let mut map: ArrowBytesMap = ArrowBytesMap::new(OutputType::Utf8); - - let string_array = StringArray::from(input.clone()); - let arr: ArrayRef = Arc::new(string_array); - - map.insert_or_update( - &arr, - |_| 1u8, - |count| { - *count += 1; - }, - ); - - let expected_counts = [ - ("A", 2), - ("bcdefghijklmnop", 1), - ("X", 1), - ("Y", 2), - ("qrstuvqxyzhjwya", 1), - ("✨🔥", 1), - ("🔥", 1), - ("🔥🔥🔥🔥🔥🔥", 1), - ]; - - for &value in input.iter() { - if let Some(value) = value { - let string_array = StringArray::from(vec![Some(value)]); - let arr: ArrayRef = Arc::new(string_array); - - let mut result_payload: Option = None; - - map.insert_or_update( - &arr, - |_| { - panic!("Unexpected new entry during verification"); - }, - |count| { - result_payload = Some(*count); - }, - ); - - if let Some(expected_count) = expected_counts.iter().find(|&&(s, _)| s == value) { - assert_eq!(result_payload.unwrap(), expected_count.1); - } - } - } - } - - #[test] - fn test_insert_if_new_after_insert_or_update() { - let initial_values = StringArray::from(vec![Some("A"), Some("B"), Some("B"), Some("C"), Some("C")]); - - let mut map: ArrowBytesMap = ArrowBytesMap::new(OutputType::Utf8); - let arr: ArrayRef = Arc::new(initial_values); - - map.insert_or_update( - &arr, - |_| 1u8, - |count| { - *count += 1; - }, - ); - - let additional_values = StringArray::from(vec![Some("A"), Some("D"), Some("E")]); - let arr_additional: ArrayRef = Arc::new(additional_values); - - map.insert_if_new(&arr_additional, |_| 5u8, |_| {}); - - let combined_arr = StringArray::from(vec![Some("A"), Some("B"), Some("C"), Some("D"), Some("E")]); - - let arr_combined: ArrayRef = Arc::new(combined_arr); - let payloads = map.get_payloads(&arr_combined); - - let expected_payloads = [Some(1u8), Some(2u8), Some(2u8), Some(5u8), Some(5u8)]; - - assert_eq!(payloads, expected_payloads); - } - - #[test] - fn test_get_payloads_u8() { - let input = vec![ - Some("A"), - Some("bcdefghijklmnop"), - Some("X"), - Some("Y"), - None, - Some("qrstuvqxyzhjwya"), - Some("✨🔥"), - Some("🔥"), - Some("🔥🔥🔥🔥🔥🔥"), - Some("A"), // Duplicate to test the count increment - Some("Y"), // Another duplicate to test the count increment - ]; - - let mut map: ArrowBytesMap = ArrowBytesMap::new(OutputType::Utf8); - - let string_array = StringArray::from(input.clone()); - let arr: ArrayRef = Arc::new(string_array); - - map.insert_or_update( - &arr, - |_| 1u8, - |count| { - *count += 1; - }, - ); - - let expected_payloads = [ - Some(2u8), - Some(1u8), - Some(1u8), - Some(2u8), - Some(1u8), - Some(1u8), - Some(1u8), - Some(1u8), - Some(1u8), - Some(2u8), - Some(2u8), - ]; - - let payloads = map.get_payloads(&arr); - - assert_eq!(payloads.len(), expected_payloads.len()); - - for (i, payload) in payloads.iter().enumerate() { - assert_eq!(*payload, expected_payloads[i]); - } - } - - #[test] - fn test_map() { - let input = vec![ - // Note mix of short/long strings - Some("A"), - Some("bcdefghijklmnop"), - Some("X"), - Some("Y"), - None, - Some("qrstuvqxyzhjwya"), - Some("✨🔥"), - Some("🔥"), - Some("🔥🔥🔥🔥🔥🔥"), - ]; - - let mut test_map = TestMap::new(); - test_map.insert(&input); - test_map.insert(&input); // put it in twice - let expected_output: ArrayRef = Arc::new(StringArray::from(input)); - assert_eq!(&test_map.into_array(), &expected_output); - } - - #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] - struct TestPayload { - // store the string value to check against input - index: usize, // store the index of the string (each new string gets the next sequential input) - } - - /// Wraps an [`ArrowBytesMap`], validating its invariants - struct TestMap { - map: ArrowBytesMap, - // stores distinct strings seen, in order - strings: Vec>, - // map strings to index in strings - indexes: HashMap, usize>, - } - - impl Debug for TestMap { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TestMap") - .field("map", &"...") - .field("strings", &self.strings) - .field("indexes", &self.indexes) - .finish() - } - } - - impl TestMap { - /// creates a map with TestPayloads for the given strings and then - /// validates the payloads - fn new() -> Self { - Self { - map: ArrowBytesMap::new(OutputType::Utf8), - strings: vec![], - indexes: HashMap::new(), - } - } - - /// Inserts strings into the map - fn insert(&mut self, strings: &[Option<&str>]) { - let string_array = StringArray::from(strings.to_vec()); - let arr: ArrayRef = Arc::new(string_array); - - let mut next_index = self.indexes.len(); - let mut actual_new_strings = vec![]; - let mut actual_seen_indexes = vec![]; - // update self with new values, keeping track of newly added values - for str in strings { - let str = str.map(|s| s.to_string()); - let index = self.indexes.get(&str).cloned().unwrap_or_else(|| { - actual_new_strings.push(str.clone()); - let index = self.strings.len(); - self.strings.push(str.clone()); - self.indexes.insert(str, index); - index - }); - actual_seen_indexes.push(index); - } - - // insert the values into the map, recording what we did - let mut seen_new_strings = vec![]; - let mut seen_indexes = vec![]; - self.map.insert_if_new( - &arr, - |s| { - let value = s.map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string")); - let index = next_index; - next_index += 1; - seen_new_strings.push(value); - TestPayload { index } - }, - |payload| { - seen_indexes.push(payload.index); - }, - ); - - assert_eq!(actual_seen_indexes, seen_indexes); - assert_eq!(actual_new_strings, seen_new_strings); - } - - /// Call `self.map.into_array()` validating that the strings are in the same - /// order as they were inserted - fn into_array(self) -> ArrayRef { - let Self { - map, - strings, - indexes: _, - } = self; - - let arr = map.into_state(); - let expected: ArrayRef = Arc::new(StringArray::from(strings)); - assert_eq!(&arr, &expected); - arr - } - } -} diff --git a/src/common/collections/binary_view_map.rs b/src/common/collections/binary_view_map.rs deleted file mode 100644 index aafae8f..0000000 --- a/src/common/collections/binary_view_map.rs +++ /dev/null @@ -1,765 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! [`ArrowBytesViewMap`] and [`ArrowBytesViewSet`] for storing maps/sets of values from -//! `StringViewArray`/`BinaryViewArray`. -//! Much of the code is from `binary_map.rs`, but with simpler implementation because we directly use the -//! [`GenericByteViewBuilder`]. -use ahash::RandomState; -use arrow::array::cast::AsArray; -use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; -use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; -use datafusion::arrow; -use datafusion::common::hash_utils::create_hashes; -use datafusion::common::utils::proxy::{RawTableAllocExt, VecAllocExt}; -use datafusion::physical_expr::binary_map::OutputType; -use std::fmt::Debug; -use std::sync::Arc; - -/// Optimized map for storing Arrow "byte view" types (`StringView`, `BinaryView`) -/// values that can produce the set of keys on -/// output as `GenericBinaryViewArray` without copies. -/// -/// Equivalent to `HashSet` but with better performance for arrow -/// data. -/// -/// # Generic Arguments -/// -/// * `V`: payload type -/// -/// # Description -/// -/// This is a specialized HashMap with the following properties: -/// -/// 1. Optimized for storing and emitting Arrow byte types (e.g. -/// `StringViewArray` / `BinaryViewArray`) very efficiently by minimizing copying of -/// the string values themselves, both when inserting and when emitting the -/// final array. -/// -/// 2. Retains the insertion order of entries in the final array. The values are -/// in the same order as they were inserted. -/// -/// Note this structure can be used as a `HashSet` by specifying the value type -/// as `()`, as is done by [`ArrowBytesViewSet`]. -/// -/// This map is used by the special `COUNT DISTINCT` aggregate function to -/// store the distinct values, and by the `GROUP BY` operator to store -/// group values when they are a single string array. - -// TODO: Remove after DataFusion next release once insert_or_update and get_payloads are added to the collection. -// Copied from datafusion/physical-expr-common/binary_view_map.rs. -pub struct ArrowBytesViewMap -where - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - /// Should the output be StringView or BinaryView? - output_type: OutputType, - /// Underlying hash set for each distinct value - map: hashbrown::raw::RawTable>, - /// Total size of the map in bytes - map_size: usize, - - /// Builder for output array - builder: GenericByteViewBuilder, - /// random state used to generate hashes - random_state: RandomState, - /// buffer that stores hash values (reused across batches to save allocations) - hashes_buffer: Vec, - /// `(payload, null_index)` for the 'null' value, if any - /// NOTE null_index is the logical index in the final array, not the index - /// in the buffer - null: Option<(V, usize)>, -} - -/// The size, in number of entries, of the initial hash table -const INITIAL_MAP_CAPACITY: usize = 512; - -impl ArrowBytesViewMap -where - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - pub fn new(output_type: OutputType) -> Self { - Self { - output_type, - map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), - map_size: 0, - builder: GenericByteViewBuilder::new(), - random_state: RandomState::new(), - hashes_buffer: vec![], - null: None, - } - } - - /// Return the contents of this map and replace it with a new empty map with - /// the same output type - pub fn take(&mut self) -> Self { - let mut new_self = Self::new(self.output_type); - std::mem::swap(self, &mut new_self); - new_self - } - - /// Inserts each value from `values` into the map, invoking `payload_fn` for - /// each value if *not* already present, deferring the allocation of the - /// payload until it is needed. - /// - /// Note that this is different than a normal map that would replace the - /// existing entry - /// - /// # Arguments: - /// - /// `values`: array whose values are inserted - /// - /// `make_payload_fn`: invoked for each value that is not already present - /// to create the payload, in order of the values in `values` - /// - /// `observe_payload_fn`: invoked once, for each value in `values`, that was - /// already present in the map, with corresponding payload value. - /// - /// # Returns - /// - /// The payload value for the entry, either the existing value or - /// the newly inserted value - /// - /// # Safety: - /// - /// Note that `make_payload_fn` and `observe_payload_fn` are only invoked - /// with valid values from `values`, not for the `NULL` value. - pub fn insert_if_new(&mut self, values: &ArrayRef, make_payload_fn: MP, observe_payload_fn: OP) - where - MP: FnMut(Option<&[u8]>) -> V, - OP: FnMut(V), - { - // Sanity check array type - match self.output_type { - OutputType::BinaryView => { - assert!(matches!(values.data_type(), DataType::BinaryView)); - self.insert_if_new_inner::(values, make_payload_fn, observe_payload_fn) - } - OutputType::Utf8View => { - assert!(matches!(values.data_type(), DataType::Utf8View)); - self.insert_if_new_inner::(values, make_payload_fn, observe_payload_fn) - } - _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), - }; - } - - /// Generic version of [`Self::insert_if_new`] that handles `ByteViewType` - /// (both StringView and BinaryView) - /// - /// Note this is the only function that is generic on [`ByteViewType`], which - /// avoids having to template the entire structure, making the code - /// simpler and understand and reducing code bloat due to duplication. - /// - /// See comments on `insert_if_new` for more details - fn insert_if_new_inner(&mut self, values: &ArrayRef, mut make_payload_fn: MP, mut observe_payload_fn: OP) - where - MP: FnMut(Option<&[u8]>) -> V, - OP: FnMut(V), - B: ByteViewType, - { - // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); - - // step 2: insert each value into the set, if not already present - let values = values.as_byte_view::(); - - // Ensure lengths are equivalent - assert_eq!(values.len(), batch_hashes.len()); - - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // handle null value - let Some(value) = value else { - let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { - payload - } else { - let payload = make_payload_fn(None); - let null_index = self.builder.len(); - self.builder.append_null(); - self.null = Some((payload, null_index)); - payload - }; - observe_payload_fn(payload); - continue; - }; - - // get the value as bytes - let value: &[u8] = value.as_ref(); - - let entry = self.map.get_mut(hash, |header| { - let v = self.builder.get_value(header.view_idx); - - if v.len() != value.len() { - return false; - } - - v == value - }); - - let payload = if let Some(entry) = entry { - entry.payload - } else { - // no existing value, make a new one. - let payload = make_payload_fn(Some(value)); - - let inner_view_idx = self.builder.len(); - let new_header = Entry { - view_idx: inner_view_idx, - hash, - payload, - }; - - self.builder.append_value(value); - - self.map.insert_accounted(new_header, |h| h.hash, &mut self.map_size); - payload - }; - observe_payload_fn(payload); - } - } - - /// Inserts each value from `values` into the map, invoking `make_payload_fn` for - /// each value if not already present, or `update_payload_fn` if the value already exists. - /// - /// This function handles both the insert and update cases. - /// - /// # Arguments: - /// - /// `values`: The array whose values are inserted or updated in the map. - /// - /// `make_payload_fn`: Invoked for each value that is not already present - /// to create the payload, in the order of the values in `values`. - /// - /// `update_payload_fn`: Invoked for each value that is already present, - /// allowing the payload to be updated in-place. - /// - /// # Safety: - /// - /// Note that `make_payload_fn` and `update_payload_fn` are only invoked - /// with valid values from `values`, not for the `NULL` value. - pub fn insert_or_update(&mut self, values: &ArrayRef, make_payload_fn: MP, update_payload_fn: UP) - where - MP: FnMut(Option<&[u8]>) -> V, - UP: FnMut(&mut V), - { - // Check the output type and dispatch to the appropriate internal function - match self.output_type { - OutputType::BinaryView => { - assert!(matches!(values.data_type(), DataType::BinaryView)); - self.insert_or_update_inner::(values, make_payload_fn, update_payload_fn) - } - OutputType::Utf8View => { - assert!(matches!(values.data_type(), DataType::Utf8View)); - self.insert_or_update_inner::(values, make_payload_fn, update_payload_fn) - } - _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), - }; - } - - /// Generic version of [`Self::insert_or_update`] that handles `ByteViewType` - /// (both StringView and BinaryView). - /// - /// This is the only function that is generic on [`ByteViewType`], which avoids having - /// to template the entire structure, simplifying the code and reducing code bloat due - /// to duplication. - /// - /// See comments on `insert_or_update` for more details. - fn insert_or_update_inner( - &mut self, - values: &ArrayRef, - mut make_payload_fn: MP, - mut update_payload_fn: UP, - ) where - MP: FnMut(Option<&[u8]>) -> V, - UP: FnMut(&mut V), - B: ByteViewType, - { - // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); - - // step 2: insert each value into the set, if not already present - let values = values.as_byte_view::(); - - // Ensure lengths are equivalent - assert_eq!(values.len(), batch_hashes.len()); - - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // Handle null value - let Some(value) = value else { - if let Some((ref mut payload, _)) = self.null { - update_payload_fn(payload); - } else { - let payload = make_payload_fn(None); - let null_index = self.builder.len(); - self.builder.append_null(); - self.null = Some((payload, null_index)); - } - continue; - }; - - let value: &[u8] = value.as_ref(); - - let entry = self.map.get_mut(hash, |header| { - let v = self.builder.get_value(header.view_idx); - - if v.len() != value.len() { - return false; - } - - v == value - }); - - if let Some(entry) = entry { - update_payload_fn(&mut entry.payload); - } else { - // no existing value, make a new one. - let payload = make_payload_fn(Some(value)); - - let inner_view_idx = self.builder.len(); - let new_header = Entry { - view_idx: inner_view_idx, - hash, - payload, - }; - - self.builder.append_value(value); - - self.map.insert_accounted(new_header, |h| h.hash, &mut self.map_size); - }; - } - } - - /// Generic version of [`Self::get_payloads`] that handles `ByteViewType` - /// (both StringView and BinaryView). - /// - /// This function computes the hashes for each value and retrieves the payloads - /// stored in the map, leveraging small value optimizations when possible. - /// - /// # Arguments: - /// - /// `values`: The array whose payloads are being retrieved. - /// - /// # Returns - /// - /// A vector of payloads for each value, or `None` if the value is not found. - /// - /// # Safety: - /// - /// This function ensures that small values are handled using inline optimization - /// and larger values are safely retrieved from the builder. - fn get_payloads_inner(self, values: &ArrayRef) -> Vec> - where - B: ByteViewType, - { - // Step 1: Compute hashes - let mut batch_hashes = vec![0u64; values.len()]; - create_hashes(&[values.clone()], &self.random_state, &mut batch_hashes).unwrap(); // Compute the hashes for the values - - // Step 2: Get payloads for each value - let values = values.as_byte_view::(); - assert_eq!(values.len(), batch_hashes.len()); // Ensure hash count matches value count - - let mut payloads = Vec::with_capacity(values.len()); - - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // Handle null value - let Some(value) = value else { - if let Some(&(payload, _)) = self.null.as_ref() { - payloads.push(Some(payload)); - } else { - payloads.push(None); - } - continue; - }; - - let value: &[u8] = value.as_ref(); - - let entry = self.map.get(hash, |header| { - let v = self.builder.get_value(header.view_idx); - v.len() == value.len() && v == value - }); - - let payload = entry.map(|e| e.payload); - payloads.push(payload); - } - - payloads - } - - /// Retrieves the payloads for each value from `values`, either by using - /// small value optimizations or larger value handling. - /// - /// This function will compute hashes for each value and attempt to retrieve - /// the corresponding payload from the map. If the value is not found, it will return `None`. - /// - /// # Arguments: - /// - /// `values`: The array whose payloads need to be retrieved. - /// - /// # Returns - /// - /// A vector of payloads for each value, or `None` if the value is not found. - pub fn get_payloads(self, values: &ArrayRef) -> Vec> { - match self.output_type { - OutputType::BinaryView => { - assert!(matches!(values.data_type(), DataType::BinaryView)); - self.get_payloads_inner::(values) - } - OutputType::Utf8View => { - assert!(matches!(values.data_type(), DataType::Utf8View)); - self.get_payloads_inner::(values) - } - _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), - } - } - - /// Converts this set into a `StringViewArray`, or `BinaryViewArray`, - /// containing each distinct value - /// that was inserted. This is done without copying the values. - /// - /// The values are guaranteed to be returned in the same order in which - /// they were first seen. - pub fn into_state(self) -> ArrayRef { - let mut builder = self.builder; - match self.output_type { - OutputType::BinaryView => { - let array = builder.finish(); - - Arc::new(array) - } - OutputType::Utf8View => { - // SAFETY: - // we asserted the input arrays were all the correct type and - // thus since all the values that went in were valid (e.g. utf8) - // so are all the values that come out - let array = builder.finish(); - let array = unsafe { array.to_string_view_unchecked() }; - Arc::new(array) - } - _ => { - unreachable!("Utf8/Binary should use `ArrowBytesMap`") - } - } - } - - /// Total number of entries (including null, if present) - pub fn len(&self) -> usize { - self.non_null_len() + self.null.map(|_| 1).unwrap_or(0) - } - - /// Is the set empty? - pub fn is_empty(&self) -> bool { - self.map.is_empty() && self.null.is_none() - } - - /// Number of non null entries - pub fn non_null_len(&self) -> usize { - self.map.len() - } - - /// Return the total size, in bytes, of memory used to store the data in - /// this set, not including `self` - pub fn size(&self) -> usize { - self.map_size + self.builder.allocated_size() + self.hashes_buffer.allocated_size() - } -} - -impl Debug for ArrowBytesViewMap -where - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ArrowBytesMap") - .field("map", &"") - .field("map_size", &self.map_size) - .field("view_builder", &self.builder) - .field("random_state", &self.random_state) - .field("hashes_buffer", &self.hashes_buffer) - .finish() - } -} - -/// Entry in the hash table -- see [`ArrowBytesViewMap`] for more details -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] -struct Entry -where - V: Debug + PartialEq + Eq + Clone + Copy + Default, -{ - /// The idx into the views array - view_idx: usize, - - hash: u64, - - /// value stored by the entry - payload: V, -} - -#[cfg(test)] -mod tests { - use arrow::array::{GenericByteViewArray, StringViewArray}; - use hashbrown::HashMap; - - use super::*; - - #[test] - fn test_insert_or_update_count_u8() { - let values = GenericByteViewArray::from(vec![ - Some("a"), - Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), - Some("🔥"), - Some("✨✨✨"), - Some("foobarbaz"), - Some("🔥"), - Some("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥"), - ]); - - let mut map: ArrowBytesViewMap = ArrowBytesViewMap::new(OutputType::Utf8View); - let arr: ArrayRef = Arc::new(values); - - map.insert_or_update( - &arr, - |_| 1u8, - |count| { - *count += 1; - }, - ); - - let expected_counts = [ - ("a", 1), - ("✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥✨🔥", 2), - ("🔥", 2), - ("✨✨✨", 1), - ("foobarbaz", 1), - ]; - - for value in expected_counts.iter() { - let string_array = GenericByteViewArray::from(vec![Some(value.0)]); - let arr: ArrayRef = Arc::new(string_array); - - let mut result_payload: Option = None; - - map.insert_or_update( - &arr, - |_| { - panic!("Unexpected new entry during verification"); - }, - |count| { - result_payload = Some(*count); - }, - ); - - assert_eq!(result_payload.unwrap(), value.1); - } - } - - #[test] - fn test_insert_if_new_after_insert_or_update() { - let initial_values = GenericByteViewArray::from(vec![Some("A"), Some("B"), Some("B"), Some("C"), Some("C")]); - - let mut map: ArrowBytesViewMap = ArrowBytesViewMap::new(OutputType::Utf8View); - let arr: ArrayRef = Arc::new(initial_values); - - map.insert_or_update( - &arr, - |_| 1u8, - |count| { - *count += 1; - }, - ); - - let additional_values = GenericByteViewArray::from(vec![Some("A"), Some("D"), Some("E")]); - let arr_additional: ArrayRef = Arc::new(additional_values); - - map.insert_if_new(&arr_additional, |_| 5u8, |_| {}); - - let expected_payloads = [Some(1u8), Some(2u8), Some(2u8), Some(5u8), Some(5u8)]; - - let combined_arr = GenericByteViewArray::from(vec![Some("A"), Some("B"), Some("C"), Some("D"), Some("E")]); - - let arr_combined: ArrayRef = Arc::new(combined_arr); - let payloads = map.get_payloads(&arr_combined); - - assert_eq!(payloads, expected_payloads); - } - - #[test] - fn test_get_payloads_u8() { - let values = GenericByteViewArray::from(vec![ - Some("A"), - Some("bcdefghijklmnop"), - Some("X"), - Some("Y"), - None, - Some("qrstuvqxyzhjwya"), - Some("✨🔥"), - Some("🔥"), - Some("🔥🔥🔥🔥🔥🔥"), - Some("A"), // Duplicate to test the count increment - Some("Y"), // Another duplicate to test the count increment - ]); - - let mut map: ArrowBytesViewMap = ArrowBytesViewMap::new(OutputType::Utf8View); - let arr: ArrayRef = Arc::new(values); - - map.insert_or_update( - &arr, - |_| 1u8, - |count| { - *count += 1; - }, - ); - - let expected_payloads = [ - Some(2u8), - Some(1u8), - Some(1u8), - Some(2u8), - Some(1u8), - Some(1u8), - Some(1u8), - Some(1u8), - Some(1u8), - Some(2u8), - Some(2u8), - ]; - - let payloads = map.get_payloads(&arr); - - assert_eq!(payloads.len(), expected_payloads.len()); - - for (i, payload) in payloads.iter().enumerate() { - assert_eq!(*payload, expected_payloads[i]); - } - } - - #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)] - struct TestPayload { - // store the string value to check against input - index: usize, // store the index of the string (each new string gets the next sequential input) - } - - /// Wraps an [`ArrowBytesViewMap`], validating its invariants - struct TestMap { - map: ArrowBytesViewMap, - // stores distinct strings seen, in order - strings: Vec>, - // map strings to index in strings - indexes: HashMap, usize>, - } - - impl TestMap { - /// creates a map with TestPayloads for the given strings and then - /// validates the payloads - fn new() -> Self { - Self { - map: ArrowBytesViewMap::new(OutputType::Utf8View), - strings: vec![], - indexes: HashMap::new(), - } - } - - /// Inserts strings into the map - fn insert(&mut self, strings: &[Option<&str>]) { - let string_array = StringViewArray::from(strings.to_vec()); - let arr: ArrayRef = Arc::new(string_array); - - let mut next_index = self.indexes.len(); - let mut actual_new_strings = vec![]; - let mut actual_seen_indexes = vec![]; - // update self with new values, keeping track of newly added values - for str in strings { - let str = str.map(|s| s.to_string()); - let index = self.indexes.get(&str).cloned().unwrap_or_else(|| { - actual_new_strings.push(str.clone()); - let index = self.strings.len(); - self.strings.push(str.clone()); - self.indexes.insert(str, index); - index - }); - actual_seen_indexes.push(index); - } - - // insert the values into the map, recording what we did - let mut seen_new_strings = vec![]; - let mut seen_indexes = vec![]; - self.map.insert_if_new( - &arr, - |s| { - let value = s.map(|s| String::from_utf8(s.to_vec()).expect("Non utf8 string")); - let index = next_index; - next_index += 1; - seen_new_strings.push(value); - TestPayload { index } - }, - |payload| { - seen_indexes.push(payload.index); - }, - ); - - assert_eq!(actual_seen_indexes, seen_indexes); - assert_eq!(actual_new_strings, seen_new_strings); - } - - /// Call `self.map.into_array()` validating that the strings are in the same - /// order as they were inserted - fn into_array(self) -> ArrayRef { - let Self { - map, - strings, - indexes: _, - } = self; - - let arr = map.into_state(); - let expected: ArrayRef = Arc::new(StringViewArray::from(strings)); - assert_eq!(&arr, &expected); - arr - } - } - - #[test] - fn test_map() { - let input = vec![ - // Note mix of short/long strings - Some("A"), - Some("bcdefghijklmnop1234567"), - Some("X"), - Some("Y"), - None, - Some("qrstuvqxyzhjwya"), - Some("✨🔥"), - Some("🔥"), - Some("🔥🔥🔥🔥🔥🔥"), - ]; - - let mut test_map = TestMap::new(); - test_map.insert(&input); - test_map.insert(&input); // put it in twice - let expected_output: ArrayRef = Arc::new(StringViewArray::from(input)); - assert_eq!(&test_map.into_array(), &expected_output); - } -} diff --git a/src/common/mod.rs b/src/common/mod.rs index 76bf27c..809ce02 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -15,5 +15,4 @@ // specific language governing permissions and limitations // under the License. -pub mod collections; pub mod mode; diff --git a/src/common/mode.rs b/src/common/mode.rs index b5d7b53..e559ce3 100644 --- a/src/common/mode.rs +++ b/src/common/mode.rs @@ -19,6 +19,5 @@ mod bytes; mod native; pub use bytes::BytesModeAccumulator; -pub use bytes::BytesViewModeAccumulator; pub use native::FloatModeAccumulator; pub use native::PrimitiveModeAccumulator; diff --git a/src/common/mode/bytes.rs b/src/common/mode/bytes.rs index 21117a5..0589c1b 100644 --- a/src/common/mode/bytes.rs +++ b/src/common/mode/bytes.rs @@ -15,198 +15,88 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - +use arrow::array::ArrayAccessor; +use arrow::array::ArrayIter; use arrow::array::ArrayRef; use arrow::array::AsArray; -use arrow::array::OffsetSizeTrait; use arrow::datatypes::DataType; use datafusion::arrow; -use datafusion::common::cast::as_list_array; use datafusion::common::cast::as_primitive_array; -use datafusion::common::utils::array_into_list_array_nullable; +use datafusion::common::cast::as_string_array; use datafusion::error::Result; use datafusion::logical_expr::Accumulator; -use datafusion::physical_expr::binary_map::ArrowBytesSet; -use datafusion::physical_expr::binary_map::OutputType; -use datafusion::physical_expr_common::binary_view_map::ArrowBytesViewSet; use datafusion::scalar::ScalarValue; - -use crate::common::collections::ArrowBytesMap; -use crate::common::collections::ArrowBytesViewMap; +use std::collections::HashMap; #[derive(Debug)] -pub struct BytesModeAccumulator { - values: ArrowBytesSet, - value_counts: ArrowBytesMap, +pub struct BytesModeAccumulator { + value_counts: HashMap, + data_type: DataType, } -impl BytesModeAccumulator { - pub fn new(output_type: OutputType) -> Self { +impl BytesModeAccumulator { + pub fn new(data_type: &DataType) -> Self { Self { - values: ArrowBytesSet::new(output_type), - value_counts: ArrowBytesMap::new(output_type), - } - } -} - -impl Accumulator for BytesModeAccumulator { - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - self.values.insert(&values[0]); - - self.value_counts.insert_or_update( - &values[0], - |maybe_value| { - if maybe_value.is_none() { - i64::MIN - } else { - 1i64 - } - }, - |count| *count += 1, - ); - - Ok(()) - } - - fn state(&mut self) -> Result> { - let values = self.values.take().into_state(); - let payloads: Vec = self - .value_counts - .take() - .get_payloads(&values) - .into_iter() - .map(|count| match count { - Some(c) => ScalarValue::Int64(Some(c)), - None => ScalarValue::Int64(None), - }) - .collect(); - - let values_list = Arc::new(array_into_list_array_nullable(values)); - let payloads_list = ScalarValue::new_list_nullable(&payloads, &DataType::Int64); - - Ok(vec![ScalarValue::List(values_list), ScalarValue::List(payloads_list)]) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); + value_counts: HashMap::new(), + data_type: data_type.clone(), } - - let arr = as_list_array(&states[0])?; - let counts = as_primitive_array::(&states[1])?; - - arr.iter().zip(counts.iter()).try_for_each(|(maybe_list, maybe_count)| { - if let (Some(list), Some(count)) = (maybe_list, maybe_count) { - // Insert or update the count for each value - self.value_counts - .insert_or_update(&list, |_| count, |existing_count| *existing_count += count); - } - Ok(()) - }) } - fn evaluate(&mut self) -> Result { - let mut max_index: Option = None; - let mut max_count: i64 = 0; - - let values = self.values.take().into_state(); - let counts = self.value_counts.take().get_payloads(&values); - - for (i, count) in counts.into_iter().enumerate() { - if let Some(c) = count { - if c > max_count { - max_count = c; - max_index = Some(i); - } - } - } - - match max_index { - Some(index) => { - let array = values.as_string::(); - let mode_value = array.value(index); - if mode_value.is_empty() { - Ok(ScalarValue::Utf8(None)) - } else if O::IS_LARGE { - Ok(ScalarValue::LargeUtf8(Some(mode_value.to_string()))) - } else { - Ok(ScalarValue::Utf8(Some(mode_value.to_string()))) - } - } - None => { - if O::IS_LARGE { - Ok(ScalarValue::LargeUtf8(None)) - } else { - Ok(ScalarValue::Utf8(None)) - } + fn update_counts<'a, V>(&mut self, array: V) + where + V: ArrayAccessor, + { + for value in ArrayIter::new(array).flatten() { + let key = value.to_string(); + if let Some(count) = self.value_counts.get_mut(&key) { + *count += 1; + } else { + self.value_counts.insert(key, 1); } } } - - fn size(&self) -> usize { - self.values.size() + self.value_counts.size() - } -} - -#[derive(Debug)] -pub struct BytesViewModeAccumulator { - values: ArrowBytesViewSet, - value_counts: ArrowBytesViewMap, -} - -impl BytesViewModeAccumulator { - pub fn new(output_type: OutputType) -> Self { - Self { - values: ArrowBytesViewSet::new(output_type), - value_counts: ArrowBytesViewMap::new(output_type), - } - } } -impl Accumulator for BytesViewModeAccumulator { +impl Accumulator for BytesModeAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { if values.is_empty() { return Ok(()); } - self.values.insert(&values[0]); + match &self.data_type { + DataType::Utf8View => { + let array = values[0].as_string_view(); + self.update_counts(array); + } + _ => { + let array = values[0].as_string::(); + self.update_counts(array); + } + }; - self.value_counts.insert_or_update( - &values[0], - |maybe_value| { - if maybe_value.is_none() { - i64::MIN - } else { - 1i64 - } - }, - |count| *count += 1, - ); Ok(()) } fn state(&mut self) -> Result> { - let values = self.values.take().into_state(); - let payloads: Vec = self + let values: Vec = self .value_counts - .take() - .get_payloads(&values) - .into_iter() - .map(|count| match count { - Some(c) => ScalarValue::Int64(Some(c)), - None => ScalarValue::Int64(None), - }) + .keys() + .map(|key| ScalarValue::Utf8(Some(key.to_string()))) + .collect(); + + let frequencies: Vec = self + .value_counts + .values() + .map(|&count| ScalarValue::Int64(Some(count))) .collect(); - let values_list = Arc::new(array_into_list_array_nullable(values)); - let payloads_list = ScalarValue::new_list_nullable(&payloads, &DataType::Int64); + let values_scalar = ScalarValue::new_list_nullable(&values, &DataType::Utf8); + let frequencies_scalar = ScalarValue::new_list_nullable(&frequencies, &DataType::Int64); - Ok(vec![ScalarValue::List(values_list), ScalarValue::List(payloads_list)]) + Ok(vec![ + ScalarValue::List(values_scalar), + ScalarValue::List(frequencies_scalar), + ]) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { @@ -214,51 +104,53 @@ impl Accumulator for BytesViewModeAccumulator { return Ok(()); } - let arr = as_list_array(&states[0])?; - let counts = as_primitive_array::(&states[1])?; + let values_array = as_string_array(&states[0])?; + let counts_array = as_primitive_array::(&states[1])?; - arr.iter().zip(counts.iter()).try_for_each(|(maybe_list, maybe_count)| { - if let (Some(list), Some(count)) = (maybe_list, maybe_count) { - // Insert or update the count for each value - self.value_counts - .insert_or_update(&list, |_| count, |existing_count| *existing_count += count); + for (i, value_option) in values_array.iter().enumerate() { + if let Some(value) = value_option { + let count = counts_array.value(i); + let entry = self.value_counts.entry(value.to_string()).or_insert(0); + *entry += count; } - Ok(()) - }) + } + + Ok(()) } fn evaluate(&mut self) -> Result { - let mut max_index: Option = None; - let mut max_count: i64 = 0; - - let values = self.values.take().into_state(); - let counts = self.value_counts.take().get_payloads(&values); - - for (i, count) in counts.into_iter().enumerate() { - if let Some(c) = count { - if c > max_count { - max_count = c; - max_index = Some(i); - } - } + if self.value_counts.is_empty() { + return match &self.data_type { + DataType::Utf8View => Ok(ScalarValue::Utf8View(None)), + _ => Ok(ScalarValue::Utf8(None)), + }; } - match max_index { - Some(index) => { - let array = values.as_string_view(); - let mode_value = array.value(index); - if mode_value.is_empty() { - Ok(ScalarValue::Utf8View(None)) - } else { - Ok(ScalarValue::Utf8View(Some(mode_value.to_string()))) - } - } - None => Ok(ScalarValue::Utf8View(None)), + let mode = self + .value_counts + .iter() + .max_by(|a, b| { + // First compare counts + a.1.cmp(b.1) + // If counts are equal, compare keys in reverse order to get the maximum string + .then_with(|| b.0.cmp(a.0)) + }) + .map(|(value, _)| value.to_string()); + + match mode { + Some(result) => match &self.data_type { + DataType::Utf8View => Ok(ScalarValue::Utf8View(Some(result))), + _ => Ok(ScalarValue::Utf8(Some(result))), + }, + None => match &self.data_type { + DataType::Utf8View => Ok(ScalarValue::Utf8View(None)), + _ => Ok(ScalarValue::Utf8(None)), + }, } } fn size(&self) -> usize { - self.values.size() + self.value_counts.size() + self.value_counts.capacity() * std::mem::size_of::<(String, i64)>() + std::mem::size_of_val(&self.data_type) } } @@ -270,7 +162,7 @@ mod tests { #[test] fn test_mode_accumulator_single_mode_utf8() -> Result<()> { - let mut acc = BytesModeAccumulator::::new(OutputType::Utf8); + let mut acc = BytesModeAccumulator::new(&DataType::Utf8); let values: ArrayRef = Arc::new(StringArray::from(vec![ Some("apple"), Some("banana"), @@ -289,7 +181,7 @@ mod tests { #[test] fn test_mode_accumulator_tie_utf8() -> Result<()> { - let mut acc = BytesModeAccumulator::::new(OutputType::Utf8); + let mut acc = BytesModeAccumulator::new(&DataType::Utf8); let values: ArrayRef = Arc::new(StringArray::from(vec![ Some("apple"), Some("banana"), @@ -307,7 +199,7 @@ mod tests { #[test] fn test_mode_accumulator_all_nulls_utf8() -> Result<()> { - let mut acc = BytesModeAccumulator::::new(OutputType::Utf8); + let mut acc = BytesModeAccumulator::new(&DataType::Utf8); let values: ArrayRef = Arc::new(StringArray::from(vec![None as Option<&str>, None, None])); acc.update_batch(&[values])?; @@ -319,7 +211,7 @@ mod tests { #[test] fn test_mode_accumulator_with_nulls_utf8() -> Result<()> { - let mut acc = BytesModeAccumulator::::new(OutputType::Utf8); + let mut acc = BytesModeAccumulator::new(&DataType::Utf8); let values: ArrayRef = Arc::new(StringArray::from(vec![ Some("apple"), None, @@ -340,7 +232,7 @@ mod tests { #[test] fn test_mode_accumulator_single_mode_utf8view() -> Result<()> { - let mut acc = BytesViewModeAccumulator::new(OutputType::Utf8View); + let mut acc = BytesModeAccumulator::new(&DataType::Utf8View); let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![ Some("apple"), Some("banana"), @@ -359,7 +251,7 @@ mod tests { #[test] fn test_mode_accumulator_tie_utf8view() -> Result<()> { - let mut acc = BytesViewModeAccumulator::new(OutputType::Utf8View); + let mut acc = BytesModeAccumulator::new(&DataType::Utf8View); let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![ Some("apple"), Some("banana"), @@ -377,7 +269,7 @@ mod tests { #[test] fn test_mode_accumulator_all_nulls_utf8view() -> Result<()> { - let mut acc = BytesViewModeAccumulator::new(OutputType::Utf8View); + let mut acc = BytesModeAccumulator::new(&DataType::Utf8View); let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![None as Option<&str>, None, None])); acc.update_batch(&[values])?; @@ -389,7 +281,7 @@ mod tests { #[test] fn test_mode_accumulator_with_nulls_utf8view() -> Result<()> { - let mut acc = BytesViewModeAccumulator::new(OutputType::Utf8View); + let mut acc = BytesModeAccumulator::new(&DataType::Utf8View); let values: ArrayRef = Arc::new(GenericByteViewArray::from(vec![ Some("apple"), None, diff --git a/src/common/mode/native.rs b/src/common/mode/native.rs index 0e733b7..a52e7b1 100644 --- a/src/common/mode/native.rs +++ b/src/common/mode/native.rs @@ -122,7 +122,7 @@ where } std::cmp::Ordering::Equal => { max_value = match max_value { - Some(ref current_max_value) if value < current_max_value => Some(*value), + Some(ref current_max_value) if value > current_max_value => Some(*value), Some(ref current_max_value) => Some(*current_max_value), None => Some(*value), }; @@ -234,7 +234,7 @@ where } std::cmp::Ordering::Equal => { max_value = match max_value { - Some(current_max_value) if value.0 < current_max_value => Some(value.0), + Some(current_max_value) if value.0 > current_max_value => Some(value.0), Some(current_max_value) => Some(current_max_value), None => Some(value.0), }; @@ -305,7 +305,7 @@ mod tests { let result = acc.evaluate()?; assert_eq!( result, - ScalarValue::new_primitive::(Some(2), &DataType::Int64)? + ScalarValue::new_primitive::(Some(3), &DataType::Int64)? ); Ok(()) } @@ -362,7 +362,7 @@ mod tests { let result = acc.evaluate()?; assert_eq!( result, - ScalarValue::new_primitive::(Some(2.0), &DataType::Float64)? + ScalarValue::new_primitive::(Some(3.0), &DataType::Float64)? ); Ok(()) } @@ -435,7 +435,7 @@ mod tests { let result = acc.evaluate()?; assert_eq!( result, - ScalarValue::new_primitive::(Some(1609545600000), &DataType::Date64)? + ScalarValue::new_primitive::(Some(1609632000000), &DataType::Date64)? ); Ok(()) } @@ -515,7 +515,7 @@ mod tests { assert_eq!( result, ScalarValue::new_primitive::( - Some(7200000000), + Some(10800000000), &DataType::Time64(TimeUnit::Microsecond) )? ); diff --git a/src/mode.rs b/src/mode.rs index 619a73e..1bca057 100644 --- a/src/mode.rs +++ b/src/mode.rs @@ -29,22 +29,18 @@ use datafusion::arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion::common::not_impl_err; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; -use datafusion::physical_expr::binary_map::OutputType; use std::any::Any; use std::fmt::Debug; -use crate::common::mode::{ - BytesModeAccumulator, BytesViewModeAccumulator, FloatModeAccumulator, PrimitiveModeAccumulator, -}; +use crate::common::mode::{BytesModeAccumulator, FloatModeAccumulator, PrimitiveModeAccumulator}; make_udaf_expr_and_func!(ModeFunction, mode, x, "Calculates the most frequent value.", mode_udaf); /// The `ModeFunction` calculates the mode (most frequent value) from a set of values. /// /// - Null values are ignored during the calculation. -/// - If multiple values have the same frequency, the first encountered value with the highest frequency is returned. -/// - In the case of `Utf8` or `Utf8View`, the first value encountered in the original order with the highest frequency is returned. +/// - If multiple values have the same frequency, the MAX value with the highest frequency is returned. pub struct ModeFunction { signature: Signature, } @@ -141,9 +137,7 @@ impl AggregateUDFImpl for ModeFunction { DataType::Float32 => Box::new(FloatModeAccumulator::::new(data_type)), DataType::Float64 => Box::new(FloatModeAccumulator::::new(data_type)), - DataType::Utf8 => Box::new(BytesModeAccumulator::::new(OutputType::Utf8)), - DataType::LargeUtf8 => Box::new(BytesModeAccumulator::::new(OutputType::Utf8)), - DataType::Utf8View => Box::new(BytesViewModeAccumulator::new(OutputType::Utf8View)), + DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 => Box::new(BytesModeAccumulator::new(data_type)), _ => { return not_impl_err!("Unsupported data type: {:?} for mode function", data_type); } diff --git a/tests/main.rs b/tests/main.rs index 3a67f56..8a07ace 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -37,7 +37,7 @@ CREATE TABLE test_table ( "#; #[tokio::test] -async fn test_mode_utf8() { +async fn test_mode() { let mut execution = TestExecution::new().await.unwrap().with_setup(TEST_TABLE).await; let actual = execution.run_and_format("SELECT MODE(utf8_col) FROM test_table").await; @@ -49,10 +49,6 @@ async fn test_mode_utf8() { - "| apple |" - +---------------------------+ "###); -} -#[tokio::test] -async fn test_mode_int64() { - let mut execution = TestExecution::new().await.unwrap().with_setup(TEST_TABLE).await; let actual = execution.run_and_format("SELECT MODE(int64_col) FROM test_table").await; @@ -63,11 +59,6 @@ async fn test_mode_int64() { - "| 3 |" - +----------------------------+ "###); -} - -#[tokio::test] -async fn test_mode_float64() { - let mut execution = TestExecution::new().await.unwrap().with_setup(TEST_TABLE).await; let actual = execution .run_and_format("SELECT MODE(float64_col) FROM test_table") @@ -80,11 +71,6 @@ async fn test_mode_float64() { - "| 3.0 |" - +------------------------------+ "###); -} - -#[tokio::test] -async fn test_mode_date64() { - let mut execution = TestExecution::new().await.unwrap().with_setup(TEST_TABLE).await; let actual = execution .run_and_format("SELECT MODE(date64_col) FROM test_table") From 146b8620e38921848862e5d6eba8b09cbc62dfc8 Mon Sep 17 00:00:00 2001 From: dmitrybugakov Date: Tue, 8 Oct 2024 12:36:21 +0200 Subject: [PATCH 2/2] improve key allocation and add utf bench --- benches/mode.rs | 57 ++++++++++++++++++++++++++++++---------- src/common/mode/bytes.rs | 6 ++--- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/benches/mode.rs b/benches/mode.rs index 796bf81..da3b87e 100644 --- a/benches/mode.rs +++ b/benches/mode.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use arrow::util::bench_util::create_primitive_array; +use arrow::util::bench_util::{create_primitive_array, create_string_array}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion::{ arrow::{ @@ -27,14 +27,28 @@ use datafusion::{ }, logical_expr::Accumulator, }; -use datafusion_functions_extra::common::mode::PrimitiveModeAccumulator; +use datafusion_functions_extra::common::mode::{BytesModeAccumulator, PrimitiveModeAccumulator}; -fn prepare_mode_accumulator() -> Box { +fn prepare_primitive_mode_accumulator() -> Box { Box::new(PrimitiveModeAccumulator::::new(&DataType::Int32)) } -fn mode_bench(c: &mut Criterion, name: &str, values: ArrayRef) { - let mut accumulator = prepare_mode_accumulator(); +fn prepare_bytes_mode_accumulator() -> Box { + Box::new(BytesModeAccumulator::new(&DataType::Utf8)) +} + +fn mode_bench_primitive(c: &mut Criterion, name: &str, values: ArrayRef) { + let mut accumulator = prepare_primitive_mode_accumulator(); + c.bench_function(name, |b| { + b.iter(|| { + accumulator.update_batch(&[values.clone()]).unwrap(); + black_box(accumulator.evaluate().unwrap()); + }); + }); +} + +fn mode_bench_bytes(c: &mut Criterion, name: &str, values: ArrayRef) { + let mut accumulator = prepare_bytes_mode_accumulator(); c.bench_function(name, |b| { b.iter(|| { accumulator.update_batch(&[values.clone()]).unwrap(); @@ -44,17 +58,32 @@ fn mode_bench(c: &mut Criterion, name: &str, values: ArrayRef) { } fn mode_benchmark(c: &mut Criterion) { - // Case: No nulls - let values = Arc::new(create_primitive_array::(8192, 0.0)) as ArrayRef; - mode_bench(c, "mode benchmark no nulls", values); + let sizes = [100_000, 1_000_000]; + let null_percentages = [0.0, 0.3, 0.7]; - // Case: 30% nulls - let values = Arc::new(create_primitive_array::(8192, 0.3)) as ArrayRef; - mode_bench(c, "mode benchmark 30% nulls", values); + for &size in &sizes { + for &null_percentage in &null_percentages { + let values = Arc::new(create_primitive_array::(size, null_percentage)) as ArrayRef; + let name = format!( + "PrimitiveModeAccumulator: {} elements, {}% nulls", + size, + null_percentage * 100.0 + ); + mode_bench_primitive(c, &name, values); + } + } - // Case: 70% nulls - let values = Arc::new(create_primitive_array::(8192, 0.7)) as ArrayRef; - mode_bench(c, "mode benchmark 70% nulls", values); + for &size in &sizes { + for &null_percentage in &null_percentages { + let values = Arc::new(create_string_array::(size, null_percentage)) as ArrayRef; + let name = format!( + "BytesModeAccumulator: {} elements, {}% nulls", + size, + null_percentage * 100.0 + ); + mode_bench_bytes(c, &name, values); + } + } } criterion_group!(benches, mode_benchmark); diff --git a/src/common/mode/bytes.rs b/src/common/mode/bytes.rs index 0589c1b..0e0b5c2 100644 --- a/src/common/mode/bytes.rs +++ b/src/common/mode/bytes.rs @@ -47,11 +47,11 @@ impl BytesModeAccumulator { V: ArrayAccessor, { for value in ArrayIter::new(array).flatten() { - let key = value.to_string(); - if let Some(count) = self.value_counts.get_mut(&key) { + let key = value; + if let Some(count) = self.value_counts.get_mut(key) { *count += 1; } else { - self.value_counts.insert(key, 1); + self.value_counts.insert(key.to_string(), 1); } } }