Skip to content

Commit

Permalink
Move lower, octet_length to datafusion-functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 committed Mar 23, 2024
1 parent ddd0627 commit 36a5480
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 181 deletions.
16 changes: 0 additions & 16 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ pub enum BuiltinScalarFunction {
Left,
/// lpad
Lpad,
/// lower
Lower,
/// octet_length
OctetLength,
/// random
Random,
/// repeat
Expand Down Expand Up @@ -247,8 +243,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::InitCap => Volatility::Immutable,
BuiltinScalarFunction::Left => Volatility::Immutable,
BuiltinScalarFunction::Lpad => Volatility::Immutable,
BuiltinScalarFunction::Lower => Volatility::Immutable,
BuiltinScalarFunction::OctetLength => Volatility::Immutable,
BuiltinScalarFunction::Radians => Volatility::Immutable,
BuiltinScalarFunction::Repeat => Volatility::Immutable,
BuiltinScalarFunction::Replace => Volatility::Immutable,
Expand Down Expand Up @@ -305,13 +299,7 @@ impl BuiltinScalarFunction {
utf8_to_str_type(&input_expr_types[0], "initcap")
}
BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"),
BuiltinScalarFunction::Lower => {
utf8_to_str_type(&input_expr_types[0], "lower")
}
BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"),
BuiltinScalarFunction::OctetLength => {
utf8_to_int_type(&input_expr_types[0], "octet_length")
}
BuiltinScalarFunction::Pi => Ok(Float64),
BuiltinScalarFunction::Random => Ok(Float64),
BuiltinScalarFunction::Uuid => Ok(Utf8),
Expand Down Expand Up @@ -428,8 +416,6 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::BitLength
| BuiltinScalarFunction::CharacterLength
| BuiltinScalarFunction::InitCap
| BuiltinScalarFunction::Lower
| BuiltinScalarFunction::OctetLength
| BuiltinScalarFunction::Reverse => {
Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility())
}
Expand Down Expand Up @@ -682,9 +668,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::EndsWith => &["ends_with"],
BuiltinScalarFunction::InitCap => &["initcap"],
BuiltinScalarFunction::Left => &["left"],
BuiltinScalarFunction::Lower => &["lower"],
BuiltinScalarFunction::Lpad => &["lpad"],
BuiltinScalarFunction::OctetLength => &["octet_length"],
BuiltinScalarFunction::Repeat => &["repeat"],
BuiltinScalarFunction::Replace => &["replace"],
BuiltinScalarFunction::Reverse => &["reverse"],
Expand Down
9 changes: 0 additions & 9 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,13 +599,6 @@ scalar_expr!(
);
scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase");
scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`");
scalar_expr!(Lower, lower, string, "convert the string to lower case");
scalar_expr!(
OctetLength,
octet_length,
string,
"returns the number of bytes of a string"
);
scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `from` with `to` in the `string`");
scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times");
scalar_expr!(Reverse, reverse, string, "reverses the `string`");
Expand Down Expand Up @@ -1069,10 +1062,8 @@ mod test {
test_scalar_expr!(Lcm, lcm, arg_1, arg_2);
test_scalar_expr!(InitCap, initcap, string);
test_scalar_expr!(Left, left, string, count);
test_scalar_expr!(Lower, lower, string);
test_nary_scalar_expr!(Lpad, lpad, string, count);
test_nary_scalar_expr!(Lpad, lpad, string, count, characters);
test_scalar_expr!(OctetLength, octet_length, string);
test_scalar_expr!(Replace, replace, string, from, to);
test_scalar_expr!(Repeat, repeat, string, count);
test_scalar_expr!(Reverse, reverse, string);
Expand Down
67 changes: 67 additions & 0 deletions datafusion/functions/src/string/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ macro_rules! get_optimal_return_type {
// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size.
get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);

// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size.
get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);

/// applies a unary expression to `args[0]` that is expected to be downcastable to
/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset)
/// # Errors
Expand Down Expand Up @@ -263,3 +266,67 @@ where
}
})
}

#[cfg(test)]
pub mod test {
/// $FUNC ScalarUDFImpl to test
/// $ARGS arguments (vec) to pass to function
/// $EXPECTED a Result<ColumnarValue>
/// $EXPECTED_TYPE is the expected value type
/// $EXPECTED_DATA_TYPE is the expected result type
/// $ARRAY_TYPE is the column type after function applied
macro_rules! test_function {
($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
let func = $FUNC;

let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
let return_type = func.return_type(&type_array);

match expected {
Ok(expected) => {
assert_eq!(return_type.is_ok(), true);
assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE);

let result = func.invoke($ARGS);
assert_eq!(result.is_ok(), true);

let len = $ARGS
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});
let inferred_length = len.unwrap_or(1);
let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array");
let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");

// value is correct
match expected {
Some(v) => assert_eq!(result.value(0), v),
None => assert!(result.is_null(0)),
};
}
Err(expected_error) => {
if return_type.is_err() {
match return_type {
Ok(_) => assert!(false, "expected error"),
Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); }
}
}
else {
// invoke is expected error - cannot use .expect_err() due to Debug not being implemented
match func.invoke($ARGS) {
Ok(_) => assert!(false, "expected error"),
Err(error) => {
assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
}
}
}
}
};
};
}

pub(crate) use test_function;
}
63 changes: 63 additions & 0 deletions datafusion/functions/src/string/lower.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::string::common::{handle, utf8_to_str_type};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;

#[derive(Debug)]
pub(super) struct LowerFunc {
signature: Signature,
}

impl LowerFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::uniform(
1,
vec![Utf8, LargeUtf8],
Volatility::Immutable,
),
}
}
}

impl ScalarUDFImpl for LowerFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"lower"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "lower")
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
handle(args, |string| string.to_lowercase(), "lower")
}
}
16 changes: 16 additions & 0 deletions datafusion/functions/src/string/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ use datafusion_expr::ScalarUDF;
mod ascii;
mod btrim;
mod common;
mod lower;
mod ltrim;
mod octet_length;
mod rtrim;
mod starts_with;
mod to_hex;
Expand All @@ -34,6 +36,8 @@ mod upper;
make_udf_function!(ascii::AsciiFunc, ASCII, ascii);
make_udf_function!(btrim::BTrimFunc, BTRIM, btrim);
make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim);
make_udf_function!(lower::LowerFunc, LOWER, lower);
make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length);
make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim);
make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with);
make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex);
Expand All @@ -52,11 +56,21 @@ pub mod expr_fn {
super::btrim().call(args)
}

#[doc = "Converts a string to lowercase."]
pub fn lower(arg1: Expr) -> Expr {
super::lower().call(vec![arg1])
}

#[doc = "Removes all characters, spaces by default, from the beginning of a string"]
pub fn ltrim(args: Vec<Expr>) -> Expr {
super::ltrim().call(args)
}

#[doc = "returns the number of bytes of a string"]
pub fn octet_length(args: Vec<Expr>) -> Expr {
super::octet_length().call(args)
}

#[doc = "Removes all characters, spaces by default, from the end of a string"]
pub fn rtrim(args: Vec<Expr>) -> Expr {
super::rtrim().call(args)
Expand All @@ -83,7 +97,9 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
vec![
ascii(),
btrim(),
lower(),
ltrim(),
octet_length(),
rtrim(),
starts_with(),
to_hex(),
Expand Down
Loading

0 comments on commit 36a5480

Please sign in to comment.