Skip to content

Commit

Permalink
Add dictionary array support for substring function (#1665)
Browse files Browse the repository at this point in the history
* initial commit

* add test

* comments

* more comments
  • Loading branch information
sunchao authored May 9, 2022
1 parent e02869a commit daed6ab
Showing 1 changed file with 175 additions and 85 deletions.
260 changes: 175 additions & 85 deletions arrow/src/compute/kernels/substring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,137 @@
//! Defines kernel to extract a substring of an Array
//! Supported array types: \[Large\]StringArray, \[Large\]BinaryArray

use crate::array::DictionaryArray;
use crate::buffer::MutableBuffer;
use crate::datatypes::*;
use crate::{array::*, buffer::Buffer};
use crate::{
datatypes::DataType,
error::{ArrowError, Result},
};
use std::cmp::Ordering;
use std::sync::Arc;

/// Returns an ArrayRef with substrings of all the elements in `array`.
///
/// # Arguments
///
/// * `start` - The start index of all substrings.
/// If `start >= 0`, then count from the start of the string,
/// otherwise count from the end of the string.
///
/// * `length`(option) - The length of all substrings.
/// If `length` is `None`, then the substring is from `start` to the end of the string.
///
/// Attention: Both `start` and `length` are counted by byte, not by char.
///
/// # Basic usage
/// ```
/// # use arrow::array::StringArray;
/// # use arrow::compute::kernels::substring::substring;
/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]);
/// let result = substring(&array, 1, Some(4)).unwrap();
/// let result = result.as_any().downcast_ref::<StringArray>().unwrap();
/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")]));
/// ```
///
/// # Error
/// - The function errors when the passed array is not a \[Large\]String array, \[Large\]Binary
/// array, or DictionaryArray with \[Large\]String or \[Large\]Binary as its value type.
/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array).
///
/// ## Example of trying to get an invalid utf-8 format substring
/// ```
/// # use arrow::array::StringArray;
/// # use arrow::compute::kernels::substring::substring;
/// let array = StringArray::from(vec![Some("E=mc²")]);
/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string();
/// assert!(error.contains("invalid utf-8 boundary"));
/// ```
pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> {
macro_rules! substring_dict {
($kt: ident, $($t: ident: $gt: ident), *) => {
match $kt.as_ref() {
$(
&DataType::$t => {
let dict = array
.as_any()
.downcast_ref::<DictionaryArray<$gt>>()
.unwrap_or_else(|| {
panic!("Expect 'DictionaryArray<{}>' but got array of data type {:?}",
stringify!($gt), array.data_type())
});
let values = substring(dict.values(), start, length)?;
let result = DictionaryArray::try_new(dict.keys(), &values)?;
Ok(Arc::new(result))
},
)*
t => panic!("Unsupported dictionary key type: {}", t)
}
}
}

match array.data_type() {
DataType::Dictionary(kt, _) => {
substring_dict!(
kt,
Int8: Int8Type,
Int16: Int16Type,
Int32: Int32Type,
Int64: Int64Type,
UInt8: UInt8Type,
UInt16: UInt16Type,
UInt32: UInt32Type,
UInt64: UInt64Type
)
}
DataType::LargeBinary => binary_substring(
array
.as_any()
.downcast_ref::<LargeBinaryArray>()
.expect("A large binary is expected"),
start,
length.map(|e| e as i64),
),
DataType::Binary => binary_substring(
array
.as_any()
.downcast_ref::<BinaryArray>()
.expect("A binary is expected"),
start as i32,
length.map(|e| e as i32),
),
DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring(
array
.as_any()
.downcast_ref::<FixedSizeBinaryArray>()
.expect("a fixed size binary is expected"),
*old_len,
start as i32,
length.map(|e| e as i32),
),
DataType::LargeUtf8 => utf8_substring(
array
.as_any()
.downcast_ref::<LargeStringArray>()
.expect("A large string is expected"),
start,
length.map(|e| e as i64),
),
DataType::Utf8 => utf8_substring(
array
.as_any()
.downcast_ref::<StringArray>()
.expect("A string is expected"),
start as i32,
length.map(|e| e as i32),
),
_ => Err(ArrowError::ComputeError(format!(
"substring does not support type {:?}",
array.data_type()
))),
}
}

