Skip to content

Commit

Permalink
Updated to use new StringArrays enum.
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 committed Aug 17, 2024
1 parent f01c084 commit 50359b1
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 109 deletions.
74 changes: 12 additions & 62 deletions datafusion/functions/src/unicode/lpad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@ use std::any::Any;
use std::fmt::Write;
use std::sync::Arc;

use arrow::array::{
Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
OffsetSizeTrait, StringViewArray,
};
use arrow::array::{Array, ArrayRef, GenericStringBuilder, Int64Array, OffsetSizeTrait};
use arrow::datatypes::DataType;
use unicode_segmentation::UnicodeSegmentation;
use DataType::{LargeUtf8, Utf8, Utf8View};
Expand All @@ -32,7 +29,7 @@ use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

use crate::utils::{make_scalar_function, utf8_to_str_type, StringArrayType};
use crate::utils::{make_scalar_function, utf8_to_str_type, Iter, StringArrays};

#[derive(Debug)]
pub struct LPadFunc {
Expand Down Expand Up @@ -110,70 +107,23 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let length_array = as_int64_array(&args[1])?;

match (args.len(), args[0].data_type()) {
(2, Utf8View) => lpad_impl::<&StringViewArray, &GenericStringArray<i32>, T>(
args[0].as_string_view(),
length_array,
None,
),
(2, Utf8 | LargeUtf8) => lpad_impl::<
&GenericStringArray<T>,
&GenericStringArray<T>,
T,
>(args[0].as_string::<T>(), length_array, None),
(3, Utf8View) => lpad_with_replace::<&StringViewArray, T>(
args[0].as_string_view(),
length_array,
&args[2],
),
(3, Utf8 | LargeUtf8) => lpad_with_replace::<&GenericStringArray<T>, T>(
args[0].as_string::<T>(),
(2, Utf8View | Utf8 | LargeUtf8) => {
lpad_impl::<T>(StringArrays::try_from(&args[0])?, length_array, None)
}
(3, Utf8View | Utf8 | LargeUtf8) => lpad_impl::<T>(
StringArrays::try_from(&args[0])?,
length_array,
&args[2],
Some(StringArrays::try_from(&args[2])?),
),
(_, _) => unreachable!(),
}
}

fn lpad_with_replace<'a, V, T: OffsetSizeTrait>(
string_array: V,
length_array: &Int64Array,
fill_array: &'a ArrayRef,
) -> Result<ArrayRef>
where
V: StringArrayType<'a>,
{
match fill_array.data_type() {
Utf8View => lpad_impl::<V, &StringViewArray, T>(
string_array,
length_array,
Some(fill_array.as_string_view()),
),
LargeUtf8 => lpad_impl::<V, &GenericStringArray<i64>, T>(
string_array,
length_array,
Some(fill_array.as_string::<i64>()),
),
Utf8 => lpad_impl::<V, &GenericStringArray<i32>, T>(
string_array,
length_array,
Some(fill_array.as_string::<i32>()),
),
other => {
exec_err!("Unsupported data type {other:?} for function lpad")
}
}
}

