Skip to content

Commit

Permalink
Restrict DictionaryArray to ArrowDictionaryKeyType (#3813)
Browse files Browse the repository at this point in the history
* Restrict DictionaryArray to ArrowDictionaryKeyType

* Fixes
  • Loading branch information
tustvold authored Mar 7, 2023
1 parent 379bd23 commit 6678b23
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 50 deletions.
8 changes: 4 additions & 4 deletions arrow-arith/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ fn math_op_dict<K, T, F>(
op: F,
) -> Result<PrimitiveArray<T>, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> T::Native,
{
Expand Down Expand Up @@ -580,7 +580,7 @@ fn math_checked_op_dict<K, T, F>(
op: F,
) -> Result<PrimitiveArray<T>, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
{
Expand Down Expand Up @@ -613,7 +613,7 @@ fn math_divide_checked_op_dict<K, T, F>(
op: F,
) -> Result<PrimitiveArray<T>, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
{
Expand Down Expand Up @@ -666,7 +666,7 @@ fn math_divide_safe_op_dict<K, T, F>(
op: F,
) -> Result<ArrayRef, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> Option<T::Native>,
{
Expand Down
5 changes: 3 additions & 2 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use arrow_array::builder::BufferBuilder;
use arrow_array::iterator::ArrayIter;
use arrow_array::types::ArrowDictionaryKeyType;
use arrow_array::*;
use arrow_buffer::buffer::{BooleanBuffer, NullBuffer};
use arrow_buffer::{Buffer, MutableBuffer};
Expand Down Expand Up @@ -96,7 +97,7 @@ where
/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
Expand All @@ -111,7 +112,7 @@ fn try_unary_dict<K, F, T>(
op: F,
) -> Result<ArrayRef, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native, ArrowError>,
{
Expand Down
46 changes: 19 additions & 27 deletions arrow-array/src/array/dictionary_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ pub type UInt64DictionaryArray = DictionaryArray<UInt64Type>;
/// .collect();
/// assert_eq!(&array, &expected);
/// ```
pub struct DictionaryArray<K: ArrowPrimitiveType> {
pub struct DictionaryArray<K: ArrowDictionaryKeyType> {
/// Data of this dictionary. Note that this is _not_ compatible with the C Data interface,
/// as, in the current implementation, `values` below are the first child of this struct.
data: ArrayData,
Expand All @@ -223,7 +223,7 @@ pub struct DictionaryArray<K: ArrowPrimitiveType> {
is_ordered: bool,
}

impl<K: ArrowPrimitiveType> Clone for DictionaryArray<K> {
impl<K: ArrowDictionaryKeyType> Clone for DictionaryArray<K> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
Expand All @@ -234,7 +234,7 @@ impl<K: ArrowPrimitiveType> Clone for DictionaryArray<K> {
}
}

impl<K: ArrowPrimitiveType> DictionaryArray<K> {
impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// Attempt to create a new DictionaryArray with a specified keys
/// (indexes into the dictionary) and values (dictionary)
/// array. Returns an error if there are any keys that are outside
Expand Down Expand Up @@ -436,7 +436,7 @@ impl<K: ArrowPrimitiveType> DictionaryArray<K> {
}

/// Constructs a `DictionaryArray` from an array data reference.
impl<T: ArrowPrimitiveType> From<ArrayData> for DictionaryArray<T> {
impl<T: ArrowDictionaryKeyType> From<ArrayData> for DictionaryArray<T> {
fn from(data: ArrayData) -> Self {
assert_eq!(
data.buffers().len(),
Expand Down Expand Up @@ -482,7 +482,7 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for DictionaryArray<T> {
}
}

impl<T: ArrowPrimitiveType> From<DictionaryArray<T>> for ArrayData {
impl<T: ArrowDictionaryKeyType> From<DictionaryArray<T>> for ArrayData {
fn from(array: DictionaryArray<T>) -> Self {
array.data
}
Expand Down Expand Up @@ -543,7 +543,7 @@ impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for DictionaryArray<T>
}
}

impl<T: ArrowPrimitiveType> Array for DictionaryArray<T> {
impl<T: ArrowDictionaryKeyType> Array for DictionaryArray<T> {
fn as_any(&self) -> &dyn Any {
self
}
Expand All @@ -557,7 +557,7 @@ impl<T: ArrowPrimitiveType> Array for DictionaryArray<T> {
}
}

impl<T: ArrowPrimitiveType> std::fmt::Debug for DictionaryArray<T> {
impl<T: ArrowDictionaryKeyType> std::fmt::Debug for DictionaryArray<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(
f,
Expand All @@ -583,15 +583,15 @@ impl<T: ArrowPrimitiveType> std::fmt::Debug for DictionaryArray<T> {
/// assert_eq!(maybe_val.unwrap(), orig)
/// }
/// ```
pub struct TypedDictionaryArray<'a, K: ArrowPrimitiveType, V> {
pub struct TypedDictionaryArray<'a, K: ArrowDictionaryKeyType, V> {
/// The dictionary array
dictionary: &'a DictionaryArray<K>,
/// The values of the dictionary
values: &'a V,
}

// Manually implement `Clone` to avoid `V: Clone` type constraint
impl<'a, K: ArrowPrimitiveType, V> Clone for TypedDictionaryArray<'a, K, V> {
impl<'a, K: ArrowDictionaryKeyType, V> Clone for TypedDictionaryArray<'a, K, V> {
fn clone(&self) -> Self {
Self {
dictionary: self.dictionary,
Expand All @@ -600,15 +600,17 @@ impl<'a, K: ArrowPrimitiveType, V> Clone for TypedDictionaryArray<'a, K, V> {
}
}

impl<'a, K: ArrowPrimitiveType, V> Copy for TypedDictionaryArray<'a, K, V> {}
impl<'a, K: ArrowDictionaryKeyType, V> Copy for TypedDictionaryArray<'a, K, V> {}

impl<'a, K: ArrowPrimitiveType, V> std::fmt::Debug for TypedDictionaryArray<'a, K, V> {
impl<'a, K: ArrowDictionaryKeyType, V> std::fmt::Debug
for TypedDictionaryArray<'a, K, V>
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f, "TypedDictionaryArray({:?})", self.dictionary)
}
}

impl<'a, K: ArrowPrimitiveType, V> TypedDictionaryArray<'a, K, V> {
impl<'a, K: ArrowDictionaryKeyType, V> TypedDictionaryArray<'a, K, V> {
/// Returns the keys of this [`TypedDictionaryArray`]
pub fn keys(&self) -> &'a PrimitiveArray<K> {
self.dictionary.keys()
Expand All @@ -620,7 +622,7 @@ impl<'a, K: ArrowPrimitiveType, V> TypedDictionaryArray<'a, K, V> {
}
}

impl<'a, K: ArrowPrimitiveType, V: Sync> Array for TypedDictionaryArray<'a, K, V> {
impl<'a, K: ArrowDictionaryKeyType, V: Sync> Array for TypedDictionaryArray<'a, K, V> {
fn as_any(&self) -> &dyn Any {
self.dictionary
}
Expand All @@ -636,7 +638,7 @@ impl<'a, K: ArrowPrimitiveType, V: Sync> Array for TypedDictionaryArray<'a, K, V

impl<'a, K, V> IntoIterator for TypedDictionaryArray<'a, K, V>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
Self: ArrayAccessor,
{
type Item = Option<<Self as ArrayAccessor>::Item>;
Expand All @@ -649,7 +651,7 @@ where

impl<'a, K, V> ArrayAccessor for TypedDictionaryArray<'a, K, V>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
V: Sync + Send,
&'a V: ArrayAccessor,
<&'a V as ArrayAccessor>::Item: Default,
Expand Down Expand Up @@ -684,10 +686,8 @@ mod tests {
use super::*;
use crate::builder::PrimitiveDictionaryBuilder;
use crate::cast::as_dictionary_array;
use crate::types::{
Float32Type, Int16Type, Int32Type, Int8Type, UInt32Type, UInt8Type,
};
use crate::{Float32Array, Int16Array, Int32Array, Int8Array};
use crate::types::{Int16Type, Int32Type, Int8Type, UInt32Type, UInt8Type};
use crate::{Int16Array, Int32Array, Int8Array};
use arrow_buffer::{Buffer, ToByteSlice};
use std::sync::Arc;

Expand Down Expand Up @@ -955,14 +955,6 @@ mod tests {
DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(expected = "Dictionary key type must be integer, but was Float32")]
fn test_try_wrong_dictionary_key_type() {
let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect();
let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect();
DictionaryArray::<Float32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(
expected = "DictionaryArray's data type must match, expected Int64 got Int32"
Expand Down
2 changes: 1 addition & 1 deletion arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ impl<T: ArrowPrimitiveType> PartialEq for PrimitiveArray<T> {
}
}

impl<K: ArrowPrimitiveType> PartialEq for DictionaryArray<K> {
impl<K: ArrowDictionaryKeyType> PartialEq for DictionaryArray<K> {
fn eq(&self, other: &Self) -> bool {
self.data().eq(other.data())
}
Expand Down
7 changes: 4 additions & 3 deletions arrow-array/src/builder/primitive_dictionary_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::builder::{ArrayBuilder, PrimitiveBuilder};
use crate::types::ArrowDictionaryKeyType;
use crate::{Array, ArrayRef, ArrowPrimitiveType, DictionaryArray};
use arrow_buffer::{ArrowNativeType, ToByteSlice};
use arrow_schema::{ArrowError, DataType};
Expand Down Expand Up @@ -172,7 +173,7 @@ where

impl<K, V> ArrayBuilder for PrimitiveDictionaryBuilder<K, V>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
V: ArrowPrimitiveType,
{
/// Returns the builder as an non-mutable `Any` reference.
Expand Down Expand Up @@ -213,7 +214,7 @@ where

impl<K, V> PrimitiveDictionaryBuilder<K, V>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
V: ArrowPrimitiveType,
{
/// Append a primitive value to the array. Return an existing index
Expand Down Expand Up @@ -312,7 +313,7 @@ where
}
}

impl<K: ArrowPrimitiveType, P: ArrowPrimitiveType> Extend<Option<P::Native>>
impl<K: ArrowDictionaryKeyType, P: ArrowPrimitiveType> Extend<Option<P::Native>>
for PrimitiveDictionaryBuilder<K, P>
{
#[inline]
Expand Down
18 changes: 9 additions & 9 deletions arrow-ord/src/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ fn unpack_dict_comparison<K>(
dict_comparison: BooleanArray,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
K::Native: num::ToPrimitive,
{
// TODO: Use take_boolean (#2967)
Expand Down Expand Up @@ -2035,7 +2035,7 @@ fn cmp_dict_primitive<K, T, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
T: ArrowPrimitiveType + Sync + Send,
F: Fn(T::Native, T::Native) -> bool,
{
Expand All @@ -2055,7 +2055,7 @@ fn cmp_dict_string_array<K, OffsetSize: OffsetSizeTrait, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(&str, &str) -> bool,
{
compare_op(
Expand All @@ -2078,7 +2078,7 @@ fn cmp_dict_boolean_array<K, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(bool, bool) -> bool,
{
compare_op(
Expand All @@ -2097,7 +2097,7 @@ fn cmp_dict_binary_array<K, OffsetSize: OffsetSizeTrait, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(&[u8], &[u8]) -> bool,
{
compare_op(
Expand All @@ -2121,7 +2121,7 @@ pub fn cmp_dict<K, T, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
T: ArrowPrimitiveType + Sync + Send,
F: Fn(T::Native, T::Native) -> bool,
{
Expand All @@ -2141,7 +2141,7 @@ pub fn cmp_dict_bool<K, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(bool, bool) -> bool,
{
compare_op(
Expand All @@ -2160,7 +2160,7 @@ pub fn cmp_dict_utf8<K, OffsetSize: OffsetSizeTrait, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(&str, &str) -> bool,
{
compare_op(
Expand All @@ -2182,7 +2182,7 @@ pub fn cmp_dict_binary<K, OffsetSize: OffsetSizeTrait, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(&[u8], &[u8]) -> bool,
{
compare_op(
Expand Down
4 changes: 2 additions & 2 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;

use arrow_array::builder::BooleanBufferBuilder;
use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array};
use arrow_array::types::ByteArrayType;
use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
use arrow_array::*;
use arrow_buffer::bit_util;
use arrow_buffer::{buffer::buffer_bin_and, Buffer, MutableBuffer};
Expand Down Expand Up @@ -671,7 +671,7 @@ fn filter_dict<T>(
predicate: &FilterPredicate,
) -> DictionaryArray<T>
where
T: ArrowPrimitiveType,
T: ArrowDictionaryKeyType,
T::Native: num::Num,
{
let builder = filter_primitive::<T>(array.keys(), predicate)
Expand Down
2 changes: 1 addition & 1 deletion arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ fn take_dict<T, I>(
indices: &PrimitiveArray<I>,
) -> Result<DictionaryArray<T>, ArrowError>
where
T: ArrowPrimitiveType,
T: ArrowDictionaryKeyType,
T::Native: num::Num,
I: ArrowPrimitiveType,
I::Native: ToPrimitive,
Expand Down
2 changes: 1 addition & 1 deletion arrow-string/src/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ macro_rules! dict_function {
///
/// See the documentation on [`like_utf8`] for more details.
#[cfg(feature = "dyn_cmp_dict")]
fn $fn_name<K: ArrowPrimitiveType>(
fn $fn_name<K: arrow_array::types::ArrowDictionaryKeyType>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
) -> Result<BooleanArray, ArrowError> {
Expand Down

0 comments on commit 6678b23

Please sign in to comment.