fn binary_substring<OffsetSize: OffsetSizeTrait>(
array: &GenericBinaryArray<OffsetSize>,
Expand Down Expand Up @@ -215,94 +339,10 @@ fn utf8_substring<OffsetSize: OffsetSizeTrait>(
Ok(make_array(data))
}

/// Returns an ArrayRef with substrings of all the elements in `array`.
///
/// # Arguments
///
/// * `start` - The start index of all substrings.
/// If `start >= 0`, then count from the start of the string,
/// otherwise count from the end of the string.
///
/// * `length`(option) - The length of all substrings.
/// If `length` is `None`, then the substring is from `start` to the end of the string.
///
/// Attention: Both `start` and `length` are counted by byte, not by char.
///
/// # Basic usage
/// ```
/// # use arrow::array::StringArray;
/// # use arrow::compute::kernels::substring::substring;
/// let array = StringArray::from(vec![Some("arrow"), None, Some("rust")]);
/// let result = substring(&array, 1, Some(4)).unwrap();
/// let result = result.as_any().downcast_ref::<StringArray>().unwrap();
/// assert_eq!(result, &StringArray::from(vec![Some("rrow"), None, Some("ust")]));
/// ```
///
/// # Error
/// - The function errors when the passed array is not a \[Large\]String array or \[Large\]Binary array.
/// - The function errors if the offset of a substring in the input array is at invalid char boundary (only for \[Large\]String array).
///
/// ## Example of trying to get an invalid utf-8 format substring
/// ```
/// # use arrow::array::StringArray;
/// # use arrow::compute::kernels::substring::substring;
/// let array = StringArray::from(vec![Some("E=mc²")]);
/// let error = substring(&array, 0, Some(5)).unwrap_err().to_string();
/// assert!(error.contains("invalid utf-8 boundary"));
/// ```
pub fn substring(array: &dyn Array, start: i64, length: Option<u64>) -> Result<ArrayRef> {
match array.data_type() {
DataType::LargeBinary => binary_substring(
array
.as_any()
.downcast_ref::<LargeBinaryArray>()
.expect("A large binary is expected"),
start,
length.map(|e| e as i64),
),
DataType::Binary => binary_substring(
array
.as_any()
.downcast_ref::<BinaryArray>()
.expect("A binary is expected"),
start as i32,
length.map(|e| e as i32),
),
DataType::FixedSizeBinary(old_len) => fixed_size_binary_substring(
array
.as_any()
.downcast_ref::<FixedSizeBinaryArray>()
.expect("a fixed size binary is expected"),
*old_len,
start as i32,
length.map(|e| e as i32),
),
DataType::LargeUtf8 => utf8_substring(
array
.as_any()
.downcast_ref::<LargeStringArray>()
.expect("A large string is expected"),
start,
length.map(|e| e as i64),
),
DataType::Utf8 => utf8_substring(
array
.as_any()
.downcast_ref::<StringArray>()
.expect("A string is expected"),
start as i32,
length.map(|e| e as i32),
),
_ => Err(ArrowError::ComputeError(format!(
"substring does not support type {:?}",
array.data_type()
))),
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::datatypes::*;

#[allow(clippy::type_complexity)]
fn with_nulls_generic_binary<O: OffsetSizeTrait>() -> Result<()> {
Expand Down Expand Up @@ -954,6 +994,56 @@ mod tests {
without_nulls_generic_string::<i64>()
}

#[test]
fn dictionary() -> Result<()> {
_dictionary::<Int8Type>()?;
_dictionary::<Int16Type>()?;
_dictionary::<Int32Type>()?;
_dictionary::<Int64Type>()?;
_dictionary::<UInt8Type>()?;
_dictionary::<UInt16Type>()?;
_dictionary::<UInt32Type>()?;
_dictionary::<UInt64Type>()?;
Ok(())
}

fn _dictionary<K: ArrowDictionaryKeyType>() -> Result<()> {
const TOTAL: i32 = 100;

let v = ["aaa", "bbb", "ccc", "ddd", "eee"];
let data: Vec<Option<&str>> = (0..TOTAL)
.map(|n| {
let i = n % 5;
if i == 3 {
None
} else {
Some(v[i as usize])
}
})
.collect();

let dict_array: DictionaryArray<K> = data.clone().into_iter().collect();

let expected: Vec<Option<&str>> =
data.iter().map(|opt| opt.map(|s| &s[1..3])).collect();

let res = substring(&dict_array, 1, Some(2))?;
let actual = res.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
let actual: Vec<Option<&str>> = actual
.values()
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.unwrap()
.take_iter(actual.keys_iter())
.collect();

for i in 0..TOTAL as usize {
assert_eq!(expected[i], actual[i],);
}

Ok(())
}

#[test]
fn check_invalid_array_type() {
let array = Int32Array::from(vec![Some(1), Some(2), Some(3)]);
Expand Down

0 comments on commit daed6ab

Please sign in to comment.