From 2026df418d4977bedb8592f28d08568e166c67b3 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 14 Mar 2024 09:42:01 -0400 Subject: [PATCH 1/9] Fix to_timestamp benchmark --- datafusion/functions/benches/to_timestamp.rs | 173 ++++++++++--------- 1 file changed, 92 insertions(+), 81 deletions(-) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index c83824526442..31d609dee9bc 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -17,97 +17,108 @@ extern crate criterion; +use std::sync::Arc; + +use arrow_array::builder::StringBuilder; +use arrow_array::ArrayRef; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_expr::lit; -use datafusion_functions::expr_fn::to_timestamp; +use datafusion_expr::ColumnarValue; +use datafusion_functions::datetime::to_timestamp; fn criterion_benchmark(c: &mut Criterion) { c.bench_function("to_timestamp_no_formats", |b| { - let inputs = vec![ - lit("1997-01-31T09:26:56.123Z"), - lit("1997-01-31T09:26:56.123-05:00"), - lit("1997-01-31 09:26:56.123-05:00"), - lit("2023-01-01 04:05:06.789 -08"), - lit("1997-01-31T09:26:56.123"), - lit("1997-01-31 09:26:56.123"), - lit("1997-01-31 09:26:56"), - lit("1997-01-31 13:26:56"), - lit("1997-01-31 13:26:56+04:00"), - lit("1997-01-31"), - ]; + let mut inputs = StringBuilder::new(); + inputs.append_value("1997-01-31T09:26:56.123Z"); + inputs.append_value("1997-01-31T09:26:56.123-05:00"); + inputs.append_value("1997-01-31 09:26:56.123-05:00"); + inputs.append_value("2023-01-01 04:05:06.789 -08"); + inputs.append_value("1997-01-31T09:26:56.123"); + inputs.append_value("1997-01-31 09:26:56.123"); + inputs.append_value("1997-01-31 09:26:56"); + inputs.append_value("1997-01-31 13:26:56"); + inputs.append_value("1997-01-31 13:26:56+04:00"); + inputs.append_value("1997-01-31"); + + let string_array = ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef); + b.iter(|| { - for i in inputs.iter() { - black_box(to_timestamp(vec![i.clone()])); - } - }); + black_box( + to_timestamp() + .invoke(&[string_array.clone()]) + .expect("to_timestamp should work on valid values"), + ) + }) }); c.bench_function("to_timestamp_with_formats", |b| { - let mut inputs = vec![]; - let mut format1 = vec![]; - let mut format2 = vec![]; - let mut format3 = vec![]; - - inputs.push(lit("1997-01-31T09:26:56.123Z")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%dT%H:%M:%S%.f%Z")); - - inputs.push(lit("1997-01-31T09:26:56.123-05:00")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%dT%H:%M:%S%.f%z")); - - inputs.push(lit("1997-01-31 09:26:56.123-05:00")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H:%M:%S%.f%Z")); - - inputs.push(lit("2023-01-01 04:05:06.789 -08")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H:%M:%S%.f %#z")); - - inputs.push(lit("1997-01-31T09:26:56.123")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%dT%H:%M:%S%.f")); - - inputs.push(lit("1997-01-31 09:26:56.123")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H:%M:%S%.f")); - - inputs.push(lit("1997-01-31 09:26:56")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H:%M:%S")); - - inputs.push(lit("1997-01-31 092656")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H%M%S")); - - inputs.push(lit("1997-01-31 092656+04:00")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d %H%M%S%:z")); - - inputs.push(lit("Sun Jul 8 00:34:60 2001")); - format1.push(lit("%+")); - format2.push(lit("%c")); - format3.push(lit("%Y-%m-%d 00:00:00")); - + let mut inputs = StringBuilder::new(); + let mut format1_builder = StringBuilder::with_capacity(2, 10); + let mut format2_builder = StringBuilder::with_capacity(2, 10); + let mut format3_builder = StringBuilder::with_capacity(2, 10); + + inputs.append_value("1997-01-31T09:26:56.123Z"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%Z"); + + inputs.append_value("1997-01-31T09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f%z"); + + inputs.append_value("1997-01-31 09:26:56.123-05:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f%Z"); + + inputs.append_value("2023-01-01 04:05:06.789 -08"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f %#z"); + + inputs.append_value("1997-01-31T09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%dT%H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56.123"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S%.f"); + + inputs.append_value("1997-01-31 09:26:56"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H:%M:%S"); + + inputs.append_value("1997-01-31 092656"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S"); + + inputs.append_value("1997-01-31 092656+04:00"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d %H%M%S%:z"); + + inputs.append_value("Sun Jul 8 00:34:60 2001"); + format1_builder.append_value("%+"); + format2_builder.append_value("%c"); + format3_builder.append_value("%Y-%m-%d 00:00:00"); + + let args = [ + ColumnarValue::Array(Arc::new(inputs.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format1_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format2_builder.finish()) as ArrayRef), + ColumnarValue::Array(Arc::new(format3_builder.finish()) as ArrayRef), + ]; b.iter(|| { - inputs.iter().enumerate().for_each(|(idx, i)| { - black_box(to_timestamp(vec![ - i.clone(), - format1.get(idx).unwrap().clone(), - format2.get(idx).unwrap().clone(), - format3.get(idx).unwrap().clone(), - ])); - }) + black_box( + to_timestamp() + .invoke(&args.clone()) + .expect("to_timestamp should work on valid values"), + ) }) }); } From 6a450b4fa2caa523cb42580c912f742dd1a1ed2b Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 18 Mar 2024 10:43:02 -0400 Subject: [PATCH 2/9] Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. --- docs/source/user-guide/example-usage.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 1c5c8f49a16a..c5eefbdaf156 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -240,17 +240,11 @@ async fn main() -> datafusion::error::Result<()> { } ``` -Finally, in order to build with the `simd` optimization `cargo nightly` is required. - -```shell -rustup toolchain install nightly -``` - Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally with `native` or at least `avx2`. ```shell -RUSTFLAGS='-C target-cpu=native' cargo +nightly run --release +RUSTFLAGS='-C target-cpu=native' cargo run --release ``` ## Enable backtraces From a94a4f6c3317e8e952d34d968996fbd603cd0c2e Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 22 Mar 2024 22:20:35 -0400 Subject: [PATCH 3/9] Fixed missing trim() function. --- datafusion/functions/src/string/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 63026092f39a..517869a25682 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -72,6 +72,11 @@ pub mod expr_fn { super::to_hex().call(vec![arg1]) } + #[doc = "Removes all characters, spaces by default, from both sides of a string"] + pub fn trim(args: Vec) -> Expr { + super::btrim().call(args) + } + #[doc = "Converts a string to uppercase."] pub fn upper(arg1: Expr) -> Expr { super::upper().call(vec![arg1]) From e3860fa52a6118720d42b74305bc92b2ace58f43 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Wed, 27 Mar 2024 12:08:02 -0400 Subject: [PATCH 4/9] Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function --- datafusion-cli/Cargo.lock | 1 + datafusion/core/Cargo.toml | 1 + .../tests/dataframe/dataframe_functions.rs | 1 + datafusion/expr/src/built_in_function.rs | 14 +- datafusion/expr/src/expr_fn.rs | 8 - datafusion/functions/Cargo.toml | 4 + datafusion/functions/src/lib.rs | 9 + datafusion/functions/src/string/ascii.rs | 2 +- datafusion/functions/src/string/bit_length.rs | 4 +- datafusion/functions/src/string/btrim.rs | 1 + datafusion/functions/src/string/chr.rs | 2 +- datafusion/functions/src/string/common.rs | 158 +--------------- .../functions/src/string/levenshtein.rs | 3 +- datafusion/functions/src/string/lower.rs | 8 +- datafusion/functions/src/string/ltrim.rs | 3 +- .../functions/src/string/octet_length.rs | 13 +- datafusion/functions/src/string/overlay.rs | 2 +- datafusion/functions/src/string/repeat.rs | 4 +- datafusion/functions/src/string/replace.rs | 2 +- datafusion/functions/src/string/rtrim.rs | 1 + datafusion/functions/src/string/split_part.rs | 4 +- .../functions/src/string/starts_with.rs | 9 +- datafusion/functions/src/string/to_hex.rs | 9 +- datafusion/functions/src/string/upper.rs | 3 +- .../functions/src/unicode/character_length.rs | 176 +++++++++++++++++ datafusion/functions/src/unicode/mod.rs | 55 ++++++ datafusion/functions/src/utils.rs | 178 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 70 ------- .../physical-expr/src/unicode_expressions.rs | 23 --- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 - datafusion/sql/Cargo.toml | 1 + datafusion/sql/tests/sql_integration.rs | 15 +- 36 files changed, 484 insertions(+), 318 deletions(-) create mode 100644 datafusion/functions/src/unicode/character_length.rs create mode 100644 datafusion/functions/src/unicode/mod.rs create mode 100644 datafusion/functions/src/utils.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 2f1d95d639d4..424dda7fdc61 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1273,6 +1273,7 @@ dependencies = [ "md-5", "regex", "sha2", + "unicode-segmentation", "uuid", ] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 1e5c0d748e3d..de03579975a2 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -70,6 +70,7 @@ unicode_expressions = [ "datafusion-physical-expr/unicode_expressions", "datafusion-optimizer/unicode_expressions", "datafusion-sql/unicode_expressions", + "datafusion-functions/unicode_expressions", ] [dependencies] diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 6ebd64c9b628..4371cce856ce 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -37,6 +37,7 @@ use datafusion::assert_batches_eq; use datafusion_common::DFSchema; use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast, ExprSchemable}; +use datafusion_functions::unicode::expr_fn::character_length; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index bb0f79f8eca4..eefbc131a27b 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -103,8 +103,6 @@ pub enum BuiltinScalarFunction { Cot, // string functions - /// character_length - CharacterLength, /// concat Concat, /// concat_ws @@ -218,7 +216,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::CharacterLength => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, @@ -257,9 +254,6 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::CharacterLength => { - utf8_to_int_type(&input_expr_types[0], "character_length") - } BuiltinScalarFunction::Coalesce => { // COALESCE has multiple args and they might get coerced, get a preview of this let coerced_types = data_types(input_expr_types, &self.signature()); @@ -367,9 +361,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::InitCap - | BuiltinScalarFunction::Reverse => { + BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { @@ -584,10 +576,6 @@ impl BuiltinScalarFunction { // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], - // string functions - BuiltinScalarFunction::CharacterLength => { - &["character_length", "char_length", "length"] - } BuiltinScalarFunction::Concat => &["concat"], BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0ea946288e0f..654464798625 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -577,13 +577,6 @@ scalar_expr!(Power, power, base exponent, "`base` raised to the power of `expone scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); -// string functions -scalar_expr!( - CharacterLength, - character_length, - string, - "the number of characters in the `string`" -); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); @@ -1032,7 +1025,6 @@ mod test { test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); - test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 81050dfddf66..0cab0276ff4b 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -43,6 +43,7 @@ default = [ "regex_expressions", "crypto_expressions", "string_expressions", + "unicode_expressions", ] # enable encode/decode functions encoding_expressions = ["base64", "hex"] @@ -52,6 +53,8 @@ math_expressions = [] regex_expressions = ["regex"] # enable string functions string_expressions = [] +# enable unicode functions +unicode_expressions = ["unicode-segmentation"] [lib] name = "datafusion_functions" @@ -75,6 +78,7 @@ log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } +unicode-segmentation = { version = "^1.7.1", optional = true } uuid = { version = "1.7", features = ["v4"] } [dev-dependencies] diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index f469b343e144..2a00839dc532 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -124,6 +124,12 @@ make_stub_package!(regex, "regex_expressions"); pub mod crypto; make_stub_package!(crypto, "crypto_expressions"); +#[cfg(feature = "unicode_expressions")] +pub mod unicode; +make_stub_package!(unicode, "unicode_expressions"); + +mod utils; + /// Fluent-style API for creating `Expr`s pub mod expr_fn { #[cfg(feature = "core_expressions")] @@ -140,6 +146,8 @@ pub mod expr_fn { pub use super::regex::expr_fn::*; #[cfg(feature = "string_expressions")] pub use super::string::expr_fn::*; + #[cfg(feature = "unicode_expressions")] + pub use super::unicode::expr_fn::*; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -151,6 +159,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { .chain(math::functions()) .chain(regex::functions()) .chain(crypto::functions()) + .chain(unicode::functions()) .chain(string::functions()); all_functions.try_for_each(|udf| { diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 5bd77833a935..9a07f4c19cf1 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use crate::utils::make_scalar_function; use arrow::array::Int32Array; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 9f612751584e..6a200471d42d 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::kernels::length::bit_length; use std::any::Any; +use arrow::compute::kernels::length::bit_length; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::utf8_to_int_type; #[derive(Debug)] pub(super) struct BitLengthFunc { diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index de1c9cc69b72..573a23d07021 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index df3b803ba659..d1f8dc398a2b 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -29,7 +29,7 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::make_scalar_function; /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 339f4e6c1a23..276aad121df2 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -24,8 +24,7 @@ use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; -use datafusion_physical_expr::functions::Hint; +use datafusion_expr::ColumnarValue; pub(crate) enum TrimType { Left, @@ -98,52 +97,6 @@ pub(crate) fn general_trim( } } -/// Creates a function to identify the optimal return type of a string function given -/// the type of its first argument. -/// -/// If the input type is `LargeUtf8` or `LargeBinary` the return type is -/// `$largeUtf8Type`, -/// -/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, -macro_rules! get_optimal_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - // LargeBinary inputs are automatically coerced to Utf8 - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - // Binary inputs are automatically coerced to Utf8 - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - DataType::Dictionary(_, value_type) => match **value_type { - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - _ => { - return datafusion_common::exec_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - **value_type - ); - } - }, - data_type => { - return datafusion_common::exec_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - data_type - ); - } - }) - } - }; -} - -// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. -get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); - -// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. -get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - /// applies a unary expression to `args[0]` that is expected to be downcastable to /// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) /// # Errors @@ -221,112 +174,3 @@ where }, } } - -pub(super) fn make_scalar_function( - inner: F, - hints: Vec, -) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - Arc::new(move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - - let inferred_length = len.unwrap_or(1); - let args = args - .iter() - .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) - .map(|(arg, hint)| { - // Decide on the length to expand this scalar to depending - // on the given hints. - let expansion_len = match hint { - Hint::AcceptsSingular => 1, - Hint::Pad => inferred_length, - }; - arg.clone().into_array(expansion_len) - }) - .collect::>>()?; - - let result = (inner)(&args); - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) - } - }) -} - -#[cfg(test)] -pub mod test { - /// $FUNC ScalarUDFImpl to test - /// $ARGS arguments (vec) to pass to function - /// $EXPECTED a Result - /// $EXPECTED_TYPE is the expected value type - /// $EXPECTED_DATA_TYPE is the expected result type - /// $ARRAY_TYPE is the column type after function applied - macro_rules! test_function { - ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { - let expected: Result> = $EXPECTED; - let func = $FUNC; - - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); - let return_type = func.return_type(&type_array); - - match expected { - Ok(expected) => { - assert_eq!(return_type.is_ok(), true); - assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); - - let result = func.invoke($ARGS); - assert_eq!(result.is_ok(), true); - - let len = $ARGS - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - let inferred_length = len.unwrap_or(1); - let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); - let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); - - // value is correct - match expected { - Some(v) => assert_eq!(result.value(0), v), - None => assert!(result.is_null(0)), - }; - } - Err(expected_error) => { - if return_type.is_err() { - match return_type { - Ok(_) => assert!(false, "expected error"), - Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } - } - } - else { - // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke($ARGS) { - Ok(_) => assert!(false, "expected error"), - Err(error) => { - assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); - } - } - } - } - }; - }; - } - - pub(crate) use test_function; -} diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index b5de4b28948f..8f497e73e393 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, Result}; @@ -28,8 +29,6 @@ use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::string::common::{make_scalar_function, utf8_to_int_type}; - #[derive(Debug)] pub(super) struct LevenshteinFunc { signature: Signature, diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 42bda0470067..327772bd808d 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::{handle, utf8_to_str_type}; +use std::any::Any; + use arrow::datatypes::DataType; + use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; + +use crate::string::common::handle; +use crate::utils::utf8_to_str_type; #[derive(Debug)] pub(super) struct LowerFunc { diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 535ffb14f5f5..e6926e5bd56e 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, OffsetSizeTrait}; use std::any::Any; +use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 36a62fbe4e38..639bf6cb48a9 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::kernels::length::length; use std::any::Any; +use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::utf8_to_int_type; #[derive(Debug)] pub(super) struct OctetLengthFunc { @@ -86,14 +86,17 @@ impl ScalarUDFImpl for OctetLengthFunc { #[cfg(test)] mod tests { - use crate::string::common::test::test_function; - use crate::string::octet_length::OctetLengthFunc; + use std::sync::Arc; + use arrow::array::{Array, Int32Array, StringArray}; use arrow::datatypes::DataType::Int32; + use datafusion_common::ScalarValue; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use std::sync::Arc; + + use crate::string::octet_length::OctetLengthFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index d7cc0da8068e..8b9cc03afc4d 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct OverlayFunc { diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 83bc929cb9a4..f4319af0a5c4 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct RepeatFunc { @@ -99,8 +99,8 @@ mod tests { use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::string::common::test::test_function; use crate::string::repeat::RepeatFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index e35244296090..e869ac205440 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct ReplaceFunc { diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 17d2f8234b34..d04d15ce8847 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. /// rtrim('testxxzx', 'xyz') = 'test' diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index af201e90fcf6..0aa968a1ef5b 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct SplitPartFunc { @@ -117,8 +117,8 @@ mod tests { use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::string::common::test::test_function; use crate::string::split_part::SplitPartFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 4450b9d332a0..f1b03907f8d8 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; + use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; + +use crate::utils::make_scalar_function; /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 1bdece3f7af8..ab320c68d493 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, }; + use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; + +use crate::utils::make_scalar_function; /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index a0c910ebb2c8..066174abf277 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::{handle, utf8_to_str_type}; +use crate::string::common::handle; +use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::ColumnarValue; diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs new file mode 100644 index 000000000000..51331bf9a586 --- /dev/null +++ b/datafusion/functions/src/unicode/character_length.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{make_scalar_function, utf8_to_int_type}; +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub(super) struct CharacterLengthFunc { + signature: Signature, + aliases: Vec, +} + +impl CharacterLengthFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + aliases: vec![String::from("length"), String::from("char_length")], + } + } +} + +impl ScalarUDFImpl for CharacterLengthFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "character_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "character_length") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(character_length::, vec![])(args) + } + DataType::LargeUtf8 => { + make_scalar_function(character_length::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function character_length") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns number of characters in the string. +/// character_length('josé') = 4 +/// The implementation counts UTF-8 code points to count the number of characters +fn character_length(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + T::Native::from_usize(string.chars().count()) + .expect("should not fail as string.chars will always return integer") + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::unicode::character_length::CharacterLengthFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("chars") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("josé") + )))], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("") + )))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé"))))], + internal_err!( + "function character_length requires compilation with feature flag: unicode_expressions." + ), + i32, + Int32, + Int32Array + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs new file mode 100644 index 000000000000..291de3843903 --- /dev/null +++ b/datafusion/functions/src/unicode/mod.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! "unicode" DataFusion functions + +use std::sync::Arc; + +use datafusion_expr::ScalarUDF; + +mod character_length; + +// create UDFs +make_udf_function!( + character_length::CharacterLengthFunc, + CHARACTER_LENGTH, + character_length +); + +pub mod expr_fn { + use datafusion_expr::Expr; + + #[doc = "the number of characters in the `string`"] + pub fn char_length(string: Expr) -> Expr { + character_length(string) + } + + #[doc = "the number of characters in the `string`"] + pub fn character_length(string: Expr) -> Expr { + super::character_length().call(vec![string]) + } + + #[doc = "the number of characters in the `string`"] + pub fn length(string: Expr) -> Expr { + character_length(string) + } +} + +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![character_length()] +} diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs new file mode 100644 index 000000000000..f45deafdb37a --- /dev/null +++ b/datafusion/functions/src/utils.rs @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_physical_expr::functions::Hint; +use std::sync::Arc; + +/// Creates a function to identify the optimal return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! get_optimal_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + **value_type + ); + } + }, + data_type => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + data_type + ); + } + }) + } + }; +} + +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. +get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); + +pub(super) fn make_scalar_function( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) + .map(|(arg, hint)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hint { + Hint::AcceptsSingular => 1, + Hint::Pad => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>>()?; + + let result = (inner)(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + }) +} + +#[cfg(test)] +pub mod test { + /// $FUNC ScalarUDFImpl to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result + /// $EXPECTED_TYPE is the expected value type + /// $EXPECTED_DATA_TYPE is the expected result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_function { + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + let expected: Result> = $EXPECTED; + let func = $FUNC; + + let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let return_type = func.return_type(&type_array); + + match expected { + Ok(expected) => { + assert_eq!(return_type.is_ok(), true); + assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); + + let result = func.invoke($ARGS); + assert_eq!(result.is_ok(), true); + + let len = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let inferred_length = len.unwrap_or(1); + let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + if return_type.is_err() { + match return_type { + Ok(_) => assert!(false, "expected error"), + Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } + } + } + else { + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke($ARGS) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); + } + } + } + } + }; + }; + } + + pub(crate) use test_function; +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index cd9bba63d624..9adc8536341d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -254,29 +254,6 @@ pub fn create_physical_fun( Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } // string functions - BuiltinScalarFunction::CharacterLength => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int32Type, - "character_length" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int64Type, - "character_length" - ); - make_scalar_function_inner(func)(args) - } - other => exec_err!( - "Unsupported data type {other:?} for function character_length" - ), - }) - } BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { @@ -595,53 +572,6 @@ mod tests { #[test] fn test_functions() -> Result<()> { - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("chars")], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("josé")], - Ok(Some(4)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("")], - Ok(Some(0)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - CharacterLength, - &[lit("josé")], - internal_err!( - "function character_length requires compilation with feature flag: unicode_expressions." - ), - i32, - Int32, - Int32Array - ); test_function!( Concat, &[lit("aa"), lit("bb"), lit("cc"),], diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 8ec9e062d9b7..c7e4b7d7c443 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -36,29 +36,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns number of characters in the string. -/// character_length('josé') = 4 -/// The implementation counts UTF-8 code points to count the number of characters -pub fn character_length(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - T::Native::from_usize(string.chars().count()) - .expect("should not fail as string.chars will always return integer") - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. /// left('abcde', 2) = 'ab' /// The implementation uses UTF-8 code points as characters diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f405ecf976be..766ca6633ee1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -565,7 +565,7 @@ enum ScalarFunction { // RegexpMatch = 21; // 22 was BitLength // 23 was Btrim - CharacterLength = 24; + // 24 was CharacterLength // 25 was Chr Concat = 26; ConcatWithSeparator = 27; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0d22ba5db773..f2814956ef1b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22928,7 +22928,6 @@ impl serde::Serialize for ScalarFunction { Self::Sin => "Sin", Self::Sqrt => "Sqrt", Self::Trunc => "Trunc", - Self::CharacterLength => "CharacterLength", Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", @@ -22988,7 +22987,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sin", "Sqrt", "Trunc", - "CharacterLength", "Concat", "ConcatWithSeparator", "InitCap", @@ -23077,7 +23075,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sin" => Ok(ScalarFunction::Sin), "Sqrt" => Ok(ScalarFunction::Sqrt), "Trunc" => Ok(ScalarFunction::Trunc), - "CharacterLength" => Ok(ScalarFunction::CharacterLength), "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 07c3fad15373..ecc94fcdaf99 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2864,7 +2864,7 @@ pub enum ScalarFunction { /// RegexpMatch = 21; /// 22 was BitLength /// 23 was Btrim - CharacterLength = 24, + /// 24 was CharacterLength /// 25 was Chr Concat = 26, ConcatWithSeparator = 27, @@ -3001,7 +3001,6 @@ impl ScalarFunction { ScalarFunction::Sin => "Sin", ScalarFunction::Sqrt => "Sqrt", ScalarFunction::Trunc => "Trunc", - ScalarFunction::CharacterLength => "CharacterLength", ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", @@ -3055,7 +3054,6 @@ impl ScalarFunction { "Sin" => Some(Self::Sin), "Sqrt" => Some(Self::Sqrt), "Trunc" => Some(Self::Trunc), - "CharacterLength" => Some(Self::CharacterLength), "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4b9874bf8f65..19edd71a3a80 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,8 +48,8 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, asinh, atan, atan2, atanh, cbrt, ceil, character_length, coalesce, - concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, + acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, + cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -450,7 +450,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Concat => Self::Concat, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, - ScalarFunction::CharacterLength => Self::CharacterLength, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, @@ -1372,9 +1371,6 @@ pub fn parse_expr( ScalarFunction::Signum => { Ok(signum(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::CharacterLength => { - Ok(character_length(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1335d511a0ea..11fc7362c75d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1442,7 +1442,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, - BuiltinScalarFunction::CharacterLength => Self::CharacterLength, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index ca2c1a240c21..b9f6dc259eb7 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -49,6 +49,7 @@ strum = { version = "0.26.1", features = ["derive"] } [dev-dependencies] ctor = { workspace = true } +datafusion-functions = { workspace = true, default-features = true } env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 448a9c54202e..101c31039c7e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -38,6 +38,7 @@ use datafusion_sql::{ planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, }; +use datafusion_functions::unicode; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -88,7 +89,7 @@ fn parse_decimals() { fn parse_ident_normalization() { let test_data = [ ( - "SELECT LENGTH('str')", + "SELECT CHARACTER_LENGTH('str')", "Ok(Projection: character_length(Utf8(\"str\"))\n EmptyRelation)", false, ), @@ -2688,6 +2689,7 @@ fn logical_plan_with_dialect_and_options( options: ParserOptions, ) -> Result { let context = MockContextProvider::default() + .with_udf(unicode::character_length().as_ref().clone()) .with_udf(make_udf( "nullif", vec![DataType::Int32, DataType::Int32], @@ -4508,26 +4510,27 @@ fn test_field_not_found_window_function() { #[test] fn test_parse_escaped_string_literal_value() { - let sql = r"SELECT length('\r\n') AS len"; + let sql = r"SELECT character_length('\r\n') AS len"; let expected = "Projection: character_length(Utf8(\"\\r\\n\")) AS len\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\r\n') AS len"; + let sql = r"SELECT character_length(E'\r\n') AS len"; let expected = "Projection: character_length(Utf8(\"\r\n\")) AS len\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; + let sql = + r"SELECT character_length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; let expected = "Projection: character_length(Utf8(\"%\")) AS len, Utf8(\"\u{004b}\") AS hex, Utf8(\"\u{0001}\") AS unicode\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\000') AS len"; + let sql = r"SELECT character_length(E'\000') AS len"; assert_eq!( logical_plan(sql).unwrap_err().strip_backtrace(), - "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 15\")" + "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 25\")" ) } From 47eac75b5ca4973c18846c0dd0bc38feac4eb1f0 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 10:43:15 -0400 Subject: [PATCH 5/9] move Left, Lpad, Reverse, Right, Rpad functions to datafusion_functions --- datafusion/expr/src/built_in_function.rs | 50 +- datafusion/expr/src/expr_fn.rs | 21 - datafusion/functions/src/unicode/left.rs | 245 +++++++ datafusion/functions/src/unicode/lpad.rs | 383 +++++++++++ datafusion/functions/src/unicode/mod.rs | 44 +- datafusion/functions/src/unicode/reverse.rs | 159 +++++ datafusion/functions/src/unicode/right.rs | 247 +++++++ datafusion/functions/src/unicode/rpad.rs | 375 +++++++++++ datafusion/physical-expr/src/functions.rs | 606 ------------------ datafusion/physical-expr/src/planner.rs | 4 +- .../physical-expr/src/unicode_expressions.rs | 263 +------- datafusion/proto/proto/datafusion.proto | 10 +- datafusion/proto/src/generated/pbjson.rs | 15 - datafusion/proto/src/generated/prost.rs | 20 +- .../proto/src/logical_plan/from_proto.rs | 53 +- datafusion/proto/src/logical_plan/to_proto.rs | 5 - 16 files changed, 1484 insertions(+), 1016 deletions(-) create mode 100644 datafusion/functions/src/unicode/left.rs create mode 100644 datafusion/functions/src/unicode/lpad.rs create mode 100644 datafusion/functions/src/unicode/reverse.rs create mode 100644 datafusion/functions/src/unicode/right.rs create mode 100644 datafusion/functions/src/unicode/rpad.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index eefbc131a27b..196d278dc70e 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -111,18 +111,8 @@ pub enum BuiltinScalarFunction { EndsWith, /// initcap InitCap, - /// left - Left, - /// lpad - Lpad, /// random Random, - /// reverse - Reverse, - /// right - Right, - /// rpad - Rpad, /// strpos Strpos, /// substr @@ -220,12 +210,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, - BuiltinScalarFunction::Left => Volatility::Immutable, - BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Reverse => Volatility::Immutable, - BuiltinScalarFunction::Right => Volatility::Immutable, - BuiltinScalarFunction::Rpad => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, @@ -264,17 +249,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } - BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), - BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), - BuiltinScalarFunction::Reverse => { - utf8_to_str_type(&input_expr_types[0], "reverse") - } - BuiltinScalarFunction::Right => { - utf8_to_str_type(&input_expr_types[0], "right") - } - BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") @@ -361,28 +337,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Reverse => { + BuiltinScalarFunction::InitCap => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), - Exact(vec![Utf8, Int64, LargeUtf8]), - Exact(vec![LargeUtf8, Int64, LargeUtf8]), - ], - self.volatility(), - ) - } - BuiltinScalarFunction::Left | BuiltinScalarFunction::Right => { - Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], - self.volatility(), - ) - } BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { Signature::one_of( @@ -580,11 +537,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Reverse => &["reverse"], - BuiltinScalarFunction::Right => &["right"], - BuiltinScalarFunction::Rpad => &["rpad"], BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::Translate => &["translate"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 654464798625..21dab72855e5 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -578,25 +578,11 @@ scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argu scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); -scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); -scalar_expr!(Reverse, reverse, string, "reverses the `string`"); -scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); -//use vec as parameter -nary_scalar_expr!( - Lpad, - lpad, - "fill up a string to the length by prepending the characters" -); -nary_scalar_expr!( - Rpad, - rpad, - "fill up a string to the length by appending the characters" -); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c nary_scalar_expr!( @@ -1028,13 +1014,6 @@ mod test { test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); - test_scalar_expr!(Left, left, string, count); - test_nary_scalar_expr!(Lpad, lpad, string, count); - test_nary_scalar_expr!(Lpad, lpad, string, count, characters); - test_scalar_expr!(Reverse, reverse, string); - test_scalar_expr!(Right, right, string, count); - test_nary_scalar_expr!(Rpad, rpad, string, count); - test_nary_scalar_expr!(Rpad, rpad, string, count, characters); test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs new file mode 100644 index 000000000000..76da56abc19e --- /dev/null +++ b/datafusion/functions/src/unicode/left.rs @@ -0,0 +1,245 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct LeftFunc { + signature: Signature, +} + +impl LeftFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LeftFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "left" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "left") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(left::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(left::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function left"), + } + } +} + +/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. +/// left('abcde', 2) = 'ab' +/// The implementation uses UTF-8 code points as characters +pub fn left(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let n_array = as_int64_array(&args[1])?; + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => { + let len = string.chars().count() as i64; + Some(if n.abs() < len { + string.chars().take((len + n) as usize).collect::() + } else { + "".to_string() + }) + } + Ordering::Equal => Some("".to_string()), + Ordering::Greater => { + Some(string.chars().take(n as usize).collect::()) + } + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::left::LeftFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("ab")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(200))), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("abc")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-200))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + LeftFunc::new90, + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + internal_err!( + "function left requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs new file mode 100644 index 000000000000..a0968b36920f --- /dev/null +++ b/datafusion/functions/src/unicode/lpad.rs @@ -0,0 +1,383 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct LPadFunc { + signature: Signature, +} + +impl LPadFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LPadFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "lpad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "lpad") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(lpad::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(lpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function lpad"), + } + } +} + +/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// lpad('hi', 5, 'xy') = 'xyxhi' +pub fn lpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "lpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else { + let mut s: String = " ".repeat(length - graphemes.len()); + s.push_str(string); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + let fill_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "lpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector.push( + *fill_chars.get(l % fill_chars.len()).unwrap(), + ); + } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => exec_err!( + "lpad was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::lpad::LPadFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some(" josé")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(Some("xyxhi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(21))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcdef")))), + ], + Ok(Some("abcdefabcdefabcdefahi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(Some("xyxyxyjosé")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("éñ")))), + ], + Ok(Some("éñéñéñjosé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + internal_err!( + "function lpad requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 291de3843903..ea4e70a92199 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -22,6 +22,11 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; mod character_length; +mod left; +mod lpad; +mod reverse; +mod right; +mod rpad; // create UDFs make_udf_function!( @@ -29,6 +34,11 @@ make_udf_function!( CHARACTER_LENGTH, character_length ); +make_udf_function!(left::LeftFunc, LEFT, left); +make_udf_function!(lpad::LPadFunc, LPAD, lpad); +make_udf_function!(right::RightFunc, RIGHT, right); +make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); +make_udf_function!(rpad::RPadFunc, RPAD, rpad); pub mod expr_fn { use datafusion_expr::Expr; @@ -47,9 +57,41 @@ pub mod expr_fn { pub fn length(string: Expr) -> Expr { character_length(string) } + + #[doc = "returns the first `n` characters in the `string`"] + pub fn left(string: Expr, n: Expr) -> Expr { + super::left().call(vec![string, n]) + } + + #[doc = "fill up a string to the length by prepending the characters"] + pub fn lpad(args: Vec) -> Expr { + super::lpad().call(args) + } + + #[doc = "reverses the `string`"] + pub fn reverse(string: Expr) -> Expr { + super::reverse().call(vec![string]) + } + + #[doc = "returns the last `n` characters in the `string`"] + pub fn right(string: Expr, n: Expr) -> Expr { + super::right().call(vec![string, n]) + } + + #[doc = "fill up a string to the length by appending the characters"] + pub fn rpad(args: Vec) -> Expr { + super::rpad().call(args) + } } /// Return a list of all functions in this package pub fn functions() -> Vec> { - vec![character_length()] + vec![ + character_length(), + left(), + lpad(), + reverse(), + right(), + rpad(), + ] } diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs new file mode 100644 index 000000000000..e1996fcb39c4 --- /dev/null +++ b/datafusion/functions/src/unicode/reverse.rs @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct ReverseFunc { + signature: Signature, +} + +impl ReverseFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ReverseFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "reverse" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "reverse") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(reverse::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(reverse::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function reverse") + } + } + } +} + +/// Reverses the order of the characters in the string. +/// reverse('abcde') = 'edcba' +/// The implementation uses UTF-8 code points as characters +pub fn reverse(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| string.map(|string: &str| string.chars().rev().collect::())) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::reverse::ReverseFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("abcde") + )))], + Ok(Some("edcba")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("loẅks") + )))], + Ok(Some("sk̈wol")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("loẅks") + )))], + Ok(Some("sk̈wol")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde"))))], + internal_err!( + "function reverse requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs new file mode 100644 index 000000000000..5eddf7b37bf0 --- /dev/null +++ b/datafusion/functions/src/unicode/right.rs @@ -0,0 +1,247 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::{max, Ordering}; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct RightFunc { + signature: Signature, +} + +impl RightFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RightFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "right" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "right") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(right::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(right::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function right"), + } + } +} + +/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. +/// right('abcde', 2) = 'de' +/// The implementation uses UTF-8 code points as characters +pub fn right(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let n_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => Some( + string + .chars() + .skip(n.unsigned_abs() as usize) + .collect::(), + ), + Ordering::Equal => Some("".to_string()), + Ordering::Greater => Some( + string + .chars() + .skip(max(string.chars().count() as i64 - n, 0) as usize) + .collect::(), + ), + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::right::RightFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("de")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(200))), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-2))), + ], + Ok(Some("cde")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-200))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + Right, + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcde")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + internal_err!( + "function right requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs new file mode 100644 index 000000000000..352b2f823008 --- /dev/null +++ b/datafusion/functions/src/unicode/rpad.rs @@ -0,0 +1,375 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct RPadFunc { + signature: Signature, +} + +impl RPadFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RPadFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rpad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "rpad") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(rpad::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(rpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function rpad"), + } + } +} + +/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// rpad('hi', 5, 'xy') = 'hixyx' +pub fn rpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else { + let mut s = string.to_string(); + s.push_str(" ".repeat(length - graphemes.len()).as_str()); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + let fill_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector + .push(*fill_chars.get(l % fill_chars.len()).unwrap()); + } + s.push_str(char_vector.iter().collect::().as_str()); + Ok(Some(s)) + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => exec_err!( + "rpad was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::rpad::RPadFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("josé ")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(Some("hixyx")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(21))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("abcdef")))), + ], + Ok(Some("hiabcdefabcdefabcdefa")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(" ")))), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("hi")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("xy")))), + ], + Ok(Some("joséxyxyxy")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("éñ")))), + ], + Ok(Some("josééñéñéñ")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + internal_err!( + "function rpad requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 9adc8536341d..c1b4900e399a 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -270,67 +270,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function initcap") } }), - BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function left"), - }), - BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function lpad"), - }), - BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function reverse") - } - }), - BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function right"), - }), - BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function rpad"), - }), BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::ends_with::)(args) @@ -691,551 +630,6 @@ mod tests { Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int8(Some(2))),], - Ok(Some("ab")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(200))),], - Ok(Some("abcde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-2))),], - Ok(Some("abc")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-200))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Left, - &[ - lit("abcde"), - lit(ScalarValue::Int8(Some(2))), - ], - internal_err!( - "function left requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some(" josé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit("xy"),], - Ok(Some("xyxhi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(21))), lit("abcdef"),], - Ok(Some("abcdefabcdefabcdefahi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(" "),], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(""),], - Ok(Some("hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - lit("xy"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(None)), lit("xy"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit("hi"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("xy"),], - Ok(Some("xyxyxyjosé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("éñ"),], - Ok(Some("éñéñéñjosé")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Lpad, - &[ - lit("josé"), - lit(ScalarValue::Int64(Some(5))), - ], - internal_err!( - "function lpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("abcde")], - Ok(Some("edcba")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("loẅks")], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("loẅks")], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Reverse, - &[lit("abcde")], - internal_err!( - "function reverse requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int8(Some(2))),], - Ok(Some("de")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(200))),], - Ok(Some("abcde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-2))),], - Ok(Some("cde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-200))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("éésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("éésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Right, - &[ - lit("abcde"), - lit(ScalarValue::Int8(Some(2))), - ], - internal_err!( - "function right requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("josé ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("hi ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit("xy"),], - Ok(Some("hixyx")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(21))), lit("abcdef"),], - Ok(Some("hiabcdefabcdefabcdefa")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(" "),], - Ok(Some("hi ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(""),], - Ok(Some("hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - lit("xy"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(None)), lit("xy"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit("hi"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("xy"),], - Ok(Some("joséxyxyxy")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("éñ"),], - Ok(Some("josééñéñéñ")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Rpad, - &[ - lit("josé"), - lit(ScalarValue::Int64(Some(5))), - ], - internal_err!( - "function rpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); test_function!( EndsWith, &[lit("alphabet"), lit("alph"),], diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 319d9ca2269a..0dbea09ffb51 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -335,11 +335,11 @@ mod tests { use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{col, left, Literal}; + use datafusion_expr::{col, lit}; #[test] fn test_create_physical_expr_scalar_input_output() -> Result<()> { - let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit())); + let expr = col("letter").eq(lit("A")); let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]); let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?; diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index c7e4b7d7c443..faff21111a61 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -21,7 +21,7 @@ //! Unicode expressions -use std::cmp::{max, Ordering}; +use std::cmp::max; use std::sync::Arc; use arrow::{ @@ -36,267 +36,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. -/// left('abcde', 2) = 'ab' -/// The implementation uses UTF-8 code points as characters -pub fn left(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let n_array = as_int64_array(&args[1])?; - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => { - let len = string.chars().count() as i64; - Some(if n.abs() < len { - string.chars().take((len + n) as usize).collect::() - } else { - "".to_string() - }) - } - Ordering::Equal => Some("".to_string()), - Ordering::Greater => { - Some(string.chars().take(n as usize).collect::()) - } - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). -/// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector.push( - *fill_chars.get(l % fill_chars.len()).unwrap(), - ); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "lpad was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - -/// Reverses the order of the characters in the string. -/// reverse('abcde') = 'edcba' -/// The implementation uses UTF-8 code points as characters -pub fn reverse(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.chars().rev().collect::())) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. -/// right('abcde', 2) = 'de' -/// The implementation uses UTF-8 code points as characters -pub fn right(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let n_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => Some( - string - .chars() - .skip(n.unsigned_abs() as usize) - .collect::(), - ), - Ordering::Equal => Some("".to_string()), - Ordering::Greater => Some( - string - .chars() - .skip(max(string.chars().count() as i64 - n, 0) as usize) - .collect::(), - ), - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. -/// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); - } - s.push_str(char_vector.iter().collect::().as_str()); - Ok(Some(s)) - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "rpad was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 766ca6633ee1..6319372d98d2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -572,8 +572,8 @@ enum ScalarFunction { // 28 was DatePart // 29 was DateTrunc InitCap = 30; - Left = 31; - Lpad = 32; + // 31 was Left + // 32 was Lpad // 33 was Lower // 34 was Ltrim // 35 was MD5 @@ -583,9 +583,9 @@ enum ScalarFunction { // 39 was RegexpReplace // 40 was Repeat // 41 was Replace - Reverse = 42; - Right = 43; - Rpad = 44; + // 42 was Reverse + // 43 was Right + // 44 was Rpad // 45 was Rtrim // 46 was SHA224 // 47 was SHA256 diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f2814956ef1b..7281bc9dc263 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22931,12 +22931,7 @@ impl serde::Serialize for ScalarFunction { Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", - Self::Left => "Left", - Self::Lpad => "Lpad", Self::Random => "Random", - Self::Reverse => "Reverse", - Self::Right => "Right", - Self::Rpad => "Rpad", Self::Strpos => "Strpos", Self::Substr => "Substr", Self::Translate => "Translate", @@ -22990,12 +22985,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Concat", "ConcatWithSeparator", "InitCap", - "Left", - "Lpad", "Random", - "Reverse", - "Right", - "Rpad", "Strpos", "Substr", "Translate", @@ -23078,12 +23068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), - "Left" => Ok(ScalarFunction::Left), - "Lpad" => Ok(ScalarFunction::Lpad), "Random" => Ok(ScalarFunction::Random), - "Reverse" => Ok(ScalarFunction::Reverse), - "Right" => Ok(ScalarFunction::Right), - "Rpad" => Ok(ScalarFunction::Rpad), "Strpos" => Ok(ScalarFunction::Strpos), "Substr" => Ok(ScalarFunction::Substr), "Translate" => Ok(ScalarFunction::Translate), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ecc94fcdaf99..2fe89efb9cea 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2871,8 +2871,8 @@ pub enum ScalarFunction { /// 28 was DatePart /// 29 was DateTrunc InitCap = 30, - Left = 31, - Lpad = 32, + /// 31 was Left + /// 32 was Lpad /// 33 was Lower /// 34 was Ltrim /// 35 was MD5 @@ -2882,9 +2882,9 @@ pub enum ScalarFunction { /// 39 was RegexpReplace /// 40 was Repeat /// 41 was Replace - Reverse = 42, - Right = 43, - Rpad = 44, + /// 42 was Reverse + /// 43 was Right + /// 44 was Rpad /// 45 was Rtrim /// 46 was SHA224 /// 47 was SHA256 @@ -3004,12 +3004,7 @@ impl ScalarFunction { ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", - ScalarFunction::Left => "Left", - ScalarFunction::Lpad => "Lpad", ScalarFunction::Random => "Random", - ScalarFunction::Reverse => "Reverse", - ScalarFunction::Right => "Right", - ScalarFunction::Rpad => "Rpad", ScalarFunction::Strpos => "Strpos", ScalarFunction::Substr => "Substr", ScalarFunction::Translate => "Translate", @@ -3057,12 +3052,7 @@ impl ScalarFunction { "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), - "Left" => Some(Self::Left), - "Lpad" => Some(Self::Lpad), "Random" => Some(Self::Random), - "Reverse" => Some(Self::Reverse), - "Right" => Some(Self::Right), - "Rpad" => Some(Self::Rpad), "Strpos" => Some(Self::Strpos), "Substr" => Some(Self::Substr), "Translate" => Some(Self::Translate), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 19edd71a3a80..2c6f2e479b24 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,18 +17,6 @@ use std::sync::Arc; -use crate::protobuf::{ - self, - plan_type::PlanTypeEnum::{ - AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, - InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, - OptimizedPhysicalPlan, - }, - AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, - OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, -}; - use arrow::{ array::AsArray, buffer::Buffer, @@ -38,6 +26,7 @@ use arrow::{ }, ipc::{reader::read_record_batch, root_as_message}, }; + use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, @@ -51,17 +40,29 @@ use datafusion_expr::{ acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, + factorial, find_in_set, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lpad, nanvl, pi, power, radians, random, reverse, right, round, rpad, signum, sin, - sinh, sqrt, strpos, substr, substr_index, substring, translate, trunc, - AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, - Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, strpos, substr, + substr_index, substring, translate, trunc, AggregateFunction, Between, BinaryExpr, + BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, + GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; +use crate::protobuf::{ + self, + plan_type::PlanTypeEnum::{ + AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, + OptimizedPhysicalPlan, + }, + AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, + OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, +}; + use super::LogicalExtensionCodec; #[derive(Debug)] @@ -453,12 +454,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, - ScalarFunction::Left => Self::Left, - ScalarFunction::Lpad => Self::Lpad, ScalarFunction::Random => Self::Random, - ScalarFunction::Reverse => Self::Reverse, - ScalarFunction::Right => Self::Right, - ScalarFunction::Rpad => Self::Rpad, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, ScalarFunction::Translate => Self::Translate, @@ -1382,26 +1378,13 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Left => Ok(left( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Random => Ok(random()), - ScalarFunction::Reverse => { - Ok(reverse(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Right => Ok(right( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Concat => { Ok(concat_expr(parse_exprs(args, registry, codec)?)) } ScalarFunction::ConcatWithSeparator => { Ok(concat_ws_expr(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Lpad => Ok(lpad(parse_exprs(args, registry, codec)?)), - ScalarFunction::Rpad => Ok(rpad(parse_exprs(args, registry, codec)?)), ScalarFunction::EndsWith => Ok(ends_with( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 11fc7362c75d..ea682a5a22f8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1445,12 +1445,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, - BuiltinScalarFunction::Left => Self::Left, - BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Reverse => Self::Reverse, - BuiltinScalarFunction::Right => Self::Right, - BuiltinScalarFunction::Rpad => Self::Rpad, BuiltinScalarFunction::Strpos => Self::Strpos, BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::Translate => Self::Translate, From d3fac7bb49c13338019f1cc6ba5c9a77c3244372 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 12:56:07 -0400 Subject: [PATCH 6/9] move strpos, substr functions to datafusion_functions --- datafusion/expr/src/built_in_function.rs | 36 +- datafusion/expr/src/expr_fn.rs | 6 - datafusion/functions/src/unicode/mod.rs | 31 ++ datafusion/functions/src/unicode/strpos.rs | 121 ++++++ datafusion/functions/src/unicode/substr.rs | 411 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 258 +---------- .../physical-expr/src/unicode_expressions.rs | 95 ---- datafusion/proto/Cargo.toml | 1 + datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 29 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 - .../tests/cases/roundtrip_logical_plan.rs | 29 +- datafusion/proto/tests/cases/serialize.rs | 5 +- datafusion/sql/src/expr/mod.rs | 9 +- datafusion/sql/src/expr/substring.rs | 16 +- datafusion/sqllogictest/test_files/scalar.slt | 2 +- 18 files changed, 617 insertions(+), 452 deletions(-) create mode 100644 datafusion/functions/src/unicode/strpos.rs create mode 100644 datafusion/functions/src/unicode/substr.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 196d278dc70e..423fc11c1d8c 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -113,10 +113,6 @@ pub enum BuiltinScalarFunction { InitCap, /// random Random, - /// strpos - Strpos, - /// substr - Substr, /// translate Translate, /// substr_index @@ -211,8 +207,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Strpos => Volatility::Immutable, - BuiltinScalarFunction::Substr => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, BuiltinScalarFunction::FindInSet => Volatility::Immutable, @@ -252,12 +246,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::EndsWith => Ok(Boolean), - BuiltinScalarFunction::Strpos => { - utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") - } - BuiltinScalarFunction::Substr => { - utf8_to_str_type(&input_expr_types[0], "substr") - } BuiltinScalarFunction::SubstrIndex => { utf8_to_str_type(&input_expr_types[0], "substr_index") } @@ -341,24 +329,12 @@ impl BuiltinScalarFunction { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - self.volatility(), - ) - } - - BuiltinScalarFunction::Substr => Signature::one_of( + BuiltinScalarFunction::EndsWith => Signature::one_of( vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), ], self.volatility(), ), @@ -537,8 +513,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], - BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::Translate => &["translate"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], BuiltinScalarFunction::FindInSet => &["find_in_set"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 21dab72855e5..09170ae639ff 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -579,9 +579,6 @@ scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); -scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); -scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); -scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c @@ -1015,9 +1012,6 @@ mod test { test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); - test_scalar_expr!(Strpos, strpos, string, substring); - test_scalar_expr!(Substr, substr, string, position); - test_scalar_expr!(Substr, substring, string, position, count); test_scalar_expr!(Translate, translate, string, from, to); test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); test_scalar_expr!(FindInSet, find_in_set, string, stringlist); diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index ea4e70a92199..ddab0d1e27c9 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -27,6 +27,8 @@ mod lpad; mod reverse; mod right; mod rpad; +mod strpos; +mod substr; // create UDFs make_udf_function!( @@ -39,6 +41,8 @@ make_udf_function!(lpad::LPadFunc, LPAD, lpad); make_udf_function!(right::RightFunc, RIGHT, right); make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); make_udf_function!(rpad::RPadFunc, RPAD, rpad); +make_udf_function!(strpos::StrposFunc, STRPOS, strpos); +make_udf_function!(substr::SubstrFunc, SUBSTR, substr); pub mod expr_fn { use datafusion_expr::Expr; @@ -53,6 +57,11 @@ pub mod expr_fn { super::character_length().call(vec![string]) } + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn instr(string: Expr, substring: Expr) -> Expr { + strpos(string, substring) + } + #[doc = "the number of characters in the `string`"] pub fn length(string: Expr) -> Expr { character_length(string) @@ -68,6 +77,11 @@ pub mod expr_fn { super::lpad().call(args) } + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn position(string: Expr, substring: Expr) -> Expr { + strpos(string, substring) + } + #[doc = "reverses the `string`"] pub fn reverse(string: Expr) -> Expr { super::reverse().call(vec![string]) @@ -82,6 +96,21 @@ pub mod expr_fn { pub fn rpad(args: Vec) -> Expr { super::rpad().call(args) } + + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn strpos(string: Expr, substring: Expr) -> Expr { + super::strpos().call(vec![string, substring]) + } + + #[doc = "substring from the `position` to the end"] + pub fn substr(string: Expr, position: Expr) -> Expr { + super::substr().call(vec![string, position]) + } + + #[doc = "substring from the `position` with `length` characters"] + pub fn substring(string: Expr, position: Expr, length: Expr) -> Expr { + super::substr().call(vec![string, position, length]) + } } /// Return a list of all functions in this package @@ -93,5 +122,7 @@ pub fn functions() -> Vec> { reverse(), right(), rpad(), + strpos(), + substr(), ] } diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs new file mode 100644 index 000000000000..1e8bfa37d40e --- /dev/null +++ b/datafusion/functions/src/unicode/strpos.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_int_type}; + +#[derive(Debug)] +pub(super) struct StrposFunc { + signature: Signature, + aliases: Vec, +} + +impl StrposFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("instr"), String::from("position")], + } + } +} + +impl ScalarUDFImpl for StrposFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "strpos" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "strpos/instr/position") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(strpos::, vec![])(args), + DataType::LargeUtf8 => { + make_scalar_function(strpos::, vec![])(args) + } + other => exec_err!("Unsupported data type {other:?} for function strpos"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) +/// strpos('high', 'ig') = 2 +/// The implementation uses UTF-8 code points as characters +fn strpos(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + + let substring_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(substring_array.iter()) + .map(|(string, substring)| match (string, substring) { + (Some(string), Some(substring)) => { + // the find method returns the byte index of the substring + // Next, we count the number of the chars until that byte + T::Native::from_usize( + string + .find(substring) + .map(|x| string[..x].chars().count() + 1) + .unwrap_or(0), + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs new file mode 100644 index 000000000000..7afe8204768a --- /dev/null +++ b/datafusion/functions/src/unicode/substr.rs @@ -0,0 +1,411 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::max; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct SubstrFunc { + signature: Signature, +} + +impl SubstrFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, Int64, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SubstrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "substr") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(substr::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(substr::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function substr"), + } + } +} + +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +/// The implementation uses UTF-8 code points as characters +pub fn substr(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let start_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(start_array.iter()) + .map(|(string, start)| match (string, start) { + (Some(string), Some(start)) => { + if start <= 0 { + Some(string.to_string()) + } else { + Some(string.chars().skip(start as usize - 1).collect()) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let start_array = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(start_array.iter()) + .zip(count_array.iter()) + .map(|((string, start), count)| match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + exec_err!( + "negative substring length not allowed: substr(, {start}, {count})" + ) + } else { + let skip = max(0, start - 1); + let count = max(0, count + (if start < 1 {start - 1} else {0})); + Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!("substr was called with {other} arguments. It requires 2 or 3.") + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{exec_err, Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::substr::SubstrFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ], + Ok(Some("joséésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("lphabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-3))), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(30))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from 5 (10 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(10))), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from -1 (4 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + // starting from 0 (5 + -5) + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-5))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ], + exec_err!("negative substring length not allowed: substr(, 1, -1)"), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("joséésoj")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(5))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("és")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ], + internal_err!( + "function substr requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c1b4900e399a..513dd71d4074 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -281,34 +281,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function ends_with") } }), - BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int32Type, "strpos" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int64Type, "strpos" - ); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function strpos"), - }), - BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function substr"), - }), BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -450,7 +422,7 @@ mod tests { }; use datafusion_common::cast::as_uint64_array; - use datafusion_common::{exec_err, internal_err, plan_err}; + use datafusion_common::{internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::Signature; @@ -663,234 +635,6 @@ mod tests { BooleanArray ); #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("ésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-5))),], - Ok(Some("joséésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(1))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(2))),], - Ok(Some("lphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(3))),], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(30))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("ph")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(0))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some("alph")), - &str, - Utf8, - StringArray - ); - // starting from 5 (10 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(10))), - ], - Ok(Some("alph")), - &str, - Utf8, - StringArray - ); - // starting from -1 (4 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(4))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - // starting from 0 (5 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(None)), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(1))), - lit(ScalarValue::Int64(Some(-1))), - ], - exec_err!("negative substring length not allowed: substr(, 1, -1)"), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("joséésoj"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("és")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(0))), - ], - internal_err!( - "function substr requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] test_function!( Translate, &[lit("12345"), lit("143"), lit("ax"),], diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index faff21111a61..ecbd1ea320d4 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -21,7 +21,6 @@ //! Unicode expressions -use std::cmp::max; use std::sync::Arc; use arrow::{ @@ -36,100 +35,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) -/// strpos('high', 'ig') = 2 -/// The implementation uses UTF-8 code points as characters -pub fn strpos(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let substring_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(substring_array.iter()) - .map(|(string, substring)| match (string, substring) { - (Some(string), Some(substring)) => { - // the find method returns the byte index of the substring - // Next, we count the number of the chars until that byte - T::Native::from_usize( - string - .find(substring) - .map(|x| string[..x].chars().count() + 1) - .unwrap_or(0), - ) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) -/// substr('alphabet', 3) = 'phabet' -/// substr('alphabet', 3, 2) = 'ph' -/// The implementation uses UTF-8 code points as characters -pub fn substr(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { - (Some(string), Some(start)) => { - if start <= 0 { - Some(string.to_string()) - } else { - Some(string.chars().skip(start as usize - 1).collect()) - } - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - let count_array = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( - "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - let skip = max(0, start - 1); - let count = max(0, count + (if start < 1 {start - 1} else {0})); - Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("substr was called with {other} arguments. It requires 2 or 3.") - } - } -} - /// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. /// translate('12345', '143', 'ax') = 'a2x5' pub fn translate(args: &[ArrayRef]) -> Result { diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index f5297aefcd1c..bec2b8c53a7a 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -54,6 +54,7 @@ serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +datafusion-functions = { workspace = true, default-features = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6319372d98d2..3a187eabe836 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -593,8 +593,8 @@ enum ScalarFunction { // 49 was SHA512 // 50 was SplitPart // StartsWith = 51; - Strpos = 52; - Substr = 53; + // 52 was Strpos + // 53 was Substr // ToHex = 54; // 55 was ToTimestamp // 56 was ToTimestampMillis diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7281bc9dc263..07b91b26d60b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22932,8 +22932,6 @@ impl serde::Serialize for ScalarFunction { Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", Self::Random => "Random", - Self::Strpos => "Strpos", - Self::Substr => "Substr", Self::Translate => "Translate", Self::Coalesce => "Coalesce", Self::Power => "Power", @@ -22986,8 +22984,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator", "InitCap", "Random", - "Strpos", - "Substr", "Translate", "Coalesce", "Power", @@ -23069,8 +23065,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), "Random" => Ok(ScalarFunction::Random), - "Strpos" => Ok(ScalarFunction::Strpos), - "Substr" => Ok(ScalarFunction::Substr), "Translate" => Ok(ScalarFunction::Translate), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2fe89efb9cea..babeccec595f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2892,8 +2892,8 @@ pub enum ScalarFunction { /// 49 was SHA512 /// 50 was SplitPart /// StartsWith = 51; - Strpos = 52, - Substr = 53, + /// 52 was Strpos + /// 53 was Substr /// ToHex = 54; /// 55 was ToTimestamp /// 56 was ToTimestampMillis @@ -3005,8 +3005,6 @@ impl ScalarFunction { ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", ScalarFunction::Random => "Random", - ScalarFunction::Strpos => "Strpos", - ScalarFunction::Substr => "Substr", ScalarFunction::Translate => "Translate", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", @@ -3053,8 +3051,6 @@ impl ScalarFunction { "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), "Random" => Some(Self::Random), - "Strpos" => Some(Self::Strpos), - "Substr" => Some(Self::Substr), "Translate" => Some(Self::Translate), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2c6f2e479b24..ff3d6773d512 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -42,10 +42,10 @@ use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, strpos, substr, - substr_index, substring, translate, trunc, AggregateFunction, Between, BinaryExpr, - BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, - GetIndexedField, GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, substr_index, + translate, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, + BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, + GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -455,8 +455,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Random => Self::Random, - ScalarFunction::Strpos => Self::Strpos, - ScalarFunction::Substr => Self::Substr, ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, @@ -1389,25 +1387,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Strpos => Ok(strpos( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Substr => { - if args.len() > 2 { - assert_eq!(args.len(), 3); - Ok(substring( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )) - } else { - Ok(substr( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )) - } - } ScalarFunction::Translate => Ok(translate( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ea682a5a22f8..89d49c5658a2 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1446,8 +1446,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Strpos => Self::Strpos, - BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3c43f100750f..3a47f556c0f3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -34,8 +34,8 @@ use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DFField, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, + internal_datafusion_err, internal_err, not_impl_err, plan_err, DFField, DFSchema, + DFSchemaRef, DataFusionError, FileType, Result, ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -44,8 +44,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - col, create_udaf, lit, Accumulator, AggregateFunction, - BuiltinScalarFunction::{Sqrt, Substr}, + col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, ColumnarValue, Expr, ExprSchemable, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, @@ -60,6 +59,7 @@ use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; use datafusion_proto::protobuf; +use datafusion::execution::FunctionRegistry; use prost::Message; #[cfg(feature = "json")] @@ -1863,17 +1863,28 @@ fn roundtrip_cube() { #[test] fn roundtrip_substr() { + let ctx = SessionContext::new(); + + let fun = ctx + .state() + .udf("substr") + .map_err(|e| { + internal_datafusion_err!("Unable to find expected 'substr' function: {e:?}") + }) + .unwrap(); + // substr(string, position) - let test_expr = - Expr::ScalarFunction(ScalarFunction::new(Substr, vec![col("col"), lit(1_i64)])); + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + fun.clone(), + vec![col("col"), lit(1_i64)], + )); // substr(string, position, count) - let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new( - Substr, + let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new_udf( + fun, vec![col("col"), lit(1_i64), lit(1_i64)], )); - let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx.clone()); roundtrip_expr_test(test_expr_with_count, ctx); } diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d4a1ab44a6ea..972382b841d5 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -260,10 +260,7 @@ fn test_expression_serialization_roundtrip() { let lit = Expr::Literal(ScalarValue::Utf8(None)); for builtin_fun in BuiltinScalarFunction::iter() { // default to 4 args (though some exprs like substr have error checking) - let num_args = match builtin_fun { - BuiltinScalarFunction::Substr => 3, - _ => 4, - }; + let num_args = 4; let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, args)); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d1fc03194997..43bf2d871564 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -823,12 +823,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = BuiltinScalarFunction::Strpos; + let fun = self + .context_provider + .get_function_meta("strpos") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'strpos' function") + })?; let substr = self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; let args = vec![fullstr, substr]; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_agg_with_filter_to_expr( &self, diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index a5d1abf0f265..f58c6f3b94d0 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -16,10 +16,10 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::plan_err; +use datafusion_common::{internal_datafusion_err, plan_err}; use datafusion_common::{DFSchema, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{BuiltinScalarFunction, Expr}; +use datafusion_expr::Expr; use sqlparser::ast::Expr as SQLExpr; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -68,9 +68,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - Ok(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::Substr, - args, - ))) + let fun = self + .context_provider + .get_function_meta("substr") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'substr' function") + })?; + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } } diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a77a2bf4059c..20c8b3d25fdd 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -2087,7 +2087,7 @@ select position('' in '') 1 -query error DataFusion error: Error during planning: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. +query error DataFusion error: Execution error: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. select position(1 in 1) From cda52307700446e8a75c8c70e3c7ec02eb9e4d66 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 19:15:07 -0400 Subject: [PATCH 7/9] move the Translate, SubstrIndex, FindInSet functions to new datafusion-functions crate --- datafusion-cli/Cargo.lock | 1 + datafusion/expr/src/built_in_function.rs | 40 ---- datafusion/expr/src/expr_fn.rs | 7 - datafusion/functions/Cargo.toml | 3 +- .../functions/src/unicode/find_in_set.rs | 129 +++++++++++ datafusion/functions/src/unicode/mod.rs | 24 ++ .../functions/src/unicode/substrindex.rs | 148 ++++++++++++ datafusion/functions/src/unicode/translate.rs | 218 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 148 +----------- datafusion/physical-expr/src/lib.rs | 2 - .../physical-expr/src/unicode_expressions.rs | 172 -------------- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 26 +-- datafusion/proto/src/logical_plan/to_proto.rs | 3 - 16 files changed, 534 insertions(+), 414 deletions(-) create mode 100644 datafusion/functions/src/unicode/find_in_set.rs create mode 100644 datafusion/functions/src/unicode/substrindex.rs create mode 100644 datafusion/functions/src/unicode/translate.rs delete mode 100644 datafusion/physical-expr/src/unicode_expressions.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index ba60c04cea55..522531a83179 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1267,6 +1267,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", + "hashbrown 0.14.3", "hex", "itertools", "log", diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 423fc11c1d8c..487bb893016f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -113,12 +113,6 @@ pub enum BuiltinScalarFunction { InitCap, /// random Random, - /// translate - Translate, - /// substr_index - SubstrIndex, - /// find_in_set - FindInSet, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -207,9 +201,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Translate => Volatility::Immutable, - BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, - BuiltinScalarFunction::FindInSet => Volatility::Immutable, // Volatile builtin functions BuiltinScalarFunction::Random => Volatility::Volatile, @@ -246,15 +237,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::EndsWith => Ok(Boolean), - BuiltinScalarFunction::SubstrIndex => { - utf8_to_str_type(&input_expr_types[0], "substr_index") - } - BuiltinScalarFunction::FindInSet => { - utf8_to_int_type(&input_expr_types[0], "find_in_set") - } - BuiltinScalarFunction::Translate => { - utf8_to_str_type(&input_expr_types[0], "translate") - } BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Gcd @@ -338,22 +320,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - - BuiltinScalarFunction::SubstrIndex => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - self.volatility(), - ), - BuiltinScalarFunction::FindInSet => Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], - self.volatility(), - ), - - BuiltinScalarFunction::Translate => { - Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) - } BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Power => Signature::one_of( @@ -513,9 +479,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], - BuiltinScalarFunction::FindInSet => &["find_in_set"], } } } @@ -580,9 +543,6 @@ macro_rules! get_optimal_return_type { // `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); -// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. -get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 09170ae639ff..6cbe32b81e9d 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -579,7 +579,6 @@ scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); -scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c nary_scalar_expr!( @@ -596,9 +595,6 @@ scalar_expr!( "returns true if a given number is +0.0 or -0.0 otherwise returns false" ); -scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); -scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"); - /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None) @@ -1012,8 +1008,5 @@ mod test { test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); - test_scalar_expr!(Translate, translate, string, from, to); - test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); - test_scalar_expr!(FindInSet, find_in_set, string, stringlist); } } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 0cab0276ff4b..70c6b3e238d6 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ regex_expressions = ["regex"] # enable string functions string_expressions = [] # enable unicode functions -unicode_expressions = ["unicode-segmentation"] +unicode_expressions = ["hashbrown", "unicode-segmentation"] [lib] name = "datafusion_functions" @@ -72,6 +72,7 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } +hashbrown = { version = "0.14", features = ["raw"], optional = true } hex = { version = "0.4", optional = true } itertools = { workspace = true } log = { workspace = true } diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs new file mode 100644 index 000000000000..9a9ab09de2f3 --- /dev/null +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_int_type}; + +#[derive(Debug)] +pub(super) struct FindInSetFunc { + signature: Signature, +} + +impl FindInSetFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for FindInSetFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "find_in_set" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "find_in_set") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(find_in_set::, vec![])(args) + } + DataType::LargeUtf8 => { + make_scalar_function(find_in_set::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + } + } +} + +///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings +///A string list is a string composed of substrings separated by , characters. +pub fn find_in_set(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + if args.len() != 2 { + return exec_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + let str_list_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = str_array + .iter() + .zip(str_list_array.iter()) + .map(|(string, str_list)| match (string, str_list) { + (Some(string), Some(str_list)) => { + let mut res = 0; + let str_set: Vec<&str> = str_list.split(',').collect(); + for (idx, str) in str_set.iter().enumerate() { + if str == &string { + res = idx + 1; + break; + } + } + T::Native::from_usize(res) + } + _ => None, + }) + .collect::>(); + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use datafusion_common::Result; + + #[test] + fn test_functions() -> Result<()> { + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index ddab0d1e27c9..eba4cd5048eb 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; mod character_length; +mod find_in_set; mod left; mod lpad; mod reverse; @@ -29,6 +30,8 @@ mod right; mod rpad; mod strpos; mod substr; +mod substrindex; +mod translate; // create UDFs make_udf_function!( @@ -36,6 +39,7 @@ make_udf_function!( CHARACTER_LENGTH, character_length ); +make_udf_function!(find_in_set::FindInSetFunc, FIND_IN_SET, find_in_set); make_udf_function!(left::LeftFunc, LEFT, left); make_udf_function!(lpad::LPadFunc, LPAD, lpad); make_udf_function!(right::RightFunc, RIGHT, right); @@ -43,6 +47,8 @@ make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); make_udf_function!(rpad::RPadFunc, RPAD, rpad); make_udf_function!(strpos::StrposFunc, STRPOS, strpos); make_udf_function!(substr::SubstrFunc, SUBSTR, substr); +make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); +make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); pub mod expr_fn { use datafusion_expr::Expr; @@ -57,6 +63,11 @@ pub mod expr_fn { super::character_length().call(vec![string]) } + #[doc = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"] + pub fn find_in_set(string: Expr, strlist: Expr) -> Expr { + super::find_in_set().call(vec![string, strlist]) + } + #[doc = "finds the position from where the `substring` matches the `string`"] pub fn instr(string: Expr, substring: Expr) -> Expr { strpos(string, substring) @@ -111,12 +122,23 @@ pub mod expr_fn { pub fn substring(string: Expr, position: Expr, length: Expr) -> Expr { super::substr().call(vec![string, position, length]) } + + #[doc = "Returns the substring from str before count occurrences of the delimiter"] + pub fn substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr { + super::substr_index().call(vec![string, delimiter, count]) + } + + #[doc = "replaces the characters in `from` with the counterpart in `to`"] + pub fn translate(string: Expr, from: Expr, to: Expr) -> Expr { + super::translate().call(vec![string, from, to]) + } } /// Return a list of all functions in this package pub fn functions() -> Vec> { vec![ character_length(), + find_in_set(), left(), lpad(), reverse(), @@ -124,5 +146,7 @@ pub fn functions() -> Vec> { rpad(), strpos(), substr(), + substr_index(), + translate(), ] } diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs new file mode 100644 index 000000000000..d115a31cb54f --- /dev/null +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct SubstrIndexFunc { + signature: Signature, + aliases: Vec, +} + +impl SubstrIndexFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("substring_index")], + } + } +} + +impl ScalarUDFImpl for SubstrIndexFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substr_index" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "substr_index") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(substr_index::, vec![])(args), + DataType::LargeUtf8 => { + make_scalar_function(substr_index::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function substr_index") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!( + "substr_index was called with {} arguments. It requires 3.", + args.len() + ); + } + + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + // In MySQL, these cases will return an empty string. + if n == 0 || string.is_empty() || delimiter.is_empty() { + return Some(String::new()); + } + + let splitted: Box> = if n > 0 { + Box::new(string.split(delimiter)) + } else { + Box::new(string.rsplit(delimiter)) + }; + let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); + // The length of the substring covered by substr_index. + let length = splitted + .take(occurrences) // at least 1 element, since n != 0 + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len(); + if n > 0 { + Some(string[..length].to_owned()) + } else { + Some(string[string.len() - length..].to_owned()) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use datafusion_common::Result; + + #[test] + fn test_functions() -> Result<()> { + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs new file mode 100644 index 000000000000..25f6aed55fd1 --- /dev/null +++ b/datafusion/functions/src/unicode/translate.rs @@ -0,0 +1,218 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use hashbrown::HashMap; +use unicode_segmentation::UnicodeSegmentation; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct TranslateFunc { + signature: Signature, +} + +impl TranslateFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TranslateFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "translate" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "translate") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(translate::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(translate::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function translate") + } + } + } +} + +/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. +/// translate('12345', '143', 'ax') = 'a2x5' +fn translate(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let from_array = as_generic_string_array::(&args[1])?; + let to_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => { + // create a hashmap of [char, index] to change from O(n) to O(1) for from list + let from_map: HashMap<&str, usize> = from + .graphemes(true) + .collect::>() + .iter() + .enumerate() + .map(|(index, c)| (c.to_owned(), index)) + .collect(); + + let to = to.graphemes(true).collect::>(); + + Some( + string + .graphemes(true) + .collect::>() + .iter() + .flat_map(|c| match from_map.get(*c) { + Some(n) => to.get(*n).copied(), + None => Some(*c), + }) + .collect::>() + .concat(), + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::translate::TranslateFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("12345")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("143")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ax")))) + ], + Ok(Some("a2x5")), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("143")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ax")))) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("12345")))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ax")))) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("12345")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("143")))), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("é2íñ5")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("éñí")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("óü")))), + ], + Ok(Some("ó2ü5")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("12345")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("143")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ax")))), + ], + internal_err!( + "function translate requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 513dd71d4074..3717ab6b3cad 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -35,7 +35,7 @@ use std::sync::Arc; use arrow::{ array::ArrayRef, - datatypes::{DataType, Int32Type, Int64Type, Schema}, + datatypes::{DataType, Schema}, }; use arrow_array::Array; @@ -85,26 +85,6 @@ pub fn create_physical_expr( ))) } -#[cfg(feature = "unicode_expressions")] -macro_rules! invoke_if_unicode_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => {{ - use crate::unicode_expressions; - unicode_expressions::$FUNC::<$T> - }}; -} - -#[cfg(not(feature = "unicode_expressions"))] -macro_rules! invoke_if_unicode_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => { - |_: &[ArrayRef]| -> Result { - internal_err!( - "function {} requires compilation with feature flag: unicode_expressions.", - $NAME - ) - } - }; -} - #[derive(Debug, Clone, Copy)] pub enum Hint { /// Indicates the argument needs to be padded if it is scalar @@ -281,71 +261,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function ends_with") } }), - BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i32, - "translate" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i64, - "translate" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function translate") - } - }), - BuiltinScalarFunction::SubstrIndex => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - substr_index, - i32, - "substr_index" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - substr_index, - i64, - "substr_index" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function substr_index") - } - }) - } - BuiltinScalarFunction::FindInSet => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - find_in_set, - Int32Type, - "find_in_set" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - find_in_set, - Int64Type, - "find_in_set" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function find_in_set") - } - }), }) } @@ -634,66 +549,7 @@ mod tests { Boolean, BooleanArray ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit("143"), lit("ax"),], - Ok(Some("a2x5")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit(ScalarValue::Utf8(None)), lit("143"), lit("ax"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit(ScalarValue::Utf8(None)), lit("ax"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit("143"), lit(ScalarValue::Utf8(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("é2íñ5"), lit("éñí"), lit("óü"),], - Ok(Some("ó2ü5")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Translate, - &[ - lit("12345"), - lit("143"), - lit("ax"), - ], - internal_err!( - "function translate requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); + Ok(()) } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 1dead099540b..7819d5116160 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -33,8 +33,6 @@ pub mod sort_properties; pub mod string_expressions; pub mod tree_node; pub mod udf; -#[cfg(feature = "unicode_expressions")] -pub mod unicode_expressions; pub mod utils; pub mod window; diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs deleted file mode 100644 index ecbd1ea320d4..000000000000 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ /dev/null @@ -1,172 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Some of these functions reference the Postgres documentation -// or implementation to ensure compatibility and are subject to -// the Postgres license. - -//! Unicode expressions - -use std::sync::Arc; - -use arrow::{ - array::{ArrayRef, GenericStringArray, OffsetSizeTrait, PrimitiveArray}, - datatypes::{ArrowNativeType, ArrowPrimitiveType}, -}; -use hashbrown::HashMap; -use unicode_segmentation::UnicodeSegmentation; - -use datafusion_common::{ - cast::{as_generic_string_array, as_int64_array}, - exec_err, Result, -}; - -/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. -/// translate('12345', '143', 'ax') = 'a2x5' -pub fn translate(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let from_array = as_generic_string_array::(&args[1])?; - let to_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .map(|((string, from), to)| match (string, from, to) { - (Some(string), Some(from), Some(to)) => { - // create a hashmap of [char, index] to change from O(n) to O(1) for from list - let from_map: HashMap<&str, usize> = from - .graphemes(true) - .collect::>() - .iter() - .enumerate() - .map(|(index, c)| (c.to_owned(), index)) - .collect(); - - let to = to.graphemes(true).collect::>(); - - Some( - string - .graphemes(true) - .collect::>() - .iter() - .flat_map(|c| match from_map.get(*c) { - Some(n) => to.get(*n).copied(), - None => Some(*c), - }) - .collect::>() - .concat(), - ) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. -/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www -/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache -/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org -/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org -pub fn substr_index(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!( - "substr_index was called with {} arguments. It requires 3.", - args.len() - ); - } - - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let count_array = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(delimiter_array.iter()) - .zip(count_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - // In MySQL, these cases will return an empty string. - if n == 0 || string.is_empty() || delimiter.is_empty() { - return Some(String::new()); - } - - let splitted: Box> = if n > 0 { - Box::new(string.split(delimiter)) - } else { - Box::new(string.rsplit(delimiter)) - }; - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - // The length of the substring covered by substr_index. - let length = splitted - .take(occurrences) // at least 1 element, since n != 0 - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len(); - if n > 0 { - Some(string[..length].to_owned()) - } else { - Some(string[string.len() - length..].to_owned()) - } - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings -///A string list is a string composed of substrings separated by , characters. -pub fn find_in_set(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - if args.len() != 2 { - return exec_err!( - "find_in_set was called with {} arguments. It requires 2.", - args.len() - ); - } - - let str_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - let str_list_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = str_array - .iter() - .zip(str_list_array.iter()) - .map(|(string, str_list)| match (string, str_list) { - (Some(string), Some(str_list)) => { - let mut res = 0; - let str_set: Vec<&str> = str_list.split(',').collect(); - for (idx, str) in str_set.iter().enumerate() { - if str == &string { - res = idx + 1; - break; - } - } - T::Native::from_usize(res) - } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) -} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 3a187eabe836..b9fc92e1eeb0 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -601,7 +601,7 @@ enum ScalarFunction { // 57 was ToTimestampMicros // 58 was ToTimestampSeconds // 59 was Now - Translate = 60; + // 60 was Translate // Trim = 61; // Upper = 62; Coalesce = 63; @@ -665,8 +665,8 @@ enum ScalarFunction { // 123 is ArrayExcept // 124 was ArrayPopFront // 125 was Levenshtein - SubstrIndex = 126; - FindInSet = 127; + // 126 was SubstrIndex + // 127 was FindInSet // 128 was ArraySort // 129 was ArrayDistinct // 130 was ArrayResize diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 07b91b26d60b..179fb0aaf824 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22932,7 +22932,6 @@ impl serde::Serialize for ScalarFunction { Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", Self::Random => "Random", - Self::Translate => "Translate", Self::Coalesce => "Coalesce", Self::Power => "Power", Self::Atan2 => "Atan2", @@ -22951,8 +22950,6 @@ impl serde::Serialize for ScalarFunction { Self::Cot => "Cot", Self::Nanvl => "Nanvl", Self::Iszero => "Iszero", - Self::SubstrIndex => "SubstrIndex", - Self::FindInSet => "FindInSet", Self::EndsWith => "EndsWith", }; serializer.serialize_str(variant) @@ -22984,7 +22981,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator", "InitCap", "Random", - "Translate", "Coalesce", "Power", "Atan2", @@ -23003,8 +22999,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cot", "Nanvl", "Iszero", - "SubstrIndex", - "FindInSet", "EndsWith", ]; @@ -23065,7 +23059,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), "Random" => Ok(ScalarFunction::Random), - "Translate" => Ok(ScalarFunction::Translate), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), "Atan2" => Ok(ScalarFunction::Atan2), @@ -23084,8 +23077,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cot" => Ok(ScalarFunction::Cot), "Nanvl" => Ok(ScalarFunction::Nanvl), "Iszero" => Ok(ScalarFunction::Iszero), - "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), - "FindInSet" => Ok(ScalarFunction::FindInSet), "EndsWith" => Ok(ScalarFunction::EndsWith), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index babeccec595f..34610c0f6477 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2900,7 +2900,7 @@ pub enum ScalarFunction { /// 57 was ToTimestampMicros /// 58 was ToTimestampSeconds /// 59 was Now - Translate = 60, + /// 60 was Translate /// Trim = 61; /// Upper = 62; Coalesce = 63, @@ -2964,8 +2964,8 @@ pub enum ScalarFunction { /// 123 is ArrayExcept /// 124 was ArrayPopFront /// 125 was Levenshtein - SubstrIndex = 126, - FindInSet = 127, + /// 126 was SubstrIndex + /// 127 was FindInSet /// 128 was ArraySort /// 129 was ArrayDistinct /// 130 was ArrayResize @@ -3005,7 +3005,6 @@ impl ScalarFunction { ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", ScalarFunction::Random => "Random", - ScalarFunction::Translate => "Translate", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", ScalarFunction::Atan2 => "Atan2", @@ -3024,8 +3023,6 @@ impl ScalarFunction { ScalarFunction::Cot => "Cot", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Iszero => "Iszero", - ScalarFunction::SubstrIndex => "SubstrIndex", - ScalarFunction::FindInSet => "FindInSet", ScalarFunction::EndsWith => "EndsWith", } } @@ -3051,7 +3048,6 @@ impl ScalarFunction { "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), "Random" => Some(Self::Random), - "Translate" => Some(Self::Translate), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), "Atan2" => Some(Self::Atan2), @@ -3070,8 +3066,6 @@ impl ScalarFunction { "Cot" => Some(Self::Cot), "Nanvl" => Some(Self::Nanvl), "Iszero" => Some(Self::Iszero), - "SubstrIndex" => Some(Self::SubstrIndex), - "FindInSet" => Some(Self::FindInSet), "EndsWith" => Some(Self::EndsWith), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ff3d6773d512..46732e596629 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -40,12 +40,11 @@ use datafusion_expr::{ acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, + factorial, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, substr_index, - translate, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, - BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, - GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, trunc, + AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, + Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -455,15 +454,12 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Random => Self::Random, - ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, ScalarFunction::Power => Self::Power, ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, - ScalarFunction::SubstrIndex => Self::SubstrIndex, - ScalarFunction::FindInSet => Self::FindInSet, } } } @@ -1387,11 +1383,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Translate => Ok(translate( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } @@ -1416,15 +1407,6 @@ pub fn parse_expr( ScalarFunction::Iszero => { Ok(iszero(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::SubstrIndex => Ok(substr_index( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::FindInSet => Ok(find_in_set( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 89d49c5658a2..d7ef3d0d3f40 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1446,15 +1446,12 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, - BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, - BuiltinScalarFunction::FindInSet => Self::FindInSet, }; Ok(scalar_function) From d6cc497c9a3f85c93fe6fe17ed6a87b2053c0a20 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 21:30:44 -0400 Subject: [PATCH 8/9] Test code cleanup --- .../functions/src/unicode/find_in_set.rs | 10 ------ .../functions/src/unicode/substrindex.rs | 10 ------ datafusion/functions/src/unicode/translate.rs | 35 ++++++++----------- 3 files changed, 15 insertions(+), 40 deletions(-) diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 9a9ab09de2f3..7e0306d49454 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -117,13 +117,3 @@ where .collect::>(); Ok(Arc::new(result) as ArrayRef) } - -#[cfg(test)] -mod tests { - use datafusion_common::Result; - - #[test] - fn test_functions() -> Result<()> { - Ok(()) - } -} diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index d115a31cb54f..77e8116fff4c 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -136,13 +136,3 @@ pub fn substr_index(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } - -#[cfg(test)] -mod tests { - use datafusion_common::Result; - - #[test] - fn test_functions() -> Result<()> { - Ok(()) - } -} diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index 25f6aed55fd1..bc1836700304 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -132,51 +132,47 @@ mod tests { #[test] fn test_functions() -> Result<()> { - #[cfg(feature = "unicode_expressions")] test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("12345")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("143")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ax")))) + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) ], Ok(Some("a2x5")), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( TranslateFunc::new(), &[ ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("143")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ax")))) + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) ], Ok(None), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("12345")))), + ColumnarValue::Scalar(ScalarValue::from("12345")), ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ax")))) + ColumnarValue::Scalar(ScalarValue::from("ax")) ], Ok(None), &str, Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("12345")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("143")))), + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), ColumnarValue::Scalar(ScalarValue::Utf8(None)) ], Ok(None), @@ -184,13 +180,12 @@ mod tests { Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("é2íñ5")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("éñí")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("óü")))), + ColumnarValue::Scalar(ScalarValue::from("é2íñ5")), + ColumnarValue::Scalar(ScalarValue::from("éñí")), + ColumnarValue::Scalar(ScalarValue::from("óü")), ], Ok(Some("ó2ü5")), &str, @@ -201,9 +196,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("12345")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("143")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ax")))), + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")), ], internal_err!( "function translate requires compilation with feature flag: unicode_expressions." From 3fa6281ebd0bfdd9ecff354e95fd5b56a7ae56c4 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 22:52:34 -0400 Subject: [PATCH 9/9] unicode_expressions Cargo.toml updates. --- datafusion-cli/Cargo.lock | 1 - datafusion/core/Cargo.toml | 2 -- datafusion/optimizer/Cargo.toml | 3 +-- datafusion/physical-expr/Cargo.toml | 3 --- 4 files changed, 1 insertion(+), 8 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 522531a83179..f4be1c97bcb2 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1342,7 +1342,6 @@ dependencies = [ "rand", "regex", "sha2", - "unicode-segmentation", ] [[package]] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index de03579975a2..f7007d653f48 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -67,8 +67,6 @@ regex_expressions = [ ] serde = ["arrow-schema/serde"] unicode_expressions = [ - "datafusion-physical-expr/unicode_expressions", - "datafusion-optimizer/unicode_expressions", "datafusion-sql/unicode_expressions", "datafusion-functions/unicode_expressions", ] diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 861715b351a6..1d64a22f1463 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -34,9 +34,8 @@ path = "src/lib.rs" [features] crypto_expressions = ["datafusion-physical-expr/crypto_expressions"] -default = ["unicode_expressions", "crypto_expressions", "regex_expressions"] +default = ["crypto_expressions", "regex_expressions"] regex_expressions = ["datafusion-physical-expr/regex_expressions"] -unicode_expressions = ["datafusion-physical-expr/unicode_expressions"] [dependencies] arrow = { workspace = true } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 24b831e7c575..baca00bea724 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,12 +37,10 @@ crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] default = [ "crypto_expressions", "regex_expressions", - "unicode_expressions", "encoding_expressions", ] encoding_expressions = ["base64", "hex"] regex_expressions = ["regex"] -unicode_expressions = ["unicode-segmentation"] [dependencies] ahash = { version = "0.8", default-features = false, features = [ @@ -73,7 +71,6 @@ petgraph = "0.6.2" rand = { workspace = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } -unicode-segmentation = { version = "^1.7.1", optional = true } [dev-dependencies] criterion = "0.5"