fn lpad_impl<'a, V, V2, T>(
string_array: V,
fn lpad_impl<'a, T: OffsetSizeTrait>(
string_array: StringArrays,
length_array: &Int64Array,
fill_array: Option<V2>,
) -> Result<ArrayRef>
where
V: StringArrayType<'a>,
V2: StringArrayType<'a>,
T: OffsetSizeTrait,
{
fill_array: Option<StringArrays>,
) -> Result<ArrayRef> {
let array = if fill_array.is_none() {
let mut builder: GenericStringBuilder<T> = GenericStringBuilder::new();

Expand Down
64 changes: 17 additions & 47 deletions datafusion/functions/src/unicode/rpad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ use std::sync::Arc;

use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
use arrow::datatypes::DataType;
use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_string_view_array,
};
use DataType::*;
use datafusion_common::cast::as_int64_array;
use unicode_segmentation::UnicodeSegmentation;

use crate::utils::{make_scalar_function, utf8_to_str_type};
use crate::utils::{make_scalar_function, utf8_to_str_type, Iter, StringArrays};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
Expand Down Expand Up @@ -86,26 +85,26 @@ impl ScalarUDFImpl for RPadFunc {
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args.len() {
2 => match args[0].data_type() {
DataType::Utf8 | DataType::Utf8View => {
Utf8 | Utf8View => {
make_scalar_function(rpad::<i32, i32>, vec![])(args)
}
DataType::LargeUtf8 => {
LargeUtf8 => {
make_scalar_function(rpad::<i64, i64>, vec![])(args)
}
other => exec_err!("Unsupported data type {other:?} for function rpad"),
},
3 => match (args[0].data_type(), args[2].data_type()) {
(
DataType::Utf8 | DataType::Utf8View,
DataType::Utf8 | DataType::Utf8View,
Utf8 | Utf8View,
Utf8 | Utf8View,
) => make_scalar_function(rpad::<i32, i32>, vec![])(args),
(DataType::LargeUtf8, DataType::LargeUtf8) => {
(LargeUtf8, LargeUtf8) => {
make_scalar_function(rpad::<i64, i64>, vec![])(args)
}
(DataType::LargeUtf8, DataType::Utf8View | DataType::Utf8) => {
(LargeUtf8, Utf8View | Utf8) => {
make_scalar_function(rpad::<i64, i32>, vec![])(args)
}
(DataType::Utf8View | DataType::Utf8, DataType::LargeUtf8) => {
(Utf8View | Utf8, LargeUtf8) => {
make_scalar_function(rpad::<i32, i64>, vec![])(args)
}
(first_type, last_type) => {
Expand Down Expand Up @@ -185,56 +184,27 @@ macro_rules! process_rpad {
}};
}

/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated.
/// Extends the string to length 'length' by appending the characters fill (a space by default).
/// If the string is already longer than length then it is truncated.
/// rpad('hi', 5, 'xy') = 'hixyx'
pub fn rpad<StringArrayLen: OffsetSizeTrait, FillArrayLen: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
match (args.len(), args[0].data_type()) {
(2, DataType::Utf8View) => {
let string_array = as_string_view_array(&args[0])?;
let length_array = as_int64_array(&args[1])?;

let result = process_rpad!(string_array, length_array)?;
Ok(Arc::new(result) as ArrayRef)
}
(2, _) => {
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
let string_array = StringArrays::try_from(&args[0])?;
let length_array = as_int64_array(&args[1])?;

let result = process_rpad!(string_array, length_array)?;
Ok(Arc::new(result) as ArrayRef)
}
(3, DataType::Utf8View) => {
let string_array = as_string_view_array(&args[0])?;
let length_array = as_int64_array(&args[1])?;
match args[2].data_type() {
DataType::Utf8View => {
let fill_array = as_string_view_array(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
DataType::Utf8 | DataType::LargeUtf8 => {
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
other_type => {
exec_err!("unsupported type for rpad's third operator: {}", other_type)
}
}
}
(3, _) => {
let string_array = as_generic_string_array::<StringArrayLen>(&args[0])?;
let string_array = StringArrays::try_from(&args[0])?;
let length_array = as_int64_array(&args[1])?;

match args[2].data_type() {
DataType::Utf8View => {
let fill_array = as_string_view_array(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
DataType::Utf8 | DataType::LargeUtf8 => {
let fill_array = as_generic_string_array::<FillArrayLen>(&args[2])?;
Utf8View | Utf8 | LargeUtf8 => {
let fill_array = StringArrays::try_from(&args[2])?;
let result = process_rpad!(string_array, length_array, fill_array)?;
Ok(Arc::new(result) as ArrayRef)
}
Expand Down

0 comments on commit 50359b1

Please sign in to comment.