Skip to content

Commit

Permalink
Add AnyDictionary Abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Aug 16, 2023
1 parent 77fe72d commit aaae7b0
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 21 deletions.
4 changes: 2 additions & 2 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where
{
let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = unary::<T, F, T>(dict_values, op);
Ok(Arc::new(array.with_values(&values)))
Ok(Arc::new(array.with_values(Arc::new(values))))
}

/// A helper function that applies a fallible unary function to a dictionary array with primitive value type.
Expand All @@ -105,7 +105,7 @@ where

let dict_values = array.values().as_any().downcast_ref().unwrap();
let values = try_unary::<T, F, T>(dict_values, op)?;
Ok(Arc::new(array.with_values(&values)))
Ok(Arc::new(array.with_values(Arc::new(values))))
}

/// Applies an infallible unary function to an array with primitive values.
Expand Down
2 changes: 1 addition & 1 deletion arrow-arith/src/temporal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ where
downcast_dictionary_array!(
array => {
let values = time_fraction_dyn(array.values(), name, op)?;
Ok(Arc::new(array.with_values(&values)))
Ok(Arc::new(array.with_values(values)))
}
dt => return_compute_error_with!(format!("{name} does not support"), dt),
)
Expand Down
90 changes: 73 additions & 17 deletions arrow-array/src/array/dictionary_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::cast::AsArray;
use crate::iterator::ArrayIter;
use crate::types::*;
use crate::{
make_array, Array, ArrayAccessor, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType,
PrimitiveArray, StringArray,
downcast_dictionary_array, make_array, Array, ArrayAccessor, ArrayRef,
ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, StringArray,
};
use arrow_buffer::bit_util::set_bit;
use arrow_buffer::buffer::NullBuffer;
Expand Down Expand Up @@ -434,6 +434,7 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// Panics if `values` has a length less than the current values
///
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::builder::PrimitiveDictionaryBuilder;
/// # use arrow_array::{Int8Array, Int64Array, ArrayAccessor};
/// # use arrow_array::types::{Int32Type, Int8Type};
Expand All @@ -451,7 +452,7 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// let values: Int64Array = typed_dictionary.values().unary(|x| x as i64);
///
/// // Create a Dict(Int32,
/// let new = dictionary.with_values(&values);
/// let new = dictionary.with_values(Arc::new(values));
///
/// // Verify values are as expected
/// let new_typed = new.downcast_dict::<Int64Array>().unwrap();
Expand All @@ -460,21 +461,18 @@ impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// }
/// ```
///
pub fn with_values(&self, values: &dyn Array) -> Self {
pub fn with_values(&self, values: ArrayRef) -> Self {
assert!(values.len() >= self.values.len());

let builder = self
.to_data()
.into_builder()
.data_type(DataType::Dictionary(
Box::new(K::DATA_TYPE),
Box::new(values.data_type().clone()),
))
.child_data(vec![values.to_data()]);

// SAFETY:
// Offsets were valid before and verified length is greater than or equal
Self::from(unsafe { builder.build_unchecked() })
let data_type = DataType::Dictionary(
Box::new(K::DATA_TYPE),
Box::new(values.data_type().clone()),
);
Self {
data_type,
keys: self.keys.clone(),
values,
is_ordered: false,
}
}

/// Returns `PrimitiveDictionaryBuilder` of this dictionary array for mutating
Expand Down Expand Up @@ -930,6 +928,64 @@ where
}
}

/// Returns a [`AnyDictionaryArray`] if `array` is a dictionary
///
/// This can be used to efficiently implement kernels for all possible dictionary
/// keys without needing to create specialized implementations for each key type
pub fn as_any_dictionary_array(array: &dyn Array) -> Option<&dyn AnyDictionaryArray> {
downcast_dictionary_array! {
array => Some(array),
_ => None
}
}

/// A [`DictionaryArray`] with the key type erased
///
/// See [`as_any_dictionary_array`]
pub trait AnyDictionaryArray: Array {
/// Returns the primitive keys of this dictionary as an [`Array`]
fn keys(&self) -> &dyn Array;

/// Returns the values of this dictionary
fn values(&self) -> &ArrayRef;

/// Returns the keys of this dictionary as usize
///
/// The values for nulls will be arbitrary, but are guaranteed
/// to be in the range `0..self.values.len()`
///
/// # Panic
///
/// Panics if `values.len() == 0`
fn normalized_keys(&self) -> Vec<usize>;

/// Create a new [`DictionaryArray`] replacing `values` with the new values
///
/// See [`DictionaryArray::with_values`]
fn with_values(&self, values: ArrayRef) -> ArrayRef;
}

impl<K: ArrowDictionaryKeyType> AnyDictionaryArray for DictionaryArray<K> {
fn keys(&self) -> &dyn Array {
&self.keys
}

fn values(&self) -> &ArrayRef {
self.values()
}

fn normalized_keys(&self) -> Vec<usize> {
let v_len = self.values().len();
assert_ne!(v_len, 0);
let iter = self.keys().values().iter();
iter.map(|x| x.as_usize().min(v_len)).collect()
}

fn with_values(&self, values: ArrayRef) -> ArrayRef {
Arc::new(self.with_values(values))
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion arrow-row/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1642,7 +1642,7 @@ mod tests {
// Construct dictionary with a timezone
let dict = a.finish();
let values = TimestampNanosecondArray::from(dict.values().to_data());
let dict_with_tz = dict.with_values(&values.with_timezone("+02:00"));
let dict_with_tz = dict.with_values(Arc::new(values.with_timezone("+02:00")));
let d = DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Timestamp(
Expand Down

0 comments on commit aaae7b0

Please sign in to comment.