From aaae7b0c73b24e9ba73522f93795d2e5e6fd6908 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 16 Aug 2023 16:52:34 +0100 Subject: [PATCH 1/3] Add AnyDictionary Abstraction --- arrow-arith/src/arity.rs | 4 +- arrow-arith/src/temporal.rs | 2 +- arrow-array/src/array/dictionary_array.rs | 90 ++++++++++++++++++----- arrow-row/src/lib.rs | 2 +- 4 files changed, 77 insertions(+), 21 deletions(-) diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index fdfb26f7f72a..13617c24a80f 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -82,7 +82,7 @@ where { let dict_values = array.values().as_any().downcast_ref().unwrap(); let values = unary::(dict_values, op); - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(Arc::new(values)))) } /// A helper function that applies a fallible unary function to a dictionary array with primitive value type. @@ -105,7 +105,7 @@ where let dict_values = array.values().as_any().downcast_ref().unwrap(); let values = try_unary::(dict_values, op)?; - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(Arc::new(values)))) } /// Applies an infallible unary function to an array with primitive values. diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs index ef551ceeddb7..7855b6fc6e46 100644 --- a/arrow-arith/src/temporal.rs +++ b/arrow-arith/src/temporal.rs @@ -462,7 +462,7 @@ where downcast_dictionary_array!( array => { let values = time_fraction_dyn(array.values(), name, op)?; - Ok(Arc::new(array.with_values(&values))) + Ok(Arc::new(array.with_values(values))) } dt => return_compute_error_with!(format!("{name} does not support"), dt), ) diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index 2d80c75f073a..d7060b3baf35 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -20,8 +20,8 @@ use crate::cast::AsArray; use crate::iterator::ArrayIter; use crate::types::*; use crate::{ - make_array, Array, ArrayAccessor, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, - PrimitiveArray, StringArray, + downcast_dictionary_array, make_array, Array, ArrayAccessor, ArrayRef, + ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, StringArray, }; use arrow_buffer::bit_util::set_bit; use arrow_buffer::buffer::NullBuffer; @@ -434,6 +434,7 @@ impl DictionaryArray { /// Panics if `values` has a length less than the current values /// /// ``` + /// # use std::sync::Arc; /// # use arrow_array::builder::PrimitiveDictionaryBuilder; /// # use arrow_array::{Int8Array, Int64Array, ArrayAccessor}; /// # use arrow_array::types::{Int32Type, Int8Type}; @@ -451,7 +452,7 @@ impl DictionaryArray { /// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64); /// /// // Create a Dict(Int32, - /// let new = dictionary.with_values(&values); + /// let new = dictionary.with_values(Arc::new(values)); /// /// // Verify values are as expected /// let new_typed = new.downcast_dict::().unwrap(); @@ -460,21 +461,18 @@ impl DictionaryArray { /// } /// ``` /// - pub fn with_values(&self, values: &dyn Array) -> Self { + pub fn with_values(&self, values: ArrayRef) -> Self { assert!(values.len() >= self.values.len()); - - let builder = self - .to_data() - .into_builder() - .data_type(DataType::Dictionary( - Box::new(K::DATA_TYPE), - Box::new(values.data_type().clone()), - )) - .child_data(vec![values.to_data()]); - - // SAFETY: - // Offsets were valid before and verified length is greater than or equal - Self::from(unsafe { builder.build_unchecked() }) + let data_type = DataType::Dictionary( + Box::new(K::DATA_TYPE), + Box::new(values.data_type().clone()), + ); + Self { + data_type, + keys: self.keys.clone(), + values, + is_ordered: false, + } } /// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating @@ -930,6 +928,64 @@ where } } +/// Returns a [`AnyDictionaryArray`] if `array` is a dictionary +/// +/// This can be used to efficiently implement kernels for all possible dictionary +/// keys without needing to create specialized implementations for each key type +pub fn as_any_dictionary_array(array: &dyn Array) -> Option<&dyn AnyDictionaryArray> { + downcast_dictionary_array! { + array => Some(array), + _ => None + } +} + +/// A [`DictionaryArray`] with the key type erased +/// +/// See [`as_any_dictionary_array`] +pub trait AnyDictionaryArray: Array { + /// Returns the primitive keys of this dictionary as an [`Array`] + fn keys(&self) -> &dyn Array; + + /// Returns the values of this dictionary + fn values(&self) -> &ArrayRef; + + /// Returns the keys of this dictionary as usize + /// + /// The values for nulls will be arbitrary, but are guaranteed + /// to be in the range `0..self.values.len()` + /// + /// # Panic + /// + /// Panics if `values.len() == 0` + fn normalized_keys(&self) -> Vec; + + /// Create a new [`DictionaryArray`] replacing `values` with the new values + /// + /// See [`DictionaryArray::with_values`] + fn with_values(&self, values: ArrayRef) -> ArrayRef; +} + +impl AnyDictionaryArray for DictionaryArray { + fn keys(&self) -> &dyn Array { + &self.keys + } + + fn values(&self) -> &ArrayRef { + self.values() + } + + fn normalized_keys(&self) -> Vec { + let v_len = self.values().len(); + assert_ne!(v_len, 0); + let iter = self.keys().values().iter(); + iter.map(|x| x.as_usize().min(v_len)).collect() + } + + fn with_values(&self, values: ArrayRef) -> ArrayRef { + Arc::new(self.with_values(values)) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 3cd082c51165..18b5890d4a3a 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -1642,7 +1642,7 @@ mod tests { // Construct dictionary with a timezone let dict = a.finish(); let values = TimestampNanosecondArray::from(dict.values().to_data()); - let dict_with_tz = dict.with_values(&values.with_timezone("+02:00")); + let dict_with_tz = dict.with_values(Arc::new(values.with_timezone("+02:00"))); let d = DataType::Dictionary( Box::new(DataType::Int32), Box::new(DataType::Timestamp( From b95d7629eea6226c75130c1664bc2541ebe612a3 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 16 Aug 2023 20:42:14 +0100 Subject: [PATCH 2/3] Review feedback --- arrow-arith/src/arity.rs | 4 +++ arrow-array/src/array/dictionary_array.rs | 38 +++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index 13617c24a80f..f8db44745859 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -109,6 +109,7 @@ where } /// Applies an infallible unary function to an array with primitive values. +#[deprecated(note = "Use arrow_array::as_any_dictionary_array")] pub fn unary_dyn(array: &dyn Array, op: F) -> Result where T: ArrowPrimitiveType, @@ -134,6 +135,7 @@ where } /// Applies a fallible unary function to an array with primitive values. +#[deprecated(note = "Use arrow_array::as_any_dictionary_array")] pub fn try_unary_dyn(array: &dyn Array, op: F) -> Result where T: ArrowPrimitiveType, @@ -436,6 +438,7 @@ mod tests { use arrow_array::types::*; #[test] + #[allow(deprecated)] fn test_unary_f64_slice() { let input = Float64Array::from(vec![Some(5.1f64), None, Some(6.8), None, Some(7.2)]); @@ -455,6 +458,7 @@ mod tests { } #[test] + #[allow(deprecated)] fn test_unary_dict_and_unary_dyn() { let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(5).unwrap(); diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index d7060b3baf35..ed4f7d447157 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -932,6 +932,44 @@ where /// /// This can be used to efficiently implement kernels for all possible dictionary /// keys without needing to create specialized implementations for each key type +/// +/// For example +/// +/// ``` +/// # use arrow_array::*; +/// # use arrow_array::cast::AsArray; +/// # use arrow_array::builder::PrimitiveDictionaryBuilder; +/// # use arrow_array::types::*; +/// # use arrow_schema::ArrowError; +/// # use std::sync::Arc; +/// +/// fn to_string(a: &dyn Array) -> Result { +/// if let Some(d) = as_any_dictionary_array(a) { +/// // Recursively handle dictionary input +/// let r = to_string(d.values().as_ref())?; +/// return Ok(d.with_values(r)); +/// } +/// downcast_primitive_array! { +/// a => Ok(Arc::new(a.iter().map(|x| x.map(|x| x.to_string())).collect::())), +/// d => Err(ArrowError::InvalidArgumentError(format!("{d:?} not supported"))) +/// } +/// } +/// +/// let result = to_string(&Int32Array::from(vec![1, 2, 3])).unwrap(); +/// let actual = result.as_string::().iter().map(Option::unwrap).collect::>(); +/// assert_eq!(actual, &["1", "2", "3"]); +/// +/// let mut dict = PrimitiveDictionaryBuilder::::new(); +/// dict.extend([Some(1), Some(1), Some(2), Some(3), Some(2)]); +/// let dict = dict.finish(); +/// +/// let r = to_string(&dict).unwrap(); +/// let r = r.as_dictionary::().downcast_dict::().unwrap(); +/// assert_eq!(r.keys(), dict.keys()); // Keys are the same +/// +/// let actual = r.into_iter().map(Option::unwrap).collect::>(); +/// assert_eq!(actual, &["1", "1", "2", "3", "2"]); +/// ``` pub fn as_any_dictionary_array(array: &dyn Array) -> Option<&dyn AnyDictionaryArray> { downcast_dictionary_array! { array => Some(array), From 98bae7b64babdfd8df66d2224626609ac8b65de9 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 16 Aug 2023 20:55:32 +0100 Subject: [PATCH 3/3] Move to AsArray --- arrow-arith/src/arity.rs | 4 ++-- arrow-array/src/array/dictionary_array.rs | 18 +++++------------- arrow-array/src/cast.rs | 20 ++++++++++++++++++++ 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index f8db44745859..f3118d104536 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -109,7 +109,7 @@ where } /// Applies an infallible unary function to an array with primitive values. -#[deprecated(note = "Use arrow_array::as_any_dictionary_array")] +#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] pub fn unary_dyn(array: &dyn Array, op: F) -> Result where T: ArrowPrimitiveType, @@ -135,7 +135,7 @@ where } /// Applies a fallible unary function to an array with primitive values. -#[deprecated(note = "Use arrow_array::as_any_dictionary_array")] +#[deprecated(note = "Use arrow_array::AnyDictionaryArray")] pub fn try_unary_dyn(array: &dyn Array, op: F) -> Result where T: ArrowPrimitiveType, diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index ed4f7d447157..ed043754da4b 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -20,8 +20,8 @@ use crate::cast::AsArray; use crate::iterator::ArrayIter; use crate::types::*; use crate::{ - downcast_dictionary_array, make_array, Array, ArrayAccessor, ArrayRef, - ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, StringArray, + make_array, Array, ArrayAccessor, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, + PrimitiveArray, StringArray, }; use arrow_buffer::bit_util::set_bit; use arrow_buffer::buffer::NullBuffer; @@ -928,7 +928,7 @@ where } } -/// Returns a [`AnyDictionaryArray`] if `array` is a dictionary +/// A [`DictionaryArray`] with the key type erased /// /// This can be used to efficiently implement kernels for all possible dictionary /// keys without needing to create specialized implementations for each key type @@ -944,7 +944,7 @@ where /// # use std::sync::Arc; /// /// fn to_string(a: &dyn Array) -> Result { -/// if let Some(d) = as_any_dictionary_array(a) { +/// if let Some(d) = a.as_any_dictionary_opt() { /// // Recursively handle dictionary input /// let r = to_string(d.values().as_ref())?; /// return Ok(d.with_values(r)); @@ -970,16 +970,8 @@ where /// let actual = r.into_iter().map(Option::unwrap).collect::>(); /// assert_eq!(actual, &["1", "1", "2", "3", "2"]); /// ``` -pub fn as_any_dictionary_array(array: &dyn Array) -> Option<&dyn AnyDictionaryArray> { - downcast_dictionary_array! { - array => Some(array), - _ => None - } -} - -/// A [`DictionaryArray`] with the key type erased /// -/// See [`as_any_dictionary_array`] +/// See [`AsArray::as_any_dictionary_opt`] and [`AsArray::as_any_dictionary`] pub trait AnyDictionaryArray: Array { /// Returns the primitive keys of this dictionary as an [`Array`] fn keys(&self) -> &dyn Array; diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index 66b40d5b8eb3..b6cda44e8973 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -833,6 +833,14 @@ pub trait AsArray: private::Sealed { fn as_dictionary(&self) -> &DictionaryArray { self.as_dictionary_opt().expect("dictionary array") } + + /// Downcasts this to a [`AnyDictionaryArray`] returning `None` if not possible + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray>; + + /// Downcasts this to a [`AnyDictionaryArray`] panicking if not possible + fn as_any_dictionary(&self) -> &dyn AnyDictionaryArray { + self.as_any_dictionary_opt().expect("any dictionary array") + } } impl private::Sealed for dyn Array + '_ {} @@ -874,6 +882,14 @@ impl AsArray for dyn Array + '_ { ) -> Option<&DictionaryArray> { self.as_any().downcast_ref() } + + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> { + let array = self; + downcast_dictionary_array! { + array => Some(array), + _ => None + } + } } impl private::Sealed for ArrayRef {} @@ -915,6 +931,10 @@ impl AsArray for ArrayRef { ) -> Option<&DictionaryArray> { self.as_ref().as_dictionary_opt() } + + fn as_any_dictionary_opt(&self) -> Option<&dyn AnyDictionaryArray> { + self.as_ref().as_any_dictionary_opt() + } } #[cfg(test)]