Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict DictionaryArray to ArrowDictionaryKeyType #3813

Merged
merged 2 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions arrow-arith/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ fn math_op_dict<K, T, F>(
op: F,
) -> Result<PrimitiveArray<T>, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> T::Native,
{
Expand Down Expand Up @@ -580,7 +580,7 @@ fn math_checked_op_dict<K, T, F>(
op: F,
) -> Result<PrimitiveArray<T>, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
{
Expand Down Expand Up @@ -613,7 +613,7 @@ fn math_divide_checked_op_dict<K, T, F>(
op: F,
) -> Result<PrimitiveArray<T>, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> Result<T::Native, ArrowError>,
{
Expand Down Expand Up @@ -666,7 +666,7 @@ fn math_divide_safe_op_dict<K, T, F>(
op: F,
) -> Result<ArrayRef, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowNumericType,
F: Fn(T::Native, T::Native) -> Option<T::Native>,
{
Expand Down
5 changes: 3 additions & 2 deletions arrow-arith/src/arity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use arrow_array::builder::BufferBuilder;
use arrow_array::iterator::ArrayIter;
use arrow_array::types::ArrowDictionaryKeyType;
use arrow_array::*;
use arrow_buffer::buffer::{BooleanBuffer, NullBuffer};
use arrow_buffer::{Buffer, MutableBuffer};
Expand Down Expand Up @@ -96,7 +97,7 @@ where
/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> T::Native,
{
Expand All @@ -111,7 +112,7 @@ fn try_unary_dict<K, F, T>(
op: F,
) -> Result<ArrayRef, ArrowError>
where
K: ArrowNumericType,
K: ArrowDictionaryKeyType + ArrowNumericType,
T: ArrowPrimitiveType,
F: Fn(T::Native) -> Result<T::Native, ArrowError>,
{
Expand Down
46 changes: 19 additions & 27 deletions arrow-array/src/array/dictionary_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ pub type UInt64DictionaryArray = DictionaryArray<UInt64Type>;
/// .collect();
/// assert_eq!(&array, &expected);
/// ```
pub struct DictionaryArray<K: ArrowPrimitiveType> {
pub struct DictionaryArray<K: ArrowDictionaryKeyType> {
/// Data of this dictionary. Note that this is _not_ compatible with the C Data interface,
/// as, in the current implementation, `values` below are the first child of this struct.
data: ArrayData,
Expand All @@ -223,7 +223,7 @@ pub struct DictionaryArray<K: ArrowPrimitiveType> {
is_ordered: bool,
}

impl<K: ArrowPrimitiveType> Clone for DictionaryArray<K> {
impl<K: ArrowDictionaryKeyType> Clone for DictionaryArray<K> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
Expand All @@ -234,7 +234,7 @@ impl<K: ArrowPrimitiveType> Clone for DictionaryArray<K> {
}
}

impl<K: ArrowPrimitiveType> DictionaryArray<K> {
impl<K: ArrowDictionaryKeyType> DictionaryArray<K> {
/// Attempt to create a new DictionaryArray with a specified keys
/// (indexes into the dictionary) and values (dictionary)
/// array. Returns an error if there are any keys that are outside
Expand Down Expand Up @@ -436,7 +436,7 @@ impl<K: ArrowPrimitiveType> DictionaryArray<K> {
}

/// Constructs a `DictionaryArray` from an array data reference.
impl<T: ArrowPrimitiveType> From<ArrayData> for DictionaryArray<T> {
impl<T: ArrowDictionaryKeyType> From<ArrayData> for DictionaryArray<T> {
fn from(data: ArrayData) -> Self {
assert_eq!(
data.buffers().len(),
Expand Down Expand Up @@ -482,7 +482,7 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for DictionaryArray<T> {
}
}

impl<T: ArrowPrimitiveType> From<DictionaryArray<T>> for ArrayData {
impl<T: ArrowDictionaryKeyType> From<DictionaryArray<T>> for ArrayData {
fn from(array: DictionaryArray<T>) -> Self {
array.data
}
Expand Down Expand Up @@ -543,7 +543,7 @@ impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for DictionaryArray<T>
}
}

impl<T: ArrowPrimitiveType> Array for DictionaryArray<T> {
impl<T: ArrowDictionaryKeyType> Array for DictionaryArray<T> {
fn as_any(&self) -> &dyn Any {
self
}
Expand All @@ -557,7 +557,7 @@ impl<T: ArrowPrimitiveType> Array for DictionaryArray<T> {
}
}

impl<T: ArrowPrimitiveType> std::fmt::Debug for DictionaryArray<T> {
impl<T: ArrowDictionaryKeyType> std::fmt::Debug for DictionaryArray<T> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(
f,
Expand All @@ -583,15 +583,15 @@ impl<T: ArrowPrimitiveType> std::fmt::Debug for DictionaryArray<T> {
/// assert_eq!(maybe_val.unwrap(), orig)
/// }
/// ```
pub struct TypedDictionaryArray<'a, K: ArrowPrimitiveType, V> {
pub struct TypedDictionaryArray<'a, K: ArrowDictionaryKeyType, V> {
/// The dictionary array
dictionary: &'a DictionaryArray<K>,
/// The values of the dictionary
values: &'a V,
}

// Manually implement `Clone` to avoid `V: Clone` type constraint
impl<'a, K: ArrowPrimitiveType, V> Clone for TypedDictionaryArray<'a, K, V> {
impl<'a, K: ArrowDictionaryKeyType, V> Clone for TypedDictionaryArray<'a, K, V> {
fn clone(&self) -> Self {
Self {
dictionary: self.dictionary,
Expand All @@ -600,15 +600,17 @@ impl<'a, K: ArrowPrimitiveType, V> Clone for TypedDictionaryArray<'a, K, V> {
}
}

impl<'a, K: ArrowPrimitiveType, V> Copy for TypedDictionaryArray<'a, K, V> {}
impl<'a, K: ArrowDictionaryKeyType, V> Copy for TypedDictionaryArray<'a, K, V> {}

impl<'a, K: ArrowPrimitiveType, V> std::fmt::Debug for TypedDictionaryArray<'a, K, V> {
impl<'a, K: ArrowDictionaryKeyType, V> std::fmt::Debug
for TypedDictionaryArray<'a, K, V>
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f, "TypedDictionaryArray({:?})", self.dictionary)
}
}

impl<'a, K: ArrowPrimitiveType, V> TypedDictionaryArray<'a, K, V> {
impl<'a, K: ArrowDictionaryKeyType, V> TypedDictionaryArray<'a, K, V> {
/// Returns the keys of this [`TypedDictionaryArray`]
pub fn keys(&self) -> &'a PrimitiveArray<K> {
self.dictionary.keys()
Expand All @@ -620,7 +622,7 @@ impl<'a, K: ArrowPrimitiveType, V> TypedDictionaryArray<'a, K, V> {
}
}

impl<'a, K: ArrowPrimitiveType, V: Sync> Array for TypedDictionaryArray<'a, K, V> {
impl<'a, K: ArrowDictionaryKeyType, V: Sync> Array for TypedDictionaryArray<'a, K, V> {
fn as_any(&self) -> &dyn Any {
self.dictionary
}
Expand All @@ -636,7 +638,7 @@ impl<'a, K: ArrowPrimitiveType, V: Sync> Array for TypedDictionaryArray<'a, K, V

impl<'a, K, V> IntoIterator for TypedDictionaryArray<'a, K, V>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
Self: ArrayAccessor,
{
type Item = Option<<Self as ArrayAccessor>::Item>;
Expand All @@ -649,7 +651,7 @@ where

impl<'a, K, V> ArrayAccessor for TypedDictionaryArray<'a, K, V>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
V: Sync + Send,
&'a V: ArrayAccessor,
<&'a V as ArrayAccessor>::Item: Default,
Expand Down Expand Up @@ -684,10 +686,8 @@ mod tests {
use super::*;
use crate::builder::PrimitiveDictionaryBuilder;
use crate::cast::as_dictionary_array;
use crate::types::{
Float32Type, Int16Type, Int32Type, Int8Type, UInt32Type, UInt8Type,
};
use crate::{Float32Array, Int16Array, Int32Array, Int8Array};
use crate::types::{Int16Type, Int32Type, Int8Type, UInt32Type, UInt8Type};
use crate::{Int16Array, Int32Array, Int8Array};
use arrow_buffer::{Buffer, ToByteSlice};
use std::sync::Arc;

Expand Down Expand Up @@ -955,14 +955,6 @@ mod tests {
DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(expected = "Dictionary key type must be integer, but was Float32")]
fn test_try_wrong_dictionary_key_type() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now results in a compilation error

let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect();
let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect();
DictionaryArray::<Float32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(
expected = "DictionaryArray's data type must match, expected Int64 got Int32"
Expand Down
2 changes: 1 addition & 1 deletion arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ impl<T: ArrowPrimitiveType> PartialEq for PrimitiveArray<T> {
}
}

impl<K: ArrowPrimitiveType> PartialEq for DictionaryArray<K> {
impl<K: ArrowDictionaryKeyType> PartialEq for DictionaryArray<K> {
fn eq(&self, other: &Self) -> bool {
self.data().eq(other.data())
}
Expand Down
7 changes: 4 additions & 3 deletions arrow-array/src/builder/primitive_dictionary_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::builder::{ArrayBuilder, PrimitiveBuilder};
use crate::types::ArrowDictionaryKeyType;
use crate::{Array, ArrayRef, ArrowPrimitiveType, DictionaryArray};
use arrow_buffer::{ArrowNativeType, ToByteSlice};
use arrow_schema::{ArrowError, DataType};
Expand Down Expand Up @@ -172,7 +173,7 @@ where

impl<K, V> ArrayBuilder for PrimitiveDictionaryBuilder<K, V>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
V: ArrowPrimitiveType,
{
/// Returns the builder as an non-mutable `Any` reference.
Expand Down Expand Up @@ -213,7 +214,7 @@ where

impl<K, V> PrimitiveDictionaryBuilder<K, V>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
V: ArrowPrimitiveType,
{
/// Append a primitive value to the array. Return an existing index
Expand Down Expand Up @@ -312,7 +313,7 @@ where
}
}

impl<K: ArrowPrimitiveType, P: ArrowPrimitiveType> Extend<Option<P::Native>>
impl<K: ArrowDictionaryKeyType, P: ArrowPrimitiveType> Extend<Option<P::Native>>
for PrimitiveDictionaryBuilder<K, P>
{
#[inline]
Expand Down
18 changes: 9 additions & 9 deletions arrow-ord/src/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1253,7 +1253,7 @@ fn unpack_dict_comparison<K>(
dict_comparison: BooleanArray,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
K::Native: num::ToPrimitive,
{
// TODO: Use take_boolean (#2967)
Expand Down Expand Up @@ -2035,7 +2035,7 @@ fn cmp_dict_primitive<K, T, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
T: ArrowPrimitiveType + Sync + Send,
F: Fn(T::Native, T::Native) -> bool,
{
Expand All @@ -2055,7 +2055,7 @@ fn cmp_dict_string_array<K, OffsetSize: OffsetSizeTrait, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(&str, &str) -> bool,
{
compare_op(
Expand All @@ -2078,7 +2078,7 @@ fn cmp_dict_boolean_array<K, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(bool, bool) -> bool,
{
compare_op(
Expand All @@ -2097,7 +2097,7 @@ fn cmp_dict_binary_array<K, OffsetSize: OffsetSizeTrait, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(&[u8], &[u8]) -> bool,
{
compare_op(
Expand All @@ -2121,7 +2121,7 @@ pub fn cmp_dict<K, T, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
T: ArrowPrimitiveType + Sync + Send,
F: Fn(T::Native, T::Native) -> bool,
{
Expand All @@ -2141,7 +2141,7 @@ pub fn cmp_dict_bool<K, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(bool, bool) -> bool,
{
compare_op(
Expand All @@ -2160,7 +2160,7 @@ pub fn cmp_dict_utf8<K, OffsetSize: OffsetSizeTrait, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(&str, &str) -> bool,
{
compare_op(
Expand All @@ -2182,7 +2182,7 @@ pub fn cmp_dict_binary<K, OffsetSize: OffsetSizeTrait, F>(
op: F,
) -> Result<BooleanArray, ArrowError>
where
K: ArrowPrimitiveType,
K: ArrowDictionaryKeyType,
F: Fn(&[u8], &[u8]) -> bool,
{
compare_op(
Expand Down
4 changes: 2 additions & 2 deletions arrow-select/src/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;

use arrow_array::builder::BooleanBufferBuilder;
use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array};
use arrow_array::types::ByteArrayType;
use arrow_array::types::{ArrowDictionaryKeyType, ByteArrayType};
use arrow_array::*;
use arrow_buffer::bit_util;
use arrow_buffer::{buffer::buffer_bin_and, Buffer, MutableBuffer};
Expand Down Expand Up @@ -671,7 +671,7 @@ fn filter_dict<T>(
predicate: &FilterPredicate,
) -> DictionaryArray<T>
where
T: ArrowPrimitiveType,
T: ArrowDictionaryKeyType,
T::Native: num::Num,
{
let builder = filter_primitive::<T>(array.keys(), predicate)
Expand Down
2 changes: 1 addition & 1 deletion arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ fn take_dict<T, I>(
indices: &PrimitiveArray<I>,
) -> Result<DictionaryArray<T>, ArrowError>
where
T: ArrowPrimitiveType,
T: ArrowDictionaryKeyType,
T::Native: num::Num,
I: ArrowPrimitiveType,
I::Native: ToPrimitive,
Expand Down
2 changes: 1 addition & 1 deletion arrow-string/src/like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ macro_rules! dict_function {
///
/// See the documentation on [`like_utf8`] for more details.
#[cfg(feature = "dyn_cmp_dict")]
fn $fn_name<K: ArrowPrimitiveType>(
fn $fn_name<K: arrow_array::types::ArrowDictionaryKeyType>(
left: &DictionaryArray<K>,
right: &DictionaryArray<K>,
) -> Result<BooleanArray, ArrowError> {
Expand Down