Skip to content

Commit

Permalink
Add Utf8View support to STRPOS function (#12087)
Browse files Browse the repository at this point in the history
* Add Utf8View support to STRPOS function

* fix type inconsistency

* fix type inconsistency

* refactor tests
  • Loading branch information
demetribu authored Aug 21, 2024
1 parent 121f330 commit c6be00d
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 59 deletions.
175 changes: 119 additions & 56 deletions datafusion/functions/src/unicode/strpos.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ use std::any::Any;
use std::sync::Arc;

use arrow::array::{
ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray,
};
use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};

use datafusion_common::cast::as_generic_string_array;
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand Down Expand Up @@ -52,6 +51,9 @@ impl StrposFunc {
Exact(vec![Utf8, LargeUtf8]),
Exact(vec![LargeUtf8, Utf8]),
Exact(vec![LargeUtf8, LargeUtf8]),
Exact(vec![Utf8View, Utf8View]),
Exact(vec![Utf8View, Utf8]),
Exact(vec![Utf8View, LargeUtf8]),
],
Volatility::Immutable,
),
Expand All @@ -78,52 +80,79 @@ impl ScalarUDFImpl for StrposFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match (args[0].data_type(), args[1].data_type()) {
(DataType::Utf8, DataType::Utf8) => {
make_scalar_function(strpos::<Int32Type, Int32Type>, vec![])(args)
}
(DataType::Utf8, DataType::LargeUtf8) => {
make_scalar_function(strpos::<Int32Type, Int64Type>, vec![])(args)
}
(DataType::LargeUtf8, DataType::Utf8) => {
make_scalar_function(strpos::<Int64Type, Int32Type>, vec![])(args)
}
(DataType::LargeUtf8, DataType::LargeUtf8) => {
make_scalar_function(strpos::<Int64Type, Int64Type>, vec![])(args)
}
other => exec_err!("Unsupported data type {other:?} for function strpos"),
}
make_scalar_function(strpos, vec![])(args)
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

fn strpos(args: &[ArrayRef]) -> Result<ArrayRef> {
match (args[0].data_type(), args[1].data_type()) {
(DataType::Utf8, DataType::Utf8) => {
let string_array = args[0].as_string::<i32>();
let substring_array = args[1].as_string::<i32>();
calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
}
(DataType::Utf8, DataType::LargeUtf8) => {
let string_array = args[0].as_string::<i32>();
let substring_array = args[1].as_string::<i64>();
calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
}
(DataType::LargeUtf8, DataType::Utf8) => {
let string_array = args[0].as_string::<i64>();
let substring_array = args[1].as_string::<i32>();
calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
}
(DataType::LargeUtf8, DataType::LargeUtf8) => {
let string_array = args[0].as_string::<i64>();
let substring_array = args[1].as_string::<i64>();
calculate_strpos::<_, _, Int64Type>(string_array, substring_array)
}
(DataType::Utf8View, DataType::Utf8View) => {
let string_array = args[0].as_string_view();
let substring_array = args[1].as_string_view();
calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
}
(DataType::Utf8View, DataType::Utf8) => {
let string_array = args[0].as_string_view();
let substring_array = args[1].as_string::<i32>();
calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
}
(DataType::Utf8View, DataType::LargeUtf8) => {
let string_array = args[0].as_string_view();
let substring_array = args[1].as_string::<i64>();
calculate_strpos::<_, _, Int32Type>(string_array, substring_array)
}

other => {
exec_err!("Unsupported data type combination {other:?} for function strpos")
}
}
}

