From 50359b1d75feddb67044243737ea7ac621be5b39 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sat, 17 Aug 2024 15:39:50 -0400 Subject: [PATCH] Updated to use new StringArrays enum. --- datafusion/functions/src/unicode/lpad.rs | 74 ++++-------------------- datafusion/functions/src/unicode/rpad.rs | 64 ++++++-------------- 2 files changed, 29 insertions(+), 109 deletions(-) diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 50dc06cdbfb5..b0f3636c36c8 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -19,10 +19,7 @@ use std::any::Any; use std::fmt::Write; use std::sync::Arc; -use arrow::array::{ - Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array, - OffsetSizeTrait, StringViewArray, -}; +use arrow::array::{Array, ArrayRef, GenericStringBuilder, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; use unicode_segmentation::UnicodeSegmentation; use DataType::{LargeUtf8, Utf8, Utf8View}; @@ -32,7 +29,7 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; -use crate::utils::{make_scalar_function, utf8_to_str_type, StringArrayType}; +use crate::utils::{make_scalar_function, utf8_to_str_type, Iter, StringArrays}; #[derive(Debug)] pub struct LPadFunc { @@ -110,70 +107,23 @@ pub fn lpad(args: &[ArrayRef]) -> Result { let length_array = as_int64_array(&args[1])?; match (args.len(), args[0].data_type()) { - (2, Utf8View) => lpad_impl::<&StringViewArray, &GenericStringArray, T>( - args[0].as_string_view(), - length_array, - None, - ), - (2, Utf8 | LargeUtf8) => lpad_impl::< - &GenericStringArray, - &GenericStringArray, - T, - >(args[0].as_string::(), length_array, None), - (3, Utf8View) => lpad_with_replace::<&StringViewArray, T>( - args[0].as_string_view(), - length_array, - &args[2], - ), - (3, Utf8 | LargeUtf8) => lpad_with_replace::<&GenericStringArray, T>( - args[0].as_string::(), + (2, Utf8View | Utf8 | LargeUtf8) => { + lpad_impl::(StringArrays::try_from(&args[0])?, length_array, None) + } + (3, Utf8View | Utf8 | LargeUtf8) => lpad_impl::( + StringArrays::try_from(&args[0])?, length_array, - &args[2], + Some(StringArrays::try_from(&args[2])?), ), (_, _) => unreachable!(), } } -fn lpad_with_replace<'a, V, T: OffsetSizeTrait>( - string_array: V, - length_array: &Int64Array, - fill_array: &'a ArrayRef, -) -> Result -where - V: StringArrayType<'a>, -{ - match fill_array.data_type() { - Utf8View => lpad_impl::( - string_array, - length_array, - Some(fill_array.as_string_view()), - ), - LargeUtf8 => lpad_impl::, T>( - string_array, - length_array, - Some(fill_array.as_string::()), - ), - Utf8 => lpad_impl::, T>( - string_array, - length_array, - Some(fill_array.as_string::()), - ), - other => { - exec_err!("Unsupported data type {other:?} for function lpad") - } - } -} - -fn lpad_impl<'a, V, V2, T>( - string_array: V, +fn lpad_impl<'a, T: OffsetSizeTrait>( + string_array: StringArrays, length_array: &Int64Array, - fill_array: Option, -) -> Result -where - V: StringArrayType<'a>, - V2: StringArrayType<'a>, - T: OffsetSizeTrait, -{ + fill_array: Option, +) -> Result { let array = if fill_array.is_none() { let mut builder: GenericStringBuilder = GenericStringBuilder::new(); diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 4bcf102c8793..31a59ac2ad39 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -20,12 +20,11 @@ use std::sync::Arc; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion_common::cast::{ - as_generic_string_array, as_int64_array, as_string_view_array, -}; +use DataType::*; +use datafusion_common::cast::as_int64_array; use unicode_segmentation::UnicodeSegmentation; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::{make_scalar_function, utf8_to_str_type, Iter, StringArrays}; use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -86,26 +85,26 @@ impl ScalarUDFImpl for RPadFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args.len() { 2 => match args[0].data_type() { - DataType::Utf8 | DataType::Utf8View => { + Utf8 | Utf8View => { make_scalar_function(rpad::, vec![])(args) } - DataType::LargeUtf8 => { + LargeUtf8 => { make_scalar_function(rpad::, vec![])(args) } other => exec_err!("Unsupported data type {other:?} for function rpad"), }, 3 => match (args[0].data_type(), args[2].data_type()) { ( - DataType::Utf8 | DataType::Utf8View, - DataType::Utf8 | DataType::Utf8View, + Utf8 | Utf8View, + Utf8 | Utf8View, ) => make_scalar_function(rpad::, vec![])(args), - (DataType::LargeUtf8, DataType::LargeUtf8) => { + (LargeUtf8, LargeUtf8) => { make_scalar_function(rpad::, vec![])(args) } - (DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => { + (LargeUtf8, Utf8View | Utf8) => { make_scalar_function(rpad::, vec![])(args) } - (DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => { + (Utf8View | Utf8, LargeUtf8) => { make_scalar_function(rpad::, vec![])(args) } (first_type, last_type) => { @@ -185,56 +184,27 @@ macro_rules! process_rpad { }}; } -/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// Extends the string to length 'length' by appending the characters fill (a space by default). +/// If the string is already longer than length then it is truncated. /// rpad('hi', 5, 'xy') = 'hixyx' pub fn rpad( args: &[ArrayRef], ) -> Result { match (args.len(), args[0].data_type()) { - (2, DataType::Utf8View) => { - let string_array = as_string_view_array(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = process_rpad!(string_array, length_array)?; - Ok(Arc::new(result) as ArrayRef) - } (2, _) => { - let string_array = as_generic_string_array::(&args[0])?; + let string_array = StringArrays::try_from(&args[0])?; let length_array = as_int64_array(&args[1])?; let result = process_rpad!(string_array, length_array)?; Ok(Arc::new(result) as ArrayRef) } - (3, DataType::Utf8View) => { - let string_array = as_string_view_array(&args[0])?; - let length_array = as_int64_array(&args[1])?; - match args[2].data_type() { - DataType::Utf8View => { - let fill_array = as_string_view_array(&args[2])?; - let result = process_rpad!(string_array, length_array, fill_array)?; - Ok(Arc::new(result) as ArrayRef) - } - DataType::Utf8 | DataType::LargeUtf8 => { - let fill_array = as_generic_string_array::(&args[2])?; - let result = process_rpad!(string_array, length_array, fill_array)?; - Ok(Arc::new(result) as ArrayRef) - } - other_type => { - exec_err!("unsupported type for rpad's third operator: {}", other_type) - } - } - } (3, _) => { - let string_array = as_generic_string_array::(&args[0])?; + let string_array = StringArrays::try_from(&args[0])?; let length_array = as_int64_array(&args[1])?; + match args[2].data_type() { - DataType::Utf8View => { - let fill_array = as_string_view_array(&args[2])?; - let result = process_rpad!(string_array, length_array, fill_array)?; - Ok(Arc::new(result) as ArrayRef) - } - DataType::Utf8 | DataType::LargeUtf8 => { - let fill_array = as_generic_string_array::(&args[2])?; + Utf8View | Utf8 | LargeUtf8 => { + let fill_array = StringArrays::try_from(&args[2])?; let result = process_rpad!(string_array, length_array, fill_array)?; Ok(Arc::new(result) as ArrayRef) }