diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index ce5e0064362b..5caa6acd6745 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -18,16 +18,21 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, Int64Array, + OffsetSizeTrait, StringViewArray, +}; use arrow::datatypes::DataType; -use datafusion_common::cast::{as_generic_string_array, as_int64_array}; use unicode_segmentation::UnicodeSegmentation; +use DataType::{LargeUtf8, Utf8, Utf8View}; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::cast::as_int64_array; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use crate::utils::{make_scalar_function, utf8_to_str_type}; + #[derive(Debug)] pub struct LPadFunc { signature: Signature, @@ -45,11 +50,17 @@ impl LPadFunc { Self { signature: Signature::one_of( vec![ + Exact(vec![Utf8View, Int64]), + Exact(vec![Utf8View, Int64, Utf8View]), + Exact(vec![Utf8View, Int64, Utf8]), + Exact(vec![Utf8View, Int64, LargeUtf8]), Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8View]), Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![LargeUtf8, Int64, Utf8View]), + Exact(vec![LargeUtf8, Int64, Utf8]), Exact(vec![LargeUtf8, Int64, LargeUtf8]), ], Volatility::Immutable, @@ -76,300 +87,450 @@ impl ScalarUDFImpl for LPadFunc { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - match args[0].data_type() { - DataType::Utf8 => make_scalar_function(lpad::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(lpad::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function lpad"), - } + make_scalar_function(lpad, vec![])(args) } } -/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// Extends the string to length 'length' by prepending the characters fill (a space by default). +/// If the string is already longer than length then it is truncated (on the right). /// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } +pub fn lpad(args: &[ArrayRef]) -> Result { + if args.len() <= 1 || args.len() > 3 { + return exec_err!( + "lpad was called with {} arguments. It requires at least 2 and at most 3.", + args.len() + ); + } + + let length_array = as_int64_array(&args[1])?; + + match args[0].data_type() { + Utf8 => match args.len() { + 2 => lpad_impl::<&GenericStringArray, &GenericStringArray, i32>( + args[0].as_string::(), + length_array, + None, + ), + 3 => lpad_with_replace::<&GenericStringArray, i32>( + args[0].as_string::(), + length_array, + &args[2], + ), + _ => unreachable!(), + }, + LargeUtf8 => match args.len() { + 2 => lpad_impl::<&GenericStringArray, &GenericStringArray, i64>( + args[0].as_string::(), + length_array, + None, + ), + 3 => lpad_with_replace::<&GenericStringArray, i64>( + args[0].as_string::(), + length_array, + &args[2], + ), + _ => unreachable!(), + }, + Utf8View => match args.len() { + 2 => lpad_impl::<&StringViewArray, &GenericStringArray, i32>( + args[0].as_string_view(), + length_array, + None, + ), + 3 => lpad_with_replace::<&StringViewArray, i32>( + args[0].as_string_view(), + length_array, + &args[2], + ), + _ => unreachable!(), + }, + other => { + exec_err!("Unsupported data type {other:?} for function lpad") + } + } +} - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) +fn lpad_with_replace<'a, V, T: OffsetSizeTrait>( + string_array: V, + length_array: &Int64Array, + fill_array: &'a ArrayRef, +) -> Result +where + V: StringArrayType<'a>, +{ + match fill_array.data_type() { + Utf8 => lpad_impl::, T>( + string_array, + length_array, + Some(fill_array.as_string::()), + ), + LargeUtf8 => lpad_impl::, T>( + string_array, + length_array, + Some(fill_array.as_string::()), + ), + Utf8View => lpad_impl::( + string_array, + length_array, + Some(fill_array.as_string_view()), + ), + other => { + exec_err!("Unsupported data type {other:?} for function lpad") + } + } +} + +fn lpad_impl<'a, V, V2, T>( + string_array: V, + length_array: &Int64Array, + fill_array: Option, +) -> Result +where + V: StringArrayType<'a>, + V2: StringArrayType<'a>, + T: OffsetSizeTrait, +{ + if fill_array.is_none() { + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } + let mut s: String = " ".repeat(length - graphemes.len()); + s.push_str(string); + Ok(Some(s)) } } - _ => Ok(None), - }) - .collect::>>()?; + } + _ => Ok(None), + }) + .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } + Ok(Arc::new(result) as ArrayRef) + } else { + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.unwrap().iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!("lpad requested length {length} too large"); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector.push( - *fill_chars.get(l % fill_chars.len()).unwrap(), - ); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector + .push(*fill_chars.get(l % fill_chars.len()).unwrap()); } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + Ok(Some(s)) } } - _ => Ok(None), - }) - .collect::>>()?; + } + _ => Ok(None), + }) + .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "lpad was called with {other} arguments. It requires at least 2 and at most 3." - ), + Ok(Arc::new(result) as ArrayRef) + } +} + +trait StringArrayType<'a>: ArrayAccessor + Sized { + fn iter(&self) -> ArrayIter; +} +impl<'a, O: OffsetSizeTrait> StringArrayType<'a> for &'a GenericStringArray { + fn iter(&self) -> ArrayIter { + GenericStringArray::::iter(self) + } +} +impl<'a> StringArrayType<'a> for &'a StringViewArray { + fn iter(&self) -> ArrayIter { + StringViewArray::iter(self) } } #[cfg(test)] mod tests { - use arrow::array::{Array, StringArray}; - use arrow::datatypes::DataType::Utf8; + use crate::unicode::lpad::LPadFunc; + use crate::utils::test::test_function; + + use arrow::array::{Array, LargeStringArray, StringArray}; + use arrow::datatypes::DataType::{LargeUtf8, Utf8}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::unicode::lpad::LPadFunc; - use crate::utils::test::test_function; + macro_rules! test_lpad { + ($INPUT:expr, $LENGTH:expr, $EXPECTED:expr) => { + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + }; + + ($INPUT:expr, $LENGTH:expr, $REPLACE:expr, $EXPECTED:expr) => { + // utf8, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + + // largeutf8, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + // largeutf8, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + // largeutf8, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + LargeUtf8, + LargeStringArray + ); + + // utf8view, utf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8view, largeutf8 + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + // utf8view, utf8view + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), + ColumnarValue::Scalar($LENGTH), + ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) + ], + $EXPECTED, + &str, + Utf8, + StringArray + ); + }; + } #[test] fn test_functions() -> Result<()> { - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some(" josé")), - &str, - Utf8, - StringArray - ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), - ], - Ok(Some("")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(5i64)), + Ok(Some(" josé")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Ok(Some(" hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(0i64)), + Ok(Some("")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(Some("xyxhi")), - &str, - Utf8, - StringArray + test_lpad!(Some("hi".into()), ScalarValue::Int64(None), Ok(None)); + test_lpad!(None, ScalarValue::Int64(Some(5i64)), Ok(None)); + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some("xy".into()), + Ok(Some("xyxhi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(21i64)), - ColumnarValue::Scalar(ScalarValue::from("abcdef")), - ], - Ok(Some("abcdefabcdefabcdefahi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(21i64)), + Some("abcdef".into()), + Ok(Some("abcdefabcdefabcdefahi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from(" ")), - ], - Ok(Some(" hi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some(" ".into()), + Ok(Some(" hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("")), - ], - Ok(Some("hi")), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + Some("".into()), + Ok(Some("hi")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + None, + ScalarValue::Int64(Some(5i64)), + Some("xy".into()), + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(None), + Some("xy".into()), + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray + test_lpad!( + Some("hi".into()), + ScalarValue::Int64(Some(5i64)), + None, + Ok(None) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), - ], - Ok(Some("xyxyxyjosé")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(10i64)), + Some("xy".into()), + Ok(Some("xyxyxyjosé")) ); - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("éñ")), - ], - Ok(Some("éñéñéñjosé")), - &str, - Utf8, - StringArray + test_lpad!( + Some("josé".into()), + ScalarValue::Int64(Some(10i64)), + Some("éñ".into()), + Ok(Some("éñéñéñjosé")) ); + #[cfg(not(feature = "unicode_expressions"))] - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - internal_err!( + test_lpad!(Some("josé".into()), ScalarValue::Int64(Some(5i64)), internal_err!( "function lpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); + )); + Ok(()) } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index c3dd791f6ca8..8a4855ea2c05 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -203,6 +203,32 @@ SELECT lpad(NULL, 5, 'xy') ---- NULL +# test largeutf8, utf8view for lpad +query T +SELECT lpad(arrow_cast('hi', 'LargeUtf8'), 5, 'xy') +---- +xyxhi + +query T +SELECT lpad(arrow_cast('hi', 'Utf8View'), 5, 'xy') +---- +xyxhi + +query T +SELECT lpad(arrow_cast('hi', 'LargeUtf8'), 5, arrow_cast('xy', 'LargeUtf8')) +---- +xyxhi + +query T +SELECT lpad(arrow_cast('hi', 'Utf8View'), 5, arrow_cast('xy', 'Utf8View')) +---- +xyxhi + +query T +SELECT lpad(arrow_cast(NULL, 'Utf8View'), 5, 'xy') +---- +NULL + query T SELECT reverse('abcde') ---- diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index e7166690580f..dcc6784bf44a 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -634,16 +634,32 @@ logical_plan 02)--TableScan: test projection=[column1_utf8view] ## Ensure no casts for LPAD -## TODO https://github.com/apache/datafusion/issues/11857 query TT EXPLAIN SELECT LPAD(column1_utf8view, 12, ' ') as c1 FROM test; ---- logical_plan -01)Projection: lpad(CAST(test.column1_utf8view AS Utf8), Int64(12), Utf8(" ")) AS c1 +01)Projection: lpad(test.column1_utf8view, Int64(12), Utf8(" ")) AS c1 02)--TableScan: test projection=[column1_utf8view] +query TT +EXPLAIN SELECT + LPAD(column1_utf8view, 12, column2_large_utf8) as c1 +FROM test; +---- +logical_plan +01)Projection: lpad(test.column1_utf8view, Int64(12), test.column2_large_utf8) AS c1 +02)--TableScan: test projection=[column2_large_utf8, column1_utf8view] + +query TT +EXPLAIN SELECT + LPAD(column1_utf8view, 12, column2_utf8view) as c1 +FROM test; +---- +logical_plan +01)Projection: lpad(test.column1_utf8view, Int64(12), test.column2_utf8view) AS c1 +02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for OCTET_LENGTH ## TODO https://github.com/apache/datafusion/issues/11858