diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index fdfb26f7f72a..13617c24a80f 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -82,7 +82,7 @@ where { let dict_values = array.values().as_any().downcast_ref().unwrap(); let values = unary::(dict_values, op); - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(Arc::new(values)))) } /// A helper function that applies a fallible unary function to a dictionary array with primitive value type. @@ -105,7 +105,7 @@ where let dict_values = array.values().as_any().downcast_ref().unwrap(); let values = try_unary::(dict_values, op)?; - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(Arc::new(values)))) } /// Applies an infallible unary function to an array with primitive values. diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs index ef551ceeddb7..7855b6fc6e46 100644 --- a/arrow-arith/src/temporal.rs +++ b/arrow-arith/src/temporal.rs @@ -462,7 +462,7 @@ where downcast_dictionary_array!( array => { let values = time_fraction_dyn(array.values(), name, op)?; - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(values))) } dt => return_compute_error_with!(format!("{name} does not support"), dt), ) diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 2d80c75f073a..d7060b3baf35 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -20,8 +20,8 @@ use crate::cast::AsArray; use crate::iterator::ArrayIter; use crate::types::*; use crate::{ - make_array, Array, ArrayAccessor, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, - PrimitiveArray, StringArray, + downcast_dictionary_array, make_array, Array, ArrayAccessor, ArrayRef, + ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, StringArray, }; use arrow_buffer::bit_util::set_bit; use arrow_buffer::buffer::NullBuffer; @@ -434,6 +434,7 @@ impl DictionaryArray { /// Panics if `values` has a length less than the current values /// /// ``` + /// # use std::sync::Arc; /// # use arrow_array::builder::PrimitiveDictionaryBuilder; /// # use arrow_array::{Int8Array, Int64Array, ArrayAccessor}; /// # use arrow_array::types::{Int32Type, Int8Type}; @@ -451,7 +452,7 @@ impl DictionaryArray { /// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64); /// /// // Create a Dict(Int32, - /// let new = dictionary.with_values(&values); + /// let new = dictionary.with_values(Arc::new(values)); /// /// // Verify values are as expected /// let new_typed = new.downcast_dict::().unwrap(); @@ -460,21 +461,18 @@ impl DictionaryArray { /// } /// ``` /// - pub fn with_values(&self, values: &dyn Array) -> Self { + pub fn with_values(&self, values: ArrayRef) -> Self { assert!(values.len() >= self.values.len()); - - let builder = self - .to_data() - .into_builder() - .data_type(DataType::Dictionary( - Box::new(K::DATA_TYPE), - Box::new(values.data_type().clone()), - )) - .child_data(vec![values.to_data()]); - - // SAFETY: - // Offsets were valid before and verified length is greater than or equal - Self::from(unsafe { builder.build_unchecked() }) + let data_type = DataType::Dictionary( + Box::new(K::DATA_TYPE), + Box::new(values.data_type().clone()), + ); + Self { + data_type, + keys: self.keys.clone(), + values, + is_ordered: false, + } } /// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating @@ -930,6 +928,64 @@ where } } +/// Returns a [`AnyDictionaryArray`] if `array` is a dictionary +/// +/// This can be used to efficiently implement kernels for all possible dictionary +/// keys without needing to create specialized implementations for each key type +pub fn as_any_dictionary_array(array: &dyn Array) -> Option<&dyn AnyDictionaryArray> { + downcast_dictionary_array! { + array => Some(array), + _ => None + } +} + +/// A [`DictionaryArray`] with the key type erased +/// +/// See [`as_any_dictionary_array`] +pub trait AnyDictionaryArray: Array { + /// Returns the primitive keys of this dictionary as an [`Array`] + fn keys(&self) -> &dyn Array; + + /// Returns the values of this dictionary + fn values(&self) -> &ArrayRef; + + /// Returns the keys of this dictionary as usize + /// + /// The values for nulls will be arbitrary, but are guaranteed + /// to be in the range `0..self.values.len()` + /// + /// # Panic + /// + /// Panics if `values.len() == 0` + fn normalized_keys(&self) -> Vec; + + /// Create a new [`DictionaryArray`] replacing `values` with the new values + /// + /// See [`DictionaryArray::with_values`] + fn with_values(&self, values: ArrayRef) -> ArrayRef; +} + +impl AnyDictionaryArray for DictionaryArray { + fn keys(&self) -> &dyn Array { + &self.keys + } + + fn values(&self) -> &ArrayRef { + self.values() + } + + fn normalized_keys(&self) -> Vec { + let v_len = self.values().len(); + assert_ne!(v_len, 0); + let iter = self.keys().values().iter(); + iter.map(|x| x.as_usize().min(v_len)).collect() + } + + fn with_values(&self, values: ArrayRef) -> ArrayRef { + Arc::new(self.with_values(values)) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 3cd082c51165..18b5890d4a3a 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1642,7 +1642,7 @@ mod tests { // Construct dictionary with a timezone let dict = a.finish(); let values = TimestampNanosecondArray::from(dict.values().to_data()); - let dict_with_tz = dict.with_values(&values.with_timezone("+02:00")); + let dict_with_tz = dict.with_values(Arc::new(values.with_timezone("+02:00"))); let d = DataType::Dictionary( Box::new(DataType::Int32), Box::new(DataType::Timestamp(