/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)
/// strpos('high', 'ig') = 2
/// The implementation uses UTF-8 code points as characters
fn strpos<T0: ArrowPrimitiveType, T1: ArrowPrimitiveType>(
args: &[ArrayRef],
fn calculate_strpos<'a, V1, V2, T: ArrowPrimitiveType>(
string_array: V1,
substring_array: V2,
) -> Result<ArrayRef>
where
T0::Native: OffsetSizeTrait,
T1::Native: OffsetSizeTrait,
V1: ArrayAccessor<Item = &'a str>,
V2: ArrayAccessor<Item = &'a str>,
{
let string_array: &GenericStringArray<T0::Native> =
as_generic_string_array::<T0::Native>(&args[0])?;

let substring_array: &GenericStringArray<T1::Native> =
as_generic_string_array::<T1::Native>(&args[1])?;
let string_iter = ArrayIter::new(string_array);
let substring_iter = ArrayIter::new(substring_array);

let result = string_array
.iter()
.zip(substring_array.iter())
let result = string_iter
.zip(substring_iter)
.map(|(string, substring)| match (string, substring) {
(Some(string), Some(substring)) => {
// the find method returns the byte index of the substring
// Next, we count the number of the chars until that byte
T0::Native::from_usize(
// The `find` method returns the byte index of the substring.
// We count the number of chars up to that byte index.
T::Native::from_usize(
string
.find(substring)
.map(|x| string[..x].chars().count() + 1)
Expand All @@ -132,20 +161,21 @@ where
}
_ => None,
})
.collect::<PrimitiveArray<T0>>();
.collect::<PrimitiveArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}

#[cfg(test)]
mod test {
use super::*;
mod tests {
use arrow::array::{Array, Int32Array, Int64Array};
use arrow::datatypes::DataType::{Int32, Int64};

use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};

use crate::unicode::strpos::StrposFunc;
use crate::utils::test::test_function;
use arrow::{
array::{Array as _, Int32Array, Int64Array},
datatypes::DataType::{Int32, Int64},
};
use datafusion_common::ScalarValue;

macro_rules! test_strpos {
($lhs:literal, $rhs:literal -> $result:literal; $t1:ident $t2:ident $t3:ident $t4:ident $t5:ident) => {
Expand All @@ -164,21 +194,54 @@ mod test {
}

#[test]
fn strpos() {
test_strpos!("foo", "bar" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("foobar", "foo" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("foobar", "bar" -> 4; Utf8 Utf8 i32 Int32 Int32Array);

test_strpos!("foo", "bar" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("foobar", "foo" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("foobar", "bar" -> 4; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);

test_strpos!("foo", "bar" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("foobar", "foo" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("foobar", "bar" -> 4; Utf8 LargeUtf8 i32 Int32 Int32Array);

test_strpos!("foo", "bar" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("foobar", "foo" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("foobar", "bar" -> 4; LargeUtf8 Utf8 i64 Int64 Int64Array);
fn test_strpos_functions() {
// Utf8 and Utf8 combinations
test_strpos!("alphabet", "ph" -> 3; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8 Utf8 i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8 Utf8 i32 Int32 Int32Array);

// LargeUtf8 and LargeUtf8 combinations
test_strpos!("alphabet", "ph" -> 3; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "a" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "z" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "" -> 1; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);
test_strpos!("", "a" -> 0; LargeUtf8 LargeUtf8 i64 Int64 Int64Array);

// Utf8 and LargeUtf8 combinations
test_strpos!("alphabet", "ph" -> 3; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8 LargeUtf8 i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8 LargeUtf8 i32 Int32 Int32Array);

// LargeUtf8 and Utf8 combinations
test_strpos!("alphabet", "ph" -> 3; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "a" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "z" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("alphabet", "" -> 1; LargeUtf8 Utf8 i64 Int64 Int64Array);
test_strpos!("", "a" -> 0; LargeUtf8 Utf8 i64 Int64 Int64Array);

// Utf8View and Utf8View combinations
test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8View Utf8View i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8View Utf8View i32 Int32 Int32Array);

// Utf8View and Utf8 combinations
test_strpos!("alphabet", "ph" -> 3; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8View Utf8 i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8View Utf8 i32 Int32 Int32Array);

// Utf8View and LargeUtf8 combinations
test_strpos!("alphabet", "ph" -> 3; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "a" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "z" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("alphabet", "" -> 1; Utf8View LargeUtf8 i32 Int32 Int32Array);
test_strpos!("", "a" -> 0; Utf8View LargeUtf8 i32 Int32 Int32Array);
}
}
5 changes: 2 additions & 3 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1066,9 +1066,8 @@ EXPLAIN SELECT
FROM test;
----
logical_plan
01)Projection: strpos(__common_expr_1, Utf8("f")) AS c, strpos(__common_expr_1, CAST(test.column2_utf8view AS Utf8)) AS c2
02)--Projection: CAST(test.column1_utf8view AS Utf8) AS __common_expr_1, test.column2_utf8view
03)----TableScan: test projection=[column1_utf8view, column2_utf8view]
01)Projection: strpos(test.column1_utf8view, Utf8("f")) AS c, strpos(test.column1_utf8view, test.column2_utf8view) AS c2
02)--TableScan: test projection=[column1_utf8view, column2_utf8view]

## Ensure no casts for SUBSTR
## TODO file ticket
Expand Down

0 comments on commit c6be00d

Please sign in to comment.