From eb13f598fa722382cb6580ff9fc31e458ad7a8a9 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Mon, 18 Mar 2024 09:48:12 -0500 Subject: [PATCH 001/117] Adding Constant Check for FilterExec (#9649) * fix bugs in adding extra SortExec * adding tests * optimize code * Update datafusion/physical-plan/src/filter.rs Co-authored-by: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> * optimize code * optimize code * optimize code * optimize code * fix clippy --------- Co-authored-by: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> --- datafusion/physical-plan/src/filter.rs | 27 +++++++- .../test_files/filter_without_sort_exec.slt | 61 +++++++++++++++++++ 2 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 datafusion/sqllogictest/test_files/filter_without_sort_exec.slt diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 4155b00820f4..72f885a93962 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -159,6 +159,27 @@ impl FilterExec { }) } + fn extend_constants( + input: &Arc, + predicate: &Arc, + ) -> Vec> { + let mut res_constants = Vec::new(); + let input_eqs = input.equivalence_properties(); + + let conjunctions = split_conjunction(predicate); + for conjunction in conjunctions { + if let Some(binary) = conjunction.as_any().downcast_ref::() { + if binary.op() == &Operator::Eq { + if input_eqs.is_expr_constant(binary.left()) { + res_constants.push(binary.right().clone()) + } else if input_eqs.is_expr_constant(binary.right()) { + res_constants.push(binary.left().clone()) + } + } + } + } + res_constants + } /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. fn compute_properties( input: &Arc, @@ -181,8 +202,12 @@ impl FilterExec { .into_iter() .filter(|column| stats.column_statistics[column.index()].is_singleton()) .map(|column| Arc::new(column) as _); + // this is for statistics eq_properties = eq_properties.add_constants(constants); - + // this is for logical constant (for example: a = '1', then a could be marked as a constant) + // to do: how to deal with multiple situation to represent = (for example c1 between 0 and 0) + eq_properties = + eq_properties.add_constants(Self::extend_constants(input, predicate)); Ok(PlanProperties::new( eq_properties, input.output_partitioning().clone(), // Output Partitioning diff --git a/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt new file mode 100644 index 000000000000..05e622db8a02 --- /dev/null +++ b/datafusion/sqllogictest/test_files/filter_without_sort_exec.slt @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# prepare table +statement ok +CREATE UNBOUNDED EXTERNAL TABLE data ( + "date" VARCHAR, + "ticker" VARCHAR, + "time" VARCHAR, +) STORED AS CSV +WITH ORDER ("date", "ticker", "time") +LOCATION './a.parquet'; + + +# query +query TT +explain SELECT * FROM data +WHERE ticker = 'A' +ORDER BY "date", "time"; +---- +logical_plan +Sort: data.date ASC NULLS LAST, data.time ASC NULLS LAST +--Filter: data.ticker = Utf8("A") +----TableScan: data projection=[date, ticker, time] +physical_plan +SortPreservingMergeExec: [date@0 ASC NULLS LAST,time@2 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: ticker@1 = A +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------StreamingTableExec: partition_sizes=1, projection=[date, ticker, time], infinite_source=true, output_ordering=[date@0 ASC NULLS LAST, ticker@1 ASC NULLS LAST, time@2 ASC NULLS LAST] + +# query +query TT +explain SELECT * FROM data +WHERE date = 'A' +ORDER BY "ticker", "time"; +---- +logical_plan +Sort: data.ticker ASC NULLS LAST, data.time ASC NULLS LAST +--Filter: data.date = Utf8("A") +----TableScan: data projection=[date, ticker, time] +physical_plan +SortPreservingMergeExec: [ticker@1 ASC NULLS LAST,time@2 ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: date@0 = A +------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------StreamingTableExec: partition_sizes=1, projection=[date, ticker, time], infinite_source=true, output_ordering=[date@0 ASC NULLS LAST, ticker@1 ASC NULLS LAST, time@2 ASC NULLS LAST] From 4e8ac98fbbebbf965eebba5cc40ecf7c590a6d28 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 18 Mar 2024 13:37:17 -0400 Subject: [PATCH 002/117] chore(deps-dev): bump follow-redirects (#9609) Bumps [follow-redirects](https://github.com/follow-redirects/follow-redirects) from 1.15.4 to 1.15.6. - [Release notes](https://github.com/follow-redirects/follow-redirects/releases) - [Commits](https://github.com/follow-redirects/follow-redirects/compare/v1.15.4...v1.15.6) --- updated-dependencies: - dependency-name: follow-redirects dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../wasmtest/datafusion-wasm-app/package-lock.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 5163c99bd5ac..aac87845bc9f 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -1666,9 +1666,9 @@ } }, "node_modules/follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true, "funding": [ { @@ -5580,9 +5580,9 @@ } }, "follow-redirects": { - "version": "1.15.4", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.4.tgz", - "integrity": "sha512-Cr4D/5wlrb0z9dgERpUL3LrmPKVDsETIJhaCMeDfuFYcqa5bldGV6wBsAN6X/vxlXQtFBMrXdXxdL8CbDTGniw==", + "version": "1.15.6", + "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", + "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", "dev": true }, "forwarded": { From 449738cd41158cb7cf65ad98abb8fda882256586 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Tue, 19 Mar 2024 02:21:07 +0800 Subject: [PATCH 003/117] move array_replace family functions to datafusion-function-array crate (#9651) * Add array replace functions * fix ci * fix ci * Update dependencies in Cargo.lock file * Fix formatting in comment * fix ci * rename mod * fix conflict * remove duplicated function * fix: clippy --------- Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 47 +- datafusion/core/benches/array_expression.rs | 42 +- datafusion/expr/src/built_in_function.rs | 28 -- datafusion/expr/src/expr_fn.rs | 24 - datafusion/functions-array/Cargo.toml | 1 + datafusion/functions-array/src/core.rs | 2 +- datafusion/functions-array/src/lib.rs | 9 +- datafusion/functions-array/src/position.rs | 106 +---- datafusion/functions-array/src/replace.rs | 362 +++++++++++++++ datafusion/functions-array/src/utils.rs | 6 +- .../physical-expr/src/array_expressions.rs | 423 ------------------ datafusion/physical-expr/src/functions.rs | 16 +- datafusion/physical-expr/src/lib.rs | 1 - datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 25 +- datafusion/proto/src/logical_plan/to_proto.rs | 3 - .../tests/cases/roundtrip_logical_plan.rs | 8 + 19 files changed, 435 insertions(+), 695 deletions(-) create mode 100644 datafusion/functions-array/src/replace.rs delete mode 100644 datafusion/physical-expr/src/array_expressions.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index deda497d9dd3..8e2a2c353e2d 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -378,13 +378,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.78" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -776,9 +776,9 @@ dependencies = [ [[package]] name = "brotli" -version = "3.4.0" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "516074a47ef4bce09577a3b379392300159ce5b1ba2e501ff1c819950066100f" +checksum = "d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391" dependencies = [ "alloc-no-stdlib", "alloc-stdlib", @@ -1080,7 +1080,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad291aa74992b9b7a7e88c38acbbf6ad7e107f1d90ee8775b7bc1fc3394f485c" dependencies = [ "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1270,6 +1270,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-buffer", + "arrow-ord", "arrow-schema", "datafusion-common", "datafusion-execution", @@ -1639,7 +1640,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -1713,9 +1714,9 @@ checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] name = "h2" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" +checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb" dependencies = [ "bytes", "fnv", @@ -2560,7 +2561,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3101,7 +3102,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3236,7 +3237,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3282,7 +3283,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3295,7 +3296,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3317,9 +3318,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.52" +version = "2.0.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" dependencies = [ "proc-macro2", "quote", @@ -3403,7 +3404,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3498,7 +3499,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3595,7 +3596,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3640,7 +3641,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] @@ -3794,7 +3795,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", "wasm-bindgen-shared", ] @@ -3828,7 +3829,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4086,7 +4087,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.53", ] [[package]] diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/core/benches/array_expression.rs index 95bc93e0e353..c980329620aa 100644 --- a/datafusion/core/benches/array_expression.rs +++ b/datafusion/core/benches/array_expression.rs @@ -22,48 +22,32 @@ extern crate datafusion; mod data_utils; use crate::criterion::Criterion; -use arrow_array::cast::AsArray; -use arrow_array::types::Int64Type; -use arrow_array::{ArrayRef, Int64Array, ListArray}; -use datafusion_physical_expr::array_expressions; -use std::sync::Arc; +use datafusion::functions_array::expr_fn::{array_replace_all, make_array}; +use datafusion_expr::lit; fn criterion_benchmark(c: &mut Criterion) { // Construct large arrays for benchmarking let array_len = 100000000; - let array = (0..array_len).map(|_| Some(2_i64)).collect::>(); - let list_array = ListArray::from_iter_primitive::(vec![ - Some(array.clone()), - Some(array.clone()), - Some(array), - ]); - let from_array = Int64Array::from_value(2, 3); - let to_array = Int64Array::from_value(-2, 3); + let array = (0..array_len).map(|_| lit(2_i64)).collect::>(); + let list_array = make_array(vec![make_array(array); 3]); + let from_array = make_array(vec![lit(2_i64); 3]); + let to_array = make_array(vec![lit(-2_i64); 3]); - let args = vec![ - Arc::new(list_array) as ArrayRef, - Arc::new(from_array) as ArrayRef, - Arc::new(to_array) as ArrayRef, - ]; - - let array = (0..array_len).map(|_| Some(-2_i64)).collect::>(); - let expected_array = ListArray::from_iter_primitive::(vec![ - Some(array.clone()), - Some(array.clone()), - Some(array), - ]); + let expected_array = list_array.clone(); // Benchmark array functions c.bench_function("array_replace", |b| { b.iter(|| { assert_eq!( - array_expressions::array_replace_all(args.as_slice()) - .unwrap() - .as_list::(), - criterion::black_box(&expected_array) + array_replace_all( + list_array.clone(), + from_array.clone(), + to_array.clone() + ), + *criterion::black_box(&expected_array) ) }) }); diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fe3397b1af52..79cd6a24ce39 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -102,14 +102,6 @@ pub enum BuiltinScalarFunction { /// cot Cot, - // array functions - /// array_replace - ArrayReplace, - /// array_replace_n - ArrayReplaceN, - /// array_replace_all - ArrayReplaceAll, - // string functions /// ascii Ascii, @@ -262,9 +254,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplace => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable, - BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::Btrim => Volatility::Immutable, @@ -322,9 +311,6 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::ArrayReplace => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()), - BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::Ascii => Ok(Int32), BuiltinScalarFunction::BitLength => { utf8_to_int_type(&input_expr_types[0], "bit_length") @@ -477,11 +463,6 @@ impl BuiltinScalarFunction { // for now, the list is small, as we do not have many built-in functions. match self { - BuiltinScalarFunction::ArrayReplace => Signature::any(3, self.volatility()), - BuiltinScalarFunction::ArrayReplaceN => Signature::any(4, self.volatility()), - BuiltinScalarFunction::ArrayReplaceAll => { - Signature::any(3, self.volatility()) - } BuiltinScalarFunction::Concat | BuiltinScalarFunction::ConcatWithSeparator => { Signature::variadic(vec![Utf8], self.volatility()) @@ -779,15 +760,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Levenshtein => &["levenshtein"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], BuiltinScalarFunction::FindInSet => &["find_in_set"], - - // hashing functions - BuiltinScalarFunction::ArrayReplace => &["array_replace", "list_replace"], - BuiltinScalarFunction::ArrayReplaceN => { - &["array_replace_n", "list_replace_n"] - } - BuiltinScalarFunction::ArrayReplaceAll => { - &["array_replace_all", "list_replace_all"] - } BuiltinScalarFunction::OverLay => &["overlay"], } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index c5ad2a9b3ce4..b76164a1c83c 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -584,25 +584,6 @@ scalar_expr!( scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); -scalar_expr!( - ArrayReplace, - array_replace, - array from to, - "replaces the first occurrence of the specified element with another specified element." -); -scalar_expr!( - ArrayReplaceN, - array_replace_n, - array from to max, - "replaces the first `max` occurrences of the specified element with another specified element." -); -scalar_expr!( - ArrayReplaceAll, - array_replace_all, - array from to, - "replaces all occurrences of the specified element with another specified element." -); - // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); scalar_expr!( @@ -1145,11 +1126,6 @@ mod test { test_scalar_expr!(Translate, translate, string, from, to); test_scalar_expr!(Trim, trim, string); test_scalar_expr!(Upper, upper, string); - - test_scalar_expr!(ArrayReplace, array_replace, array, from, to); - test_scalar_expr!(ArrayReplaceN, array_replace_n, array, from, to, max); - test_scalar_expr!(ArrayReplaceAll, array_replace_all, array, from, to); - test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); test_scalar_expr!(Levenshtein, levenshtein, string1, string2); diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index 99239ffb3bdc..80c0e5e18768 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -40,6 +40,7 @@ path = "src/lib.rs" arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/functions-array/src/core.rs b/datafusion/functions-array/src/core.rs index 4c84b7018c99..fdd127cc3f32 100644 --- a/datafusion/functions-array/src/core.rs +++ b/datafusion/functions-array/src/core.rs @@ -96,7 +96,7 @@ impl ScalarUDFImpl for MakeArray { } } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(make_array_inner)(args) } diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 2c19dfad6222..fb16acdef2bd 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -36,6 +36,7 @@ mod extract; mod kernels; mod position; mod remove; +mod replace; mod rewrite; mod set_ops; mod udf; @@ -66,6 +67,9 @@ pub mod expr_fn { pub use super::remove::array_remove; pub use super::remove::array_remove_all; pub use super::remove::array_remove_n; + pub use super::replace::array_replace; + pub use super::replace::array_replace_all; + pub use super::replace::array_replace_n; pub use super::set_ops::array_distinct; pub use super::set_ops::array_intersect; pub use super::set_ops::array_union; @@ -120,8 +124,11 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { position::array_position_udf(), position::array_positions_udf(), remove::array_remove_udf(), - remove::array_remove_n_udf(), remove::array_remove_all_udf(), + remove::array_remove_n_udf(), + replace::array_replace_n_udf(), + replace::array_replace_all_udf(), + replace::array_replace_udf(), ]; functions.into_iter().try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-array/src/position.rs index 4988e0ded106..627cf3cb0cf0 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-array/src/position.rs @@ -27,8 +27,7 @@ use std::sync::Arc; use arrow_array::types::UInt64Type; use arrow_array::{ - Array, ArrayRef, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, - UInt32Array, UInt64Array, + Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, }; use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, @@ -36,6 +35,8 @@ use datafusion_common::cast::{ use datafusion_common::{exec_err, internal_err}; use itertools::Itertools; +use crate::utils::compare_element_to_list; + make_udf_function!( ArrayPosition, array_position, @@ -173,107 +174,6 @@ fn generic_position( Ok(Arc::new(UInt64Array::from(data))) } -/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. -/// -/// # Arguments -/// -/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. -/// -/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. -/// -/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. -/// -/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. -/// -/// # Returns -/// -/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. -/// -/// # Example -/// -/// ```text -/// compare_element_to_list( -/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] -/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] -/// -/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] -/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] -/// ) -/// ``` -fn compare_element_to_list( - list_array_row: &dyn Array, - element_array: &dyn Array, - row_index: usize, - eq: bool, -) -> datafusion_common::Result { - if list_array_row.data_type() != element_array.data_type() { - return exec_err!( - "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", - list_array_row.data_type(), - element_array.data_type() - ); - } - - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let res = match element_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - DataType::LargeList(_) => { - // compare each element of the from array - let element_array_row_inner = - as_large_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_large_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - _ => { - let element_arr = Scalar::new(element_array_row); - // use not_distinct so we can compare NULL - if eq { - arrow::compute::kernels::cmp::not_distinct(&list_array_row, &element_arr)? - } else { - arrow::compute::kernels::cmp::distinct(&list_array_row, &element_arr)? - } - } - }; - - Ok(res) -} - make_udf_function!( ArrayPositions, array_positions, diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-array/src/replace.rs new file mode 100644 index 000000000000..8ff65d315431 --- /dev/null +++ b/datafusion/functions-array/src/replace.rs @@ -0,0 +1,362 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array functions. + +use arrow::array::{ + Array, ArrayRef, AsArray, Capacities, MutableArrayData, OffsetSizeTrait, +}; +use arrow::datatypes::DataType; + +use arrow_array::GenericListArray; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_schema::Field; +use datafusion_common::cast::as_int64_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::Expr; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::compare_element_to_list; +use crate::utils::make_scalar_function; + +use std::any::Any; +use std::sync::Arc; + +// Create static instances of ScalarUDFs for each function +make_udf_function!(ArrayReplace, + array_replace, + array from to, + "replaces the first occurrence of the specified element with another specified element.", + array_replace_udf +); +make_udf_function!(ArrayReplaceN, + array_replace_n, + array from to max, + "replaces the first `max` occurrences of the specified element with another specified element.", + array_replace_n_udf +); +make_udf_function!(ArrayReplaceAll, + array_replace_all, + array from to, + "replaces all occurrences of the specified element with another specified element.", + array_replace_all_udf +); + +#[derive(Debug)] +pub(super) struct ArrayReplace { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplace { + pub fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + aliases: vec![String::from("array_replace"), String::from("list_replace")], + } + } +} + +impl ScalarUDFImpl for ArrayReplace { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_replace" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_replace_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Debug)] +pub(super) struct ArrayReplaceN { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplaceN { + pub fn new() -> Self { + Self { + signature: Signature::any(4, Volatility::Immutable), + aliases: vec![ + String::from("array_replace_n"), + String::from("list_replace_n"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayReplaceN { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_replace_n" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_replace_n_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +#[derive(Debug)] +pub(super) struct ArrayReplaceAll { + signature: Signature, + aliases: Vec, +} + +impl ArrayReplaceAll { + pub fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + aliases: vec![ + String::from("array_replace_all"), + String::from("list_replace_all"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayReplaceAll { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_replace_all" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + Ok(args[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_replace_all_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences +/// of `from_array[i]`, `to_array[i]`. +/// +/// The type of each **element** in `list_array` must be the same as the type of +/// `from_array` and `to_array`. This function also handles nested arrays +/// (\[`ListArray`\] of \[`ListArray`\]s) +/// +/// For example, when called to replace a list array (where each element is a +/// list of int32s, the second and third argument are int32 arrays, and the +/// fourth argument is the number of occurrences to replace +/// +/// ```text +/// general_replace( +/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) +/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) +/// ) +/// ``` +fn general_replace( + list_array: &GenericListArray, + from_array: &ArrayRef, + to_array: &ArrayRef, + arr_n: Vec, +) -> Result { + // Build up the offsets for the final output array + let mut offsets: Vec = vec![O::usize_as(0)]; + let values = list_array.values(); + let original_data = values.to_data(); + let to_data = to_array.to_data(); + let capacity = Capacities::Array(original_data.len()); + + // First array is the original array, second array is the element to replace with. + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &to_data], + false, + capacity, + ); + + let mut valid = BooleanBufferBuilder::new(list_array.len()); + + for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { + if list_array.is_null(row_index) { + offsets.push(offsets[row_index]); + valid.append(false); + continue; + } + + let start = offset_window[0]; + let end = offset_window[1]; + + let list_array_row = list_array.value(row_index); + + // Compute all positions in list_row_array (that is itself an + // array) that are equal to `from_array_row` + let eq_array = + compare_element_to_list(&list_array_row, &from_array, row_index, true)?; + + let original_idx = O::usize_as(0); + let replace_idx = O::usize_as(1); + let n = arr_n[row_index]; + let mut counter = 0; + + // All elements are false, no need to replace, just copy original data + if eq_array.false_count() == eq_array.len() { + mutable.extend( + original_idx.to_usize().unwrap(), + start.to_usize().unwrap(), + end.to_usize().unwrap(), + ); + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + continue; + } + + for (i, to_replace) in eq_array.iter().enumerate() { + let i = O::usize_as(i); + if let Some(true) = to_replace { + mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); + counter += 1; + if counter == n { + // copy original data for any matches past n + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + end.to_usize().unwrap(), + ); + break; + } + } else { + // copy original data for false / null matches + mutable.extend( + original_idx.to_usize().unwrap(), + (start + i).to_usize().unwrap(), + (start + i).to_usize().unwrap() + 1, + ); + } + } + + offsets.push(offsets[row_index] + (end - start)); + valid.append(true); + } + + let data = mutable.freeze(); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", list_array.value_type(), true)), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + Some(NullBuffer::new(valid.finish())), + )?)) +} + +pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace expects three arguments"); + } + + // replace at most one occurence for each element + let arr_n = vec![1; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => exec_err!("array_replace does not support type '{array_type:?}'."), + } +} + +pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { + if args.len() != 4 { + return exec_err!("array_replace_n expects four arguments"); + } + + // replace the specified number of occurences + let arr_n = as_int64_array(&args[3])?.values().to_vec(); + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_n does not support type '{array_type:?}'.") + } + } +} + +pub(crate) fn array_replace_all_inner(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!("array_replace_all expects three arguments"); + } + + // replace all occurrences (up to "i64::MAX") + let arr_n = vec![i64::MAX; args[0].len()]; + let array = &args[0]; + match array.data_type() { + DataType::List(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + DataType::LargeList(_) => { + let list_array = array.as_list::(); + general_replace::(list_array, &args[1], &args[2], arr_n) + } + array_type => { + exec_err!("array_replace_all does not support type '{array_type:?}'.") + } + } +} diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index ad613163c6af..9589cb05fe9b 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use arrow::{array::ArrayRef, datatypes::DataType}; + use arrow_array::{ Array, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, }; @@ -27,6 +28,7 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::Field; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { @@ -202,9 +204,9 @@ pub(crate) fn compare_element_to_list( let element_arr = Scalar::new(element_array_row); // use not_distinct so we can compare NULL if eq { - arrow::compute::kernels::cmp::not_distinct(&list_array_row, &element_arr)? + arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? } else { - arrow::compute::kernels::cmp::distinct(&list_array_row, &element_arr)? + arrow_ord::cmp::distinct(&list_array_row, &element_arr)? } } }; diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs deleted file mode 100644 index c3c0f4c82282..000000000000 --- a/datafusion/physical-expr/src/array_expressions.rs +++ /dev/null @@ -1,423 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Array expressions - -use std::sync::Arc; - -use arrow::array::*; -use arrow::buffer::OffsetBuffer; -use arrow::datatypes::{DataType, Field}; -use arrow_buffer::NullBuffer; - -use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; -use datafusion_common::utils::array_into_list_array; -use datafusion_common::{exec_err, plan_err, Result}; - -/// Computes a BooleanArray indicating equality or inequality between elements in a list array and a specified element array. -/// -/// # Arguments -/// -/// * `list_array_row` - A reference to a trait object implementing the Arrow `Array` trait. It represents the list array for which the equality or inequality will be compared. -/// -/// * `element_array` - A reference to a trait object implementing the Arrow `Array` trait. It represents the array with which each element in the `list_array_row` will be compared. -/// -/// * `row_index` - The index of the row in the `element_array` and `list_array` to use for the comparison. -/// -/// * `eq` - A boolean flag. If `true`, the function computes equality; if `false`, it computes inequality. -/// -/// # Returns -/// -/// Returns a `Result` representing the comparison results. The result may contain an error if there are issues with the computation. -/// -/// # Example -/// -/// ```text -/// compare_element_to_list( -/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false] -/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, true, false] -/// -/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 0, true => [true, false, false] -/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 1, false => [true, false, false] -/// ) -/// ``` -fn compare_element_to_list( - list_array_row: &dyn Array, - element_array: &dyn Array, - row_index: usize, - eq: bool, -) -> Result { - if list_array_row.data_type() != element_array.data_type() { - return exec_err!( - "compare_element_to_list received incompatible types: '{:?}' and '{:?}'.", - list_array_row.data_type(), - element_array.data_type() - ); - } - - let indices = UInt32Array::from(vec![row_index as u32]); - let element_array_row = arrow::compute::take(element_array, &indices, None)?; - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let res = match element_array_row.data_type() { - // arrow_ord::cmp::eq does not support ListArray, so we need to compare it by loop - DataType::List(_) => { - // compare each element of the from array - let element_array_row_inner = as_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - DataType::LargeList(_) => { - // compare each element of the from array - let element_array_row_inner = - as_large_list_array(&element_array_row)?.value(0); - let list_array_row_inner = as_large_list_array(list_array_row)?; - - list_array_row_inner - .iter() - // compare element by element the current row of list_array - .map(|row| { - row.map(|row| { - if eq { - row.eq(&element_array_row_inner) - } else { - row.ne(&element_array_row_inner) - } - }) - }) - .collect::() - } - _ => { - let element_arr = Scalar::new(element_array_row); - // use not_distinct so we can compare NULL - if eq { - arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)? - } else { - arrow_ord::cmp::distinct(&list_array_row, &element_arr)? - } - } - }; - - Ok(res) -} - -/// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` or 'LargeListArray' depending on the offset size. -/// -/// # Example (non nested) -/// -/// Calling `array(col1, col2)` where col1 and col2 are non nested -/// would return a single new `ListArray`, where each row was a list -/// of 2 elements: -/// -/// ```text -/// ┌─────────┐ ┌─────────┐ ┌──────────────┐ -/// │ ┌─────┐ │ │ ┌─────┐ │ │ ┌──────────┐ │ -/// │ │ A │ │ │ │ X │ │ │ │ [A, X] │ │ -/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ -/// │ │NULL │ │ │ │ Y │ │──────────▶│ │[NULL, Y] │ │ -/// │ ├─────┤ │ │ ├─────┤ │ │ ├──────────┤ │ -/// │ │ C │ │ │ │ Z │ │ │ │ [C, Z] │ │ -/// │ └─────┘ │ │ └─────┘ │ │ └──────────┘ │ -/// └─────────┘ └─────────┘ └──────────────┘ -/// col1 col2 output -/// ``` -/// -/// # Example (nested) -/// -/// Calling `array(col1, col2)` where col1 and col2 are lists -/// would return a single new `ListArray`, where each row was a list -/// of the corresponding elements of col1 and col2. -/// -/// ``` text -/// ┌──────────────┐ ┌──────────────┐ ┌─────────────────────────────┐ -/// │ ┌──────────┐ │ │ ┌──────────┐ │ │ ┌────────────────────────┐ │ -/// │ │ [A, X] │ │ │ │ [] │ │ │ │ [[A, X], []] │ │ -/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────┤ │ -/// │ │[NULL, Y] │ │ │ │[Q, R, S] │ │───────▶│ │ [[NULL, Y], [Q, R, S]] │ │ -/// │ ├──────────┤ │ │ ├──────────┤ │ │ ├────────────────────────│ │ -/// │ │ [C, Z] │ │ │ │ NULL │ │ │ │ [[C, Z], NULL] │ │ -/// │ └──────────┘ │ │ └──────────┘ │ │ └────────────────────────┘ │ -/// └──────────────┘ └──────────────┘ └─────────────────────────────┘ -/// col1 col2 output -/// ``` -fn array_array( - args: &[ArrayRef], - data_type: DataType, -) -> Result { - // do not accept 0 arguments. - if args.is_empty() { - return plan_err!("Array requires at least one argument"); - } - - let mut data = vec![]; - let mut total_len = 0; - for arg in args { - let arg_data = if arg.as_any().is::() { - ArrayData::new_empty(&data_type) - } else { - arg.to_data() - }; - total_len += arg_data.len(); - data.push(arg_data); - } - - let mut offsets: Vec = Vec::with_capacity(total_len); - offsets.push(O::usize_as(0)); - - let capacity = Capacities::Array(total_len); - let data_ref = data.iter().collect::>(); - let mut mutable = MutableArrayData::with_capacities(data_ref, true, capacity); - - let num_rows = args[0].len(); - for row_idx in 0..num_rows { - for (arr_idx, arg) in args.iter().enumerate() { - if !arg.as_any().is::() - && !arg.is_null(row_idx) - && arg.is_valid(row_idx) - { - mutable.extend(arr_idx, row_idx, row_idx + 1); - } else { - mutable.extend_nulls(1); - } - } - offsets.push(O::usize_as(mutable.len())); - } - let data = mutable.freeze(); - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type, true)), - OffsetBuffer::new(offsets.into()), - arrow_array::make_array(data), - None, - )?)) -} - -/// `make_array` SQL function -pub fn make_array(arrays: &[ArrayRef]) -> Result { - let mut data_type = DataType::Null; - for arg in arrays { - let arg_data_type = arg.data_type(); - if !arg_data_type.equals_datatype(&DataType::Null) { - data_type = arg_data_type.clone(); - break; - } - } - - match data_type { - // Either an empty array or all nulls: - DataType::Null => { - let array = - new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); - Ok(Arc::new(array_into_list_array(array))) - } - DataType::LargeList(..) => array_array::(arrays, data_type), - _ => array_array::(arrays, data_type), - } -} - -/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences -/// of `from_array[i]`, `to_array[i]`. -/// -/// The type of each **element** in `list_array` must be the same as the type of -/// `from_array` and `to_array`. This function also handles nested arrays -/// ([`ListArray`] of [`ListArray`]s) -/// -/// For example, when called to replace a list array (where each element is a -/// list of int32s, the second and third argument are int32 arrays, and the -/// fourth argument is the number of occurrences to replace -/// -/// ```text -/// general_replace( -/// [1, 2, 3, 2], 2, 10, 1 ==> [1, 10, 3, 2] (only the first 2 is replaced) -/// [4, 5, 6, 5], 5, 20, 2 ==> [4, 20, 6, 20] (both 5s are replaced) -/// ) -/// ``` -fn general_replace( - list_array: &GenericListArray, - from_array: &ArrayRef, - to_array: &ArrayRef, - arr_n: Vec, -) -> Result { - // Build up the offsets for the final output array - let mut offsets: Vec = vec![O::usize_as(0)]; - let values = list_array.values(); - let original_data = values.to_data(); - let to_data = to_array.to_data(); - let capacity = Capacities::Array(original_data.len()); - - // First array is the original array, second array is the element to replace with. - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &to_data], - false, - capacity, - ); - - let mut valid = BooleanBufferBuilder::new(list_array.len()); - - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - if list_array.is_null(row_index) { - offsets.push(offsets[row_index]); - valid.append(false); - continue; - } - - let start = offset_window[0]; - let end = offset_window[1]; - - let list_array_row = list_array.value(row_index); - - // Compute all positions in list_row_array (that is itself an - // array) that are equal to `from_array_row` - let eq_array = - compare_element_to_list(&list_array_row, &from_array, row_index, true)?; - - let original_idx = O::usize_as(0); - let replace_idx = O::usize_as(1); - let n = arr_n[row_index]; - let mut counter = 0; - - // All elements are false, no need to replace, just copy original data - if eq_array.false_count() == eq_array.len() { - mutable.extend( - original_idx.to_usize().unwrap(), - start.to_usize().unwrap(), - end.to_usize().unwrap(), - ); - offsets.push(offsets[row_index] + (end - start)); - valid.append(true); - continue; - } - - for (i, to_replace) in eq_array.iter().enumerate() { - let i = O::usize_as(i); - if let Some(true) = to_replace { - mutable.extend(replace_idx.to_usize().unwrap(), row_index, row_index + 1); - counter += 1; - if counter == n { - // copy original data for any matches past n - mutable.extend( - original_idx.to_usize().unwrap(), - (start + i).to_usize().unwrap() + 1, - end.to_usize().unwrap(), - ); - break; - } - } else { - // copy original data for false / null matches - mutable.extend( - original_idx.to_usize().unwrap(), - (start + i).to_usize().unwrap(), - (start + i).to_usize().unwrap() + 1, - ); - } - } - - offsets.push(offsets[row_index] + (end - start)); - valid.append(true); - } - - let data = mutable.freeze(); - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", list_array.value_type(), true)), - OffsetBuffer::::new(offsets.into()), - arrow_array::make_array(data), - Some(NullBuffer::new(valid.finish())), - )?)) -} - -pub fn array_replace(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace expects three arguments"); - } - - // replace at most one occurence for each element - let arr_n = vec![1; args[0].len()]; - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => exec_err!("array_replace does not support type '{array_type:?}'."), - } -} - -pub fn array_replace_n(args: &[ArrayRef]) -> Result { - if args.len() != 4 { - return exec_err!("array_replace_n expects four arguments"); - } - - // replace the specified number of occurences - let arr_n = as_int64_array(&args[3])?.values().to_vec(); - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => { - exec_err!("array_replace_n does not support type '{array_type:?}'.") - } - } -} - -pub fn array_replace_all(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!("array_replace_all expects three arguments"); - } - - // replace all occurrences (up to "i64::MAX") - let arr_n = vec![i64::MAX; args[0].len()]; - let array = &args[0]; - match array.data_type() { - DataType::List(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - DataType::LargeList(_) => { - let list_array = array.as_list::(); - general_replace::(list_array, &args[1], &args[2], arr_n) - } - array_type => { - exec_err!("array_replace_all does not support type '{array_type:?}'.") - } - } -} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 994c17309ec0..c6c185e002f0 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -32,8 +32,8 @@ use crate::sort_properties::SortProperties; use crate::{ - array_expressions, conditional_expressions, math_expressions, string_expressions, - PhysicalExpr, ScalarFunctionExpr, + conditional_expressions, math_expressions, string_expressions, PhysicalExpr, + ScalarFunctionExpr, }; use arrow::{ array::ArrayRef, @@ -253,18 +253,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Cot => { Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } - - // array functions - BuiltinScalarFunction::ArrayReplace => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace)(args) - }), - BuiltinScalarFunction::ArrayReplaceN => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace_n)(args) - }), - BuiltinScalarFunction::ArrayReplaceAll => Arc::new(|args| { - make_scalar_function_inner(array_expressions::array_replace_all)(args) - }), - // string functions BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index e8b80ee4e1e6..1791a6ed60b2 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -17,7 +17,6 @@ pub mod aggregate; pub mod analysis; -pub mod array_expressions; pub mod binary_map; pub mod conditional_expressions; pub mod equivalence; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 597094758584..6879f70cd05c 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -637,7 +637,7 @@ enum ScalarFunction { // 93 was ArrayPositions // 94 was ArrayPrepend // 95 was ArrayRemove - ArrayReplace = 96; + // 96 was ArrayReplace // 97 was ArrayToString // 98 was Cardinality // 99 was ArrayElement @@ -647,9 +647,9 @@ enum ScalarFunction { // 105 was ArrayHasAny // 106 was ArrayHasAll // 107 was ArrayRemoveN - ArrayReplaceN = 108; + // 108 was ArrayReplaceN // 109 was ArrayRemoveAll - ArrayReplaceAll = 110; + // 110 was ArrayReplaceAll Nanvl = 111; // 112 was Flatten // 113 was IsNan diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index cb9633338e8f..75c135fd01b4 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22951,10 +22951,7 @@ impl serde::Serialize for ScalarFunction { Self::Factorial => "Factorial", Self::Lcm => "Lcm", Self::Gcd => "Gcd", - Self::ArrayReplace => "ArrayReplace", Self::Cot => "Cot", - Self::ArrayReplaceN => "ArrayReplaceN", - Self::ArrayReplaceAll => "ArrayReplaceAll", Self::Nanvl => "Nanvl", Self::Iszero => "Iszero", Self::OverLay => "OverLay", @@ -23032,10 +23029,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Factorial", "Lcm", "Gcd", - "ArrayReplace", "Cot", - "ArrayReplaceN", - "ArrayReplaceAll", "Nanvl", "Iszero", "OverLay", @@ -23142,10 +23136,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Factorial" => Ok(ScalarFunction::Factorial), "Lcm" => Ok(ScalarFunction::Lcm), "Gcd" => Ok(ScalarFunction::Gcd), - "ArrayReplace" => Ok(ScalarFunction::ArrayReplace), "Cot" => Ok(ScalarFunction::Cot), - "ArrayReplaceN" => Ok(ScalarFunction::ArrayReplaceN), - "ArrayReplaceAll" => Ok(ScalarFunction::ArrayReplaceAll), "Nanvl" => Ok(ScalarFunction::Nanvl), "Iszero" => Ok(ScalarFunction::Iszero), "OverLay" => Ok(ScalarFunction::OverLay), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index f5ef6c1f74f0..c9cc4a9b073b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2910,7 +2910,7 @@ pub enum ScalarFunction { /// 93 was ArrayPositions /// 94 was ArrayPrepend /// 95 was ArrayRemove - ArrayReplace = 96, + /// 96 was ArrayReplace /// 97 was ArrayToString /// 98 was Cardinality /// 99 was ArrayElement @@ -2920,9 +2920,9 @@ pub enum ScalarFunction { /// 105 was ArrayHasAny /// 106 was ArrayHasAll /// 107 was ArrayRemoveN - ArrayReplaceN = 108, + /// 108 was ArrayReplaceN /// 109 was ArrayRemoveAll - ArrayReplaceAll = 110, + /// 110 was ArrayReplaceAll Nanvl = 111, /// 112 was Flatten /// 113 was IsNan @@ -3019,10 +3019,7 @@ impl ScalarFunction { ScalarFunction::Factorial => "Factorial", ScalarFunction::Lcm => "Lcm", ScalarFunction::Gcd => "Gcd", - ScalarFunction::ArrayReplace => "ArrayReplace", ScalarFunction::Cot => "Cot", - ScalarFunction::ArrayReplaceN => "ArrayReplaceN", - ScalarFunction::ArrayReplaceAll => "ArrayReplaceAll", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Iszero => "Iszero", ScalarFunction::OverLay => "OverLay", @@ -3094,10 +3091,7 @@ impl ScalarFunction { "Factorial" => Some(Self::Factorial), "Lcm" => Some(Self::Lcm), "Gcd" => Some(Self::Gcd), - "ArrayReplace" => Some(Self::ArrayReplace), "Cot" => Some(Self::Cot), - "ArrayReplaceN" => Some(Self::ArrayReplaceN), - "ArrayReplaceAll" => Some(Self::ArrayReplaceAll), "Nanvl" => Some(Self::Nanvl), "Iszero" => Some(Self::Iszero), "OverLay" => Some(Self::OverLay), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3822b74bc18c..06aab16edd57 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,9 +48,9 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, array_replace, array_replace_all, array_replace_n, ascii, asinh, atan, atan2, - atanh, bit_length, btrim, cbrt, ceil, character_length, chr, coalesce, concat_expr, - concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, + acosh, ascii, asinh, atan, atan2, atanh, bit_length, btrim, cbrt, ceil, + character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, + degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -466,9 +466,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Trim => Self::Trim, ScalarFunction::Ltrim => Self::Ltrim, ScalarFunction::Rtrim => Self::Rtrim, - ScalarFunction::ArrayReplace => Self::ArrayReplace, - ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, - ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, ScalarFunction::Ascii => Self::Ascii, @@ -1362,22 +1359,6 @@ pub fn parse_expr( ScalarFunction::Acosh => { Ok(acosh(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::ArrayReplace => Ok(array_replace( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::ArrayReplaceN => Ok(array_replace_n( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - parse_expr(&args[3], registry, codec)?, - )), - ScalarFunction::ArrayReplaceAll => Ok(array_replace_all( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 7a17d2a2b405..478f7c779552 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1453,9 +1453,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Trim => Self::Trim, BuiltinScalarFunction::Ltrim => Self::Ltrim, BuiltinScalarFunction::Rtrim => Self::Rtrim, - BuiltinScalarFunction::ArrayReplace => Self::ArrayReplace, - BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN, - BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, BuiltinScalarFunction::Ascii => Self::Ascii, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 479f80fbdddf..93de560dbee5 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -605,6 +605,14 @@ async fn roundtrip_expr_api() -> Result<()> { make_array(vec![lit(3), lit(3), lit(2), lit(3), lit(1)]), lit(3), ), + array_replace(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), + array_replace_n( + make_array(vec![lit(1), lit(2), lit(3)]), + lit(2), + lit(4), + lit(1), + ), + array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), ]; // ensure expressions created with the expr api can be round tripped From 2c7cf5f41af8f74feec8c8581732ef0fb5003008 Mon Sep 17 00:00:00 2001 From: InventiveCoder <163831412+InventiveCoder@users.noreply.github.com> Date: Tue, 19 Mar 2024 02:21:46 +0800 Subject: [PATCH 004/117] chore: remove repetitive words (#9673) Signed-off-by: InventiveCoder --- datafusion/common/src/stats.rs | 2 +- datafusion/core/src/datasource/file_format/mod.rs | 2 +- .../src/physical_optimizer/enforce_distribution.rs | 2 +- .../src/physical_optimizer/output_requirements.rs | 2 +- .../src/physical_optimizer/projection_pushdown.rs | 14 +++++++------- datafusion/physical-expr/src/binary_map.rs | 2 +- .../physical-plan/src/aggregates/order/mod.rs | 2 +- datafusion/sql/src/expr/arrow_cast.rs | 2 +- .../sqllogictest/test_files/create_function.slt | 2 +- datafusion/sqllogictest/test_files/limit.slt | 2 +- dev/changelog/13.0.0.md | 2 +- dev/changelog/7.0.0.md | 2 +- docs/source/contributor-guide/communication.md | 2 +- docs/source/library-user-guide/adding-udfs.md | 2 +- 14 files changed, 20 insertions(+), 20 deletions(-) diff --git a/datafusion/common/src/stats.rs b/datafusion/common/src/stats.rs index a10e05a55c64..6cefef8d0eb5 100644 --- a/datafusion/common/src/stats.rs +++ b/datafusion/common/src/stats.rs @@ -221,7 +221,7 @@ pub struct Statistics { /// Total bytes of the table rows. pub total_byte_size: Precision, /// Statistics on a column level. It contains a [`ColumnStatistics`] for - /// each field in the schema of the the table to which the [`Statistics`] refer. + /// each field in the schema of the table to which the [`Statistics`] refer. pub column_statistics: Vec, } diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 72dc289d4b64..5ee0f7186703 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -49,7 +49,7 @@ use object_store::{ObjectMeta, ObjectStore}; /// This trait abstracts all the file format specific implementations /// from the [`TableProvider`]. This helps code re-utilization across -/// providers that support the the same file formats. +/// providers that support the same file formats. /// /// [`TableProvider`]: crate::datasource::provider::TableProvider #[async_trait] diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 54fe6e8406fd..0740a8d2cdbc 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -392,7 +392,7 @@ fn adjust_input_keys_ordering( let expr = proj.expr(); // For Projection, we need to transform the requirements to the columns before the Projection // And then to push down the requirements - // Construct a mapping from new name to the the orginal Column + // Construct a mapping from new name to the orginal Column let new_required = map_columns_before_projection(&requirements.data, expr); if new_required.len() == requirements.data.len() { requirements.children[0].data = new_required; diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index bd71b3e8ed80..bf010a5e39d8 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -216,7 +216,7 @@ impl PhysicalOptimizerRule for OutputRequirements { } } -/// This functions adds ancillary `OutputRequirementExec` to the the physical plan, so that +/// This functions adds ancillary `OutputRequirementExec` to the physical plan, so that /// global requirements are not lost during optimization. fn require_top_ordering(plan: Arc) -> Result> { let (new_plan, is_changed) = require_top_ordering_helper(plan)?; diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index e8f3bf01ecaa..ab5611597472 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -322,7 +322,7 @@ fn try_swapping_with_output_req( projection: &ProjectionExec, output_req: &OutputRequirementExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -372,7 +372,7 @@ fn try_swapping_with_output_req( fn try_swapping_with_coalesce_partitions( projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -387,7 +387,7 @@ fn try_swapping_with_filter( projection: &ProjectionExec, filter: &FilterExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down: + // If the projection does not narrow the schema, we should not try to push it down: if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -412,7 +412,7 @@ fn try_swapping_with_repartition( projection: &ProjectionExec, repartition: &RepartitionExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down. + // If the projection does not narrow the schema, we should not try to push it down. if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -454,7 +454,7 @@ fn try_swapping_with_sort( projection: &ProjectionExec, sort: &SortExec, ) -> Result>> { - // If the projection does not narrow the the schema, we should not try to push it down. + // If the projection does not narrow the schema, we should not try to push it down. if projection.expr().len() >= projection.input().schema().fields().len() { return Ok(None); } @@ -1082,7 +1082,7 @@ fn join_table_borders( (far_right_left_col_ind, far_left_right_col_ind) } -/// Tries to update the equi-join `Column`'s of a join as if the the input of +/// Tries to update the equi-join `Column`'s of a join as if the input of /// the join was replaced by a projection. fn update_join_on( proj_left_exprs: &[(Column, String)], @@ -1152,7 +1152,7 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Tries to update the column indices of a [`JoinFilter`] as if the the input of +/// Tries to update the column indices of a [`JoinFilter`] as if the input of /// the join was replaced by a projection. fn update_join_filter( projection_left_exprs: &[(Column, String)], diff --git a/datafusion/physical-expr/src/binary_map.rs b/datafusion/physical-expr/src/binary_map.rs index b661f0a74148..6c3a452a8611 100644 --- a/datafusion/physical-expr/src/binary_map.rs +++ b/datafusion/physical-expr/src/binary_map.rs @@ -280,7 +280,7 @@ where /// # Returns /// /// The payload value for the entry, either the existing value or - /// the the newly inserted value + /// the newly inserted value /// /// # Safety: /// diff --git a/datafusion/physical-plan/src/aggregates/order/mod.rs b/datafusion/physical-plan/src/aggregates/order/mod.rs index 4f1914b12c96..556103e1e222 100644 --- a/datafusion/physical-plan/src/aggregates/order/mod.rs +++ b/datafusion/physical-plan/src/aggregates/order/mod.rs @@ -40,7 +40,7 @@ pub(crate) enum GroupOrdering { } impl GroupOrdering { - /// Create a `GroupOrdering` for the the specified ordering + /// Create a `GroupOrdering` for the specified ordering pub fn try_new( input_schema: &Schema, mode: &InputOrderMode, diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index 9a0d61f41c01..a75cdf9e3c6b 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -76,7 +76,7 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result /// Parses `str` into a `DataType`. /// -/// `parse_data_type` is the the reverse of [`DataType`]'s `Display` +/// `parse_data_type` is the reverse of [`DataType`]'s `Display` /// impl, and maintains the invariant that /// `parse_data_type(data_type.to_string()) == data_type` /// diff --git a/datafusion/sqllogictest/test_files/create_function.slt b/datafusion/sqllogictest/test_files/create_function.slt index baa40ac64afc..4f0c53c36ca1 100644 --- a/datafusion/sqllogictest/test_files/create_function.slt +++ b/datafusion/sqllogictest/test_files/create_function.slt @@ -47,7 +47,7 @@ select abs(-1); statement ok DROP FUNCTION abs; -# now the the query errors +# now the query errors query error Invalid function 'abs'. select abs(-1); diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 92093ba13eba..0d98c41d0028 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -320,7 +320,7 @@ SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11); 0 # The aggregate does not need to be computed because the input statistics are exact and -# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET). +# the number of rows is less than or equal to the "fetch+skip" value (LIMIT+OFFSET). query TT EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8); ---- diff --git a/dev/changelog/13.0.0.md b/dev/changelog/13.0.0.md index 0f35903e2600..14b42a052ef9 100644 --- a/dev/changelog/13.0.0.md +++ b/dev/changelog/13.0.0.md @@ -87,7 +87,7 @@ - Optimizer rule 'projection_push_down' failed due to unexpected error: Error during planning: Aggregate schema has wrong number of fields. Expected 3 got 8 [\#3704](https://github.com/apache/arrow-datafusion/issues/3704) - Optimizer regressions in `unwrap_cast_in_comparison` [\#3690](https://github.com/apache/arrow-datafusion/issues/3690) - Internal error when evaluating a predicate = "The type of Dictionary\(Int16, Utf8\) = Int64 of binary physical should be same" [\#3685](https://github.com/apache/arrow-datafusion/issues/3685) -- Specialized regexp_replace should early-abort when the the input arrays are empty [\#3647](https://github.com/apache/arrow-datafusion/issues/3647) +- Specialized regexp_replace should early-abort when the input arrays are empty [\#3647](https://github.com/apache/arrow-datafusion/issues/3647) - Internal error: Failed to coerce types Decimal128\(10, 2\) and Boolean in BETWEEN expression [\#3646](https://github.com/apache/arrow-datafusion/issues/3646) - Internal error: Failed to coerce types Decimal128\(10, 2\) and Boolean in BETWEEN expression [\#3645](https://github.com/apache/arrow-datafusion/issues/3645) - Type coercion error: The type of Boolean AND Decimal128\(10, 2\) of binary physical should be same [\#3644](https://github.com/apache/arrow-datafusion/issues/3644) diff --git a/dev/changelog/7.0.0.md b/dev/changelog/7.0.0.md index e63c2a4455c9..4d2606d7bfbe 100644 --- a/dev/changelog/7.0.0.md +++ b/dev/changelog/7.0.0.md @@ -56,7 +56,7 @@ - Keep all datafusion's packages up to date with Dependabot [\#1472](https://github.com/apache/arrow-datafusion/issues/1472) - ExecutionContext support init ExecutionContextState with `new(state: Arc>)` method [\#1439](https://github.com/apache/arrow-datafusion/issues/1439) - support the decimal scalar value [\#1393](https://github.com/apache/arrow-datafusion/issues/1393) -- Documentation for using scalar functions with the the DataFrame API [\#1364](https://github.com/apache/arrow-datafusion/issues/1364) +- Documentation for using scalar functions with the DataFrame API [\#1364](https://github.com/apache/arrow-datafusion/issues/1364) - Support `boolean == boolean` and `boolean != boolean` operators [\#1159](https://github.com/apache/arrow-datafusion/issues/1159) - Support DataType::Decimal\(15, 2\) in TPC-H benchmark [\#174](https://github.com/apache/arrow-datafusion/issues/174) - Make `MemoryStream` public [\#150](https://github.com/apache/arrow-datafusion/issues/150) diff --git a/docs/source/contributor-guide/communication.md b/docs/source/contributor-guide/communication.md index 8678aa534baf..7b5e71bc3a1c 100644 --- a/docs/source/contributor-guide/communication.md +++ b/docs/source/contributor-guide/communication.md @@ -44,7 +44,7 @@ request one in the `Arrow Rust` channel of the [Arrow Rust Discord server](https ## Mailing list We also use arrow.apache.org's `dev@` mailing list for release coordination and occasional design discussions. Other -than the the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. +than the release process, most DataFusion mailing list traffic will link to a GitHub issue or PR for discussion. ([subscribe](mailto:dev-subscribe@arrow.apache.org), [unsubscribe](mailto:dev-unsubscribe@arrow.apache.org), [archives](https://lists.apache.org/list.html?dev@arrow.apache.org)). diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index f433e026e0a2..ad210724103d 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -204,7 +204,7 @@ let df = ctx.sql(&sql).await.unwrap(); ## Adding a Window UDF -Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the the proximal rows is helpful, but adds some complexity to the implementation. +Scalar UDFs are functions that take a row of data and return a single value. Window UDFs are similar, but they also have access to the rows around them. Access to the proximal rows is helpful, but adds some complexity to the implementation. For example, we will declare a user defined window function that computes a moving average. From c0a21b28c7dadd7d3e1db1fbe2433735a2b65d5a Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Mon, 18 Mar 2024 14:23:16 -0400 Subject: [PATCH 005/117] Update example-usage.md to remove reference to simd and rust nightly. (#9677) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. --- docs/source/user-guide/example-usage.md | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 1c5c8f49a16a..c5eefbdaf156 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -240,17 +240,11 @@ async fn main() -> datafusion::error::Result<()> { } ``` -Finally, in order to build with the `simd` optimization `cargo nightly` is required. - -```shell -rustup toolchain install nightly -``` - Based on the instruction set architecture you are building on you will want to configure the `target-cpu` as well, ideally with `native` or at least `avx2`. ```shell -RUSTFLAGS='-C target-cpu=native' cargo +nightly run --release +RUSTFLAGS='-C target-cpu=native' cargo run --release ``` ## Enable backtraces From 35ff7a66c0e2579489e1408bb426fe4444f6ce2e Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Mon, 18 Mar 2024 21:23:31 +0300 Subject: [PATCH 006/117] Minor changes (#9674) --- .../physical-expr/src/window/nth_value.rs | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index e913f39333f9..9de71c2d604c 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -225,31 +225,38 @@ impl PartitionEvaluator for NthValueEvaluator { } // Extract valid indices if ignoring nulls. - let (slice, valid_indices) = if self.ignore_nulls { + let valid_indices = if self.ignore_nulls { + // Calculate valid indices, inside the window frame boundaries let slice = arr.slice(range.start, n_range); - let valid_indices = - slice.nulls().unwrap().valid_indices().collect::>(); + let valid_indices = slice + .nulls() + .map(|nulls| { + nulls + .valid_indices() + // Add offset `range.start` to valid indices, to point correct index in the original arr. + .map(|idx| idx + range.start) + .collect::>() + }) + .unwrap_or_default(); if valid_indices.is_empty() { return ScalarValue::try_from(arr.data_type()); } - (Some(slice), Some(valid_indices)) + Some(valid_indices) } else { - (None, None) + None }; match self.state.kind { NthValueKind::First => { - if let Some(slice) = &slice { - let valid_indices = valid_indices.unwrap(); - ScalarValue::try_from_array(slice, valid_indices[0]) + if let Some(valid_indices) = &valid_indices { + ScalarValue::try_from_array(arr, valid_indices[0]) } else { ScalarValue::try_from_array(arr, range.start) } } NthValueKind::Last => { - if let Some(slice) = &slice { - let valid_indices = valid_indices.unwrap(); + if let Some(valid_indices) = &valid_indices { ScalarValue::try_from_array( - slice, + arr, valid_indices[valid_indices.len() - 1], ) } else { @@ -264,15 +271,11 @@ impl PartitionEvaluator for NthValueEvaluator { if index >= n_range { // Outside the range, return NULL: ScalarValue::try_from(arr.data_type()) - } else if self.ignore_nulls { - let valid_indices = valid_indices.unwrap(); + } else if let Some(valid_indices) = valid_indices { if index >= valid_indices.len() { return ScalarValue::try_from(arr.data_type()); } - ScalarValue::try_from_array( - &slice.unwrap(), - valid_indices[index], - ) + ScalarValue::try_from_array(&arr, valid_indices[index]) } else { ScalarValue::try_from_array(arr, range.start + index) } @@ -282,14 +285,13 @@ impl PartitionEvaluator for NthValueEvaluator { if n_range < reverse_index { // Outside the range, return NULL: ScalarValue::try_from(arr.data_type()) - } else if self.ignore_nulls { - let valid_indices = valid_indices.unwrap(); + } else if let Some(valid_indices) = valid_indices { if reverse_index > valid_indices.len() { return ScalarValue::try_from(arr.data_type()); } let new_index = valid_indices[valid_indices.len() - reverse_index]; - ScalarValue::try_from_array(&slice.unwrap(), new_index) + ScalarValue::try_from_array(&arr, new_index) } else { ScalarValue::try_from_array( arr, From 4687a2f793019deb199f1759f5171730a6434189 Mon Sep 17 00:00:00 2001 From: comphead Date: Mon, 18 Mar 2024 13:59:00 -0700 Subject: [PATCH 007/117] minor: Remove deprecated methods (#9627) * minor: remove deprecared code * Remove deprecated test * docs --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/dfschema.rs | 51 +++----------- datafusion/core/src/dataframe/mod.rs | 10 --- datafusion/core/src/datasource/listing/url.rs | 33 --------- datafusion/core/src/execution/context/mod.rs | 43 ------------ datafusion/execution/src/config.rs | 16 +---- datafusion/execution/src/task.rs | 45 ++---------- datafusion/expr/src/aggregate_function.rs | 22 ------ datafusion/expr/src/expr.rs | 28 -------- datafusion/expr/src/expr_rewriter/mod.rs | 43 ------------ datafusion/expr/src/function.rs | 25 +------ datafusion/expr/src/logical_plan/plan.rs | 70 ------------------- datafusion/physical-plan/src/common.rs | 6 -- datafusion/physical-plan/src/sorts/sort.rs | 27 ------- 13 files changed, 18 insertions(+), 401 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 2642032c9a04..597507a044a2 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -97,10 +97,11 @@ pub type DFSchemaRef = Arc; /// ```rust /// use datafusion_common::{DFSchema, DFField}; /// use arrow_schema::Schema; +/// use std::collections::HashMap; /// -/// let df_schema = DFSchema::new(vec![ +/// let df_schema = DFSchema::new_with_metadata(vec![ /// DFField::new_unqualified("c1", arrow::datatypes::DataType::Int32, false), -/// ]).unwrap(); +/// ], HashMap::new()).unwrap(); /// let schema = Schema::from(df_schema); /// assert_eq!(schema.fields().len(), 1); /// ``` @@ -124,12 +125,6 @@ impl DFSchema { } } - #[deprecated(since = "7.0.0", note = "please use `new_with_metadata` instead")] - /// Create a new `DFSchema` - pub fn new(fields: Vec) -> Result { - Self::new_with_metadata(fields, HashMap::new()) - } - /// Create a new `DFSchema` pub fn new_with_metadata( fields: Vec, @@ -251,32 +246,6 @@ impl DFSchema { &self.fields[i] } - #[deprecated(since = "8.0.0", note = "please use `index_of_column_by_name` instead")] - /// Find the index of the column with the given unqualified name - pub fn index_of(&self, name: &str) -> Result { - for i in 0..self.fields.len() { - if self.fields[i].name() == name { - return Ok(i); - } else { - // Now that `index_of` is deprecated an error is thrown if - // a fully qualified field name is provided. - match &self.fields[i].qualifier { - Some(qualifier) => { - if (qualifier.to_string() + "." + self.fields[i].name()) == name { - return _plan_err!( - "Fully qualified field name '{name}' was supplied to `index_of` \ - which is deprecated. Please use `index_of_column_by_name` instead" - ); - } - } - None => (), - } - } - } - - Err(unqualified_field_not_found(name, self)) - } - pub fn index_of_column_by_name( &self, qualifier: Option<&TableReference>, @@ -1146,13 +1115,10 @@ mod tests { Ok(()) } - #[allow(deprecated)] #[test] fn helpful_error_messages() -> Result<()> { let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?; let expected_help = "Valid fields are t1.c0, t1.c1."; - // Pertinent message parts - let expected_err_msg = "Fully qualified field name 't1.c0'"; assert_contains!( schema .field_with_qualified_name(&TableReference::bare("x"), "y") @@ -1167,11 +1133,12 @@ mod tests { .to_string(), expected_help ); - assert_contains!(schema.index_of("y").unwrap_err().to_string(), expected_help); - assert_contains!( - schema.index_of("t1.c0").unwrap_err().to_string(), - expected_err_msg - ); + assert!(schema.index_of_column_by_name(None, "y").unwrap().is_none()); + assert!(schema + .index_of_column_by_name(None, "t1.c0") + .unwrap() + .is_none()); + Ok(()) } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 5f192b83fdd9..25830401571d 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1001,16 +1001,6 @@ impl DataFrame { Arc::new(DataFrameTableProvider { plan: self.plan }) } - /// Return the optimized logical plan represented by this DataFrame. - /// - /// Note: This method should not be used outside testing, as it loses the snapshot - /// of the [`SessionState`] attached to this [`DataFrame`] and consequently subsequent - /// operations may take place against a different state - #[deprecated(since = "23.0.0", note = "Use DataFrame::into_optimized_plan")] - pub fn to_logical_plan(self) -> Result { - self.into_optimized_plan() - } - /// Return a DataFrame with the explanation of its plan so far. /// /// if `analyze` is specified, runs the plan and reports metrics diff --git a/datafusion/core/src/datasource/listing/url.rs b/datafusion/core/src/datasource/listing/url.rs index d9149bcc20e0..eb95dc7b1d24 100644 --- a/datafusion/core/src/datasource/listing/url.rs +++ b/datafusion/core/src/datasource/listing/url.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::fs; - use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::context::SessionState; use datafusion_common::{DataFusionError, Result}; @@ -117,37 +115,6 @@ impl ListingTableUrl { } } - /// Get object store for specified input_url - /// if input_url is actually not a url, we assume it is a local file path - /// if we have a local path, create it if not exists so ListingTableUrl::parse works - #[deprecated(note = "Use parse")] - pub fn parse_create_local_if_not_exists( - s: impl AsRef, - is_directory: bool, - ) -> Result { - let s = s.as_ref(); - let is_valid_url = Url::parse(s).is_ok(); - - match is_valid_url { - true => ListingTableUrl::parse(s), - false => { - let path = std::path::PathBuf::from(s); - if !path.exists() { - if is_directory { - fs::create_dir_all(path)?; - } else { - // ensure parent directory exists - if let Some(parent) = path.parent() { - fs::create_dir_all(parent)?; - } - fs::File::create(path)?; - } - } - ListingTableUrl::parse(s) - } - } - } - /// Creates a new [`ListingTableUrl`] interpreting `s` as a filesystem path #[cfg(not(target_arch = "wasm32"))] fn parse_path(s: &str) -> Result { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 32c1c60ec564..1ac7da465216 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1181,49 +1181,6 @@ impl SessionContext { } } - /// Returns the set of available tables in the default catalog and - /// schema. - /// - /// Use [`table`] to get a specific table. - /// - /// [`table`]: SessionContext::table - #[deprecated( - since = "23.0.0", - note = "Please use the catalog provider interface (`SessionContext::catalog`) to examine available catalogs, schemas, and tables" - )] - pub fn tables(&self) -> Result> { - Ok(self - .state - .read() - // a bare reference will always resolve to the default catalog and schema - .schema_for_ref(TableReference::Bare { table: "".into() })? - .table_names() - .iter() - .cloned() - .collect()) - } - - /// Optimizes the logical plan by applying optimizer rules. - #[deprecated( - since = "23.0.0", - note = "Use SessionState::optimize to ensure a consistent state for planning and execution" - )] - pub fn optimize(&self, plan: &LogicalPlan) -> Result { - self.state.read().optimize(plan) - } - - /// Creates a physical plan from a logical plan. - #[deprecated( - since = "23.0.0", - note = "Use SessionState::create_physical_plan or DataFrame::create_physical_plan to ensure a consistent state for planning and execution" - )] - pub async fn create_physical_plan( - &self, - logical_plan: &LogicalPlan, - ) -> Result> { - self.state().create_physical_plan(logical_plan).await - } - /// Get a new TaskContext to run in this session pub fn task_ctx(&self) -> Arc { Arc::new(TaskContext::from(self)) diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 312aef953e9c..360bac71c510 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -434,9 +434,9 @@ impl SessionConfig { /// converted to strings. /// /// Note that this method will eventually be deprecated and - /// replaced by [`config_options`]. + /// replaced by [`options`]. /// - /// [`config_options`]: Self::config_options + /// [`options`]: Self::options pub fn to_props(&self) -> HashMap { let mut map = HashMap::new(); // copy configs from config_options @@ -447,18 +447,6 @@ impl SessionConfig { map } - /// Return a handle to the configuration options. - #[deprecated(since = "21.0.0", note = "use options() instead")] - pub fn config_options(&self) -> &ConfigOptions { - &self.options - } - - /// Return a mutable handle to the configuration options. - #[deprecated(since = "21.0.0", note = "use options_mut() instead")] - pub fn config_options_mut(&mut self) -> &mut ConfigOptions { - &mut self.options - } - /// Add extensions. /// /// Extensions can be used to attach extra data to the session config -- e.g. tracing information or caches. diff --git a/datafusion/execution/src/task.rs b/datafusion/execution/src/task.rs index cae410655d10..4216ce95f35e 100644 --- a/datafusion/execution/src/task.rs +++ b/datafusion/execution/src/task.rs @@ -20,10 +20,7 @@ use std::{ sync::Arc, }; -use datafusion_common::{ - config::{ConfigOptions, Extensions}, - plan_datafusion_err, DataFusionError, Result, -}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF}; use crate::{ @@ -102,39 +99,6 @@ impl TaskContext { } } - /// Create a new task context instance, by first copying all - /// name/value pairs from `task_props` into a `SessionConfig`. - #[deprecated( - since = "21.0.0", - note = "Construct SessionConfig and call TaskContext::new() instead" - )] - pub fn try_new( - task_id: String, - session_id: String, - task_props: HashMap, - scalar_functions: HashMap>, - aggregate_functions: HashMap>, - runtime: Arc, - extensions: Extensions, - ) -> Result { - let mut config = ConfigOptions::new().with_extensions(extensions); - for (k, v) in task_props { - config.set(&k, &v)?; - } - let session_config = SessionConfig::from(config); - let window_functions = HashMap::new(); - - Ok(Self::new( - Some(task_id), - session_id, - session_config, - scalar_functions, - aggregate_functions, - window_functions, - runtime, - )) - } - /// Return the SessionConfig associated with this [TaskContext] pub fn session_config(&self) -> &SessionConfig { &self.session_config @@ -160,7 +124,7 @@ impl TaskContext { self.runtime.clone() } - /// Update the [`ConfigOptions`] + /// Update the [`SessionConfig`] pub fn with_session_config(mut self, session_config: SessionConfig) -> Self { self.session_config = session_config; self @@ -229,7 +193,10 @@ impl FunctionRegistry for TaskContext { #[cfg(test)] mod tests { use super::*; - use datafusion_common::{config::ConfigExtension, extensions_options}; + use datafusion_common::{ + config::{ConfigExtension, ConfigOptions, Extensions}, + extensions_options, + }; extensions_options! { struct TestExtension { diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 574de3e7082a..85f8c74f3737 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -218,19 +218,6 @@ impl FromStr for AggregateFunction { } } -/// Returns the datatype of the aggregate function. -/// This is used to get the returned data type for aggregate expr. -#[deprecated( - since = "27.0.0", - note = "please use `AggregateFunction::return_type` instead" -)] -pub fn return_type( - fun: &AggregateFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - impl AggregateFunction { /// Returns the datatype of the aggregate function given its argument types /// @@ -328,15 +315,6 @@ pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result { avg_sum_type(&coerced_data_types[0]) } -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `AggregateFunction::signature` instead" -)] -pub fn signature(fun: &AggregateFunction) -> Signature { - fun.signature() -} - impl AggregateFunction { /// the signatures supported by the function `fun`. pub fn signature(&self) -> Signature { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 0da05d96f67e..7ede4cd8ffc9 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -703,27 +703,6 @@ pub fn find_df_window_func(name: &str) -> Option { } } -/// Returns the datatype of the window function -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::return_type` instead" -)] -pub fn return_type( - fun: &WindowFunctionDefinition, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -/// the signatures supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `WindowFunction::signature` instead" -)] -pub fn signature(fun: &WindowFunctionDefinition) -> Signature { - fun.signature() -} - // Exists expression. #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct Exists { @@ -887,13 +866,6 @@ impl Expr { create_name(self) } - /// Returns the name of this expression as it should appear in a schema. This name - /// will not include any CAST expressions. - #[deprecated(since = "14.0.0", note = "please use `display_name` instead")] - pub fn name(&self) -> Result { - self.display_name() - } - /// Returns a full and complete string representation of this expression. pub fn canonical_name(&self) -> String { format!("{self}") diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 357b1aed7dde..ea3ffadda391 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -74,31 +74,6 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { .data() } -/// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions -/// in the `expr` expression tree. -#[deprecated( - since = "20.0.0", - note = "use normalize_col_with_schemas_and_ambiguity_check instead" -)] -#[allow(deprecated)] -pub fn normalize_col_with_schemas( - expr: Expr, - schemas: &[&Arc], - using_columns: &[HashSet], -) -> Result { - expr.transform(&|expr| { - Ok({ - if let Expr::Column(c) = expr { - let col = c.normalize_with_schemas(schemas, using_columns)?; - Transformed::yes(Expr::Column(col)) - } else { - Transformed::no(expr) - } - }) - }) - .data() -} - /// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage pub fn normalize_col_with_schemas_and_ambiguity_check( expr: Expr, @@ -398,24 +373,6 @@ mod test { ); } - #[test] - #[allow(deprecated)] - fn normalize_cols_priority() { - let expr = col("a") + col("b"); - // Schemas with multiple matches for column a, first takes priority - let schema_a = make_schema_with_empty_metadata(vec![make_field("tableA", "a")]); - let schema_b = make_schema_with_empty_metadata(vec![make_field("tableB", "b")]); - let schema_a2 = make_schema_with_empty_metadata(vec![make_field("tableA2", "a")]); - let schemas = vec![schema_a2, schema_b, schema_a] - .into_iter() - .map(Arc::new) - .collect::>(); - let schemas = schemas.iter().collect::>(); - - let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap(); - assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b")); - } - #[test] fn normalize_cols_non_exist() { // test normalizing columns when the name doesn't exist diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index a3760eeb357d..adf4dd3fef20 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,9 +17,7 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{ - Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature, -}; +use crate::{Accumulator, ColumnarValue, PartitionEvaluator}; use arrow::datatypes::DataType; use datafusion_common::Result; use std::sync::Arc; @@ -53,24 +51,3 @@ pub type PartitionEvaluatorFactory = /// its state, given its return datatype. pub type StateTypeFunction = Arc Result>> + Send + Sync>; - -/// Returns the datatype of the scalar function -#[deprecated( - since = "27.0.0", - note = "please use `BuiltinScalarFunction::return_type` instead" -)] -pub fn return_type( - fun: &BuiltinScalarFunction, - input_expr_types: &[DataType], -) -> Result { - fun.return_type(input_expr_types) -} - -/// Return the [`Signature`] supported by the function `fun`. -#[deprecated( - since = "27.0.0", - note = "please use `BuiltinScalarFunction::signature` instead" -)] -pub fn signature(fun: &BuiltinScalarFunction) -> Signature { - fun.signature() -} diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c6f280acb255..08fe3380061f 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -217,56 +217,6 @@ impl LogicalPlan { } } - /// Get all meaningful schemas of a plan and its children plan. - #[deprecated(since = "20.0.0")] - pub fn all_schemas(&self) -> Vec<&DFSchemaRef> { - match self { - // return self and children schemas - LogicalPlan::Window(_) - | LogicalPlan::Projection(_) - | LogicalPlan::Aggregate(_) - | LogicalPlan::Unnest(_) - | LogicalPlan::Join(_) - | LogicalPlan::CrossJoin(_) => { - let mut schemas = vec![self.schema()]; - self.inputs().iter().for_each(|input| { - schemas.push(input.schema()); - }); - schemas - } - // just return self.schema() - LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::EmptyRelation(_) - | LogicalPlan::Ddl(_) - | LogicalPlan::Dml(_) - | LogicalPlan::Copy(_) - | LogicalPlan::Values(_) - | LogicalPlan::SubqueryAlias(_) - | LogicalPlan::Union(_) - | LogicalPlan::Extension(_) - | LogicalPlan::TableScan(_) => { - vec![self.schema()] - } - LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { - // return only the schema of the static term - static_term.all_schemas() - } - // return children schemas - LogicalPlan::Limit(_) - | LogicalPlan::Subquery(_) - | LogicalPlan::Repartition(_) - | LogicalPlan::Sort(_) - | LogicalPlan::Filter(_) - | LogicalPlan::Distinct(_) - | LogicalPlan::Prepare(_) => { - self.inputs().iter().map(|p| p.schema()).collect() - } - // return empty - LogicalPlan::Statement(_) | LogicalPlan::DescribeTable(_) => vec![], - } - } - /// Returns the (fixed) output schema for explain plans pub fn explain_schema() -> SchemaRef { SchemaRef::new(Schema::new(vec![ @@ -3079,14 +3029,6 @@ digraph { empty_schema: DFSchemaRef, } - impl NoChildExtension { - fn empty() -> Self { - Self { - empty_schema: Arc::new(DFSchema::empty()), - } - } - } - impl UserDefinedLogicalNode for NoChildExtension { fn as_any(&self) -> &dyn std::any::Any { unimplemented!() @@ -3129,18 +3071,6 @@ digraph { } } - #[test] - #[allow(deprecated)] - fn test_extension_all_schemas() { - let plan = LogicalPlan::Extension(Extension { - node: Arc::new(NoChildExtension::empty()), - }); - - let schemas = plan.all_schemas(); - assert_eq!(1, schemas.len()); - assert_eq!(0, schemas[0].fields().len()); - } - #[test] fn test_replace_invalid_placeholder() { // test empty placeholder diff --git a/datafusion/physical-plan/src/common.rs b/datafusion/physical-plan/src/common.rs index f4a2cba68e16..59c54199333e 100644 --- a/datafusion/physical-plan/src/common.rs +++ b/datafusion/physical-plan/src/common.rs @@ -349,12 +349,6 @@ pub fn can_project( } } -/// Returns the total number of bytes of memory occupied physically by this batch. -#[deprecated(since = "28.0.0", note = "RecordBatch::get_array_memory_size")] -pub fn batch_byte_size(batch: &RecordBatch) -> usize { - batch.get_array_memory_size() -} - #[cfg(test)] mod tests { use std::ops::Not; diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index db352bb2c86f..a80dab058ca6 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -733,16 +733,6 @@ pub struct SortExec { } impl SortExec { - /// Create a new sort execution plan - #[deprecated(since = "22.0.0", note = "use `new` and `with_fetch`")] - pub fn try_new( - expr: Vec, - input: Arc, - fetch: Option, - ) -> Result { - Ok(Self::new(expr, input).with_fetch(fetch)) - } - /// Create a new sort execution plan that produces a single, /// sorted output partition. pub fn new(expr: Vec, input: Arc) -> Self { @@ -758,23 +748,6 @@ impl SortExec { } } - /// Create a new sort execution plan with the option to preserve - /// the partitioning of the input plan - #[deprecated( - since = "22.0.0", - note = "use `new`, `with_fetch` and `with_preserve_partioning` instead" - )] - pub fn new_with_partitioning( - expr: Vec, - input: Arc, - preserve_partitioning: bool, - fetch: Option, - ) -> Self { - Self::new(expr, input) - .with_fetch(fetch) - .with_preserve_partitioning(preserve_partitioning) - } - /// Whether this `SortExec` preserves partitioning of the children pub fn preserve_partitioning(&self) -> bool { self.preserve_partitioning From 2499245f348f2b8fe9777ab7ff7552642c56b4ce Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 18 Mar 2024 17:12:04 -0400 Subject: [PATCH 008/117] Migrate `arrow_cast` to a UDF (#9610) * feat: arrow_cast function as UDF * fix: cargo.lock in datafusion-cli * fix: unwrap arg1 on match arm Co-authored-by: Andrew Lamb * fix: unwrap on matching arms using some * Rewrite to use simplify API * Update error messages * Fix up tests * Update cargo.lock * fix test * fix * Fix merge errors, return error --------- Co-authored-by: Brayan Jules Co-authored-by: Brayan Jules --- datafusion-examples/examples/to_char.rs | 46 +++--- .../core/tests/optimizer_integration.rs | 27 +++- .../user_defined/user_defined_aggregates.rs | 10 +- .../expr => functions/src/core}/arrow_cast.rs | 145 ++++++++++++------ datafusion/functions/src/core/mod.rs | 3 + datafusion/sql/src/expr/function.rs | 8 - datafusion/sql/src/expr/mod.rs | 1 - datafusion/sql/src/lib.rs | 1 - datafusion/sql/tests/sql_integration.rs | 14 +- .../sqllogictest/test_files/arrow_typeof.slt | 9 +- 10 files changed, 155 insertions(+), 109 deletions(-) rename datafusion/{sql/src/expr => functions/src/core}/arrow_cast.rs (90%) diff --git a/datafusion-examples/examples/to_char.rs b/datafusion-examples/examples/to_char.rs index e99f69fbcd55..ef616d72cc1c 100644 --- a/datafusion-examples/examples/to_char.rs +++ b/datafusion-examples/examples/to_char.rs @@ -125,14 +125,14 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+------------+", - "| t.values |", - "+------------+", - "| 2020-09-01 |", - "| 2020-09-02 |", - "| 2020-09-03 |", - "| 2020-09-04 |", - "+------------+", + "+-----------------------------------+", + "| arrow_cast(t.values,Utf8(\"Utf8\")) |", + "+-----------------------------------+", + "| 2020-09-01 |", + "| 2020-09-02 |", + "| 2020-09-03 |", + "| 2020-09-04 |", + "+-----------------------------------+", ], &result ); @@ -146,11 +146,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+-----------------------------------------------------------------+", - "| to_char(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", - "+-----------------------------------------------------------------+", - "| 03-08-2023 14:38:50 |", - "+-----------------------------------------------------------------+", + "+-------------------------------------------------------------------------------------------------------------+", + "| to_char(arrow_cast(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"Timestamp(Second, None)\")),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", + "+-------------------------------------------------------------------------------------------------------------+", + "| 03-08-2023 14:38:50 |", + "+-------------------------------------------------------------------------------------------------------------+", ], &result ); @@ -165,11 +165,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------+", - "| to_char(Int64(123456),Utf8(\"pretty\")) |", - "+---------------------------------------+", - "| 1 days 10 hours 17 mins 36 secs |", - "+---------------------------------------+", + "+----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"pretty\")) |", + "+----------------------------------------------------------------------------+", + "| 1 days 10 hours 17 mins 36 secs |", + "+----------------------------------------------------------------------------+", ], &result ); @@ -184,11 +184,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+----------------------------------------+", - "| to_char(Int64(123456),Utf8(\"iso8601\")) |", - "+----------------------------------------+", - "| PT123456S |", - "+----------------------------------------+", + "+-----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"iso8601\")) |", + "+-----------------------------------------------------------------------------+", + "| PT123456S |", + "+-----------------------------------------------------------------------------+", ], &result ); diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index f9696955769e..60010bdddfb8 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -15,6 +15,9 @@ // specific language governing permissions and limitations // under the License. +//! Tests for the DataFusion SQL query planner that require functions from the +//! datafusion-functions crate. + use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -42,12 +45,18 @@ fn init() { let _ = env_logger::try_init(); } +#[test] +fn select_arrow_cast() { + let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; + let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ + \n EmptyRelation"; + quick_test(sql, expected); +} #[test] fn timestamp_nano_ts_none_predicates() -> Result<()> { let sql = "SELECT col_int32 FROM test WHERE col_ts_nano_none < (now() - interval '1 hour')"; - let plan = test_sql(sql)?; // a scan should have the now()... predicate folded to a single // constant and compared to the column without a cast so it can be // pushed down / pruned @@ -55,7 +64,7 @@ fn timestamp_nano_ts_none_predicates() -> Result<()> { "Projection: test.col_int32\ \n Filter: test.col_ts_nano_none < TimestampNanosecond(1666612093000000000, None)\ \n TableScan: test projection=[col_int32, col_ts_nano_none]"; - assert_eq!(expected, format!("{plan:?}")); + quick_test(sql, expected); Ok(()) } @@ -74,6 +83,11 @@ fn timestamp_nano_ts_utc_predicates() { assert_eq!(expected, format!("{plan:?}")); } +fn quick_test(sql: &str, expected_plan: &str) { + let plan = test_sql(sql).unwrap(); + assert_eq!(expected_plan, format!("{:?}", plan)); +} + fn test_sql(sql: &str) -> Result { // parse the SQL let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... @@ -81,12 +95,9 @@ fn test_sql(sql: &str) -> Result { let statement = &ast[0]; // create a logical query plan - let now_udf = datetime::functions() - .iter() - .find(|f| f.name() == "now") - .unwrap() - .to_owned(); - let context_provider = MyContextProvider::default().with_udf(now_udf); + let context_provider = MyContextProvider::default() + .with_udf(datetime::now()) + .with_udf(datafusion_functions::core::arrow_cast()); let sql_to_rel = SqlToRel::new(&context_provider); let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 3f40c55a3ed7..a58a8cf51681 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -184,11 +184,11 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let expected = [ - "+-------------+", - "| SUM(t.time) |", - "+-------------+", - "| 19000 |", - "+-------------+", + "+---------------------------------------+", + "| SUM(arrow_cast(t.time,Utf8(\"Int64\"))) |", + "+---------------------------------------+", + "| 19000 |", + "+---------------------------------------+", ]; assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs similarity index 90% rename from datafusion/sql/src/expr/arrow_cast.rs rename to datafusion/functions/src/core/arrow_cast.rs index a75cdf9e3c6b..b6c1b5eb9a38 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -15,63 +15,125 @@ // specific language governing permissions and limitations // under the License. -//! Implementation of the `arrow_cast` function that allows -//! casting to arbitrary arrow types (rather than SQL types) +//! [`ArrowCastFunc`]: Implementation of the `arrow_cast` +use std::any::Any; use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; -use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; +use arrow::datatypes::{DataType, Field, IntervalUnit, TimeUnit}; use datafusion_common::{ - plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, + internal_err, plan_datafusion_err, plan_err, DataFusionError, ExprSchema, Result, + ScalarValue, }; -use datafusion_common::plan_err; -use datafusion_expr::{Expr, ExprSchemable}; +use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; -pub const ARROW_CAST_NAME: &str = "arrow_cast"; - -/// Create an [`Expr`] that evaluates the `arrow_cast` function +/// Implements casting to arbitrary arrow types (rather than SQL types) +/// +/// Note that the `arrow_cast` function is somewhat special in that its +/// return depends only on the *value* of its second argument (not its type) /// -/// This function is not a [`BuiltinScalarFunction`] because the -/// return type of [`BuiltinScalarFunction`] depends only on the -/// *types* of the arguments. However, the type of `arrow_type` depends on -/// the *value* of its second argument. +/// It is implemented by calling the same underlying arrow `cast` kernel as +/// normal SQL casts. /// -/// Use the `cast` function to cast to SQL type (which is then mapped -/// to the corresponding arrow type). For example to cast to `int` -/// (which is then mapped to the arrow type `Int32`) +/// For example to cast to `int` using SQL (which is then mapped to the arrow +/// type `Int32`) /// /// ```sql /// select cast(column_x as int) ... /// ``` /// -/// Use the `arrow_cast` functiont to cast to a specfic arrow type +/// You can use the `arrow_cast` functiont to cast to a specific arrow type /// /// For example /// ```sql /// select arrow_cast(column_x, 'Float64') /// ``` -/// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction -pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { +#[derive(Debug)] +pub(super) struct ArrowCastFunc { + signature: Signature, +} + +impl ArrowCastFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ArrowCastFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "arrow_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + // should be using return_type_from_exprs and not calling the default + // implementation + internal_err!("arrow_cast should return type from exprs") + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + _schema: &dyn ExprSchema, + _arg_types: &[DataType], + ) -> Result { + data_type_from_args(args) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + internal_err!("arrow_cast should have been simplified to cast") + } + + fn simplify( + &self, + mut args: Vec, + info: &dyn SimplifyInfo, + ) -> Result { + // convert this into a real cast + let target_type = data_type_from_args(&args)?; + // remove second (type) argument + args.pop().unwrap(); + let arg = args.pop().unwrap(); + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == target_type { + // the argument's data type is already the correct type + arg + } else { + // Use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(arg), + data_type: target_type, + }) + }; + // return the newly written argument to DataFusion + Ok(ExprSimplifyResult::Simplified(new_expr)) + } +} + +/// Returns the requested type from the arguments +fn data_type_from_args(args: &[Expr]) -> Result { if args.len() != 2 { return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } - let arg1 = args.pop().unwrap(); - let arg0 = args.pop().unwrap(); - - // arg1 must be a string - let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { - v - } else { + let Expr::Literal(ScalarValue::Utf8(Some(val))) = &args[1] else { return plan_err!( - "arrow_cast requires its second argument to be a constant string, got {arg1}" + "arrow_cast requires its second argument to be a constant string, got {:?}", + &args[1] ); }; - - // do the actual lookup to the appropriate data type - let data_type = parse_data_type(&data_type_string)?; - - arg0.cast_to(&data_type, schema) + parse_data_type(val) } /// Parses `str` into a `DataType`. @@ -80,22 +142,8 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result /// impl, and maintains the invariant that /// `parse_data_type(data_type.to_string()) == data_type` /// -/// Example: -/// ``` -/// # use datafusion_sql::parse_data_type; -/// # use arrow_schema::DataType; -/// let display_value = "Int32"; -/// -/// // "Int32" is the Display value of `DataType` -/// assert_eq!(display_value, &format!("{}", DataType::Int32)); -/// -/// // parse_data_type coverts "Int32" back to `DataType`: -/// let data_type = parse_data_type(display_value).unwrap(); -/// assert_eq!(data_type, DataType::Int32); -/// ``` -/// /// Remove if added to arrow: -pub fn parse_data_type(val: &str) -> Result { +fn parse_data_type(val: &str) -> Result { Parser::new(val).parse() } @@ -647,8 +695,6 @@ impl Display for Token { #[cfg(test)] mod test { - use arrow_schema::{IntervalUnit, TimeUnit}; - use super::*; #[test] @@ -844,7 +890,6 @@ mod test { assert!(message.contains("Must be a supported arrow type name such as 'Int32' or 'Timestamp(Nanosecond, None)'")); } } - println!(" Ok"); } } } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 73cc4d18bf9f..5a0bd2c77f63 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -17,6 +17,7 @@ //! "core" DataFusion functions +mod arrow_cast; mod arrowtypeof; mod getfield; mod nullif; @@ -25,6 +26,7 @@ mod nvl2; mod r#struct; // create UDFs +make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); make_udf_function!(nvl::NVLFunc, NVL, nvl); make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); @@ -35,6 +37,7 @@ make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."), + (arrow_cast, arg_1 arg_2, "returns arg_1 cast to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`."), (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"), (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."), (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."), diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index ffc951a6fa66..582404b29749 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -34,8 +34,6 @@ use sqlparser::ast::{ use std::str::FromStr; use strum::IntoEnumIterator; -use super::arrow_cast::ARROW_CAST_NAME; - /// Suggest a valid function based on an invalid input function name pub fn suggest_valid_function( input_function_name: &str, @@ -249,12 +247,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { null_treatment, ))); }; - - // Special case arrow_cast (as its type is dependent on its argument value) - if name == ARROW_CAST_NAME { - let args = self.function_args_to_expr(args, schema, planner_context)?; - return super::arrow_cast::create_arrow_cast(args, schema); - } } // Could not find the relevant function, so return an error diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index a6f1c78c7250..5e9c0623a265 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod arrow_cast; mod binary_op; mod function; mod grouping_set; diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index e8e07eebe22d..12d6a4669634 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -42,5 +42,4 @@ pub mod utils; mod values; pub use datafusion_common::{ResolvedTableReference, TableReference}; -pub use expr::arrow_cast::parse_data_type; pub use sqlparser; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index c9c2bdd694b5..b6077353e5dd 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2566,15 +2566,6 @@ fn approx_median_window() { quick_test(sql, expected); } -#[test] -fn select_arrow_cast() { - let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo', 'LargeUtf8')"; - let expected = "\ - Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\ - \n EmptyRelation"; - quick_test(sql, expected); -} - #[test] fn select_typed_date_string() { let sql = "SELECT date '2020-12-10' AS date"; @@ -2670,6 +2661,11 @@ fn logical_plan_with_dialect_and_options( vec![DataType::Int32, DataType::Int32], DataType::Int32, )) + .with_udf(make_udf( + "arrow_cast", + vec![DataType::Int64, DataType::Utf8], + DataType::Float64, + )) .with_udf(make_udf( "date_trunc", vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)], diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 8b3bd7eac95d..3e8694f3b2c2 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -92,10 +92,11 @@ SELECT arrow_cast('1', 'Int16') 1 # Basic error test -query error Error during planning: arrow_cast needs 2 arguments, 1 provided +query error DataFusion error: Error during planning: No function matches the given name and argument types 'arrow_cast\(Utf8\)'. You might need to add explicit type casts. SELECT arrow_cast('1') -query error Error during planning: arrow_cast requires its second argument to be a constant string, got Int64\(43\) + +query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown @@ -315,7 +316,7 @@ select arrow_cast(interval '30 minutes', 'Duration(Second)'); ---- 0 days 0 hours 30 mins 0 secs -query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration\(Second\) +query error DataFusion error: This feature is not implemented: Unsupported CAST from Utf8 to Duration\(Second\) select arrow_cast('30 minutes', 'Duration(Second)'); @@ -336,7 +337,7 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( ---- 2000-01-01T00:00:00+08:00 -statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone +statement error DataFusion error: Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); From e53eb036f5c61f7d7bd90047976511628ddca2d0 Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Mon, 18 Mar 2024 22:13:14 +0100 Subject: [PATCH 009/117] parquet: Add row_groups_matched_{statistics,bloom_filter} statistics (#9640) * test_row_group_prune: Display which assertion failed * Add row_groups_matched_{statistics,bloom_filter} statistics This helps diagnostic whether a Bloom filter mismatches (because of high false-positive probability caused by suboptimal tuning) or is not used at all. --- .../physical_plan/parquet/metrics.rs | 14 ++ .../physical_plan/parquet/row_groups.rs | 4 + datafusion/core/tests/parquet/mod.rs | 17 +++ .../core/tests/parquet/row_group_pruning.rs | 126 +++++++++++++++++- datafusion/core/tests/sql/explain_analyze.rs | 4 + 5 files changed, 160 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs index a17a3c6d9752..c2a7e4345a5b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/metrics.rs @@ -29,8 +29,12 @@ use crate::physical_plan::metrics::{ pub struct ParquetFileMetrics { /// Number of times the predicate could not be evaluated pub predicate_evaluation_errors: Count, + /// Number of row groups whose bloom filters were checked and matched + pub row_groups_matched_bloom_filter: Count, /// Number of row groups pruned by bloom filters pub row_groups_pruned_bloom_filter: Count, + /// Number of row groups whose statistics were checked and matched + pub row_groups_matched_statistics: Count, /// Number of row groups pruned by statistics pub row_groups_pruned_statistics: Count, /// Total number of bytes scanned @@ -56,10 +60,18 @@ impl ParquetFileMetrics { .with_new_label("filename", filename.to_string()) .counter("predicate_evaluation_errors", partition); + let row_groups_matched_bloom_filter = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("row_groups_matched_bloom_filter", partition); + let row_groups_pruned_bloom_filter = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .counter("row_groups_pruned_bloom_filter", partition); + let row_groups_matched_statistics = MetricBuilder::new(metrics) + .with_new_label("filename", filename.to_string()) + .counter("row_groups_matched_statistics", partition); + let row_groups_pruned_statistics = MetricBuilder::new(metrics) .with_new_label("filename", filename.to_string()) .counter("row_groups_pruned_statistics", partition); @@ -85,7 +97,9 @@ impl ParquetFileMetrics { Self { predicate_evaluation_errors, + row_groups_matched_bloom_filter, row_groups_pruned_bloom_filter, + row_groups_matched_statistics, row_groups_pruned_statistics, bytes_scanned, pushdown_rows_filtered, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index ef2eb775e037..1a84f52a33fd 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -94,6 +94,7 @@ pub(crate) fn prune_row_groups_by_statistics( metrics.predicate_evaluation_errors.add(1); } } + metrics.row_groups_matched_statistics.add(1); } filtered.push(idx) @@ -166,6 +167,9 @@ pub(crate) async fn prune_row_groups_by_bloom_filters< if prune_group { metrics.row_groups_pruned_bloom_filter.add(1); } else { + if !stats.column_sbbf.is_empty() { + metrics.row_groups_matched_bloom_filter.add(1); + } filtered.push(*idx); } } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 7649b6acd45c..c60780919489 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -117,16 +117,33 @@ impl TestOutput { self.metric_value("predicate_evaluation_errors") } + /// The number of row_groups matched by bloom filter + fn row_groups_matched_bloom_filter(&self) -> Option { + self.metric_value("row_groups_matched_bloom_filter") + } + /// The number of row_groups pruned by bloom filter fn row_groups_pruned_bloom_filter(&self) -> Option { self.metric_value("row_groups_pruned_bloom_filter") } + /// The number of row_groups matched by statistics + fn row_groups_matched_statistics(&self) -> Option { + self.metric_value("row_groups_matched_statistics") + } + /// The number of row_groups pruned by statistics fn row_groups_pruned_statistics(&self) -> Option { self.metric_value("row_groups_pruned_statistics") } + /// The number of row_groups matched by bloom filter or statistics + fn row_groups_matched(&self) -> Option { + self.row_groups_matched_bloom_filter() + .zip(self.row_groups_matched_statistics()) + .map(|(a, b)| a + b) + } + /// The number of row_groups pruned fn row_groups_pruned(&self) -> Option { self.row_groups_pruned_bloom_filter() diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index fa53b9c56cec..b7038ef1a73f 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -29,7 +29,9 @@ struct RowGroupPruningTest { scenario: Scenario, query: String, expected_errors: Option, + expected_row_group_matched_by_statistics: Option, expected_row_group_pruned_by_statistics: Option, + expected_row_group_matched_by_bloom_filter: Option, expected_row_group_pruned_by_bloom_filter: Option, expected_results: usize, } @@ -40,7 +42,9 @@ impl RowGroupPruningTest { scenario: Scenario::Timestamps, // or another default query: String::new(), expected_errors: None, + expected_row_group_matched_by_statistics: None, expected_row_group_pruned_by_statistics: None, + expected_row_group_matched_by_bloom_filter: None, expected_row_group_pruned_by_bloom_filter: None, expected_results: 0, } @@ -64,12 +68,24 @@ impl RowGroupPruningTest { self } + // Set the expected matched row groups by statistics + fn with_matched_by_stats(mut self, matched_by_stats: Option) -> Self { + self.expected_row_group_matched_by_statistics = matched_by_stats; + self + } + // Set the expected pruned row groups by statistics fn with_pruned_by_stats(mut self, pruned_by_stats: Option) -> Self { self.expected_row_group_pruned_by_statistics = pruned_by_stats; self } + // Set the expected matched row groups by bloom filter + fn with_matched_by_bloom_filter(mut self, matched_by_bf: Option) -> Self { + self.expected_row_group_matched_by_bloom_filter = matched_by_bf; + self + } + // Set the expected pruned row groups by bloom filter fn with_pruned_by_bloom_filter(mut self, pruned_by_bf: Option) -> Self { self.expected_row_group_pruned_by_bloom_filter = pruned_by_bf; @@ -90,20 +106,36 @@ impl RowGroupPruningTest { .await; println!("{}", output.description()); - assert_eq!(output.predicate_evaluation_errors(), self.expected_errors); + assert_eq!( + output.predicate_evaluation_errors(), + self.expected_errors, + "mismatched predicate_evaluation" + ); + assert_eq!( + output.row_groups_matched_statistics(), + self.expected_row_group_matched_by_statistics, + "mismatched row_groups_matched_statistics", + ); assert_eq!( output.row_groups_pruned_statistics(), - self.expected_row_group_pruned_by_statistics + self.expected_row_group_pruned_by_statistics, + "mismatched row_groups_pruned_statistics", + ); + assert_eq!( + output.row_groups_matched_bloom_filter(), + self.expected_row_group_matched_by_bloom_filter, + "mismatched row_groups_matched_bloom_filter", ); assert_eq!( output.row_groups_pruned_bloom_filter(), - self.expected_row_group_pruned_by_bloom_filter + self.expected_row_group_pruned_by_bloom_filter, + "mismatched row_groups_pruned_bloom_filter", ); assert_eq!( output.result_rows, self.expected_results, - "{}", - output.description() + "mismatched expected rows: {}", + output.description(), ); } } @@ -114,7 +146,9 @@ async fn prune_timestamps_nanos() { .with_scenario(Scenario::Timestamps) .with_query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -129,7 +163,9 @@ async fn prune_timestamps_micros() { "SELECT * FROM t where micros < to_timestamp_micros('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -144,7 +180,9 @@ async fn prune_timestamps_millis() { "SELECT * FROM t where micros < to_timestamp_millis('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -159,7 +197,9 @@ async fn prune_timestamps_seconds() { "SELECT * FROM t where seconds < to_timestamp_seconds('2020-01-02 01:01:11Z')", ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -172,7 +212,9 @@ async fn prune_date32() { .with_scenario(Scenario::Dates) .with_query("SELECT * FROM t where date32 < cast('2020-01-02' as date)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -201,6 +243,7 @@ async fn prune_date64() { println!("{}", output.description()); // This should prune out groups without error assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_matched(), Some(1)); assert_eq!(output.row_groups_pruned(), Some(3)); assert_eq!(output.result_rows, 1, "{}", output.description()); } @@ -211,7 +254,9 @@ async fn prune_disabled() { .with_scenario(Scenario::Timestamps) .with_query("SELECT * FROM t where nanos < to_timestamp('2020-01-02 01:01:11Z')") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(10) .test_row_group_prune() @@ -230,6 +275,7 @@ async fn prune_disabled() { // This should not prune any assert_eq!(output.predicate_evaluation_errors(), Some(0)); + assert_eq!(output.row_groups_matched(), Some(0)); assert_eq!(output.row_groups_pruned(), Some(0)); assert_eq!( output.result_rows, @@ -245,7 +291,9 @@ async fn prune_int32_lt() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i < 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -257,7 +305,9 @@ async fn prune_int32_lt() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where -i > -1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -270,7 +320,9 @@ async fn prune_int32_eq() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i = 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -282,7 +334,9 @@ async fn prune_int32_scalar_fun_and_eq() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i = 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -295,7 +349,9 @@ async fn prune_int32_scalar_fun() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where abs(i) = 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(3) .test_row_group_prune() @@ -308,7 +364,9 @@ async fn prune_int32_complex_expr() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i+1 = 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -321,7 +379,9 @@ async fn prune_int32_complex_expr_subtract() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where 1-i > 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -334,7 +394,9 @@ async fn prune_f64_lt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where f < 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -343,7 +405,9 @@ async fn prune_f64_lt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where -f > -1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(11) .test_row_group_prune() @@ -358,7 +422,9 @@ async fn prune_f64_scalar_fun_and_gt() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where abs(f - 1) <= 0.000001 and f >= 0.1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -372,7 +438,9 @@ async fn prune_f64_scalar_fun() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where abs(f-1) <= 0.000001") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -386,7 +454,9 @@ async fn prune_f64_complex_expr() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where f+1 > 1.1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -400,7 +470,9 @@ async fn prune_f64_complex_expr_subtract() { .with_scenario(Scenario::Float64) .with_query("SELECT * FROM t where 1-f > 1") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(9) .test_row_group_prune() @@ -414,7 +486,9 @@ async fn prune_int32_eq_in_list() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i in (1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(1) .test_row_group_prune() @@ -429,7 +503,9 @@ async fn prune_int32_eq_in_list_2() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i in (1000)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) .with_pruned_by_stats(Some(4)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(0) .test_row_group_prune() @@ -449,7 +525,9 @@ async fn prune_int32_eq_large_in_list() { .as_str(), ) .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(1)) .with_expected_rows(0) .test_row_group_prune() @@ -463,7 +541,9 @@ async fn prune_int32_eq_in_list_negated() { .with_scenario(Scenario::Int32) .with_query("SELECT * FROM t where i not in (1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(4)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(19) .test_row_group_prune() @@ -479,7 +559,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col < 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -488,7 +570,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) .test_row_group_prune() @@ -497,7 +581,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col < 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -506,7 +592,9 @@ async fn prune_decimal_lt() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col < cast(4.55 as decimal(20,2))") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(8) .test_row_group_prune() @@ -522,7 +610,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col = 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -531,7 +621,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col = 4.00") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -541,7 +633,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col = 4") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -550,7 +644,9 @@ async fn prune_decimal_eq() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col = 4.00") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() @@ -567,7 +663,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4,3,2,123456789123)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -576,7 +674,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -585,7 +685,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::Decimal) .with_query("SELECT * FROM t where decimal_col in (4,3,2,123456789123)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -594,7 +696,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalLargePrecision) .with_query("SELECT * FROM t where decimal_col in (4.00,3.00,11.2345,1)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(6) .test_row_group_prune() @@ -605,7 +709,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalBloomFilterInt32) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) .test_row_group_prune() @@ -616,7 +722,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalBloomFilterInt64) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) .test_row_group_prune() @@ -627,7 +735,9 @@ async fn prune_decimal_in_list() { .with_scenario(Scenario::DecimalLargePrecisionBloomFilter) .with_query("SELECT * FROM t where decimal_col in (5)") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(2)) .with_expected_rows(1) .test_row_group_prune() @@ -644,7 +754,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(2)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(7) .test_row_group_prune() @@ -653,7 +765,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"name\" != 'HTTP GET / DISPATCH'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(5) .test_row_group_prune() @@ -662,7 +776,9 @@ async fn prune_periods_in_column_names() { .with_scenario(Scenario::PeriodsInColumnNames) .with_query( "SELECT \"name\", \"service.name\" FROM t WHERE \"service.name\" = 'frontend' AND \"name\" != 'HTTP GET / DISPATCH'") .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(1)) .with_pruned_by_bloom_filter(Some(0)) .with_expected_rows(2) .test_row_group_prune() diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 695b3ba745e2..30b11fe2a0ee 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -737,7 +737,9 @@ async fn parquet_explain_analyze() { // should contain aggregated stats assert_contains!(&formatted, "output_rows=8"); + assert_contains!(&formatted, "row_groups_matched_bloom_filter=0"); assert_contains!(&formatted, "row_groups_pruned_bloom_filter=0"); + assert_contains!(&formatted, "row_groups_matched_statistics=1"); assert_contains!(&formatted, "row_groups_pruned_statistics=0"); } @@ -754,7 +756,9 @@ async fn parquet_explain_analyze_verbose() { .to_string(); // should contain the raw per file stats (with the label) + assert_contains!(&formatted, "row_groups_matched_bloom_filter{partition=0"); assert_contains!(&formatted, "row_groups_pruned_bloom_filter{partition=0"); + assert_contains!(&formatted, "row_groups_matched_statistics{partition=0"); assert_contains!(&formatted, "row_groups_pruned_statistics{partition=0"); } From b137f60b9b6132d389efa9911b929d7b4d285b3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Metehan=20Y=C4=B1ld=C4=B1r=C4=B1m?= <100111937+metesynnada@users.noreply.github.com> Date: Tue, 19 Mar 2024 01:45:26 +0300 Subject: [PATCH 010/117] Make COPY TO align with CREATE EXTERNAL TABLE (#9604) --- datafusion-cli/src/catalog.rs | 2 +- datafusion-cli/src/exec.rs | 6 +- datafusion/common/src/config.rs | 221 +++++++++++++----- datafusion/common/src/file_options/mod.rs | 85 +++---- datafusion/core/src/dataframe/mod.rs | 9 +- datafusion/core/src/dataframe/parquet.rs | 5 +- .../src/datasource/file_format/options.rs | 2 +- .../core/src/datasource/listing/table.rs | 40 ++-- .../src/datasource/listing_table_factory.rs | 6 +- datafusion/core/src/execution/context/mod.rs | 15 +- datafusion/core/src/physical_planner.rs | 2 +- datafusion/core/src/test_util/parquet.rs | 2 +- datafusion/core/tests/sql/sql_api.rs | 12 +- .../tests/cases/roundtrip_logical_plan.rs | 9 +- datafusion/sql/src/parser.rs | 206 ++++++++++++---- datafusion/sql/src/statement.rs | 139 ++++------- datafusion/sql/tests/sql_integration.rs | 26 ++- datafusion/sqllogictest/test_files/copy.slt | 159 ++++++------- .../test_files/create_external_table.slt | 4 +- .../sqllogictest/test_files/csv_files.slt | 10 +- .../sqllogictest/test_files/group_by.slt | 8 +- .../sqllogictest/test_files/parquet.slt | 8 +- .../sqllogictest/test_files/repartition.slt | 2 +- .../test_files/repartition_scan.slt | 8 +- .../test_files/schema_evolution.slt | 8 +- docs/source/user-guide/sql/dml.md | 2 +- 26 files changed, 598 insertions(+), 398 deletions(-) diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index a8ecb98637cb..46dd8bb00f06 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -189,7 +189,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { &state, table_url.scheme(), url, - state.default_table_options(), + &state.default_table_options(), ) .await?; state.runtime_env().register_object_store(url, store); diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index b11f1c202284..ea765ee8eceb 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -412,7 +412,7 @@ mod tests { ) })?; for location in locations { - let sql = format!("copy (values (1,2)) to '{}';", location); + let sql = format!("copy (values (1,2)) to '{}' STORED AS PARQUET;", location); let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?; for statement in statements { //Should not fail @@ -438,8 +438,8 @@ mod tests { let location = "s3://bucket/path/file.parquet"; // Missing region, use object_store defaults - let sql = format!("COPY (values (1,2)) TO '{location}' - (format parquet, 'aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')"); + let sql = format!("COPY (values (1,2)) TO '{location}' STORED AS PARQUET + OPTIONS ('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}')"); copy_to_table_test(location, &sql).await?; Ok(()) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 68b9ec9dab94..968d8215ca4d 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -1109,58 +1109,163 @@ macro_rules! extensions_options { } } +/// Represents the configuration options available for handling different table formats within a data processing application. +/// This struct encompasses options for various file formats including CSV, Parquet, and JSON, allowing for flexible configuration +/// of parsing and writing behaviors specific to each format. Additionally, it supports extending functionality through custom extensions. #[derive(Debug, Clone, Default)] pub struct TableOptions { + /// Configuration options for CSV file handling. This includes settings like the delimiter, + /// quote character, and whether the first row is considered as headers. pub csv: CsvOptions, + + /// Configuration options for Parquet file handling. This includes settings for compression, + /// encoding, and other Parquet-specific file characteristics. pub parquet: TableParquetOptions, + + /// Configuration options for JSON file handling. pub json: JsonOptions, + + /// The current file format that the table operations should assume. This option allows + /// for dynamic switching between the supported file types (e.g., CSV, Parquet, JSON). pub current_format: Option, - /// Optional extensions registered using [`Extensions::insert`] + + /// Optional extensions that can be used to extend or customize the behavior of the table + /// options. Extensions can be registered using `Extensions::insert` and might include + /// custom file handling logic, additional configuration parameters, or other enhancements. pub extensions: Extensions, } impl ConfigField for TableOptions { + /// Visits configuration settings for the current file format, or all formats if none is selected. + /// + /// This method adapts the behavior based on whether a file format is currently selected in `current_format`. + /// If a format is selected, it visits only the settings relevant to that format. Otherwise, + /// it visits all available format settings. fn visit(&self, v: &mut V, _key_prefix: &str, _description: &'static str) { - self.csv.visit(v, "csv", ""); - self.parquet.visit(v, "parquet", ""); - self.json.visit(v, "json", ""); + if let Some(file_type) = &self.current_format { + match file_type { + #[cfg(feature = "parquet")] + FileType::PARQUET => self.parquet.visit(v, "format", ""), + FileType::CSV => self.csv.visit(v, "format", ""), + FileType::JSON => self.json.visit(v, "format", ""), + _ => {} + } + } else { + self.csv.visit(v, "csv", ""); + self.parquet.visit(v, "parquet", ""); + self.json.visit(v, "json", ""); + } } + /// Sets a configuration value for a specific key within `TableOptions`. + /// + /// This method delegates setting configuration values to the specific file format configurations, + /// based on the current format selected. If no format is selected, it returns an error. + /// + /// # Parameters + /// + /// * `key`: The configuration key specifying which setting to adjust, prefixed with the format (e.g., "format.delimiter") + /// for CSV format. + /// * `value`: The value to set for the specified configuration key. + /// + /// # Returns + /// + /// A result indicating success or an error if the key is not recognized, if a format is not specified, + /// or if setting the configuration value fails for the specific format. fn set(&mut self, key: &str, value: &str) -> Result<()> { // Extensions are handled in the public `ConfigOptions::set` let (key, rem) = key.split_once('.').unwrap_or((key, "")); + let Some(format) = &self.current_format else { + return _config_err!("Specify a format for TableOptions"); + }; match key { - "csv" => self.csv.set(rem, value), - "parquet" => self.parquet.set(rem, value), - "json" => self.json.set(rem, value), + "format" => match format { + #[cfg(feature = "parquet")] + FileType::PARQUET => self.parquet.set(rem, value), + FileType::CSV => self.csv.set(rem, value), + FileType::JSON => self.json.set(rem, value), + _ => { + _config_err!("Config value \"{key}\" is not supported on {}", format) + } + }, _ => _config_err!("Config value \"{key}\" not found on TableOptions"), } } } impl TableOptions { - /// Creates a new [`ConfigOptions`] with default values + /// Constructs a new instance of `TableOptions` with default settings. + /// + /// # Returns + /// + /// A new `TableOptions` instance with default configuration values. pub fn new() -> Self { Self::default() } + /// Sets the file format for the table. + /// + /// # Parameters + /// + /// * `format`: The file format to use (e.g., CSV, Parquet). pub fn set_file_format(&mut self, format: FileType) { self.current_format = Some(format); } + /// Creates a new `TableOptions` instance initialized with settings from a given session config. + /// + /// # Parameters + /// + /// * `config`: A reference to the session `ConfigOptions` from which to derive initial settings. + /// + /// # Returns + /// + /// A new `TableOptions` instance with settings applied from the session config. pub fn default_from_session_config(config: &ConfigOptions) -> Self { - let mut initial = TableOptions::default(); - initial.parquet.global = config.execution.parquet.clone(); + let initial = TableOptions::default(); + initial.combine_with_session_config(config); initial } - /// Set extensions to provided value + /// Updates the current `TableOptions` with settings from a given session config. + /// + /// # Parameters + /// + /// * `config`: A reference to the session `ConfigOptions` whose settings are to be applied. + /// + /// # Returns + /// + /// A new `TableOptions` instance with updated settings from the session config. + pub fn combine_with_session_config(&self, config: &ConfigOptions) -> Self { + let mut clone = self.clone(); + clone.parquet.global = config.execution.parquet.clone(); + clone + } + + /// Sets the extensions for this `TableOptions` instance. + /// + /// # Parameters + /// + /// * `extensions`: The `Extensions` instance to set. + /// + /// # Returns + /// + /// A new `TableOptions` instance with the specified extensions applied. pub fn with_extensions(mut self, extensions: Extensions) -> Self { self.extensions = extensions; self } - /// Set a configuration option + /// Sets a specific configuration option. + /// + /// # Parameters + /// + /// * `key`: The configuration key (e.g., "format.delimiter"). + /// * `value`: The value to set for the specified key. + /// + /// # Returns + /// + /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { let (prefix, _) = key.split_once('.').ok_or_else(|| { DataFusionError::Configuration(format!( @@ -1168,28 +1273,7 @@ impl TableOptions { )) })?; - if prefix == "csv" || prefix == "json" || prefix == "parquet" { - if let Some(format) = &self.current_format { - match format { - FileType::CSV if prefix != "csv" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for CSV format" - ))) - } - #[cfg(feature = "parquet")] - FileType::PARQUET if prefix != "parquet" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for PARQUET format" - ))) - } - FileType::JSON if prefix != "json" => { - return Err(DataFusionError::Configuration(format!( - "Key \"{key}\" is not applicable for JSON format" - ))) - } - _ => {} - } - } + if prefix == "format" { return ConfigField::set(self, key, value); } @@ -1202,6 +1286,15 @@ impl TableOptions { e.0.set(key, value) } + /// Initializes a new `TableOptions` from a hash map of string settings. + /// + /// # Parameters + /// + /// * `settings`: A hash map where each key-value pair represents a configuration setting. + /// + /// # Returns + /// + /// A result containing the new `TableOptions` instance or an error if any setting could not be applied. pub fn from_string_hash_map(settings: &HashMap) -> Result { let mut ret = Self::default(); for (k, v) in settings { @@ -1211,6 +1304,15 @@ impl TableOptions { Ok(ret) } + /// Modifies the current `TableOptions` instance with settings from a hash map. + /// + /// # Parameters + /// + /// * `settings`: A hash map where each key-value pair represents a configuration setting. + /// + /// # Returns + /// + /// A result indicating success or failure in applying the settings. pub fn alter_with_string_hash_map( &mut self, settings: &HashMap, @@ -1221,7 +1323,11 @@ impl TableOptions { Ok(()) } - /// Returns the [`ConfigEntry`] stored within this [`ConfigOptions`] + /// Retrieves all configuration entries from this `TableOptions`. + /// + /// # Returns + /// + /// A vector of `ConfigEntry` instances, representing all the configuration options within this `TableOptions`. pub fn entries(&self) -> Vec { struct Visitor(Vec); @@ -1249,9 +1355,7 @@ impl TableOptions { } let mut v = Visitor(vec![]); - self.visit(&mut v, "csv", ""); - self.visit(&mut v, "json", ""); - self.visit(&mut v, "parquet", ""); + self.visit(&mut v, "format", ""); v.0.extend(self.extensions.0.values().flat_map(|e| e.0.entries())); v.0 @@ -1556,6 +1660,7 @@ mod tests { use crate::config::{ ConfigEntry, ConfigExtension, ExtensionOptions, Extensions, TableOptions, }; + use crate::FileType; #[derive(Default, Debug, Clone)] pub struct TestExtensionConfig { @@ -1609,12 +1714,13 @@ mod tests { } #[test] - fn alter_kafka_config() { + fn alter_test_extension_config() { let mut extension = Extensions::new(); extension.insert(TestExtensionConfig::default()); let mut table_config = TableOptions::new().with_extensions(extension); - table_config.set("parquet.write_batch_size", "10").unwrap(); - assert_eq!(table_config.parquet.global.write_batch_size, 10); + table_config.set_file_format(FileType::CSV); + table_config.set("format.delimiter", ";").unwrap(); + assert_eq!(table_config.csv.delimiter, b';'); table_config.set("test.bootstrap.servers", "asd").unwrap(); let kafka_config = table_config .extensions @@ -1626,11 +1732,25 @@ mod tests { ); } + #[test] + fn csv_u8_table_options() { + let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::CSV); + table_config.set("format.delimiter", ";").unwrap(); + assert_eq!(table_config.csv.delimiter as char, ';'); + table_config.set("format.escape", "\"").unwrap(); + assert_eq!(table_config.csv.escape.unwrap() as char, '"'); + table_config.set("format.escape", "\'").unwrap(); + assert_eq!(table_config.csv.escape.unwrap() as char, '\''); + } + + #[cfg(feature = "parquet")] #[test] fn parquet_table_options() { let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config - .set("parquet.bloom_filter_enabled::col1", "true") + .set("format.bloom_filter_enabled::col1", "true") .unwrap(); assert_eq!( table_config.parquet.column_specific_options["col1"].bloom_filter_enabled, @@ -1638,26 +1758,17 @@ mod tests { ); } - #[test] - fn csv_u8_table_options() { - let mut table_config = TableOptions::new(); - table_config.set("csv.delimiter", ";").unwrap(); - assert_eq!(table_config.csv.delimiter as char, ';'); - table_config.set("csv.escape", "\"").unwrap(); - assert_eq!(table_config.csv.escape.unwrap() as char, '"'); - table_config.set("csv.escape", "\'").unwrap(); - assert_eq!(table_config.csv.escape.unwrap() as char, '\''); - } - + #[cfg(feature = "parquet")] #[test] fn parquet_table_options_config_entry() { let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config - .set("parquet.bloom_filter_enabled::col1", "true") + .set("format.bloom_filter_enabled::col1", "true") .unwrap(); let entries = table_config.entries(); assert!(entries .iter() - .any(|item| item.key == "parquet.bloom_filter_enabled::col1")) + .any(|item| item.key == "format.bloom_filter_enabled::col1")) } } diff --git a/datafusion/common/src/file_options/mod.rs b/datafusion/common/src/file_options/mod.rs index a72b812adc8d..eb1ce1b364fd 100644 --- a/datafusion/common/src/file_options/mod.rs +++ b/datafusion/common/src/file_options/mod.rs @@ -35,7 +35,7 @@ mod tests { config::TableOptions, file_options::{csv_writer::CsvWriterOptions, json_writer::JsonWriterOptions}, parsers::CompressionTypeVariant, - Result, + FileType, Result, }; use parquet::{ @@ -47,35 +47,36 @@ mod tests { #[test] fn test_writeroptions_parquet_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("parquet.max_row_group_size".to_owned(), "123".to_owned()); - option_map.insert("parquet.data_pagesize_limit".to_owned(), "123".to_owned()); - option_map.insert("parquet.write_batch_size".to_owned(), "123".to_owned()); - option_map.insert("parquet.writer_version".to_owned(), "2.0".to_owned()); + option_map.insert("format.max_row_group_size".to_owned(), "123".to_owned()); + option_map.insert("format.data_pagesize_limit".to_owned(), "123".to_owned()); + option_map.insert("format.write_batch_size".to_owned(), "123".to_owned()); + option_map.insert("format.writer_version".to_owned(), "2.0".to_owned()); option_map.insert( - "parquet.dictionary_page_size_limit".to_owned(), + "format.dictionary_page_size_limit".to_owned(), "123".to_owned(), ); option_map.insert( - "parquet.created_by".to_owned(), + "format.created_by".to_owned(), "df write unit test".to_owned(), ); option_map.insert( - "parquet.column_index_truncate_length".to_owned(), + "format.column_index_truncate_length".to_owned(), "123".to_owned(), ); option_map.insert( - "parquet.data_page_row_count_limit".to_owned(), + "format.data_page_row_count_limit".to_owned(), "123".to_owned(), ); - option_map.insert("parquet.bloom_filter_enabled".to_owned(), "true".to_owned()); - option_map.insert("parquet.encoding".to_owned(), "plain".to_owned()); - option_map.insert("parquet.dictionary_enabled".to_owned(), "true".to_owned()); - option_map.insert("parquet.compression".to_owned(), "zstd(4)".to_owned()); - option_map.insert("parquet.statistics_enabled".to_owned(), "page".to_owned()); - option_map.insert("parquet.bloom_filter_fpp".to_owned(), "0.123".to_owned()); - option_map.insert("parquet.bloom_filter_ndv".to_owned(), "123".to_owned()); + option_map.insert("format.bloom_filter_enabled".to_owned(), "true".to_owned()); + option_map.insert("format.encoding".to_owned(), "plain".to_owned()); + option_map.insert("format.dictionary_enabled".to_owned(), "true".to_owned()); + option_map.insert("format.compression".to_owned(), "zstd(4)".to_owned()); + option_map.insert("format.statistics_enabled".to_owned(), "page".to_owned()); + option_map.insert("format.bloom_filter_fpp".to_owned(), "0.123".to_owned()); + option_map.insert("format.bloom_filter_ndv".to_owned(), "123".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -131,54 +132,52 @@ mod tests { let mut option_map: HashMap = HashMap::new(); option_map.insert( - "parquet.bloom_filter_enabled::col1".to_owned(), + "format.bloom_filter_enabled::col1".to_owned(), "true".to_owned(), ); option_map.insert( - "parquet.bloom_filter_enabled::col2.nested".to_owned(), + "format.bloom_filter_enabled::col2.nested".to_owned(), "true".to_owned(), ); - option_map.insert("parquet.encoding::col1".to_owned(), "plain".to_owned()); - option_map.insert("parquet.encoding::col2.nested".to_owned(), "rle".to_owned()); + option_map.insert("format.encoding::col1".to_owned(), "plain".to_owned()); + option_map.insert("format.encoding::col2.nested".to_owned(), "rle".to_owned()); option_map.insert( - "parquet.dictionary_enabled::col1".to_owned(), + "format.dictionary_enabled::col1".to_owned(), "true".to_owned(), ); option_map.insert( - "parquet.dictionary_enabled::col2.nested".to_owned(), + "format.dictionary_enabled::col2.nested".to_owned(), "true".to_owned(), ); - option_map.insert("parquet.compression::col1".to_owned(), "zstd(4)".to_owned()); + option_map.insert("format.compression::col1".to_owned(), "zstd(4)".to_owned()); option_map.insert( - "parquet.compression::col2.nested".to_owned(), + "format.compression::col2.nested".to_owned(), "zstd(10)".to_owned(), ); option_map.insert( - "parquet.statistics_enabled::col1".to_owned(), + "format.statistics_enabled::col1".to_owned(), "page".to_owned(), ); option_map.insert( - "parquet.statistics_enabled::col2.nested".to_owned(), + "format.statistics_enabled::col2.nested".to_owned(), "none".to_owned(), ); option_map.insert( - "parquet.bloom_filter_fpp::col1".to_owned(), + "format.bloom_filter_fpp::col1".to_owned(), "0.123".to_owned(), ); option_map.insert( - "parquet.bloom_filter_fpp::col2.nested".to_owned(), + "format.bloom_filter_fpp::col2.nested".to_owned(), "0.456".to_owned(), ); + option_map.insert("format.bloom_filter_ndv::col1".to_owned(), "123".to_owned()); option_map.insert( - "parquet.bloom_filter_ndv::col1".to_owned(), - "123".to_owned(), - ); - option_map.insert( - "parquet.bloom_filter_ndv::col2.nested".to_owned(), + "format.bloom_filter_ndv::col2.nested".to_owned(), "456".to_owned(), ); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::PARQUET); table_config.alter_with_string_hash_map(&option_map)?; let parquet_options = ParquetWriterOptions::try_from(&table_config.parquet)?; @@ -271,16 +270,17 @@ mod tests { // for StatementOptions fn test_writeroptions_csv_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("csv.has_header".to_owned(), "true".to_owned()); - option_map.insert("csv.date_format".to_owned(), "123".to_owned()); - option_map.insert("csv.datetime_format".to_owned(), "123".to_owned()); - option_map.insert("csv.timestamp_format".to_owned(), "2.0".to_owned()); - option_map.insert("csv.time_format".to_owned(), "123".to_owned()); - option_map.insert("csv.null_value".to_owned(), "123".to_owned()); - option_map.insert("csv.compression".to_owned(), "gzip".to_owned()); - option_map.insert("csv.delimiter".to_owned(), ";".to_owned()); + option_map.insert("format.has_header".to_owned(), "true".to_owned()); + option_map.insert("format.date_format".to_owned(), "123".to_owned()); + option_map.insert("format.datetime_format".to_owned(), "123".to_owned()); + option_map.insert("format.timestamp_format".to_owned(), "2.0".to_owned()); + option_map.insert("format.time_format".to_owned(), "123".to_owned()); + option_map.insert("format.null_value".to_owned(), "123".to_owned()); + option_map.insert("format.compression".to_owned(), "gzip".to_owned()); + option_map.insert("format.delimiter".to_owned(), ";".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::CSV); table_config.alter_with_string_hash_map(&option_map)?; let csv_options = CsvWriterOptions::try_from(&table_config.csv)?; @@ -299,9 +299,10 @@ mod tests { // for StatementOptions fn test_writeroptions_json_from_statement_options() -> Result<()> { let mut option_map: HashMap = HashMap::new(); - option_map.insert("json.compression".to_owned(), "gzip".to_owned()); + option_map.insert("format.compression".to_owned(), "gzip".to_owned()); let mut table_config = TableOptions::new(); + table_config.set_file_format(FileType::JSON); table_config.alter_with_string_hash_map(&option_map)?; let json_options = JsonWriterOptions::try_from(&table_config.json)?; diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 25830401571d..eea5fc1127ce 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1151,8 +1151,8 @@ impl DataFrame { "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), )); } - let table_options = self.session_state.default_table_options(); - let props = writer_options.unwrap_or_else(|| table_options.csv.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().csv); let plan = LogicalPlanBuilder::copy_to( self.plan, @@ -1200,9 +1200,8 @@ impl DataFrame { )); } - let table_options = self.session_state.default_table_options(); - - let props = writer_options.unwrap_or_else(|| table_options.json.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().json); let plan = LogicalPlanBuilder::copy_to( self.plan, diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index f4e8c9dfcd6f..e3f606e322fe 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -57,9 +57,8 @@ impl DataFrame { )); } - let table_options = self.session_state.default_table_options(); - - let props = writer_options.unwrap_or_else(|| table_options.parquet.clone()); + let props = writer_options + .unwrap_or_else(|| self.session_state.default_table_options().parquet); let plan = LogicalPlanBuilder::copy_to( self.plan, diff --git a/datafusion/core/src/datasource/file_format/options.rs b/datafusion/core/src/datasource/file_format/options.rs index f66683c311c1..f5bd72495d66 100644 --- a/datafusion/core/src/datasource/file_format/options.rs +++ b/datafusion/core/src/datasource/file_format/options.rs @@ -461,7 +461,7 @@ pub trait ReadOptions<'a> { return Ok(Arc::new(s.to_owned())); } - self.to_listing_options(config, state.default_table_options().clone()) + self.to_listing_options(config, state.default_table_options()) .infer_schema(&state, &table_path) .await } diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 2a2551236e1b..c1e337b5c44a 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -118,7 +118,7 @@ impl ListingTableConfig { } } - fn infer_format(path: &str) -> Result<(Arc, String)> { + fn infer_file_type(path: &str) -> Result<(FileType, String)> { let err_msg = format!("Unable to infer file type from path: {path}"); let mut exts = path.rsplit('.'); @@ -139,20 +139,7 @@ impl ListingTableConfig { .get_ext_with_compression(file_compression_type.to_owned()) .map_err(|_| DataFusionError::Internal(err_msg))?; - let file_format: Arc = match file_type { - FileType::ARROW => Arc::new(ArrowFormat), - FileType::AVRO => Arc::new(AvroFormat), - FileType::CSV => Arc::new( - CsvFormat::default().with_file_compression_type(file_compression_type), - ), - FileType::JSON => Arc::new( - JsonFormat::default().with_file_compression_type(file_compression_type), - ), - #[cfg(feature = "parquet")] - FileType::PARQUET => Arc::new(ParquetFormat::default()), - }; - - Ok((file_format, ext)) + Ok((file_type, ext)) } /// Infer `ListingOptions` based on `table_path` suffix. @@ -173,10 +160,27 @@ impl ListingTableConfig { .await .ok_or_else(|| DataFusionError::Internal("No files for table".into()))??; - let (format, file_extension) = - ListingTableConfig::infer_format(file.location.as_ref())?; + let (file_type, file_extension) = + ListingTableConfig::infer_file_type(file.location.as_ref())?; + + let mut table_options = state.default_table_options(); + table_options.set_file_format(file_type.clone()); + let file_format: Arc = match file_type { + FileType::CSV => { + Arc::new(CsvFormat::default().with_options(table_options.csv)) + } + #[cfg(feature = "parquet")] + FileType::PARQUET => { + Arc::new(ParquetFormat::default().with_options(table_options.parquet)) + } + FileType::AVRO => Arc::new(AvroFormat), + FileType::JSON => { + Arc::new(JsonFormat::default().with_options(table_options.json)) + } + FileType::ARROW => Arc::new(ArrowFormat), + }; - let listing_options = ListingOptions::new(format) + let listing_options = ListingOptions::new(file_format) .with_file_extension(file_extension) .with_target_partitions(state.config().target_partitions()); diff --git a/datafusion/core/src/datasource/listing_table_factory.rs b/datafusion/core/src/datasource/listing_table_factory.rs index 4e126bbba9f9..b616e0181cfc 100644 --- a/datafusion/core/src/datasource/listing_table_factory.rs +++ b/datafusion/core/src/datasource/listing_table_factory.rs @@ -34,7 +34,6 @@ use crate::datasource::TableProvider; use crate::execution::context::SessionState; use arrow::datatypes::{DataType, SchemaRef}; -use datafusion_common::config::TableOptions; use datafusion_common::{arrow_datafusion_err, DataFusionError, FileType}; use datafusion_expr::CreateExternalTable; @@ -58,8 +57,7 @@ impl TableProviderFactory for ListingTableFactory { state: &SessionState, cmd: &CreateExternalTable, ) -> datafusion_common::Result> { - let mut table_options = - TableOptions::default_from_session_config(state.config_options()); + let mut table_options = state.default_table_options(); let file_type = FileType::from_str(cmd.file_type.as_str()).map_err(|_| { DataFusionError::Execution(format!("Unknown FileType {}", cmd.file_type)) })?; @@ -227,7 +225,7 @@ mod tests { let name = OwnedTableReference::bare("foo".to_string()); let mut options = HashMap::new(); - options.insert("csv.schema_infer_max_rec".to_owned(), "1000".to_owned()); + options.insert("format.schema_infer_max_rec".to_owned(), "1000".to_owned()); let cmd = CreateExternalTable { name, location: csv_file.path().to_str().unwrap().to_string(), diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 1ac7da465216..116e45c8c130 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -384,9 +384,9 @@ impl SessionContext { self.state.read().config.clone() } - /// Return a copied version of config for this Session + /// Return a copied version of table options for this Session pub fn copied_table_options(&self) -> TableOptions { - self.state.read().default_table_options().clone() + self.state.read().default_table_options() } /// Creates a [`DataFrame`] from SQL query text. @@ -1750,11 +1750,7 @@ impl SessionState { .0 .insert(ObjectName(vec![Ident::from(table.name.as_str())])); } - DFStatement::CopyTo(CopyToStatement { - source, - target: _, - options: _, - }) => match source { + DFStatement::CopyTo(CopyToStatement { source, .. }) => match source { CopyToSource::Relation(table_name) => { visitor.insert(table_name); } @@ -1963,8 +1959,9 @@ impl SessionState { } /// return the TableOptions options with its extensions - pub fn default_table_options(&self) -> &TableOptions { - &self.table_option_namespace + pub fn default_table_options(&self) -> TableOptions { + self.table_option_namespace + .combine_with_session_config(self.config_options()) } /// Get a new TaskContext to run in this session diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 96f5e1c3ffd3..ee581ca64214 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -595,7 +595,7 @@ impl DefaultPhysicalPlanner { table_partition_cols, overwrite: false, }; - let mut table_options = session_state.default_table_options().clone(); + let mut table_options = session_state.default_table_options(); let sink_format: Arc = match format_options { FormatOptions::CSV(options) => { table_options.csv = options.clone(); diff --git a/datafusion/core/src/test_util/parquet.rs b/datafusion/core/src/test_util/parquet.rs index 7a466a666d8d..8113d799a184 100644 --- a/datafusion/core/src/test_util/parquet.rs +++ b/datafusion/core/src/test_util/parquet.rs @@ -165,7 +165,7 @@ impl TestParquetFile { // run coercion on the filters to coerce types etc. let props = ExecutionProps::new(); let context = SimplifyContext::new(&props).with_schema(df_schema.clone()); - let parquet_options = ctx.state().default_table_options().parquet.clone(); + let parquet_options = ctx.copied_table_options().parquet; if let Some(filter) = maybe_filter { let simplifier = ExprSimplifier::new(context); let filter = simplifier.coerce(filter, df_schema.clone()).unwrap(); diff --git a/datafusion/core/tests/sql/sql_api.rs b/datafusion/core/tests/sql/sql_api.rs index d7adc9611b2f..b3a819fbc331 100644 --- a/datafusion/core/tests/sql/sql_api.rs +++ b/datafusion/core/tests/sql/sql_api.rs @@ -16,6 +16,7 @@ // under the License. use datafusion::prelude::*; + use tempfile::TempDir; #[tokio::test] @@ -27,7 +28,7 @@ async fn unsupported_ddl_returns_error() { // disallow ddl let options = SQLOptions::new().with_allow_ddl(false); - let sql = "create view test_view as select * from test"; + let sql = "CREATE VIEW test_view AS SELECT * FROM test"; let df = ctx.sql_with_options(sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -46,7 +47,7 @@ async fn unsupported_dml_returns_error() { let options = SQLOptions::new().with_allow_dml(false); - let sql = "insert into test values (1)"; + let sql = "INSERT INTO test VALUES (1)"; let df = ctx.sql_with_options(sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -67,7 +68,10 @@ async fn unsupported_copy_returns_error() { let options = SQLOptions::new().with_allow_dml(false); - let sql = format!("copy (values(1)) to '{}'", tmpfile.to_string_lossy()); + let sql = format!( + "COPY (values(1)) TO '{}' STORED AS parquet", + tmpfile.to_string_lossy() + ); let df = ctx.sql_with_options(&sql, options).await; assert_eq!( df.unwrap_err().strip_backtrace(), @@ -106,7 +110,7 @@ async fn ddl_can_not_be_planned_by_session_state() { let state = ctx.state(); // can not create a logical plan for catalog DDL - let sql = "drop table test"; + let sql = "DROP TABLE test"; let plan = state.create_logical_plan(sql).await.unwrap(); let physical_plan = state.create_physical_plan(&plan).await; assert_eq!( diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 93de560dbee5..3c43f100750f 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -35,7 +35,7 @@ use datafusion_common::config::{FormatOptions, TableOptions}; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ internal_err, not_impl_err, plan_err, DFField, DFSchema, DFSchemaRef, - DataFusionError, Result, ScalarValue, + DataFusionError, FileType, Result, ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -314,10 +314,9 @@ async fn roundtrip_logical_plan_copy_to_sql_options() -> Result<()> { let ctx = SessionContext::new(); let input = create_csv_scan(&ctx).await?; - - let mut table_options = - TableOptions::default_from_session_config(ctx.state().config_options()); - table_options.set("csv.delimiter", ";")?; + let mut table_options = ctx.copied_table_options(); + table_options.set_file_format(FileType::CSV); + table_options.set("format.delimiter", ";")?; let plan = LogicalPlan::Copy(CopyTo { input: Arc::new(input), diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index effc1d096cfd..a5d7970495c5 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -17,21 +17,20 @@ //! [`DFParser`]: DataFusion SQL Parser based on [`sqlparser`] +use std::collections::{HashMap, VecDeque}; +use std::fmt; +use std::str::FromStr; + use datafusion_common::parsers::CompressionTypeVariant; -use sqlparser::ast::{OrderByExpr, Query, Value}; -use sqlparser::tokenizer::Word; use sqlparser::{ ast::{ - ColumnDef, ColumnOptionDef, ObjectName, Statement as SQLStatement, - TableConstraint, + ColumnDef, ColumnOptionDef, ObjectName, OrderByExpr, Query, + Statement as SQLStatement, TableConstraint, Value, }, dialect::{keywords::Keyword, Dialect, GenericDialect}, parser::{Parser, ParserError}, - tokenizer::{Token, TokenWithLocation, Tokenizer}, + tokenizer::{Token, TokenWithLocation, Tokenizer, Word}, }; -use std::collections::VecDeque; -use std::fmt; -use std::{collections::HashMap, str::FromStr}; // Use `Parser::expected` instead, if possible macro_rules! parser_err { @@ -102,6 +101,12 @@ pub struct CopyToStatement { pub source: CopyToSource, /// The URL to where the data is heading pub target: String, + /// Partition keys + pub partitioned_by: Vec, + /// Indicates whether there is a header row (e.g. CSV) + pub has_header: bool, + /// File type (Parquet, NDJSON, CSV etc.) + pub stored_as: Option, /// Target specific options pub options: Vec<(String, Value)>, } @@ -111,15 +116,27 @@ impl fmt::Display for CopyToStatement { let Self { source, target, + partitioned_by, + stored_as, options, + .. } = self; write!(f, "COPY {source} TO {target}")?; + if let Some(file_type) = stored_as { + write!(f, " STORED AS {}", file_type)?; + } + if !partitioned_by.is_empty() { + write!(f, " PARTITIONED BY ({})", partitioned_by.join(", "))?; + } + + if self.has_header { + write!(f, " WITH HEADER ROW")?; + } if !options.is_empty() { let opts: Vec<_> = options.iter().map(|(k, v)| format!("{k} {v}")).collect(); - // print them in sorted order - write!(f, " ({})", opts.join(", "))?; + write!(f, " OPTIONS ({})", opts.join(", "))?; } Ok(()) @@ -243,6 +260,15 @@ impl fmt::Display for Statement { } } +fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { + if field.is_some() { + return Err(ParserError::ParserError(format!( + "{name} specified more than once", + ))); + } + Ok(()) +} + /// Datafusion SQL Parser based on [`sqlparser`] /// /// Parses DataFusion's SQL dialect, often delegating to [`sqlparser`]'s [`Parser`]. @@ -370,21 +396,79 @@ impl<'a> DFParser<'a> { CopyToSource::Relation(table_name) }; - self.parser.expect_keyword(Keyword::TO)?; + #[derive(Default)] + struct Builder { + stored_as: Option, + target: Option, + partitioned_by: Option>, + has_header: Option, + options: Option>, + } - let target = self.parser.parse_literal_string()?; + let mut builder = Builder::default(); - // check for options in parens - let options = if self.parser.peek_token().token == Token::LParen { - self.parse_value_options()? - } else { - vec![] + loop { + if let Some(keyword) = self.parser.parse_one_of_keywords(&[ + Keyword::STORED, + Keyword::TO, + Keyword::PARTITIONED, + Keyword::OPTIONS, + Keyword::WITH, + ]) { + match keyword { + Keyword::STORED => { + self.parser.expect_keyword(Keyword::AS)?; + ensure_not_set(&builder.stored_as, "STORED AS")?; + builder.stored_as = Some(self.parse_file_format()?); + } + Keyword::TO => { + ensure_not_set(&builder.target, "TO")?; + builder.target = Some(self.parser.parse_literal_string()?); + } + Keyword::WITH => { + self.parser.expect_keyword(Keyword::HEADER)?; + self.parser.expect_keyword(Keyword::ROW)?; + ensure_not_set(&builder.has_header, "WITH HEADER ROW")?; + builder.has_header = Some(true); + } + Keyword::PARTITIONED => { + self.parser.expect_keyword(Keyword::BY)?; + ensure_not_set(&builder.partitioned_by, "PARTITIONED BY")?; + builder.partitioned_by = Some(self.parse_partitions()?); + } + Keyword::OPTIONS => { + ensure_not_set(&builder.options, "OPTIONS")?; + builder.options = Some(self.parse_value_options()?); + } + _ => { + unreachable!() + } + } + } else { + let token = self.parser.next_token(); + if token == Token::EOF || token == Token::SemiColon { + break; + } else { + return Err(ParserError::ParserError(format!( + "Unexpected token {token}" + ))); + } + } + } + + let Some(target) = builder.target else { + return Err(ParserError::ParserError( + "Missing TO clause in COPY statement".into(), + )); }; Ok(Statement::CopyTo(CopyToStatement { source, target, - options, + partitioned_by: builder.partitioned_by.unwrap_or(vec![]), + has_header: builder.has_header.unwrap_or(false), + stored_as: builder.stored_as, + options: builder.options.unwrap_or(vec![]), })) } @@ -624,15 +708,6 @@ impl<'a> DFParser<'a> { } let mut builder = Builder::default(); - fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { - if field.is_some() { - return Err(ParserError::ParserError(format!( - "{name} specified more than once", - ))); - } - Ok(()) - } - loop { if let Some(keyword) = self.parser.parse_one_of_keywords(&[ Keyword::STORED, @@ -1321,10 +1396,13 @@ mod tests { #[test] fn copy_to_table_to_table() -> Result<(), ParserError> { // positive case - let sql = "COPY foo TO bar"; + let sql = "COPY foo TO bar STORED AS CSV"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("CSV".to_owned()), options: vec![], }); @@ -1335,10 +1413,22 @@ mod tests { #[test] fn explain_copy_to_table_to_table() -> Result<(), ParserError> { let cases = vec![ - ("EXPLAIN COPY foo TO bar", false, false), - ("EXPLAIN ANALYZE COPY foo TO bar", true, false), - ("EXPLAIN VERBOSE COPY foo TO bar", false, true), - ("EXPLAIN ANALYZE VERBOSE COPY foo TO bar", true, true), + ("EXPLAIN COPY foo TO bar STORED AS PARQUET", false, false), + ( + "EXPLAIN ANALYZE COPY foo TO bar STORED AS PARQUET", + true, + false, + ), + ( + "EXPLAIN VERBOSE COPY foo TO bar STORED AS PARQUET", + false, + true, + ), + ( + "EXPLAIN ANALYZE VERBOSE COPY foo TO bar STORED AS PARQUET", + true, + true, + ), ]; for (sql, analyze, verbose) in cases { println!("sql: {sql}, analyze: {analyze}, verbose: {verbose}"); @@ -1346,6 +1436,9 @@ mod tests { let expected_copy = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("PARQUET".to_owned()), options: vec![], }); let expected = Statement::Explain(ExplainStatement { @@ -1375,10 +1468,13 @@ mod tests { panic!("Expected query, got {statement:?}"); }; - let sql = "COPY (SELECT 1) TO bar"; + let sql = "COPY (SELECT 1) TO bar STORED AS CSV WITH HEADER ROW"; let expected = Statement::CopyTo(CopyToStatement { source: CopyToSource::Query(query), target: "bar".to_string(), + partitioned_by: vec![], + has_header: true, + stored_as: Some("CSV".to_owned()), options: vec![], }); assert_eq!(verified_stmt(sql), expected); @@ -1387,10 +1483,31 @@ mod tests { #[test] fn copy_to_options() -> Result<(), ParserError> { - let sql = "COPY foo TO bar (row_group_size 55)"; + let sql = "COPY foo TO bar STORED AS CSV OPTIONS (row_group_size 55)"; + let expected = Statement::CopyTo(CopyToStatement { + source: object_name("foo"), + target: "bar".to_string(), + partitioned_by: vec![], + has_header: false, + stored_as: Some("CSV".to_owned()), + options: vec![( + "row_group_size".to_string(), + Value::Number("55".to_string(), false), + )], + }); + assert_eq!(verified_stmt(sql), expected); + Ok(()) + } + + #[test] + fn copy_to_partitioned_by() -> Result<(), ParserError> { + let sql = "COPY foo TO bar STORED AS CSV PARTITIONED BY (a) OPTIONS (row_group_size 55)"; let expected = Statement::CopyTo(CopyToStatement { source: object_name("foo"), target: "bar".to_string(), + partitioned_by: vec!["a".to_string()], + has_header: false, + stored_as: Some("CSV".to_owned()), options: vec![( "row_group_size".to_string(), Value::Number("55".to_string(), false), @@ -1404,24 +1521,24 @@ mod tests { fn copy_to_multi_options() -> Result<(), ParserError> { // order of options is preserved let sql = - "COPY foo TO bar (format parquet, row_group_size 55, compression snappy)"; + "COPY foo TO bar STORED AS parquet OPTIONS ('format.row_group_size' 55, 'format.compression' snappy)"; let expected_options = vec![ ( - "format".to_string(), - Value::UnQuotedString("parquet".to_string()), - ), - ( - "row_group_size".to_string(), + "format.row_group_size".to_string(), Value::Number("55".to_string(), false), ), ( - "compression".to_string(), + "format.compression".to_string(), Value::UnQuotedString("snappy".to_string()), ), ]; - let options = if let Statement::CopyTo(copy_to) = verified_stmt(sql) { + let mut statements = DFParser::parse_sql(sql).unwrap(); + assert_eq!(statements.len(), 1); + let only_statement = statements.pop_front().unwrap(); + + let options = if let Statement::CopyTo(copy_to) = only_statement { copy_to.options } else { panic!("Expected copy"); @@ -1460,7 +1577,10 @@ mod tests { } let only_statement = statements.pop_front().unwrap(); - assert_eq!(canonical, only_statement.to_string()); + assert_eq!( + canonical.to_uppercase(), + only_statement.to_string().to_uppercase() + ); only_statement } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 412c3b753ed5..e50aceb757df 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -813,20 +813,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn copy_to_plan(&self, statement: CopyToStatement) -> Result { // determine if source is table or query and handle accordingly let copy_source = statement.source; - let input = match copy_source { + let (input, input_schema, table_ref) = match copy_source { CopyToSource::Relation(object_name) => { - let table_ref = - self.object_name_to_table_reference(object_name.clone())?; - let table_source = self.context_provider.get_table_source(table_ref)?; - LogicalPlanBuilder::scan( - object_name_to_string(&object_name), - table_source, - None, - )? - .build()? + let table_name = object_name_to_string(&object_name); + let table_ref = self.object_name_to_table_reference(object_name)?; + let table_source = + self.context_provider.get_table_source(table_ref.clone())?; + let plan = + LogicalPlanBuilder::scan(table_name, table_source, None)?.build()?; + let input_schema = plan.schema().clone(); + (plan, input_schema, Some(table_ref)) } CopyToSource::Query(query) => { - self.query_to_plan(query, &mut PlannerContext::new())? + let plan = self.query_to_plan(query, &mut PlannerContext::new())?; + let input_schema = plan.schema().clone(); + (plan, input_schema, None) } }; @@ -852,8 +853,41 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { options.insert(key.to_lowercase(), value_string.to_lowercase()); } - let file_type = try_infer_file_type(&mut options, &statement.target)?; - let partition_by = take_partition_by(&mut options); + let file_type = if let Some(file_type) = statement.stored_as { + FileType::from_str(&file_type).map_err(|_| { + DataFusionError::Configuration(format!("Unknown FileType {}", file_type)) + })? + } else { + let e = || { + DataFusionError::Configuration( + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) + }; + // try to infer file format from file extension + let extension: &str = &Path::new(&statement.target) + .extension() + .ok_or_else(e)? + .to_str() + .ok_or_else(e)? + .to_lowercase(); + + FileType::from_str(extension).map_err(|e| { + DataFusionError::Configuration(format!( + "{}. Use STORED AS to define file format.", + e + )) + })? + }; + + let partition_by = statement + .partitioned_by + .iter() + .map(|col| input_schema.field_with_name(table_ref.as_ref(), col)) + .collect::>>()? + .into_iter() + .map(|f| f.name().to_owned()) + .collect(); Ok(LogicalPlan::Copy(CopyTo { input: Arc::new(input), @@ -1469,82 +1503,3 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .is_ok() } } - -/// Infers the file type for a given target based on provided options or file extension. -/// -/// This function tries to determine the file type based on the 'format' option present -/// in the provided options hashmap. If 'format' is not explicitly set, the function attempts -/// to infer the file type from the file extension of the target. It returns an error if neither -/// the format option is set nor the file extension can be determined or parsed. -/// -/// # Arguments -/// -/// * `options` - A mutable reference to a HashMap containing options where the file format -/// might be specified under the 'format' key. -/// * `target` - A string slice representing the path to the file for which the file type needs to be inferred. -/// -/// # Returns -/// -/// Returns `Result` which is Ok if the file type could be successfully inferred, -/// otherwise returns an error in case of failure to determine or parse the file format or extension. -/// -/// # Errors -/// -/// This function returns an error in two cases: -/// - If the 'format' option is not set and the file extension cannot be retrieved from `target`. -/// - If the file extension is found but cannot be converted into a valid string. -/// -pub fn try_infer_file_type( - options: &mut HashMap, - target: &str, -) -> Result { - let explicit_format = options.remove("format"); - let format = match explicit_format { - Some(s) => FileType::from_str(&s), - None => { - // try to infer file format from file extension - let extension: &str = &Path::new(target) - .extension() - .ok_or(DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension!" - .to_string(), - ))? - .to_str() - .ok_or(DataFusionError::Configuration( - "Format not explicitly set and failed to parse file extension!" - .to_string(), - ))? - .to_lowercase(); - - FileType::from_str(extension) - } - }?; - - Ok(format) -} - -/// Extracts and parses the 'partition_by' option from a provided options hashmap. -/// -/// This function looks for a 'partition_by' key in the options hashmap. If found, -/// it splits the value by commas, trims each resulting string, and replaces double -/// single quotes with a single quote. It returns a vector of partition column names. -/// -/// # Arguments -/// -/// * `options` - A mutable reference to a HashMap containing options where 'partition_by' -/// might be specified. -/// -/// # Returns -/// -/// Returns a `Vec` containing partition column names. If the 'partition_by' option -/// is not present, returns an empty vector. -pub fn take_partition_by(options: &mut HashMap) -> Vec { - let partition_by = options.remove("partition_by"); - match partition_by { - Some(part_cols) => part_cols - .split(',') - .map(|s| s.trim().replace("''", "'")) - .collect::>(), - None => vec![], - } -} diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index b6077353e5dd..6d335f1f8fc9 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -22,25 +22,23 @@ use std::{sync::Arc, vec}; use arrow_schema::TimeUnit::Nanosecond; use arrow_schema::*; -use datafusion_sql::planner::PlannerContext; -use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; -use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; - +use datafusion_common::config::ConfigOptions; use datafusion_common::{ - config::ConfigOptions, DataFusionError, Result, ScalarValue, TableReference, + plan_err, DFSchema, DataFusionError, ParamValues, Result, ScalarValue, TableReference, }; -use datafusion_common::{plan_err, DFSchema, ParamValues}; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, AggregateUDF, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, TableSource, Volatility, WindowUDF, }; +use datafusion_sql::unparser::{expr_to_sql, plan_to_sql}; use datafusion_sql::{ parser::DFParser, - planner::{ContextProvider, ParserOptions, SqlToRel}, + planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, }; use rstest::rstest; +use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; #[test] @@ -389,7 +387,7 @@ fn plan_rollback_transaction_chained() { #[test] fn plan_copy_to() { - let sql = "COPY test_decimal to 'output.csv'"; + let sql = "COPY test_decimal to 'output.csv' STORED AS CSV"; let plan = r#" CopyTo: format=csv output_url=output.csv options: () TableScan: test_decimal @@ -410,6 +408,18 @@ Explain quick_test(sql, plan); } +#[test] +fn plan_explain_copy_to_format() { + let sql = "EXPLAIN COPY test_decimal to 'output.tbl' STORED AS CSV"; + let plan = r#" +Explain + CopyTo: format=csv output_url=output.tbl options: () + TableScan: test_decimal + "# + .trim(); + quick_test(sql, plan); +} + #[test] fn plan_copy_to_query() { let sql = "COPY (select * from test_decimal limit 10) to 'output.csv'"; diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index df23a993ebce..4d4f596d0c60 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -21,13 +21,13 @@ create table source_table(col1 integer, col2 varchar) as values (1, 'Foo'), (2, # Copy to directory as multiple files query IT -COPY source_table TO 'test_files/scratch/copy/table/' (format parquet, 'parquet.compression' 'zstd(10)'); +COPY source_table TO 'test_files/scratch/copy/table/' STORED AS parquet OPTIONS ('format.compression' 'zstd(10)'); ---- 2 # Copy to directory as partitioned files query IT -COPY source_table TO 'test_files/scratch/copy/partitioned_table1/' (format parquet, 'parquet.compression' 'zstd(10)', partition_by 'col2'); +COPY source_table TO 'test_files/scratch/copy/partitioned_table1/' STORED AS parquet PARTITIONED BY (col2) OPTIONS ('format.compression' 'zstd(10)'); ---- 2 @@ -54,8 +54,8 @@ select * from validate_partitioned_parquet_bar order by col1; # Copy to directory as partitioned files query ITT -COPY (values (1, 'a', 'x'), (2, 'b', 'y'), (3, 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table2/' -(format parquet, partition_by 'column2, column3', 'parquet.compression' 'zstd(10)'); +COPY (values (1, 'a', 'x'), (2, 'b', 'y'), (3, 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table2/' STORED AS parquet PARTITIONED BY (column2, column3) +OPTIONS ('format.compression' 'zstd(10)'); ---- 3 @@ -82,8 +82,8 @@ select * from validate_partitioned_parquet_a_x order by column1; # Copy to directory as partitioned files query TTT -COPY (values ('1', 'a', 'x'), ('2', 'b', 'y'), ('3', 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table3/' -(format parquet, 'parquet.compression' 'zstd(10)', partition_by 'column1, column3'); +COPY (values ('1', 'a', 'x'), ('2', 'b', 'y'), ('3', 'c', 'z')) TO 'test_files/scratch/copy/partitioned_table3/' STORED AS parquet PARTITIONED BY (column1, column3) +OPTIONS ('format.compression' 'zstd(10)'); ---- 3 @@ -111,49 +111,52 @@ a statement ok create table test ("'test'" varchar, "'test2'" varchar, "'test3'" varchar); -query TTT -insert into test VALUES ('a', 'x', 'aa'), ('b','y', 'bb'), ('c', 'z', 'cc') ----- -3 - -query T -select "'test'" from test ----- -a -b -c - -# Note to place a single ' inside of a literal string escape by putting two '' -query TTT -copy test to 'test_files/scratch/copy/escape_quote' (format csv, partition_by '''test2'',''test3''') ----- -3 - -statement ok -CREATE EXTERNAL TABLE validate_partitioned_escape_quote STORED AS CSV -LOCATION 'test_files/scratch/copy/escape_quote/' PARTITIONED BY ("'test2'", "'test3'"); - +## Until the partition by parsing uses ColumnDef, this test is meaningless since it becomes an overfit. Even in +## CREATE EXTERNAL TABLE, there is a schema mismatch, this should be an issue. +# +#query TTT +#insert into test VALUES ('a', 'x', 'aa'), ('b','y', 'bb'), ('c', 'z', 'cc') +#---- +#3 +# +#query T +#select "'test'" from test +#---- +#a +#b +#c +# +# # Note to place a single ' inside of a literal string escape by putting two '' +#query TTT +#copy test to 'test_files/scratch/copy/escape_quote' STORED AS CSV; +#---- +#3 +# +#statement ok +#CREATE EXTERNAL TABLE validate_partitioned_escape_quote STORED AS CSV +#LOCATION 'test_files/scratch/copy/escape_quote/' PARTITIONED BY ("'test2'", "'test3'"); +# # This triggers a panic (index out of bounds) # https://github.com/apache/arrow-datafusion/issues/9269 #query #select * from validate_partitioned_escape_quote; query TT -EXPLAIN COPY source_table TO 'test_files/scratch/copy/table/' (format parquet, 'parquet.compression' 'zstd(10)'); +EXPLAIN COPY source_table TO 'test_files/scratch/copy/table/' STORED AS PARQUET OPTIONS ('format.compression' 'zstd(10)'); ---- logical_plan -CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: (parquet.compression zstd(10)) +CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: (format.compression zstd(10)) --TableScan: source_table projection=[col1, col2] physical_plan FileSinkExec: sink=ParquetSink(file_groups=[]) --MemoryExec: partitions=1, partition_sizes=[1] # Error case -query error DataFusion error: Invalid or Unsupported Configuration: Format not explicitly set and unable to get file extension! +query error DataFusion error: Invalid or Unsupported Configuration: Format not explicitly set and unable to get file extension! Use STORED AS to define file format. EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' query TT -EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' (format parquet) +EXPLAIN COPY source_table to 'test_files/scratch/copy/table/' STORED AS PARQUET ---- logical_plan CopyTo: format=parquet output_url=test_files/scratch/copy/table/ options: () @@ -164,7 +167,7 @@ FileSinkExec: sink=ParquetSink(file_groups=[]) # Copy more files to directory via query query IT -COPY (select * from source_table UNION ALL select * from source_table) to 'test_files/scratch/copy/table/' (format parquet); +COPY (select * from source_table UNION ALL select * from source_table) to 'test_files/scratch/copy/table/' STORED AS PARQUET; ---- 4 @@ -185,7 +188,7 @@ select * from validate_parquet; query ? copy (values (struct(timestamp '2021-01-01 01:00:01', 1)), (struct(timestamp '2022-01-01 01:00:01', 2)), (struct(timestamp '2023-01-03 01:00:01', 3)), (struct(timestamp '2024-01-01 01:00:01', 4))) -to 'test_files/scratch/copy/table_nested2/' (format parquet); +to 'test_files/scratch/copy/table_nested2/' STORED AS PARQUET; ---- 4 @@ -204,7 +207,7 @@ query ?? COPY (values (struct ('foo', (struct ('foo', make_array(struct('a',1), struct('b',2))))), make_array(timestamp '2023-01-01 01:00:01',timestamp '2023-01-01 01:00:01')), (struct('bar', (struct ('foo', make_array(struct('aa',10), struct('bb',20))))), make_array(timestamp '2024-01-01 01:00:01', timestamp '2024-01-01 01:00:01'))) -to 'test_files/scratch/copy/table_nested/' (format parquet); +to 'test_files/scratch/copy/table_nested/' STORED AS PARQUET; ---- 2 @@ -221,7 +224,7 @@ select * from validate_parquet_nested; query ? copy (values ([struct('foo', 1), struct('bar', 2)])) to 'test_files/scratch/copy/array_of_struct/' -(format parquet); +STORED AS PARQUET; ---- 1 @@ -236,8 +239,7 @@ select * from validate_array_of_struct; query ? copy (values (struct('foo', [1,2,3], struct('bar', [2,3,4])))) -to 'test_files/scratch/copy/struct_with_array/' -(format parquet); +to 'test_files/scratch/copy/struct_with_array/' STORED AS PARQUET; ---- 1 @@ -255,31 +257,32 @@ select * from validate_struct_with_array; query IT COPY source_table TO 'test_files/scratch/copy/table_with_options/' -(format parquet, -'parquet.compression' snappy, -'parquet.compression::col1' 'zstd(5)', -'parquet.compression::col2' snappy, -'parquet.max_row_group_size' 12345, -'parquet.data_pagesize_limit' 1234, -'parquet.write_batch_size' 1234, -'parquet.writer_version' 2.0, -'parquet.dictionary_page_size_limit' 123, -'parquet.created_by' 'DF copy.slt', -'parquet.column_index_truncate_length' 123, -'parquet.data_page_row_count_limit' 1234, -'parquet.bloom_filter_enabled' true, -'parquet.bloom_filter_enabled::col1' false, -'parquet.bloom_filter_fpp::col2' 0.456, -'parquet.bloom_filter_ndv::col2' 456, -'parquet.encoding' plain, -'parquet.encoding::col1' DELTA_BINARY_PACKED, -'parquet.dictionary_enabled::col2' true, -'parquet.dictionary_enabled' false, -'parquet.statistics_enabled' page, -'parquet.statistics_enabled::col2' none, -'parquet.max_statistics_size' 123, -'parquet.bloom_filter_fpp' 0.001, -'parquet.bloom_filter_ndv' 100 +STORED AS PARQUET +OPTIONS ( +'format.compression' snappy, +'format.compression::col1' 'zstd(5)', +'format.compression::col2' snappy, +'format.max_row_group_size' 12345, +'format.data_pagesize_limit' 1234, +'format.write_batch_size' 1234, +'format.writer_version' 2.0, +'format.dictionary_page_size_limit' 123, +'format.created_by' 'DF copy.slt', +'format.column_index_truncate_length' 123, +'format.data_page_row_count_limit' 1234, +'format.bloom_filter_enabled' true, +'format.bloom_filter_enabled::col1' false, +'format.bloom_filter_fpp::col2' 0.456, +'format.bloom_filter_ndv::col2' 456, +'format.encoding' plain, +'format.encoding::col1' DELTA_BINARY_PACKED, +'format.dictionary_enabled::col2' true, +'format.dictionary_enabled' false, +'format.statistics_enabled' page, +'format.statistics_enabled::col2' none, +'format.max_statistics_size' 123, +'format.bloom_filter_fpp' 0.001, +'format.bloom_filter_ndv' 100 ) ---- 2 @@ -312,7 +315,7 @@ select * from validate_parquet_single; # copy from table to folder of compressed json files query IT -COPY source_table to 'test_files/scratch/copy/table_json_gz' (format json, 'json.compression' gzip); +COPY source_table to 'test_files/scratch/copy/table_json_gz' STORED AS JSON OPTIONS ('format.compression' gzip); ---- 2 @@ -328,7 +331,7 @@ select * from validate_json_gz; # copy from table to folder of compressed csv files query IT -COPY source_table to 'test_files/scratch/copy/table_csv' (format csv, 'csv.has_header' false, 'csv.compression' gzip); +COPY source_table to 'test_files/scratch/copy/table_csv' STORED AS CSV OPTIONS ('format.has_header' false, 'format.compression' gzip); ---- 2 @@ -360,7 +363,7 @@ select * from validate_single_csv; # Copy from table to folder of json query IT -COPY source_table to 'test_files/scratch/copy/table_json' (format json); +COPY source_table to 'test_files/scratch/copy/table_json' STORED AS JSON; ---- 2 @@ -376,7 +379,7 @@ select * from validate_json; # Copy from table to single json file query IT -COPY source_table to 'test_files/scratch/copy/table.json'; +COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON ; ---- 2 @@ -394,12 +397,12 @@ select * from validate_single_json; query IT COPY source_table to 'test_files/scratch/copy/table_csv_with_options' -(format csv, -'csv.has_header' false, -'csv.compression' uncompressed, -'csv.datetime_format' '%FT%H:%M:%S.%9f', -'csv.delimiter' ';', -'csv.null_value' 'NULLVAL'); +STORED AS CSV OPTIONS ( +'format.has_header' false, +'format.compression' uncompressed, +'format.datetime_format' '%FT%H:%M:%S.%9f', +'format.delimiter' ';', +'format.null_value' 'NULLVAL'); ---- 2 @@ -417,7 +420,7 @@ select * from validate_csv_with_options; # Copy from table to single arrow file query IT -COPY source_table to 'test_files/scratch/copy/table.arrow'; +COPY source_table to 'test_files/scratch/copy/table.arrow' STORED AS ARROW; ---- 2 @@ -437,7 +440,7 @@ select * from validate_arrow_file; query T? COPY (values ('c', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('d', arrow_cast('bar', 'Dictionary(Int32, Utf8)'))) -to 'test_files/scratch/copy/table_dict.arrow'; +to 'test_files/scratch/copy/table_dict.arrow' STORED AS ARROW; ---- 2 @@ -456,7 +459,7 @@ d bar # Copy from table to folder of json query IT -COPY source_table to 'test_files/scratch/copy/table_arrow' (format arrow); +COPY source_table to 'test_files/scratch/copy/table_arrow' STORED AS ARROW; ---- 2 @@ -475,12 +478,12 @@ select * from validate_arrow; # Copy from table with options query error DataFusion error: Invalid or Unsupported Configuration: Config value "row_group_size" not found on JsonOptions -COPY source_table to 'test_files/scratch/copy/table.json' ('json.row_group_size' 55); +COPY source_table to 'test_files/scratch/copy/table.json' STORED AS JSON OPTIONS ('format.row_group_size' 55); # Incomplete statement query error DataFusion error: SQL error: ParserError\("Expected \), found: EOF"\) COPY (select col2, sum(col1) from source_table # Copy from table with non literal -query error DataFusion error: SQL error: ParserError\("Expected ',' or '\)' after option definition, found: \+"\) +query error DataFusion error: SQL error: ParserError\("Unexpected token \("\) COPY source_table to '/tmp/table.parquet' (row_group_size 55 + 102); diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index 3b85dd9e986f..c4a26a5e227d 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -101,8 +101,8 @@ statement error DataFusion error: SQL error: ParserError\("Unexpected token FOOB CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV FOOBAR BARBAR BARFOO LOCATION 'foo.csv'; # Conflicting options -statement error DataFusion error: Invalid or Unsupported Configuration: Key "parquet.column_index_truncate_length" is not applicable for CSV format +statement error DataFusion error: Invalid or Unsupported Configuration: Config value "column_index_truncate_length" not found on CsvOptions CREATE EXTERNAL TABLE csv_table (column1 int) STORED AS CSV LOCATION 'foo.csv' -OPTIONS ('csv.delimiter' ';', 'parquet.column_index_truncate_length' '123') +OPTIONS ('format.delimiter' ';', 'format.column_index_truncate_length' '123') diff --git a/datafusion/sqllogictest/test_files/csv_files.slt b/datafusion/sqllogictest/test_files/csv_files.slt index 7b299c0cf143..ab6847afb6a5 100644 --- a/datafusion/sqllogictest/test_files/csv_files.slt +++ b/datafusion/sqllogictest/test_files/csv_files.slt @@ -23,7 +23,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.quote' '~') +OPTIONS ('format.quote' '~') LOCATION '../core/tests/data/quote.csv'; statement ok @@ -33,7 +33,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.escape' '\') +OPTIONS ('format.escape' '\') LOCATION '../core/tests/data/escape.csv'; query TT @@ -71,7 +71,7 @@ c2 VARCHAR ) STORED AS CSV WITH HEADER ROW DELIMITER ',' -OPTIONS ('csv.escape' '"') +OPTIONS ('format.escape' '"') LOCATION '../core/tests/data/escape.csv'; # TODO: Validate this with better data. @@ -117,14 +117,14 @@ CREATE TABLE src_table_2 ( query ITII COPY src_table_1 TO 'test_files/scratch/csv_files/csv_partitions/1.csv' -(FORMAT CSV); +STORED AS CSV; ---- 4 query ITII COPY src_table_2 TO 'test_files/scratch/csv_files/csv_partitions/2.csv' -(FORMAT CSV); +STORED AS CSV; ---- 4 diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 3d9f8ff3ad2c..869462b4722a 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -4506,28 +4506,28 @@ CREATE TABLE src_table ( query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/0.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/1.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/2.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 query PI COPY (SELECT * FROM src_table) TO 'test_files/scratch/group_by/timestamp_table/3.csv' -(FORMAT CSV); +STORED AS CSV; ---- 10 diff --git a/datafusion/sqllogictest/test_files/parquet.slt b/datafusion/sqllogictest/test_files/parquet.slt index b7cd1243cb0f..3cc52666d533 100644 --- a/datafusion/sqllogictest/test_files/parquet.slt +++ b/datafusion/sqllogictest/test_files/parquet.slt @@ -45,7 +45,7 @@ CREATE TABLE src_table ( query ITID COPY (SELECT * FROM src_table LIMIT 3) TO 'test_files/scratch/parquet/test_table/0.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -53,7 +53,7 @@ TO 'test_files/scratch/parquet/test_table/0.parquet' query ITID COPY (SELECT * FROM src_table WHERE int_col > 3 LIMIT 3) TO 'test_files/scratch/parquet/test_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -128,7 +128,7 @@ SortPreservingMergeExec: [string_col@1 ASC NULLS LAST,int_col@0 ASC NULLS LAST] query ITID COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) TO 'test_files/scratch/parquet/test_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 @@ -281,7 +281,7 @@ LIMIT 10; query ITID COPY (SELECT * FROM src_table WHERE int_col > 6 LIMIT 3) TO 'test_files/scratch/parquet/test_table/subdir/3.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ---- 3 diff --git a/datafusion/sqllogictest/test_files/repartition.slt b/datafusion/sqllogictest/test_files/repartition.slt index 391a6739b060..594c52f12d75 100644 --- a/datafusion/sqllogictest/test_files/repartition.slt +++ b/datafusion/sqllogictest/test_files/repartition.slt @@ -25,7 +25,7 @@ set datafusion.execution.target_partitions = 4; statement ok COPY (VALUES (1, 2), (2, 5), (3, 2), (4, 5), (5, 0)) TO 'test_files/scratch/repartition/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; statement ok CREATE EXTERNAL TABLE parquet_table(column1 int, column2 int) diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index 15fe670a454c..fe0f6c1e8139 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -35,7 +35,7 @@ set datafusion.optimizer.repartition_file_min_size = 1; # Note filename 2.parquet to test sorting (on local file systems it is often listed before 1.parquet) statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; statement ok CREATE EXTERNAL TABLE parquet_table(column1 int) @@ -86,7 +86,7 @@ set datafusion.optimizer.enable_round_robin_repartition = true; # create a second parquet file statement ok COPY (VALUES (100), (200)) TO 'test_files/scratch/repartition_scan/parquet_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; ## Still expect to see the scan read the file as "4" groups with even sizes. One group should read ## parts of both files. @@ -158,7 +158,7 @@ DROP TABLE parquet_table_with_order; # create a single csv file statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/csv_table/1.csv' -(FORMAT csv, 'csv.has_header' true); +STORED AS CSV WITH HEADER ROW; statement ok CREATE EXTERNAL TABLE csv_table(column1 int) @@ -202,7 +202,7 @@ DROP TABLE csv_table; # create a single json file statement ok COPY (VALUES (1), (2), (3), (4), (5)) TO 'test_files/scratch/repartition_scan/json_table/1.json' -(FORMAT json); +STORED AS JSON; statement ok CREATE EXTERNAL TABLE json_table (column1 int) diff --git a/datafusion/sqllogictest/test_files/schema_evolution.slt b/datafusion/sqllogictest/test_files/schema_evolution.slt index aee0e97edc1e..5572c4a5ffef 100644 --- a/datafusion/sqllogictest/test_files/schema_evolution.slt +++ b/datafusion/sqllogictest/test_files/schema_evolution.slt @@ -31,7 +31,7 @@ COPY ( SELECT column1 as a, column2 as b FROM ( VALUES ('foo', 1), ('foo', 2), ('foo', 3) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/1.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File2 has only b @@ -40,7 +40,7 @@ COPY ( SELECT column1 as b FROM ( VALUES (10) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/2.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File3 has a column from 'z' which does not appear in the table # but also values from a which do appear in the table @@ -49,7 +49,7 @@ COPY ( SELECT column1 as z, column2 as a FROM ( VALUES ('bar', 'foo'), ('blarg', 'foo') ) ) TO 'test_files/scratch/schema_evolution/parquet_table/3.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # File4 has data for b and a (reversed) and d statement ok @@ -57,7 +57,7 @@ COPY ( SELECT column1 as b, column2 as a, column3 as c FROM ( VALUES (100, 'foo', 10.5), (200, 'foo', 12.6), (300, 'bzz', 13.7) ) ) TO 'test_files/scratch/schema_evolution/parquet_table/4.parquet' -(FORMAT PARQUET); +STORED AS PARQUET; # The logical distribution of `a`, `b` and `c` in the files is like this: # diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index 405e77a21b26..b9614bb8f929 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -49,7 +49,7 @@ Copy the contents of `source_table` to one or more Parquet formatted files in the `dir_name` directory: ```sql -> COPY source_table TO 'dir_name' (FORMAT parquet); +> COPY source_table TO 'dir_name' STORED AS PARQUET; +-------+ | count | +-------+ From fa7ca27c15328247dbf98b2f8773c19398b8a745 Mon Sep 17 00:00:00 2001 From: Chunchun Ye <14298407+appletreeisyellow@users.noreply.github.com> Date: Mon, 18 Mar 2024 17:53:46 -0500 Subject: [PATCH 011/117] Support "A column is known to be entirely NULL" in `PruningPredicate` (#9223) * feat: add row_counts() to PruningStatistics trait * chore: remove comments * feat(pruning): add predicate rewrite for `CASE WHEN x_null_count = x_row_count THEN false ELSE ... END` * chore: clippy and update pruning predicates in tests * chore(pruning): fix data type and column expression for null and row counts chore: fix pruning_predicate in slt tests chore: clippy * doc: add examples in doc * chore: update comments * docs: use feedback Co-authored-by: Andrew Lamb docs: take more feedback * test: add test * docs: update comments * docs: update comments to put rewritten predicate first --- datafusion-examples/examples/pruning.rs | 5 + .../datasource/physical_plan/parquet/mod.rs | 2 +- .../physical_plan/parquet/page_filter.rs | 4 + .../physical_plan/parquet/row_groups.rs | 8 + .../core/src/physical_optimizer/pruning.rs | 516 ++++++++++++++++-- .../test_files/repartition_scan.slt | 8 +- 6 files changed, 492 insertions(+), 51 deletions(-) diff --git a/datafusion-examples/examples/pruning.rs b/datafusion-examples/examples/pruning.rs index 1d84fc2d1e0a..3fa35049a8da 100644 --- a/datafusion-examples/examples/pruning.rs +++ b/datafusion-examples/examples/pruning.rs @@ -163,6 +163,11 @@ impl PruningStatistics for MyCatalog { None } + fn row_counts(&self, _column: &Column) -> Option { + // In this example, we know nothing about the number of rows in each file + None + } + fn contained( &self, _column: &Column, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 2cfbb578da66..a2e645cf3e72 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -1870,7 +1870,7 @@ mod tests { assert_contains!( &display, - "pruning_predicate=c1_min@0 != bar OR bar != c1_max@1" + "pruning_predicate=CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != bar OR bar != c1_max@1 END" ); assert_contains!(&display, r#"predicate=c1@0 != bar"#); diff --git a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs index 064a8e1fff33..c7706f3458d0 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/page_filter.rs @@ -547,6 +547,10 @@ impl<'a> PruningStatistics for PagesPruningStatistics<'a> { } } + fn row_counts(&self, _column: &datafusion_common::Column) -> Option { + None + } + fn contained( &self, _column: &datafusion_common::Column, diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 1a84f52a33fd..a0bb5ab71204 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -199,6 +199,10 @@ impl PruningStatistics for BloomFilterStatistics { None } + fn row_counts(&self, _column: &Column) -> Option { + None + } + /// Use bloom filters to determine if we are sure this column can not /// possibly contain `values` /// @@ -332,6 +336,10 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { scalar.to_array().ok() } + fn row_counts(&self, _column: &Column) -> Option { + None + } + fn contained( &self, _column: &Column, diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index d2126f90eca9..80bb5ad42e81 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -53,7 +53,7 @@ use log::trace; /// /// 1. Minimum and maximum values for columns /// -/// 2. Null counts for columns +/// 2. Null counts and row counts for columns /// /// 3. Whether the values in a column are contained in a set of literals /// @@ -100,7 +100,8 @@ pub trait PruningStatistics { /// these statistics. /// /// This value corresponds to the size of the [`ArrayRef`] returned by - /// [`Self::min_values`], [`Self::max_values`], and [`Self::null_counts`]. + /// [`Self::min_values`], [`Self::max_values`], [`Self::null_counts`], + /// and [`Self::row_counts`]. fn num_containers(&self) -> usize; /// Return the number of null values for the named column as an @@ -111,6 +112,14 @@ pub trait PruningStatistics { /// Note: the returned array must contain [`Self::num_containers`] rows fn null_counts(&self, column: &Column) -> Option; + /// Return the number of rows for the named column in each container + /// as an `Option`. + /// + /// See [`Self::min_values`] for when to return `None` and null values. + /// + /// Note: the returned array must contain [`Self::num_containers`] rows + fn row_counts(&self, column: &Column) -> Option; + /// Returns [`BooleanArray`] where each row represents information known /// about specific literal `values` in a column. /// @@ -268,7 +277,7 @@ pub trait PruningStatistics { /// 3. [`PruningStatistics`] that provides information about columns in that /// schema, for multiple “containers”. For each column in each container, it /// provides optional information on contained values, min_values, max_values, -/// and null_counts counts. +/// null_counts counts, and row_counts counts. /// /// **Outputs**: /// A (non null) boolean value for each container: @@ -306,17 +315,23 @@ pub trait PruningStatistics { /// * `false`: there are no rows that could possibly match the predicate, /// **PRUNES** the container /// -/// For example, given a column `x`, the `x_min` and `x_max` and `x_null_count` -/// represent the minimum and maximum values, and the null count of column `x`, -/// provided by the `PruningStatistics`. Here are some examples of the rewritten -/// predicates: +/// For example, given a column `x`, the `x_min`, `x_max`, `x_null_count`, and +/// `x_row_count` represent the minimum and maximum values, the null count of +/// column `x`, and the row count of column `x`, provided by the `PruningStatistics`. +/// `x_null_count` and `x_row_count` are used to handle the case where the column `x` +/// is known to be all `NULL`s. Note this is different from knowing nothing about +/// the column `x`, which confusingly is encoded by returning `NULL` for the min/max +/// values from [`PruningStatistics::max_values`] and [`PruningStatistics::min_values`]. +/// +/// Here are some examples of the rewritten predicates: /// /// Original Predicate | Rewritten Predicate /// ------------------ | -------------------- -/// `x = 5` | `x_min <= 5 AND 5 <= x_max` -/// `x < 5` | `x_max < 5` -/// `x = 5 AND y = 10` | `x_min <= 5 AND 5 <= x_max AND y_min <= 10 AND 10 <= y_max` -/// `x IS NULL` | `x_null_count > 0` +/// `x = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END` +/// `x < 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_max < 5 END` +/// `x = 5 AND y = 10` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_min <= 5 AND 5 <= x_max END AND CASE WHEN y_null_count = y_row_count THEN false ELSE y_min <= 10 AND 10 <= y_max END` +/// `x IS NULL` | `CASE WHEN x_null_count = x_row_count THEN false ELSE x_null_count > 0 END` +/// `CAST(x as int) = 5` | `CASE WHEN x_null_count = x_row_count THEN false ELSE CAST(x_min as int) <= 5 AND 5 <= CAST(x_max as int) END` /// /// ## Predicate Evaluation /// The PruningPredicate works in two passes @@ -326,28 +341,47 @@ pub trait PruningStatistics { /// LiteralGuarantees are not satisfied /// /// **Second Pass**: Evaluates the rewritten expression using the -/// min/max/null_counts values for each column for each container. For any +/// min/max/null_counts/row_counts values for each column for each container. For any /// container that this expression evaluates to `false`, it rules out those /// containers. /// -/// For example, given the predicate, `x = 5 AND y = 10`, if we know `x` is -/// between `1 and 100` and we know that `y` is between `4` and `7`, the input -/// statistics might look like +/// +/// ### Example 1 +/// +/// Given the predicate, `x = 5 AND y = 10`, the rewritten predicate would look like: +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 5 AND 5 <= x_max +/// END +/// AND +/// CASE +/// WHEN y_null_count = y_row_count THEN false +/// ELSE y_min <= 10 AND 10 <= y_max +/// END +/// ``` +/// +/// If we know that for a given container, `x` is between `1 and 100` and we know that +/// `y` is between `4` and `7`, we know nothing about the null count and row count of +/// `x` and `y`, the input statistics might look like: /// /// Column | Value /// -------- | ----- /// `x_min` | `1` /// `x_max` | `100` +/// `x_null_count` | `null` +/// `x_row_count` | `null` /// `y_min` | `4` /// `y_max` | `7` +/// `y_null_count` | `null` +/// `y_row_count` | `null` /// -/// The rewritten predicate would look like -/// -/// `x_min <= 5 AND 5 <= x_max AND y_min <= 10 AND 10 <= y_max` -/// -/// When these values are substituted in to the rewritten predicate and +/// When these statistics values are substituted in to the rewritten predicate and /// simplified, the result is `false`: /// +/// * `CASE WHEN null = null THEN false ELSE 1 <= 5 AND 5 <= 100 END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` +/// * `null = null` is `null` which is not true, so the `CASE` expression will use the `ELSE` clause /// * `1 <= 5 AND 5 <= 100 AND 4 <= 10 AND 10 <= 7` /// * `true AND true AND true AND false` /// * `false` @@ -364,6 +398,52 @@ pub trait PruningStatistics { /// more analysis, for example by actually reading the data and evaluating the /// predicate row by row. /// +/// ### Example 2 +/// +/// Given the same predicate, `x = 5 AND y = 10`, the rewritten predicate would +/// look like the same as example 1: +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 5 AND 5 <= x_max +/// END +/// AND +/// CASE +/// WHEN y_null_count = y_row_count THEN false +/// ELSE y_min <= 10 AND 10 <= y_max +/// END +/// ``` +/// +/// If we know that for another given container, `x_min` is NULL and `x_max` is +/// NULL (the min/max values are unknown), `x_null_count` is `100` and `x_row_count` +/// is `100`; we know that `y` is between `4` and `7`, but we know nothing about +/// the null count and row count of `y`. The input statistics might look like: +/// +/// Column | Value +/// -------- | ----- +/// `x_min` | `null` +/// `x_max` | `null` +/// `x_null_count` | `100` +/// `x_row_count` | `100` +/// `y_min` | `4` +/// `y_max` | `7` +/// `y_null_count` | `null` +/// `y_row_count` | `null` +/// +/// When these statistics values are substituted in to the rewritten predicate and +/// simplified, the result is `false`: +/// +/// * `CASE WHEN 100 = 100 THEN false ELSE null <= 5 AND 5 <= null END AND CASE WHEN null = null THEN false ELSE 4 <= 10 AND 10 <= 7 END` +/// * Since `100 = 100` is `true`, the `CASE` expression will use the `THEN` clause, i.e. `false` +/// * The other `CASE` expression will use the `ELSE` clause, i.e. `4 <= 10 AND 10 <= 7` +/// * `false AND true` +/// * `false` +/// +/// Returning `false` means the container can be pruned, which matches the +/// intuition that `x = 5 AND y = 10` can’t be true for all values in `x` +/// are known to be NULL. +/// /// # Related Work /// /// [`PruningPredicate`] implements the type of min/max pruning described in @@ -744,6 +824,22 @@ impl RequiredColumns { "null_count", ) } + + /// rewrite col --> col_row_count + fn row_count_column_expr( + &mut self, + column: &phys_expr::Column, + column_expr: &Arc, + field: &Field, + ) -> Result> { + self.stat_column_expr( + column, + column_expr, + field, + StatisticsType::RowCount, + "row_count", + ) + } } impl From> for RequiredColumns { @@ -794,6 +890,7 @@ fn build_statistics_record_batch( StatisticsType::Min => statistics.min_values(&column), StatisticsType::Max => statistics.max_values(&column), StatisticsType::NullCount => statistics.null_counts(&column), + StatisticsType::RowCount => statistics.row_counts(&column), }; let array = array.unwrap_or_else(|| new_null_array(data_type, num_containers)); @@ -903,6 +1000,46 @@ impl<'a> PruningExpressionBuilder<'a> { self.required_columns .max_column_expr(&self.column, &self.column_expr, self.field) } + + /// This function is to simply retune the `null_count` physical expression no matter what the + /// predicate expression is + /// + /// i.e., x > 5 => x_null_count, + /// cast(x as int) < 10 => x_null_count, + /// try_cast(x as float) < 10.0 => x_null_count + fn null_count_column_expr(&mut self) -> Result> { + // Retune to [`phys_expr::Column`] + let column_expr = Arc::new(self.column.clone()) as _; + + // null_count is DataType::UInt64, which is different from the column's data type (i.e. self.field) + let null_count_field = &Field::new(self.field.name(), DataType::UInt64, true); + + self.required_columns.null_count_column_expr( + &self.column, + &column_expr, + null_count_field, + ) + } + + /// This function is to simply retune the `row_count` physical expression no matter what the + /// predicate expression is + /// + /// i.e., x > 5 => x_row_count, + /// cast(x as int) < 10 => x_row_count, + /// try_cast(x as float) < 10.0 => x_row_count + fn row_count_column_expr(&mut self) -> Result> { + // Retune to [`phys_expr::Column`] + let column_expr = Arc::new(self.column.clone()) as _; + + // row_count is DataType::UInt64, which is different from the column's data type (i.e. self.field) + let row_count_field = &Field::new(self.field.name(), DataType::UInt64, true); + + self.required_columns.row_count_column_expr( + &self.column, + &column_expr, + row_count_field, + ) + } } /// This function is designed to rewrite the column_expr to @@ -1320,14 +1457,56 @@ fn build_statistics_expr( ); } }; + let statistics_expr = wrap_case_expr(statistics_expr, expr_builder)?; Ok(statistics_expr) } +/// Wrap the statistics expression in a case expression. +/// This is necessary to handle the case where the column is known +/// to be all nulls. +/// +/// For example: +/// +/// `x_min <= 10 AND 10 <= x_max` +/// +/// will become +/// +/// ```sql +/// CASE +/// WHEN x_null_count = x_row_count THEN false +/// ELSE x_min <= 10 AND 10 <= x_max +/// END +/// ```` +/// +/// If the column is known to be all nulls, then the expression +/// `x_null_count = x_row_count` will be true, which will cause the +/// case expression to return false. Therefore, prune out the container. +fn wrap_case_expr( + statistics_expr: Arc, + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + // x_null_count = x_row_count + let when_null_count_eq_row_count = Arc::new(phys_expr::BinaryExpr::new( + expr_builder.null_count_column_expr()?, + Operator::Eq, + expr_builder.row_count_column_expr()?, + )); + let then = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(false)))); + + // CASE WHEN x_null_count = x_row_count THEN false ELSE END + Ok(Arc::new(phys_expr::CaseExpr::try_new( + None, + vec![(when_null_count_eq_row_count, then)], + Some(statistics_expr), + )?)) +} + #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(crate) enum StatisticsType { Min, Max, NullCount, + RowCount, } #[cfg(test)] @@ -1361,6 +1540,7 @@ mod tests { max: Option, /// Optional values null_counts: Option, + row_counts: Option, /// Optional known values (e.g. mimic a bloom filter) /// (value, contained) /// If present, all BooleanArrays must be the same size as min/max @@ -1440,6 +1620,10 @@ mod tests { self.null_counts.clone() } + fn row_counts(&self) -> Option { + self.row_counts.clone() + } + /// return an iterator over all arrays in this statistics fn arrays(&self) -> Vec { let contained_arrays = self @@ -1451,6 +1635,7 @@ mod tests { self.min.as_ref().cloned(), self.max.as_ref().cloned(), self.null_counts.as_ref().cloned(), + self.row_counts.as_ref().cloned(), ] .into_iter() .flatten() @@ -1509,6 +1694,20 @@ mod tests { self } + /// Add row counts. There must be the same number of row counts as + /// there are containers + fn with_row_counts( + mut self, + counts: impl IntoIterator>, + ) -> Self { + let row_counts: ArrayRef = + Arc::new(counts.into_iter().collect::()); + + self.assert_invariants(); + self.row_counts = Some(row_counts); + self + } + /// Add contained information. pub fn with_contained( mut self, @@ -1576,6 +1775,28 @@ mod tests { self } + /// Add row counts for the specified columm. + /// There must be the same number of row counts as + /// there are containers + fn with_row_counts( + mut self, + name: impl Into, + counts: impl IntoIterator>, + ) -> Self { + let col = Column::from_name(name.into()); + + // take stats out and update them + let container_stats = self + .stats + .remove(&col) + .unwrap_or_default() + .with_row_counts(counts); + + // put stats back in + self.stats.insert(col, container_stats); + self + } + /// Add contained information for the specified columm. fn with_contained( mut self, @@ -1628,6 +1849,13 @@ mod tests { .unwrap_or(None) } + fn row_counts(&self, column: &Column) -> Option { + self.stats + .get(column) + .map(|container_stats| container_stats.row_counts()) + .unwrap_or(None) + } + fn contained( &self, column: &Column, @@ -1663,6 +1891,10 @@ mod tests { None } + fn row_counts(&self, _column: &Column) -> Option { + None + } + fn contained( &self, _column: &Column, @@ -1853,7 +2085,7 @@ mod tests { #[test] fn row_group_predicate_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1"; + let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 END"; // test column on the left let expr = col("c1").eq(lit(1)); @@ -1873,7 +2105,7 @@ mod tests { #[test] fn row_group_predicate_not_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 != 1 OR 1 != c1_max@1"; + let expected_expr = "CASE WHEN c1_null_count@2 = c1_row_count@3 THEN false ELSE c1_min@0 != 1 OR 1 != c1_max@1 END"; // test column on the left let expr = col("c1").not_eq(lit(1)); @@ -1893,7 +2125,8 @@ mod tests { #[test] fn row_group_predicate_gt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_max@0 > 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 > 1 END"; // test column on the left let expr = col("c1").gt(lit(1)); @@ -1913,7 +2146,7 @@ mod tests { #[test] fn row_group_predicate_gt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_max@0 >= 1"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_max@0 >= 1 END"; // test column on the left let expr = col("c1").gt_eq(lit(1)); @@ -1932,7 +2165,8 @@ mod tests { #[test] fn row_group_predicate_lt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 < 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; // test column on the left let expr = col("c1").lt(lit(1)); @@ -1952,7 +2186,7 @@ mod tests { #[test] fn row_group_predicate_lt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min@0 <= 1"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 <= 1 END"; // test column on the left let expr = col("c1").lt_eq(lit(1)); @@ -1977,7 +2211,8 @@ mod tests { ]); // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); - let expected_expr = "c1_min@0 < 1"; + let expected_expr = + "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < 1 END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2043,7 +2278,7 @@ mod tests { #[test] fn row_group_predicate_lt_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); - let expected_expr = "c1_min@0 < true"; + let expected_expr = "CASE WHEN c1_null_count@1 = c1_row_count@2 THEN false ELSE c1_min@0 < true END"; // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated @@ -2066,7 +2301,21 @@ mod tests { let expr = col("c1") .lt(lit(1)) .and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3)))); - let expected_expr = "c1_min@0 < 1 AND (c2_min@1 <= 2 AND 2 <= c2_max@2 OR c2_min@1 <= 3 AND 3 <= c2_max@2)"; + let expected_expr = "\ + CASE \ + WHEN c1_null_count@1 = c1_row_count@2 THEN false \ + ELSE c1_min@0 < 1 \ + END \ + AND (\ + CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@3 <= 2 AND 2 <= c2_max@4 \ + END \ + OR CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@3 <= 3 AND 3 <= c2_max@4 \ + END\ + )"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut required_columns); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2080,10 +2329,30 @@ mod tests { c1_min_field.with_nullable(true) // could be nullable if stats are not present ) ); + // c1 < 1 should add c1_null_count + let c1_null_count_field = Field::new("c1_null_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[1], + ( + phys_expr::Column::new("c1", 0), + StatisticsType::NullCount, + c1_null_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); + // c1 < 1 should add c1_row_count + let c1_row_count_field = Field::new("c1_row_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[2], + ( + phys_expr::Column::new("c1", 0), + StatisticsType::RowCount, + c1_row_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); // c2 = 2 should add c2_min and c2_max let c2_min_field = Field::new("c2_min", DataType::Int32, false); assert_eq!( - required_columns.columns[1], + required_columns.columns[3], ( phys_expr::Column::new("c2", 1), StatisticsType::Min, @@ -2092,15 +2361,35 @@ mod tests { ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); assert_eq!( - required_columns.columns[2], + required_columns.columns[4], ( phys_expr::Column::new("c2", 1), StatisticsType::Max, c2_max_field.with_nullable(true) // could be nullable if stats are not present ) ); + // c2 = 2 should add c2_null_count + let c2_null_count_field = Field::new("c2_null_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[5], + ( + phys_expr::Column::new("c2", 1), + StatisticsType::NullCount, + c2_null_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); + // c2 = 2 should add c2_row_count + let c2_row_count_field = Field::new("c2_row_count", DataType::UInt64, false); + assert_eq!( + required_columns.columns[6], + ( + phys_expr::Column::new("c2", 1), + StatisticsType::RowCount, + c2_row_count_field.with_nullable(true) // could be nullable if stats are not present + ) + ); // c2 = 3 shouldn't add any new statistics fields - assert_eq!(required_columns.columns.len(), 3); + assert_eq!(required_columns.columns.len(), 7); Ok(()) } @@ -2117,7 +2406,18 @@ mod tests { vec![lit(1), lit(2), lit(3)], false, )); - let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 3 AND 3 <= c1_max@1 \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2153,9 +2453,19 @@ mod tests { vec![lit(1), lit(2), lit(3)], true, )); - let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ - AND (c1_min@0 != 2 OR 2 != c1_max@1) \ - AND (c1_min@0 != 3 OR 3 != c1_max@1)"; + let expected_expr = "\ + CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 1 OR 1 != c1_max@1 \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 2 OR 2 != c1_max@1 \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 != 3 OR 3 != c1_max@1 \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2201,7 +2511,24 @@ mod tests { // test c1 in(1, 2) and c2 BETWEEN 4 AND 5 let expr3 = expr1.and(expr2); - let expected_expr = "(c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1) AND c2_max@2 >= 4 AND c2_min@3 <= 5"; + let expected_expr = "\ + (\ + CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 1 AND 1 <= c1_max@1 \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE c1_min@0 <= 2 AND 2 <= c1_max@1 \ + END\ + ) AND CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_max@4 >= 4 \ + END \ + AND CASE \ + WHEN c2_null_count@5 = c2_row_count@6 THEN false \ + ELSE c2_min@7 <= 5 \ + END"; let predicate_expr = test_build_predicate_expression(&expr3, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2228,9 +2555,12 @@ mod tests { #[test] fn row_group_predicate_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = - "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ + END"; + // test cast(c1 as int64) = 1 // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); let predicate_expr = @@ -2243,7 +2573,10 @@ mod tests { test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); - let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; + let expected_expr = "CASE \ + WHEN c1_null_count@1 = c1_row_count@2 THEN false \ + ELSE TRY_CAST(c1_max@0 AS Int64) > 1 \ + END"; // test column on the left let expr = @@ -2275,7 +2608,18 @@ mod tests { ], false, )); - let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) \ + END \ + OR CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64) \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2289,10 +2633,18 @@ mod tests { ], true, )); - let expected_expr = - "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ - AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ - AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; + let expected_expr = "CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64) \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64) \ + END \ + AND CASE \ + WHEN c1_null_count@2 = c1_row_count@3 THEN false \ + ELSE CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64) \ + END"; let predicate_expr = test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); assert_eq!(predicate_expr.to_string(), expected_expr); @@ -2819,7 +3171,7 @@ mod tests { let expected_ret = &[false, true, true, true, false]; prune_with_expr( - // i IS NULL, with actual null statistcs + // i IS NULL, with actual null statistics col("i").is_null(), &schema, &statistics, @@ -2827,6 +3179,78 @@ mod tests { ); } + #[test] + fn prune_int32_column_is_known_all_null() { + let (schema, statistics) = int32_setup(); + + // Expression "i < 0" + // i [-5, 5] ==> some rows could pass (must keep) + // i [1, 11] ==> no rows can pass (not keep) + // i [-11, -1] ==> all rows must pass (must keep) + // i [NULL, NULL] ==> unknown (must keep) + // i [1, NULL] ==> no rows can pass (not keep) + let expected_ret = &[true, false, true, true, false]; + + prune_with_expr( + // i < 0 + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + + // provide row counts for each column + let statistics = statistics.with_row_counts( + "i", + vec![ + Some(10), // 10 rows of data + Some(9), // 9 rows of data + None, // unknown row counts + Some(4), + Some(10), + ], + ); + + // pruning result is still the same if we only know row counts + prune_with_expr( + // i < 0, with only row counts statistics + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + + // provide null counts for each column + let statistics = statistics.with_null_counts( + "i", + vec![ + Some(0), // no nulls + Some(1), // 1 null + None, // unknown nulls + Some(4), // 4 nulls, which is the same as the row counts, i.e. this column is all null (don't keep) + Some(0), // 0 nulls (max=null too which means no known max) + ], + ); + + // Expression "i < 0" with actual null and row counts statistics + // col | min, max | row counts | null counts | + // ----+--------------+------------+-------------+ + // i | [-5, 5] | 10 | 0 | ==> Some rows could pass (must keep) + // i | [1, 11] | 9 | 1 | ==> No rows can pass (not keep) + // i | [-11,-1] | Unknown | Unknown | ==> All rows must pass (must keep) + // i | [NULL, NULL] | 4 | 4 | ==> The column is all null (not keep) + // i | [1, NULL] | 10 | 0 | ==> No rows can pass (not keep) + let expected_ret = &[true, false, true, false, false]; + + prune_with_expr( + // i < 0, with actual null and row counts statistics + col("i").lt(lit(0)), + &schema, + &statistics, + expected_ret, + ); + } + #[test] fn prune_cast_column_scalar() { // The data type of column i is INT32 diff --git a/datafusion/sqllogictest/test_files/repartition_scan.slt b/datafusion/sqllogictest/test_files/repartition_scan.slt index fe0f6c1e8139..f9699a5fda8f 100644 --- a/datafusion/sqllogictest/test_files/repartition_scan.slt +++ b/datafusion/sqllogictest/test_files/repartition_scan.slt @@ -61,7 +61,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # disable round robin repartitioning statement ok @@ -77,7 +77,7 @@ Filter: parquet_table.column1 != Int32(42) physical_plan CoalesceBatchesExec: target_batch_size=8192 --FilterExec: column1@0 != 42 -----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +----ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..104], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:104..208], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:208..312], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:312..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # enable round robin repartitioning again statement ok @@ -102,7 +102,7 @@ SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --SortExec: expr=[column1@0 ASC NULLS LAST] ----CoalesceBatchesExec: target_batch_size=8192 ------FilterExec: column1@0 != 42 ---------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +--------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..205], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:205..405, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..5], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:5..210], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:210..414]]}, projection=[column1], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] ## Read the files as though they are ordered @@ -138,7 +138,7 @@ physical_plan SortPreservingMergeExec: [column1@0 ASC NULLS LAST] --CoalesceBatchesExec: target_batch_size=8192 ----FilterExec: column1@0 != 42 -------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=column1_min@0 != 42 OR 42 != column1_max@1, required_guarantees=[column1 not in (42)] +------ParquetExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:0..202], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:0..207], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/2.parquet:207..414], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/repartition_scan/parquet_table/1.parquet:202..405]]}, projection=[column1], output_ordering=[column1@0 ASC NULLS LAST], predicate=column1@0 != 42, pruning_predicate=CASE WHEN column1_null_count@2 = column1_row_count@3 THEN false ELSE column1_min@0 != 42 OR 42 != column1_max@1 END, required_guarantees=[column1 not in (42)] # Cleanup statement ok From b0b329ba39403b9e87156d6f9b8c5464dc6d2480 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 19 Mar 2024 06:57:52 +0800 Subject: [PATCH 012/117] Suppress self update for windows CI runner (#9661) * suppress self update for window Signed-off-by: jayzhan211 * Update .github/actions/setup-windows-builder/action.yaml --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- .github/actions/setup-windows-builder/action.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/actions/setup-windows-builder/action.yaml b/.github/actions/setup-windows-builder/action.yaml index 9ab5c4a8b1bb..a26a34a3db93 100644 --- a/.github/actions/setup-windows-builder/action.yaml +++ b/.github/actions/setup-windows-builder/action.yaml @@ -38,8 +38,8 @@ runs: - name: Setup Rust toolchain shell: bash run: | - rustup update stable - rustup toolchain install stable + # Avoid self update to avoid CI failures: https://github.com/apache/arrow-datafusion/issues/9653 + rustup toolchain install stable --no-self-update rustup default stable rustup component add rustfmt - name: Configure rust runtime env From 8438d2b1ea67fda64955839fb4bd4ed88b861ade Mon Sep 17 00:00:00 2001 From: Suriya Kandaswamy Date: Tue, 19 Mar 2024 10:14:42 -0400 Subject: [PATCH 013/117] add schema to SQL ast builder (#9624) * add schema to ast builder * add schema test --- datafusion/sql/src/unparser/plan.rs | 9 ++++++--- datafusion/sql/tests/sql_integration.rs | 9 +++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e1f5135efda9..c9b0a8a04c7e 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -124,9 +124,12 @@ impl Unparser<'_> { match plan { LogicalPlan::TableScan(scan) => { let mut builder = TableRelationBuilder::default(); - builder.name(ast::ObjectName(vec![ - self.new_ident(scan.table_name.table().to_string()) - ])); + let mut table_parts = vec![]; + if let Some(schema_name) = scan.table_name.schema() { + table_parts.push(self.new_ident(schema_name.to_string())); + } + table_parts.push(self.new_ident(scan.table_name.table().to_string())); + builder.name(ast::ObjectName(table_parts)); relation.table(builder); Ok(()) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 6d335f1f8fc9..47638e58ff00 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -41,6 +41,15 @@ use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; +#[test] +fn test_schema_support() { + quick_test( + "SELECT * FROM s1.test", + "Projection: s1.test.t_date32, s1.test.t_date64\ + \n TableScan: s1.test", + ); +} + #[test] fn parse_decimals() { let test_data = [ From 9b098eef6f7d8b6d1162ccbdc9053f8e1cb999d4 Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Tue, 19 Mar 2024 15:20:01 +0100 Subject: [PATCH 014/117] Add tests for row group pruning on strings (#9642) --- datafusion/core/tests/parquet/mod.rs | 112 +++++++++++++++++- .../core/tests/parquet/row_group_pruning.rs | 102 ++++++++++++++++ 2 files changed, 211 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index c60780919489..3fe51288e79a 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -19,9 +19,9 @@ use arrow::array::Decimal128Array; use arrow::{ array::{ - Array, ArrayRef, Date32Array, Date64Array, Float64Array, Int32Array, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, + Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray, + Float64Array, Int32Array, StringArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -70,6 +70,7 @@ enum Scenario { DecimalBloomFilterInt64, DecimalLargePrecision, DecimalLargePrecisionBloomFilter, + ByteArray, PeriodsInColumnNames, } @@ -506,6 +507,51 @@ fn make_date_batch(offset: Duration) -> RecordBatch { .unwrap() } +/// returns a batch with two columns (note "service.name" is the name +/// of the column. It is *not* a table named service.name +/// +/// name | service.name +fn make_bytearray_batch( + name: &str, + string_values: Vec<&str>, + binary_values: Vec<&[u8]>, + fixedsize_values: Vec<&[u8; 3]>, +) -> RecordBatch { + let num_rows = string_values.len(); + let name: StringArray = std::iter::repeat(Some(name)).take(num_rows).collect(); + let service_string: StringArray = string_values.iter().map(Some).collect(); + let service_binary: BinaryArray = binary_values.iter().map(Some).collect(); + let service_fixedsize: FixedSizeBinaryArray = fixedsize_values + .iter() + .map(|value| Some(value.as_slice())) + .collect::>() + .into(); + + let schema = Schema::new(vec![ + Field::new("name", name.data_type().clone(), true), + // note the column name has a period in it! + Field::new("service_string", service_string.data_type().clone(), true), + Field::new("service_binary", service_binary.data_type().clone(), true), + Field::new( + "service_fixedsize", + service_fixedsize.data_type().clone(), + true, + ), + ]); + let schema = Arc::new(schema); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(name), + Arc::new(service_string), + Arc::new(service_binary), + Arc::new(service_fixedsize), + ], + ) + .unwrap() +} + /// returns a batch with two columns (note "service.name" is the name /// of the column. It is *not* a table named service.name /// @@ -604,6 +650,66 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_decimal_batch(vec![100000, 200000, 300000, 400000, 600000], 38, 5), ] } + Scenario::ByteArray => { + // frontends first, then backends. All in order, except frontends 4 and 7 + // are swapped to cause a statistics false positive on the 'fixed size' column. + vec![ + make_bytearray_batch( + "all frontends", + vec![ + "frontend one", + "frontend two", + "frontend three", + "frontend seven", + "frontend five", + ], + vec![ + b"frontend one", + b"frontend two", + b"frontend three", + b"frontend seven", + b"frontend five", + ], + vec![b"fe1", b"fe2", b"fe3", b"fe7", b"fe5"], + ), + make_bytearray_batch( + "mixed", + vec![ + "frontend six", + "frontend four", + "backend one", + "backend two", + "backend three", + ], + vec![ + b"frontend six", + b"frontend four", + b"backend one", + b"backend two", + b"backend three", + ], + vec![b"fe6", b"fe4", b"be1", b"be2", b"be3"], + ), + make_bytearray_batch( + "all backends", + vec![ + "backend four", + "backend five", + "backend six", + "backend seven", + "backend eight", + ], + vec![ + b"backend four", + b"backend five", + b"backend six", + b"backend seven", + b"backend eight", + ], + vec![b"be4", b"be5", b"be6", b"be7", b"be8"], + ), + ] + } Scenario::PeriodsInColumnNames => { vec![ // all frontend diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index b7038ef1a73f..406eb721bf94 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -744,6 +744,108 @@ async fn prune_decimal_in_list() { .await; } +#[tokio::test] +async fn prune_string_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'backend one'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'backend nine'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string = 'frontend nine'", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'frontend five' < 'frontend nine' < 'frontend two' + // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string != 'backend one'", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_string_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string < 'backend one'", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_string FROM t WHERE service_string < 'backend zero'", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; +} + #[tokio::test] async fn prune_periods_in_column_names() { // There are three row groups for "service.name", each with 5 rows = 15 rows total From 3c3b22866a7ece784208e9d499119b2e13399762 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 19 Mar 2024 11:38:25 -0400 Subject: [PATCH 015/117] Fix incorrect results with multiple `COUNT(DISTINCT..)` aggregates on dictionaries (#9679) * Add test for multiple count distincts on a dictionary * Fix accumulator merge bug * Fix cleanup code --- datafusion/common/src/scalar/mod.rs | 2 +- .../src/aggregate/count_distinct/mod.rs | 32 +++++++-- .../sqllogictest/test_files/dictionary.slt | 67 +++++++++++++++++++ 3 files changed, 93 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 5ace44f24b69..316624175e1c 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1746,7 +1746,7 @@ impl ScalarValue { } /// Converts `Vec` where each element has type corresponding to - /// `data_type`, to a [`ListArray`]. + /// `data_type`, to a single element [`ListArray`]. /// /// Example /// ``` diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 71782fcc5f9b..fb5e7710496c 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -47,7 +47,7 @@ use crate::binary_map::OutputType; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -/// Expression for a COUNT(DISTINCT) aggregation. +/// Expression for a `COUNT(DISTINCT)` aggregation. #[derive(Debug)] pub struct DistinctCount { /// Column name @@ -100,6 +100,7 @@ impl AggregateExpr for DistinctCount { use TimeUnit::*; Ok(match &self.state_data_type { + // try and use a specialized accumulator if possible, otherwise fall back to generic accumulator Int8 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int16 => Box::new(PrimitiveDistinctCountAccumulator::::new()), Int32 => Box::new(PrimitiveDistinctCountAccumulator::::new()), @@ -157,6 +158,7 @@ impl AggregateExpr for DistinctCount { OutputType::Binary, )), + // Use the generic accumulator based on `ScalarValue` for all other types _ => Box::new(DistinctCountAccumulator { values: HashSet::default(), state_data_type: self.state_data_type.clone(), @@ -183,7 +185,11 @@ impl PartialEq for DistinctCount { } /// General purpose distinct accumulator that works for any DataType by using -/// [`ScalarValue`]. Some types have specialized accumulators that are (much) +/// [`ScalarValue`]. +/// +/// It stores intermediate results as a `ListArray` +/// +/// Note that many types have specialized accumulators that are (much) /// more efficient such as [`PrimitiveDistinctCountAccumulator`] and /// [`BytesDistinctCountAccumulator`] #[derive(Debug)] @@ -193,8 +199,9 @@ struct DistinctCountAccumulator { } impl DistinctCountAccumulator { - // calculating the size for fixed length values, taking first batch size * number of batches - // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types + // calculating the size for fixed length values, taking first batch size * + // number of batches This method is faster than .full_size(), however it is + // not suitable for variable length values like strings or complex types fn fixed_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -207,7 +214,8 @@ impl DistinctCountAccumulator { + std::mem::size_of::() } - // calculates the size as accurate as possible, call to this method is expensive + // calculates the size as accurately as possible. Note that calling this + // method is expensive fn full_size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.values.capacity()) @@ -221,6 +229,7 @@ impl DistinctCountAccumulator { } impl Accumulator for DistinctCountAccumulator { + /// Returns the distinct values seen so far as (one element) ListArray. fn state(&mut self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); @@ -246,6 +255,11 @@ impl Accumulator for DistinctCountAccumulator { }) } + /// Merges multiple sets of distinct values into the current set. + /// + /// The input to this function is a `ListArray` with **multiple** rows, + /// where each row contains the values from a partial aggregate's phase (e.g. + /// the result of calling `Self::state` on multiple accumulators). fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { if states.is_empty() { return Ok(()); @@ -253,8 +267,12 @@ impl Accumulator for DistinctCountAccumulator { assert_eq!(states.len(), 1, "array_agg states must be singleton!"); let array = &states[0]; let list_array = array.as_list::(); - let inner_array = list_array.value(0); - self.update_batch(&[inner_array]) + for inner_array in list_array.iter() { + let inner_array = inner_array + .expect("counts are always non null, so are intermediate results"); + self.update_batch(&[inner_array])?; + } + Ok(()) } fn evaluate(&mut self) -> Result { diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index 002aade2528e..af7bf5cb16e8 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -280,3 +280,70 @@ ORDER BY 2023-12-20T01:20:00 1000 f2 foo 2023-12-20T01:30:00 1000 f1 32.0 2023-12-20T01:30:00 1000 f2 foo + +# Cleanup +statement ok +drop view m1; + +statement ok +drop view m2; + +###### +# Create a table using UNION ALL to get 2 partitions (very important) +###### +statement ok +create table m3_source as + select * from (values('foo', 'bar', 1)) + UNION ALL + select * from (values('foo', 'baz', 1)); + +###### +# Now, create a table with the same data, but column2 has type `Dictionary(Int32)` to trigger the fallback code +###### +statement ok +create table m3 as + select + column1, + arrow_cast(column2, 'Dictionary(Int32, Utf8)') as "column2", + column3 +from m3_source; + +# there are two values in column2 +query T?I rowsort +SELECT * +FROM m3; +---- +foo bar 1 +foo baz 1 + +# There is 1 distinct value in column1 +query I +SELECT count(distinct column1) +FROM m3 +GROUP BY column3; +---- +1 + +# There are 2 distinct values in column2 +query I +SELECT count(distinct column2) +FROM m3 +GROUP BY column3; +---- +2 + +# Should still get the same results when querying in the same query +query II +SELECT count(distinct column1), count(distinct column2) +FROM m3 +GROUP BY column3; +---- +1 2 + + +# Cleanup +statement ok +drop table m3; + +statement ok +drop table m3_source; From b87dd6143c2dc089b07f74780bd525c4369e68a3 Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Tue, 19 Mar 2024 16:43:11 +0100 Subject: [PATCH 016/117] Add support for Bloom filters on binary columns (#9644) --- .../physical_plan/parquet/row_groups.rs | 1 + .../core/tests/parquet/row_group_pruning.rs | 102 ++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index a0bb5ab71204..9cd46994960f 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -225,6 +225,7 @@ impl PruningStatistics for BloomFilterStatistics { .map(|value| { match value { ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Binary(Some(v)) => sbbf.check(v), ScalarValue::Boolean(Some(v)) => sbbf.check(v), ScalarValue::Float64(Some(v)) => sbbf.check(v), ScalarValue::Float32(Some(v)) => sbbf.check(v), diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 406eb721bf94..55112193502d 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -846,6 +846,108 @@ async fn prune_string_lt() { .await; } +#[tokio::test] +async fn prune_binary_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('backend nine' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all backends' batch: 'backend five' < 'backend one' < 'backend three' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary = CAST('frontend nine' AS bytea)", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'frontend five' < 'frontend nine' < 'frontend two' + // false positive on 'mixed' batch: 'backend one' < 'frontend nine' < 'frontend six' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(2)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary != CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_binary_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary < CAST('backend one' AS bytea)", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_binary FROM t WHERE service_binary < CAST('backend zero' AS bytea)", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; +} + #[tokio::test] async fn prune_periods_in_column_names() { // There are three row groups for "service.name", each with 5 rows = 15 rows total From 7af69f9768497060343ae2a6fbd1991e9a047dce Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Wed, 20 Mar 2024 05:22:12 +1300 Subject: [PATCH 017/117] Update Arrow/Parquet to `51.0.0`, tonic to `0.11` (#9613) * Prepare for arrow 51 * Fix datafusion-proto * Update deserialize_to_struct * Format * Update pins * Update datafusion-cli Cargo.lock * Remove stale comment * Add comment to seconds --------- Co-authored-by: Andrew Lamb --- Cargo.toml | 18 +- datafusion-cli/Cargo.lock | 88 +++---- datafusion-cli/Cargo.toml | 4 +- datafusion-examples/Cargo.toml | 3 +- .../examples/deserialize_to_struct.rs | 58 ++--- .../examples/flight/flight_server.rs | 9 +- .../examples/flight/flight_sql_server.rs | 3 + .../common/src/file_options/parquet_writer.rs | 1 + datafusion/common/src/scalar/mod.rs | 10 +- .../src/datasource/avro_to_arrow/schema.rs | 6 + .../src/datasource/file_format/parquet.rs | 4 - .../datasource/physical_plan/parquet/mod.rs | 8 +- .../functions/src/datetime/date_part.rs | 214 ++++++------------ datafusion/proto/src/logical_plan/to_proto.rs | 3 + datafusion/sql/src/unparser/expr.rs | 4 + 15 files changed, 195 insertions(+), 238 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 48e555bd5527..d9e69e53db7c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,14 +57,14 @@ version = "36.0.0" # for the inherited dependency but cannot do the reverse (override from true to false). # # See for more detaiils: https://github.com/rust-lang/cargo/issues/11329 -arrow = { version = "50.0.0", features = ["prettyprint"] } -arrow-array = { version = "50.0.0", default-features = false, features = ["chrono-tz"] } -arrow-buffer = { version = "50.0.0", default-features = false } -arrow-flight = { version = "50.0.0", features = ["flight-sql-experimental"] } -arrow-ipc = { version = "50.0.0", default-features = false, features = ["lz4"] } -arrow-ord = { version = "50.0.0", default-features = false } -arrow-schema = { version = "50.0.0", default-features = false } -arrow-string = { version = "50.0.0", default-features = false } +arrow = { version = "51.0.0", features = ["prettyprint"] } +arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] } +arrow-buffer = { version = "51.0.0", default-features = false } +arrow-flight = { version = "51.0.0", features = ["flight-sql-experimental"] } +arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4"] } +arrow-ord = { version = "51.0.0", default-features = false } +arrow-schema = { version = "51.0.0", default-features = false } +arrow-string = { version = "51.0.0", default-features = false } async-trait = "0.1.73" bigdecimal = "=0.4.1" bytes = "1.4" @@ -95,7 +95,7 @@ log = "^0.4" num_cpus = "1.13.0" object_store = { version = "0.9.0", default-features = false } parking_lot = "0.12" -parquet = { version = "50.0.0", default-features = false, features = ["arrow", "async", "object_store"] } +parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" rstest = "0.18.0" serde_json = "1" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 8e2a2c353e2d..51cccf60a1e4 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -130,9 +130,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "arrow" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa285343fba4d829d49985bdc541e3789cf6000ed0e84be7c039438df4a4e78c" +checksum = "219d05930b81663fd3b32e3bde8ce5bff3c4d23052a99f11a8fa50a3b47b2658" dependencies = [ "arrow-arith", "arrow-array", @@ -151,9 +151,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "753abd0a5290c1bcade7c6623a556f7d1659c5f4148b140b5b63ce7bd1a45705" +checksum = "0272150200c07a86a390be651abdd320a2d12e84535f0837566ca87ecd8f95e0" dependencies = [ "arrow-array", "arrow-buffer", @@ -166,9 +166,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d390feeb7f21b78ec997a4081a025baef1e2e0d6069e181939b61864c9779609" +checksum = "8010572cf8c745e242d1b632bd97bd6d4f40fefed5ed1290a8f433abaa686fea" dependencies = [ "ahash", "arrow-buffer", @@ -183,9 +183,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69615b061701bcdffbc62756bc7e85c827d5290b472b580c972ebbbf690f5aa4" +checksum = "0d0a2432f0cba5692bf4cb757469c66791394bac9ec7ce63c1afe74744c37b27" dependencies = [ "bytes", "half", @@ -194,28 +194,30 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e448e5dd2f4113bf5b74a1f26531708f5edcacc77335b7066f9398f4bcf4cdef" +checksum = "9abc10cd7995e83505cc290df9384d6e5412b207b79ce6bdff89a10505ed2cba" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", - "base64 0.21.7", + "atoi", + "base64 0.22.0", "chrono", "comfy-table", "half", "lexical-core", "num", + "ryu", ] [[package]] name = "arrow-csv" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46af72211f0712612f5b18325530b9ad1bfbdc87290d5fbfd32a7da128983781" +checksum = "95cbcba196b862270bf2a5edb75927380a7f3a163622c61d40cbba416a6305f2" dependencies = [ "arrow-array", "arrow-buffer", @@ -232,9 +234,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67d644b91a162f3ad3135ce1184d0a31c28b816a581e08f29e8e9277a574c64e" +checksum = "2742ac1f6650696ab08c88f6dd3f0eb68ce10f8c253958a18c943a68cd04aec5" dependencies = [ "arrow-buffer", "arrow-schema", @@ -244,9 +246,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03dea5e79b48de6c2e04f03f62b0afea7105be7b77d134f6c5414868feefb80d" +checksum = "a42ea853130f7e78b9b9d178cb4cd01dee0f78e64d96c2949dc0a915d6d9e19d" dependencies = [ "arrow-array", "arrow-buffer", @@ -259,9 +261,9 @@ dependencies = [ [[package]] name = "arrow-json" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8950719280397a47d37ac01492e3506a8a724b3fb81001900b866637a829ee0f" +checksum = "eaafb5714d4e59feae964714d724f880511500e3569cc2a94d02456b403a2a49" dependencies = [ "arrow-array", "arrow-buffer", @@ -279,9 +281,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ed9630979034077982d8e74a942b7ac228f33dd93a93b615b4d02ad60c260be" +checksum = "e3e6b61e3dc468f503181dccc2fc705bdcc5f2f146755fa5b56d0a6c5943f412" dependencies = [ "arrow-array", "arrow-buffer", @@ -294,9 +296,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "007035e17ae09c4e8993e4cb8b5b96edf0afb927cd38e2dff27189b274d83dcf" +checksum = "848ee52bb92eb459b811fb471175ea3afcf620157674c8794f539838920f9228" dependencies = [ "ahash", "arrow-array", @@ -309,15 +311,15 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ff3e9c01f7cd169379d269f926892d0e622a704960350d09d331be3ec9e0029" +checksum = "02d9483aaabe910c4781153ae1b6ae0393f72d9ef757d38d09d450070cf2e528" [[package]] name = "arrow-select" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ce20973c1912de6514348e064829e50947e35977bb9d7fb637dc99ea9ffd78c" +checksum = "849524fa70e0e3c5ab58394c770cb8f514d0122d20de08475f7b472ed8075830" dependencies = [ "ahash", "arrow-array", @@ -329,15 +331,16 @@ dependencies = [ [[package]] name = "arrow-string" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f3b37f2aeece31a2636d1b037dabb69ef590e03bdc7eb68519b51ec86932a7" +checksum = "9373cb5a021aee58863498c37eb484998ef13377f69989c6c5ccfbd258236cdb" dependencies = [ "arrow-array", "arrow-buffer", "arrow-data", "arrow-schema", "arrow-select", + "memchr", "num", "regex", "regex-syntax", @@ -387,6 +390,15 @@ dependencies = [ "syn 2.0.53", ] +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + [[package]] name = "atty" version = "0.2.14" @@ -739,9 +751,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "blake2" @@ -2128,7 +2140,7 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "libc", "redox_syscall", ] @@ -2442,9 +2454,9 @@ dependencies = [ [[package]] name = "parquet" -version = "50.0.0" +version = "51.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "547b92ebf0c1177e3892f44c8f79757ee62e678d564a9834189725f2c5b7a750" +checksum = "096795d4f47f65fd3ee1ec5a98b77ab26d602f2cc785b0e4be5443add17ecc32" dependencies = [ "ahash", "arrow-array", @@ -2454,7 +2466,7 @@ dependencies = [ "arrow-ipc", "arrow-schema", "arrow-select", - "base64 0.21.7", + "base64 0.22.0", "brotli", "bytes", "chrono", @@ -2903,7 +2915,7 @@ version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", @@ -3720,9 +3732,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" +checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ "getrandom", "serde", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index ad506762f0d0..da744a06f3aa 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -30,7 +30,7 @@ rust-version = "1.72" readme = "README.md" [dependencies] -arrow = "50.0.0" +arrow = "51.0.0" async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" @@ -52,7 +52,7 @@ futures = "0.3" mimalloc = { version = "0.1", default-features = false } object_store = { version = "0.9.0", features = ["aws", "gcp", "http"] } parking_lot = { version = "0.12" } -parquet = { version = "50.0.0", default-features = false } +parquet = { version = "51.0.0", default-features = false } regex = "1.8" rustyline = "11.0" tokio = { version = "1.24", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index ad2a49fb352e..2b6e869ec500 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -74,7 +74,6 @@ serde = { version = "1.0.136", features = ["derive"] } serde_json = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } -# 0.10 and 0.11 are incompatible. Need to upgrade tonic to 0.11 when upgrading to arrow 51 -tonic = "0.10" +tonic = "0.11" url = { workspace = true } uuid = "1.2" diff --git a/datafusion-examples/examples/deserialize_to_struct.rs b/datafusion-examples/examples/deserialize_to_struct.rs index e999fc4dac3e..985cab703a5c 100644 --- a/datafusion-examples/examples/deserialize_to_struct.rs +++ b/datafusion-examples/examples/deserialize_to_struct.rs @@ -15,14 +15,13 @@ // specific language governing permissions and limitations // under the License. +use arrow::array::AsArray; +use arrow::datatypes::{Float64Type, Int32Type}; use datafusion::error::Result; use datafusion::prelude::*; -use serde::Deserialize; +use futures::StreamExt; /// This example shows that it is possible to convert query results into Rust structs . -/// It will collect the query results into RecordBatch, then convert it to serde_json::Value. -/// Then, serde_json::Value is turned into Rust's struct. -/// Any datatype with `Deserialize` implemeneted works. #[tokio::main] async fn main() -> Result<()> { let data_list = Data::new().await?; @@ -30,10 +29,10 @@ async fn main() -> Result<()> { Ok(()) } -#[derive(Deserialize, Debug)] +#[derive(Debug)] struct Data { #[allow(dead_code)] - int_col: i64, + int_col: i32, #[allow(dead_code)] double_col: f64, } @@ -41,35 +40,36 @@ struct Data { impl Data { pub async fn new() -> Result> { // this group is almost the same as the one you find it in parquet_sql.rs - let batches = { - let ctx = SessionContext::new(); + let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); + let testdata = datafusion::test_util::parquet_test_data(); - ctx.register_parquet( - "alltypes_plain", - &format!("{testdata}/alltypes_plain.parquet"), - ParquetReadOptions::default(), - ) - .await?; + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; - let df = ctx - .sql("SELECT int_col, double_col FROM alltypes_plain") - .await?; + let df = ctx + .sql("SELECT int_col, double_col FROM alltypes_plain") + .await?; - df.clone().show().await?; + df.clone().show().await?; - df.collect().await? - }; - let batches: Vec<_> = batches.iter().collect(); + let mut stream = df.execute_stream().await?; + let mut list = vec![]; + while let Some(b) = stream.next().await.transpose()? { + let int_col = b.column(0).as_primitive::(); + let float_col = b.column(1).as_primitive::(); - // converts it to serde_json type and then convert that into Rust type - let list = arrow::json::writer::record_batches_to_json_rows(&batches[..])? - .into_iter() - .map(|val| serde_json::from_value(serde_json::Value::Object(val))) - .take_while(|val| val.is_ok()) - .map(|val| val.unwrap()) - .collect(); + for (i, f) in int_col.values().iter().zip(float_col.values()) { + list.push(Data { + int_col: *i, + double_col: *f, + }) + } + } Ok(list) } diff --git a/datafusion-examples/examples/flight/flight_server.rs b/datafusion-examples/examples/flight/flight_server.rs index cb7b7c28d909..f9d1b8029f04 100644 --- a/datafusion-examples/examples/flight/flight_server.rs +++ b/datafusion-examples/examples/flight/flight_server.rs @@ -18,7 +18,7 @@ use arrow::ipc::writer::{DictionaryTracker, IpcDataGenerator}; use std::sync::Arc; -use arrow_flight::SchemaAsIpc; +use arrow_flight::{PollInfo, SchemaAsIpc}; use datafusion::arrow::error::ArrowError; use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::listing::{ListingOptions, ListingTableUrl}; @@ -177,6 +177,13 @@ impl FlightService for FlightServiceImpl { ) -> Result, Status> { Err(Status::unimplemented("Not yet implemented")) } + + async fn poll_flight_info( + &self, + _request: Request, + ) -> Result, Status> { + Err(Status::unimplemented("Not yet implemented")) + } } fn to_tonic_err(e: datafusion::error::DataFusionError) -> Status { diff --git a/datafusion-examples/examples/flight/flight_sql_server.rs b/datafusion-examples/examples/flight/flight_sql_server.rs index 35d475623062..ed9457643b7d 100644 --- a/datafusion-examples/examples/flight/flight_sql_server.rs +++ b/datafusion-examples/examples/flight/flight_sql_server.rs @@ -307,6 +307,8 @@ impl FlightSqlService for FlightSqlServiceImpl { let endpoint = FlightEndpoint { ticket: Some(ticket), location: vec![], + expiration_time: None, + app_metadata: Default::default(), }; let endpoints = vec![endpoint]; @@ -329,6 +331,7 @@ impl FlightSqlService for FlightSqlServiceImpl { total_records: -1_i64, total_bytes: -1_i64, ordered: false, + app_metadata: Default::default(), }; let resp = Response::new(info); Ok(resp) diff --git a/datafusion/common/src/file_options/parquet_writer.rs b/datafusion/common/src/file_options/parquet_writer.rs index e8a350e8d389..28e73ba48f53 100644 --- a/datafusion/common/src/file_options/parquet_writer.rs +++ b/datafusion/common/src/file_options/parquet_writer.rs @@ -156,6 +156,7 @@ pub(crate) fn parse_encoding_string( "plain" => Ok(parquet::basic::Encoding::PLAIN), "plain_dictionary" => Ok(parquet::basic::Encoding::PLAIN_DICTIONARY), "rle" => Ok(parquet::basic::Encoding::RLE), + #[allow(deprecated)] "bit_packed" => Ok(parquet::basic::Encoding::BIT_PACKED), "delta_binary_packed" => Ok(parquet::basic::Encoding::DELTA_BINARY_PACKED), "delta_length_byte_array" => { diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 316624175e1c..a2484e93e812 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1650,7 +1650,11 @@ impl ScalarValue { | DataType::Duration(_) | DataType::Union(_, _) | DataType::Map(_, _) - | DataType::RunEndEncoded(_, _) => { + | DataType::RunEndEncoded(_, _) + | DataType::Utf8View + | DataType::BinaryView + | DataType::ListView(_) + | DataType::LargeListView(_) => { return _internal_err!( "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, @@ -5769,7 +5773,7 @@ mod tests { let batch = RecordBatch::try_from_iter(vec![("s", arr as _)]).unwrap(); #[rustfmt::skip] - let expected = [ + let expected = [ "+---+", "| s |", "+---+", @@ -5803,7 +5807,7 @@ mod tests { &DataType::List(Arc::new(Field::new( "item", DataType::Timestamp(TimeUnit::Millisecond, Some(s.into())), - true + true, ))) ); } diff --git a/datafusion/core/src/datasource/avro_to_arrow/schema.rs b/datafusion/core/src/datasource/avro_to_arrow/schema.rs index 761e6b62680f..039a6aacc07e 100644 --- a/datafusion/core/src/datasource/avro_to_arrow/schema.rs +++ b/datafusion/core/src/datasource/avro_to_arrow/schema.rs @@ -224,6 +224,12 @@ fn default_field_name(dt: &DataType) -> &str { DataType::RunEndEncoded(_, _) => { unimplemented!("RunEndEncoded support not implemented") } + DataType::Utf8View + | DataType::BinaryView + | DataType::ListView(_) + | DataType::LargeListView(_) => { + unimplemented!("View support not implemented") + } DataType::Decimal128(_, _) => "decimal", DataType::Decimal256(_, _) => "decimal", } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index c04c536e7ca6..b7626d41f4dd 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -78,9 +78,6 @@ use hashbrown::HashMap; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; -/// Size of the buffer for [`AsyncArrowWriter`]. -const PARQUET_WRITER_BUFFER_SIZE: usize = 10485760; - /// Initial writing buffer size. Note this is just a size hint for efficiency. It /// will grow beyond the set value if needed. const INITIAL_BUFFER_BYTES: usize = 1048576; @@ -626,7 +623,6 @@ impl ParquetSink { let writer = AsyncArrowWriter::try_new( multipart_writer, self.get_writer_schema(), - PARQUET_WRITER_BUFFER_SIZE, Some(parquet_props), )?; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index a2e645cf3e72..282cd624d036 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -701,12 +701,8 @@ pub async fn plan_to_parquet( let (_, multipart_writer) = storeref.put_multipart(&file).await?; let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let mut writer = AsyncArrowWriter::try_new( - multipart_writer, - plan.schema(), - 10485760, - propclone, - )?; + let mut writer = + AsyncArrowWriter::try_new(multipart_writer, plan.schema(), propclone)?; while let Some(next_batch) = stream.next().await { let batch = next_batch?; writer.write(&batch).await?; diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 1f00f5bc3137..5d2719bf0365 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -18,16 +18,14 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::types::ArrowTemporalType; -use arrow::array::{Array, ArrayRef, ArrowNumericType, Float64Array, PrimitiveArray}; -use arrow::compute::cast; -use arrow::compute::kernels::temporal; +use arrow::array::{Array, ArrayRef, Float64Array}; +use arrow::compute::{binary, cast, date_part, DatePart}; use arrow::datatypes::DataType::{Date32, Date64, Float64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_timestamp_microsecond_array, + as_date32_array, as_date64_array, as_int32_array, as_timestamp_microsecond_array, as_timestamp_millisecond_array, as_timestamp_nanosecond_array, as_timestamp_second_array, }; @@ -78,46 +76,6 @@ impl DatePartFunc { } } -macro_rules! extract_date_part { - ($ARRAY: expr, $FN:expr) => { - match $ARRAY.data_type() { - DataType::Date32 => { - let array = as_date32_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - DataType::Date64 => { - let array = as_date64_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - DataType::Timestamp(time_unit, _) => match time_unit { - TimeUnit::Second => { - let array = as_timestamp_second_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Millisecond => { - let array = as_timestamp_millisecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Microsecond => { - let array = as_timestamp_microsecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - TimeUnit::Nanosecond => { - let array = as_timestamp_nanosecond_array($ARRAY)?; - Ok($FN(array) - .map(|v| cast(&(Arc::new(v) as ArrayRef), &DataType::Float64))?) - } - }, - datatype => exec_err!("Extract does not support datatype {:?}", datatype), - } - }; -} - impl ScalarUDFImpl for DatePartFunc { fn as_any(&self) -> &dyn Any { self @@ -139,16 +97,15 @@ impl ScalarUDFImpl for DatePartFunc { if args.len() != 2 { return exec_err!("Expected two arguments in DATE_PART"); } - let (date_part, array) = (&args[0], &args[1]); + let (part, array) = (&args[0], &args[1]); - let date_part = - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = date_part { - v - } else { - return exec_err!( - "First argument of `DATE_PART` must be non-null scalar Utf8" - ); - }; + let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { + v + } else { + return exec_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ); + }; let is_scalar = matches!(array, ColumnarValue::Scalar(_)); @@ -157,28 +114,28 @@ impl ScalarUDFImpl for DatePartFunc { ColumnarValue::Scalar(scalar) => scalar.to_array()?, }; - let arr = match date_part.to_lowercase().as_str() { - "year" => extract_date_part!(&array, temporal::year), - "quarter" => extract_date_part!(&array, temporal::quarter), - "month" => extract_date_part!(&array, temporal::month), - "week" => extract_date_part!(&array, temporal::week), - "day" => extract_date_part!(&array, temporal::day), - "doy" => extract_date_part!(&array, temporal::doy), - "dow" => extract_date_part!(&array, temporal::num_days_from_sunday), - "hour" => extract_date_part!(&array, temporal::hour), - "minute" => extract_date_part!(&array, temporal::minute), - "second" => extract_date_part!(&array, seconds), - "millisecond" => extract_date_part!(&array, millis), - "microsecond" => extract_date_part!(&array, micros), - "nanosecond" => extract_date_part!(&array, nanos), - "epoch" => extract_date_part!(&array, epoch), - _ => exec_err!("Date part '{date_part}' not supported"), - }?; + let arr = match part.to_lowercase().as_str() { + "year" => date_part_f64(array.as_ref(), DatePart::Year)?, + "quarter" => date_part_f64(array.as_ref(), DatePart::Quarter)?, + "month" => date_part_f64(array.as_ref(), DatePart::Month)?, + "week" => date_part_f64(array.as_ref(), DatePart::Week)?, + "day" => date_part_f64(array.as_ref(), DatePart::Day)?, + "doy" => date_part_f64(array.as_ref(), DatePart::DayOfYear)?, + "dow" => date_part_f64(array.as_ref(), DatePart::DayOfWeekSunday0)?, + "hour" => date_part_f64(array.as_ref(), DatePart::Hour)?, + "minute" => date_part_f64(array.as_ref(), DatePart::Minute)?, + "second" => seconds(array.as_ref(), Second)?, + "millisecond" => seconds(array.as_ref(), Millisecond)?, + "microsecond" => seconds(array.as_ref(), Microsecond)?, + "nanosecond" => seconds(array.as_ref(), Nanosecond)?, + "epoch" => epoch(array.as_ref())?, + _ => return exec_err!("Date part '{part}' not supported"), + }; Ok(if is_scalar { - ColumnarValue::Scalar(ScalarValue::try_from_array(&arr?, 0)?) + ColumnarValue::Scalar(ScalarValue::try_from_array(arr.as_ref(), 0)?) } else { - ColumnarValue::Array(arr?) + ColumnarValue::Array(arr) }) } @@ -187,83 +144,52 @@ impl ScalarUDFImpl for DatePartFunc { } } -fn to_ticks(array: &PrimitiveArray, frac: i32) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - let zipped = temporal::second(array)? - .values() - .iter() - .zip(temporal::nanosecond(array)?.values().iter()) - .map(|o| (*o.0 as f64 + (*o.1 as f64) / 1_000_000_000.0) * (frac as f64)) - .collect::>(); - - Ok(Float64Array::from(zipped)) +/// Invoke [`date_part`] and cast the result to Float64 +fn date_part_f64(array: &dyn Array, part: DatePart) -> Result { + Ok(cast(date_part(array, part)?.as_ref(), &Float64)?) } -fn seconds(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1) -} - -fn millis(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000) -} - -fn micros(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000_000) +/// invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the +/// result to a total number of seconds, milliseconds, microseconds or +/// nanoseconds +/// +/// # Panics +/// If `array` is not a temporal type such as Timestamp or Date32 +fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { + let sf = match unit { + Second => 1_f64, + Millisecond => 1_000_f64, + Microsecond => 1_000_000_f64, + Nanosecond => 1_000_000_000_f64, + }; + let secs = date_part(array, DatePart::Second)?; + let secs = as_int32_array(secs.as_ref())?; + let subsecs = date_part(array, DatePart::Nanosecond)?; + let subsecs = as_int32_array(subsecs.as_ref())?; + + let r: Float64Array = binary(secs, subsecs, |secs, subsecs| { + (secs as f64 + (subsecs as f64 / 1_000_000_000_f64)) * sf + })?; + Ok(Arc::new(r)) } -fn nanos(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - to_ticks(array, 1_000_000_000) -} +fn epoch(array: &dyn Array) -> Result { + const SECONDS_IN_A_DAY: f64 = 86400_f64; -fn epoch(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - let b = match array.data_type() { - Timestamp(tu, _) => { - let scale = match tu { - Second => 1, - Millisecond => 1_000, - Microsecond => 1_000_000, - Nanosecond => 1_000_000_000, - } as f64; - array.unary(|n| { - let n: i64 = n.into(); - n as f64 / scale - }) + let f: Float64Array = match array.data_type() { + Timestamp(Second, _) => as_timestamp_second_array(array)?.unary(|x| x as f64), + Timestamp(Millisecond, _) => { + as_timestamp_millisecond_array(array)?.unary(|x| x as f64 / 1_000_f64) + } + Timestamp(Microsecond, _) => { + as_timestamp_microsecond_array(array)?.unary(|x| x as f64 / 1_000_000_f64) } - Date32 => { - let seconds_in_a_day = 86400_f64; - array.unary(|n| { - let n: i64 = n.into(); - n as f64 * seconds_in_a_day - }) + Timestamp(Nanosecond, _) => { + as_timestamp_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) } - Date64 => array.unary(|n| { - let n: i64 = n.into(); - n as f64 / 1_000_f64 - }), - _ => return exec_err!("Can not convert {:?} to epoch", array.data_type()), + Date32 => as_date32_array(array)?.unary(|x| x as f64 * SECONDS_IN_A_DAY), + Date64 => as_date64_array(array)?.unary(|x| x as f64 / 1_000_f64), + d => return exec_err!("Can not convert {d:?} to epoch"), }; - Ok(b) + Ok(Arc::new(f)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 478f7c779552..92015594906b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -230,6 +230,9 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { "Proto serialization error: The RunEndEncoded data type is not yet supported".to_owned() )) } + DataType::Utf8View | DataType::BinaryView | DataType::ListView(_) | DataType::LargeListView(_) => { + return Err(Error::General(format!("Proto serialization error: {val} not yet supported"))) + } }; Ok(res) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 9680177d736f..c26e8481ce43 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -491,11 +491,15 @@ impl Unparser<'_> { DataType::Binary => todo!(), DataType::FixedSizeBinary(_) => todo!(), DataType::LargeBinary => todo!(), + DataType::BinaryView => todo!(), DataType::Utf8 => Ok(ast::DataType::Varchar(None)), DataType::LargeUtf8 => Ok(ast::DataType::Text), + DataType::Utf8View => todo!(), DataType::List(_) => todo!(), DataType::FixedSizeList(_, _) => todo!(), DataType::LargeList(_) => todo!(), + DataType::ListView(_) => todo!(), + DataType::LargeListView(_) => todo!(), DataType::Struct(_) => todo!(), DataType::Union(_, _) => todo!(), DataType::Dictionary(_, _) => todo!(), From 7fab5ac53c1e715743aee7a51111c2976add8a99 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Wed, 20 Mar 2024 00:58:10 +0800 Subject: [PATCH 018/117] Move inlist rule to expr_simplifier (#9692) * move inlist rule to expr_simplifier Signed-off-by: jayzhan211 * clippy Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- .../simplify_expressions/expr_simplifier.rs | 220 +++++++++++++++++- .../simplify_expressions/inlist_simplifier.rs | 122 +--------- .../sqllogictest/test_files/predicates.slt | 2 +- 3 files changed, 210 insertions(+), 134 deletions(-) diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5b5bca75ddb0..61e002ece98b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -21,7 +21,7 @@ use std::borrow::Cow; use std::collections::HashSet; use std::ops::Not; -use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; +use super::inlist_simplifier::ShortenInListSimplifier; use super::utils::*; use crate::analyzer::type_coercion::TypeCoercionRewriter; use crate::simplify_expressions::guarantees::GuaranteeRewriter; @@ -175,7 +175,6 @@ impl ExprSimplifier { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); - let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); if self.canonicalize { @@ -190,8 +189,6 @@ impl ExprSimplifier { .data()? .rewrite(&mut simplifier) .data()? - .rewrite(&mut inlist_simplifier) - .data()? .rewrite(&mut guarantee_rewriter) .data()? // run both passes twice to try an minimize simplifications that we missed @@ -1452,13 +1449,8 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Operator::Or, right, }) if are_inlist_and_eq(left.as_ref(), right.as_ref()) => { - let left = as_inlist(left.as_ref()); - let right = as_inlist(right.as_ref()); - - let lhs = left.unwrap(); - let rhs = right.unwrap(); - let lhs = lhs.into_owned(); - let rhs = rhs.into_owned(); + let lhs = to_inlist(*left).unwrap(); + let rhs = to_inlist(*right).unwrap(); let mut seen: HashSet = HashSet::new(); let list = lhs .list @@ -1473,7 +1465,123 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { negated: false, }; - return Ok(Transformed::yes(Expr::InList(merged_inlist))); + Transformed::yes(Expr::InList(merged_inlist)) + } + + // Simplify expressions that is guaranteed to be true or false to a literal boolean expression + // + // Rules: + // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists + // Intersection: + // 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` + // 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` + // 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` + // Union: + // 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` + // # This rule is handled by `or_in_list_simplifier.rs` + // 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` + // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression + // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` + // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` + // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + false, + false, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_intersection(l1, l2, false).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_union(l1, l2, true).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + false, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_except(l1, l2).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::And, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + false, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_except(l2, l1).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } + } + + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Or, + right, + }) if are_inlist_and_eq_and_match_neg( + left.as_ref(), + right.as_ref(), + true, + true, + ) => + { + match (*left, *right) { + (Expr::InList(l1), Expr::InList(l2)) => { + return inlist_intersection(l1, l2, true).map(Transformed::yes); + } + // Matched previously once + _ => unreachable!(), + } } // no additional rewrites possible @@ -1482,6 +1590,22 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } +// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 +fn are_inlist_and_eq_and_match_neg( + left: &Expr, + right: &Expr, + is_left_neg: bool, + is_right_neg: bool, +) -> bool { + match (left, right) { + (Expr::InList(l), Expr::InList(r)) => { + l.expr == r.expr && l.negated == is_left_neg && r.negated == is_right_neg + } + _ => false, + } +} + +// TODO: We might not need this after defer pattern for Box is stabilized. https://github.com/rust-lang/rust/issues/87121 fn are_inlist_and_eq(left: &Expr, right: &Expr) -> bool { let left = as_inlist(left); let right = as_inlist(right); @@ -1519,6 +1643,78 @@ fn as_inlist(expr: &Expr) -> Option> { } } +fn to_inlist(expr: Expr) -> Option { + match expr { + Expr::InList(inlist) => Some(inlist), + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => match (left.as_ref(), right.as_ref()) { + (Expr::Column(_), Expr::Literal(_)) => Some(InList { + expr: left, + list: vec![*right], + negated: false, + }), + (Expr::Literal(_), Expr::Column(_)) => Some(InList { + expr: right, + list: vec![*left], + negated: false, + }), + _ => None, + }, + _ => None, + } +} + +/// Return the union of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { + // extend the list in l1 with the elements in l2 that are not already in l1 + let l1_items: HashSet<_> = l1.list.iter().collect(); + + // keep all l2 items that do not also appear in l1 + let keep_l2: Vec<_> = l2 + .list + .into_iter() + .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) + .collect(); + + l1.list.extend(keep_l2); + l1.negated = negated; + Ok(Expr::InList(l1)) +} + +/// Return the intersection of two inlist expressions +/// maintaining the order of the elements in the two lists +fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // remove all items from l1 that are not in l2 + l1.list.retain(|e| l2_items.contains(e)); + + // e in () is always false + // e not in () is always true + if l1.list.is_empty() { + return Ok(lit(negated)); + } + Ok(Expr::InList(l1)) +} + +/// Return the all items in l1 that are not in l2 +/// maintaining the order of the elements in the two lists +fn inlist_except(mut l1: InList, l2: InList) -> Result { + let l2_items = l2.list.iter().collect::>(); + + // keep only items from l1 that are not in l2 + l1.list.retain(|e| !l2_items.contains(e)); + + if l1.list.is_empty() { + return Ok(lit(false)); + } + Ok(Expr::InList(l1)) +} + #[cfg(test)] mod tests { use std::{ diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 5d1cf27827a9..9dcb8ed15563 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -19,12 +19,10 @@ use super::THRESHOLD_INLINE_INLIST; -use std::collections::HashSet; - use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; -use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; +use datafusion_expr::Expr; pub(super) struct ShortenInListSimplifier {} @@ -97,121 +95,3 @@ impl TreeNodeRewriter for ShortenInListSimplifier { Ok(Transformed::no(expr)) } } - -pub(super) struct InListSimplifier {} - -impl InListSimplifier { - pub(super) fn new() -> Self { - Self {} - } -} - -impl TreeNodeRewriter for InListSimplifier { - type Node = Expr; - - fn f_up(&mut self, expr: Expr) -> Result> { - // Simplify expressions that is guaranteed to be true or false to a literal boolean expression - // - // Rules: - // If both expressions are `IN` or `NOT IN`, then we can apply intersection or union on both lists - // Intersection: - // 1. `a in (1,2,3) AND a in (4,5) -> a in (), which is false` - // 2. `a in (1,2,3) AND a in (2,3,4) -> a in (2,3)` - // 3. `a not in (1,2,3) OR a not in (3,4,5,6) -> a not in (3)` - // Union: - // 4. `a not int (1,2,3) AND a not in (4,5,6) -> a not in (1,2,3,4,5,6)` - // # This rule is handled by `or_in_list_simplifier.rs` - // 5. `a in (1,2,3) OR a in (4,5,6) -> a in (1,2,3,4,5,6)` - // If one of the expressions is `IN` and another one is `NOT IN`, then we apply exception on `In` expression - // 6. `a in (1,2,3,4) AND a not in (1,2,3,4,5) -> a in (), which is false` - // 7. `a not in (1,2,3,4) AND a in (1,2,3,4,5) -> a = 5` - // 8. `a in (1,2,3,4) AND a not in (5,6,7,8) -> a in (1,2,3,4)` - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr.clone() { - match (*left, op, *right) { - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && !l1.negated && !l2.negated => - { - return inlist_intersection(l1, l2, false).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && l2.negated => - { - return inlist_union(l1, l2, true).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && !l1.negated && l2.negated => - { - return inlist_except(l1, l2).map(Transformed::yes); - } - (Expr::InList(l1), Operator::And, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && !l2.negated => - { - return inlist_except(l2, l1).map(Transformed::yes); - } - (Expr::InList(l1), Operator::Or, Expr::InList(l2)) - if l1.expr == l2.expr && l1.negated && l2.negated => - { - return inlist_intersection(l1, l2, true).map(Transformed::yes); - } - (left, op, right) => { - // put the expression back together - return Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { - left: Box::new(left), - op, - right: Box::new(right), - }))); - } - } - } - - Ok(Transformed::no(expr)) - } -} - -/// Return the union of two inlist expressions -/// maintaining the order of the elements in the two lists -fn inlist_union(mut l1: InList, l2: InList, negated: bool) -> Result { - // extend the list in l1 with the elements in l2 that are not already in l1 - let l1_items: HashSet<_> = l1.list.iter().collect(); - - // keep all l2 items that do not also appear in l1 - let keep_l2: Vec<_> = l2 - .list - .into_iter() - .filter_map(|e| if l1_items.contains(&e) { None } else { Some(e) }) - .collect(); - - l1.list.extend(keep_l2); - l1.negated = negated; - Ok(Expr::InList(l1)) -} - -/// Return the intersection of two inlist expressions -/// maintaining the order of the elements in the two lists -fn inlist_intersection(mut l1: InList, l2: InList, negated: bool) -> Result { - let l2_items = l2.list.iter().collect::>(); - - // remove all items from l1 that are not in l2 - l1.list.retain(|e| l2_items.contains(e)); - - // e in () is always false - // e not in () is always true - if l1.list.is_empty() { - return Ok(lit(negated)); - } - Ok(Expr::InList(l1)) -} - -/// Return the all items in l1 that are not in l2 -/// maintaining the order of the elements in the two lists -fn inlist_except(mut l1: InList, l2: InList) -> Result { - let l2_items = l2.list.iter().collect::>(); - - // keep only items from l1 that are not in l2 - l1.list.retain(|e| !l2_items.contains(e)); - - if l1.list.is_empty() { - return Ok(lit(false)); - } - Ok(Expr::InList(l1)) -} diff --git a/datafusion/sqllogictest/test_files/predicates.slt b/datafusion/sqllogictest/test_files/predicates.slt index 4c9254beef6b..33c9ff7c3eed 100644 --- a/datafusion/sqllogictest/test_files/predicates.slt +++ b/datafusion/sqllogictest/test_files/predicates.slt @@ -781,4 +781,4 @@ logical_plan EmptyRelation physical_plan EmptyExec statement ok -drop table t; +drop table t; \ No newline at end of file From 09747596fd75bfce8903e86472cccb8acc524453 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 19 Mar 2024 11:49:36 -0600 Subject: [PATCH 019/117] Support Serde for ScalarUDF in Physical Expressions (#9436) * initial try * revert * stage commit * use ScalarFunctionDefinition to rewrite PhysicalExpr proto * cargo fmt * feat : add test * fix bug * fix wrong delete code when resolve conflict * Update datafusion/proto/src/physical_plan/to_proto.rs Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> * Update datafusion/proto/tests/cases/roundtrip_physical_plan.rs Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> * address the comment --------- Co-authored-by: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> --- .../physical_optimizer/projection_pushdown.rs | 58 ++- datafusion/physical-expr/src/functions.rs | 10 +- .../physical-expr/src/scalar_function.rs | 26 +- datafusion/physical-expr/src/udf.rs | 7 +- datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 21 + datafusion/proto/src/generated/prost.rs | 2 + .../proto/src/physical_plan/from_proto.rs | 44 +- datafusion/proto/src/physical_plan/mod.rs | 139 ++++-- .../proto/src/physical_plan/to_proto.rs | 433 ++++++++++-------- .../tests/cases/roundtrip_physical_plan.rs | 157 ++++++- 11 files changed, 634 insertions(+), 264 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index ab5611597472..ed445e6d48b8 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1287,6 +1287,7 @@ fn new_join_children( #[cfg(test)] mod tests { use super::*; + use std::any::Any; use std::sync::Arc; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -1313,7 +1314,10 @@ mod tests { use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_expr::{ + ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, + }; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; @@ -1329,6 +1333,42 @@ mod tests { use itertools::Itertools; + /// Mocked UDF + #[derive(Debug)] + struct DummyUDF { + signature: Signature, + } + + impl DummyUDF { + fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } + } + #[test] fn test_update_matching_exprs() -> Result<()> { let exprs: Vec> = vec![ @@ -1345,7 +1385,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1412,7 +1454,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1482,7 +1526,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1549,7 +1595,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c6c185e002f0..e76e7f56dc95 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -44,6 +44,7 @@ use arrow_array::Array; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; pub use datafusion_expr::FuncMonotonicity; +use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{ type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, @@ -57,7 +58,7 @@ pub fn create_physical_expr( fun: &BuiltinScalarFunction, input_phy_exprs: &[Arc], input_schema: &Schema, - execution_props: &ExecutionProps, + _execution_props: &ExecutionProps, ) -> Result> { let input_expr_types = input_phy_exprs .iter() @@ -69,14 +70,12 @@ pub fn create_physical_expr( let data_type = fun.return_type(&input_expr_types)?; - let fun_expr: ScalarFunctionImplementation = - create_physical_fun(fun, execution_props)?; - let monotonicity = fun.monotonicity(); + let fun_def = ScalarFunctionDefinition::BuiltIn(*fun); Ok(Arc::new(ScalarFunctionExpr::new( &format!("{fun}"), - fun_expr, + fun_def, input_phy_exprs.to_vec(), data_type, monotonicity, @@ -195,7 +194,6 @@ where /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, - _execution_props: &ExecutionProps, ) -> Result { Ok(match fun { // math functions diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 1c9f0e609c3c..d34084236690 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,22 +34,22 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::functions::out_ordering; +use crate::functions::{create_physical_fun, out_ordering}; use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use datafusion_expr::{ expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, - ScalarFunctionImplementation, + ScalarFunctionDefinition, }; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, name: String, args: Vec>, return_type: DataType, @@ -79,7 +79,7 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, args: Vec>, return_type: DataType, monotonicity: Option, @@ -96,7 +96,7 @@ impl ScalarFunctionExpr { } /// Get the scalar function implementation - pub fn fun(&self) -> &ScalarFunctionImplementation { + pub fn fun(&self) -> &ScalarFunctionDefinition { &self.fun } @@ -172,8 +172,18 @@ impl PhysicalExpr for ScalarFunctionExpr { }; // evaluate the function - let fun = self.fun.as_ref(); - (fun)(&inputs) + match self.fun { + ScalarFunctionDefinition::BuiltIn(ref fun) => { + let fun = create_physical_fun(fun)?; + (fun)(&inputs) + } + ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs), + ScalarFunctionDefinition::Name(_) => { + internal_err!( + "Name function must be resolved to one of the other variants prior to physical planning" + ) + } + } } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index ede3e5badbb1..4fc94bfa15ec 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -20,7 +20,9 @@ use crate::{PhysicalExpr, ScalarFunctionExpr}; use arrow_schema::Schema; use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; -use datafusion_expr::{type_coercion::functions::data_types, Expr}; +use datafusion_expr::{ + type_coercion::functions::data_types, Expr, ScalarFunctionDefinition, +}; use std::sync::Arc; /// Create a physical expression of the UDF. @@ -45,9 +47,10 @@ pub fn create_physical_expr( let return_type = fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun.fun(), + fun_def, input_phy_exprs.to_vec(), return_type, fun.monotonicity()?, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6879f70cd05c..7a9b427ce7d3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1458,6 +1458,7 @@ message PhysicalExprNode { message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; + optional bytes fun_definition = 3; ArrowType return_type = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 75c135fd01b4..fd27520b3be0 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20391,6 +20391,9 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } if self.return_type.is_some() { len += 1; } @@ -20401,6 +20404,10 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.return_type.as_ref() { struct_ser.serialize_field("returnType", v)?; } @@ -20416,6 +20423,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { const FIELDS: &[&str] = &[ "name", "args", + "fun_definition", + "funDefinition", "return_type", "returnType", ]; @@ -20424,6 +20433,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { enum GeneratedField { Name, Args, + FunDefinition, ReturnType, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -20448,6 +20458,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { match value { "name" => Ok(GeneratedField::Name), "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "returnType" | "return_type" => Ok(GeneratedField::ReturnType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -20470,6 +20481,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { let mut name__ = None; let mut args__ = None; + let mut fun_definition__ = None; let mut return_type__ = None; while let Some(k) = map_.next_key()? { match k { @@ -20485,6 +20497,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } args__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::ReturnType => { if return_type__.is_some() { return Err(serde::de::Error::duplicate_field("returnType")); @@ -20496,6 +20516,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { Ok(PhysicalScalarUdfNode { name: name__.unwrap_or_default(), args: args__.unwrap_or_default(), + fun_definition: fun_definition__, return_type: return_type__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c9cc4a9b073b..16ad2b848db9 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2092,6 +2092,8 @@ pub struct PhysicalScalarUdfNode { pub name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(message, optional, tag = "4")] pub return_type: ::core::option::Option, } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 184c048c1bdd..ca54d4e803ca 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -59,9 +59,12 @@ use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; use chrono::{TimeZone, Utc}; +use datafusion_expr::ScalarFunctionDefinition; use object_store::path::Path; use object_store::ObjectMeta; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { Column::new(&c.name, c.index as usize) @@ -82,7 +85,8 @@ pub fn parse_physical_sort_expr( input_schema: &Schema, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -110,7 +114,9 @@ pub fn parse_physical_sort_exprs( .iter() .map(|sort_expr| { if let Some(expr) = &sort_expr.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = + parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; let options = SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -137,16 +143,17 @@ pub fn parse_physical_window_expr( registry: &dyn FunctionRegistry, input_schema: &Schema, ) -> Result> { + let codec = DefaultPhysicalExtensionCodec {}; let window_node_expr = proto .args .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) + .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .collect::>>()?; let partition_by = proto .partition_by .iter() - .map(|p| parse_physical_expr(p, registry, input_schema)) + .map(|p| parse_physical_expr(p, registry, input_schema, &codec)) .collect::>>()?; let order_by = proto @@ -191,6 +198,7 @@ pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { let expr_type = proto .expr_type @@ -270,7 +278,7 @@ pub fn parse_physical_expr( )?, e.list .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?, &e.negated, input_schema, @@ -278,7 +286,7 @@ pub fn parse_physical_expr( ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, e.when_then_expr .iter() @@ -301,7 +309,7 @@ pub fn parse_physical_expr( .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( @@ -334,7 +342,7 @@ pub fn parse_physical_expr( let args = e .args .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?; // TODO Do not create new the ExecutionProps @@ -348,19 +356,22 @@ pub fn parse_physical_expr( )? } ExprType::ScalarUdf(e) => { - let udf = registry.udf(e.name.as_str())?; + let udf = match &e.fun_definition { + Some(buf) => codec.try_decode_udf(&e.name, buf)?, + None => registry.udf(e.name.as_str())?, + }; let signature = udf.signature(); - let scalar_fun = udf.fun().clone(); + let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone()); let args = e .args .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?; Arc::new(ScalarFunctionExpr::new( e.name.as_str(), - scalar_fun, + scalar_fun_def, args, convert_required!(e.return_type)?, None, @@ -394,7 +405,8 @@ fn parse_required_physical_expr( field: &str, input_schema: &Schema, ) -> Result> { - expr.map(|e| parse_physical_expr(e, registry, input_schema)) + let codec = DefaultPhysicalExtensionCodec {}; + expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal(format!("Missing required field {field:?}")) @@ -439,10 +451,11 @@ pub fn parse_protobuf_hash_partitioning( ) -> Result> { match partitioning { Some(hash_part) => { + let codec = DefaultPhysicalExtensionCodec {}; let expr = hash_part .hash_expr .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) + .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .collect::>, _>>()?; Ok(Some(Partitioning::Hash( @@ -503,6 +516,7 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { + let codec = DefaultPhysicalExtensionCodec {}; let sort_expr = node_collection .physical_sort_expr_nodes .iter() @@ -510,7 +524,7 @@ pub fn parse_protobuf_file_scan_config( let expr = node .expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, &schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, &schema, &codec)) .unwrap()?; Ok(PhysicalSortExpr { expr, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 004948da938f..da31c5e762bc 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -20,6 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use self::from_proto::parse_physical_window_expr; +use self::to_proto::serialize_physical_expr; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::convert_required; @@ -138,7 +139,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr(expr, registry, input.schema().as_ref())?, + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + )?, name.to_string(), )) }) @@ -156,7 +162,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .transpose()? .ok_or_else(|| { @@ -208,6 +219,7 @@ impl AsExecutionPlan for PhysicalPlanNode { expr, registry, base_config.file_schema.as_ref(), + extension_codec, ) }) .transpose()?; @@ -254,7 +266,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .hash_expr .iter() .map(|e| { - parse_physical_expr(e, registry, input.schema().as_ref()) + parse_physical_expr( + e, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>, _>>()?; @@ -329,7 +346,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>>>()?; @@ -396,8 +418,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -406,8 +433,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -434,7 +466,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| { expr.expr .as_ref() - .map(|e| parse_physical_expr(e, registry, &physical_schema)) + .map(|e| { + parse_physical_expr( + e, + registry, + &physical_schema, + extension_codec, + ) + }) .transpose() }) .collect::, _>>()?; @@ -451,7 +490,7 @@ impl AsExecutionPlan for PhysicalPlanNode { match expr_type { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() - .map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect(); + .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); let ordering_req: Vec = agg_node.ordering_req.iter() .map(|e| parse_physical_sort_expr(e, registry, &physical_schema).unwrap()).collect(); agg_node.aggregate_function.as_ref().map(|func| { @@ -524,11 +563,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -555,6 +596,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -635,11 +677,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -666,6 +710,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -805,7 +850,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -852,7 +897,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -916,6 +961,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -1088,7 +1134,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let expr = exec .expr() .iter() - .map(|expr| expr.0.clone().try_into()) + .map(|expr| serialize_physical_expr(expr.0.clone(), extension_codec)) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); return Ok(protobuf::PhysicalPlanNode { @@ -1128,7 +1174,10 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(exec.predicate().clone().try_into()?), + expr: Some(serialize_physical_expr( + exec.predicate().clone(), + extension_codec, + )?), default_filter_selectivity: exec.default_selectivity() as u32, }, ))), @@ -1183,8 +1232,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1196,7 +1245,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1254,8 +1306,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1267,7 +1319,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1304,7 +1359,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1321,7 +1379,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1423,14 +1484,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .group_expr() .null_expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1512,7 +1573,7 @@ impl AsExecutionPlan for PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() - .map(|pred| pred.clone().try_into()) + .map(|pred| serialize_physical_expr(pred.clone(), extension_codec)) .transpose()?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( @@ -1559,7 +1620,9 @@ impl AsExecutionPlan for PhysicalPlanNode { PartitionMethod::Hash(protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() - .map(|expr| expr.clone().try_into()) + .map(|expr| { + serialize_physical_expr(expr.clone(), extension_codec) + }) .collect::>>()?, partition_count: *partition_count as u64, }) @@ -1592,7 +1655,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1658,7 +1724,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1695,7 +1764,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1743,7 +1815,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1773,7 +1845,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -1816,7 +1888,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ba77b30b7f8d..b66709d0c5bd 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,7 +22,6 @@ use std::{ sync::Arc, }; -use crate::logical_plan::csv_writer_options_to_proto; use crate::protobuf::{ self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode, @@ -31,13 +30,10 @@ use crate::protobuf::{ #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::datasource::{ - file_format::csv::CsvSink, - file_format::json::JsonSink, - listing::{FileRange, PartitionedFile}, - physical_plan::FileScanConfig, - physical_plan::FileSinkConfig, -}; + +use datafusion_expr::ScalarFunctionDefinition; + +use crate::logical_plan::csv_writer_options_to_proto; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; @@ -46,16 +42,24 @@ use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, LastValue, LikeExpr, Literal, Max, Median, - Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, - Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, - TryCastExpr, Variance, VariancePop, WindowShift, + InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, + NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, + RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, + Variance, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion::{ + datasource::{ + file_format::{csv::CsvSink, json::JsonSink}, + listing::{FileRange, PartitionedFile}, + physical_plan::{FileScanConfig, FileSinkConfig}, + }, + physical_plan::expressions::LikeExpr, +}; use datafusion_common::config::{ ColumnOptions, CsvOptions, FormatOptions, JsonOptions, ParquetOptions, TableParquetOptions, @@ -68,14 +72,17 @@ use datafusion_common::{ DataFusionError, JoinSide, Result, }; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; fn try_from(a: Arc) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let expressions: Vec = a .expressions() .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), &codec)) .collect::>>()?; let ordering_req: Vec = a @@ -237,16 +244,16 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - + let codec = DefaultPhysicalExtensionCodec {}; let args = args .into_iter() - .map(|e| e.try_into()) + .map(|e| serialize_physical_expr(e, &codec)) .collect::>>()?; let partition_by = window_expr .partition_by() .iter() - .map(|p| p.clone().try_into()) + .map(|p| serialize_physical_expr(p.clone(), &codec)) .collect::>>()?; let order_by = window_expr @@ -374,195 +381,250 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { Ok(AggrFn { inner, distinct }) } -impl TryFrom> for protobuf::PhysicalExprNode { - type Error = DataFusionError; - - fn try_from(value: Arc) -> Result { - let expr = value.as_any(); - - if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Column( - protobuf::PhysicalColumn { - name: expr.name().to_string(), - index: expr.index() as u32, - }, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(expr.left().to_owned().try_into()?)), - r: Some(Box::new(expr.right().to_owned().try_into()?)), - op: format!("{:?}", expr.op()), - }); +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]) +pub fn serialize_physical_expr( + value: Arc, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let expr = value.as_any(); + + if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: expr.name().to_string(), + index: expr.index() as u32, + }, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { + l: Some(Box::new(serialize_physical_expr( + expr.left().clone(), + codec, + )?)), + r: Some(Box::new(serialize_physical_expr( + expr.right().clone(), + codec, + )?)), + op: format!("{:?}", expr.op()), + }); - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( - binary_expr, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::Case( - Box::new( - protobuf::PhysicalCaseNode { - expr: expr - .expr() - .map(|exp| exp.clone().try_into().map(Box::new)) - .transpose()?, - when_then_expr: expr - .when_then_expr() - .iter() - .map(|(when_expr, then_expr)| { - try_parse_when_then_expr(when_expr, then_expr) - }) - .collect::, - Self::Error, - >>()?, - else_expr: expr - .else_expr() - .map(|a| a.clone().try_into().map(Box::new)) - .transpose()?, - }, - ), + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( + binary_expr, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::Case( + Box::new( + protobuf::PhysicalCaseNode { + expr: expr + .expr() + .map(|exp| { + serialize_physical_expr(exp.clone(), codec) + .map(Box::new) + }) + .transpose()?, + when_then_expr: expr + .when_then_expr() + .iter() + .map(|(when_expr, then_expr)| { + try_parse_when_then_expr(when_expr, then_expr, codec) + }) + .collect::, + DataFusionError, + >>()?, + else_expr: expr + .else_expr() + .map(|a| { + serialize_physical_expr(a.clone(), codec) + .map(Box::new) + }) + .transpose()?, + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr( - Box::new(protobuf::PhysicalNot { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( - Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::InList( - Box::new( - protobuf::PhysicalInListNode { - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - list: expr - .list() - .iter() - .map(|a| a.clone().try_into()) - .collect::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( + protobuf::PhysicalNot { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( + Box::new(protobuf::PhysicalIsNotNull { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::InList( + Box::new( + protobuf::PhysicalInListNode { + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + list: expr + .list() + .iter() + .map(|a| serialize_physical_expr(a.clone(), codec)) + .collect::, - Self::Error, + DataFusionError, >>()?, - negated: expr.negated(), - }, - ), + negated: expr.negated(), + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Negative( - Box::new(protobuf::PhysicalNegativeNode { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(lit) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( - lit.value().try_into()?, - )), - }) - } else if let Some(cast) = expr.downcast_ref::() { + ), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( + protobuf::PhysicalNegativeNode { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(lit) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( + lit.value().try_into()?, + )), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( + protobuf::PhysicalCastNode { + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let args: Vec = expr + .args() + .iter() + .map(|e| serialize_physical_expr(e.to_owned(), codec)) + .collect::, _>>()?; + if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { + let fun: protobuf::ScalarFunction = (&fun).try_into()?; + Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( - protobuf::PhysicalCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction( + protobuf::PhysicalScalarFunctionNode { + name: expr.name().to_string(), + fun: fun.into(), + args, + return_type: Some(expr.return_type().try_into()?), }, - ))), - }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast( - Box::new(protobuf::PhysicalTryCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), - }), )), }) - } else if let Some(expr) = expr.downcast_ref::() { - let args: Vec = expr - .args() - .iter() - .map(|e| e.to_owned().try_into()) - .collect::, _>>()?; - if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { - let fun: protobuf::ScalarFunction = (&fun).try_into()?; - - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::ScalarFunction( - protobuf::PhysicalScalarFunctionNode { - name: expr.name().to_string(), - fun: fun.into(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - ), - ), - }) - } else { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( - protobuf::PhysicalScalarUdfNode { - name: expr.name().to_string(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - )), - }) + } else { + let mut buf = Vec::new(); + match expr.fun() { + ScalarFunctionDefinition::UDF(udf) => { + codec.try_encode_udf(udf, &mut buf)?; + } + _ => { + return not_impl_err!( + "Proto serialization error: Trying to serialize a unresolved function" + ); + } } - } else if let Some(expr) = expr.downcast_ref::() { + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr( - Box::new(protobuf::PhysicalLikeExprNode { - negated: expr.negated(), - case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - pattern: Some(Box::new(expr.pattern().to_owned().try_into()?)), - }), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( + protobuf::PhysicalScalarUdfNode { + name: expr.name().to_string(), + args, + fun_definition, + return_type: Some(expr.return_type().try_into()?), + }, )), }) - } else { - internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( + protobuf::PhysicalLikeExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + pattern: Some(Box::new(serialize_physical_expr( + expr.pattern().to_owned(), + codec, + )?)), + }, + ))), + }) + } else { + internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } } fn try_parse_when_then_expr( when_expr: &Arc, then_expr: &Arc, + codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(when_expr.clone().try_into()?), - then_expr: Some(then_expr.clone().try_into()?), + when_expr: Some(serialize_physical_expr(when_expr.clone(), codec)?), + then_expr: Some(serialize_physical_expr(then_expr.clone(), codec)?), }) } @@ -683,6 +745,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { fn try_from( conf: &FileScanConfig, ) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let file_groups = conf .file_groups .iter() @@ -694,7 +757,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { let expr_node_vec = order .iter() .map(|sort_expr| { - let expr = sort_expr.expr.clone().try_into()?; + let expr = serialize_physical_expr(sort_expr.expr.clone(), &codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !sort_expr.options.descending, @@ -757,10 +820,11 @@ impl TryFrom>> for protobuf::MaybeFilter { type Error = DataFusionError; fn try_from(expr: Option>) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(expr.try_into()?), + expr: Some(serialize_physical_expr(expr, &codec)?), }), } } @@ -786,8 +850,9 @@ impl TryFrom for protobuf::PhysicalSortExprNode { type Error = DataFusionError; fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result { + let codec = DefaultPhysicalExtensionCodec {}; Ok(PhysicalSortExprNode { - expr: Some(Box::new(sort_expr.expr.try_into()?)), + expr: Some(Box::new(serialize_physical_expr(sort_expr.expr, &codec)?)), asc: !sort_expr.options.descending, nulls_first: sort_expr.options.nulls_first, }) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7f0c6286a19d..4924128ae190 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -32,7 +33,7 @@ use datafusion::datasource::physical_plan::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, FileSinkConfig, ParquetExec, }; -use datafusion::execution::context::ExecutionProps; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; @@ -49,7 +50,6 @@ use datafusion::physical_plan::expressions::{ NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum, }; use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::functions; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, @@ -73,13 +73,19 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature, - SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, + ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::physical_plan::{ + AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; -use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +use prost::Message; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -603,14 +609,11 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); - let execution_props = ExecutionProps::new(); - - let fun_expr = - functions::create_physical_fun(&BuiltinScalarFunction::Sin, &execution_props)?; + let fun_def = ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Sin); let expr = ScalarFunctionExpr::new( "sin", - fun_expr, + fun_def, vec![col("a", &schema)?], DataType::Float64, None, @@ -646,9 +649,11 @@ fn roundtrip_scalar_udf() -> Result<()> { scalar_fn.clone(), ); + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone())); + let expr = ScalarFunctionExpr::new( "dummy", - scalar_fn, + fun_def, vec![col("a", &schema)?], DataType::Int64, None, @@ -665,6 +670,134 @@ fn roundtrip_scalar_udf() -> Result<()> { roundtrip_test_with_context(Arc::new(project), ctx) } +#[test] +fn roundtrip_scalar_udf_extension_codec() { + #[derive(Debug)] + struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, + } + + impl MyRegexUdf { + fn new(pattern: String) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + pattern, + } + } + } + + /// Implement the ScalarUDFImpl trait for MyRegexUdf + impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.first(), Some(&DataType::Utf8)) { + return plan_err!("regex_udf only accepts Utf8 arguments"); + } + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, + } + + #[derive(Debug)] + pub struct ScalarUDFExtensionCodec {} + + impl PhysicalExtensionCodec for ScalarUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!("No extension codec provided") + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode regex_udf: {}", + err + )) + })?; + + Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( + proto.pattern, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") + } + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + } + Ok(()) + } + } + + let pattern = ".*"; + let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); + let test_expr = ScalarFunctionExpr::new( + udf.name(), + ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), + vec![], + DataType::Int32, + None, + false, + ); + let fmt_expr = format!("{test_expr:?}"); + let ctx = SessionContext::new(); + + ctx.register_udf(udf.clone()); + let extension_codec = ScalarUDFExtensionCodec {}; + let proto: protobuf::PhysicalExprNode = + match serialize_physical_expr(Arc::new(test_expr), &extension_codec) { + Ok(proto) => proto, + Err(e) => panic!("failed to serialize expr: {e:?}"), + }; + let field_a = Field::new("a", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![field_a])); + let round_trip = + parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap(); + assert_eq!(fmt_expr, format!("{round_trip:?}")); +} #[test] fn roundtrip_distinct_count() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); From 8074ca1e758470319699a562074290906003b312 Mon Sep 17 00:00:00 2001 From: Brent Gardner Date: Tue, 19 Mar 2024 12:14:13 -0600 Subject: [PATCH 020/117] Support Union types in `ScalarValue` (#9683) Support Union types in `ScalarValue` (#9683) --- datafusion/common/src/error.rs | 4 +- datafusion/common/src/scalar/mod.rs | 82 ++++++ datafusion/physical-plan/src/filter.rs | 35 +++ datafusion/proto/proto/datafusion.proto | 15 + datafusion/proto/src/generated/pbjson.rs | 272 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 26 +- .../proto/src/logical_plan/from_proto.rs | 35 +++ datafusion/proto/src/logical_plan/to_proto.rs | 29 ++ datafusion/sql/src/unparser/expr.rs | 1 + 9 files changed, 496 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index 1ecd5b62bee8..d1e47b473499 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -63,7 +63,7 @@ pub enum DataFusionError { IoError(io::Error), /// Error when SQL is syntactically incorrect. /// - /// 2nd argument is for optional backtrace + /// 2nd argument is for optional backtrace SQL(ParserError, Option), /// Error when a feature is not yet implemented. /// @@ -101,7 +101,7 @@ pub enum DataFusionError { /// This error can be returned in cases such as when schema inference is not /// possible and when column names are not unique. /// - /// 2nd argument is for optional backtrace + /// 2nd argument is for optional backtrace /// Boxing the optional backtrace to prevent SchemaError(SchemaError, Box>), /// Error during execution of the query. diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index a2484e93e812..d33b8b6e142c 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -53,6 +53,8 @@ use arrow::{ }, }; use arrow_array::{ArrowNativeTypeOp, Scalar}; +use arrow_buffer::Buffer; +use arrow_schema::{UnionFields, UnionMode}; pub use struct_builder::ScalarStructBuilder; @@ -275,6 +277,11 @@ pub enum ScalarValue { DurationMicrosecond(Option), /// Duration in nanoseconds DurationNanosecond(Option), + /// A nested datatype that can represent slots of differing types. Components: + /// `.0`: a tuple of union `type_id` and the single value held by this Scalar + /// `.1`: the list of fields, zero-to-one of which will by set in `.0` + /// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came + Union(Option<(i8, Box)>, UnionFields, UnionMode), /// Dictionary type: index type and value Dictionary(Box, Box), } @@ -375,6 +382,10 @@ impl PartialEq for ScalarValue { (IntervalDayTime(_), _) => false, (IntervalMonthDayNano(v1), IntervalMonthDayNano(v2)) => v1.eq(v2), (IntervalMonthDayNano(_), _) => false, + (Union(val1, fields1, mode1), Union(val2, fields2, mode2)) => { + val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2) + } + (Union(_, _, _), _) => false, (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), (Dictionary(_, _), _) => false, (Null, Null) => true, @@ -500,6 +511,14 @@ impl PartialOrd for ScalarValue { (DurationMicrosecond(_), _) => None, (DurationNanosecond(v1), DurationNanosecond(v2)) => v1.partial_cmp(v2), (DurationNanosecond(_), _) => None, + (Union(v1, t1, m1), Union(v2, t2, m2)) => { + if t1.eq(t2) && m1.eq(m2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Union(_, _, _), _) => None, (Dictionary(k1, v1), Dictionary(k2, v2)) => { // Don't compare if the key types don't match (it is effectively a different datatype) if k1 == k2 { @@ -663,6 +682,11 @@ impl std::hash::Hash for ScalarValue { IntervalYearMonth(v) => v.hash(state), IntervalDayTime(v) => v.hash(state), IntervalMonthDayNano(v) => v.hash(state), + Union(v, t, m) => { + v.hash(state); + t.hash(state); + m.hash(state); + } Dictionary(k, v) => { k.hash(state); v.hash(state); @@ -1093,6 +1117,7 @@ impl ScalarValue { ScalarValue::DurationNanosecond(_) => { DataType::Duration(TimeUnit::Nanosecond) } + ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), ScalarValue::Dictionary(k, v) => { DataType::Dictionary(k.clone(), Box::new(v.data_type())) } @@ -1292,6 +1317,7 @@ impl ScalarValue { ScalarValue::DurationMillisecond(v) => v.is_none(), ScalarValue::DurationMicrosecond(v) => v.is_none(), ScalarValue::DurationNanosecond(v) => v.is_none(), + ScalarValue::Union(v, _, _) => v.is_none(), ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -2087,6 +2113,39 @@ impl ScalarValue { e, size ), + ScalarValue::Union(value, fields, _mode) => match value { + Some((v_id, value)) => { + let mut field_type_ids = Vec::::with_capacity(fields.len()); + let mut child_arrays = + Vec::<(Field, ArrayRef)>::with_capacity(fields.len()); + for (f_id, field) in fields.iter() { + let ar = if f_id == *v_id { + value.to_array_of_size(size)? + } else { + let dt = field.data_type(); + new_null_array(dt, size) + }; + let field = (**field).clone(); + child_arrays.push((field, ar)); + field_type_ids.push(f_id); + } + let type_ids = repeat(*v_id).take(size).collect::>(); + let type_ids = Buffer::from_slice_ref(type_ids); + let value_offsets: Option = None; + let ar = UnionArray::try_new( + field_type_ids.as_slice(), + type_ids, + value_offsets, + child_arrays, + ) + .map_err(|e| DataFusionError::ArrowError(e, None))?; + Arc::new(ar) + } + None => { + let dt = self.data_type(); + new_null_array(&dt, size) + } + }, ScalarValue::Dictionary(key_type, v) => { // values array is one element long (the value) match key_type.as_ref() { @@ -2626,6 +2685,9 @@ impl ScalarValue { ScalarValue::DurationNanosecond(val) => { eq_array_primitive!(array, index, DurationNanosecondArray, val)? } + ScalarValue::Union(_, _, _) => { + return _not_impl_err!("Union is not supported yet") + } ScalarValue::Dictionary(key_type, v) => { let (values_array, values_index) = match key_type.as_ref() { DataType::Int8 => get_dict_value::(array, index)?, @@ -2703,6 +2765,15 @@ impl ScalarValue { ScalarValue::LargeList(arr) => arr.get_array_memory_size(), ScalarValue::FixedSizeList(arr) => arr.get_array_memory_size(), ScalarValue::Struct(arr) => arr.get_array_memory_size(), + ScalarValue::Union(vals, fields, _mode) => { + vals.as_ref() + .map(|(_id, sv)| sv.size() - std::mem::size_of_val(sv)) + .unwrap_or_default() + // `fields` is boxed, so it is NOT already included in `self` + + std::mem::size_of_val(fields) + + (std::mem::size_of::() * fields.len()) + + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() + } ScalarValue::Dictionary(dt, sv) => { // `dt` and `sv` are boxed, so they are NOT already included in `self` dt.size() + sv.size() @@ -3048,6 +3119,9 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + DataType::Union(fields, mode) => { + ScalarValue::Union(None, fields.clone(), *mode) + } DataType::Null => ScalarValue::Null, _ => { return _not_impl_err!( @@ -3164,6 +3238,10 @@ impl fmt::Display for ScalarValue { .join(",") )? } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "{}:{}", id, val)?, + None => write!(f, "NULL")?, + }, ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; @@ -3279,6 +3357,10 @@ impl fmt::Debug for ScalarValue { ScalarValue::DurationNanosecond(_) => { write!(f, "DurationNanosecond(\"{self}\")") } + ScalarValue::Union(val, _fields, _mode) => match val { + Some((id, val)) => write!(f, "Union {}:{}", id, val), + None => write!(f, "Union(NULL)"), + }, ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 72f885a93962..f44ade7106df 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -441,7 +441,9 @@ mod tests { use crate::test::exec::StatisticsExec; use crate::ExecutionPlan; + use crate::empty::EmptyExec; use arrow::datatypes::{DataType, Field, Schema}; + use arrow_schema::{UnionFields, UnionMode}; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_expr::Operator; @@ -1090,4 +1092,37 @@ mod tests { assert_eq!(statistics.total_byte_size, Precision::Inexact(1600)); Ok(()) } + + #[test] + fn test_equivalence_properties_union_type() -> Result<()> { + let union_type = DataType::Union( + UnionFields::new( + vec![0, 1], + vec![ + Field::new("f1", DataType::Int32, true), + Field::new("f2", DataType::Utf8, true), + ], + ), + UnionMode::Sparse, + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", union_type, true), + ])); + + let exec = FilterExec::try_new( + binary( + binary(col("c1", &schema)?, Operator::GtEq, lit(1i32), &schema)?, + Operator::And, + binary(col("c1", &schema)?, Operator::LtEq, lit(4i32), &schema)?, + &schema, + )?, + Arc::new(EmptyExec::new(schema.clone())), + )?; + + exec.statistics().unwrap(); + + Ok(()) + } } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 7a9b427ce7d3..10f79a2b8cc8 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -988,6 +988,20 @@ message IntervalMonthDayNanoValue { int64 nanos = 3; } +message UnionField { + int32 field_id = 1; + Field field = 2; +} + +message UnionValue { + // Note that a null union value must have one or more fields, so we + // encode a null UnionValue as one with value_id == 128 + int32 value_id = 1; + ScalarValue value = 2; + repeated UnionField fields = 3; + UnionMode mode = 4; +} + message ScalarFixedSizeBinary{ bytes values = 1; int32 length = 2; @@ -1042,6 +1056,7 @@ message ScalarValue{ ScalarTime64Value time64_value = 30; IntervalMonthDayNanoValue interval_month_day_nano = 31; ScalarFixedSizeBinary fixed_size_binary_value = 34; + UnionValue union_value = 42; } } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index fd27520b3be0..7757a64ef359 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -24053,6 +24053,9 @@ impl serde::Serialize for ScalarValue { scalar_value::Value::FixedSizeBinaryValue(v) => { struct_ser.serialize_field("fixedSizeBinaryValue", v)?; } + scalar_value::Value::UnionValue(v) => { + struct_ser.serialize_field("unionValue", v)?; + } } } struct_ser.end() @@ -24137,6 +24140,8 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "intervalMonthDayNano", "fixed_size_binary_value", "fixedSizeBinaryValue", + "union_value", + "unionValue", ]; #[allow(clippy::enum_variant_names)] @@ -24177,6 +24182,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { Time64Value, IntervalMonthDayNano, FixedSizeBinaryValue, + UnionValue, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -24234,6 +24240,7 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { "time64Value" | "time64_value" => Ok(GeneratedField::Time64Value), "intervalMonthDayNano" | "interval_month_day_nano" => Ok(GeneratedField::IntervalMonthDayNano), "fixedSizeBinaryValue" | "fixed_size_binary_value" => Ok(GeneratedField::FixedSizeBinaryValue), + "unionValue" | "union_value" => Ok(GeneratedField::UnionValue), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -24483,6 +24490,13 @@ impl<'de> serde::Deserialize<'de> for ScalarValue { return Err(serde::de::Error::duplicate_field("fixedSizeBinaryValue")); } value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::FixedSizeBinaryValue) +; + } + GeneratedField::UnionValue => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("unionValue")); + } + value__ = map_.next_value::<::std::option::Option<_>>()?.map(scalar_value::Value::UnionValue) ; } } @@ -26942,6 +26956,117 @@ impl<'de> serde::Deserialize<'de> for UnionExecNode { deserializer.deserialize_struct("datafusion.UnionExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionField { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.field_id != 0 { + len += 1; + } + if self.field.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionField", len)?; + if self.field_id != 0 { + struct_ser.serialize_field("fieldId", &self.field_id)?; + } + if let Some(v) = self.field.as_ref() { + struct_ser.serialize_field("field", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionField { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "field_id", + "fieldId", + "field", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + FieldId, + Field, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "fieldId" | "field_id" => Ok(GeneratedField::FieldId), + "field" => Ok(GeneratedField::Field), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionField") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut field_id__ = None; + let mut field__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::FieldId => { + if field_id__.is_some() { + return Err(serde::de::Error::duplicate_field("fieldId")); + } + field_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Field => { + if field__.is_some() { + return Err(serde::de::Error::duplicate_field("field")); + } + field__ = map_.next_value()?; + } + } + } + Ok(UnionField { + field_id: field_id__.unwrap_or_default(), + field: field__, + }) + } + } + deserializer.deserialize_struct("datafusion.UnionField", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UnionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -27104,6 +27229,153 @@ impl<'de> serde::Deserialize<'de> for UnionNode { deserializer.deserialize_struct("datafusion.UnionNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for UnionValue { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.value_id != 0 { + len += 1; + } + if self.value.is_some() { + len += 1; + } + if !self.fields.is_empty() { + len += 1; + } + if self.mode != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.UnionValue", len)?; + if self.value_id != 0 { + struct_ser.serialize_field("valueId", &self.value_id)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + if !self.fields.is_empty() { + struct_ser.serialize_field("fields", &self.fields)?; + } + if self.mode != 0 { + let v = UnionMode::try_from(self.mode) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.mode)))?; + struct_ser.serialize_field("mode", &v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for UnionValue { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "value_id", + "valueId", + "value", + "fields", + "mode", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + ValueId, + Value, + Fields, + Mode, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "valueId" | "value_id" => Ok(GeneratedField::ValueId), + "value" => Ok(GeneratedField::Value), + "fields" => Ok(GeneratedField::Fields), + "mode" => Ok(GeneratedField::Mode), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = UnionValue; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.UnionValue") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut value_id__ = None; + let mut value__ = None; + let mut fields__ = None; + let mut mode__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::ValueId => { + if value_id__.is_some() { + return Err(serde::de::Error::duplicate_field("valueId")); + } + value_id__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map_.next_value()?; + } + GeneratedField::Fields => { + if fields__.is_some() { + return Err(serde::de::Error::duplicate_field("fields")); + } + fields__ = Some(map_.next_value()?); + } + GeneratedField::Mode => { + if mode__.is_some() { + return Err(serde::de::Error::duplicate_field("mode")); + } + mode__ = Some(map_.next_value::()? as i32); + } + } + } + Ok(UnionValue { + value_id: value_id__.unwrap_or_default(), + value: value__, + fields: fields__.unwrap_or_default(), + mode: mode__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.UnionValue", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for UniqueConstraint { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 16ad2b848db9..ab0ddb14ebfc 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1225,6 +1225,28 @@ pub struct IntervalMonthDayNanoValue { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionField { + #[prost(int32, tag = "1")] + pub field_id: i32, + #[prost(message, optional, tag = "2")] + pub field: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct UnionValue { + /// Note that a null union value must have one or more fields, so we + /// encode a null UnionValue as one with value_id == 128 + #[prost(int32, tag = "1")] + pub value_id: i32, + #[prost(message, optional, boxed, tag = "2")] + pub value: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, repeated, tag = "3")] + pub fields: ::prost::alloc::vec::Vec, + #[prost(enumeration = "UnionMode", tag = "4")] + pub mode: i32, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ScalarFixedSizeBinary { #[prost(bytes = "vec", tag = "1")] pub values: ::prost::alloc::vec::Vec, @@ -1236,7 +1258,7 @@ pub struct ScalarFixedSizeBinary { pub struct ScalarValue { #[prost( oneof = "scalar_value::Value", - tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34" + tags = "33, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 32, 20, 39, 21, 24, 25, 35, 36, 37, 38, 26, 27, 28, 29, 30, 31, 34, 42" )] pub value: ::core::option::Option, } @@ -1320,6 +1342,8 @@ pub mod scalar_value { IntervalMonthDayNano(super::IntervalMonthDayNanoValue), #[prost(message, tag = "34")] FixedSizeBinaryValue(super::ScalarFixedSizeBinary), + #[prost(message, tag = "42")] + UnionValue(::prost::alloc::boxed::Box), } } #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 06aab16edd57..8581156e2bb8 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -768,6 +768,41 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::IntervalMonthDayNano(v) => Self::IntervalMonthDayNano(Some( IntervalMonthDayNanoType::make_value(v.months, v.days, v.nanos), )), + Value::UnionValue(val) => { + let mode = match val.mode { + 0 => UnionMode::Sparse, + 1 => UnionMode::Dense, + id => Err(Error::unknown("UnionMode", id))?, + }; + let ids = val + .fields + .iter() + .map(|f| f.field_id as i8) + .collect::>(); + let fields = val + .fields + .iter() + .map(|f| f.field.clone()) + .collect::>>(); + let fields = fields.ok_or_else(|| Error::required("UnionField"))?; + let fields = fields + .iter() + .map(Field::try_from) + .collect::, _>>()?; + let fields = UnionFields::new(ids, fields); + let v_id = val.value_id as i8; + let val = match &val.value { + None => None, + Some(val) => { + let val: ScalarValue = val + .as_ref() + .try_into() + .map_err(|_| Error::General("Invalid Scalar".to_string()))?; + Some((v_id, Box::new(val))) + } + }; + Self::Union(val, fields, mode) + } Value::FixedSizeBinaryValue(v) => { Self::FixedSizeBinary(v.length, Some(v.clone().values)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 92015594906b..05a29ff6d42b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -30,6 +30,7 @@ use crate::protobuf::{ }, AnalyzedLogicalPlanType, CubeNode, EmptyMessage, GroupingSetNode, LogicalExprList, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, + UnionField, UnionValue, }; use arrow::{ @@ -1405,6 +1406,34 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { }; Ok(protobuf::ScalarValue { value: Some(value) }) } + + ScalarValue::Union(val, df_fields, mode) => { + let mut fields = Vec::::with_capacity(df_fields.len()); + for (id, field) in df_fields.iter() { + let field_id = id as i32; + let field = Some(field.as_ref().try_into()?); + let field = UnionField { field_id, field }; + fields.push(field); + } + let mode = match mode { + UnionMode::Sparse => 0, + UnionMode::Dense => 1, + }; + let value = match val { + None => None, + Some((_id, v)) => Some(Box::new(v.as_ref().try_into()?)), + }; + let val = UnionValue { + value_id: val.as_ref().map(|(id, _v)| *id as i32).unwrap_or(0), + value, + fields, + mode, + }; + let val = Value::UnionValue(Box::new(val)); + let val = protobuf::ScalarValue { value: Some(val) }; + Ok(val) + } + ScalarValue::Dictionary(index_type, val) => { let value: protobuf::ScalarValue = val.as_ref().try_into()?; Ok(protobuf::ScalarValue { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index c26e8481ce43..43f3e348dc32 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -456,6 +456,7 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Null)) } ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), + ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } } From ad8d552b9f150c3c066b0764e84f72b667a649ff Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Tue, 19 Mar 2024 22:09:20 +0100 Subject: [PATCH 021/117] parquet: Add support for row group pruning on FixedSizeBinary (#9646) * Add support for row group pruning on FixedSizeBinary * Check statistics values are valid for their type --- .../physical_plan/parquet/row_groups.rs | 1 + .../physical_plan/parquet/statistics.rs | 27 ++++- .../core/tests/parquet/row_group_pruning.rs | 101 ++++++++++++++++++ 3 files changed, 127 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 9cd46994960f..a82c5d97a2b7 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -226,6 +226,7 @@ impl PruningStatistics for BloomFilterStatistics { match value { ScalarValue::Utf8(Some(v)) => sbbf.check(&v.as_str()), ScalarValue::Binary(Some(v)) => sbbf.check(v), + ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), ScalarValue::Boolean(Some(v)) => sbbf.check(v), ScalarValue::Float64(Some(v)) => sbbf.check(v), ScalarValue::Float32(Some(v)) => sbbf.check(v), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index 4e472606da51..aac5aff80f16 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -105,14 +105,20 @@ macro_rules! get_statistic { let s = std::str::from_utf8(s.$bytes_func()) .map(|s| s.to_string()) .ok(); + if s.is_none() { + log::debug!( + "Utf8 statistics is a non-UTF8 value, ignoring it." + ); + } Some(ScalarValue::Utf8(s)) } } } - // type not supported yet + // type not fully supported yet ParquetStatistics::FixedLenByteArray(s) => { match $target_arrow_type { - // just support the decimal data type + // just support specific logical data types, there are others each + // with their own ordering Some(DataType::Decimal128(precision, scale)) => { Some(ScalarValue::Decimal128( Some(from_bytes_to_i128(s.$bytes_func())), @@ -120,6 +126,23 @@ macro_rules! get_statistic { *scale, )) } + Some(DataType::FixedSizeBinary(size)) => { + let value = s.$bytes_func().to_vec(); + let value = if value.len().try_into() == Ok(*size) { + Some(value) + } else { + log::debug!( + "FixedSizeBinary({}) statistics is a binary of size {}, ignoring it.", + size, + value.len(), + ); + None + }; + Some(ScalarValue::FixedSizeBinary( + *size, + value, + )) + } _ => None, } } diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 55112193502d..ed48d040648c 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -948,6 +948,107 @@ async fn prune_binary_lt() { .await; } +#[tokio::test] +async fn prune_fixedsizebinary_eq_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('fe6' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('fe6' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'all frontends' batch: 'fe1' < 'fe6' < 'fe7' + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(1)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(1) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_eq_no_match() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize = ARROW_CAST(CAST('be9' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // false positive on 'mixed' batch: 'be1' < 'be9' < 'fe4' + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_neq() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize != ARROW_CAST(CAST('be1' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(3)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(14) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn prune_fixedsizebinary_lt() { + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize < ARROW_CAST(CAST('be3' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + // matches 'all backends' only + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(2)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::ByteArray) + .with_query( + "SELECT name, service_fixedsize FROM t WHERE service_fixedsize < ARROW_CAST(CAST('be9' AS bytea), 'FixedSizeBinary(3)')", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(2)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + // all backends from 'mixed' and 'all backends' + .with_expected_rows(8) + .test_row_group_prune() + .await; +} + #[tokio::test] async fn prune_periods_in_column_names() { // There are three row groups for "service.name", each with 5 rows = 15 rows total From 89efc4a7e06bd0295ca72dd6ec5fe987d1ac246b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Mar 2024 08:58:02 -0400 Subject: [PATCH 022/117] Minor: Add documentation about LogicalPlan::expressions (#9698) --- datafusion/expr/src/logical_plan/extension.rs | 9 +++++---- datafusion/expr/src/logical_plan/plan.rs | 14 +++++++++++--- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index f87ca45f14be..bb2c932ce391 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -53,10 +53,11 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { /// Return the output schema of this logical plan node. fn schema(&self) -> &DFSchemaRef; - /// Returns all expressions in the current logical plan node. This - /// should not include expressions of any inputs (aka - /// non-recursively). These expressions are used for optimizer - /// passes and rewrites. + /// Returns all expressions in the current logical plan node. This should + /// not include expressions of any inputs (aka non-recursively). + /// + /// These expressions are used for optimizer + /// passes and rewrites. See [`LogicalPlan::expressions`] for more details. fn expressions(&self) -> Vec; /// A list of output columns (e.g. the names of columns in diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 08fe3380061f..05d7ac539458 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -234,9 +234,17 @@ impl LogicalPlan { ]) } - /// returns all expressions (non-recursively) in the current - /// logical plan node. This does not include expressions in any - /// children + /// Returns all expressions (non-recursively) evaluated by the current + /// logical plan node. This does not include expressions in any children + /// + /// The returned expressions do not necessarily represent or even + /// contributed to the output schema of this node. For example, + /// `LogicalPlan::Filter` returns the filter expression even though the + /// output of a Filter has the same columns as the input. + /// + /// The expressions do contain all the columns that are used by this plan, + /// so if there are columns not referenced by these expressions then + /// DataFusion's optimizer attempts to optimize them away. pub fn expressions(self: &LogicalPlan) -> Vec { let mut exprs = vec![]; self.inspect_expressions(|e| { From 1d0171ab9d33fc7896861dee85804d7daf0a6390 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 20 Mar 2024 08:24:33 -0700 Subject: [PATCH 023/117] Make builtin window function output datatype to be derived from schema (#9686) * Make builtin window function output datatype to be derived from schema --- datafusion/core/src/physical_planner.rs | 22 ++++----- .../core/tests/fuzz_cases/window_fuzz.rs | 39 +++++++++++++-- datafusion/physical-plan/src/windows/mod.rs | 47 ++++++++++--------- 3 files changed, 72 insertions(+), 36 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ee581ca64214..ca708b05823e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -742,13 +742,13 @@ impl DefaultPhysicalPlanner { ); } - let logical_input_schema = input.schema(); + let logical_schema = logical_plan.schema(); let window_expr = window_expr .iter() .map(|e| { create_window_expr( e, - logical_input_schema, + logical_schema, session_state.execution_props(), ) }) @@ -1578,11 +1578,11 @@ pub fn is_window_frame_bound_valid(window_frame: &WindowFrame) -> bool { pub fn create_window_expr_with_name( e: &Expr, name: impl Into, - logical_input_schema: &DFSchema, + logical_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { let name = name.into(); - let physical_input_schema: &Schema = &logical_input_schema.into(); + let physical_schema: &Schema = &logical_schema.into(); match e { Expr::WindowFunction(WindowFunction { fun, @@ -1594,17 +1594,15 @@ pub fn create_window_expr_with_name( }) => { let args = args .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) + .map(|e| create_physical_expr(e, logical_schema, execution_props)) .collect::>>()?; let partition_by = partition_by .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) + .map(|e| create_physical_expr(e, logical_schema, execution_props)) .collect::>>()?; let order_by = order_by .iter() - .map(|e| { - create_physical_sort_expr(e, logical_input_schema, execution_props) - }) + .map(|e| create_physical_sort_expr(e, logical_schema, execution_props)) .collect::>>()?; if !is_window_frame_bound_valid(window_frame) { @@ -1625,7 +1623,7 @@ pub fn create_window_expr_with_name( &partition_by, &order_by, window_frame, - physical_input_schema, + physical_schema, ignore_nulls, ) } @@ -1636,7 +1634,7 @@ pub fn create_window_expr_with_name( /// Create a window expression from a logical expression or an alias pub fn create_window_expr( e: &Expr, - logical_input_schema: &DFSchema, + logical_schema: &DFSchema, execution_props: &ExecutionProps, ) -> Result> { // unpack aliased logical expressions, e.g. "sum(col) over () as total" @@ -1644,7 +1642,7 @@ pub fn create_window_expr( Expr::Alias(Alias { expr, name, .. }) => (name.clone(), expr.as_ref()), _ => (e.display_name()?, e), }; - create_window_expr_with_name(e, name, logical_input_schema, execution_props) + create_window_expr_with_name(e, name, logical_schema, execution_props) } type AggregateExprWithOptionalArgs = ( diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs b/datafusion/core/tests/fuzz_cases/window_fuzz.rs index 00c65995a5ff..2514324a9541 100644 --- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs @@ -22,6 +22,7 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; +use arrow_schema::{Field, Schema}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::windows::{ @@ -39,6 +40,7 @@ use datafusion_expr::{ }; use datafusion_physical_expr::expressions::{cast, col, lit}; use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use itertools::Itertools; use test_utils::add_empty_batches; use hashbrown::HashMap; @@ -273,6 +275,9 @@ async fn bounded_window_causal_non_causal() -> Result<()> { window_frame.is_causal() }; + let extended_schema = + schema_add_window_fields(&args, &schema, &window_fn, fn_name)?; + let window_expr = create_window_expr( &window_fn, fn_name.to_string(), @@ -280,7 +285,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> { &partitionby_exprs, &orderby_exprs, Arc::new(window_frame), - schema.as_ref(), + &extended_schema, false, )?; let running_window_exec = Arc::new(BoundedWindowAggExec::try_new( @@ -678,6 +683,8 @@ async fn run_window_test( exec1 = Arc::new(SortExec::new(sort_keys, exec1)) as _; } + let extended_schema = schema_add_window_fields(&args, &schema, &window_fn, &fn_name)?; + let usual_window_exec = Arc::new(WindowAggExec::try_new( vec![create_window_expr( &window_fn, @@ -686,7 +693,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - schema.as_ref(), + &extended_schema, false, )?], exec1, @@ -704,7 +711,7 @@ async fn run_window_test( &partitionby_exprs, &orderby_exprs, Arc::new(window_frame.clone()), - schema.as_ref(), + &extended_schema, false, )?], exec2, @@ -747,6 +754,32 @@ async fn run_window_test( Ok(()) } +// The planner has fully updated schema before calling the `create_window_expr` +// Replicate the same for this test +fn schema_add_window_fields( + args: &[Arc], + schema: &Arc, + window_fn: &WindowFunctionDefinition, + fn_name: &str, +) -> Result> { + let data_types = args + .iter() + .map(|e| e.clone().as_ref().data_type(schema)) + .collect::>>()?; + let window_expr_return_type = window_fn.return_type(&data_types)?; + let mut window_fields = schema + .fields() + .iter() + .map(|f| f.as_ref().clone()) + .collect_vec(); + window_fields.extend_from_slice(&[Field::new( + fn_name, + window_expr_return_type, + true, + )]); + Ok(Arc::new(Schema::new(window_fields))) +} + /// Return randomly sized record batches with: /// three sorted int32 columns 'a', 'b', 'c' ranged from 0..DISTINCT as columns /// one random int32 column x diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index da2b24487d02..21f42f41fb5c 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -174,20 +174,15 @@ fn create_built_in_window_expr( name: String, ignore_nulls: bool, ) -> Result> { - // need to get the types into an owned vec for some reason - let input_types: Vec<_> = args - .iter() - .map(|arg| arg.data_type(input_schema)) - .collect::>()?; + // derive the output datatype from incoming schema + let out_data_type: &DataType = input_schema.field_with_name(&name)?.data_type(); - // figure out the output type - let data_type = &fun.return_type(&input_types)?; Ok(match fun { - BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, data_type)), - BuiltInWindowFunction::Rank => Arc::new(rank(name, data_type)), - BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, data_type)), - BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, data_type)), - BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, data_type)), + BuiltInWindowFunction::RowNumber => Arc::new(RowNumber::new(name, out_data_type)), + BuiltInWindowFunction::Rank => Arc::new(rank(name, out_data_type)), + BuiltInWindowFunction::DenseRank => Arc::new(dense_rank(name, out_data_type)), + BuiltInWindowFunction::PercentRank => Arc::new(percent_rank(name, out_data_type)), + BuiltInWindowFunction::CumeDist => Arc::new(cume_dist(name, out_data_type)), BuiltInWindowFunction::Ntile => { let n = get_scalar_value_from_args(args, 0)?.ok_or_else(|| { DataFusionError::Execution( @@ -201,13 +196,13 @@ fn create_built_in_window_expr( if n.is_unsigned() { let n: u64 = n.try_into()?; - Arc::new(Ntile::new(name, n, data_type)) + Arc::new(Ntile::new(name, n, out_data_type)) } else { let n: i64 = n.try_into()?; if n <= 0 { return exec_err!("NTILE requires a positive integer"); } - Arc::new(Ntile::new(name, n as u64, data_type)) + Arc::new(Ntile::new(name, n as u64, out_data_type)) } } BuiltInWindowFunction::Lag => { @@ -216,10 +211,10 @@ fn create_built_in_window_expr( .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; Arc::new(lag( name, - data_type.clone(), + out_data_type.clone(), arg, shift_offset, default_value, @@ -232,10 +227,10 @@ fn create_built_in_window_expr( .map(|v| v.try_into()) .and_then(|v| v.ok()); let default_value = - get_casted_value(get_scalar_value_from_args(args, 2)?, data_type)?; + get_casted_value(get_scalar_value_from_args(args, 2)?, out_data_type)?; Arc::new(lead( name, - data_type.clone(), + out_data_type.clone(), arg, shift_offset, default_value, @@ -252,18 +247,28 @@ fn create_built_in_window_expr( Arc::new(NthValue::nth( name, arg, - data_type.clone(), + out_data_type.clone(), n, ignore_nulls, )?) } BuiltInWindowFunction::FirstValue => { let arg = args[0].clone(); - Arc::new(NthValue::first(name, arg, data_type.clone(), ignore_nulls)) + Arc::new(NthValue::first( + name, + arg, + out_data_type.clone(), + ignore_nulls, + )) } BuiltInWindowFunction::LastValue => { let arg = args[0].clone(); - Arc::new(NthValue::last(name, arg, data_type.clone(), ignore_nulls)) + Arc::new(NthValue::last( + name, + arg, + out_data_type.clone(), + ignore_nulls, + )) } }) } From 3bf06d3cc40657d38ab3425dca1945e4592d2d05 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Wed, 20 Mar 2024 12:19:33 -0700 Subject: [PATCH 024/117] Issue-9660 - Extract array_to_string and string_to_array from kernels and udf containers (#9704) --- datafusion/functions-array/src/kernels.rs | 329 +-------------- datafusion/functions-array/src/lib.rs | 9 +- datafusion/functions-array/src/string.rs | 479 ++++++++++++++++++++++ datafusion/functions-array/src/udf.rs | 137 +------ datafusion/functions-array/src/utils.rs | 12 + 5 files changed, 502 insertions(+), 464 deletions(-) create mode 100644 datafusion/functions-array/src/string.rs diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index 15cdf8f279ae..ec0942837795 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -18,10 +18,8 @@ //! implementation kernels for array functions use arrow::array::{ - Array, ArrayRef, BooleanArray, Capacities, Date32Array, Float32Array, Float64Array, - GenericListArray, Int16Array, Int32Array, Int64Array, Int8Array, LargeListArray, - LargeStringArray, ListArray, ListBuilder, MutableArrayData, OffsetSizeTrait, - StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Array, ArrayRef, BooleanArray, Capacities, Date32Array, GenericListArray, Int64Array, + LargeListArray, ListArray, MutableArrayData, OffsetSizeTrait, UInt64Array, }; use arrow::compute; use arrow::datatypes::{ @@ -33,335 +31,18 @@ use arrow_schema::FieldRef; use arrow_schema::SortOptions; use datafusion_common::cast::{ - as_date32_array, as_generic_list_array, as_generic_string_array, as_int64_array, - as_interval_mdn_array, as_large_list_array, as_list_array, as_null_array, - as_string_array, + as_date32_array, as_generic_list_array, as_int64_array, as_interval_mdn_array, + as_large_list_array, as_list_array, as_null_array, as_string_array, }; use datafusion_common::{ exec_err, internal_datafusion_err, not_impl_datafusion_err, DataFusionError, Result, ScalarValue, }; +use crate::utils::downcast_arg; use std::any::type_name; use std::sync::Arc; -macro_rules! downcast_arg { - ($ARG:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast to {}", - type_name::<$ARRAY_TYPE>() - )) - })? - }}; -} - -macro_rules! to_string { - ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ - let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); - for x in arr { - match x { - Some(x) => { - $ARG.push_str(&x.to_string()); - $ARG.push_str($DELIMITER); - } - None => { - if $WITH_NULL_STRING { - $ARG.push_str($NULL_STRING); - $ARG.push_str($DELIMITER); - } - } - } - } - Ok($ARG) - }}; -} - -macro_rules! call_array_function { - ($DATATYPE:expr, false) => { - match $DATATYPE { - DataType::Utf8 => array_function!(StringArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), - } - }; - ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ - match $DATATYPE { - DataType::List(_) => array_function!(ListArray), - DataType::Utf8 => array_function!(StringArray), - DataType::LargeUtf8 => array_function!(LargeStringArray), - DataType::Boolean => array_function!(BooleanArray), - DataType::Float32 => array_function!(Float32Array), - DataType::Float64 => array_function!(Float64Array), - DataType::Int8 => array_function!(Int8Array), - DataType::Int16 => array_function!(Int16Array), - DataType::Int32 => array_function!(Int32Array), - DataType::Int64 => array_function!(Int64Array), - DataType::UInt8 => array_function!(UInt8Array), - DataType::UInt16 => array_function!(UInt16Array), - DataType::UInt32 => array_function!(UInt32Array), - DataType::UInt64 => array_function!(UInt64Array), - _ => unreachable!(), - } - }}; -} - -/// Array_to_string SQL function -pub(super) fn array_to_string(args: &[ArrayRef]) -> Result { - if args.len() < 2 || args.len() > 3 { - return exec_err!("array_to_string expects two or three arguments"); - } - - let arr = &args[0]; - - let delimiters = as_string_array(&args[1])?; - let delimiters: Vec> = delimiters.iter().collect(); - - let mut null_string = String::from(""); - let mut with_null_string = false; - if args.len() == 3 { - null_string = as_string_array(&args[2])?.value(0).to_string(); - with_null_string = true; - } - - fn compute_array_to_string( - arg: &mut String, - arr: ArrayRef, - delimiter: String, - null_string: String, - with_null_string: bool, - ) -> datafusion_common::Result<&mut String> { - match arr.data_type() { - DataType::List(..) => { - let list_array = as_list_array(&arr)?; - for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } - - Ok(arg) - } - DataType::LargeList(..) => { - let list_array = as_large_list_array(&arr)?; - for i in 0..list_array.len() { - compute_array_to_string( - arg, - list_array.value(i), - delimiter.clone(), - null_string.clone(), - with_null_string, - )?; - } - - Ok(arg) - } - DataType::Null => Ok(arg), - data_type => { - macro_rules! array_function { - ($ARRAY_TYPE:ident) => { - to_string!( - arg, - arr, - &delimiter, - &null_string, - with_null_string, - $ARRAY_TYPE - ) - }; - } - call_array_function!(data_type, false) - } - } - } - - fn generate_string_array( - list_arr: &GenericListArray, - delimiters: Vec>, - null_string: String, - with_null_string: bool, - ) -> datafusion_common::Result { - let mut res: Vec> = Vec::new(); - for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { - if let (Some(arr), Some(delimiter)) = (arr, delimiter) { - let mut arg = String::from(""); - let s = compute_array_to_string( - &mut arg, - arr, - delimiter.to_string(), - null_string.clone(), - with_null_string, - )? - .clone(); - - if let Some(s) = s.strip_suffix(delimiter) { - res.push(Some(s.to_string())); - } else { - res.push(Some(s)); - } - } else { - res.push(None); - } - } - - Ok(StringArray::from(res)) - } - - let arr_type = arr.data_type(); - let string_arr = match arr_type { - DataType::List(_) | DataType::FixedSizeList(_, _) => { - let list_array = as_list_array(&arr)?; - generate_string_array::( - list_array, - delimiters, - null_string, - with_null_string, - )? - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&arr)?; - generate_string_array::( - list_array, - delimiters, - null_string, - with_null_string, - )? - } - _ => { - let mut arg = String::from(""); - let mut res: Vec> = Vec::new(); - // delimiter length is 1 - assert_eq!(delimiters.len(), 1); - let delimiter = delimiters[0].unwrap(); - let s = compute_array_to_string( - &mut arg, - arr.clone(), - delimiter.to_string(), - null_string, - with_null_string, - )? - .clone(); - - if !s.is_empty() { - let s = s.strip_suffix(delimiter).unwrap().to_string(); - res.push(Some(s)); - } else { - res.push(Some(s)); - } - StringArray::from(res) - } - }; - - Ok(Arc::new(string_arr)) -} - -/// Splits string at occurrences of delimiter and returns an array of parts -/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' -pub fn string_to_array(args: &[ArrayRef]) -> Result { - if args.len() < 2 || args.len() > 3 { - return exec_err!("string_to_array expects two or three arguments"); - } - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - - let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( - string_array.len(), - string_array.get_buffer_memory_size(), - )); - - match args.len() { - 2 => { - string_array.iter().zip(delimiter_array.iter()).for_each( - |(string, delimiter)| { - match (string, delimiter) { - (Some(string), Some("")) => { - list_builder.values().append_value(string); - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - list_builder.values().append_value(s); - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - list_builder.values().append_value(c); - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value - } - }, - ); - } - - 3 => { - let null_value_array = as_generic_string_array::(&args[2])?; - string_array - .iter() - .zip(delimiter_array.iter()) - .zip(null_value_array.iter()) - .for_each(|((string, delimiter), null_value)| { - match (string, delimiter) { - (Some(string), Some("")) => { - if Some(string) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(string); - } - list_builder.append(true); - } - (Some(string), Some(delimiter)) => { - string.split(delimiter).for_each(|s| { - if Some(s) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(s); - } - }); - list_builder.append(true); - } - (Some(string), None) => { - string.chars().map(|c| c.to_string()).for_each(|c| { - if Some(c.as_str()) == null_value { - list_builder.values().append_null(); - } else { - list_builder.values().append_value(c); - } - }); - list_builder.append(true); - } - _ => list_builder.append(false), // null value - } - }); - } - _ => { - return exec_err!( - "Expect string_to_array function to take two or three parameters" - ) - } - } - - let list_array = list_builder.finish(); - Ok(Arc::new(list_array) as ArrayRef) -} - /// Generates an array of integers from start to stop with a given step. /// /// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index fb16acdef2bd..f8d85800b3e3 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -39,6 +39,7 @@ mod remove; mod replace; mod rewrite; mod set_ops; +mod string; mod udf; mod utils; @@ -73,6 +74,8 @@ pub mod expr_fn { pub use super::set_ops::array_distinct; pub use super::set_ops::array_intersect; pub use super::set_ops::array_union; + pub use super::string::array_to_string; + pub use super::string::string_to_array; pub use super::udf::array_dims; pub use super::udf::array_empty; pub use super::udf::array_length; @@ -81,19 +84,17 @@ pub mod expr_fn { pub use super::udf::array_resize; pub use super::udf::array_reverse; pub use super::udf::array_sort; - pub use super::udf::array_to_string; pub use super::udf::cardinality; pub use super::udf::flatten; pub use super::udf::gen_series; pub use super::udf::range; - pub use super::udf::string_to_array; } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ - udf::array_to_string_udf(), - udf::string_to_array_udf(), + string::array_to_string_udf(), + string::string_to_array_udf(), udf::range_udf(), udf::gen_series_udf(), udf::array_dims_udf(), diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs new file mode 100644 index 000000000000..3140866f5ff6 --- /dev/null +++ b/datafusion/functions-array/src/string.rs @@ -0,0 +1,479 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_to_string and string_to_array functions. + +use arrow::array::{ + Array, ArrayRef, BooleanArray, Float32Array, Float64Array, GenericListArray, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, ListBuilder, + OffsetSizeTrait, StringArray, StringBuilder, UInt16Array, UInt32Array, UInt64Array, + UInt8Array, +}; +use arrow::datatypes::{DataType, Field}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{Expr, TypeSignature}; + +use datafusion_common::{plan_err, DataFusionError, Result}; + +use std::any::{type_name, Any}; + +use crate::utils::{downcast_arg, make_scalar_function}; +use arrow_schema::DataType::{LargeUtf8, Utf8}; +use datafusion_common::cast::{ + as_generic_string_array, as_large_list_array, as_list_array, as_string_array, +}; +use datafusion_common::exec_err; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::sync::Arc; + +macro_rules! to_string { + ($ARG:expr, $ARRAY:expr, $DELIMITER:expr, $NULL_STRING:expr, $WITH_NULL_STRING:expr, $ARRAY_TYPE:ident) => {{ + let arr = downcast_arg!($ARRAY, $ARRAY_TYPE); + for x in arr { + match x { + Some(x) => { + $ARG.push_str(&x.to_string()); + $ARG.push_str($DELIMITER); + } + None => { + if $WITH_NULL_STRING { + $ARG.push_str($NULL_STRING); + $ARG.push_str($DELIMITER); + } + } + } + } + Ok($ARG) + }}; +} + +macro_rules! call_array_function { + ($DATATYPE:expr, false) => { + match $DATATYPE { + DataType::Utf8 => array_function!(StringArray), + DataType::LargeUtf8 => array_function!(LargeStringArray), + DataType::Boolean => array_function!(BooleanArray), + DataType::Float32 => array_function!(Float32Array), + DataType::Float64 => array_function!(Float64Array), + DataType::Int8 => array_function!(Int8Array), + DataType::Int16 => array_function!(Int16Array), + DataType::Int32 => array_function!(Int32Array), + DataType::Int64 => array_function!(Int64Array), + DataType::UInt8 => array_function!(UInt8Array), + DataType::UInt16 => array_function!(UInt16Array), + DataType::UInt32 => array_function!(UInt32Array), + DataType::UInt64 => array_function!(UInt64Array), + _ => unreachable!(), + } + }; + ($DATATYPE:expr, $INCLUDE_LIST:expr) => {{ + match $DATATYPE { + DataType::List(_) => array_function!(ListArray), + DataType::Utf8 => array_function!(StringArray), + DataType::LargeUtf8 => array_function!(LargeStringArray), + DataType::Boolean => array_function!(BooleanArray), + DataType::Float32 => array_function!(Float32Array), + DataType::Float64 => array_function!(Float64Array), + DataType::Int8 => array_function!(Int8Array), + DataType::Int16 => array_function!(Int16Array), + DataType::Int32 => array_function!(Int32Array), + DataType::Int64 => array_function!(Int64Array), + DataType::UInt8 => array_function!(UInt8Array), + DataType::UInt16 => array_function!(UInt16Array), + DataType::UInt32 => array_function!(UInt32Array), + DataType::UInt64 => array_function!(UInt64Array), + _ => unreachable!(), + } + }}; +} + +// Create static instances of ScalarUDFs for each function +make_udf_function!( + ArrayToString, + array_to_string, + array delimiter, // arg name + "converts each element to its text representation.", // doc + array_to_string_udf // internal function name +); +#[derive(Debug)] +pub(super) struct ArrayToString { + signature: Signature, + aliases: Vec, +} + +impl ArrayToString { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![ + String::from("array_to_string"), + String::from("list_to_string"), + String::from("array_join"), + String::from("list_join"), + ], + } + } +} + +impl ScalarUDFImpl for ArrayToString { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_to_string" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, + _ => { + return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_to_string_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + StringToArray, + string_to_array, + string delimiter null_string, // arg name + "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc + string_to_array_udf // internal function name +); +#[derive(Debug)] +pub(super) struct StringToArray { + signature: Signature, + aliases: Vec, +} + +impl StringToArray { + pub fn new() -> Self { + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Uniform(2, vec![Utf8, LargeUtf8]), + TypeSignature::Uniform(3, vec![Utf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + aliases: vec![ + String::from("string_to_array"), + String::from("string_to_list"), + ], + } + } +} + +impl ScalarUDFImpl for StringToArray { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "string_to_array" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + Ok(match arg_types[0] { + Utf8 | LargeUtf8 => { + List(Arc::new(Field::new("item", arg_types[0].clone(), true))) + } + _ => { + return plan_err!( + "The string_to_array function can only accept Utf8 or LargeUtf8." + ); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + Utf8 => make_scalar_function(string_to_array_inner::)(args), + LargeUtf8 => make_scalar_function(string_to_array_inner::)(args), + other => { + exec_err!("unsupported type for string_to_array function as {other}") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_to_string SQL function +pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("array_to_string expects two or three arguments"); + } + + let arr = &args[0]; + + let delimiters = as_string_array(&args[1])?; + let delimiters: Vec> = delimiters.iter().collect(); + + let mut null_string = String::from(""); + let mut with_null_string = false; + if args.len() == 3 { + null_string = as_string_array(&args[2])?.value(0).to_string(); + with_null_string = true; + } + + fn compute_array_to_string( + arg: &mut String, + arr: ArrayRef, + delimiter: String, + null_string: String, + with_null_string: bool, + ) -> Result<&mut String> { + match arr.data_type() { + DataType::List(..) => { + let list_array = as_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + + Ok(arg) + } + DataType::LargeList(..) => { + let list_array = as_large_list_array(&arr)?; + for i in 0..list_array.len() { + compute_array_to_string( + arg, + list_array.value(i), + delimiter.clone(), + null_string.clone(), + with_null_string, + )?; + } + + Ok(arg) + } + DataType::Null => Ok(arg), + data_type => { + macro_rules! array_function { + ($ARRAY_TYPE:ident) => { + to_string!( + arg, + arr, + &delimiter, + &null_string, + with_null_string, + $ARRAY_TYPE + ) + }; + } + call_array_function!(data_type, false) + } + } + } + + fn generate_string_array( + list_arr: &GenericListArray, + delimiters: Vec>, + null_string: String, + with_null_string: bool, + ) -> Result { + let mut res: Vec> = Vec::new(); + for (arr, &delimiter) in list_arr.iter().zip(delimiters.iter()) { + if let (Some(arr), Some(delimiter)) = (arr, delimiter) { + let mut arg = String::from(""); + let s = compute_array_to_string( + &mut arg, + arr, + delimiter.to_string(), + null_string.clone(), + with_null_string, + )? + .clone(); + + if let Some(s) = s.strip_suffix(delimiter) { + res.push(Some(s.to_string())); + } else { + res.push(Some(s)); + } + } else { + res.push(None); + } + } + + Ok(StringArray::from(res)) + } + + let arr_type = arr.data_type(); + let string_arr = match arr_type { + DataType::List(_) | DataType::FixedSizeList(_, _) => { + let list_array = as_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&arr)?; + generate_string_array::( + list_array, + delimiters, + null_string, + with_null_string, + )? + } + _ => { + let mut arg = String::from(""); + let mut res: Vec> = Vec::new(); + // delimiter length is 1 + assert_eq!(delimiters.len(), 1); + let delimiter = delimiters[0].unwrap(); + let s = compute_array_to_string( + &mut arg, + arr.clone(), + delimiter.to_string(), + null_string, + with_null_string, + )? + .clone(); + + if !s.is_empty() { + let s = s.strip_suffix(delimiter).unwrap().to_string(); + res.push(Some(s)); + } else { + res.push(Some(s)); + } + StringArray::from(res) + } + }; + + Ok(Arc::new(string_arr)) +} + +/// String_to_array SQL function +/// Splits string at occurrences of delimiter and returns an array of parts +/// string_to_array('abc~@~def~@~ghi', '~@~') = '["abc", "def", "ghi"]' +pub fn string_to_array_inner(args: &[ArrayRef]) -> Result { + if args.len() < 2 || args.len() > 3 { + return exec_err!("string_to_array expects two or three arguments"); + } + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + + let mut list_builder = ListBuilder::new(StringBuilder::with_capacity( + string_array.len(), + string_array.get_buffer_memory_size(), + )); + + match args.len() { + 2 => { + string_array.iter().zip(delimiter_array.iter()).for_each( + |(string, delimiter)| { + match (string, delimiter) { + (Some(string), Some("")) => { + list_builder.values().append_value(string); + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + list_builder.values().append_value(s); + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + list_builder.values().append_value(c); + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }, + ); + } + + 3 => { + let null_value_array = as_generic_string_array::(&args[2])?; + string_array + .iter() + .zip(delimiter_array.iter()) + .zip(null_value_array.iter()) + .for_each(|((string, delimiter), null_value)| { + match (string, delimiter) { + (Some(string), Some("")) => { + if Some(string) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(string); + } + list_builder.append(true); + } + (Some(string), Some(delimiter)) => { + string.split(delimiter).for_each(|s| { + if Some(s) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(s); + } + }); + list_builder.append(true); + } + (Some(string), None) => { + string.chars().map(|c| c.to_string()).for_each(|c| { + if Some(c.as_str()) == null_value { + list_builder.values().append_null(); + } else { + list_builder.values().append_value(c); + } + }); + list_builder.append(true); + } + _ => list_builder.append(false), // null value + } + }); + } + _ => { + return exec_err!( + "Expect string_to_array function to take two or three parameters" + ) + } + } + + let list_array = list_builder.finish(); + Ok(Arc::new(list_array) as ArrayRef) +} diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index e0793900c6b3..5f5d90851758 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -17,11 +17,10 @@ //! [`ScalarUDFImpl`] definitions for array functions. -use arrow::array::{NullArray, StringArray}; use arrow::datatypes::DataType; use arrow::datatypes::Field; use arrow::datatypes::IntervalUnit::MonthDayNano; -use arrow_schema::DataType::{LargeUtf8, List, Utf8}; +use arrow_schema::DataType::List; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::Result; @@ -32,140 +31,6 @@ use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -// Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayToString, - array_to_string, - array delimiter, // arg name - "converts each element to its text representation.", // doc - array_to_string_udf // internal function name -); -#[derive(Debug)] -pub(super) struct ArrayToString { - signature: Signature, - aliases: Vec, -} - -impl ArrayToString { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("array_to_string"), - String::from("list_to_string"), - String::from("array_join"), - String::from("list_join"), - ], - } - } -} - -impl ScalarUDFImpl for ArrayToString { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_to_string" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, - _ => { - return plan_err!("The array_to_string function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_to_string(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!(StringToArray, - string_to_array, - string delimiter null_string, // arg name - "splits a `string` based on a `delimiter` and returns an array of parts. Any parts matching the optional `null_string` will be replaced with `NULL`", // doc - string_to_array_udf // internal function name -); -#[derive(Debug)] -pub(super) struct StringToArray { - signature: Signature, - aliases: Vec, -} - -impl StringToArray { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![ - String::from("string_to_array"), - String::from("string_to_list"), - ], - } - } -} - -impl ScalarUDFImpl for StringToArray { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "string_to_array" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - Utf8 | LargeUtf8 => { - List(Arc::new(Field::new("item", arg_types[0].clone(), true))) - } - _ => { - return plan_err!( - "The string_to_array function can only accept Utf8 or LargeUtf8." - ); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let mut args = ColumnarValue::values_to_arrays(args)?; - // Case: delimiter is NULL, needs to be handled as well. - if args[1].as_any().is::() { - args[1] = Arc::new(StringArray::new_null(args[1].len())); - }; - - match args[0].data_type() { - Utf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - LargeUtf8 => { - crate::kernels::string_to_array::(&args).map(ColumnarValue::Array) - } - other => { - exec_err!("unsupported type for string_to_array function as {other}") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - make_udf_function!( Range, range, diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index 9589cb05fe9b..c0f7627d2ab7 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -214,6 +214,18 @@ pub(crate) fn compare_element_to_list( Ok(res) } +macro_rules! downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast to {}", + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} +pub(crate) use downcast_arg; + #[cfg(test)] mod tests { use super::*; From 55aacf62b39c7632df6536b2c1bf3856faf708ac Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Mar 2024 15:19:42 -0400 Subject: [PATCH 025/117] Document MSRV policy (#9681) --- README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index abd727672aca..c3d7c6792990 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,11 @@ Optional features: [apache avro]: https://avro.apache.org/ [apache parquet]: https://parquet.apache.org/ -## Rust Version Compatibility +## Rust Version Compatibility Policy -Datafusion crate is tested with the [minimum required stable Rust version](https://github.com/search?q=repo%3Aapache%2Farrow-datafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) +DataFusion's Minimum Required Stable Rust Version (MSRV) policy is to support +each stable Rust version for 6 months after it is +[released](https://github.com/rust-lang/rust/blob/master/RELEASES.md). This +generally translates to support for the most recent 3 to 4 stable Rust versions. + +We enforce this policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Farrow-datafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) From 496e4b67a05bb49af8d7aa1ca5035312fd4e54f9 Mon Sep 17 00:00:00 2001 From: comphead Date: Wed, 20 Mar 2024 13:32:15 -0700 Subject: [PATCH 026/117] doc: Add DataFusion profiling documentation for MacOS (#9711) * Add profiling doc for MacOS --- docs/source/index.rst | 3 +- docs/source/library-user-guide/profiling.md | 63 +++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 docs/source/library-user-guide/profiling.md diff --git a/docs/source/index.rst b/docs/source/index.rst index f7c0873f3a5f..919a7ad7036f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -79,7 +79,7 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for .. toctree:: :maxdepth: 1 :caption: Library User Guide - + library-user-guide/index library-user-guide/using-the-sql-api library-user-guide/working-with-exprs @@ -89,6 +89,7 @@ Please see the `developer’s guide`_ for contributing and `communication`_ for library-user-guide/adding-udfs library-user-guide/custom-table-providers library-user-guide/extending-operators + library-user-guide/profiling .. _toc.contributor-guide: diff --git a/docs/source/library-user-guide/profiling.md b/docs/source/library-user-guide/profiling.md new file mode 100644 index 000000000000..a20489496f0c --- /dev/null +++ b/docs/source/library-user-guide/profiling.md @@ -0,0 +1,63 @@ + + +# Profiling Cookbook + +The section contains examples how to perform CPU profiling for Apache Arrow DataFusion on different operating systems. + +## MacOS + +### Building a flamegraph + +- [cargo-flamegraph](https://github.com/flamegraph-rs/flamegraph) + +Test: + +```bash +CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --unit-test datafusion -- dataframe::tests::test_array_agg +``` + +Benchmark: + +```bash +CARGO_PROFILE_RELEASE_DEBUG=true cargo flamegraph --root --bench sql_planner -- --bench +``` + +Open `flamegraph.svg` file with the browser + +- dtrace with DataFusion CLI + +```bash +git clone https://github.com/brendangregg/FlameGraph.git /tmp/fg +cd datafusion-cli +CARGO_PROFILE_RELEASE_DEBUG=true cargo build --release +echo "select * from table;" >> test.sql +sudo dtrace -c './target/debug/datafusion-cli -f test.sql' -o out.stacks -n 'profile-997 /execname == "datafusion-cli"/ { @[ustack(100)] = count(); }' +/tmp/fg/FlameGraph/stackcollapse.pl out.stacks | /tmp/fg/FlameGraph/flamegraph.pl > flamegraph.svg +``` + +Open `flamegraph.svg` file with the browser + +### CPU profiling with XCode Instruments + +[Video: how to CPU profile DataFusion with XCode Instruments](https://youtu.be/P3dXH61Kr5U) + +## Linux + +## Windows From e522bcebb04288fe7fe27192c51dabdf04e6ac88 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Mar 2024 17:13:43 -0400 Subject: [PATCH 027/117] Minor: add ticket reference to commented out test (#9715) --- datafusion/sqllogictest/test_files/copy.slt | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 4d4f596d0c60..7884bece1f39 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -111,6 +111,7 @@ a statement ok create table test ("'test'" varchar, "'test2'" varchar, "'test3'" varchar); +# https://github.com/apache/arrow-datafusion/issues/9714 ## Until the partition by parsing uses ColumnDef, this test is meaningless since it becomes an overfit. Even in ## CREATE EXTERNAL TABLE, there is a schema mismatch, this should be an issue. # From 7a0dd6ff5a78e10a96cb6ee7e1390b2a2df941b2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 20 Mar 2024 17:33:59 -0400 Subject: [PATCH 028/117] Minor: Change path from `common_runtime` to `common-runtime` (#9717) --- Cargo.toml | 4 ++-- datafusion/{common_runtime => common-runtime}/Cargo.toml | 0 datafusion/{common_runtime => common-runtime}/README.md | 0 datafusion/{common_runtime => common-runtime}/src/common.rs | 0 datafusion/{common_runtime => common-runtime}/src/lib.rs | 0 5 files changed, 2 insertions(+), 2 deletions(-) rename datafusion/{common_runtime => common-runtime}/Cargo.toml (100%) rename datafusion/{common_runtime => common-runtime}/README.md (100%) rename datafusion/{common_runtime => common-runtime}/src/common.rs (100%) rename datafusion/{common_runtime => common-runtime}/src/lib.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index d9e69e53db7c..abe6d2c1744b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ exclude = ["datafusion-cli"] members = [ "datafusion/common", - "datafusion/common_runtime", + "datafusion/common-runtime", "datafusion/core", "datafusion/expr", "datafusion/execution", @@ -73,7 +73,7 @@ ctor = "0.2.0" dashmap = "5.4.0" datafusion = { path = "datafusion/core", version = "36.0.0", default-features = false } datafusion-common = { path = "datafusion/common", version = "36.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common_runtime", version = "36.0.0" } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "36.0.0" } datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } datafusion-expr = { path = "datafusion/expr", version = "36.0.0" } datafusion-functions = { path = "datafusion/functions", version = "36.0.0" } diff --git a/datafusion/common_runtime/Cargo.toml b/datafusion/common-runtime/Cargo.toml similarity index 100% rename from datafusion/common_runtime/Cargo.toml rename to datafusion/common-runtime/Cargo.toml diff --git a/datafusion/common_runtime/README.md b/datafusion/common-runtime/README.md similarity index 100% rename from datafusion/common_runtime/README.md rename to datafusion/common-runtime/README.md diff --git a/datafusion/common_runtime/src/common.rs b/datafusion/common-runtime/src/common.rs similarity index 100% rename from datafusion/common_runtime/src/common.rs rename to datafusion/common-runtime/src/common.rs diff --git a/datafusion/common_runtime/src/lib.rs b/datafusion/common-runtime/src/lib.rs similarity index 100% rename from datafusion/common_runtime/src/lib.rs rename to datafusion/common-runtime/src/lib.rs From dbfb153658f17448af8e7de7bab0d37f73cdeac1 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 20 Mar 2024 16:01:25 -0600 Subject: [PATCH 029/117] Use object_store:BufWriter to replace put_multipart (#9648) * feat: use BufWriter to replace put_multipart * feat: remove AbortableWrite * fix clippy * fix: add doc comment --- Cargo.toml | 2 +- .../file_format/file_compression_type.rs | 7 +- .../src/datasource/file_format/parquet.rs | 19 ++-- .../src/datasource/file_format/write/mod.rs | 100 ++---------------- .../file_format/write/orchestration.rs | 18 +--- .../core/src/datasource/physical_plan/csv.rs | 10 +- .../core/src/datasource/physical_plan/json.rs | 10 +- .../datasource/physical_plan/parquet/mod.rs | 5 +- 8 files changed, 37 insertions(+), 134 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index abe6d2c1744b..c3dade8bc6c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -93,7 +93,7 @@ indexmap = "2.0.0" itertools = "0.12" log = "^0.4" num_cpus = "1.13.0" -object_store = { version = "0.9.0", default-features = false } +object_store = { version = "0.9.1", default-features = false } parking_lot = "0.12" parquet = { version = "51.0.0", default-features = false, features = ["arrow", "async", "object_store"] } rand = "0.8" diff --git a/datafusion/core/src/datasource/file_format/file_compression_type.rs b/datafusion/core/src/datasource/file_format/file_compression_type.rs index c538819e2684..c1fbe352d37b 100644 --- a/datafusion/core/src/datasource/file_format/file_compression_type.rs +++ b/datafusion/core/src/datasource/file_format/file_compression_type.rs @@ -43,6 +43,7 @@ use futures::stream::BoxStream; use futures::StreamExt; #[cfg(feature = "compression")] use futures::TryStreamExt; +use object_store::buffered::BufWriter; use tokio::io::AsyncWrite; #[cfg(feature = "compression")] use tokio_util::io::{ReaderStream, StreamReader}; @@ -148,11 +149,11 @@ impl FileCompressionType { }) } - /// Wrap the given `AsyncWrite` so that it performs compressed writes + /// Wrap the given `BufWriter` so that it performs compressed writes /// according to this `FileCompressionType`. pub fn convert_async_writer( &self, - w: Box, + w: BufWriter, ) -> Result> { Ok(match self.variant { #[cfg(feature = "compression")] @@ -169,7 +170,7 @@ impl FileCompressionType { "Compression feature is not enabled".to_owned(), )) } - UNCOMPRESSED => w, + UNCOMPRESSED => Box::new(w), }) } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index b7626d41f4dd..ec333bb557d2 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -23,7 +23,7 @@ use std::fmt::Debug; use std::sync::Arc; use super::write::demux::start_demuxer_task; -use super::write::{create_writer, AbortableWrite, SharedBuffer}; +use super::write::{create_writer, SharedBuffer}; use super::{FileFormat, FileScanConfig}; use crate::arrow::array::{ BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, RecordBatch, @@ -56,6 +56,7 @@ use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; +use object_store::buffered::BufWriter; use parquet::arrow::arrow_writer::{ compute_leaves, get_column_writers, ArrowColumnChunk, ArrowColumnWriter, ArrowLeafColumn, @@ -613,19 +614,13 @@ impl ParquetSink { location: &Path, object_store: Arc, parquet_props: WriterProperties, - ) -> Result< - AsyncArrowWriter>, - > { - let (_, multipart_writer) = object_store - .put_multipart(location) - .await - .map_err(DataFusionError::ObjectStore)?; + ) -> Result> { + let buf_writer = BufWriter::new(object_store, location.clone()); let writer = AsyncArrowWriter::try_new( - multipart_writer, + buf_writer, self.get_writer_schema(), Some(parquet_props), )?; - Ok(writer) } @@ -943,7 +938,7 @@ async fn concatenate_parallel_row_groups( mut serialize_rx: Receiver>, schema: Arc, writer_props: Arc, - mut object_store_writer: AbortableWrite>, + mut object_store_writer: Box, ) -> Result { let merged_buff = SharedBuffer::new(INITIAL_BUFFER_BYTES); @@ -985,7 +980,7 @@ async fn concatenate_parallel_row_groups( /// task then stitches these independent RowGroups together and streams this large /// single parquet file to an ObjectStore in multiple parts. async fn output_single_parquet_file_parallelized( - object_store_writer: AbortableWrite>, + object_store_writer: Box, data: Receiver, output_schema: Arc, parquet_props: &WriterProperties, diff --git a/datafusion/core/src/datasource/file_format/write/mod.rs b/datafusion/core/src/datasource/file_format/write/mod.rs index 410a32a19cc1..42115fc7b93f 100644 --- a/datafusion/core/src/datasource/file_format/write/mod.rs +++ b/datafusion/core/src/datasource/file_format/write/mod.rs @@ -18,21 +18,18 @@ //! Module containing helper methods/traits related to enabling //! write support for the various file formats -use std::io::{Error, Write}; -use std::pin::Pin; +use std::io::Write; use std::sync::Arc; -use std::task::{Context, Poll}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::error::Result; use arrow_array::RecordBatch; -use datafusion_common::DataFusionError; use bytes::Bytes; -use futures::future::BoxFuture; +use object_store::buffered::BufWriter; use object_store::path::Path; -use object_store::{MultipartId, ObjectStore}; +use object_store::ObjectStore; use tokio::io::AsyncWrite; pub(crate) mod demux; @@ -69,79 +66,6 @@ impl Write for SharedBuffer { } } -/// Stores data needed during abortion of MultiPart writers -#[derive(Clone)] -pub(crate) struct MultiPart { - /// A shared reference to the object store - store: Arc, - multipart_id: MultipartId, - location: Path, -} - -impl MultiPart { - /// Create a new `MultiPart` - pub fn new( - store: Arc, - multipart_id: MultipartId, - location: Path, - ) -> Self { - Self { - store, - multipart_id, - location, - } - } -} - -/// A wrapper struct with abort method and writer -pub(crate) struct AbortableWrite { - writer: W, - multipart: MultiPart, -} - -impl AbortableWrite { - /// Create a new `AbortableWrite` instance with the given writer, and write mode. - pub(crate) fn new(writer: W, multipart: MultiPart) -> Self { - Self { writer, multipart } - } - - /// handling of abort for different write modes - pub(crate) fn abort_writer(&self) -> Result>> { - let multi = self.multipart.clone(); - Ok(Box::pin(async move { - multi - .store - .abort_multipart(&multi.location, &multi.multipart_id) - .await - .map_err(DataFusionError::ObjectStore) - })) - } -} - -impl AsyncWrite for AbortableWrite { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_write(cx, buf) - } - - fn poll_flush( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_flush(cx) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.get_mut().writer).poll_shutdown(cx) - } -} - /// A trait that defines the methods required for a RecordBatch serializer. pub trait BatchSerializer: Sync + Send { /// Asynchronously serializes a `RecordBatch` and returns the serialized bytes. @@ -150,19 +74,15 @@ pub trait BatchSerializer: Sync + Send { fn serialize(&self, batch: RecordBatch, initial: bool) -> Result; } -/// Returns an [`AbortableWrite`] which writes to the given object store location -/// with the specified compression +/// Returns an [`AsyncWrite`] which writes to the given object store location +/// with the specified compression. +/// We drop the `AbortableWrite` struct and the writer will not try to cleanup on failure. +/// Users can configure automatic cleanup with their cloud provider. pub(crate) async fn create_writer( file_compression_type: FileCompressionType, location: &Path, object_store: Arc, -) -> Result>> { - let (multipart_id, writer) = object_store - .put_multipart(location) - .await - .map_err(DataFusionError::ObjectStore)?; - Ok(AbortableWrite::new( - file_compression_type.convert_async_writer(writer)?, - MultiPart::new(object_store, multipart_id, location.clone()), - )) +) -> Result> { + let buf_writer = BufWriter::new(object_store, location.clone()); + file_compression_type.convert_async_writer(buf_writer) } diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs b/datafusion/core/src/datasource/file_format/write/orchestration.rs index b7f268959311..3ae2122de827 100644 --- a/datafusion/core/src/datasource/file_format/write/orchestration.rs +++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use super::demux::start_demuxer_task; -use super::{create_writer, AbortableWrite, BatchSerializer}; +use super::{create_writer, BatchSerializer}; use crate::datasource::file_format::file_compression_type::FileCompressionType; use crate::datasource::physical_plan::FileSinkConfig; use crate::error::Result; @@ -39,7 +39,7 @@ use tokio::io::{AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{self, Receiver}; use tokio::task::JoinSet; -type WriterType = AbortableWrite>; +type WriterType = Box; type SerializerType = Arc; /// Serializes a single data stream in parallel and writes to an ObjectStore @@ -49,7 +49,7 @@ type SerializerType = Arc; pub(crate) async fn serialize_rb_stream_to_object_store( mut data_rx: Receiver, serializer: Arc, - mut writer: AbortableWrite>, + mut writer: WriterType, ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> { let (tx, mut rx) = mpsc::channel::>>(100); @@ -173,19 +173,9 @@ pub(crate) async fn stateless_serialize_and_write_files( // Finalize or abort writers as appropriate for mut writer in finished_writers.into_iter() { - match any_errors { - true => { - let abort_result = writer.abort_writer(); - if abort_result.is_err() { - any_abort_errors = true; - } - } - false => { - writer.shutdown() + writer.shutdown() .await .map_err(|_| internal_datafusion_err!("Error encountered while finalizing writes! Partial results may have been written to ObjectStore!"))?; - } - } } if any_errors { diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 5fcb9f483952..31cc52f79697 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -44,6 +44,7 @@ use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; use futures::{ready, StreamExt, TryStreamExt}; +use object_store::buffered::BufWriter; use object_store::{GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -471,7 +472,7 @@ pub async fn plan_to_csv( let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let (_, mut multipart_writer) = storeref.put_multipart(&file).await?; + let mut buf_writer = BufWriter::new(storeref, file.clone()); let mut buffer = Vec::with_capacity(1024); //only write headers on first iteration let mut write_headers = true; @@ -481,15 +482,12 @@ pub async fn plan_to_csv( .build(buffer); writer.write(&batch)?; buffer = writer.into_inner(); - multipart_writer.write_all(&buffer).await?; + buf_writer.write_all(&buffer).await?; buffer.clear(); //prevent writing headers more than once write_headers = false; } - multipart_writer - .shutdown() - .await - .map_err(DataFusionError::from) + buf_writer.shutdown().await.map_err(DataFusionError::from) }); } diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 068426e0fdcb..194a4a91c34a 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -43,6 +43,7 @@ use datafusion_physical_expr::{EquivalenceProperties, LexOrdering}; use bytes::{Buf, Bytes}; use futures::{ready, StreamExt, TryStreamExt}; +use object_store::buffered::BufWriter; use object_store::{self, GetOptions, GetResultPayload, ObjectStore}; use tokio::io::AsyncWriteExt; use tokio::task::JoinSet; @@ -338,21 +339,18 @@ pub async fn plan_to_json( let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { - let (_, mut multipart_writer) = storeref.put_multipart(&file).await?; + let mut buf_writer = BufWriter::new(storeref, file.clone()); let mut buffer = Vec::with_capacity(1024); while let Some(batch) = stream.next().await.transpose()? { let mut writer = json::LineDelimitedWriter::new(buffer); writer.write(&batch)?; buffer = writer.into_inner(); - multipart_writer.write_all(&buffer).await?; + buf_writer.write_all(&buffer).await?; buffer.clear(); } - multipart_writer - .shutdown() - .await - .map_err(DataFusionError::from) + buf_writer.shutdown().await.map_err(DataFusionError::from) }); } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 282cd624d036..767cde9cc55e 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -52,6 +52,7 @@ use futures::future::BoxFuture; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; use log::debug; +use object_store::buffered::BufWriter; use object_store::path::Path; use object_store::ObjectStore; use parquet::arrow::arrow_reader::ArrowReaderOptions; @@ -698,11 +699,11 @@ pub async fn plan_to_parquet( let propclone = writer_properties.clone(); let storeref = store.clone(); - let (_, multipart_writer) = storeref.put_multipart(&file).await?; + let buf_writer = BufWriter::new(storeref, file.clone()); let mut stream = plan.execute(i, task_ctx.clone())?; join_set.spawn(async move { let mut writer = - AsyncArrowWriter::try_new(multipart_writer, plan.schema(), propclone)?; + AsyncArrowWriter::try_new(buf_writer, plan.schema(), propclone)?; while let Some(next_batch) = stream.next().await { let batch = next_batch?; writer.write(&batch).await?; From 14972e6ae4be799450d1fbb81073fa0e1cbe57bc Mon Sep 17 00:00:00 2001 From: Kunal Kundu Date: Thu, 21 Mar 2024 05:24:04 +0530 Subject: [PATCH 030/117] Fix COPY TO failing on passing format options through CLI (#9709) * Fix COPY TO failing on passing format options through CLI * fix clippy lint error --- datafusion-cli/src/exec.rs | 20 +++++++++++++++++-- .../common/src/file_options/file_type.rs | 14 +++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index ea765ee8eceb..4e374a4c0032 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -40,6 +40,7 @@ use datafusion::prelude::SessionContext; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; +use datafusion_common::FileType; use rustyline::error::ReadlineError; use rustyline::Editor; use tokio::signal; @@ -257,15 +258,23 @@ async fn create_plan( // datafusion-cli specific options before passing through to datafusion. Otherwise, datafusion // will raise Configuration errors. if let LogicalPlan::Ddl(DdlStatement::CreateExternalTable(cmd)) = &plan { - register_object_store_and_config_extensions(ctx, &cmd.location, &cmd.options) - .await?; + register_object_store_and_config_extensions( + ctx, + &cmd.location, + &cmd.options, + None, + ) + .await?; } if let LogicalPlan::Copy(copy_to) = &mut plan { + let format: FileType = (©_to.format_options).into(); + register_object_store_and_config_extensions( ctx, ©_to.output_url, ©_to.options, + Some(format), ) .await?; } @@ -303,6 +312,7 @@ pub(crate) async fn register_object_store_and_config_extensions( ctx: &SessionContext, location: &String, options: &HashMap, + format: Option, ) -> Result<()> { // Parse the location URL to extract the scheme and other components let table_path = ListingTableUrl::parse(location)?; @@ -318,6 +328,9 @@ pub(crate) async fn register_object_store_and_config_extensions( // Clone and modify the default table options based on the provided options let mut table_options = ctx.state().default_table_options().clone(); + if let Some(format) = format { + table_options.set_file_format(format); + } table_options.alter_with_string_hash_map(options)?; // Retrieve the appropriate object store based on the scheme, URL, and modified table options @@ -347,6 +360,7 @@ mod tests { &ctx, &cmd.location, &cmd.options, + None, ) .await?; } else { @@ -367,10 +381,12 @@ mod tests { let plan = ctx.state().create_logical_plan(sql).await?; if let LogicalPlan::Copy(cmd) = &plan { + let format: FileType = (&cmd.format_options).into(); register_object_store_and_config_extensions( &ctx, &cmd.output_url, &cmd.options, + Some(format), ) .await?; } else { diff --git a/datafusion/common/src/file_options/file_type.rs b/datafusion/common/src/file_options/file_type.rs index 812cb02a5f77..fc0bb7445645 100644 --- a/datafusion/common/src/file_options/file_type.rs +++ b/datafusion/common/src/file_options/file_type.rs @@ -20,6 +20,7 @@ use std::fmt::{self, Display}; use std::str::FromStr; +use crate::config::FormatOptions; use crate::error::{DataFusionError, Result}; /// The default file extension of arrow files @@ -55,6 +56,19 @@ pub enum FileType { JSON, } +impl From<&FormatOptions> for FileType { + fn from(value: &FormatOptions) -> Self { + match value { + FormatOptions::CSV(_) => FileType::CSV, + FormatOptions::JSON(_) => FileType::JSON, + #[cfg(feature = "parquet")] + FormatOptions::PARQUET(_) => FileType::PARQUET, + FormatOptions::AVRO => FileType::AVRO, + FormatOptions::ARROW => FileType::ARROW, + } + } +} + impl GetExt for FileType { fn get_ext(&self) -> String { match self { From b72d25cc3a3a4257de1fc88e8df56b4c874d60ce Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Thu, 21 Mar 2024 07:56:54 +0800 Subject: [PATCH 031/117] fix: recursive cte hangs on joins (#9687) * fix: recursive cte hangs on joins * Use ExecutionPlan::with_new_children * Naming --- .../physical-plan/src/recursive_query.rs | 26 ++++++- datafusion/sqllogictest/test_files/cte.slt | 73 +++++++++++++++++-- 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 68abc9653a8b..140820ff782a 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -309,10 +309,9 @@ impl RecursiveQueryStream { // Downstream plans should not expect any partitioning. let partition = 0; - self.recursive_stream = Some( - self.recursive_term - .execute(partition, self.task_context.clone())?, - ); + let recursive_plan = reset_plan_states(self.recursive_term.clone())?; + self.recursive_stream = + Some(recursive_plan.execute(partition, self.task_context.clone())?); self.poll_next(cx) } } @@ -343,6 +342,25 @@ fn assign_work_table( .data() } +/// Some plans will change their internal states after execution, making them unable to be executed again. +/// This function uses `ExecutionPlan::with_new_children` to fork a new plan with initial states. +/// +/// An example is `CrossJoinExec`, which loads the left table into memory and stores it in the plan. +/// However, if the data of the left table is derived from the work table, it will become outdated +/// as the work table changes. When the next iteration executes this plan again, we must clear the left table. +fn reset_plan_states(plan: Arc) -> Result> { + plan.transform_up(&|plan| { + // WorkTableExec's states have already been updated correctly. + if plan.as_any().is::() { + Ok(Transformed::no(plan)) + } else { + let new_plan = plan.clone().with_new_children(plan.children())?; + Ok(Transformed::yes(new_plan)) + } + }) + .data() +} + impl Stream for RecursiveQueryStream { type Item = Result; diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 6b9db5589391..50c88e41959f 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -40,11 +40,6 @@ ProjectionExec: expr=[1 as a, 2 as b, 3 as c] --PlaceholderRowExec - -# enable recursive CTEs -statement ok -set datafusion.execution.enable_recursive_ctes = true; - # trivial recursive CTE works query I rowsort WITH RECURSIVE nodes AS ( @@ -651,3 +646,71 @@ WITH RECURSIVE my_cte AS ( WHERE my_cte.a<5 ) SELECT a FROM my_cte; + + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9680 +query I +WITH RECURSIVE recursive_cte AS ( + SELECT 1 as val + UNION ALL + ( + WITH sub_cte AS ( + SELECT 2 as val + ) + SELECT + 2 as val + FROM recursive_cte + CROSS JOIN sub_cte + WHERE recursive_cte.val < 2 + ) +) +SELECT * FROM recursive_cte; +---- +1 +2 + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9680 +# 'recursive_cte' should be on the left of the cross join, as this is the test purpose of the above query. +query TT +explain WITH RECURSIVE recursive_cte AS ( + SELECT 1 as val + UNION ALL + ( + WITH sub_cte AS ( + SELECT 2 as val + ) + SELECT + 2 as val + FROM recursive_cte + CROSS JOIN sub_cte + WHERE recursive_cte.val < 2 + ) +) +SELECT * FROM recursive_cte; +---- +logical_plan +Projection: recursive_cte.val +--SubqueryAlias: recursive_cte +----RecursiveQuery: is_distinct=false +------Projection: Int64(1) AS val +--------EmptyRelation +------Projection: Int64(2) AS val +--------CrossJoin: +----------Filter: recursive_cte.val < Int64(2) +------------TableScan: recursive_cte +----------SubqueryAlias: sub_cte +------------Projection: Int64(2) AS val +--------------EmptyRelation +physical_plan +RecursiveQueryExec: name=recursive_cte, is_distinct=false +--ProjectionExec: expr=[1 as val] +----PlaceholderRowExec +--ProjectionExec: expr=[2 as val] +----CrossJoinExec +------CoalescePartitionsExec +--------CoalesceBatchesExec: target_batch_size=8182 +----------FilterExec: val@0 < 2 +------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +--------------WorkTableExec: name=recursive_cte +------ProjectionExec: expr=[2 as val] +--------PlaceholderRowExec From 1d8a41bc8e08b56e90d6f8e6ef20e39a126987e4 Mon Sep 17 00:00:00 2001 From: "Reilly.tang" Date: Thu, 21 Mar 2024 07:57:05 +0800 Subject: [PATCH 032/117] Move `starts_with`, `to_hex`,` trim`, `upper` to datafusion-functions (and add string_expressions) (#9541) * [task #9539] Move starts_with, to_hex, trim, upper to datafusion-functions Signed-off-by: tangruilin * Export expr_fn, restore tests * fix comments --------- Signed-off-by: tangruilin Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 57 +--- datafusion/expr/src/expr_fn.rs | 18 -- datafusion/functions/Cargo.toml | 3 + datafusion/functions/src/lib.rs | 9 +- datafusion/functions/src/string/mod.rs | 292 ++++++++++++++++++ .../functions/src/string/starts_with.rs | 89 ++++++ datafusion/functions/src/string/to_hex.rs | 155 ++++++++++ datafusion/functions/src/string/trim.rs | 78 +++++ datafusion/functions/src/string/upper.rs | 66 ++++ datafusion/physical-expr/src/functions.rs | 118 ------- .../physical-expr/src/string_expressions.rs | 77 +---- datafusion/proto/proto/datafusion.proto | 8 +- datafusion/proto/src/generated/pbjson.rs | 12 - datafusion/proto/src/generated/prost.rs | 16 +- .../proto/src/logical_plan/from_proto.rs | 22 +- datafusion/proto/src/logical_plan/to_proto.rs | 4 - datafusion/sql/src/expr/mod.rs | 2 +- 17 files changed, 720 insertions(+), 306 deletions(-) create mode 100644 datafusion/functions/src/string/mod.rs create mode 100644 datafusion/functions/src/string/starts_with.rs create mode 100644 datafusion/functions/src/string/to_hex.rs create mode 100644 datafusion/functions/src/string/trim.rs create mode 100644 datafusion/functions/src/string/upper.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 79cd6a24ce39..fffe2cf4c9c9 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -147,20 +147,12 @@ pub enum BuiltinScalarFunction { Rtrim, /// split_part SplitPart, - /// starts_with - StartsWith, /// strpos Strpos, /// substr Substr, - /// to_hex - ToHex, /// translate Translate, - /// trim - Trim, - /// upper - Upper, /// uuid Uuid, /// overlay @@ -276,13 +268,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Rpad => Volatility::Immutable, BuiltinScalarFunction::Rtrim => Volatility::Immutable, BuiltinScalarFunction::SplitPart => Volatility::Immutable, - BuiltinScalarFunction::StartsWith => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, - BuiltinScalarFunction::ToHex => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, - BuiltinScalarFunction::Trim => Volatility::Immutable, - BuiltinScalarFunction::Upper => Volatility::Immutable, BuiltinScalarFunction::OverLay => Volatility::Immutable, BuiltinScalarFunction::Levenshtein => Volatility::Immutable, BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, @@ -365,7 +353,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") } - BuiltinScalarFunction::StartsWith => Ok(Boolean), BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") @@ -373,12 +360,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Substr => { utf8_to_str_type(&input_expr_types[0], "substr") } - BuiltinScalarFunction::ToHex => Ok(match input_expr_types[0] { - Int8 | Int16 | Int32 | Int64 => Utf8, - _ => { - return plan_err!("The to_hex function can only accept integers."); - } - }), BuiltinScalarFunction::SubstrIndex => { utf8_to_str_type(&input_expr_types[0], "substr_index") } @@ -388,10 +369,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Translate => { utf8_to_str_type(&input_expr_types[0], "translate") } - BuiltinScalarFunction::Trim => utf8_to_str_type(&input_expr_types[0], "trim"), - BuiltinScalarFunction::Upper => { - utf8_to_str_type(&input_expr_types[0], "upper") - } BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Gcd @@ -476,18 +453,16 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Lower | BuiltinScalarFunction::OctetLength - | BuiltinScalarFunction::Reverse - | BuiltinScalarFunction::Upper => { + | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } BuiltinScalarFunction::Btrim | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim - | BuiltinScalarFunction::Trim => Signature::one_of( + | BuiltinScalarFunction::Rtrim => Signature::one_of( vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], self.volatility(), ), - BuiltinScalarFunction::Chr | BuiltinScalarFunction::ToHex => { + BuiltinScalarFunction::Chr => { Signature::uniform(1, vec![Int64], self.volatility()) } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { @@ -519,17 +494,17 @@ impl BuiltinScalarFunction { self.volatility(), ), - BuiltinScalarFunction::EndsWith - | BuiltinScalarFunction::Strpos - | BuiltinScalarFunction::StartsWith => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - self.volatility(), - ), + BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { + Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + self.volatility(), + ) + } BuiltinScalarFunction::Substr => Signature::one_of( vec![ @@ -749,13 +724,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Rpad => &["rpad"], BuiltinScalarFunction::Rtrim => &["rtrim"], BuiltinScalarFunction::SplitPart => &["split_part"], - BuiltinScalarFunction::StartsWith => &["starts_with"], BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], - BuiltinScalarFunction::ToHex => &["to_hex"], BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::Trim => &["trim"], - BuiltinScalarFunction::Upper => &["upper"], BuiltinScalarFunction::Uuid => &["uuid"], BuiltinScalarFunction::Levenshtein => &["levenshtein"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index b76164a1c83c..8667f631c507 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -575,12 +575,6 @@ scalar_expr!(Log10, log10, num, "base 10 logarithm of number"); scalar_expr!(Ln, ln, num, "natural logarithm (base e) of number"); scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); -scalar_expr!( - ToHex, - to_hex, - num, - "returns the hexdecimal representation of an integer" -); scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); @@ -630,19 +624,11 @@ scalar_expr!( "removes all characters, spaces by default, from the end of a string" ); scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); -scalar_expr!(StartsWith, starts_with, string prefix, "whether the `string` starts with the `prefix`"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); -scalar_expr!( - Trim, - trim, - string, - "removes all characters, space by default from the string" -); -scalar_expr!(Upper, upper, string, "converts the string to upper case"); //use vec as parameter nary_scalar_expr!( Lpad, @@ -1117,15 +1103,11 @@ mod test { test_nary_scalar_expr!(Rpad, rpad, string, count, characters); test_scalar_expr!(Rtrim, rtrim, string); test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); - test_scalar_expr!(StartsWith, starts_with, string, characters); test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); test_scalar_expr!(Substr, substring, string, position, count); - test_scalar_expr!(ToHex, to_hex, string); test_scalar_expr!(Translate, translate, string, from, to); - test_scalar_expr!(Trim, trim, string); - test_scalar_expr!(Upper, upper, string); test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); test_nary_scalar_expr!(OverLay, overlay, string, characters, position); test_scalar_expr!(Levenshtein, levenshtein, string1, string2); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 5a6da5345d7c..b12c99e84a90 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -29,6 +29,8 @@ authors = { workspace = true } rust-version = { workspace = true } [features] +# enable string functions +string_expressions = [] # enable core functions core_expressions = [] # enable datetime functions @@ -41,6 +43,7 @@ default = [ "math_expressions", "regex_expressions", "crypto_expressions", + "string_expressions", ] # enable encode/decode functions encoding_expressions = ["base64", "hex"] diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index 3a2eab8e5f05..f469b343e144 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -84,6 +84,10 @@ use log::debug; #[macro_use] pub mod macros; +#[cfg(feature = "string_expressions")] +pub mod string; +make_stub_package!(string, "string_expressions"); + /// Core datafusion expressions /// Enabled via feature flag `core_expressions` #[cfg(feature = "core_expressions")] @@ -134,6 +138,8 @@ pub mod expr_fn { pub use super::math::expr_fn::*; #[cfg(feature = "regex_expressions")] pub use super::regex::expr_fn::*; + #[cfg(feature = "string_expressions")] + pub use super::string::expr_fn::*; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -144,7 +150,8 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { .chain(encoding::functions()) .chain(math::functions()) .chain(regex::functions()) - .chain(crypto::functions()); + .chain(crypto::functions()) + .chain(string::functions()); all_functions.try_for_each(|udf| { let existing_udf = registry.register_udf(udf)?; diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs new file mode 100644 index 000000000000..08fcbb363bbc --- /dev/null +++ b/datafusion/functions/src/string/mod.rs @@ -0,0 +1,292 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}, + datatypes::DataType, +}; +use datafusion_common::{ + cast::as_generic_string_array, exec_err, plan_err, Result, ScalarValue, +}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_physical_expr::functions::Hint; +use std::{ + fmt::{Display, Formatter}, + sync::Arc, +}; + +/// Creates a function to identify the optimal return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! get_optimal_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return plan_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + **value_type + ); + } + }, + data_type => { + return plan_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + data_type + ); + } + }) + } + }; +} + +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +/// applies a unary expression to `args[0]` that is expected to be downcastable to +/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) +/// # Errors +/// This function errors when: +/// * the number of arguments is not 1 +/// * the first argument is not castable to a `GenericStringArray` +pub(crate) fn unary_string_function<'a, T, O, F, R>( + args: &[&'a dyn Array], + op: F, + name: &str, +) -> Result> +where + R: AsRef, + O: OffsetSizeTrait, + T: OffsetSizeTrait, + F: Fn(&'a str) -> R, +{ + if args.len() != 1 { + return exec_err!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + name + ); + } + + let string_array = as_generic_string_array::(args[0])?; + + // first map is the iterator, second is for the `Option<_>` + Ok(string_array.iter().map(|string| string.map(&op)).collect()) +} + +fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result +where + R: AsRef, + F: Fn(&'a str) -> R, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i32, + i32, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i64, + i64, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + } +} + +// TODO: mode allow[(dead_code)] after move ltrim and rtrim +enum TrimType { + #[allow(dead_code)] + Left, + #[allow(dead_code)] + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + + match args.len() { + 1 => { + let result = string_array + .iter() + .map(|string| string.map(|string: &str| func(string, " "))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let characters_array = as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (Some(string), Some(characters)) => Some(func(string, characters)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } + } +} + +pub(super) fn make_scalar_function( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) + .map(|(arg, hint)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hint { + Hint::AcceptsSingular => 1, + Hint::Pad => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>>()?; + + let result = (inner)(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + }) +} + +mod starts_with; +mod to_hex; +mod trim; +mod upper; +// create UDFs +make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); +make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); +make_udf_function!(trim::TrimFunc, TRIM, trim); +make_udf_function!(upper::UpperFunc, UPPER, upper); + +export_functions!( + ( + starts_with, + arg1 arg2, + "Returns true if string starts with prefix."), + ( + to_hex, + arg1, + "Converts an integer to a hexadecimal string."), + (trim, + arg1, + "removes all characters, space by default from the string"), + (upper, + arg1, + "Converts a string to uppercase.")); diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs new file mode 100644 index 000000000000..1fce399d1e70 --- /dev/null +++ b/datafusion/functions/src/string/starts_with.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use crate::string::make_scalar_function; + +/// Returns true if string starts with prefix. +/// starts_with('alphabet', 'alph') = 't' +pub fn starts_with(args: &[ArrayRef]) -> Result { + let left = as_generic_string_array::(&args[0])?; + let right = as_generic_string_array::(&args[1])?; + + let result = arrow::compute::kernels::comparison::starts_with(left, right)?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct StartsWithFunc { + signature: Signature, +} +impl StartsWithFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for StartsWithFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "starts_with" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(Boolean) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(starts_with::, vec![])(args), + DataType::LargeUtf8 => { + return make_scalar_function(starts_with::, vec![])(args); + } + _ => internal_err!("Unsupported data type"), + } + } +} diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs new file mode 100644 index 000000000000..4dfc84887da2 --- /dev/null +++ b/datafusion/functions/src/string/to_hex.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::{ + ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, +}; +use datafusion_common::cast::as_primitive_array; +use datafusion_common::Result; +use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use super::make_scalar_function; + +/// Converts the number to its equivalent hexadecimal representation. +/// to_hex(2147483647) = '7fffffff' +pub fn to_hex(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let integer_array = as_primitive_array::(&args[0])?; + + let result = integer_array + .iter() + .map(|integer| { + if let Some(value) = integer { + if let Some(value_usize) = value.to_usize() { + Ok(Some(format!("{value_usize:x}"))) + } else if let Some(value_isize) = value.to_isize() { + Ok(Some(format!("{value_isize:x}"))) + } else { + exec_err!("Unsupported data type {integer:?} for function to_hex") + } + } else { + Ok(None) + } + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct ToHexFunc { + signature: Signature, +} +impl ToHexFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ToHexFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "to_hex" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match arg_types[0] { + Int8 | Int16 | Int32 | Int64 => Utf8, + _ => { + return plan_err!("The to_hex function can only accept integers."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Int32 => make_scalar_function(to_hex::, vec![])(args), + DataType::Int64 => make_scalar_function(to_hex::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function to_hex"), + } + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{Int32Array, StringArray}, + datatypes::Int32Type, + }; + + use datafusion_common::cast::as_string_array; + + use super::*; + + #[test] + // Test to_hex function for zero + fn to_hex_zero() -> Result<()> { + let array = vec![0].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("0")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } + + #[test] + // Test to_hex function for positive number + fn to_hex_positive_number() -> Result<()> { + let array = vec![100].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("64")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } + + #[test] + // Test to_hex function for negative number + fn to_hex_negative_number() -> Result<()> { + let array = vec![-1].into_iter().collect::(); + let array_ref = Arc::new(array); + let hex_value_arc = to_hex::(&[array_ref])?; + let hex_value = as_string_array(&hex_value_arc)?; + let expected = StringArray::from(vec![Some("ffffffffffffffff")]); + assert_eq!(&expected, hex_value); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/trim.rs b/datafusion/functions/src/string/trim.rs new file mode 100644 index 000000000000..e04a171722e3 --- /dev/null +++ b/datafusion/functions/src/string/trim.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +use crate::string::{make_scalar_function, utf8_to_str_type}; + +use super::{general_trim, TrimType}; + +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +/// btrim('xyxtrimyyx', 'xyz') = 'trim' +pub fn btrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Both) +} + +#[derive(Debug)] +pub(super) struct TrimFunc { + signature: Signature, +} + +impl TrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "trim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "trim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(btrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(btrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function trim"), + } + } +} diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs new file mode 100644 index 000000000000..ed41487699aa --- /dev/null +++ b/datafusion/functions/src/string/upper.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +use crate::string::utf8_to_str_type; + +use super::handle; + +#[derive(Debug)] +pub(super) struct UpperFunc { + signature: Signature, +} + +impl UpperFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for UpperFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "upper" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "upper") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + handle(args, |string| string.to_uppercase(), "upper") + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index e76e7f56dc95..f2c93c3ec1dd 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -447,17 +447,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function split_part") } }), - BuiltinScalarFunction::StartsWith => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::starts_with::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::starts_with::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function starts_with") - } - }), BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::ends_with::)(args) @@ -497,15 +486,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function substr"), }), - BuiltinScalarFunction::ToHex => Arc::new(|args| match args[0].data_type() { - DataType::Int32 => { - make_scalar_function_inner(string_expressions::to_hex::)(args) - } - DataType::Int64 => { - make_scalar_function_inner(string_expressions::to_hex::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function to_hex"), - }), BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -527,16 +507,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function translate") } }), - BuiltinScalarFunction::Trim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function trim"), - }), - BuiltinScalarFunction::Upper => Arc::new(string_expressions::upper), BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -1797,38 +1767,6 @@ mod tests { Utf8, StringArray ); - test_function!( - StartsWith, - &[lit("alphabet"), lit("alph"),], - Ok(Some(true)), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit("alphabet"), lit("blph"),], - Ok(Some(false)), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit(ScalarValue::Utf8(None)), lit("alph"),], - Ok(None), - bool, - Boolean, - BooleanArray - ); - test_function!( - StartsWith, - &[lit("alphabet"), lit(ScalarValue::Utf8(None)),], - Ok(None), - bool, - Boolean, - BooleanArray - ); test_function!( EndsWith, &[lit("alphabet"), lit("alph"),], @@ -2149,62 +2087,6 @@ mod tests { Utf8, StringArray ); - test_function!( - Trim, - &[lit(" trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit("trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit(" trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Trim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit("upper")], - Ok(Some("UPPER")), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit("UPPER")], - Ok(Some("UPPER")), - &str, - Utf8, - StringArray - ); - test_function!( - Upper, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); Ok(()) } diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index ace7ef2888a3..86c0092a220d 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -32,16 +32,14 @@ use arrow::{ Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, OffsetSizeTrait, StringArray, }, - datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType}, + datatypes::DataType, }; use uuid::Uuid; use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use datafusion_common::{ - cast::{ - as_generic_string_array, as_int64_array, as_primitive_array, as_string_array, - }, + cast::{as_generic_string_array, as_int64_array, as_string_array}, exec_err, ScalarValue, }; use datafusion_expr::ColumnarValue; @@ -526,34 +524,6 @@ pub fn ends_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Converts the number to its equivalent hexadecimal representation. -/// to_hex(2147483647) = '7fffffff' -pub fn to_hex(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let integer_array = as_primitive_array::(&args[0])?; - - let result = integer_array - .iter() - .map(|integer| { - if let Some(value) = integer { - if let Some(value_usize) = value.to_usize() { - Ok(Some(format!("{value_usize:x}"))) - } else if let Some(value_isize) = value.to_isize() { - Ok(Some(format!("{value_isize:x}"))) - } else { - exec_err!("Unsupported data type {integer:?} for function to_hex") - } - } else { - Ok(None) - } - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) -} - /// Converts the string to all upper case. /// upper('tom') = 'TOM' pub fn upper(args: &[ColumnarValue]) -> Result { @@ -709,54 +679,13 @@ pub fn levenshtein(args: &[ArrayRef]) -> Result { #[cfg(test)] mod tests { - use arrow::{array::Int32Array, datatypes::Int32Type}; + use arrow::array::Int32Array; use arrow_array::Int64Array; use datafusion_common::cast::as_int32_array; - use crate::string_expressions; - use super::*; - #[test] - // Test to_hex function for zero - fn to_hex_zero() -> Result<()> { - let array = vec![0].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("0")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - - #[test] - // Test to_hex function for positive number - fn to_hex_positive_number() -> Result<()> { - let array = vec![100].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("64")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - - #[test] - // Test to_hex function for negative number - fn to_hex_negative_number() -> Result<()> { - let array = vec![-1].into_iter().collect::(); - let array_ref = Arc::new(array); - let hex_value_arc = string_expressions::to_hex::(&[array_ref])?; - let hex_value = as_string_array(&hex_value_arc)?; - let expected = StringArray::from(vec![Some("ffffffffffffffff")]); - assert_eq!(&expected, hex_value); - - Ok(()) - } - #[test] fn to_overlay() -> Result<()> { let string = diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 10f79a2b8cc8..c009682d5a4d 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -592,18 +592,18 @@ enum ScalarFunction { // 48 was SHA384 // 49 was SHA512 SplitPart = 50; - StartsWith = 51; + // StartsWith = 51; Strpos = 52; Substr = 53; - ToHex = 54; + // ToHex = 54; // 55 was ToTimestamp // 56 was ToTimestampMillis // 57 was ToTimestampMicros // 58 was ToTimestampSeconds // 59 was Now Translate = 60; - Trim = 61; - Upper = 62; + // Trim = 61; + // Upper = 62; Coalesce = 63; Power = 64; // 65 was StructFun diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7757a64ef359..58683dba6dff 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22949,13 +22949,9 @@ impl serde::Serialize for ScalarFunction { Self::Rpad => "Rpad", Self::Rtrim => "Rtrim", Self::SplitPart => "SplitPart", - Self::StartsWith => "StartsWith", Self::Strpos => "Strpos", Self::Substr => "Substr", - Self::ToHex => "ToHex", Self::Translate => "Translate", - Self::Trim => "Trim", - Self::Upper => "Upper", Self::Coalesce => "Coalesce", Self::Power => "Power", Self::Atan2 => "Atan2", @@ -23027,13 +23023,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Rpad", "Rtrim", "SplitPart", - "StartsWith", "Strpos", "Substr", - "ToHex", "Translate", - "Trim", - "Upper", "Coalesce", "Power", "Atan2", @@ -23134,13 +23126,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Rpad" => Ok(ScalarFunction::Rpad), "Rtrim" => Ok(ScalarFunction::Rtrim), "SplitPart" => Ok(ScalarFunction::SplitPart), - "StartsWith" => Ok(ScalarFunction::StartsWith), "Strpos" => Ok(ScalarFunction::Strpos), "Substr" => Ok(ScalarFunction::Substr), - "ToHex" => Ok(ScalarFunction::ToHex), "Translate" => Ok(ScalarFunction::Translate), - "Trim" => Ok(ScalarFunction::Trim), - "Upper" => Ok(ScalarFunction::Upper), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), "Atan2" => Ok(ScalarFunction::Atan2), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ab0ddb14ebfc..8eabb3b18603 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2891,18 +2891,18 @@ pub enum ScalarFunction { /// 48 was SHA384 /// 49 was SHA512 SplitPart = 50, - StartsWith = 51, + /// StartsWith = 51; Strpos = 52, Substr = 53, - ToHex = 54, + /// ToHex = 54; /// 55 was ToTimestamp /// 56 was ToTimestampMillis /// 57 was ToTimestampMicros /// 58 was ToTimestampSeconds /// 59 was Now Translate = 60, - Trim = 61, - Upper = 62, + /// Trim = 61; + /// Upper = 62; Coalesce = 63, Power = 64, /// 65 was StructFun @@ -3022,13 +3022,9 @@ impl ScalarFunction { ScalarFunction::Rpad => "Rpad", ScalarFunction::Rtrim => "Rtrim", ScalarFunction::SplitPart => "SplitPart", - ScalarFunction::StartsWith => "StartsWith", ScalarFunction::Strpos => "Strpos", ScalarFunction::Substr => "Substr", - ScalarFunction::ToHex => "ToHex", ScalarFunction::Translate => "Translate", - ScalarFunction::Trim => "Trim", - ScalarFunction::Upper => "Upper", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", ScalarFunction::Atan2 => "Atan2", @@ -3094,13 +3090,9 @@ impl ScalarFunction { "Rpad" => Some(Self::Rpad), "Rtrim" => Some(Self::Rtrim), "SplitPart" => Some(Self::SplitPart), - "StartsWith" => Some(Self::StartsWith), "Strpos" => Some(Self::Strpos), "Substr" => Some(Self::Substr), - "ToHex" => Some(Self::ToHex), "Translate" => Some(Self::Translate), - "Trim" => Some(Self::Trim), - "Upper" => Some(Self::Upper), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), "Atan2" => Some(Self::Atan2), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 8581156e2bb8..64ceb37d2961 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -57,10 +57,9 @@ use datafusion_expr::{ logical_plan::{PlanType, StringifiedPlan}, lower, lpad, ltrim, nanvl, octet_length, overlay, pi, power, radians, random, repeat, replace, reverse, right, round, rpad, rtrim, signum, sin, sinh, split_part, sqrt, - starts_with, strpos, substr, substr_index, substring, to_hex, translate, trim, trunc, - upper, uuid, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, - BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, - GroupingSet, + strpos, substr, substr_index, substring, translate, trunc, uuid, AggregateFunction, + Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, + GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -462,8 +461,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::OctetLength => Self::OctetLength, ScalarFunction::Concat => Self::Concat, ScalarFunction::Lower => Self::Lower, - ScalarFunction::Upper => Self::Upper, - ScalarFunction::Trim => Self::Trim, ScalarFunction::Ltrim => Self::Ltrim, ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::Log2 => Self::Log2, @@ -485,10 +482,8 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Right => Self::Right, ScalarFunction::Rpad => Self::Rpad, ScalarFunction::SplitPart => Self::SplitPart, - ScalarFunction::StartsWith => Self::StartsWith, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, - ScalarFunction::ToHex => Self::ToHex, ScalarFunction::Uuid => Self::Uuid, ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, @@ -1444,10 +1439,6 @@ pub fn parse_expr( ScalarFunction::Lower => { Ok(lower(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Upper => { - Ok(upper(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Trim => Ok(trim(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Ltrim => { Ok(ltrim(parse_expr(&args[0], registry, codec)?)) } @@ -1532,10 +1523,6 @@ pub fn parse_expr( parse_expr(&args[1], registry, codec)?, parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::StartsWith => Ok(starts_with( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::EndsWith => Ok(ends_with( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, @@ -1563,9 +1550,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::ToHex => { - Ok(to_hex(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Translate => Ok(translate( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 05a29ff6d42b..89bd93550a04 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1481,8 +1481,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::OctetLength => Self::OctetLength, BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::Lower => Self::Lower, - BuiltinScalarFunction::Upper => Self::Upper, - BuiltinScalarFunction::Trim => Self::Trim, BuiltinScalarFunction::Ltrim => Self::Ltrim, BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::Log2 => Self::Log2, @@ -1505,10 +1503,8 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Right => Self::Right, BuiltinScalarFunction::Rpad => Self::Rpad, BuiltinScalarFunction::SplitPart => Self::SplitPart, - BuiltinScalarFunction::StartsWith => Self::StartsWith, BuiltinScalarFunction::Strpos => Self::Strpos, BuiltinScalarFunction::Substr => Self::Substr, - BuiltinScalarFunction::ToHex => Self::ToHex, BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 5e9c0623a265..c34b42193cec 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -747,7 +747,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(TrimWhereField::Leading) => BuiltinScalarFunction::Ltrim, Some(TrimWhereField::Trailing) => BuiltinScalarFunction::Rtrim, Some(TrimWhereField::Both) => BuiltinScalarFunction::Btrim, - None => BuiltinScalarFunction::Trim, + None => BuiltinScalarFunction::Btrim, }; let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; From dc373a3550610ce041fd73a1eabe08b096d6ed27 Mon Sep 17 00:00:00 2001 From: Jeffrey Vo Date: Fri, 22 Mar 2024 01:13:44 +1100 Subject: [PATCH 033/117] Support for `extract(x from time)` / `date_part` from time types (#8693) * Initial support for `extract(x from time)` * Update function docs * Add extract tests --- datafusion/common/src/cast.rs | 37 ++- .../functions/src/datetime/date_part.rs | 31 +- datafusion/sqllogictest/test_files/expr.slt | 287 ++++++++++++++++++ .../source/user-guide/sql/scalar_functions.md | 27 +- 4 files changed, 345 insertions(+), 37 deletions(-) diff --git a/datafusion/common/src/cast.rs b/datafusion/common/src/cast.rs index 088f03e002ed..0dc0532bbb6f 100644 --- a/datafusion/common/src/cast.rs +++ b/datafusion/common/src/cast.rs @@ -24,17 +24,18 @@ use crate::{downcast_value, DataFusionError, Result}; use arrow::{ array::{ Array, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, - DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, Float32Array, - Float64Array, GenericBinaryArray, GenericListArray, GenericStringArray, - Int32Array, Int64Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, - IntervalYearMonthArray, LargeListArray, ListArray, MapArray, NullArray, - OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt32Array, UInt64Array, UInt8Array, UnionArray, + Decimal256Array, DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, + Float32Array, Float64Array, GenericBinaryArray, GenericListArray, + GenericStringArray, Int32Array, Int64Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeListArray, ListArray, + MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, + Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt32Array, UInt64Array, + UInt8Array, UnionArray, }, datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType}, }; -use arrow_array::Decimal256Array; // Downcast ArrayRef to Date32Array pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> { @@ -154,6 +155,26 @@ pub fn as_union_array(array: &dyn Array) -> Result<&UnionArray> { Ok(downcast_value!(array, UnionArray)) } +// Downcast ArrayRef to Time32SecondArray +pub fn as_time32_second_array(array: &dyn Array) -> Result<&Time32SecondArray> { + Ok(downcast_value!(array, Time32SecondArray)) +} + +// Downcast ArrayRef to Time32MillisecondArray +pub fn as_time32_millisecond_array(array: &dyn Array) -> Result<&Time32MillisecondArray> { + Ok(downcast_value!(array, Time32MillisecondArray)) +} + +// Downcast ArrayRef to Time64MicrosecondArray +pub fn as_time64_microsecond_array(array: &dyn Array) -> Result<&Time64MicrosecondArray> { + Ok(downcast_value!(array, Time64MicrosecondArray)) +} + +// Downcast ArrayRef to Time64NanosecondArray +pub fn as_time64_nanosecond_array(array: &dyn Array) -> Result<&Time64NanosecondArray> { + Ok(downcast_value!(array, Time64NanosecondArray)) +} + // Downcast ArrayRef to TimestampNanosecondArray pub fn as_timestamp_nanosecond_array( array: &dyn Array, diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index 5d2719bf0365..b41f7e13cff2 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -20,14 +20,17 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef, Float64Array}; use arrow::compute::{binary, cast, date_part, DatePart}; -use arrow::datatypes::DataType::{Date32, Date64, Float64, Timestamp, Utf8}; +use arrow::datatypes::DataType::{ + Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, +}; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::cast::{ - as_date32_array, as_date64_array, as_int32_array, as_timestamp_microsecond_array, - as_timestamp_millisecond_array, as_timestamp_nanosecond_array, - as_timestamp_second_array, + as_date32_array, as_date64_array, as_int32_array, as_time32_millisecond_array, + as_time32_second_array, as_time64_microsecond_array, as_time64_nanosecond_array, + as_timestamp_microsecond_array, as_timestamp_millisecond_array, + as_timestamp_nanosecond_array, as_timestamp_second_array, }; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; @@ -68,6 +71,10 @@ impl DatePartFunc { ]), Exact(vec![Utf8, Date64]), Exact(vec![Utf8, Date32]), + Exact(vec![Utf8, Time32(Second)]), + Exact(vec![Utf8, Time32(Millisecond)]), + Exact(vec![Utf8, Time64(Microsecond)]), + Exact(vec![Utf8, Time64(Nanosecond)]), ], Volatility::Immutable, ), @@ -149,12 +156,9 @@ fn date_part_f64(array: &dyn Array, part: DatePart) -> Result { Ok(cast(date_part(array, part)?.as_ref(), &Float64)?) } -/// invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the +/// Invoke [`date_part`] on an `array` (e.g. Timestamp) and convert the /// result to a total number of seconds, milliseconds, microseconds or /// nanoseconds -/// -/// # Panics -/// If `array` is not a temporal type such as Timestamp or Date32 fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { let sf = match unit { Second => 1_f64, @@ -163,6 +167,7 @@ fn seconds(array: &dyn Array, unit: TimeUnit) -> Result { Nanosecond => 1_000_000_000_f64, }; let secs = date_part(array, DatePart::Second)?; + // This assumes array is primitive and not a dictionary let secs = as_int32_array(secs.as_ref())?; let subsecs = date_part(array, DatePart::Nanosecond)?; let subsecs = as_int32_array(subsecs.as_ref())?; @@ -189,6 +194,16 @@ fn epoch(array: &dyn Array) -> Result { } Date32 => as_date32_array(array)?.unary(|x| x as f64 * SECONDS_IN_A_DAY), Date64 => as_date64_array(array)?.unary(|x| x as f64 / 1_000_f64), + Time32(Second) => as_time32_second_array(array)?.unary(|x| x as f64), + Time32(Millisecond) => { + as_time32_millisecond_array(array)?.unary(|x| x as f64 / 1_000_f64) + } + Time64(Microsecond) => { + as_time64_microsecond_array(array)?.unary(|x| x as f64 / 1_000_000_f64) + } + Time64(Nanosecond) => { + as_time64_nanosecond_array(array)?.unary(|x| x as f64 / 1_000_000_000_f64) + } d => return exec_err!("Can not convert {d:?} to epoch"), }; Ok(Arc::new(f)) diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 73fb5eec97d5..d6343f9a3fe8 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -939,6 +939,293 @@ SELECT date_part('nanosecond', '2020-09-08T12:00:12.12345678+00:00') ---- 12123456780 +# test_date_part_time + +## time32 seconds +query R +SELECT date_part('hour', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query R +SELECT extract(second from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +50000000000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +query R +SELECT extract(epoch from arrow_cast('23:32:50'::time, 'Time32(Second)')) +---- +84770 + +## time32 milliseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50.123 + +query R +SELECT extract(second from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50.123 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +50123000000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123'::time, 'Time32(Millisecond)')) +---- +84770.123 + +## time64 microseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50.123456 + +query R +SELECT extract(second from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50.123456 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123.456 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123.456 + +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456000 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +50123456000 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456'::time, 'Time64(Microsecond)')) +---- +84770.123456 + +## time64 nanoseconds +query R +SELECT date_part('hour', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query R +SELECT extract(hour from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +23 + +query R +SELECT date_part('minute', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query R +SELECT extract(minute from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +32 + +query R +SELECT date_part('second', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50.123456789 + +query R +SELECT extract(second from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50.123456789 + +query R +SELECT date_part('millisecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123.456789 + +query R +SELECT extract(millisecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123.456789 + +# just some floating point stuff happening in the result here +query R +SELECT date_part('microsecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456.789000005 + +query R +SELECT extract(microsecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456.789000005 + +query R +SELECT date_part('nanosecond', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456789 + +query R +SELECT extract(nanosecond from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +50123456789 + +query R +SELECT date_part('epoch', arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + +query R +SELECT extract(epoch from arrow_cast('23:32:50.123456789'::time, 'Time64(Nanosecond)')) +---- +84770.123456789 + # test_extract_epoch query R diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index b63fa9950ae0..d4570dbc35f2 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1624,34 +1624,19 @@ _Alias of [date_part](#date_part)._ ### `extract` Returns a sub-field from a time value as an integer. -Similar to `date_part`, but with different arguments. ``` extract(field FROM source) ``` -#### Arguments - -- **field**: Part or field of the date to return. - The following date fields are supported: +Equivalent to calling `date_part('field', source)`. For example, these are equivalent: - - year - - quarter _(emits value in inclusive range [1, 4] based on which quartile of the year the date is in)_ - - month - - week _(week of the year)_ - - day _(day of the month)_ - - hour - - minute - - second - - millisecond - - microsecond - - nanosecond - - dow _(day of the week)_ - - doy _(day of the year)_ - - epoch _(seconds since Unix epoch)_ +```sql +extract(day FROM '2024-04-13'::date) +date_part('day', '2024-04-13'::date) +``` -- **source**: Source time expression to operate on. - Can be a constant, column, or function. +See [date_part](#date_part). ### `make_date` From edaf235828a90042eaf918ec4b3ee5ab2716f060 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 21 Mar 2024 08:12:18 -0700 Subject: [PATCH 034/117] doc: Updated known users list and usage dependency description (#9718) * minor: update known users and usage description --- docs/source/user-guide/example-usage.md | 10 +++++----- docs/source/user-guide/introduction.md | 1 + 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index c5eefbdaf156..31b599ac3308 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -23,20 +23,20 @@ In this example some simple processing is performed on the [`example.csv`](https Even [`more code examples`](https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples) attached to the project. -## Add DataFusion as a dependency +## Add published DataFusion dependency Find latest available Datafusion version on [DataFusion's crates.io] page. Add the dependency to your `Cargo.toml` file: ```toml -datafusion = "31" +datafusion = "latest_version" tokio = "1.0" ``` -## Add DataFusion latest codebase as a dependency +## Add latest non published DataFusion dependency -Cargo supports adding dependency directly from Github which allows testing out latest DataFusion codebase without waiting the code to be released to crates.io -according to the [DataFusion release schedule](https://github.com/apache/arrow-datafusion/blob/main/dev/release/README.md#release-process) +DataFusion changes are published to `crates.io` according to [release schedule](https://github.com/apache/arrow-datafusion/blob/main/dev/release/README.md#release-process) +In case if it is required to test out DataFusion changes which are merged but yet to be published, Cargo supports adding dependency directly to Github branch ```toml datafusion = { git = "https://github.com/apache/arrow-datafusion", branch = "main"} diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index ae2684699726..0e9d731c6e21 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -96,6 +96,7 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/arrow-ballista) Distributed SQL Query Engine +- [Comet](https://github.com/apache/arrow-datafusion-comet) Apache Spark native query execution plugin - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) - [Dask SQL](https://github.com/dask-contrib/dask-sql) Distributed SQL query engine in Python From c5c9d3f57f361c6c01d0cb01c416f6a7e9dfd906 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 21 Mar 2024 11:18:14 -0400 Subject: [PATCH 035/117] Minor: improve documentation for `CommonSubexprEliminate` (#9700) --- .../optimizer/src/common_subexpr_eliminate.rs | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 7b8eccad5133..e73885c6aaef 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -53,10 +53,32 @@ type ExprSet = HashMap; /// here is not such a good choose. type Identifier = String; -/// Perform Common Sub-expression Elimination optimization. +/// Performs Common Sub-expression Elimination optimization. /// -/// Currently only common sub-expressions within one logical plan will +/// This optimization improves query performance by computing expressions that +/// appear more than once and reusing those results rather than re-computing the +/// same value +/// +/// Currently only common sub-expressions within a single `LogicalPlan` are /// be eliminated. +/// +/// # Example +/// +/// Given a projection that computes the same expensive expression +/// multiple times such as parsing as string as a date with `to_date` twice: +/// +/// ```text +/// ProjectionExec(expr=[extract (day from to_date(c1)), extract (year from to_date(c1))]) +/// ``` +/// +/// This optimization will rewrite the plan to compute the common expression once +/// using a new `ProjectionExec` and then rewrite the original expressions to +/// refer to that new column. +/// +/// ```text +/// ProjectionExec(exprs=[extract (day from new_col), extract (year from new_col)]) <-- reuse here +/// ProjectionExec(exprs=[to_date(c1) as new_col]) <-- compute to_date once +/// ``` pub struct CommonSubexprEliminate {} impl CommonSubexprEliminate { From eda2ddfc123a0549c7df7fe0500b48bff1f76910 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 21 Mar 2024 11:07:40 -0700 Subject: [PATCH 036/117] build: modify code to comply with latest clippy requirement (#9725) * fix CI clippy * fix scalar size test * fix tests * fix tests --- datafusion/common/src/scalar/mod.rs | 3 ++- datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs | 2 +- datafusion/expr/src/expr_rewriter/mod.rs | 2 +- datafusion/functions/benches/regx.rs | 4 ++-- datafusion/functions/benches/to_char.rs | 2 +- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 3 ++- datafusion/physical-expr/src/equivalence/class.rs | 6 +++--- datafusion/physical-expr/src/equivalence/ordering.rs | 4 ++-- datafusion/physical-expr/src/equivalence/properties.rs | 2 +- datafusion/physical-plan/src/sorts/partial_sort.rs | 2 +- datafusion/physical-plan/src/union.rs | 2 +- datafusion/substrait/src/serializer.rs | 1 + 12 files changed, 18 insertions(+), 15 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index d33b8b6e142c..2a99b667d8f1 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -4539,7 +4539,8 @@ mod tests { // The alignment requirements differ across architectures and // thus the size of the enum appears to as well - assert_eq!(std::mem::size_of::(), 48); + // The value can be changed depending on rust version + assert_eq!(std::mem::size_of::(), 64); } #[test] diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 59905d859dc8..8df16e7944d2 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -46,7 +46,7 @@ use tokio::task::JoinSet; /// same results #[tokio::test(flavor = "multi_thread")] async fn streaming_aggregate_test() { - let test_cases = vec![ + let test_cases = [ vec!["a"], vec!["b", "a"], vec!["c", "a"], diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index ea3ffadda391..7a227a91c455 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -379,7 +379,7 @@ mod test { let expr = col("a") + col("b"); let schema_a = make_schema_with_empty_metadata(vec![make_field("\"tableA\"", "a")]); - let schemas = vec![schema_a]; + let schemas = [schema_a]; let schemas = schemas.iter().collect::>(); let error = diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index 5831e263b4eb..f22be5ba3532 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -44,7 +44,7 @@ fn data(rng: &mut ThreadRng) -> StringArray { } fn regex(rng: &mut ThreadRng) -> StringArray { - let samples = vec![ + let samples = [ ".*([A-Z]{1}).*".to_string(), "^(A).*".to_string(), r#"[\p{Letter}-]+"#.to_string(), @@ -60,7 +60,7 @@ fn regex(rng: &mut ThreadRng) -> StringArray { } fn flags(rng: &mut ThreadRng) -> StringArray { - let samples = vec![Some("i".to_string()), Some("im".to_string()), None]; + let samples = [Some("i".to_string()), Some("im".to_string()), None]; let mut sb = StringBuilder::new(); for _ in 0..1000 { let sample = samples.choose(rng).unwrap(); diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 45a40f175da4..d9a153e64abc 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -64,7 +64,7 @@ fn data(rng: &mut ThreadRng) -> Date32Array { } fn patterns(rng: &mut ThreadRng) -> StringArray { - let samples = vec![ + let samples = [ "%Y:%m:%d".to_string(), "%d-%m-%Y".to_string(), "%d%m%Y".to_string(), diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 61e002ece98b..1cbe7decf15b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -405,11 +405,12 @@ struct ConstEvaluator<'a> { input_batch: RecordBatch, } +#[allow(dead_code)] /// The simplify result of ConstEvaluator enum ConstSimplifyResult { // Expr was simplifed and contains the new expression Simplified(ScalarValue), - // Evalaution encountered an error, contains the original expression + // Evaluation encountered an error, contains the original expression SimplifyRuntimeError(DataFusionError, Expr), } diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 280535f5e6be..58519c61cf1f 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -535,7 +535,7 @@ mod tests { #[test] fn test_remove_redundant_entries_eq_group() -> Result<()> { - let entries = vec![ + let entries = [ EquivalenceClass::new(vec![lit(1), lit(1), lit(2)]), // This group is meaningless should be removed EquivalenceClass::new(vec![lit(3), lit(3)]), @@ -543,11 +543,11 @@ mod tests { ]; // Given equivalences classes are not in succinct form. // Expected form is the most plain representation that is functionally same. - let expected = vec![ + let expected = [ EquivalenceClass::new(vec![lit(1), lit(2)]), EquivalenceClass::new(vec![lit(4), lit(5), lit(6)]), ]; - let mut eq_groups = EquivalenceGroup::new(entries); + let mut eq_groups = EquivalenceGroup::new(entries.to_vec()); eq_groups.remove_redundant_entries(); let eq_groups = eq_groups.classes; diff --git a/datafusion/physical-expr/src/equivalence/ordering.rs b/datafusion/physical-expr/src/equivalence/ordering.rs index c7cb9e5f530e..1364d3a8c028 100644 --- a/datafusion/physical-expr/src/equivalence/ordering.rs +++ b/datafusion/physical-expr/src/equivalence/ordering.rs @@ -746,7 +746,7 @@ mod tests { // Generate a data that satisfies properties given let table_data_with_properties = generate_table_for_eq_properties(&eq_properties, N_ELEMENTS, N_DISTINCT)?; - let col_exprs = vec![ + let col_exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, @@ -815,7 +815,7 @@ mod tests { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let exprs = vec![ + let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index a08e85b24162..5eb9d6eb1b86 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -1793,7 +1793,7 @@ mod tests { Operator::Plus, col("b", &test_schema)?, )) as Arc; - let exprs = vec![ + let exprs = [ col("a", &test_schema)?, col("b", &test_schema)?, col("c", &test_schema)?, diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 500df6153fdb..2acb881246a4 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -578,7 +578,7 @@ mod tests { #[tokio::test] async fn test_partial_sort2() -> Result<()> { let task_ctx = Arc::new(TaskContext::default()); - let source_tables = vec![ + let source_tables = [ test::build_table_scan_i32( ("a", &vec![0, 0, 0, 0, 1, 1, 1, 1]), ("b", &vec![1, 1, 3, 3, 4, 4, 2, 2]), diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 7eaac74a5449..64322bd5f101 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -740,7 +740,7 @@ mod tests { let col_e = &col("e", &schema)?; let col_f = &col("f", &schema)?; let options = SortOptions::default(); - let test_cases = vec![ + let test_cases = [ //-----------TEST CASE 1----------// ( // First child orderings diff --git a/datafusion/substrait/src/serializer.rs b/datafusion/substrait/src/serializer.rs index e8698253edb5..6b81e33dfc37 100644 --- a/datafusion/substrait/src/serializer.rs +++ b/datafusion/substrait/src/serializer.rs @@ -27,6 +27,7 @@ use substrait::proto::Plan; use std::fs::OpenOptions; use std::io::{Read, Write}; +#[allow(clippy::suspicious_open_options)] pub async fn serialize(sql: &str, ctx: &SessionContext, path: &str) -> Result<()> { let protobuf_out = serialize_bytes(sql, ctx).await; let mut file = OpenOptions::new().create(true).write(true).open(path)?; From 6d74025bb17342f9451702b84cf8d04171a29149 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 21 Mar 2024 15:08:27 -0400 Subject: [PATCH 037/117] Minor: return internal error rather than panic on unexpected error in COUNT DISTINCT (#9712) * Minor: return internal error rather than panic on unexpected error in COUNT DISTICT * Update datafusion/physical-expr/src/aggregate/count_distinct/mod.rs * Update datafusion/physical-expr/src/aggregate/count_distinct/mod.rs Co-authored-by: comphead --------- Co-authored-by: comphead --- .../physical-expr/src/aggregate/count_distinct/mod.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index fb5e7710496c..9c5605f495ea 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -35,7 +35,7 @@ use arrow_array::types::{ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::Accumulator; use crate::aggregate::count_distinct::bytes::BytesDistinctCountAccumulator; @@ -268,8 +268,11 @@ impl Accumulator for DistinctCountAccumulator { let array = &states[0]; let list_array = array.as_list::(); for inner_array in list_array.iter() { - let inner_array = inner_array - .expect("counts are always non null, so are intermediate results"); + let Some(inner_array) = inner_array else { + return internal_err!( + "Intermediate results of COUNT DISTINCT should always be non null" + ); + }; self.update_batch(&[inner_array])?; } Ok(()) From 5f0cb49c8b1a47830d80a7add1d3c96d7d5a0025 Mon Sep 17 00:00:00 2001 From: wiedld Date: Thu, 21 Mar 2024 12:40:51 -0700 Subject: [PATCH 038/117] fix(9678): short circuiting prevented population of visited stack, for common subexpr elimination optimization (#9685) * test(9678): reproducer of short-circuiting causing expr elimination to error * fix(9678): populate visited stack for short-circuited expressions, during the common-expr elimination optimization * test(9678): reproducer for optimizer error (in common_subexpr_eliminate), as seen in other test case * chore: extract id_array into abstraction, to make it more clear the relationship between the two visitors * refactor: tweak the fix and make code more explicit (JumpMark, node_to_identifier) * fix: get the series_number and curr_id with the correct self.current_idx, before the various incr/decr * chore: remove unneeded conditional check (already done earlier), and add code comments * Refine documentation in common_subexpr_eliminate.rs * chore: cleanup -- fix 1 doc comment and consolidate common-expr-elimination test with other expr test --------- Co-authored-by: Andrew Lamb --- .../optimizer/src/common_subexpr_eliminate.rs | 130 +++++++++++------- datafusion/sqllogictest/test_files/expr.slt | 36 +++++ 2 files changed, 115 insertions(+), 51 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index e73885c6aaef..0c9064d0641f 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,8 +29,7 @@ use datafusion_common::tree_node::{ TreeNodeVisitor, }; use datafusion_common::{ - internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, - DataFusionError, Result, + internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; @@ -42,8 +41,36 @@ use datafusion_expr::{col, Expr, ExprSchemable}; /// - DataType of this expression. type ExprSet = HashMap; -/// Identifier type. Current implementation use describe of an expression (type String) as -/// Identifier. +/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an +/// initial expression walk. +/// +/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove +/// common subexpressions. +/// +/// Elements in this array are created on the walk down the expression tree +/// during `f_down`. Thus element 0 is the root of the expression tree. The +/// tuple contains: +/// - series_number. +/// - Incremented during `f_up`, start from 1. +/// - Thus, items with higher idx have the lower series_number. +/// - [`Identifier`] +/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination. +/// +/// # Example +/// An expression like `(a + b)` would have the following `IdArray`: +/// ```text +/// [ +/// (3, "a + b"), +/// (2, "a"), +/// (1, "b") +/// ] +/// ``` +type IdArray = Vec<(usize, Identifier)>; + +/// Identifier for each subexpression. +/// +/// Note that the current implementation uses the `Display` of an expression +/// (a `String`) as `Identifier`. /// /// An identifier should (ideally) be able to "hash", "accumulate", "equal" and "have no /// collision (as low as possible)" @@ -293,8 +320,9 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - let id = - ExprIdentifierVisitor::<'static>::desc_expr(&expr_rewritten); + let id = ExprIdentifierVisitor::<'static>::expr_identifier( + &expr_rewritten, + ); let out_name = expr_rewritten.to_field(&new_input_schema)?.qualified_name(); agg_exprs.push(expr_rewritten.alias(&id)); @@ -557,15 +585,15 @@ impl ExprMask { /// This visitor implementation use a stack `visit_stack` to track traversal, which /// lets us know when a sub-tree's visiting is finished. When `pre_visit` is called /// (traversing to a new node), an `EnterMark` and an `ExprItem` will be pushed into stack. -/// And try to pop out a `EnterMark` on leaving a node (`post_visit()`). All `ExprItem` +/// And try to pop out a `EnterMark` on leaving a node (`f_up()`). All `ExprItem` /// before the first `EnterMark` is considered to be sub-tree of the leaving node. /// /// This visitor also records identifier in `id_array`. Makes the following traverse /// pass can get the identifier of a node without recalculate it. We assign each node /// in the expr tree a series number, start from 1, maintained by `series_number`. -/// Series number represents the order we left (`post_visit`) a node. Has the property +/// Series number represents the order we left (`f_up()`) a node. Has the property /// that child node's series number always smaller than parent's. While `id_array` is -/// organized in the order we enter (`pre_visit`) a node. `node_count` helps us to +/// organized in the order we enter (`f_down()`) a node. `node_count` helps us to /// get the index of `id_array` for each node. /// /// `Expr` without sub-expr (column, literal etc.) will not have identifier @@ -574,15 +602,15 @@ struct ExprIdentifierVisitor<'a> { // param expr_set: &'a mut ExprSet, /// series number (usize) and identifier. - id_array: &'a mut Vec<(usize, Identifier)>, + id_array: &'a mut IdArray, /// input schema for the node that we're optimizing, so we can determine the correct datatype /// for each subexpression input_schema: DFSchemaRef, // inner states visit_stack: Vec, - /// increased in pre_visit, start from 0. + /// increased in fn_down, start from 0. node_count: usize, - /// increased in post_visit, start from 1. + /// increased in fn_up, start from 1. series_number: usize, /// which expression should be skipped? expr_mask: ExprMask, @@ -593,31 +621,33 @@ enum VisitRecord { /// `usize` is the monotone increasing series number assigned in pre_visit(). /// Starts from 0. Is used to index the identifier array `id_array` in post_visit(). EnterMark(usize), + /// the node's children were skipped => jump to f_up on same node + JumpMark(usize), /// Accumulated identifier of sub expression. ExprItem(Identifier), } impl ExprIdentifierVisitor<'_> { - fn desc_expr(expr: &Expr) -> String { + fn expr_identifier(expr: &Expr) -> Identifier { format!("{expr}") } /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { + fn pop_enter_mark(&mut self) -> (usize, Identifier) { let mut desc = String::new(); while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(idx) => { - return Some((idx, desc)); + VisitRecord::EnterMark(idx) | VisitRecord::JumpMark(idx) => { + return (idx, desc); } - VisitRecord::ExprItem(s) => { - desc.push_str(&s); + VisitRecord::ExprItem(id) => { + desc.push_str(&id); } } } - None + unreachable!("Enter mark should paired with node number"); } } @@ -625,34 +655,39 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type Node = Expr; fn f_down(&mut self, expr: &Expr) -> Result { + // put placeholder, sets the proper array length + self.id_array.push((0, "".to_string())); + // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(TreeNodeRecursion::Jump); + self.visit_stack + .push(VisitRecord::JumpMark(self.node_count)); + return Ok(TreeNodeRecursion::Jump); // go to f_up } + self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; - // put placeholder - self.id_array.push((0, "".to_string())); + Ok(TreeNodeRecursion::Continue) } fn f_up(&mut self, expr: &Expr) -> Result { self.series_number += 1; - let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { - return Ok(TreeNodeRecursion::Continue); - }; + let (idx, sub_expr_identifier) = self.pop_enter_mark(); + // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { - self.id_array[idx].0 = self.series_number; - let desc = Self::desc_expr(expr); - self.visit_stack.push(VisitRecord::ExprItem(desc)); + let curr_expr_identifier = Self::expr_identifier(expr); + self.visit_stack + .push(VisitRecord::ExprItem(curr_expr_identifier)); + self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr return Ok(TreeNodeRecursion::Continue); } - let mut desc = Self::desc_expr(expr); - desc.push_str(&sub_expr_desc); + let mut desc = Self::expr_identifier(expr); + desc.push_str(&sub_expr_identifier); self.id_array[idx] = (self.series_number, desc.clone()); self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); @@ -693,7 +728,7 @@ fn expr_to_identifier( /// evaluate result of replaced expression. struct CommonSubexprRewriter<'a> { expr_set: &'a ExprSet, - id_array: &'a [(usize, Identifier)], + id_array: &'a IdArray, /// Which identifier is replaced. affected_id: &'a mut BTreeSet, @@ -715,20 +750,26 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { if expr.short_circuits() || is_volatile_expression(&expr)? { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } + + let (series_number, curr_id) = &self.id_array[self.curr_index]; + + // halting conditions if self.curr_index >= self.id_array.len() - || self.max_series_number > self.id_array[self.curr_index].0 + || self.max_series_number > *series_number { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { - self.curr_index += 1; + self.curr_index += 1; // incr idx for id_array, when not jumping return Ok(Transformed::no(expr)); } + + // lookup previously visited expression match self.expr_set.get(curr_id) { Some((_, counter, _)) => { + // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); @@ -741,23 +782,10 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { )); } - let (series_number, id) = &self.id_array[self.curr_index]; + // incr idx for id_array, when not jumping self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - internal_datafusion_err!("expr_set invalid state") - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } + // series_number was the inverse number ordering (when doing f_up) self.max_series_number = *series_number; // step index to skip all sub-node (which has smaller series number). while self.curr_index < self.id_array.len() @@ -771,7 +799,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(id).alias(expr_name), + col(curr_id).alias(expr_name), true, TreeNodeRecursion::Jump, )) @@ -787,7 +815,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { fn replace_common_expr( expr: Expr, - id_array: &[(usize, Identifier)], + id_array: &IdArray, expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index d6343f9a3fe8..69f3e439eac9 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2205,3 +2205,39 @@ false true false true NULL NULL NULL NULL false false true true false false true false + + +############# +## Common Subexpr Eliminate Tests +############# + +statement ok +CREATE TABLE doubles ( + f64 DOUBLE +) as VALUES + (10.1) +; + +# common subexpr with alias +query RRR rowsort +select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles; +---- +10.1 0 1.570796326795 + +# common subexpr with coalesce (short-circuited) +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with coalesce (short-circuited) and alias +query RRR rowsort +select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles; +---- +10.1 0.09900990099 1.471623942989 + +# common subexpr with case (short-circuited) +query RRR rowsort +select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles; +---- +10.1 0.09900990099 1.471623942989 From 2b69acca7b0662b86f70bc9eeb58c12cdcdf971b Mon Sep 17 00:00:00 2001 From: Huaijin Date: Fri, 22 Mar 2024 10:51:33 +0800 Subject: [PATCH 039/117] perf: improve to_field performance (#9722) * perf: improve to_field performance * finish * remove duplicate code * retrigger ci --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/dfschema.rs | 12 ++++ datafusion/expr/src/expr_schema.rs | 106 +++++++++++++++++++++++------ 2 files changed, 98 insertions(+), 20 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 597507a044a2..90fb0b035d35 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -740,6 +740,9 @@ pub trait ExprSchema: std::fmt::Debug { /// Returns the column's optional metadata. fn metadata(&self, col: &Column) -> Result<&HashMap>; + + /// Return the coulmn's datatype and nullability + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; } // Implement `ExprSchema` for `Arc` @@ -755,6 +758,10 @@ impl + std::fmt::Debug> ExprSchema for P { fn metadata(&self, col: &Column) -> Result<&HashMap> { ExprSchema::metadata(self.as_ref(), col) } + + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + self.as_ref().data_type_and_nullable(col) + } } impl ExprSchema for DFSchema { @@ -769,6 +776,11 @@ impl ExprSchema for DFSchema { fn metadata(&self, col: &Column) -> Result<&HashMap> { Ok(self.field_from_column(col)?.metadata()) } + + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + let field = self.field_from_column(col)?; + Ok((field.data_type(), field.is_nullable())) + } } /// DFField wraps an Arrow field and adds an optional qualifier diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 1d83fbe8c0e0..f1ac22d584ee 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -50,6 +50,10 @@ pub trait ExprSchemable { /// cast to a type with respect to a schema fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; + + /// given a schema, return the type and nullability of the expr + fn data_type_and_nullable(&self, schema: &dyn ExprSchema) + -> Result<(DataType, bool)>; } impl ExprSchemable for Expr { @@ -370,32 +374,90 @@ impl ExprSchemable for Expr { } } + /// Returns the datatype and nullability of the expression based on [ExprSchema]. + /// + /// Note: [`DFSchema`] implements [ExprSchema]. + /// + /// [`DFSchema`]: datafusion_common::DFSchema + /// + /// # Errors + /// + /// This function errors when it is not possible to compute its + /// datatype or nullability. + fn data_type_and_nullable( + &self, + schema: &dyn ExprSchema, + ) -> Result<(DataType, bool)> { + match self { + Expr::Alias(Alias { expr, name, .. }) => match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { + None => schema + .data_type_and_nullable(&Column::from_name(name)) + .map(|(d, n)| (d.clone(), n)), + Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)), + }, + _ => expr.data_type_and_nullable(schema), + }, + Expr::Sort(Sort { expr, .. }) | Expr::Negative(expr) => { + expr.data_type_and_nullable(schema) + } + Expr::Column(c) => schema + .data_type_and_nullable(c) + .map(|(d, n)| (d.clone(), n)), + Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), + Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), + Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + Expr::IsNull(_) + | Expr::IsNotNull(_) + | Expr::IsTrue(_) + | Expr::IsFalse(_) + | Expr::IsUnknown(_) + | Expr::IsNotTrue(_) + | Expr::IsNotFalse(_) + | Expr::IsNotUnknown(_) + | Expr::Exists { .. } => Ok((DataType::Boolean, false)), + Expr::ScalarSubquery(subquery) => Ok(( + subquery.subquery.schema().field(0).data_type().clone(), + subquery.subquery.schema().field(0).is_nullable(), + )), + Expr::BinaryExpr(BinaryExpr { + ref left, + ref right, + ref op, + }) => { + let left = left.data_type_and_nullable(schema)?; + let right = right.data_type_and_nullable(schema)?; + Ok((get_result_type(&left.0, op, &right.0)?, left.1 || right.1)) + } + _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), + } + } + /// Returns a [arrow::datatypes::Field] compatible with this expression. /// /// So for example, a projected expression `col(c1) + col(c2)` is /// placed in an output field **named** col("c1 + c2") fn to_field(&self, input_schema: &dyn ExprSchema) -> Result { match self { - Expr::Column(c) => Ok(DFField::new( - c.relation.clone(), - &c.name, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - ) - .with_metadata(self.metadata(input_schema)?)), - Expr::Alias(Alias { relation, name, .. }) => Ok(DFField::new( - relation.clone(), - name, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - ) - .with_metadata(self.metadata(input_schema)?)), - _ => Ok(DFField::new_unqualified( - &self.display_name()?, - self.get_type(input_schema)?, - self.nullable(input_schema)?, - ) - .with_metadata(self.metadata(input_schema)?)), + Expr::Column(c) => { + let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; + Ok( + DFField::new(c.relation.clone(), &c.name, data_type, nullable) + .with_metadata(self.metadata(input_schema)?), + ) + } + Expr::Alias(Alias { relation, name, .. }) => { + let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; + Ok(DFField::new(relation.clone(), name, data_type, nullable) + .with_metadata(self.metadata(input_schema)?)) + } + _ => { + let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; + Ok( + DFField::new_unqualified(&self.display_name()?, data_type, nullable) + .with_metadata(self.metadata(input_schema)?), + ) + } } } @@ -704,5 +766,9 @@ mod tests { fn metadata(&self, _col: &Column) -> Result<&HashMap> { Ok(&self.metadata) } + + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + Ok((self.data_type(col)?, self.nullable(col)?)) + } } } From 47f4b5a67ac3b327764cbd4c0f42da7ac44854e5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 22 Mar 2024 08:47:48 -0400 Subject: [PATCH 040/117] Minor: Run ScalarValue size test on aarch again (#9728) * Minor: Run ScalarValue size test on aarch again * add comments --- datafusion/common/src/scalar/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 2a99b667d8f1..88d40a35585d 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -4528,18 +4528,18 @@ mod tests { assert_eq!(expected, data_type.try_into().unwrap()) } - // this test fails on aarch, so don't run it there - #[cfg(not(target_arch = "aarch64"))] #[test] fn size_of_scalar() { // Since ScalarValues are used in a non trivial number of places, // making it larger means significant more memory consumption // per distinct value. // + // Thus this test ensures that no code change makes ScalarValue larger + // // The alignment requirements differ across architectures and // thus the size of the enum appears to as well - // The value can be changed depending on rust version + // The value may also change depending on rust version assert_eq!(std::mem::size_of::(), 64); } From d321ba3cc31ba20823a7b6452899da070f528522 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 22 Mar 2024 14:00:11 -0400 Subject: [PATCH 041/117] Move trim functions (btrim, ltrim, rtrim) to datafusion_functions, make expr_fn API consistent (#9730) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Move trim functions to datafusion-functions * Doc updates for ltrim, rtrim and trim to reflect how they actually function. * Fixed struct name Trim -> BTrim --- .../tests/dataframe/dataframe_functions.rs | 4 +- datafusion/expr/src/built_in_function.rs | 27 -- datafusion/expr/src/expr_fn.rs | 21 -- datafusion/functions/Cargo.toml | 7 +- .../src/string/{trim.rs => btrim.rs} | 37 ++- datafusion/functions/src/string/common.rs | 265 +++++++++++++++ datafusion/functions/src/string/ltrim.rs | 77 +++++ datafusion/functions/src/string/mod.rs | 301 +++--------------- datafusion/functions/src/string/rtrim.rs | 77 +++++ .../functions/src/string/starts_with.rs | 3 +- datafusion/functions/src/string/to_hex.rs | 3 +- datafusion/functions/src/string/upper.rs | 5 +- datafusion/physical-expr/src/functions.rs | 187 ----------- .../physical-expr/src/string_expressions.rs | 94 +----- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 28 +- datafusion/proto/src/logical_plan/to_proto.rs | 3 - datafusion/sql/src/expr/mod.rs | 53 +-- .../source/user-guide/sql/scalar_functions.md | 54 ++-- 21 files changed, 559 insertions(+), 714 deletions(-) rename datafusion/functions/src/string/{trim.rs => btrim.rs} (73%) create mode 100644 datafusion/functions/src/string/common.rs create mode 100644 datafusion/functions/src/string/ltrim.rs create mode 100644 datafusion/functions/src/string/rtrim.rs diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index cea701492910..6ebd64c9b628 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -367,7 +367,7 @@ async fn test_fn_lpad_with_string() -> Result<()> { #[tokio::test] async fn test_fn_ltrim() -> Result<()> { - let expr = ltrim(lit(" a b c ")); + let expr = ltrim(vec![lit(" a b c ")]); let expected = [ "+-----------------------------------------+", @@ -384,7 +384,7 @@ async fn test_fn_ltrim() -> Result<()> { #[tokio::test] async fn test_fn_ltrim_with_columns() -> Result<()> { - let expr = ltrim(col("a")); + let expr = ltrim(vec![col("a")]); let expected = [ "+---------------+", diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fffe2cf4c9c9..785965f6f693 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -107,8 +107,6 @@ pub enum BuiltinScalarFunction { Ascii, /// bit_length BitLength, - /// btrim - Btrim, /// character_length CharacterLength, /// chr @@ -127,8 +125,6 @@ pub enum BuiltinScalarFunction { Lpad, /// lower Lower, - /// ltrim - Ltrim, /// octet_length OctetLength, /// random @@ -143,8 +139,6 @@ pub enum BuiltinScalarFunction { Right, /// rpad Rpad, - /// rtrim - Rtrim, /// split_part SplitPart, /// strpos @@ -248,7 +242,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Trunc => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, - BuiltinScalarFunction::Btrim => Volatility::Immutable, BuiltinScalarFunction::CharacterLength => Volatility::Immutable, BuiltinScalarFunction::Chr => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, @@ -258,7 +251,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Left => Volatility::Immutable, BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Lower => Volatility::Immutable, - BuiltinScalarFunction::Ltrim => Volatility::Immutable, BuiltinScalarFunction::OctetLength => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, BuiltinScalarFunction::Repeat => Volatility::Immutable, @@ -266,7 +258,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Reverse => Volatility::Immutable, BuiltinScalarFunction::Right => Volatility::Immutable, BuiltinScalarFunction::Rpad => Volatility::Immutable, - BuiltinScalarFunction::Rtrim => Volatility::Immutable, BuiltinScalarFunction::SplitPart => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, @@ -303,9 +294,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::BitLength => { utf8_to_int_type(&input_expr_types[0], "bit_length") } - BuiltinScalarFunction::Btrim => { - utf8_to_str_type(&input_expr_types[0], "btrim") - } BuiltinScalarFunction::CharacterLength => { utf8_to_int_type(&input_expr_types[0], "character_length") } @@ -325,9 +313,6 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "lower") } BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), - BuiltinScalarFunction::Ltrim => { - utf8_to_str_type(&input_expr_types[0], "ltrim") - } BuiltinScalarFunction::OctetLength => { utf8_to_int_type(&input_expr_types[0], "octet_length") } @@ -347,9 +332,6 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "right") } BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), - BuiltinScalarFunction::Rtrim => { - utf8_to_str_type(&input_expr_types[0], "rtrim") - } BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") } @@ -456,12 +438,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::Btrim - | BuiltinScalarFunction::Ltrim - | BuiltinScalarFunction::Rtrim => Signature::one_of( - vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], - self.volatility(), - ), BuiltinScalarFunction::Chr => { Signature::uniform(1, vec![Int64], self.volatility()) } @@ -703,7 +679,6 @@ impl BuiltinScalarFunction { // string functions BuiltinScalarFunction::Ascii => &["ascii"], BuiltinScalarFunction::BitLength => &["bit_length"], - BuiltinScalarFunction::Btrim => &["btrim"], BuiltinScalarFunction::CharacterLength => { &["character_length", "char_length", "length"] } @@ -715,14 +690,12 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Left => &["left"], BuiltinScalarFunction::Lower => &["lower"], BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Ltrim => &["ltrim"], BuiltinScalarFunction::OctetLength => &["octet_length"], BuiltinScalarFunction::Repeat => &["repeat"], BuiltinScalarFunction::Replace => &["replace"], BuiltinScalarFunction::Reverse => &["reverse"], BuiltinScalarFunction::Right => &["right"], BuiltinScalarFunction::Rpad => &["rpad"], - BuiltinScalarFunction::Rtrim => &["rtrim"], BuiltinScalarFunction::SplitPart => &["split_part"], BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8667f631c507..a834ccab9d15 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -601,12 +601,6 @@ scalar_expr!( scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Lower, lower, string, "convert the string to lower case"); -scalar_expr!( - Ltrim, - ltrim, - string, - "removes all characters, spaces by default, from the beginning of a string" -); scalar_expr!( OctetLength, octet_length, @@ -617,12 +611,6 @@ scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `fro scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`"); -scalar_expr!( - Rtrim, - rtrim, - string, - "removes all characters, spaces by default, from the end of a string" -); scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); @@ -640,11 +628,6 @@ nary_scalar_expr!( rpad, "fill up a string to the length by appending the characters" ); -nary_scalar_expr!( - Btrim, - btrim, - "removes all characters, spaces by default, from both sides of a string" -); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c nary_scalar_expr!( @@ -1082,8 +1065,6 @@ mod test { test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); - test_nary_scalar_expr!(Btrim, btrim, string); - test_nary_scalar_expr!(Btrim, btrim, string, characters); test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Chr, chr, string); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); @@ -1093,7 +1074,6 @@ mod test { test_scalar_expr!(Lower, lower, string); test_nary_scalar_expr!(Lpad, lpad, string, count); test_nary_scalar_expr!(Lpad, lpad, string, count, characters); - test_scalar_expr!(Ltrim, ltrim, string); test_scalar_expr!(OctetLength, octet_length, string); test_scalar_expr!(Replace, replace, string, from, to); test_scalar_expr!(Repeat, repeat, string, count); @@ -1101,7 +1081,6 @@ mod test { test_scalar_expr!(Right, right, string, count); test_nary_scalar_expr!(Rpad, rpad, string, count); test_nary_scalar_expr!(Rpad, rpad, string, count, characters); - test_scalar_expr!(Rtrim, rtrim, string); test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index b12c99e84a90..0410d89d123f 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -29,10 +29,9 @@ authors = { workspace = true } rust-version = { workspace = true } [features] -# enable string functions -string_expressions = [] # enable core functions core_expressions = [] +crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] # enable datetime functions datetime_expressions = [] # Enable encoding by default so the doctests work. In general don't automatically enable all packages. @@ -51,7 +50,9 @@ encoding_expressions = ["base64", "hex"] math_expressions = [] # enable regular expressions regex_expressions = ["regex"] -crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] +# enable string functions +string_expressions = [] + [lib] name = "datafusion_functions" path = "src/lib.rs" diff --git a/datafusion/functions/src/string/trim.rs b/datafusion/functions/src/string/btrim.rs similarity index 73% rename from datafusion/functions/src/string/trim.rs rename to datafusion/functions/src/string/btrim.rs index e04a171722e3..de1c9cc69b72 100644 --- a/datafusion/functions/src/string/trim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -16,30 +16,30 @@ // under the License. use arrow::array::{ArrayRef, OffsetSizeTrait}; -use arrow::datatypes::DataType; -use datafusion_common::exec_err; -use datafusion_common::Result; -use datafusion_expr::ColumnarValue; -use datafusion_expr::TypeSignature::*; -use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use crate::string::{make_scalar_function, utf8_to_str_type}; +use arrow::datatypes::DataType; -use super::{general_trim, TrimType}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; -/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. +use crate::string::common::*; + +/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { +fn btrim(args: &[ArrayRef]) -> Result { general_trim::(args, TrimType::Both) } #[derive(Debug)] -pub(super) struct TrimFunc { +pub(super) struct BTrimFunc { signature: Signature, + aliases: Vec, } -impl TrimFunc { +impl BTrimFunc { pub fn new() -> Self { use DataType::*; Self { @@ -47,17 +47,18 @@ impl TrimFunc { vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], Volatility::Immutable, ), + aliases: vec![String::from("trim")], } } } -impl ScalarUDFImpl for TrimFunc { +impl ScalarUDFImpl for BTrimFunc { fn as_any(&self) -> &dyn Any { self } fn name(&self) -> &str { - "trim" + "btrim" } fn signature(&self) -> &Signature { @@ -65,14 +66,18 @@ impl ScalarUDFImpl for TrimFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "trim") + utf8_to_str_type(&arg_types[0], "btrim") } fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { DataType::Utf8 => make_scalar_function(btrim::, vec![])(args), DataType::LargeUtf8 => make_scalar_function(btrim::, vec![])(args), - other => exec_err!("Unsupported data type {other:?} for function trim"), + other => exec_err!("Unsupported data type {other:?} for function btrim"), } } + + fn aliases(&self) -> &[String] { + &self.aliases + } } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs new file mode 100644 index 000000000000..97465420fb99 --- /dev/null +++ b/datafusion/functions/src/string/common.rs @@ -0,0 +1,265 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::{Display, Formatter}; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::Result; +use datafusion_common::{exec_err, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_physical_expr::functions::Hint; + +pub(crate) enum TrimType { + Left, + Right, + Both, +} + +impl Display for TrimType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + TrimType::Left => write!(f, "ltrim"), + TrimType::Right => write!(f, "rtrim"), + TrimType::Both => write!(f, "btrim"), + } + } +} + +pub(crate) fn general_trim( + args: &[ArrayRef], + trim_type: TrimType, +) -> Result { + let func = match trim_type { + TrimType::Left => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_start_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Right => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>(input, pattern.as_ref()) + }, + TrimType::Both => |input, pattern: &str| { + let pattern = pattern.chars().collect::>(); + str::trim_end_matches::<&[char]>( + str::trim_start_matches::<&[char]>(input, pattern.as_ref()), + pattern.as_ref(), + ) + }, + }; + + let string_array = as_generic_string_array::(&args[0])?; + + match args.len() { + 1 => { + let result = string_array + .iter() + .map(|string| string.map(|string: &str| func(string, " "))) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 2 => { + let characters_array = as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .map(|(string, characters)| match (string, characters) { + (Some(string), Some(characters)) => Some(func(string, characters)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." + ) + } + } +} + +/// Creates a function to identify the optimal return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! get_optimal_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + **value_type + ); + } + }, + data_type => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + data_type + ); + } + }) + } + }; +} + +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +/// applies a unary expression to `args[0]` that is expected to be downcastable to +/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) +/// # Errors +/// This function errors when: +/// * the number of arguments is not 1 +/// * the first argument is not castable to a `GenericStringArray` +pub(crate) fn unary_string_function<'a, T, O, F, R>( + args: &[&'a dyn Array], + op: F, + name: &str, +) -> Result> +where + R: AsRef, + O: OffsetSizeTrait, + T: OffsetSizeTrait, + F: Fn(&'a str) -> R, +{ + if args.len() != 1 { + return exec_err!( + "{:?} args were supplied but {} takes exactly one argument", + args.len(), + name + ); + } + + let string_array = as_generic_string_array::(args[0])?; + + // first map is the iterator, second is for the `Option<_>` + Ok(string_array.iter().map(|string| string.map(&op)).collect()) +} + +pub(crate) fn handle<'a, F, R>( + args: &'a [ColumnarValue], + op: F, + name: &str, +) -> Result +where + R: AsRef, + F: Fn(&'a str) -> R, +{ + match &args[0] { + ColumnarValue::Array(a) => match a.data_type() { + DataType::Utf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i32, + i32, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + DataType::LargeUtf8 => { + Ok(ColumnarValue::Array(Arc::new(unary_string_function::< + i64, + i64, + _, + _, + >( + &[a.as_ref()], op, name + )?))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + } + ScalarValue::LargeUtf8(a) => { + let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); + Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) + } + other => exec_err!("Unsupported data type {other:?} for function {name}"), + }, + } +} + +pub(super) fn make_scalar_function( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) + .map(|(arg, hint)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hint { + Hint::AcceptsSingular => 1, + Hint::Pad => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>>()?; + + let result = (inner)(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + }) +} diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs new file mode 100644 index 000000000000..535ffb14f5f5 --- /dev/null +++ b/datafusion/functions/src/string/ltrim.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. +/// ltrim('zzzytest', 'xyz') = 'test' +fn ltrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Left) +} + +#[derive(Debug)] +pub(super) struct LtrimFunc { + signature: Signature, +} + +impl LtrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LtrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ltrim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "ltrim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(ltrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(ltrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function ltrim"), + } + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 08fcbb363bbc..13c02d5dfac3 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -15,278 +15,63 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}, - datatypes::DataType, -}; -use datafusion_common::{ - cast::as_generic_string_array, exec_err, plan_err, Result, ScalarValue, -}; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; -use datafusion_physical_expr::functions::Hint; -use std::{ - fmt::{Display, Formatter}, - sync::Arc, -}; +//! "string" DataFusion functions -/// Creates a function to identify the optimal return type of a string function given -/// the type of its first argument. -/// -/// If the input type is `LargeUtf8` or `LargeBinary` the return type is -/// `$largeUtf8Type`, -/// -/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, -macro_rules! get_optimal_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - // LargeBinary inputs are automatically coerced to Utf8 - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - // Binary inputs are automatically coerced to Utf8 - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - DataType::Dictionary(_, value_type) => match **value_type { - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - _ => { - return plan_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - **value_type - ); - } - }, - data_type => { - return plan_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - data_type - ); - } - }) - } - }; -} +use std::sync::Arc; -// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. -get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); +use datafusion_expr::ScalarUDF; -/// applies a unary expression to `args[0]` that is expected to be downcastable to -/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) -/// # Errors -/// This function errors when: -/// * the number of arguments is not 1 -/// * the first argument is not castable to a `GenericStringArray` -pub(crate) fn unary_string_function<'a, T, O, F, R>( - args: &[&'a dyn Array], - op: F, - name: &str, -) -> Result> -where - R: AsRef, - O: OffsetSizeTrait, - T: OffsetSizeTrait, - F: Fn(&'a str) -> R, -{ - if args.len() != 1 { - return exec_err!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name - ); - } +mod btrim; +mod common; +mod ltrim; +mod rtrim; +mod starts_with; +mod to_hex; +mod upper; - let string_array = as_generic_string_array::(args[0])?; +// create UDFs +make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); +make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); +make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); +make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); +make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); +make_udf_function!(upper::UpperFunc, UPPER, upper); - // first map is the iterator, second is for the `Option<_>` - Ok(string_array.iter().map(|string| string.map(&op)).collect()) -} +pub mod expr_fn { + use datafusion_expr::Expr; -fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result -where - R: AsRef, - F: Fn(&'a str) -> R, -{ - match &args[0] { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i32, - i32, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i64, - i64, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, + #[doc = "Removes all characters, spaces by default, from both sides of a string"] + pub fn btrim(args: Vec) -> Expr { + super::btrim().call(args) } -} - -// TODO: mode allow[(dead_code)] after move ltrim and rtrim -enum TrimType { - #[allow(dead_code)] - Left, - #[allow(dead_code)] - Right, - Both, -} -impl Display for TrimType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - TrimType::Left => write!(f, "ltrim"), - TrimType::Right => write!(f, "rtrim"), - TrimType::Both => write!(f, "btrim"), - } + #[doc = "Removes all characters, spaces by default, from the beginning of a string"] + pub fn ltrim(args: Vec) -> Expr { + super::ltrim().call(args) } -} - -fn general_trim( - args: &[ArrayRef], - trim_type: TrimType, -) -> Result { - let func = match trim_type { - TrimType::Left => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_start_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Right => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Both => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>( - str::trim_start_matches::<&[char]>(input, pattern.as_ref()), - pattern.as_ref(), - ) - }, - }; - - let string_array = as_generic_string_array::(&args[0])?; - match args.len() { - 1 => { - let result = string_array - .iter() - .map(|string| string.map(|string: &str| func(string, " "))) - .collect::>(); + #[doc = "Removes all characters, spaces by default, from the end of a string"] + pub fn rtrim(args: Vec) -> Expr { + super::rtrim().call(args) + } - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let characters_array = as_generic_string_array::(&args[1])?; + #[doc = "Returns true if string starts with prefix."] + pub fn starts_with(arg1: Expr, arg2: Expr) -> Expr { + super::starts_with().call(vec![arg1, arg2]) + } - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters)), - _ => None, - }) - .collect::>(); + #[doc = "Converts an integer to a hexadecimal string."] + pub fn to_hex(arg1: Expr) -> Expr { + super::to_hex().call(vec![arg1]) + } - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!( - "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." - ) - } + #[doc = "Converts a string to uppercase."] + pub fn upper(arg1: Expr) -> Expr { + super::upper().call(vec![arg1]) } } -pub(super) fn make_scalar_function( - inner: F, - hints: Vec, -) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - Arc::new(move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - - let inferred_length = len.unwrap_or(1); - let args = args - .iter() - .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) - .map(|(arg, hint)| { - // Decide on the length to expand this scalar to depending - // on the given hints. - let expansion_len = match hint { - Hint::AcceptsSingular => 1, - Hint::Pad => inferred_length, - }; - arg.clone().into_array(expansion_len) - }) - .collect::>>()?; - - let result = (inner)(&args); - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) - } - }) +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![btrim(), ltrim(), rtrim(), starts_with(), to_hex(), upper()] } - -mod starts_with; -mod to_hex; -mod trim; -mod upper; -// create UDFs -make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); -make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); -make_udf_function!(trim::TrimFunc, TRIM, trim); -make_udf_function!(upper::UpperFunc, UPPER, upper); - -export_functions!( - ( - starts_with, - arg1 arg2, - "Returns true if string starts with prefix."), - ( - to_hex, - arg1, - "Converts an integer to a hexadecimal string."), - (trim, - arg1, - "removes all characters, space by default from the string"), - (upper, - arg1, - "Converts a string to uppercase.")); diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs new file mode 100644 index 000000000000..17d2f8234b34 --- /dev/null +++ b/datafusion/functions/src/string/rtrim.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. +/// rtrim('testxxzx', 'xyz') = 'test' +fn rtrim(args: &[ArrayRef]) -> Result { + general_trim::(args, TrimType::Right) +} + +#[derive(Debug)] +pub(super) struct RtrimFunc { + signature: Signature, +} + +impl RtrimFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8]), Exact(vec![Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RtrimFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rtrim" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "rtrim") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(rtrim::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(rtrim::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function rtrim"), + } + } +} diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 1fce399d1e70..4450b9d332a0 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::make_scalar_function; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; @@ -24,8 +25,6 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -use crate::string::make_scalar_function; - /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' pub fn starts_with(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 4dfc84887da2..1bdece3f7af8 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::make_scalar_function; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, @@ -27,8 +28,6 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -use super::make_scalar_function; - /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' pub fn to_hex(args: &[ArrayRef]) -> Result diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index ed41487699aa..a0c910ebb2c8 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -15,16 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::string::common::{handle, utf8_to_str_type}; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; -use crate::string::utf8_to_str_type; - -use super::handle; - #[derive(Debug)] pub(super) struct UpperFunc { signature: Signature, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index f2c93c3ec1dd..a6efe0e0861d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -273,15 +273,6 @@ pub fn create_physical_fun( _ => unreachable!(), }, }), - BuiltinScalarFunction::Btrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::btrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function btrim"), - }), BuiltinScalarFunction::CharacterLength => { Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -347,15 +338,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function lpad"), }), - BuiltinScalarFunction::Ltrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::ltrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::ltrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function ltrim"), - }), BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { @@ -427,15 +409,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function rpad"), }), - BuiltinScalarFunction::Rtrim => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::rtrim::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::rtrim::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function rtrim"), - }), BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::split_part::)(args) @@ -752,70 +725,6 @@ mod tests { Int32Array ); test_function!(BitLength, &[lit("")], Ok(Some(0)), i32, Int32, Int32Array); - test_function!( - Btrim, - &[lit(" trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit(" trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("\n trim \n")], - Ok(Some("\n trim \n")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("xyxtrimyyx"), lit("xyz"),], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("\nxyxtrimyyx\n"), lit("xyz\n"),], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit(ScalarValue::Utf8(None)), lit("xyz"),], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Btrim, - &[lit("xyxtrimyyx"), lit(ScalarValue::Utf8(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); #[cfg(feature = "unicode_expressions")] test_function!( CharacterLength, @@ -1287,54 +1196,6 @@ mod tests { Utf8, StringArray ); - test_function!( - Ltrim, - &[lit(" trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit(" trim ")], - Ok(Some("trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("trim ")], - Ok(Some("trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit("\n trim ")], - Ok(Some("\n trim ")), - &str, - Utf8, - StringArray - ); - test_function!( - Ltrim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); test_function!( OctetLength, &[lit("chars")], @@ -1683,54 +1544,6 @@ mod tests { Utf8, StringArray ); - test_function!( - Rtrim, - &[lit("trim ")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim ")], - Ok(Some(" trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim \n")], - Ok(Some(" trim \n")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(" trim")], - Ok(Some(" trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit("trim")], - Ok(Some("trim")), - &str, - Utf8, - StringArray - ); - test_function!( - Rtrim, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); test_function!( SplitPart, &[ diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 86c0092a220d..f5229d92545e 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -21,11 +21,8 @@ //! String expressions +use std::iter; use std::sync::Arc; -use std::{ - fmt::{Display, Formatter}, - iter, -}; use arrow::{ array::{ @@ -346,95 +343,6 @@ pub fn lower(args: &[ColumnarValue]) -> Result { handle(args, |string| string.to_lowercase(), "lower") } -enum TrimType { - Left, - Right, - Both, -} - -impl Display for TrimType { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - TrimType::Left => write!(f, "ltrim"), - TrimType::Right => write!(f, "rtrim"), - TrimType::Both => write!(f, "btrim"), - } - } -} - -fn general_trim( - args: &[ArrayRef], - trim_type: TrimType, -) -> Result { - let func = match trim_type { - TrimType::Left => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_start_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Right => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>(input, pattern.as_ref()) - }, - TrimType::Both => |input, pattern: &str| { - let pattern = pattern.chars().collect::>(); - str::trim_end_matches::<&[char]>( - str::trim_start_matches::<&[char]>(input, pattern.as_ref()), - pattern.as_ref(), - ) - }, - }; - - let string_array = as_generic_string_array::(&args[0])?; - - match args.len() { - 1 => { - let result = string_array - .iter() - .map(|string| string.map(|string: &str| func(string, " "))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 2 => { - let characters_array = as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!( - "{trim_type} was called with {other} arguments. It requires at least 1 and at most 2." - ) - } - } -} - -/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. -/// btrim('xyxtrimyyx', 'xyz') = 'trim' -pub fn btrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Both) -} - -/// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. -/// ltrim('zzzytest', 'xyz') = 'test' -pub fn ltrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Left) -} - -/// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. -/// rtrim('testxxzx', 'xyz') = 'test' -pub fn rtrim(args: &[ArrayRef]) -> Result { - general_trim::(args, TrimType::Right) -} - /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index c009682d5a4d..416b49db7aa7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -564,7 +564,7 @@ enum ScalarFunction { // 20 was Array // RegexpMatch = 21; BitLength = 22; - Btrim = 23; + // 23 was Btrim CharacterLength = 24; Chr = 25; Concat = 26; @@ -575,7 +575,7 @@ enum ScalarFunction { Left = 31; Lpad = 32; Lower = 33; - Ltrim = 34; + // 34 was Ltrim // 35 was MD5 // 36 was NullIf OctetLength = 37; @@ -586,7 +586,7 @@ enum ScalarFunction { Reverse = 42; Right = 43; Rpad = 44; - Rtrim = 45; + // 45 was Rtrim // 46 was SHA224 // 47 was SHA256 // 48 was SHA384 diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 58683dba6dff..49102137b659 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22930,7 +22930,6 @@ impl serde::Serialize for ScalarFunction { Self::Sqrt => "Sqrt", Self::Trunc => "Trunc", Self::BitLength => "BitLength", - Self::Btrim => "Btrim", Self::CharacterLength => "CharacterLength", Self::Chr => "Chr", Self::Concat => "Concat", @@ -22939,7 +22938,6 @@ impl serde::Serialize for ScalarFunction { Self::Left => "Left", Self::Lpad => "Lpad", Self::Lower => "Lower", - Self::Ltrim => "Ltrim", Self::OctetLength => "OctetLength", Self::Random => "Random", Self::Repeat => "Repeat", @@ -22947,7 +22945,6 @@ impl serde::Serialize for ScalarFunction { Self::Reverse => "Reverse", Self::Right => "Right", Self::Rpad => "Rpad", - Self::Rtrim => "Rtrim", Self::SplitPart => "SplitPart", Self::Strpos => "Strpos", Self::Substr => "Substr", @@ -23004,7 +23001,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sqrt", "Trunc", "BitLength", - "Btrim", "CharacterLength", "Chr", "Concat", @@ -23013,7 +23009,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Left", "Lpad", "Lower", - "Ltrim", "OctetLength", "Random", "Repeat", @@ -23021,7 +23016,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Reverse", "Right", "Rpad", - "Rtrim", "SplitPart", "Strpos", "Substr", @@ -23107,7 +23101,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sqrt" => Ok(ScalarFunction::Sqrt), "Trunc" => Ok(ScalarFunction::Trunc), "BitLength" => Ok(ScalarFunction::BitLength), - "Btrim" => Ok(ScalarFunction::Btrim), "CharacterLength" => Ok(ScalarFunction::CharacterLength), "Chr" => Ok(ScalarFunction::Chr), "Concat" => Ok(ScalarFunction::Concat), @@ -23116,7 +23109,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Left" => Ok(ScalarFunction::Left), "Lpad" => Ok(ScalarFunction::Lpad), "Lower" => Ok(ScalarFunction::Lower), - "Ltrim" => Ok(ScalarFunction::Ltrim), "OctetLength" => Ok(ScalarFunction::OctetLength), "Random" => Ok(ScalarFunction::Random), "Repeat" => Ok(ScalarFunction::Repeat), @@ -23124,7 +23116,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Reverse" => Ok(ScalarFunction::Reverse), "Right" => Ok(ScalarFunction::Right), "Rpad" => Ok(ScalarFunction::Rpad), - "Rtrim" => Ok(ScalarFunction::Rtrim), "SplitPart" => Ok(ScalarFunction::SplitPart), "Strpos" => Ok(ScalarFunction::Strpos), "Substr" => Ok(ScalarFunction::Substr), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 8eabb3b18603..5e458bfef016 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2863,7 +2863,7 @@ pub enum ScalarFunction { /// 20 was Array /// RegexpMatch = 21; BitLength = 22, - Btrim = 23, + /// 23 was Btrim CharacterLength = 24, Chr = 25, Concat = 26, @@ -2874,7 +2874,7 @@ pub enum ScalarFunction { Left = 31, Lpad = 32, Lower = 33, - Ltrim = 34, + /// 34 was Ltrim /// 35 was MD5 /// 36 was NullIf OctetLength = 37, @@ -2885,7 +2885,7 @@ pub enum ScalarFunction { Reverse = 42, Right = 43, Rpad = 44, - Rtrim = 45, + /// 45 was Rtrim /// 46 was SHA224 /// 47 was SHA256 /// 48 was SHA384 @@ -3003,7 +3003,6 @@ impl ScalarFunction { ScalarFunction::Sqrt => "Sqrt", ScalarFunction::Trunc => "Trunc", ScalarFunction::BitLength => "BitLength", - ScalarFunction::Btrim => "Btrim", ScalarFunction::CharacterLength => "CharacterLength", ScalarFunction::Chr => "Chr", ScalarFunction::Concat => "Concat", @@ -3012,7 +3011,6 @@ impl ScalarFunction { ScalarFunction::Left => "Left", ScalarFunction::Lpad => "Lpad", ScalarFunction::Lower => "Lower", - ScalarFunction::Ltrim => "Ltrim", ScalarFunction::OctetLength => "OctetLength", ScalarFunction::Random => "Random", ScalarFunction::Repeat => "Repeat", @@ -3020,7 +3018,6 @@ impl ScalarFunction { ScalarFunction::Reverse => "Reverse", ScalarFunction::Right => "Right", ScalarFunction::Rpad => "Rpad", - ScalarFunction::Rtrim => "Rtrim", ScalarFunction::SplitPart => "SplitPart", ScalarFunction::Strpos => "Strpos", ScalarFunction::Substr => "Substr", @@ -3071,7 +3068,6 @@ impl ScalarFunction { "Sqrt" => Some(Self::Sqrt), "Trunc" => Some(Self::Trunc), "BitLength" => Some(Self::BitLength), - "Btrim" => Some(Self::Btrim), "CharacterLength" => Some(Self::CharacterLength), "Chr" => Some(Self::Chr), "Concat" => Some(Self::Concat), @@ -3080,7 +3076,6 @@ impl ScalarFunction { "Left" => Some(Self::Left), "Lpad" => Some(Self::Lpad), "Lower" => Some(Self::Lower), - "Ltrim" => Some(Self::Ltrim), "OctetLength" => Some(Self::OctetLength), "Random" => Some(Self::Random), "Repeat" => Some(Self::Repeat), @@ -3088,7 +3083,6 @@ impl ScalarFunction { "Reverse" => Some(Self::Reverse), "Right" => Some(Self::Right), "Rpad" => Some(Self::Rpad), - "Rtrim" => Some(Self::Rtrim), "SplitPart" => Some(Self::SplitPart), "Strpos" => Some(Self::Strpos), "Substr" => Some(Self::Substr), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 64ceb37d2961..d41add915a96 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,17 +48,16 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, ascii, asinh, atan, atan2, atanh, bit_length, btrim, cbrt, ceil, - character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, - degrees, ends_with, exp, + acosh, ascii, asinh, atan, atan2, atanh, bit_length, cbrt, ceil, character_length, + chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, ltrim, nanvl, octet_length, overlay, pi, power, radians, random, repeat, - replace, reverse, right, round, rpad, rtrim, signum, sin, sinh, split_part, sqrt, - strpos, substr, substr_index, substring, translate, trunc, uuid, AggregateFunction, - Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, + lower, lpad, nanvl, octet_length, overlay, pi, power, radians, random, repeat, + replace, reverse, right, round, rpad, signum, sin, sinh, split_part, sqrt, strpos, + substr, substr_index, substring, translate, trunc, uuid, AggregateFunction, Between, + BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, @@ -461,13 +460,10 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::OctetLength => Self::OctetLength, ScalarFunction::Concat => Self::Concat, ScalarFunction::Lower => Self::Lower, - ScalarFunction::Ltrim => Self::Ltrim, - ScalarFunction::Rtrim => Self::Rtrim, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, ScalarFunction::Ascii => Self::Ascii, ScalarFunction::BitLength => Self::BitLength, - ScalarFunction::Btrim => Self::Btrim, ScalarFunction::CharacterLength => Self::CharacterLength, ScalarFunction::Chr => Self::Chr, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, @@ -1439,12 +1435,6 @@ pub fn parse_expr( ScalarFunction::Lower => { Ok(lower(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Ltrim => { - Ok(ltrim(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Rtrim => { - Ok(rtrim(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Ascii => { Ok(ascii(parse_expr(&args[0], registry, codec)?)) } @@ -1512,12 +1502,6 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), - ScalarFunction::Btrim => Ok(btrim( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), ScalarFunction::SplitPart => Ok(split_part( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 89bd93550a04..39d663b6c59b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1481,13 +1481,10 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::OctetLength => Self::OctetLength, BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::Lower => Self::Lower, - BuiltinScalarFunction::Ltrim => Self::Ltrim, - BuiltinScalarFunction::Rtrim => Self::Rtrim, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, BuiltinScalarFunction::Ascii => Self::Ascii, BuiltinScalarFunction::BitLength => Self::BitLength, - BuiltinScalarFunction::Btrim => Self::Btrim, BuiltinScalarFunction::CharacterLength => Self::CharacterLength, BuiltinScalarFunction::Chr => Self::Chr, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index c34b42193cec..04f8001bfc1b 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,20 +15,11 @@ // specific language governing permissions and limitations // under the License. -mod binary_op; -mod function; -mod grouping_set; -mod identifier; -mod json_access; -mod order_by; -mod subquery; -mod substring; -mod unary_op; -mod value; - -use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::DataType; use arrow_schema::TimeUnit; +use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; +use sqlparser::parser::ParserError::ParserError; + use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_err, Column, DFSchema, Result, ScalarValue, @@ -40,8 +31,19 @@ use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, }; -use sqlparser::ast::{ArrayAgg, Expr as SQLExpr, JsonOperator, TrimWhereField, Value}; -use sqlparser::parser::ParserError::ParserError; + +use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; + +mod binary_op; +mod function; +mod grouping_set; +mod identifier; +mod json_access; +mod order_by; +mod subquery; +mod substring; +mod unary_op; +mod value; impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(crate) fn sql_expr_to_logical_expr( @@ -743,13 +745,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = match trim_where { - Some(TrimWhereField::Leading) => BuiltinScalarFunction::Ltrim, - Some(TrimWhereField::Trailing) => BuiltinScalarFunction::Rtrim, - Some(TrimWhereField::Both) => BuiltinScalarFunction::Btrim, - None => BuiltinScalarFunction::Btrim, - }; - let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; let args = match (trim_what, trim_characters) { (Some(to_trim), None) => { @@ -774,7 +769,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } (None, None) => Ok(vec![arg]), }?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + + let fun_name = match trim_where { + Some(TrimWhereField::Leading) => "ltrim", + Some(TrimWhereField::Trailing) => "rtrim", + Some(TrimWhereField::Both) => "btrim", + None => "trim", + }; + let fun = self + .context_provider + .get_function_meta(fun_name) + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected '{fun_name}' function") + })?; + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_overlay_to_expr( diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index d4570dbc35f2..5eb3436b4256 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -731,12 +731,15 @@ btrim(str[, trim_str]) Can be a constant, column, or function, and any combination of string operators. - **trim_str**: String expression to trim from the beginning and end of the input string. Can be a constant, column, or function, and any combination of arithmetic operators. - _Default is whitespace characters_. + _Default is whitespace characters._ **Related functions**: [ltrim](#ltrim), -[rtrim](#rtrim), -[trim](#trim) +[rtrim](#rtrim) + +#### Aliases + +- trim ### `char_length` @@ -919,26 +922,25 @@ lpad(str, n[, padding_str]) ### `ltrim` -Removes leading spaces from a string. +Trims the specified trim string from the beginning of a string. +If no trim string is provided, all whitespace is removed from the start +of the input string. ``` -ltrim(str) +ltrim(str[, trim_str]) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +- **trim_str**: String expression to trim from the beginning of the input string. + Can be a constant, column, or function, and any combination of arithmetic operators. + _Default is whitespace characters._ **Related functions**: [btrim](#btrim), -[rtrim](#rtrim), -[trim](#trim) - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. +[rtrim](#rtrim) ### `octet_length` @@ -1040,21 +1042,25 @@ rpad(str, n[, padding_str]) ### `rtrim` -Removes trailing spaces from a string. +Trims the specified trim string from the end of a string. +If no trim string is provided, all whitespace is removed from the end +of the input string. ``` -rtrim(str) +rtrim(str[, trim_str]) ``` #### Arguments - **str**: String expression to operate on. Can be a constant, column, or function, and any combination of string operators. +- **trim_str**: String expression to trim from the end of the input string. + Can be a constant, column, or function, and any combination of arithmetic operators. + _Default is whitespace characters._ **Related functions**: [btrim](#btrim), -[ltrim](#ltrim), -[trim](#trim) +[ltrim](#ltrim) ### `split_part` @@ -1154,21 +1160,7 @@ to_hex(int) ### `trim` -Removes leading and trailing spaces from a string. - -``` -trim(str) -``` - -#### Arguments - -- **str**: String expression to operate on. - Can be a constant, column, or function, and any combination of string operators. - -**Related functions**: -[btrim](#btrim), -[ltrim](#ltrim), -[rtrim](#rtrim) +_Alias of [btrim](#btrim)._ ### `upper` From 4913a0025b63211d4848693e972540c781c1cf98 Mon Sep 17 00:00:00 2001 From: Kunal Kundu Date: Fri, 22 Mar 2024 23:50:21 +0530 Subject: [PATCH 042/117] make format prefix optional for format options in COPY (#9723) * make format prefix optional for format options in COPY * fix clippy lint error * Add negative test case for unknown option * Improve test comments --------- Co-authored-by: Andrew Lamb --- datafusion/sql/src/statement.rs | 11 +++++- datafusion/sqllogictest/test_files/copy.slt | 39 +++++++++++++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index e50aceb757df..4cca4c114a91 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -850,7 +850,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return plan_err!("Unsupported Value in COPY statement {}", value); } }; - options.insert(key.to_lowercase(), value_string.to_lowercase()); + if !(&key.contains('.')) { + // If config does not belong to any namespace, assume it is + // a format option and apply the format prefix for backwards + // compatibility. + + let renamed_key = format!("format.{}", key); + options.insert(renamed_key.to_lowercase(), value_string.to_lowercase()); + } else { + options.insert(key.to_lowercase(), value_string.to_lowercase()); + } } let file_type = if let Some(file_type) = statement.stored_as { diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 7884bece1f39..75f1ccb07aac 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -474,6 +474,45 @@ select * from validate_arrow; 1 Foo 2 Bar +# Format Options Support without the 'format.' prefix + +# Copy with format options for Parquet without the 'format.' prefix +query IT +COPY source_table TO 'test_files/scratch/copy/format_table.parquet' +OPTIONS ( + compression snappy, + 'compression::col1' 'zstd(5)' +); +---- +2 + +# Copy with format options for JSON without the 'format.' prefix +query IT +COPY source_table to 'test_files/scratch/copy/format_table' +STORED AS JSON OPTIONS (compression gzip); +---- +2 + +# Copy with format options for CSV without the 'format.' prefix +query IT +COPY source_table to 'test_files/scratch/copy/format_table.csv' +OPTIONS ( + has_header false, + compression xz, + datetime_format '%FT%H:%M:%S.%9f', + delimiter ';', + null_value 'NULLVAL' +); +---- +2 + +# Copy with unknown format options without the 'format.' prefix to ensure error is sensible +query error DataFusion error: Invalid or Unsupported Configuration: Config value "unknown_option" not found on CsvOptions +COPY source_table to 'test_files/scratch/copy/format_table2.csv' +OPTIONS ( + unknown_option false, +); + # Error cases: From d43dd1694ff9123488bbff6c2fde30c451fda2de Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Fri, 22 Mar 2024 11:48:11 -0700 Subject: [PATCH 043/117] refactor: Extract `range` and `gen_series` functions from `functions-array` subcrate' s `kernels` and `udf` containers (#9720) * Issue-9705 - Extract range and gen_series functions from functions-array subcrate' s kernels and udf containers * Issue-9705 - Address review comment --- datafusion/functions-array/src/kernels.rs | 177 +----------- datafusion/functions-array/src/lib.rs | 9 +- datafusion/functions-array/src/range.rs | 332 ++++++++++++++++++++++ datafusion/functions-array/src/udf.rs | 142 --------- 4 files changed, 342 insertions(+), 318 deletions(-) create mode 100644 datafusion/functions-array/src/range.rs diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index ec0942837795..9b743fa913cf 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -18,152 +18,28 @@ //! implementation kernels for array functions use arrow::array::{ - Array, ArrayRef, BooleanArray, Capacities, Date32Array, GenericListArray, Int64Array, + Array, ArrayRef, BooleanArray, Capacities, GenericListArray, Int64Array, LargeListArray, ListArray, MutableArrayData, OffsetSizeTrait, UInt64Array, }; use arrow::compute; -use arrow::datatypes::{ - DataType, Date32Type, Field, IntervalMonthDayNanoType, UInt64Type, -}; +use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow_array::new_null_array; use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_schema::FieldRef; use arrow_schema::SortOptions; use datafusion_common::cast::{ - as_date32_array, as_generic_list_array, as_int64_array, as_interval_mdn_array, - as_large_list_array, as_list_array, as_null_array, as_string_array, + as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, + as_null_array, as_string_array, }; use datafusion_common::{ - exec_err, internal_datafusion_err, not_impl_datafusion_err, DataFusionError, Result, - ScalarValue, + exec_err, internal_datafusion_err, DataFusionError, Result, ScalarValue, }; use crate::utils::downcast_arg; use std::any::type_name; use std::sync::Arc; -/// Generates an array of integers from start to stop with a given step. -/// -/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. -/// It returns a `Result` representing the resulting ListArray after the operation. -/// -/// # Arguments -/// -/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. -/// -/// # Examples -/// -/// gen_range(3) => [0, 1, 2] -/// gen_range(1, 4) => [1, 2, 3] -/// gen_range(1, 7, 2) => [1, 3, 5] -pub(super) fn gen_range(args: &[ArrayRef], include_upper: bool) -> Result { - let (start_array, stop_array, step_array) = match args.len() { - 1 => (None, as_int64_array(&args[0])?, None), - 2 => ( - Some(as_int64_array(&args[0])?), - as_int64_array(&args[1])?, - None, - ), - 3 => ( - Some(as_int64_array(&args[0])?), - as_int64_array(&args[1])?, - Some(as_int64_array(&args[2])?), - ), - _ => return exec_err!("gen_range expects 1 to 3 arguments"), - }; - - let mut values = vec![]; - let mut offsets = vec![0]; - let mut valid = BooleanBufferBuilder::new(stop_array.len()); - for (idx, stop) in stop_array.iter().enumerate() { - match retrieve_range_args(start_array, stop, step_array, idx) { - Some((_, _, 0)) => { - return exec_err!( - "step can't be 0 for function {}(start [, stop, step])", - if include_upper { - "generate_series" - } else { - "range" - } - ); - } - Some((start, stop, step)) => { - // Below, we utilize `usize` to represent steps. - // On 32-bit targets, the absolute value of `i64` may fail to fit into `usize`. - let step_abs = usize::try_from(step.unsigned_abs()).map_err(|_| { - not_impl_datafusion_err!("step {} can't fit into usize", step) - })?; - values.extend( - gen_range_iter(start, stop, step < 0, include_upper) - .step_by(step_abs), - ); - offsets.push(values.len() as i32); - valid.append(true); - } - // If any of the arguments is NULL, append a NULL value to the result. - None => { - offsets.push(values.len() as i32); - valid.append(false); - } - }; - } - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Int64, true)), - OffsetBuffer::new(offsets.into()), - Arc::new(Int64Array::from(values)), - Some(NullBuffer::new(valid.finish())), - )?); - Ok(arr) -} - -/// Get the (start, stop, step) args for the range and generate_series function. -/// If any of the arguments is NULL, returns None. -fn retrieve_range_args( - start_array: Option<&Int64Array>, - stop: Option, - step_array: Option<&Int64Array>, - idx: usize, -) -> Option<(i64, i64, i64)> { - // Default start value is 0 if not provided - let start = - start_array.map_or(Some(0), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; - let stop = stop?; - // Default step value is 1 if not provided - let step = - step_array.map_or(Some(1), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; - Some((start, stop, step)) -} - -/// Returns an iterator of i64 values from start to stop -fn gen_range_iter( - start: i64, - stop: i64, - decreasing: bool, - include_upper: bool, -) -> Box> { - match (decreasing, include_upper) { - // Decreasing range, stop is inclusive - (true, true) => Box::new((stop..=start).rev()), - // Decreasing range, stop is exclusive - (true, false) => { - if stop == i64::MAX { - // start is never greater than stop, and stop is exclusive, - // so the decreasing range must be empty. - Box::new(std::iter::empty()) - } else { - // Increase the stop value by one to exclude it. - // Since stop is not i64::MAX, `stop + 1` will not overflow. - Box::new((stop + 1..=start).rev()) - } - } - // Increasing range, stop is inclusive - (false, true) => Box::new(start..=stop), - // Increasing range, stop is exclusive - (false, false) => Box::new(start..stop), - } -} - /// Returns the length of each array dimension fn compute_array_dims(arr: Option) -> Result>>> { let mut value = match arr { @@ -285,49 +161,6 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { array_type => exec_err!("array_ndims does not support type {array_type:?}"), } } -pub fn gen_range_date( - args: &[ArrayRef], - include_upper: bool, -) -> datafusion_common::Result { - if args.len() != 3 { - return exec_err!("arguments length does not match"); - } - let (start_array, stop_array, step_array) = ( - Some(as_date32_array(&args[0])?), - as_date32_array(&args[1])?, - Some(as_interval_mdn_array(&args[2])?), - ); - - let mut values = vec![]; - let mut offsets = vec![0]; - for (idx, stop) in stop_array.iter().enumerate() { - let mut stop = stop.unwrap_or(0); - let start = start_array.as_ref().map(|x| x.value(idx)).unwrap_or(0); - let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); - let (months, days, _) = IntervalMonthDayNanoType::to_parts(step); - let neg = months < 0 || days < 0; - if !include_upper { - stop = Date32Type::subtract_month_day_nano(stop, step); - } - let mut new_date = start; - loop { - if neg && new_date < stop || !neg && new_date > stop { - break; - } - values.push(new_date); - new_date = Date32Type::add_month_day_nano(new_date, step); - } - offsets.push(values.len() as i32); - } - - let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Date32, true)), - OffsetBuffer::new(offsets.into()), - Arc::new(Date32Array::from(values)), - None, - )?); - Ok(arr) -} /// Array_empty SQL function pub fn array_empty(args: &[ArrayRef]) -> Result { diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index f8d85800b3e3..6ed77e5d170c 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -35,6 +35,7 @@ mod except; mod extract; mod kernels; mod position; +mod range; mod remove; mod replace; mod rewrite; @@ -65,6 +66,8 @@ pub mod expr_fn { pub use super::extract::array_slice; pub use super::position::array_position; pub use super::position::array_positions; + pub use super::range::gen_series; + pub use super::range::range; pub use super::remove::array_remove; pub use super::remove::array_remove_all; pub use super::remove::array_remove_n; @@ -86,8 +89,6 @@ pub mod expr_fn { pub use super::udf::array_sort; pub use super::udf::cardinality; pub use super::udf::flatten; - pub use super::udf::gen_series; - pub use super::udf::range; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -95,8 +96,8 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ string::array_to_string_udf(), string::string_to_array_udf(), - udf::range_udf(), - udf::gen_series_udf(), + range::range_udf(), + range::gen_series_udf(), udf::array_dims_udf(), udf::cardinality_udf(), udf::array_ndims_udf(), diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs new file mode 100644 index 000000000000..7dfce71332a1 --- /dev/null +++ b/datafusion/functions-array/src/range.rs @@ -0,0 +1,332 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for range and gen_series functions. + +use arrow::array::{Array, ArrayRef, Int64Array, ListArray}; +use arrow::datatypes::{DataType, Field}; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use std::any::Any; + +use arrow_array::types::{Date32Type, IntervalMonthDayNanoType}; +use arrow_array::Date32Array; +use arrow_schema::IntervalUnit::MonthDayNano; +use datafusion_common::cast::{as_date32_array, as_int64_array, as_interval_mdn_array}; +use datafusion_common::{exec_err, not_impl_datafusion_err, Result}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::Expr; +use datafusion_expr::{ + ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, +}; +use std::sync::Arc; + +make_udf_function!( + Range, + range, + start stop step, + "create a list of values in the range between start and stop", + range_udf +); +#[derive(Debug)] +pub(super) struct Range { + signature: Signature, + aliases: Vec, +} +impl Range { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![Int64]), + TypeSignature::Exact(vec![Int64, Int64]), + TypeSignature::Exact(vec![Int64, Int64, Int64]), + TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("range")], + } + } +} +impl ScalarUDFImpl for Range { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "range" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + match args[0].data_type() { + DataType::Int64 => gen_range_inner(&args, false).map(ColumnarValue::Array), + DataType::Date32 => gen_range_date(&args, false).map(ColumnarValue::Array), + _ => { + exec_err!("unsupported type for range") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +make_udf_function!( + GenSeries, + gen_series, + start stop step, + "create a list of values in the range between start and stop, include upper bound", + gen_series_udf +); +#[derive(Debug)] +pub(super) struct GenSeries { + signature: Signature, + aliases: Vec, +} +impl GenSeries { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + TypeSignature::Exact(vec![Int64]), + TypeSignature::Exact(vec![Int64, Int64]), + TypeSignature::Exact(vec![Int64, Int64, Int64]), + TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("generate_series")], + } + } +} +impl ScalarUDFImpl for GenSeries { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "generate_series" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let args = ColumnarValue::values_to_arrays(args)?; + match args[0].data_type() { + DataType::Int64 => gen_range_inner(&args, true).map(ColumnarValue::Array), + DataType::Date32 => gen_range_date(&args, true).map(ColumnarValue::Array), + _ => { + exec_err!("unsupported type for range") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Generates an array of integers from start to stop with a given step. +/// +/// This function takes 1 to 3 ArrayRefs as arguments, representing start, stop, and step values. +/// It returns a `Result` representing the resulting ListArray after the operation. +/// +/// # Arguments +/// +/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values. +/// +/// # Examples +/// +/// gen_range(3) => [0, 1, 2] +/// gen_range(1, 4) => [1, 2, 3] +/// gen_range(1, 7, 2) => [1, 3, 5] +pub(super) fn gen_range_inner( + args: &[ArrayRef], + include_upper: bool, +) -> Result { + let (start_array, stop_array, step_array) = match args.len() { + 1 => (None, as_int64_array(&args[0])?, None), + 2 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + None, + ), + 3 => ( + Some(as_int64_array(&args[0])?), + as_int64_array(&args[1])?, + Some(as_int64_array(&args[2])?), + ), + _ => return exec_err!("gen_range expects 1 to 3 arguments"), + }; + + let mut values = vec![]; + let mut offsets = vec![0]; + let mut valid = BooleanBufferBuilder::new(stop_array.len()); + for (idx, stop) in stop_array.iter().enumerate() { + match retrieve_range_args(start_array, stop, step_array, idx) { + Some((_, _, 0)) => { + return exec_err!( + "step can't be 0 for function {}(start [, stop, step])", + if include_upper { + "generate_series" + } else { + "range" + } + ); + } + Some((start, stop, step)) => { + // Below, we utilize `usize` to represent steps. + // On 32-bit targets, the absolute value of `i64` may fail to fit into `usize`. + let step_abs = usize::try_from(step.unsigned_abs()).map_err(|_| { + not_impl_datafusion_err!("step {} can't fit into usize", step) + })?; + values.extend( + gen_range_iter(start, stop, step < 0, include_upper) + .step_by(step_abs), + ); + offsets.push(values.len() as i32); + valid.append(true); + } + // If any of the arguments is NULL, append a NULL value to the result. + None => { + offsets.push(values.len() as i32); + valid.append(false); + } + }; + } + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Int64, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Int64Array::from(values)), + Some(NullBuffer::new(valid.finish())), + )?); + Ok(arr) +} + +/// Get the (start, stop, step) args for the range and generate_series function. +/// If any of the arguments is NULL, returns None. +fn retrieve_range_args( + start_array: Option<&Int64Array>, + stop: Option, + step_array: Option<&Int64Array>, + idx: usize, +) -> Option<(i64, i64, i64)> { + // Default start value is 0 if not provided + let start = + start_array.map_or(Some(0), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; + let stop = stop?; + // Default step value is 1 if not provided + let step = + step_array.map_or(Some(1), |arr| arr.is_valid(idx).then(|| arr.value(idx)))?; + Some((start, stop, step)) +} + +/// Returns an iterator of i64 values from start to stop +fn gen_range_iter( + start: i64, + stop: i64, + decreasing: bool, + include_upper: bool, +) -> Box> { + match (decreasing, include_upper) { + // Decreasing range, stop is inclusive + (true, true) => Box::new((stop..=start).rev()), + // Decreasing range, stop is exclusive + (true, false) => { + if stop == i64::MAX { + // start is never greater than stop, and stop is exclusive, + // so the decreasing range must be empty. + Box::new(std::iter::empty()) + } else { + // Increase the stop value by one to exclude it. + // Since stop is not i64::MAX, `stop + 1` will not overflow. + Box::new((stop + 1..=start).rev()) + } + } + // Increasing range, stop is inclusive + (false, true) => Box::new(start..=stop), + // Increasing range, stop is exclusive + (false, false) => Box::new(start..stop), + } +} + +fn gen_range_date(args: &[ArrayRef], include_upper: bool) -> Result { + if args.len() != 3 { + return exec_err!("arguments length does not match"); + } + let (start_array, stop_array, step_array) = ( + Some(as_date32_array(&args[0])?), + as_date32_array(&args[1])?, + Some(as_interval_mdn_array(&args[2])?), + ); + + let mut values = vec![]; + let mut offsets = vec![0]; + for (idx, stop) in stop_array.iter().enumerate() { + let mut stop = stop.unwrap_or(0); + let start = start_array.as_ref().map(|x| x.value(idx)).unwrap_or(0); + let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1); + let (months, days, _) = IntervalMonthDayNanoType::to_parts(step); + let neg = months < 0 || days < 0; + if !include_upper { + stop = Date32Type::subtract_month_day_nano(stop, step); + } + let mut new_date = start; + loop { + if neg && new_date < stop || !neg && new_date > stop { + break; + } + values.push(new_date); + new_date = Date32Type::add_month_day_nano(new_date, step); + } + offsets.push(values.len() as i32); + } + + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Date32, true)), + OffsetBuffer::new(offsets.into()), + Arc::new(Date32Array::from(values)), + None, + )?); + Ok(arr) +} diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index 5f5d90851758..156703105766 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -19,158 +19,16 @@ use arrow::datatypes::DataType; use arrow::datatypes::Field; -use arrow::datatypes::IntervalUnit::MonthDayNano; use arrow_schema::DataType::List; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::Result; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; -use datafusion_expr::TypeSignature; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; -make_udf_function!( - Range, - range, - start stop step, - "create a list of values in the range between start and stop", - range_udf -); -#[derive(Debug)] -pub(super) struct Range { - signature: Signature, - aliases: Vec, -} -impl Range { - pub fn new() -> Self { - use DataType::*; - Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - ], - Volatility::Immutable, - ), - aliases: vec![String::from("range")], - } - } -} -impl ScalarUDFImpl for Range { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "range" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - match args[0].data_type() { - arrow::datatypes::DataType::Int64 => { - crate::kernels::gen_range(&args, false).map(ColumnarValue::Array) - } - arrow::datatypes::DataType::Date32 => { - crate::kernels::gen_range_date(&args, false).map(ColumnarValue::Array) - } - _ => { - exec_err!("unsupported type for range") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - GenSeries, - gen_series, - start stop step, - "create a list of values in the range between start and stop, include upper bound", - gen_series_udf -); -#[derive(Debug)] -pub(super) struct GenSeries { - signature: Signature, - aliases: Vec, -} -impl GenSeries { - pub fn new() -> Self { - use DataType::*; - Self { - signature: Signature::one_of( - vec![ - TypeSignature::Exact(vec![Int64]), - TypeSignature::Exact(vec![Int64, Int64]), - TypeSignature::Exact(vec![Int64, Int64, Int64]), - TypeSignature::Exact(vec![Date32, Date32, Interval(MonthDayNano)]), - ], - Volatility::Immutable, - ), - aliases: vec![String::from("generate_series")], - } - } -} -impl ScalarUDFImpl for GenSeries { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "generate_series" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - match args[0].data_type() { - arrow::datatypes::DataType::Int64 => { - crate::kernels::gen_range(&args, true).map(ColumnarValue::Array) - } - arrow::datatypes::DataType::Date32 => { - crate::kernels::gen_range_date(&args, true).map(ColumnarValue::Array) - } - _ => { - exec_err!("unsupported type for range") - } - } - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - make_udf_function!( ArrayDims, array_dims, From 6c6305159711fbca43ff7798faa52aaebcb2e7d3 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sat, 23 Mar 2024 04:41:02 +0800 Subject: [PATCH 044/117] Move ascii function to datafusion_functions (#9740) * Move ascii function to datafusion_functions Signed-off-by: Chojan Shang * Minor update Signed-off-by: Chojan Shang * Minor update Signed-off-by: Chojan Shang * Move more sqllogictests Signed-off-by: Chojan Shang * Fix leftover merge --------- Signed-off-by: Chojan Shang Co-authored-by: Andrew Lamb --- datafusion/expr/src/built_in_function.rs | 8 +- datafusion/expr/src/expr_fn.rs | 2 - datafusion/functions/src/string/ascii.rs | 91 +++++++++++++++++++ datafusion/functions/src/string/mod.rs | 17 +++- datafusion/physical-expr/src/functions.rs | 36 -------- .../physical-expr/src/string_expressions.rs | 18 ---- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 - datafusion/sqllogictest/test_files/expr.slt | 15 +++ 12 files changed, 127 insertions(+), 78 deletions(-) create mode 100644 datafusion/functions/src/string/ascii.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 785965f6f693..7649c27b392f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -103,8 +103,6 @@ pub enum BuiltinScalarFunction { Cot, // string functions - /// ascii - Ascii, /// bit_length BitLength, /// character_length @@ -240,7 +238,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::Ascii => Volatility::Immutable, BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::CharacterLength => Volatility::Immutable, BuiltinScalarFunction::Chr => Volatility::Immutable, @@ -290,7 +287,6 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::Ascii => Ok(Int32), BuiltinScalarFunction::BitLength => { utf8_to_int_type(&input_expr_types[0], "bit_length") } @@ -429,8 +425,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::Ascii - | BuiltinScalarFunction::BitLength + BuiltinScalarFunction::BitLength | BuiltinScalarFunction::CharacterLength | BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Lower @@ -677,7 +672,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => &["coalesce"], // string functions - BuiltinScalarFunction::Ascii => &["ascii"], BuiltinScalarFunction::BitLength => &["bit_length"], BuiltinScalarFunction::CharacterLength => { &["character_length", "char_length", "length"] diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a834ccab9d15..061b16562e82 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -579,7 +579,6 @@ scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); // string functions -scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); scalar_expr!( BitLength, bit_length, @@ -1063,7 +1062,6 @@ mod test { test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); - test_scalar_expr!(Ascii, ascii, input); test_scalar_expr!(BitLength, bit_length, string); test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Chr, chr, string); diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs new file mode 100644 index 000000000000..5bd77833a935 --- /dev/null +++ b/datafusion/functions/src/string/ascii.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::string::common::make_scalar_function; +use arrow::array::Int32Array; +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +/// Returns the numeric code of the first character of the argument. +/// ascii('x') = 120 +pub fn ascii(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + let mut chars = string.chars(); + chars.next().map_or(0, |v| v as i32) + }) + }) + .collect::(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct AsciiFunc { + signature: Signature, +} +impl AsciiFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for AsciiFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "ascii" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(Int32) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(ascii::, vec![])(args), + DataType::LargeUtf8 => { + return make_scalar_function(ascii::, vec![])(args); + } + _ => internal_err!("Unsupported data type"), + } + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 13c02d5dfac3..63026092f39a 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; +mod ascii; mod btrim; mod common; mod ltrim; @@ -30,6 +31,7 @@ mod to_hex; mod upper; // create UDFs +make_udf_function!(ascii::AsciiFunc, ASCII, ascii); make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); @@ -40,6 +42,11 @@ make_udf_function!(upper::UpperFunc, UPPER, upper); pub mod expr_fn { use datafusion_expr::Expr; + #[doc = "Returns the numeric code of the first character of the argument."] + pub fn ascii(arg1: Expr) -> Expr { + super::ascii().call(vec![arg1]) + } + #[doc = "Removes all characters, spaces by default, from both sides of a string"] pub fn btrim(args: Vec) -> Expr { super::btrim().call(args) @@ -73,5 +80,13 @@ pub mod expr_fn { /// Return a list of all functions in this package pub fn functions() -> Vec> { - vec![btrim(), ltrim(), rtrim(), starts_with(), to_hex(), upper()] + vec![ + ascii(), + btrim(), + ltrim(), + rtrim(), + starts_with(), + to_hex(), + upper(), + ] } diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index a6efe0e0861d..d66af3d22a40 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -252,15 +252,6 @@ pub fn create_physical_fun( Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } // string functions - BuiltinScalarFunction::Ascii => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::ascii::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::ascii::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function ascii"), - }), BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), ColumnarValue::Scalar(v) => match v { @@ -681,33 +672,6 @@ mod tests { #[test] fn test_functions() -> Result<()> { - test_function!(Ascii, &[lit("x")], Ok(Some(120)), i32, Int32, Int32Array); - test_function!(Ascii, &[lit("ésoj")], Ok(Some(233)), i32, Int32, Int32Array); - test_function!( - Ascii, - &[lit("💯")], - Ok(Some(128175)), - i32, - Int32, - Int32Array - ); - test_function!( - Ascii, - &[lit("💯a")], - Ok(Some(128175)), - i32, - Int32, - Int32Array - ); - test_function!(Ascii, &[lit("")], Ok(Some(0)), i32, Int32, Int32Array); - test_function!( - Ascii, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); test_function!( BitLength, &[lit("chars")], diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index f5229d92545e..6877fb18ad4f 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -115,24 +115,6 @@ where } } -/// Returns the numeric code of the first character of the argument. -/// ascii('x') = 120 -pub fn ascii(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - let mut chars = string.chars(); - chars.next().map_or(0, |v| v as i32) - }) - }) - .collect::(); - - Ok(Arc::new(result) as ArrayRef) -} - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 416b49db7aa7..3724eb9be4ad 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -545,7 +545,7 @@ enum ScalarFunction { // 1 was Acos // 2 was Asin Atan = 3; - Ascii = 4; + // 4 was Ascii Ceil = 5; Cos = 6; // 7 was Digest diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 49102137b659..90b0e22c779d 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22915,7 +22915,6 @@ impl serde::Serialize for ScalarFunction { let variant = match self { Self::Unknown => "unknown", Self::Atan => "Atan", - Self::Ascii => "Ascii", Self::Ceil => "Ceil", Self::Cos => "Cos", Self::Exp => "Exp", @@ -22986,7 +22985,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { const FIELDS: &[&str] = &[ "unknown", "Atan", - "Ascii", "Ceil", "Cos", "Exp", @@ -23086,7 +23084,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { match value { "unknown" => Ok(ScalarFunction::Unknown), "Atan" => Ok(ScalarFunction::Atan), - "Ascii" => Ok(ScalarFunction::Ascii), "Ceil" => Ok(ScalarFunction::Ceil), "Cos" => Ok(ScalarFunction::Cos), "Exp" => Ok(ScalarFunction::Exp), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 5e458bfef016..09e2fa07c877 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2844,7 +2844,7 @@ pub enum ScalarFunction { /// 1 was Acos /// 2 was Asin Atan = 3, - Ascii = 4, + /// 4 was Ascii Ceil = 5, Cos = 6, /// 7 was Digest @@ -2988,7 +2988,6 @@ impl ScalarFunction { match self { ScalarFunction::Unknown => "unknown", ScalarFunction::Atan => "Atan", - ScalarFunction::Ascii => "Ascii", ScalarFunction::Ceil => "Ceil", ScalarFunction::Cos => "Cos", ScalarFunction::Exp => "Exp", @@ -3053,7 +3052,6 @@ impl ScalarFunction { match value { "unknown" => Some(Self::Unknown), "Atan" => Some(Self::Atan), - "Ascii" => Some(Self::Ascii), "Ceil" => Some(Self::Ceil), "Cos" => Some(Self::Cos), "Exp" => Some(Self::Exp), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d41add915a96..fc39df6a815b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,8 +48,8 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, ascii, asinh, atan, atan2, atanh, bit_length, cbrt, ceil, character_length, - chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, + acosh, asinh, atan, atan2, atanh, bit_length, cbrt, ceil, character_length, chr, + coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -462,7 +462,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Lower => Self::Lower, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, - ScalarFunction::Ascii => Self::Ascii, ScalarFunction::BitLength => Self::BitLength, ScalarFunction::CharacterLength => Self::CharacterLength, ScalarFunction::Chr => Self::Chr, @@ -1435,9 +1434,6 @@ pub fn parse_expr( ScalarFunction::Lower => { Ok(lower(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Ascii => { - Ok(ascii(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::BitLength => { Ok(bit_length(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 39d663b6c59b..a774444960f3 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1483,7 +1483,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Lower => Self::Lower, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, - BuiltinScalarFunction::Ascii => Self::Ascii, BuiltinScalarFunction::BitLength => Self::BitLength, BuiltinScalarFunction::CharacterLength => Self::CharacterLength, BuiltinScalarFunction::Chr => Self::Chr, diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 69f3e439eac9..70fdc26a6002 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -335,6 +335,21 @@ SELECT ascii(NULL) ---- NULL +query I +SELECT ascii('ésoj') +---- +233 + +query I +SELECT ascii('💯') +---- +128175 + +query I +SELECT ascii('💯a') +---- +128175 + query I SELECT bit_length('') ---- From 1dbec3e10ed299e2b5f337f3c6f14bd4f4560257 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Sat, 23 Mar 2024 06:05:56 -0500 Subject: [PATCH 045/117] adding expr to string for IsNotNull IsTrue IsFalse and IsUnkown (#9739) * adding expr to string for IsNotNull IsTrue IsFalse and IsUnkown * fix clippy * change to raw string --- datafusion/sql/src/unparser/expr.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 43f3e348dc32..8d25a607bb89 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -181,6 +181,18 @@ impl Unparser<'_> { negated: insubq.negated, }) } + Expr::IsNotNull(expr) => { + Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) + } + Expr::IsTrue(expr) => { + Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql(expr)?))) + } + Expr::IsFalse(expr) => { + Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) + } + Expr::IsUnknown(expr) => { + Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) + } _ => not_impl_err!("Unsupported expression: {expr:?}"), } } @@ -599,6 +611,19 @@ mod tests { }), "COUNT(DISTINCT *)", ), + (Expr::IsNotNull(Box::new(col("a"))), r#""a" IS NOT NULL"#), + ( + Expr::IsTrue(Box::new((col("a") + col("b")).gt(lit(4)))), + r#"(("a" + "b") > 4) IS TRUE"#, + ), + ( + Expr::IsFalse(Box::new((col("a") + col("b")).gt(lit(4)))), + r#"(("a" + "b") > 4) IS FALSE"#, + ), + ( + Expr::IsUnknown(Box::new((col("a") + col("b")).gt(lit(4)))), + r#"(("a" + "b") > 4) IS UNKNOWN"#, + ), ]; for (expr, expected) in tests { From 02fd450aa0f43f5ad5981394719a8ec0451e83a0 Mon Sep 17 00:00:00 2001 From: Devin D'Angelo Date: Sat, 23 Mar 2024 07:14:16 -0400 Subject: [PATCH 046/117] fix: parallel parquet can underflow when max_record_batch_rows < execution.batch_size (#9737) * loop split rb * add test * add new test * fmt * lower batch size in test * make test faster * use path not into_path --- datafusion/core/src/dataframe/parquet.rs | 56 +++++++++++++- .../src/datasource/file_format/parquet.rs | 73 ++++++++++--------- 2 files changed, 93 insertions(+), 36 deletions(-) diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index e3f606e322fe..7cc3201bf7e4 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -74,6 +74,7 @@ impl DataFrame { #[cfg(test)] mod tests { + use std::collections::HashMap; use std::sync::Arc; use super::super::Result; @@ -81,9 +82,10 @@ mod tests { use crate::arrow::util::pretty; use crate::execution::context::SessionContext; use crate::execution::options::ParquetReadOptions; - use crate::test_util; + use crate::test_util::{self, register_aggregate_csv}; use datafusion_common::file_options::parquet_writer::parse_compression_string; + use datafusion_execution::config::SessionConfig; use datafusion_expr::{col, lit}; use object_store::local::LocalFileSystem; @@ -150,7 +152,7 @@ mod tests { .await?; // Check that file actually used the specified compression - let file = std::fs::File::open(tmp_dir.into_path().join("test.parquet"))?; + let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?; let reader = parquet::file::serialized_reader::SerializedFileReader::new(file) @@ -166,4 +168,54 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn write_parquet_with_small_rg_size() -> Result<()> { + // This test verifies writing a parquet file with small rg size + // relative to datafusion.execution.batch_size does not panic + let mut ctx = SessionContext::new_with_config( + SessionConfig::from_string_hash_map(HashMap::from_iter( + [("datafusion.execution.batch_size", "10")] + .iter() + .map(|(s1, s2)| (s1.to_string(), s2.to_string())), + ))?, + ); + register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; + let test_df = ctx.table("aggregate_test_100").await?; + + let output_path = "file://local/test.parquet"; + + for rg_size in 1..10 { + let df = test_df.clone(); + let tmp_dir = TempDir::new()?; + let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?); + let local_url = Url::parse("file://local").unwrap(); + let ctx = &test_df.session_state; + ctx.runtime_env().register_object_store(&local_url, local); + let mut options = TableParquetOptions::default(); + options.global.max_row_group_size = rg_size; + options.global.allow_single_file_parallelism = true; + df.write_parquet( + output_path, + DataFrameWriteOptions::new().with_single_file_output(true), + Some(options), + ) + .await?; + + // Check that file actually used the correct rg size + let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?; + + let reader = + parquet::file::serialized_reader::SerializedFileReader::new(file) + .unwrap(); + + let parquet_metadata = reader.metadata(); + + let written_rows = parquet_metadata.row_group(0).num_rows(); + + assert_eq!(written_rows as usize, rg_size); + } + + Ok(()) + } } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index ec333bb557d2..bcf4e8a2c8e4 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -876,42 +876,47 @@ fn spawn_parquet_parallel_serialization_task( )?; let mut current_rg_rows = 0; - while let Some(rb) = data.recv().await { - if current_rg_rows + rb.num_rows() < max_row_group_rows { - send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone()) - .await?; - current_rg_rows += rb.num_rows(); - } else { - let rows_left = max_row_group_rows - current_rg_rows; - let a = rb.slice(0, rows_left); - send_arrays_to_col_writers(&col_array_channels, &a, schema.clone()) - .await?; + while let Some(mut rb) = data.recv().await { + // This loop allows the "else" block to repeatedly split the RecordBatch to handle the case + // when max_row_group_rows < execution.batch_size as an alternative to a recursive async + // function. + loop { + if current_rg_rows + rb.num_rows() < max_row_group_rows { + send_arrays_to_col_writers(&col_array_channels, &rb, schema.clone()) + .await?; + current_rg_rows += rb.num_rows(); + break; + } else { + let rows_left = max_row_group_rows - current_rg_rows; + let a = rb.slice(0, rows_left); + send_arrays_to_col_writers(&col_array_channels, &a, schema.clone()) + .await?; + + // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup + // on a separate task, so that we can immediately start on the next RG before waiting + // for the current one to finish. + drop(col_array_channels); + let finalize_rg_task = spawn_rg_join_and_finalize_task( + column_writer_handles, + max_row_group_rows, + ); + + serialize_tx.send(finalize_rg_task).await.map_err(|_| { + DataFusionError::Internal( + "Unable to send closed RG to concat task!".into(), + ) + })?; - // Signal the parallel column writers that the RowGroup is done, join and finalize RowGroup - // on a separate task, so that we can immediately start on the next RG before waiting - // for the current one to finish. - drop(col_array_channels); - let finalize_rg_task = spawn_rg_join_and_finalize_task( - column_writer_handles, - max_row_group_rows, - ); - - serialize_tx.send(finalize_rg_task).await.map_err(|_| { - DataFusionError::Internal( - "Unable to send closed RG to concat task!".into(), - ) - })?; + current_rg_rows = 0; + rb = rb.slice(rows_left, rb.num_rows() - rows_left); - let b = rb.slice(rows_left, rb.num_rows() - rows_left); - (column_writer_handles, col_array_channels) = - spawn_column_parallel_row_group_writer( - schema.clone(), - writer_props.clone(), - max_buffer_rb, - )?; - send_arrays_to_col_writers(&col_array_channels, &b, schema.clone()) - .await?; - current_rg_rows = b.num_rows(); + (column_writer_handles, col_array_channels) = + spawn_column_parallel_row_group_writer( + schema.clone(), + writer_props.clone(), + max_buffer_rb, + )?; + } } } From 40fb1b859be4dd399922c498d49b9b847874af2b Mon Sep 17 00:00:00 2001 From: Kunal Kundu Date: Sat, 23 Mar 2024 16:57:45 +0530 Subject: [PATCH 047/117] support format in options of COPY command (#9744) * support format in options of COPY command * fix clippy lint error * add testcase to verify priority b/w STORED AS and OPTIONS (format <>) --- datafusion/sql/src/statement.rs | 12 ++++--- datafusion/sql/tests/sql_integration.rs | 12 +++++++ datafusion/sqllogictest/test_files/copy.slt | 38 +++++++++++++++++++++ 3 files changed, 58 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 4cca4c114a91..7717f75d16b8 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -850,7 +850,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return plan_err!("Unsupported Value in COPY statement {}", value); } }; - if !(&key.contains('.')) { + if !(key.contains('.') || key == "format") { // If config does not belong to any namespace, assume it is // a format option and apply the format prefix for backwards // compatibility. @@ -866,12 +866,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FileType::from_str(&file_type).map_err(|_| { DataFusionError::Configuration(format!("Unknown FileType {}", file_type)) })? + } else if let Some(format) = options.remove("format") { + // try to infer file format from the "format" key in options + FileType::from_str(&format) + .map_err(|e| DataFusionError::Configuration(format!("{}", e)))? } else { let e = || { DataFusionError::Configuration( - "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." - .to_string(), - ) + "Format not explicitly set and unable to get file extension! Use STORED AS to define file format." + .to_string(), + ) }; // try to infer file format from file extension let extension: &str = &Path::new(&statement.target) diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 47638e58ff00..c738a2bd754f 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -442,6 +442,18 @@ CopyTo: format=csv output_url=output.csv options: () quick_test(sql, plan); } +#[test] +fn plan_copy_stored_as_priority() { + let sql = "COPY (select * from (values (1))) to 'output/' STORED AS CSV OPTIONS (format json)"; + let plan = r#" +CopyTo: format=csv output_url=output/ options: (format json) + Projection: column1 + Values: (Int64(1)) + "# + .trim(); + quick_test(sql, plan); +} + #[test] fn plan_insert() { let sql = diff --git a/datafusion/sqllogictest/test_files/copy.slt b/datafusion/sqllogictest/test_files/copy.slt index 75f1ccb07aac..fca892dfcdad 100644 --- a/datafusion/sqllogictest/test_files/copy.slt +++ b/datafusion/sqllogictest/test_files/copy.slt @@ -514,6 +514,44 @@ OPTIONS ( ); +# Format Options Support with format in OPTIONS i.e. COPY { table_name | query } TO 'file_name' OPTIONS (format , ...) + +query I +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS (format parquet); +---- +1 + +query I +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS (format parquet, compression 'zstd(10)'); +---- +1 + +query I +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS (format json, compression gzip); +---- +1 + +query I +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS ( + format csv, + has_header false, + compression xz, + datetime_format '%FT%H:%M:%S.%9f', + delimiter ';', + null_value 'NULLVAL' +); +---- +1 + +query error DataFusion error: Invalid or Unsupported Configuration: This feature is not implemented: Unknown FileType: NOTVALIDFORMAT +COPY (select * from (values (1))) to 'test_files/scratch/copy/' +OPTIONS (format notvalidformat, compression 'zstd(5)'); + + # Error cases: # Copy from table with options From e91d100ebbeba9fcbb8a762573172021b581e1c5 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sat, 23 Mar 2024 08:20:14 -0400 Subject: [PATCH 048/117] Move lower, octet_length to datafusion-functions (#9747) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Move trim functions to datafusion-functions * Doc updates for ltrim, rtrim and trim to reflect how they actually function. * Fixed struct name Trim -> BTrim * Move lower, octet_length to datafusion-functions --- datafusion/expr/src/built_in_function.rs | 16 -- datafusion/expr/src/expr_fn.rs | 9 - datafusion/functions/src/string/common.rs | 67 +++++++ datafusion/functions/src/string/lower.rs | 63 +++++++ datafusion/functions/src/string/mod.rs | 16 ++ .../functions/src/string/octet_length.rs | 173 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 40 +--- .../physical-expr/src/string_expressions.rs | 86 --------- datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 18 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 - 13 files changed, 329 insertions(+), 179 deletions(-) create mode 100644 datafusion/functions/src/string/lower.rs create mode 100644 datafusion/functions/src/string/octet_length.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 7649c27b392f..d0ec1326c49e 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -121,10 +121,6 @@ pub enum BuiltinScalarFunction { Left, /// lpad Lpad, - /// lower - Lower, - /// octet_length - OctetLength, /// random Random, /// repeat @@ -247,8 +243,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => Volatility::Immutable, BuiltinScalarFunction::Left => Volatility::Immutable, BuiltinScalarFunction::Lpad => Volatility::Immutable, - BuiltinScalarFunction::Lower => Volatility::Immutable, - BuiltinScalarFunction::OctetLength => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, BuiltinScalarFunction::Repeat => Volatility::Immutable, BuiltinScalarFunction::Replace => Volatility::Immutable, @@ -305,13 +299,7 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "initcap") } BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), - BuiltinScalarFunction::Lower => { - utf8_to_str_type(&input_expr_types[0], "lower") - } BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), - BuiltinScalarFunction::OctetLength => { - utf8_to_int_type(&input_expr_types[0], "octet_length") - } BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::Uuid => Ok(Utf8), @@ -428,8 +416,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::BitLength | BuiltinScalarFunction::CharacterLength | BuiltinScalarFunction::InitCap - | BuiltinScalarFunction::Lower - | BuiltinScalarFunction::OctetLength | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } @@ -682,9 +668,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lower => &["lower"], BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::OctetLength => &["octet_length"], BuiltinScalarFunction::Repeat => &["repeat"], BuiltinScalarFunction::Replace => &["replace"], BuiltinScalarFunction::Reverse => &["reverse"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 061b16562e82..e1ab11c5b778 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -599,13 +599,6 @@ scalar_expr!( ); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); -scalar_expr!(Lower, lower, string, "convert the string to lower case"); -scalar_expr!( - OctetLength, - octet_length, - string, - "returns the number of bytes of a string" -); scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `from` with `to` in the `string`"); scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); @@ -1069,10 +1062,8 @@ mod test { test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(Left, left, string, count); - test_scalar_expr!(Lower, lower, string); test_nary_scalar_expr!(Lpad, lpad, string, count); test_nary_scalar_expr!(Lpad, lpad, string, count, characters); - test_scalar_expr!(OctetLength, octet_length, string); test_scalar_expr!(Replace, replace, string, from, to); test_scalar_expr!(Repeat, repeat, string, count); test_scalar_expr!(Reverse, reverse, string); diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 97465420fb99..339f4e6c1a23 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -141,6 +141,9 @@ macro_rules! get_optimal_return_type { // `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); +// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. +get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); + /// applies a unary expression to `args[0]` that is expected to be downcastable to /// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) /// # Errors @@ -263,3 +266,67 @@ where } }) } + +#[cfg(test)] +pub mod test { + /// $FUNC ScalarUDFImpl to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result + /// $EXPECTED_TYPE is the expected value type + /// $EXPECTED_DATA_TYPE is the expected result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_function { + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + let expected: Result> = $EXPECTED; + let func = $FUNC; + + let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let return_type = func.return_type(&type_array); + + match expected { + Ok(expected) => { + assert_eq!(return_type.is_ok(), true); + assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); + + let result = func.invoke($ARGS); + assert_eq!(result.is_ok(), true); + + let len = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let inferred_length = len.unwrap_or(1); + let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + if return_type.is_err() { + match return_type { + Ok(_) => assert!(false, "expected error"), + Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } + } + } + else { + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke($ARGS) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); + } + } + } + } + }; + }; + } + + pub(crate) use test_function; +} diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs new file mode 100644 index 000000000000..42bda0470067 --- /dev/null +++ b/datafusion/functions/src/string/lower.rs @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::string::common::{handle, utf8_to_str_type}; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +#[derive(Debug)] +pub(super) struct LowerFunc { + signature: Signature, +} + +impl LowerFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LowerFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "lower" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "lower") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + handle(args, |string| string.to_lowercase(), "lower") + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 63026092f39a..a70a695e935b 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -24,7 +24,9 @@ use datafusion_expr::ScalarUDF; mod ascii; mod btrim; mod common; +mod lower; mod ltrim; +mod octet_length; mod rtrim; mod starts_with; mod to_hex; @@ -34,6 +36,8 @@ mod upper; make_udf_function!(ascii::AsciiFunc, ASCII, ascii); make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); +make_udf_function!(lower::LowerFunc, LOWER, lower); +make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length); make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); @@ -52,11 +56,21 @@ pub mod expr_fn { super::btrim().call(args) } + #[doc = "Converts a string to lowercase."] + pub fn lower(arg1: Expr) -> Expr { + super::lower().call(vec![arg1]) + } + #[doc = "Removes all characters, spaces by default, from the beginning of a string"] pub fn ltrim(args: Vec) -> Expr { super::ltrim().call(args) } + #[doc = "returns the number of bytes of a string"] + pub fn octet_length(args: Vec) -> Expr { + super::octet_length().call(args) + } + #[doc = "Removes all characters, spaces by default, from the end of a string"] pub fn rtrim(args: Vec) -> Expr { super::rtrim().call(args) @@ -83,7 +97,9 @@ pub fn functions() -> Vec> { vec![ ascii(), btrim(), + lower(), ltrim(), + octet_length(), rtrim(), starts_with(), to_hex(), diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs new file mode 100644 index 000000000000..36a62fbe4e38 --- /dev/null +++ b/datafusion/functions/src/string/octet_length.rs @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::kernels::length::length; +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +#[derive(Debug)] +pub(super) struct OctetLengthFunc { + signature: Signature, +} + +impl OctetLengthFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for OctetLengthFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "octet_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "octet_length") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "octet_length function requires 1 argument, got {}", + args.len() + ); + } + + match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| x.len() as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), + )), + _ => unreachable!(), + }, + } + } +} + +#[cfg(test)] +mod tests { + use crate::string::common::test::test_function; + use crate::string::octet_length::OctetLengthFunc; + use arrow::array::{Array, Int32Array, StringArray}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::ScalarValue; + use datafusion_common::{exec_err, Result}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], + exec_err!( + "The OCTET_LENGTH function can only accept strings, but got Int32." + ), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Array(Arc::new(StringArray::from(vec![ + String::from("chars"), + String::from("chars2"), + ])))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))) + ], + exec_err!("octet_length function requires 1 argument, got 2"), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("chars") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("josé") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("") + )))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + test_function!( + OctetLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index d66af3d22a40..2436fa24d4ef 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -37,7 +37,7 @@ use crate::{ }; use arrow::{ array::ArrayRef, - compute::kernels::length::{bit_length, length}, + compute::kernels::length::bit_length, datatypes::{DataType, Int32Type, Int64Type, Schema}, }; use arrow_array::Array; @@ -317,7 +317,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function left"), }), - BuiltinScalarFunction::Lower => Arc::new(string_expressions::lower), BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); @@ -329,18 +328,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function lpad"), }), - BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( - v.as_ref().map(|x| x.len() as i32), - ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), - )), - _ => unreachable!(), - }, - }), BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::repeat::)(args) @@ -1160,31 +1147,6 @@ mod tests { Utf8, StringArray ); - test_function!( - OctetLength, - &[lit("chars")], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - test_function!( - OctetLength, - &[lit("josé")], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - test_function!(OctetLength, &[lit("")], Ok(Some(0)), i32, Int32, Int32Array); - test_function!( - OctetLength, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); test_function!( Repeat, &[lit("Pg"), lit(ScalarValue::Int64(Some(4))),], diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 6877fb18ad4f..13e4ce77e0ac 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -41,80 +41,6 @@ use datafusion_common::{ }; use datafusion_expr::ColumnarValue; -/// applies a unary expression to `args[0]` that is expected to be downcastable to -/// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) -/// # Errors -/// This function errors when: -/// * the number of arguments is not 1 -/// * the first argument is not castable to a `GenericStringArray` -pub(crate) fn unary_string_function<'a, T, O, F, R>( - args: &[&'a dyn Array], - op: F, - name: &str, -) -> Result> -where - R: AsRef, - O: OffsetSizeTrait, - T: OffsetSizeTrait, - F: Fn(&'a str) -> R, -{ - if args.len() != 1 { - return exec_err!( - "{:?} args were supplied but {} takes exactly one argument", - args.len(), - name - ); - } - - let string_array = as_generic_string_array::(args[0])?; - - // first map is the iterator, second is for the `Option<_>` - Ok(string_array.iter().map(|string| string.map(&op)).collect()) -} - -fn handle<'a, F, R>(args: &'a [ColumnarValue], op: F, name: &str) -> Result -where - R: AsRef, - F: Fn(&'a str) -> R, -{ - match &args[0] { - ColumnarValue::Array(a) => match a.data_type() { - DataType::Utf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i32, - i32, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - DataType::LargeUtf8 => { - Ok(ColumnarValue::Array(Arc::new(unary_string_function::< - i64, - i64, - _, - _, - >( - &[a.as_ref()], op, name - )?))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) - } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| (op)(x).as_ref().to_string()); - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) - } - other => exec_err!("Unsupported data type {other:?} for function {name}"), - }, - } -} - /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' pub fn chr(args: &[ArrayRef]) -> Result { @@ -319,12 +245,6 @@ pub fn instr(args: &[ArrayRef]) -> Result { } } -/// Converts the string to all lower case. -/// lower('TOM') = 'tom' -pub fn lower(args: &[ColumnarValue]) -> Result { - handle(args, |string| string.to_lowercase(), "lower") -} - /// Repeats string the specified number of times. /// repeat('Pg', 4) = 'PgPgPgPg' pub fn repeat(args: &[ArrayRef]) -> Result { @@ -414,12 +334,6 @@ pub fn ends_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } -/// Converts the string to all upper case. -/// upper('tom') = 'TOM' -pub fn upper(args: &[ColumnarValue]) -> Result { - handle(args, |string| string.to_uppercase(), "upper") -} - /// Prints random (v4) uuid values per row /// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' pub fn uuid(args: &[ColumnarValue]) -> Result { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 3724eb9be4ad..e4953283b184 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -574,11 +574,11 @@ enum ScalarFunction { InitCap = 30; Left = 31; Lpad = 32; - Lower = 33; + // 33 was Lower // 34 was Ltrim // 35 was MD5 // 36 was NullIf - OctetLength = 37; + // 37 was OctetLength Random = 38; // 39 was RegexpReplace Repeat = 40; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 90b0e22c779d..7cdebdf85944 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22936,8 +22936,6 @@ impl serde::Serialize for ScalarFunction { Self::InitCap => "InitCap", Self::Left => "Left", Self::Lpad => "Lpad", - Self::Lower => "Lower", - Self::OctetLength => "OctetLength", Self::Random => "Random", Self::Repeat => "Repeat", Self::Replace => "Replace", @@ -23006,8 +23004,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "InitCap", "Left", "Lpad", - "Lower", - "OctetLength", "Random", "Repeat", "Replace", @@ -23105,8 +23101,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "InitCap" => Ok(ScalarFunction::InitCap), "Left" => Ok(ScalarFunction::Left), "Lpad" => Ok(ScalarFunction::Lpad), - "Lower" => Ok(ScalarFunction::Lower), - "OctetLength" => Ok(ScalarFunction::OctetLength), "Random" => Ok(ScalarFunction::Random), "Repeat" => Ok(ScalarFunction::Repeat), "Replace" => Ok(ScalarFunction::Replace), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 09e2fa07c877..2932bcf6d93f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2873,11 +2873,11 @@ pub enum ScalarFunction { InitCap = 30, Left = 31, Lpad = 32, - Lower = 33, + /// 33 was Lower /// 34 was Ltrim /// 35 was MD5 /// 36 was NullIf - OctetLength = 37, + /// 37 was OctetLength Random = 38, /// 39 was RegexpReplace Repeat = 40, @@ -3009,8 +3009,6 @@ impl ScalarFunction { ScalarFunction::InitCap => "InitCap", ScalarFunction::Left => "Left", ScalarFunction::Lpad => "Lpad", - ScalarFunction::Lower => "Lower", - ScalarFunction::OctetLength => "OctetLength", ScalarFunction::Random => "Random", ScalarFunction::Repeat => "Repeat", ScalarFunction::Replace => "Replace", @@ -3073,8 +3071,6 @@ impl ScalarFunction { "InitCap" => Some(Self::InitCap), "Left" => Some(Self::Left), "Lpad" => Some(Self::Lpad), - "Lower" => Some(Self::Lower), - "OctetLength" => Some(Self::OctetLength), "Random" => Some(Self::Random), "Repeat" => Some(Self::Repeat), "Replace" => Some(Self::Replace), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index fc39df6a815b..d00aeeda462b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -54,11 +54,11 @@ use datafusion_expr::{ factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, levenshtein, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lower, lpad, nanvl, octet_length, overlay, pi, power, radians, random, repeat, - replace, reverse, right, round, rpad, signum, sin, sinh, split_part, sqrt, strpos, - substr, substr_index, substring, translate, trunc, uuid, AggregateFunction, Between, - BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, - GetFieldAccess, GetIndexedField, GroupingSet, + lpad, nanvl, overlay, pi, power, radians, random, repeat, replace, reverse, right, + round, rpad, signum, sin, sinh, split_part, sqrt, strpos, substr, substr_index, + substring, translate, trunc, uuid, AggregateFunction, Between, BinaryExpr, + BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, + GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -457,9 +457,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Ceil => Self::Ceil, ScalarFunction::Round => Self::Round, ScalarFunction::Trunc => Self::Trunc, - ScalarFunction::OctetLength => Self::OctetLength, ScalarFunction::Concat => Self::Concat, - ScalarFunction::Lower => Self::Lower, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, ScalarFunction::BitLength => Self::BitLength, @@ -1428,12 +1426,6 @@ pub fn parse_expr( ScalarFunction::Signum => { Ok(signum(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::OctetLength => { - Ok(octet_length(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Lower => { - Ok(lower(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::BitLength => { Ok(bit_length(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index a774444960f3..edb8c4e4eb01 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1478,9 +1478,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Ceil => Self::Ceil, BuiltinScalarFunction::Round => Self::Round, BuiltinScalarFunction::Trunc => Self::Trunc, - BuiltinScalarFunction::OctetLength => Self::OctetLength, BuiltinScalarFunction::Concat => Self::Concat, - BuiltinScalarFunction::Lower => Self::Lower, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, BuiltinScalarFunction::BitLength => Self::BitLength, From 01ff53771a5e866813ef9636e3f7eec6b88ce4a4 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sat, 23 Mar 2024 08:20:47 -0400 Subject: [PATCH 049/117] Fixed missing trim() in rust api (#9749) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Fixed missing trim() function. --- datafusion/functions/src/string/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index a70a695e935b..a6d844932655 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -86,6 +86,11 @@ pub mod expr_fn { super::to_hex().call(vec![arg1]) } + #[doc = "Removes all characters, spaces by default, from both sides of a string"] + pub fn trim(args: Vec) -> Expr { + super::btrim().call(args) + } + #[doc = "Converts a string to uppercase."] pub fn upper(arg1: Expr) -> Expr { super::upper().call(vec![arg1]) From 6eda0e2e35c99b64e107f860598085b8e06b2277 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Sat, 23 Mar 2024 17:57:13 -0700 Subject: [PATCH 050/117] Issue-9750 - Extract array_length, array_reverse and array_sort functions from functions-array subcrate' s kernels and udf containers (#9751) --- datafusion/functions-array/src/kernels.rs | 232 +--------------------- datafusion/functions-array/src/length.rs | 158 +++++++++++++++ datafusion/functions-array/src/lib.rs | 15 +- datafusion/functions-array/src/range.rs | 19 +- datafusion/functions-array/src/reverse.rs | 147 ++++++++++++++ datafusion/functions-array/src/sort.rs | 176 ++++++++++++++++ datafusion/functions-array/src/udf.rs | 167 ---------------- 7 files changed, 507 insertions(+), 407 deletions(-) create mode 100644 datafusion/functions-array/src/length.rs create mode 100644 datafusion/functions-array/src/reverse.rs create mode 100644 datafusion/functions-array/src/sort.rs diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index 9b743fa913cf..1fb3abd52906 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -18,19 +18,18 @@ //! implementation kernels for array functions use arrow::array::{ - Array, ArrayRef, BooleanArray, Capacities, GenericListArray, Int64Array, - LargeListArray, ListArray, MutableArrayData, OffsetSizeTrait, UInt64Array, + Array, ArrayRef, BooleanArray, Capacities, GenericListArray, Int64Array, ListArray, + MutableArrayData, OffsetSizeTrait, UInt64Array, }; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; use arrow_array::new_null_array; -use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_buffer::{ArrowNativeType, OffsetBuffer}; use arrow_schema::FieldRef; -use arrow_schema::SortOptions; use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, - as_null_array, as_string_array, + as_null_array, }; use datafusion_common::{ exec_err, internal_datafusion_err, DataFusionError, Result, ScalarValue, @@ -190,64 +189,6 @@ fn general_array_empty(array: &ArrayRef) -> Result Ok(Arc::new(builder)) } -/// Returns the length of a concrete array dimension -fn compute_array_length( - arr: Option, - dimension: Option, -) -> Result> { - let mut current_dimension: i64 = 1; - let mut value = match arr { - Some(arr) => arr, - None => return Ok(None), - }; - let dimension = match dimension { - Some(value) => { - if value < 1 { - return Ok(None); - } - - value - } - None => return Ok(None), - }; - - loop { - if current_dimension == dimension { - return Ok(Some(value.len() as u64)); - } - - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - current_dimension += 1; - } - DataType::LargeList(..) => { - value = downcast_arg!(value, LargeListArray).value(0); - current_dimension += 1; - } - _ => return Ok(None), - } - } -} - -/// Dispatch array length computation based on the offset type. -fn general_array_length(array: &[ArrayRef]) -> Result { - let list_array = as_generic_list_array::(&array[0])?; - let dimension = if array.len() == 2 { - as_int64_array(&array[1])?.clone() - } else { - Int64Array::from_value(1, list_array.len()) - }; - - let result = list_array - .iter() - .zip(dimension.iter()) - .map(|(arr, dim)| compute_array_length(arr, dim)) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) -} - /// Array_repeat SQL function pub fn array_repeat(args: &[ArrayRef]) -> Result { if args.len() != 2 { @@ -394,19 +335,6 @@ fn general_list_repeat( )?)) } -/// Array_length SQL function -pub fn array_length(args: &[ArrayRef]) -> Result { - if args.len() != 1 && args.len() != 2 { - return exec_err!("array_length expects one or two arguments"); - } - - match &args[0].data_type() { - DataType::List(_) => general_array_length::(args), - DataType::LargeList(_) => general_array_length::(args), - array_type => exec_err!("array_length does not support type '{array_type:?}'"), - } -} - /// array_resize SQL function pub fn array_resize(arg: &[ArrayRef]) -> Result { if arg.len() < 2 || arg.len() > 3 { @@ -501,89 +429,6 @@ where )?)) } -/// Array_sort SQL function -pub fn array_sort(args: &[ArrayRef]) -> Result { - if args.is_empty() || args.len() > 3 { - return exec_err!("array_sort expects one to three arguments"); - } - - let sort_option = match args.len() { - 1 => None, - 2 => { - let sort = as_string_array(&args[1])?.value(0); - Some(SortOptions { - descending: order_desc(sort)?, - nulls_first: true, - }) - } - 3 => { - let sort = as_string_array(&args[1])?.value(0); - let nulls_first = as_string_array(&args[2])?.value(0); - Some(SortOptions { - descending: order_desc(sort)?, - nulls_first: order_nulls_first(nulls_first)?, - }) - } - _ => return exec_err!("array_sort expects 1 to 3 arguments"), - }; - - let list_array = as_list_array(&args[0])?; - let row_count = list_array.len(); - - let mut array_lengths = vec![]; - let mut arrays = vec![]; - let mut valid = BooleanBufferBuilder::new(row_count); - for i in 0..row_count { - if list_array.is_null(i) { - array_lengths.push(0); - valid.append(false); - } else { - let arr_ref = list_array.value(i); - let arr_ref = arr_ref.as_ref(); - - let sorted_array = compute::sort(arr_ref, sort_option)?; - array_lengths.push(sorted_array.len()); - arrays.push(sorted_array); - valid.append(true); - } - } - - // Assume all arrays have the same data type - let data_type = list_array.value_type(); - let buffer = valid.finish(); - - let elements = arrays - .iter() - .map(|a| a.as_ref()) - .collect::>(); - - let list_arr = ListArray::new( - Arc::new(Field::new("item", data_type, true)), - OffsetBuffer::from_lengths(array_lengths), - Arc::new(compute::concat(elements.as_slice())?), - Some(NullBuffer::new(buffer)), - ); - Ok(Arc::new(list_arr)) -} - -fn order_desc(modifier: &str) -> Result { - match modifier.to_uppercase().as_str() { - "DESC" => Ok(true), - "ASC" => Ok(false), - _ => exec_err!("the second parameter of array_sort expects DESC or ASC"), - } -} - -fn order_nulls_first(modifier: &str) -> Result { - match modifier.to_uppercase().as_str() { - "NULLS FIRST" => Ok(true), - "NULLS LAST" => Ok(false), - _ => exec_err!( - "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" - ), - } -} - // Create new offsets that are euqiavlent to `flatten` the array. fn get_offsets_for_flatten( offsets: OffsetBuffer, @@ -652,72 +497,3 @@ pub fn flatten(args: &[ArrayRef]) -> Result { } } } - -/// array_reverse SQL function -pub fn array_reverse(arg: &[ArrayRef]) -> Result { - if arg.len() != 1 { - return exec_err!("array_reverse needs one argument"); - } - - match &arg[0].data_type() { - DataType::List(field) => { - let array = as_list_array(&arg[0])?; - general_array_reverse::(array, field) - } - DataType::LargeList(field) => { - let array = as_large_list_array(&arg[0])?; - general_array_reverse::(array, field) - } - DataType::Null => Ok(arg[0].clone()), - array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), - } -} - -fn general_array_reverse( - array: &GenericListArray, - field: &FieldRef, -) -> Result -where - O: TryFrom, -{ - let values = array.values(); - let original_data = values.to_data(); - let capacity = Capacities::Array(original_data.len()); - let mut offsets = vec![O::usize_as(0)]; - let mut nulls = vec![]; - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], false, capacity); - - for (row_index, offset_window) in array.offsets().windows(2).enumerate() { - // skip the null value - if array.is_null(row_index) { - nulls.push(false); - offsets.push(offsets[row_index] + O::one()); - mutable.extend(0, 0, 1); - continue; - } else { - nulls.push(true); - } - - let start = offset_window[0]; - let end = offset_window[1]; - - let mut index = end - O::one(); - let mut cnt = 0; - - while index >= start { - mutable.extend(0, index.to_usize().unwrap(), index.to_usize().unwrap() + 1); - index = index - O::one(); - cnt += 1; - } - offsets.push(offsets[row_index] + O::usize_as(cnt)); - } - - let data = mutable.freeze(); - Ok(Arc::new(GenericListArray::::try_new( - field.clone(), - OffsetBuffer::::new(offsets.into()), - arrow_array::make_array(data), - Some(nulls.into()), - )?)) -} diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-array/src/length.rs new file mode 100644 index 000000000000..e8e361131763 --- /dev/null +++ b/datafusion/functions-array/src/length.rs @@ -0,0 +1,158 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_length function. + +use crate::utils::{downcast_arg, make_scalar_function}; +use arrow_array::{ + Array, ArrayRef, Int64Array, LargeListArray, ListArray, OffsetSizeTrait, UInt64Array, +}; +use arrow_schema::DataType; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; +use core::any::type_name; +use datafusion_common::cast::{as_generic_list_array, as_int64_array}; +use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayLength, + array_length, + array, + "returns the length of the array dimension.", + array_length_udf +); + +#[derive(Debug)] +pub(super) struct ArrayLength { + signature: Signature, + aliases: Vec, +} +impl ArrayLength { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("array_length"), String::from("list_length")], + } + } +} + +impl ScalarUDFImpl for ArrayLength { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + _ => { + return plan_err!("The array_length function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_length_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_length SQL function +pub fn array_length_inner(args: &[ArrayRef]) -> datafusion_common::Result { + if args.len() != 1 && args.len() != 2 { + return exec_err!("array_length expects one or two arguments"); + } + + match &args[0].data_type() { + List(_) => general_array_length::(args), + LargeList(_) => general_array_length::(args), + array_type => exec_err!("array_length does not support type '{array_type:?}'"), + } +} + +/// Dispatch array length computation based on the offset type. +fn general_array_length( + array: &[ArrayRef], +) -> datafusion_common::Result { + let list_array = as_generic_list_array::(&array[0])?; + let dimension = if array.len() == 2 { + as_int64_array(&array[1])?.clone() + } else { + Int64Array::from_value(1, list_array.len()) + }; + + let result = list_array + .iter() + .zip(dimension.iter()) + .map(|(arr, dim)| compute_array_length(arr, dim)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Returns the length of a concrete array dimension +fn compute_array_length( + arr: Option, + dimension: Option, +) -> datafusion_common::Result> { + let mut current_dimension: i64 = 1; + let mut value = match arr { + Some(arr) => arr, + None => return Ok(None), + }; + let dimension = match dimension { + Some(value) => { + if value < 1 { + return Ok(None); + } + + value + } + None => return Ok(None), + }; + + loop { + if current_dimension == dimension { + return Ok(Some(value.len() as u64)); + } + + match value.data_type() { + List(..) => { + value = downcast_arg!(value, ListArray).value(0); + current_dimension += 1; + } + LargeList(..) => { + value = downcast_arg!(value, LargeListArray).value(0); + current_dimension += 1; + } + _ => return Ok(None), + } + } +} diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 6ed77e5d170c..f4ca5408aa6b 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -34,12 +34,15 @@ mod core; mod except; mod extract; mod kernels; +mod length; mod position; mod range; mod remove; mod replace; +mod reverse; mod rewrite; mod set_ops; +mod sort; mod string; mod udf; mod utils; @@ -64,6 +67,7 @@ pub mod expr_fn { pub use super::extract::array_pop_back; pub use super::extract::array_pop_front; pub use super::extract::array_slice; + pub use super::length::array_length; pub use super::position::array_position; pub use super::position::array_positions; pub use super::range::gen_series; @@ -74,19 +78,18 @@ pub mod expr_fn { pub use super::replace::array_replace; pub use super::replace::array_replace_all; pub use super::replace::array_replace_n; + pub use super::reverse::array_reverse; pub use super::set_ops::array_distinct; pub use super::set_ops::array_intersect; pub use super::set_ops::array_union; + pub use super::sort::array_sort; pub use super::string::array_to_string; pub use super::string::string_to_array; pub use super::udf::array_dims; pub use super::udf::array_empty; - pub use super::udf::array_length; pub use super::udf::array_ndims; pub use super::udf::array_repeat; pub use super::udf::array_resize; - pub use super::udf::array_reverse; - pub use super::udf::array_sort; pub use super::udf::cardinality; pub use super::udf::flatten; } @@ -114,12 +117,12 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { array_has::array_has_all_udf(), array_has::array_has_any_udf(), udf::array_empty_udf(), - udf::array_length_udf(), + length::array_length_udf(), udf::flatten_udf(), - udf::array_sort_udf(), + sort::array_sort_udf(), udf::array_repeat_udf(), udf::array_resize_udf(), - udf::array_reverse_udf(), + reverse::array_reverse_udf(), set_ops::array_distinct_udf(), set_ops::array_intersect_udf(), set_ops::array_union_udf(), diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs index 7dfce71332a1..176a5617d599 100644 --- a/datafusion/functions-array/src/range.rs +++ b/datafusion/functions-array/src/range.rs @@ -22,6 +22,7 @@ use arrow::datatypes::{DataType, Field}; use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use std::any::Any; +use crate::utils::make_scalar_function; use arrow_array::types::{Date32Type, IntervalMonthDayNanoType}; use arrow_array::Date32Array; use arrow_schema::IntervalUnit::MonthDayNano; @@ -85,10 +86,13 @@ impl ScalarUDFImpl for Range { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; match args[0].data_type() { - DataType::Int64 => gen_range_inner(&args, false).map(ColumnarValue::Array), - DataType::Date32 => gen_range_date(&args, false).map(ColumnarValue::Array), + DataType::Int64 => { + make_scalar_function(|args| gen_range_inner(args, false))(args) + } + DataType::Date32 => { + make_scalar_function(|args| gen_range_date(args, false))(args) + } _ => { exec_err!("unsupported type for range") } @@ -151,10 +155,13 @@ impl ScalarUDFImpl for GenSeries { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; match args[0].data_type() { - DataType::Int64 => gen_range_inner(&args, true).map(ColumnarValue::Array), - DataType::Date32 => gen_range_date(&args, true).map(ColumnarValue::Array), + DataType::Int64 => { + make_scalar_function(|args| gen_range_inner(args, true))(args) + } + DataType::Date32 => { + make_scalar_function(|args| gen_range_date(args, true))(args) + } _ => { exec_err!("unsupported type for range") } diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-array/src/reverse.rs new file mode 100644 index 000000000000..7eb9e53deef4 --- /dev/null +++ b/datafusion/functions-array/src/reverse.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_reverse function. + +use crate::utils::make_scalar_function; +use arrow::array::{Capacities, MutableArrayData}; +use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; +use arrow_buffer::OffsetBuffer; +use arrow_schema::{DataType, FieldRef}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::exec_err; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayReverse, + array_reverse, + array, + "reverses the order of elements in the array.", + array_reverse_udf +); + +#[derive(Debug)] +pub(super) struct ArrayReverse { + signature: Signature, + aliases: Vec, +} + +impl ArrayReverse { + pub fn new() -> Self { + Self { + signature: Signature::any(1, Volatility::Immutable), + aliases: vec!["array_reverse".to_string(), "list_reverse".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayReverse { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_reverse" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(arg_types[0].clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_reverse_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// array_reverse SQL function +pub fn array_reverse_inner(arg: &[ArrayRef]) -> datafusion_common::Result { + if arg.len() != 1 { + return exec_err!("array_reverse needs one argument"); + } + + match &arg[0].data_type() { + DataType::List(field) => { + let array = as_list_array(&arg[0])?; + general_array_reverse::(array, field) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&arg[0])?; + general_array_reverse::(array, field) + } + DataType::Null => Ok(arg[0].clone()), + array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), + } +} + +fn general_array_reverse( + array: &GenericListArray, + field: &FieldRef, +) -> datafusion_common::Result +where + O: TryFrom, +{ + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut offsets = vec![O::usize_as(0)]; + let mut nulls = vec![]; + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + // skip the null value + if array.is_null(row_index) { + nulls.push(false); + offsets.push(offsets[row_index] + O::one()); + mutable.extend(0, 0, 1); + continue; + } else { + nulls.push(true); + } + + let start = offset_window[0]; + let end = offset_window[1]; + + let mut index = end - O::one(); + let mut cnt = 0; + + while index >= start { + mutable.extend(0, index.to_usize().unwrap(), index.to_usize().unwrap() + 1); + index = index - O::one(); + cnt += 1; + } + offsets.push(offsets[row_index] + O::usize_as(cnt)); + } + + let data = mutable.freeze(); + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + Some(nulls.into()), + )?)) +} diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-array/src/sort.rs new file mode 100644 index 000000000000..2f3fa33e6857 --- /dev/null +++ b/datafusion/functions-array/src/sort.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_sort function. + +use crate::utils::make_scalar_function; +use arrow::compute; +use arrow_array::{Array, ArrayRef, ListArray}; +use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List}; +use arrow_schema::{DataType, Field, SortOptions}; +use datafusion_common::cast::{as_list_array, as_string_array}; +use datafusion_common::exec_err; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArraySort, + array_sort, + array desc null_first, + "returns sorted array.", + array_sort_udf +); + +#[derive(Debug)] +pub(super) struct ArraySort { + signature: Signature, + aliases: Vec, +} + +impl ArraySort { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["array_sort".to_string(), "list_sort".to_string()], + } + } +} + +impl ScalarUDFImpl for ArraySort { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_sort" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + match &arg_types[0] { + List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + LargeList(field) => Ok(LargeList(Arc::new(Field::new( + "item", + field.data_type().clone(), + true, + )))), + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_sort_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_sort SQL function +pub fn array_sort_inner(args: &[ArrayRef]) -> datafusion_common::Result { + if args.is_empty() || args.len() > 3 { + return exec_err!("array_sort expects one to three arguments"); + } + + let sort_option = match args.len() { + 1 => None, + 2 => { + let sort = as_string_array(&args[1])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: true, + }) + } + 3 => { + let sort = as_string_array(&args[1])?.value(0); + let nulls_first = as_string_array(&args[2])?.value(0); + Some(SortOptions { + descending: order_desc(sort)?, + nulls_first: order_nulls_first(nulls_first)?, + }) + } + _ => return exec_err!("array_sort expects 1 to 3 arguments"), + }; + + let list_array = as_list_array(&args[0])?; + let row_count = list_array.len(); + + let mut array_lengths = vec![]; + let mut arrays = vec![]; + let mut valid = BooleanBufferBuilder::new(row_count); + for i in 0..row_count { + if list_array.is_null(i) { + array_lengths.push(0); + valid.append(false); + } else { + let arr_ref = list_array.value(i); + let arr_ref = arr_ref.as_ref(); + + let sorted_array = compute::sort(arr_ref, sort_option)?; + array_lengths.push(sorted_array.len()); + arrays.push(sorted_array); + valid.append(true); + } + } + + // Assume all arrays have the same data type + let data_type = list_array.value_type(); + let buffer = valid.finish(); + + let elements = arrays + .iter() + .map(|a| a.as_ref()) + .collect::>(); + + let list_arr = ListArray::new( + Arc::new(Field::new("item", data_type, true)), + OffsetBuffer::from_lengths(array_lengths), + Arc::new(compute::concat(elements.as_slice())?), + Some(NullBuffer::new(buffer)), + ); + Ok(Arc::new(list_arr)) +} + +fn order_desc(modifier: &str) -> datafusion_common::Result { + match modifier.to_uppercase().as_str() { + "DESC" => Ok(true), + "ASC" => Ok(false), + _ => exec_err!("the second parameter of array_sort expects DESC or ASC"), + } +} + +fn order_nulls_first(modifier: &str) -> datafusion_common::Result { + match modifier.to_uppercase().as_str() { + "NULLS FIRST" => Ok(true), + "NULLS LAST" => Ok(false), + _ => exec_err!( + "the third parameter of array_sort expects NULLS FIRST or NULLS LAST" + ), + } +} diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index 156703105766..943950447bc4 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -86,70 +86,6 @@ impl ScalarUDFImpl for ArrayDims { } } -make_udf_function!( - ArraySort, - array_sort, - array desc null_first, - "returns sorted array.", - array_sort_udf -); - -#[derive(Debug)] -pub(super) struct ArraySort { - signature: Signature, - aliases: Vec, -} - -impl ArraySort { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec!["array_sort".to_string(), "list_sort".to_string()], - } - } -} - -impl ScalarUDFImpl for ArraySort { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_sort" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - LargeList(field) => Ok(LargeList(Arc::new(Field::new( - "item", - field.data_type().clone(), - true, - )))), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_sort(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - make_udf_function!( ArrayResize, array_resize, @@ -420,60 +356,6 @@ impl ScalarUDFImpl for ArrayRepeat { } } -make_udf_function!( - ArrayLength, - array_length, - array, - "returns the length of the array dimension.", - array_length_udf -); - -#[derive(Debug)] -pub(super) struct ArrayLength { - signature: Signature, - aliases: Vec, -} -impl ArrayLength { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_length"), String::from("list_length")], - } - } -} - -impl ScalarUDFImpl for ArrayLength { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_length" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The array_length function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_length(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - make_udf_function!( Flatten, flatten, @@ -541,52 +423,3 @@ impl ScalarUDFImpl for Flatten { &self.aliases } } - -make_udf_function!( - ArrayReverse, - array_reverse, - array, - "reverses the order of elements in the array.", - array_reverse_udf -); - -#[derive(Debug)] -pub(super) struct ArrayReverse { - signature: Signature, - aliases: Vec, -} - -impl crate::udf::ArrayReverse { - pub fn new() -> Self { - Self { - signature: Signature::any(1, Volatility::Immutable), - aliases: vec!["array_reverse".to_string(), "list_reverse".to_string()], - } - } -} - -impl ScalarUDFImpl for crate::udf::ArrayReverse { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_reserse" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(arg_types[0].clone()) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_reverse(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} From f5c5343127649c787741f05b53eb3990cc770b3f Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Sat, 23 Mar 2024 20:31:29 -0700 Subject: [PATCH 051/117] Issue-9761 - Extract array_empty and array_repeat functions from functions-array subcrate' s kernels and udf containers (#9762) --- datafusion/functions-array/src/empty.rs | 111 +++++++++++ datafusion/functions-array/src/kernels.rs | 181 +---------------- datafusion/functions-array/src/lib.rs | 10 +- datafusion/functions-array/src/repeat.rs | 232 ++++++++++++++++++++++ datafusion/functions-array/src/udf.rs | 107 ---------- 5 files changed, 351 insertions(+), 290 deletions(-) create mode 100644 datafusion/functions-array/src/empty.rs create mode 100644 datafusion/functions-array/src/repeat.rs diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-array/src/empty.rs new file mode 100644 index 000000000000..37b247deb4c8 --- /dev/null +++ b/datafusion/functions-array/src/empty.rs @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_empty function. + +use crate::utils::make_scalar_function; +use arrow_array::{ArrayRef, BooleanArray, OffsetSizeTrait}; +use arrow_schema::DataType; +use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List}; +use datafusion_common::cast::{as_generic_list_array, as_null_array}; +use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayEmpty, + array_empty, + array, + "returns true for an empty array or false for a non-empty array.", + array_empty_udf +); + +#[derive(Debug)] +pub(super) struct ArrayEmpty { + signature: Signature, + aliases: Vec, +} +impl ArrayEmpty { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("empty")], + } + } +} + +impl ScalarUDFImpl for ArrayEmpty { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "empty" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, + _ => { + return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_empty_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_empty SQL function +pub fn array_empty_inner(args: &[ArrayRef]) -> datafusion_common::Result { + if args.len() != 1 { + return exec_err!("array_empty expects one argument"); + } + + if as_null_array(&args[0]).is_ok() { + // Make sure to return Boolean type. + return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); + } + let array_type = args[0].data_type(); + + match array_type { + List(_) => general_array_empty::(&args[0]), + LargeList(_) => general_array_empty::(&args[0]), + _ => exec_err!("array_empty does not support type '{array_type:?}'."), + } +} + +fn general_array_empty( + array: &ArrayRef, +) -> datafusion_common::Result { + let array = as_generic_list_array::(array)?; + let builder = array + .iter() + .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) + .collect::(); + Ok(Arc::new(builder)) +} diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index 1fb3abd52906..4745db0170ed 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -18,18 +18,15 @@ //! implementation kernels for array functions use arrow::array::{ - Array, ArrayRef, BooleanArray, Capacities, GenericListArray, Int64Array, ListArray, + Array, ArrayRef, Capacities, GenericListArray, Int64Array, ListArray, MutableArrayData, OffsetSizeTrait, UInt64Array, }; -use arrow::compute; -use arrow::datatypes::{DataType, Field, UInt64Type}; -use arrow_array::new_null_array; +use arrow::datatypes::{DataType, UInt64Type}; use arrow_buffer::{ArrowNativeType, OffsetBuffer}; use arrow_schema::FieldRef; use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, - as_null_array, }; use datafusion_common::{ exec_err, internal_datafusion_err, DataFusionError, Result, ScalarValue, @@ -161,180 +158,6 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { } } -/// Array_empty SQL function -pub fn array_empty(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_empty expects one argument"); - } - - if as_null_array(&args[0]).is_ok() { - // Make sure to return Boolean type. - return Ok(Arc::new(BooleanArray::new_null(args[0].len()))); - } - let array_type = args[0].data_type(); - - match array_type { - DataType::List(_) => general_array_empty::(&args[0]), - DataType::LargeList(_) => general_array_empty::(&args[0]), - _ => exec_err!("array_empty does not support type '{array_type:?}'."), - } -} - -fn general_array_empty(array: &ArrayRef) -> Result { - let array = as_generic_list_array::(array)?; - let builder = array - .iter() - .map(|arr| arr.map(|arr| arr.len() == arr.null_count())) - .collect::(); - Ok(Arc::new(builder)) -} - -/// Array_repeat SQL function -pub fn array_repeat(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!("array_repeat expects two arguments"); - } - - let element = &args[0]; - let count_array = as_int64_array(&args[1])?; - - match element.data_type() { - DataType::List(_) => { - let list_array = as_list_array(element)?; - general_list_repeat::(list_array, count_array) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(element)?; - general_list_repeat::(list_array, count_array) - } - _ => general_repeat::(element, count_array), - } -} - -/// For each element of `array[i]` repeat `count_array[i]` times. -/// -/// Assumption for the input: -/// 1. `count[i] >= 0` -/// 2. `array.len() == count_array.len()` -/// -/// For example, -/// ```text -/// array_repeat( -/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] -/// ) -/// ``` -fn general_repeat( - array: &ArrayRef, - count_array: &Int64Array, -) -> Result { - let data_type = array.data_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (row_index, &count) in count_vec.iter().enumerate() { - let repeated_array = if array.is_null(row_index) { - new_null_array(data_type, count) - } else { - let original_data = array.to_data(); - let capacity = Capacities::Array(count); - let mut mutable = - MutableArrayData::with_capacities(vec![&original_data], false, capacity); - - for _ in 0..count { - mutable.extend(0, row_index, row_index + 1); - } - - let data = mutable.freeze(); - arrow_array::make_array(data) - }; - new_values.push(repeated_array); - } - - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; - - Ok(Arc::new(GenericListArray::::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), - OffsetBuffer::from_lengths(count_vec), - values, - None, - )?)) -} - -/// Handle List version of `general_repeat` -/// -/// For each element of `list_array[i]` repeat `count_array[i]` times. -/// -/// For example, -/// ```text -/// array_repeat( -/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] -/// ) -/// ``` -fn general_list_repeat( - list_array: &GenericListArray, - count_array: &Int64Array, -) -> Result { - let data_type = list_array.data_type(); - let value_type = list_array.value_type(); - let mut new_values = vec![]; - - let count_vec = count_array - .values() - .to_vec() - .iter() - .map(|x| *x as usize) - .collect::>(); - - for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { - let list_arr = match list_array_row { - Some(list_array_row) => { - let original_data = list_array_row.to_data(); - let capacity = Capacities::Array(original_data.len() * count); - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data], - false, - capacity, - ); - - for _ in 0..count { - mutable.extend(0, 0, original_data.len()); - } - - let data = mutable.freeze(); - let repeated_array = arrow_array::make_array(data); - - let list_arr = GenericListArray::::try_new( - Arc::new(Field::new("item", value_type.clone(), true)), - OffsetBuffer::::from_lengths(vec![original_data.len(); count]), - repeated_array, - None, - )?; - Arc::new(list_arr) as ArrayRef - } - None => new_null_array(data_type, count), - }; - new_values.push(list_arr); - } - - let lengths = new_values.iter().map(|a| a.len()).collect::>(); - let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = compute::concat(&new_values)?; - - Ok(Arc::new(ListArray::try_new( - Arc::new(Field::new("item", data_type.to_owned(), true)), - OffsetBuffer::::from_lengths(lengths), - values, - None, - )?)) -} - /// array_resize SQL function pub fn array_resize(arg: &[ArrayRef]) -> Result { if arg.len() < 2 || arg.len() > 3 { diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index f4ca5408aa6b..4a7bb3fda90d 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -31,6 +31,7 @@ pub mod macros; mod array_has; mod concat; mod core; +mod empty; mod except; mod extract; mod kernels; @@ -38,6 +39,7 @@ mod length; mod position; mod range; mod remove; +mod repeat; mod replace; mod reverse; mod rewrite; @@ -62,6 +64,7 @@ pub mod expr_fn { pub use super::concat::array_concat; pub use super::concat::array_prepend; pub use super::core::make_array; + pub use super::empty::array_empty; pub use super::except::array_except; pub use super::extract::array_element; pub use super::extract::array_pop_back; @@ -75,6 +78,7 @@ pub mod expr_fn { pub use super::remove::array_remove; pub use super::remove::array_remove_all; pub use super::remove::array_remove_n; + pub use super::repeat::array_repeat; pub use super::replace::array_replace; pub use super::replace::array_replace_all; pub use super::replace::array_replace_n; @@ -86,9 +90,7 @@ pub mod expr_fn { pub use super::string::array_to_string; pub use super::string::string_to_array; pub use super::udf::array_dims; - pub use super::udf::array_empty; pub use super::udf::array_ndims; - pub use super::udf::array_repeat; pub use super::udf::array_resize; pub use super::udf::cardinality; pub use super::udf::flatten; @@ -116,11 +118,11 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), - udf::array_empty_udf(), + empty::array_empty_udf(), length::array_length_udf(), udf::flatten_udf(), sort::array_sort_udf(), - udf::array_repeat_udf(), + repeat::array_repeat_udf(), udf::array_resize_udf(), reverse::array_reverse_udf(), set_ops::array_distinct_udf(), diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-array/src/repeat.rs new file mode 100644 index 000000000000..bf967f65724b --- /dev/null +++ b/datafusion/functions-array/src/repeat.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_repeat function. + +use crate::utils::make_scalar_function; +use arrow::array::{Capacities, MutableArrayData}; +use arrow::compute; +use arrow_array::{ + new_null_array, Array, ArrayRef, GenericListArray, Int64Array, ListArray, + OffsetSizeTrait, +}; +use arrow_buffer::OffsetBuffer; +use arrow_schema::DataType::{LargeList, List}; +use arrow_schema::{DataType, Field}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::exec_err; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayRepeat, + array_repeat, + element count, // arg name + "returns an array containing element `count` times.", // doc + array_repeat_udf // internal function name +); +#[derive(Debug)] +pub(super) struct ArrayRepeat { + signature: Signature, + aliases: Vec, +} + +impl ArrayRepeat { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec![String::from("array_repeat"), String::from("list_repeat")], + } + } +} + +impl ScalarUDFImpl for ArrayRepeat { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_repeat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(List(Arc::new(Field::new( + "item", + arg_types[0].clone(), + true, + )))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_repeat_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Array_repeat SQL function +pub fn array_repeat_inner(args: &[ArrayRef]) -> datafusion_common::Result { + if args.len() != 2 { + return exec_err!("array_repeat expects two arguments"); + } + + let element = &args[0]; + let count_array = as_int64_array(&args[1])?; + + match element.data_type() { + List(_) => { + let list_array = as_list_array(element)?; + general_list_repeat::(list_array, count_array) + } + LargeList(_) => { + let list_array = as_large_list_array(element)?; + general_list_repeat::(list_array, count_array) + } + _ => general_repeat::(element, count_array), + } +} + +/// For each element of `array[i]` repeat `count_array[i]` times. +/// +/// Assumption for the input: +/// 1. `count[i] >= 0` +/// 2. `array.len() == count_array.len()` +/// +/// For example, +/// ```text +/// array_repeat( +/// [1, 2, 3], [2, 0, 1] => [[1, 1], [], [3]] +/// ) +/// ``` +fn general_repeat( + array: &ArrayRef, + count_array: &Int64Array, +) -> datafusion_common::Result { + let data_type = array.data_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (row_index, &count) in count_vec.iter().enumerate() { + let repeated_array = if array.is_null(row_index) { + new_null_array(data_type, count) + } else { + let original_data = array.to_data(); + let capacity = Capacities::Array(count); + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + + for _ in 0..count { + mutable.extend(0, row_index, row_index + 1); + } + + let data = mutable.freeze(); + arrow_array::make_array(data) + }; + new_values.push(repeated_array); + } + + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(GenericListArray::::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::from_lengths(count_vec), + values, + None, + )?)) +} + +/// Handle List version of `general_repeat` +/// +/// For each element of `list_array[i]` repeat `count_array[i]` times. +/// +/// For example, +/// ```text +/// array_repeat( +/// [[1, 2, 3], [4, 5], [6]], [2, 0, 1] => [[[1, 2, 3], [1, 2, 3]], [], [[6]]] +/// ) +/// ``` +fn general_list_repeat( + list_array: &GenericListArray, + count_array: &Int64Array, +) -> datafusion_common::Result { + let data_type = list_array.data_type(); + let value_type = list_array.value_type(); + let mut new_values = vec![]; + + let count_vec = count_array + .values() + .to_vec() + .iter() + .map(|x| *x as usize) + .collect::>(); + + for (list_array_row, &count) in list_array.iter().zip(count_vec.iter()) { + let list_arr = match list_array_row { + Some(list_array_row) => { + let original_data = list_array_row.to_data(); + let capacity = Capacities::Array(original_data.len() * count); + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data], + false, + capacity, + ); + + for _ in 0..count { + mutable.extend(0, 0, original_data.len()); + } + + let data = mutable.freeze(); + let repeated_array = arrow_array::make_array(data); + + let list_arr = GenericListArray::::try_new( + Arc::new(Field::new("item", value_type.clone(), true)), + OffsetBuffer::::from_lengths(vec![original_data.len(); count]), + repeated_array, + None, + )?; + Arc::new(list_arr) as ArrayRef + } + None => new_null_array(data_type, count), + }; + new_values.push(list_arr); + } + + let lengths = new_values.iter().map(|a| a.len()).collect::>(); + let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); + let values = compute::concat(&new_values)?; + + Ok(Arc::new(ListArray::try_new( + Arc::new(Field::new("item", data_type.to_owned(), true)), + OffsetBuffer::::from_lengths(lengths), + values, + None, + )?)) +} diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index 943950447bc4..9cbcf0a923d1 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -19,7 +19,6 @@ use arrow::datatypes::DataType; use arrow::datatypes::Field; -use arrow_schema::DataType::List; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::Result; @@ -250,112 +249,6 @@ impl ScalarUDFImpl for ArrayNdims { } } -make_udf_function!( - ArrayEmpty, - array_empty, - array, - "returns true for an empty array or false for a non-empty array.", - array_empty_udf -); - -#[derive(Debug)] -pub(super) struct ArrayEmpty { - signature: Signature, - aliases: Vec, -} -impl ArrayEmpty { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("empty")], - } - } -} - -impl ScalarUDFImpl for ArrayEmpty { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "empty" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, - _ => { - return plan_err!("The array_empty function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_empty(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - ArrayRepeat, - array_repeat, - element count, // arg name - "returns an array containing element `count` times.", // doc - array_repeat_udf // internal function name -); -#[derive(Debug)] -pub(super) struct ArrayRepeat { - signature: Signature, - aliases: Vec, -} - -impl ArrayRepeat { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec![String::from("array_repeat"), String::from("list_repeat")], - } - } -} - -impl ScalarUDFImpl for ArrayRepeat { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_repeat" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - Ok(List(Arc::new(Field::new( - "item", - arg_types[0].clone(), - true, - )))) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_repeat(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - make_udf_function!( Flatten, flatten, From 7b8d6587ae7c913215c0a27869937d2fb438c1e9 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Sun, 24 Mar 2024 13:07:23 +0800 Subject: [PATCH 052/117] Minor: remove an outdated TODO in `TypeCoercion` (#9752) --- .../optimizer/src/analyzer/type_coercion.rs | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index f8dcf460a469..c76c1c8a7bd0 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -47,8 +47,8 @@ use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, type_coercion, AggregateFunction, Expr, ExprSchemable, LogicalPlan, Operator, - Projection, ScalarFunctionDefinition, ScalarUDF, Signature, WindowFrame, - WindowFrameBound, WindowFrameUnits, + ScalarFunctionDefinition, ScalarUDF, Signature, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; #[derive(Default)] @@ -76,7 +76,7 @@ fn analyze_internal( plan: &LogicalPlan, ) -> Result { // optimize child plans first - let mut new_inputs = plan + let new_inputs = plan .inputs() .iter() .map(|p| analyze_internal(external_schema, p)) @@ -110,14 +110,7 @@ fn analyze_internal( }) .collect::>>()?; - // TODO: with_new_exprs can't change the schema, so we need to do this here - match &plan { - LogicalPlan::Projection(_) => Ok(LogicalPlan::Projection(Projection::try_new( - new_expr, - Arc::new(new_inputs.swap_remove(0)), - )?)), - _ => plan.with_new_exprs(new_expr, new_inputs), - } + plan.with_new_exprs(new_expr, new_inputs) } pub(crate) struct TypeCoercionRewriter { From 06ee8a443911797d97fc216a4fc8a29b7604dca8 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Sun, 24 Mar 2024 00:01:44 -0700 Subject: [PATCH 053/117] Issue-9765 - Extract array_resize and cardinality functions from functions-array subcrate' s kernels and udf containers (#9766) --- datafusion/functions-array/src/cardinality.rs | 115 +++++++++++ datafusion/functions-array/src/kernels.rs | 165 +--------------- datafusion/functions-array/src/lib.rs | 10 +- datafusion/functions-array/src/resize.rs | 179 ++++++++++++++++++ datafusion/functions-array/src/udf.rs | 110 ----------- datafusion/functions-array/src/utils.rs | 49 +++-- 6 files changed, 343 insertions(+), 285 deletions(-) create mode 100644 datafusion/functions-array/src/cardinality.rs create mode 100644 datafusion/functions-array/src/resize.rs diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-array/src/cardinality.rs new file mode 100644 index 000000000000..483336fe081d --- /dev/null +++ b/datafusion/functions-array/src/cardinality.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for cardinality function. + +use crate::utils::make_scalar_function; +use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait, UInt64Array}; +use arrow_schema::DataType; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::{exec_err, plan_err}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + Cardinality, + cardinality, + array, + "returns the total number of elements in the array.", + cardinality_udf +); + +impl Cardinality { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("cardinality")], + } + } +} + +#[derive(Debug)] +pub(super) struct Cardinality { + signature: Signature, + aliases: Vec, +} +impl ScalarUDFImpl for Cardinality { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "cardinality" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + Ok(match arg_types[0] { + List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, + _ => { + return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(cardinality_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Cardinality SQL function +pub fn cardinality_inner(args: &[ArrayRef]) -> datafusion_common::Result { + if args.len() != 1 { + return exec_err!("cardinality expects one argument"); + } + + match &args[0].data_type() { + List(_) => { + let list_array = as_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + LargeList(_) => { + let list_array = as_large_list_array(&args[0])?; + generic_list_cardinality::(list_array) + } + other => { + exec_err!("cardinality does not support type '{:?}'", other) + } + } +} + +fn generic_list_cardinality( + array: &GenericListArray, +) -> datafusion_common::Result { + let result = array + .iter() + .map(|arr| match crate::utils::compute_array_dims(arr)? { + Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), + None => Ok(None), + }) + .collect::>()?; + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/kernels.rs index 4745db0170ed..1a08b64197a9 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/kernels.rs @@ -18,80 +18,19 @@ //! implementation kernels for array functions use arrow::array::{ - Array, ArrayRef, Capacities, GenericListArray, Int64Array, ListArray, - MutableArrayData, OffsetSizeTrait, UInt64Array, + Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, }; use arrow::datatypes::{DataType, UInt64Type}; -use arrow_buffer::{ArrowNativeType, OffsetBuffer}; -use arrow_schema::FieldRef; +use arrow_buffer::OffsetBuffer; use datafusion_common::cast::{ - as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, -}; -use datafusion_common::{ - exec_err, internal_datafusion_err, DataFusionError, Result, ScalarValue, + as_generic_list_array, as_large_list_array, as_list_array, }; +use datafusion_common::{exec_err, Result}; -use crate::utils::downcast_arg; -use std::any::type_name; +use crate::utils::compute_array_dims; use std::sync::Arc; -/// Returns the length of each array dimension -fn compute_array_dims(arr: Option) -> Result>>> { - let mut value = match arr { - Some(arr) => arr, - None => return Ok(None), - }; - if value.is_empty() { - return Ok(None); - } - let mut res = vec![Some(value.len() as u64)]; - - loop { - match value.data_type() { - DataType::List(..) => { - value = downcast_arg!(value, ListArray).value(0); - res.push(Some(value.len() as u64)); - } - _ => return Ok(Some(res)), - } - } -} - -fn generic_list_cardinality( - array: &GenericListArray, -) -> Result { - let result = array - .iter() - .map(|arr| match compute_array_dims(arr)? { - Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), - None => Ok(None), - }) - .collect::>()?; - Ok(Arc::new(result) as ArrayRef) -} - -/// Cardinality SQL function -pub fn cardinality(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("cardinality expects one argument"); - } - - match &args[0].data_type() { - DataType::List(_) => { - let list_array = as_list_array(&args[0])?; - generic_list_cardinality::(list_array) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&args[0])?; - generic_list_cardinality::(list_array) - } - other => { - exec_err!("cardinality does not support type '{:?}'", other) - } - } -} - /// Array_dims SQL function pub fn array_dims(args: &[ArrayRef]) -> Result { if args.len() != 1 { @@ -158,100 +97,6 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result { } } -/// array_resize SQL function -pub fn array_resize(arg: &[ArrayRef]) -> Result { - if arg.len() < 2 || arg.len() > 3 { - return exec_err!("array_resize needs two or three arguments"); - } - - let new_len = as_int64_array(&arg[1])?; - let new_element = if arg.len() == 3 { - Some(arg[2].clone()) - } else { - None - }; - - match &arg[0].data_type() { - DataType::List(field) => { - let array = as_list_array(&arg[0])?; - general_list_resize::(array, new_len, field, new_element) - } - DataType::LargeList(field) => { - let array = as_large_list_array(&arg[0])?; - general_list_resize::(array, new_len, field, new_element) - } - array_type => exec_err!("array_resize does not support type '{array_type:?}'."), - } -} - -/// array_resize keep the original array and append the default element to the end -fn general_list_resize( - array: &GenericListArray, - count_array: &Int64Array, - field: &FieldRef, - default_element: Option, -) -> Result -where - O: TryInto, -{ - let data_type = array.value_type(); - - let values = array.values(); - let original_data = values.to_data(); - - // create default element array - let default_element = if let Some(default_element) = default_element { - default_element - } else { - let null_scalar = ScalarValue::try_from(&data_type)?; - null_scalar.to_array_of_size(original_data.len())? - }; - let default_value_data = default_element.to_data(); - - // create a mutable array to store the original data - let capacity = Capacities::Array(original_data.len() + default_value_data.len()); - let mut offsets = vec![O::usize_as(0)]; - let mut mutable = MutableArrayData::with_capacities( - vec![&original_data, &default_value_data], - false, - capacity, - ); - - for (row_index, offset_window) in array.offsets().windows(2).enumerate() { - let count = count_array.value(row_index).to_usize().ok_or_else(|| { - internal_datafusion_err!("array_resize: failed to convert size to usize") - })?; - let count = O::usize_as(count); - let start = offset_window[0]; - if start + count > offset_window[1] { - let extra_count = - (start + count - offset_window[1]).try_into().map_err(|_| { - internal_datafusion_err!( - "array_resize: failed to convert size to i64" - ) - })?; - let end = offset_window[1]; - mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); - // append default element - for _ in 0..extra_count { - mutable.extend(1, row_index, row_index + 1); - } - } else { - let end = start + count; - mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); - }; - offsets.push(offsets[row_index] + count); - } - - let data = mutable.freeze(); - Ok(Arc::new(GenericListArray::::try_new( - field.clone(), - OffsetBuffer::::new(offsets.into()), - arrow_array::make_array(data), - None, - )?)) -} - // Create new offsets that are euqiavlent to `flatten` the array. fn get_offsets_for_flatten( offsets: OffsetBuffer, diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 4a7bb3fda90d..feecd18c2e8d 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -29,6 +29,7 @@ pub mod macros; mod array_has; +mod cardinality; mod concat; mod core; mod empty; @@ -41,6 +42,7 @@ mod range; mod remove; mod repeat; mod replace; +mod resize; mod reverse; mod rewrite; mod set_ops; @@ -60,6 +62,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; pub use super::concat::array_prepend; @@ -82,6 +85,7 @@ pub mod expr_fn { pub use super::replace::array_replace; pub use super::replace::array_replace_all; pub use super::replace::array_replace_n; + pub use super::resize::array_resize; pub use super::reverse::array_reverse; pub use super::set_ops::array_distinct; pub use super::set_ops::array_intersect; @@ -91,8 +95,6 @@ pub mod expr_fn { pub use super::string::string_to_array; pub use super::udf::array_dims; pub use super::udf::array_ndims; - pub use super::udf::array_resize; - pub use super::udf::cardinality; pub use super::udf::flatten; } @@ -104,7 +106,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { range::range_udf(), range::gen_series_udf(), udf::array_dims_udf(), - udf::cardinality_udf(), + cardinality::cardinality_udf(), udf::array_ndims_udf(), concat::array_append_udf(), concat::array_prepend_udf(), @@ -123,7 +125,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { udf::flatten_udf(), sort::array_sort_udf(), repeat::array_repeat_udf(), - udf::array_resize_udf(), + resize::array_resize_udf(), reverse::array_reverse_udf(), set_ops::array_distinct_udf(), set_ops::array_intersect_udf(), diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-array/src/resize.rs new file mode 100644 index 000000000000..f3996110f904 --- /dev/null +++ b/datafusion/functions-array/src/resize.rs @@ -0,0 +1,179 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_resize function. + +use crate::utils::make_scalar_function; +use arrow::array::{Capacities, MutableArrayData}; +use arrow_array::{ArrayRef, GenericListArray, Int64Array, OffsetSizeTrait}; +use arrow_buffer::{ArrowNativeType, OffsetBuffer}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List}; +use arrow_schema::{DataType, FieldRef}; +use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; +use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue}; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +make_udf_function!( + ArrayResize, + array_resize, + array size value, + "returns an array with the specified size filled with the given value.", + array_resize_udf +); + +#[derive(Debug)] +pub(super) struct ArrayResize { + signature: Signature, + aliases: Vec, +} + +impl ArrayResize { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: vec!["array_resize".to_string(), "list_resize".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayResize { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "array_resize" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + match &arg_types[0] { + List(field) | FixedSizeList(field, _) => Ok(List(field.clone())), + LargeList(field) => Ok(LargeList(field.clone())), + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + ), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(array_resize_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// array_resize SQL function +pub fn array_resize_inner(arg: &[ArrayRef]) -> datafusion_common::Result { + if arg.len() < 2 || arg.len() > 3 { + return exec_err!("array_resize needs two or three arguments"); + } + + let new_len = as_int64_array(&arg[1])?; + let new_element = if arg.len() == 3 { + Some(arg[2].clone()) + } else { + None + }; + + match &arg[0].data_type() { + DataType::List(field) => { + let array = as_list_array(&arg[0])?; + general_list_resize::(array, new_len, field, new_element) + } + DataType::LargeList(field) => { + let array = as_large_list_array(&arg[0])?; + general_list_resize::(array, new_len, field, new_element) + } + array_type => exec_err!("array_resize does not support type '{array_type:?}'."), + } +} + +/// array_resize keep the original array and append the default element to the end +fn general_list_resize( + array: &GenericListArray, + count_array: &Int64Array, + field: &FieldRef, + default_element: Option, +) -> datafusion_common::Result +where + O: TryInto, +{ + let data_type = array.value_type(); + + let values = array.values(); + let original_data = values.to_data(); + + // create default element array + let default_element = if let Some(default_element) = default_element { + default_element + } else { + let null_scalar = ScalarValue::try_from(&data_type)?; + null_scalar.to_array_of_size(original_data.len())? + }; + let default_value_data = default_element.to_data(); + + // create a mutable array to store the original data + let capacity = Capacities::Array(original_data.len() + default_value_data.len()); + let mut offsets = vec![O::usize_as(0)]; + let mut mutable = MutableArrayData::with_capacities( + vec![&original_data, &default_value_data], + false, + capacity, + ); + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let count = count_array.value(row_index).to_usize().ok_or_else(|| { + internal_datafusion_err!("array_resize: failed to convert size to usize") + })?; + let count = O::usize_as(count); + let start = offset_window[0]; + if start + count > offset_window[1] { + let extra_count = + (start + count - offset_window[1]).try_into().map_err(|_| { + internal_datafusion_err!( + "array_resize: failed to convert size to i64" + ) + })?; + let end = offset_window[1]; + mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); + // append default element + for _ in 0..extra_count { + mutable.extend(1, row_index, row_index + 1); + } + } else { + let end = start + count; + mutable.extend(0, (start).to_usize().unwrap(), (end).to_usize().unwrap()); + }; + offsets.push(offsets[row_index] + count); + } + + let data = mutable.freeze(); + Ok(Arc::new(GenericListArray::::try_new( + field.clone(), + OffsetBuffer::::new(offsets.into()), + arrow_array::make_array(data), + None, + )?)) +} diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/udf.rs index 9cbcf0a923d1..bdc11155b633 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/udf.rs @@ -85,116 +85,6 @@ impl ScalarUDFImpl for ArrayDims { } } -make_udf_function!( - ArrayResize, - array_resize, - array size value, - "returns an array with the specified size filled with the given value.", - array_resize_udf -); - -#[derive(Debug)] -pub(super) struct ArrayResize { - signature: Signature, - aliases: Vec, -} - -impl ArrayResize { - pub fn new() -> Self { - Self { - signature: Signature::variadic_any(Volatility::Immutable), - aliases: vec!["array_resize".to_string(), "list_resize".to_string()], - } - } -} - -impl ScalarUDFImpl for ArrayResize { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "array_resize" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - match &arg_types[0] { - List(field) | FixedSizeList(field, _) => Ok(List(field.clone())), - LargeList(field) => Ok(LargeList(field.clone())), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), - } - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_resize(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - -make_udf_function!( - Cardinality, - cardinality, - array, - "returns the total number of elements in the array.", - cardinality_udf -); - -impl Cardinality { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("cardinality")], - } - } -} - -#[derive(Debug)] -pub(super) struct Cardinality { - signature: Signature, - aliases: Vec, -} -impl ScalarUDFImpl for Cardinality { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "cardinality" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - Ok(match arg_types[0] { - List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, - _ => { - return plan_err!("The cardinality function can only accept List/LargeList/FixedSizeList."); - } - }) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::cardinality(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases - } -} - make_udf_function!( ArrayNdims, array_ndims, diff --git a/datafusion/functions-array/src/utils.rs b/datafusion/functions-array/src/utils.rs index c0f7627d2ab7..d86e4fe2ab7b 100644 --- a/datafusion/functions-array/src/utils.rs +++ b/datafusion/functions-array/src/utils.rs @@ -22,15 +22,30 @@ use std::sync::Arc; use arrow::{array::ArrayRef, datatypes::DataType}; use arrow_array::{ - Array, BooleanArray, GenericListArray, OffsetSizeTrait, Scalar, UInt32Array, + Array, BooleanArray, GenericListArray, ListArray, OffsetSizeTrait, Scalar, + UInt32Array, }; use arrow_buffer::OffsetBuffer; use arrow_schema::Field; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, plan_err, Result, ScalarValue}; +use core::any::type_name; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +macro_rules! downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + DataFusionError::Internal(format!( + "could not cast to {}", + type_name::<$ARRAY_TYPE>() + )) + })? + }}; +} +pub(crate) use downcast_arg; + pub(crate) fn check_datatypes(name: &str, args: &[&ArrayRef]) -> Result<()> { let data_type = args[0].data_type(); if !args.iter().all(|arg| { @@ -214,17 +229,29 @@ pub(crate) fn compare_element_to_list( Ok(res) } -macro_rules! downcast_arg { - ($ARG:expr, $ARRAY_TYPE:ident) => {{ - $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { - DataFusionError::Internal(format!( - "could not cast to {}", - type_name::<$ARRAY_TYPE>() - )) - })? - }}; +/// Returns the length of each array dimension +pub(crate) fn compute_array_dims( + arr: Option, +) -> Result>>> { + let mut value = match arr { + Some(arr) => arr, + None => return Ok(None), + }; + if value.is_empty() { + return Ok(None); + } + let mut res = vec![Some(value.len() as u64)]; + + loop { + match value.data_type() { + DataType::List(..) => { + value = downcast_arg!(value, ListArray).value(0); + res.push(Some(value.len() as u64)); + } + _ => return Ok(Some(res)), + } + } } -pub(crate) use downcast_arg; #[cfg(test)] mod tests { From b1f377465a50b388a312272e2552b7ce036f6935 Mon Sep 17 00:00:00 2001 From: Adam Curtis Date: Sun, 24 Mar 2024 05:49:52 -0400 Subject: [PATCH 054/117] fix: change placeholder errors from Internal to Plan (#9745) * fix: change placeholder errors from Internal to Plan * Add tests for error case --------- Co-authored-by: Andrew Lamb --- datafusion/common/src/error.rs | 1 + datafusion/common/src/param_value.rs | 16 ++++------- datafusion/sql/tests/sql_integration.rs | 37 ++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/datafusion/common/src/error.rs b/datafusion/common/src/error.rs index d1e47b473499..cafab6d334b3 100644 --- a/datafusion/common/src/error.rs +++ b/datafusion/common/src/error.rs @@ -601,6 +601,7 @@ pub use config_err as _config_err; pub use internal_datafusion_err as _internal_datafusion_err; pub use internal_err as _internal_err; pub use not_impl_err as _not_impl_err; +pub use plan_datafusion_err as _plan_datafusion_err; pub use plan_err as _plan_err; pub use schema_err as _schema_err; diff --git a/datafusion/common/src/param_value.rs b/datafusion/common/src/param_value.rs index c614098713d6..8d61bad97b9f 100644 --- a/datafusion/common/src/param_value.rs +++ b/datafusion/common/src/param_value.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::error::_plan_err; -use crate::{DataFusionError, Result, ScalarValue}; +use crate::error::{_plan_datafusion_err, _plan_err}; +use crate::{Result, ScalarValue}; use arrow_schema::DataType; use std::collections::HashMap; @@ -75,16 +75,12 @@ impl ParamValues { let idx = id[1..] .parse::() .map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {e}" - )) + _plan_datafusion_err!("Failed to parse placeholder id: {e}") })? .checked_sub(1); // value at the idx-th position in param_values should be the value for the placeholder let value = idx.and_then(|idx| list.get(idx)).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {id}" - )) + _plan_datafusion_err!("No value found for placeholder with id {id}") })?; Ok(value.clone()) } @@ -93,9 +89,7 @@ impl ParamValues { let name = &id[1..]; // value at the name position in param_values should be the value for the placeholder let value = map.get(name).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with name {id}" - )) + _plan_datafusion_err!("No value found for placeholder with name {id}") })?; Ok(value.clone()) } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index c738a2bd754f..448a9c54202e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -24,7 +24,8 @@ use arrow_schema::TimeUnit::Nanosecond; use arrow_schema::*; use datafusion_common::config::ConfigOptions; use datafusion_common::{ - plan_err, DFSchema, DataFusionError, ParamValues, Result, ScalarValue, TableReference, + assert_contains, plan_err, DFSchema, DataFusionError, ParamValues, Result, + ScalarValue, TableReference, }; use datafusion_expr::{ logical_plan::{LogicalPlan, Prepare}, @@ -4364,6 +4365,40 @@ fn test_prepare_statement_to_plan_value_list() { prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } +#[test] +fn test_prepare_statement_unknown_list_param() { + let sql = "SELECT id from person where id = $2"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with id $2" + ); +} + +#[test] +fn test_prepare_statement_unknown_hash_param() { + let sql = "SELECT id from person where id = $bar"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::Map(HashMap::new()); + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!( + err.to_string(), + "Error during planning: No value found for placeholder with name $bar" + ); +} + +#[test] +fn test_prepare_statement_bad_list_idx() { + let sql = "SELECT id from person where id = $foo"; + let plan = logical_plan(sql).unwrap(); + let param_values = ParamValues::List(vec![]); + + let err = plan.replace_params_with_values(¶m_values).unwrap_err(); + assert_contains!(err.to_string(), "Error during planning: Failed to parse placeholder id: invalid digit found in string"); +} + #[test] fn test_table_alias() { let sql = "select * from (\ From 67e0bd32a86de2675f677ddd565fa2f9d5c0383f Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 24 Mar 2024 06:07:31 -0400 Subject: [PATCH 055/117] Move levenshtein, uuid, overlay to datafusion-functions (#9760) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Fixed missing trim() function. * Move levenshtein, uuid, overlay to datafusion-functions --- datafusion-cli/Cargo.lock | 3 +- datafusion-examples/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- datafusion/expr/src/built_in_function.rs | 35 ---- datafusion/expr/src/expr_fn.rs | 26 --- datafusion/functions/Cargo.toml | 2 + .../functions/src/string/levenshtein.rs | 146 +++++++++++++ datafusion/functions/src/string/mod.rs | 24 +++ datafusion/functions/src/string/overlay.rs | 190 +++++++++++++++++ datafusion/functions/src/string/uuid.rs | 73 +++++++ datafusion/physical-expr/Cargo.toml | 1 - datafusion/physical-expr/src/functions.rs | 29 +-- .../physical-expr/src/string_expressions.rs | 192 ------------------ datafusion/physical-plan/Cargo.toml | 1 - datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 27 +-- datafusion/proto/src/logical_plan/to_proto.rs | 3 - datafusion/sql/src/expr/mod.rs | 9 +- 20 files changed, 458 insertions(+), 334 deletions(-) create mode 100644 datafusion/functions/src/string/levenshtein.rs create mode 100644 datafusion/functions/src/string/overlay.rs create mode 100644 datafusion/functions/src/string/uuid.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 51cccf60a1e4..2f1d95d639d4 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1273,6 +1273,7 @@ dependencies = [ "md-5", "regex", "sha2", + "uuid", ] [[package]] @@ -1340,7 +1341,6 @@ dependencies = [ "regex", "sha2", "unicode-segmentation", - "uuid", ] [[package]] @@ -1370,7 +1370,6 @@ dependencies = [ "pin-project-lite", "rand", "tokio", - "uuid", ] [[package]] diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index 2b6e869ec500..4966143782ba 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -76,4 +76,4 @@ tempfile = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] } tonic = "0.11" url = { workspace = true } -uuid = "1.2" +uuid = "1.7" diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index a3570834fdb7..1e5c0d748e3d 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -122,7 +122,7 @@ tempfile = { workspace = true } tokio = { workspace = true } tokio-util = { version = "0.7.4", features = ["io"], optional = true } url = { workspace = true } -uuid = { version = "1.0", features = ["v4"] } +uuid = { version = "1.7", features = ["v4"] } xz2 = { version = "0.1", optional = true, features = ["static"] } zstd = { version = "0.13", optional = true, default-features = false } diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index d0ec1326c49e..1904d58cfc92 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -141,12 +141,6 @@ pub enum BuiltinScalarFunction { Substr, /// translate Translate, - /// uuid - Uuid, - /// overlay - OverLay, - /// levenshtein - Levenshtein, /// substr_index SubstrIndex, /// find_in_set @@ -253,14 +247,11 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, - BuiltinScalarFunction::OverLay => Volatility::Immutable, - BuiltinScalarFunction::Levenshtein => Volatility::Immutable, BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, BuiltinScalarFunction::FindInSet => Volatility::Immutable, // Volatile builtin functions BuiltinScalarFunction::Random => Volatility::Volatile, - BuiltinScalarFunction::Uuid => Volatility::Volatile, } } @@ -302,7 +293,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), - BuiltinScalarFunction::Uuid => Ok(Utf8), BuiltinScalarFunction::Repeat => { utf8_to_str_type(&input_expr_types[0], "repeat") } @@ -362,14 +352,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Iszero => Ok(Boolean), - BuiltinScalarFunction::OverLay => { - utf8_to_str_type(&input_expr_types[0], "overlay") - } - - BuiltinScalarFunction::Levenshtein => { - utf8_to_int_type(&input_expr_types[0], "levenshtein") - } - BuiltinScalarFunction::Atan | BuiltinScalarFunction::Acosh | BuiltinScalarFunction::Asinh @@ -490,7 +472,6 @@ impl BuiltinScalarFunction { } BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), - BuiltinScalarFunction::Uuid => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Power => Signature::one_of( vec![Exact(vec![Int64, Int64]), Exact(vec![Float64, Float64])], self.volatility(), @@ -536,19 +517,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { Signature::uniform(2, vec![Int64], self.volatility()) } - BuiltinScalarFunction::OverLay => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - self.volatility(), - ), - BuiltinScalarFunction::Levenshtein => Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], - self.volatility(), - ), BuiltinScalarFunction::Atan | BuiltinScalarFunction::Acosh | BuiltinScalarFunction::Asinh @@ -678,11 +646,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::Uuid => &["uuid"], - BuiltinScalarFunction::Levenshtein => &["levenshtein"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], BuiltinScalarFunction::FindInSet => &["find_in_set"], - BuiltinScalarFunction::OverLay => &["overlay"], } } } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index e1ab11c5b778..60db21e5f5fe 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -575,7 +575,6 @@ scalar_expr!(Log10, log10, num, "base 10 logarithm of number"); scalar_expr!(Ln, ln, num, "natural logarithm (base e) of number"); scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); -scalar_expr!(Uuid, uuid, , "returns uuid v4 as a string value"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); // string functions @@ -628,12 +627,6 @@ nary_scalar_expr!( "concatenates several strings, placing a seperator between each one" ); nary_scalar_expr!(Concat, concat_expr, "concatenates several strings"); -nary_scalar_expr!( - OverLay, - overlay, - "replace the substring of string that starts at the start'th character and extends for count characters with new substring" -); - scalar_expr!(Nanvl, nanvl, x y, "returns x if x is not NaN otherwise returns y"); scalar_expr!( Iszero, @@ -642,7 +635,6 @@ scalar_expr!( "returns true if a given number is +0.0 or -0.0 otherwise returns false" ); -scalar_expr!(Levenshtein, levenshtein, string1 string2, "Returns the Levenshtein distance between the two given strings"); scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"); @@ -1076,25 +1068,7 @@ mod test { test_scalar_expr!(Substr, substr, string, position); test_scalar_expr!(Substr, substring, string, position, count); test_scalar_expr!(Translate, translate, string, from, to); - test_nary_scalar_expr!(OverLay, overlay, string, characters, position, len); - test_nary_scalar_expr!(OverLay, overlay, string, characters, position); - test_scalar_expr!(Levenshtein, levenshtein, string1, string2); test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); test_scalar_expr!(FindInSet, find_in_set, string, stringlist); } - - #[test] - fn uuid_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args, - }) = uuid() - { - let name = BuiltinScalarFunction::Uuid; - assert_eq!(name, fun); - assert_eq!(0, args.len()); - } else { - unreachable!(); - } - } } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 0410d89d123f..81050dfddf66 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -75,6 +75,8 @@ log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } +uuid = { version = "1.7", features = ["v4"] } + [dev-dependencies] criterion = "0.5" rand = { workspace = true } diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs new file mode 100644 index 000000000000..b5de4b28948f --- /dev/null +++ b/datafusion/functions/src/string/levenshtein.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::utils::datafusion_strsim; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; + +use crate::string::common::{make_scalar_function, utf8_to_int_type}; + +#[derive(Debug)] +pub(super) struct LevenshteinFunc { + signature: Signature, +} + +impl LevenshteinFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LevenshteinFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "levenshtein" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "levenshtein") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(levenshtein::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(levenshtein::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function levenshtein") + } + } + } +} + +///Returns the Levenshtein distance between the two given strings. +/// LEVENSHTEIN('kitten', 'sitting') = 3 +pub fn levenshtein(args: &[ArrayRef]) -> Result { + if args.len() != 2 { + return exec_err!( + "levenshtein function requires two arguments, got {}", + args.len() + ); + } + let str1_array = as_generic_string_array::(&args[0])?; + let str2_array = as_generic_string_array::(&args[1])?; + match args[0].data_type() { + DataType::Utf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i32) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + DataType::LargeUtf8 => { + let result = str1_array + .iter() + .zip(str2_array.iter()) + .map(|(string1, string2)| match (string1, string2) { + (Some(string1), Some(string2)) => { + Some(datafusion_strsim::levenshtein(string1, string2) as i64) + } + _ => None, + }) + .collect::(); + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!( + "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." + ) + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Int32Array, StringArray}; + + use datafusion_common::cast::as_int32_array; + + use super::*; + + #[test] + fn to_levenshtein() -> Result<()> { + let string1_array = + Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); + let string2_array = + Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); + let res = levenshtein::(&[string1_array, string2_array]).unwrap(); + let result = + as_int32_array(&res).expect("failed to initialized function levenshtein"); + let expected = Int32Array::from(vec![2, 3, 2, 3]); + assert_eq!(&expected, result); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index a6d844932655..165a7c660404 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -24,24 +24,30 @@ use datafusion_expr::ScalarUDF; mod ascii; mod btrim; mod common; +mod levenshtein; mod lower; mod ltrim; mod octet_length; +mod overlay; mod rtrim; mod starts_with; mod to_hex; mod upper; +mod uuid; // create UDFs make_udf_function!(ascii::AsciiFunc, ASCII, ascii); make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); +make_udf_function!(levenshtein::LevenshteinFunc, LEVENSHTEIN, levenshtein); make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); make_udf_function!(lower::LowerFunc, LOWER, lower); make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length); +make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay); make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); make_udf_function!(upper::UpperFunc, UPPER, upper); +make_udf_function!(uuid::UuidFunc, UUID, uuid); pub mod expr_fn { use datafusion_expr::Expr; @@ -56,6 +62,11 @@ pub mod expr_fn { super::btrim().call(args) } + #[doc = "Returns the Levenshtein distance between the two given strings"] + pub fn levenshtein(arg1: Expr, arg2: Expr) -> Expr { + super::levenshtein().call(vec![arg1, arg2]) + } + #[doc = "Converts a string to lowercase."] pub fn lower(arg1: Expr) -> Expr { super::lower().call(vec![arg1]) @@ -71,6 +82,11 @@ pub mod expr_fn { super::octet_length().call(args) } + #[doc = "replace the substring of string that starts at the start'th character and extends for count characters with new substring"] + pub fn overlay(args: Vec) -> Expr { + super::overlay().call(args) + } + #[doc = "Removes all characters, spaces by default, from the end of a string"] pub fn rtrim(args: Vec) -> Expr { super::rtrim().call(args) @@ -95,6 +111,11 @@ pub mod expr_fn { pub fn upper(arg1: Expr) -> Expr { super::upper().call(vec![arg1]) } + + #[doc = "returns uuid v4 as a string value"] + pub fn uuid() -> Expr { + super::uuid().call(vec![]) + } } /// Return a list of all functions in this package @@ -102,12 +123,15 @@ pub fn functions() -> Vec> { vec![ ascii(), btrim(), + levenshtein(), lower(), ltrim(), octet_length(), + overlay(), rtrim(), starts_with(), to_hex(), upper(), + uuid(), ] } diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs new file mode 100644 index 000000000000..d7cc0da8068e --- /dev/null +++ b/datafusion/functions/src/string/overlay.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +#[derive(Debug)] +pub(super) struct OverlayFunc { + signature: Signature, +} + +impl OverlayFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for OverlayFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "overlay" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "overlay") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(overlay::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(overlay::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function overlay"), + } + } +} + +/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) +/// Replaces a substring of string1 with string2 starting at the integer bit +/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas +/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead +pub fn overlay(args: &[ArrayRef]) -> Result { + match args.len() { + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 4 => { + let string_array = as_generic_string_array::(&args[0])?; + let characters_array = as_generic_string_array::(&args[1])?; + let pos_num = as_int64_array(&args[2])?; + let len_num = as_int64_array(&args[3])?; + + let result = string_array + .iter() + .zip(characters_array.iter()) + .zip(pos_num.iter()) + .zip(len_num.iter()) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = + String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Int64Array, StringArray}; + + use super::*; + + #[test] + fn to_overlay() -> Result<()> { + let string = + Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); + let replace_string = + Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); + let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start + let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len + + let res = overlay::(&[string, replace_string, start, end]).unwrap(); + let result = as_generic_string_array::(&res).unwrap(); + let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); + assert_eq!(&expected, result); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/uuid.rs b/datafusion/functions/src/string/uuid.rs new file mode 100644 index 000000000000..791ad6d3c4f3 --- /dev/null +++ b/datafusion/functions/src/string/uuid.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::iter; +use std::sync::Arc; + +use arrow::array::GenericStringArray; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Utf8; +use uuid::Uuid; + +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +#[derive(Debug)] +pub(super) struct UuidFunc { + signature: Signature, +} + +impl UuidFunc { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for UuidFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "uuid" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + /// Prints random (v4) uuid values per row + /// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let len: usize = match &args[0] { + ColumnarValue::Array(array) => array.len(), + _ => return exec_err!("Expect uuid function to take no param"), + }; + + let values = iter::repeat_with(|| Uuid::new_v4().to_string()).take(len); + let array = GenericStringArray::::from_iter_values(values); + Ok(ColumnarValue::Array(Arc::new(array))) + } +} diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d63ad9bb4a3a..24b831e7c575 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -74,7 +74,6 @@ rand = { workspace = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] criterion = "0.5" diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 2436fa24d4ef..8759adc89b40 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -458,29 +458,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function translate") } }), - BuiltinScalarFunction::Uuid => Arc::new(string_expressions::uuid), - BuiltinScalarFunction::OverLay => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::overlay::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::overlay::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function overlay"), - }), - BuiltinScalarFunction::Levenshtein => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => make_scalar_function_inner( - string_expressions::levenshtein::, - )(args), - DataType::LargeUtf8 => make_scalar_function_inner( - string_expressions::levenshtein::, - )(args), - other => { - exec_err!("Unsupported data type {other:?} for function levenshtein") - } - }) - } BuiltinScalarFunction::SubstrIndex => { Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -1868,11 +1845,7 @@ mod tests { let execution_props = ExecutionProps::new(); let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let funs = [ - BuiltinScalarFunction::Pi, - BuiltinScalarFunction::Random, - BuiltinScalarFunction::Uuid, - ]; + let funs = [BuiltinScalarFunction::Pi, BuiltinScalarFunction::Random]; for fun in funs.iter() { create_physical_expr_with_type_coercion(fun, &[], &schema, &execution_props)?; diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 13e4ce77e0ac..766e167a9426 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -21,7 +21,6 @@ //! String expressions -use std::iter; use std::sync::Arc; use arrow::{ @@ -31,9 +30,7 @@ use arrow::{ }, datatypes::DataType, }; -use uuid::Uuid; -use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use datafusion_common::{ cast::{as_generic_string_array, as_int64_array, as_string_array}, @@ -333,192 +330,3 @@ pub fn ends_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } - -/// Prints random (v4) uuid values per row -/// uuid() = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11' -pub fn uuid(args: &[ColumnarValue]) -> Result { - let len: usize = match &args[0] { - ColumnarValue::Array(array) => array.len(), - _ => return exec_err!("Expect uuid function to take no param"), - }; - - let values = iter::repeat_with(|| Uuid::new_v4().to_string()).take(len); - let array = GenericStringArray::::from_iter_values(values); - Ok(ColumnarValue::Array(Arc::new(array))) -} - -/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) -/// Replaces a substring of string1 with string2 starting at the integer bit -/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas -/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead -pub fn overlay(args: &[ArrayRef]) -> Result { - match args.len() { - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .map(|((string, characters), start_pos)| { - match (string, characters, start_pos) { - (Some(string), Some(characters), Some(start_pos)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = characters_len as i64; - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 4 => { - let string_array = as_generic_string_array::(&args[0])?; - let characters_array = as_generic_string_array::(&args[1])?; - let pos_num = as_int64_array(&args[2])?; - let len_num = as_int64_array(&args[3])?; - - let result = string_array - .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .zip(len_num.iter()) - .map(|(((string, characters), start_pos), len)| { - match (string, characters, start_pos, len) { - (Some(string), Some(characters), Some(start_pos), Some(len)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = len.min(string_len as i64); - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") - } - } -} - -///Returns the Levenshtein distance between the two given strings. -/// LEVENSHTEIN('kitten', 'sitting') = 3 -pub fn levenshtein(args: &[ArrayRef]) -> Result { - if args.len() != 2 { - return exec_err!( - "levenshtein function requires two arguments, got {}", - args.len() - ); - } - let str1_array = as_generic_string_array::(&args[0])?; - let str2_array = as_generic_string_array::(&args[1])?; - match args[0].data_type() { - DataType::Utf8 => { - let result = str1_array - .iter() - .zip(str2_array.iter()) - .map(|(string1, string2)| match (string1, string2) { - (Some(string1), Some(string2)) => { - Some(datafusion_strsim::levenshtein(string1, string2) as i32) - } - _ => None, - }) - .collect::(); - Ok(Arc::new(result) as ArrayRef) - } - DataType::LargeUtf8 => { - let result = str1_array - .iter() - .zip(str2_array.iter()) - .map(|(string1, string2)| match (string1, string2) { - (Some(string1), Some(string2)) => { - Some(datafusion_strsim::levenshtein(string1, string2) as i64) - } - _ => None, - }) - .collect::(); - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!( - "levenshtein was called with {other} datatype arguments. It requires Utf8 or LargeUtf8." - ) - } - } -} - -#[cfg(test)] -mod tests { - use arrow::array::Int32Array; - use arrow_array::Int64Array; - - use datafusion_common::cast::as_int32_array; - - use super::*; - - #[test] - fn to_overlay() -> Result<()> { - let string = - Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"])); - let replace_string = - Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"])); - let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start - let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len - - let res = overlay::(&[string, replace_string, start, end]).unwrap(); - let result = as_generic_string_array::(&res).unwrap(); - let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]); - assert_eq!(&expected, result); - - Ok(()) - } - - #[test] - fn to_levenshtein() -> Result<()> { - let string1_array = - Arc::new(StringArray::from(vec!["123", "abc", "xyz", "kitten"])); - let string2_array = - Arc::new(StringArray::from(vec!["321", "def", "zyx", "sitting"])); - let res = levenshtein::(&[string1_array, string2_array]).unwrap(); - let result = - as_int32_array(&res).expect("failed to initialized function levenshtein"); - let expected = Int32Array::from(vec![2, 3, 2, 3]); - assert_eq!(&expected, result); - - Ok(()) - } -} diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 72ee4fb3ef7e..1ba32bff746e 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -58,7 +58,6 @@ parking_lot = { workspace = true } pin-project-lite = "^0.2.7" rand = { workspace = true } tokio = { workspace = true } -uuid = { version = "^1.2", features = ["v4"] } [dev-dependencies] rstest = { workspace = true } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index e4953283b184..795995ce2c46 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -613,7 +613,7 @@ enum ScalarFunction { // 69 was ArrowTypeof // 70 was CurrentDate // 71 was CurrentTime - Uuid = 72; + // 72 was Uuid Cbrt = 73; Acosh = 74; Asinh = 75; @@ -660,11 +660,11 @@ enum ScalarFunction { // 118 was ToTimestampNanos // 119 was ArrayIntersect // 120 was ArrayUnion - OverLay = 121; + // 121 was OverLay // 122 is Range // 123 is ArrayExcept // 124 was ArrayPopFront - Levenshtein = 125; + // 125 was Levenshtein SubstrIndex = 126; FindInSet = 127; // 128 was ArraySort diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7cdebdf85944..3941171e4fe6 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22949,7 +22949,6 @@ impl serde::Serialize for ScalarFunction { Self::Coalesce => "Coalesce", Self::Power => "Power", Self::Atan2 => "Atan2", - Self::Uuid => "Uuid", Self::Cbrt => "Cbrt", Self::Acosh => "Acosh", Self::Asinh => "Asinh", @@ -22965,8 +22964,6 @@ impl serde::Serialize for ScalarFunction { Self::Cot => "Cot", Self::Nanvl => "Nanvl", Self::Iszero => "Iszero", - Self::OverLay => "OverLay", - Self::Levenshtein => "Levenshtein", Self::SubstrIndex => "SubstrIndex", Self::FindInSet => "FindInSet", Self::EndsWith => "EndsWith", @@ -23017,7 +23014,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Coalesce", "Power", "Atan2", - "Uuid", "Cbrt", "Acosh", "Asinh", @@ -23033,8 +23029,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cot", "Nanvl", "Iszero", - "OverLay", - "Levenshtein", "SubstrIndex", "FindInSet", "EndsWith", @@ -23114,7 +23108,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), "Atan2" => Ok(ScalarFunction::Atan2), - "Uuid" => Ok(ScalarFunction::Uuid), "Cbrt" => Ok(ScalarFunction::Cbrt), "Acosh" => Ok(ScalarFunction::Acosh), "Asinh" => Ok(ScalarFunction::Asinh), @@ -23130,8 +23123,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cot" => Ok(ScalarFunction::Cot), "Nanvl" => Ok(ScalarFunction::Nanvl), "Iszero" => Ok(ScalarFunction::Iszero), - "OverLay" => Ok(ScalarFunction::OverLay), - "Levenshtein" => Ok(ScalarFunction::Levenshtein), "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), "FindInSet" => Ok(ScalarFunction::FindInSet), "EndsWith" => Ok(ScalarFunction::EndsWith), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2932bcf6d93f..58fda7fcb5ad 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2912,7 +2912,7 @@ pub enum ScalarFunction { /// 69 was ArrowTypeof /// 70 was CurrentDate /// 71 was CurrentTime - Uuid = 72, + /// 72 was Uuid Cbrt = 73, Acosh = 74, Asinh = 75, @@ -2959,11 +2959,11 @@ pub enum ScalarFunction { /// 118 was ToTimestampNanos /// 119 was ArrayIntersect /// 120 was ArrayUnion - OverLay = 121, + /// 121 was OverLay /// 122 is Range /// 123 is ArrayExcept /// 124 was ArrayPopFront - Levenshtein = 125, + /// 125 was Levenshtein SubstrIndex = 126, FindInSet = 127, /// 128 was ArraySort @@ -3022,7 +3022,6 @@ impl ScalarFunction { ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", ScalarFunction::Atan2 => "Atan2", - ScalarFunction::Uuid => "Uuid", ScalarFunction::Cbrt => "Cbrt", ScalarFunction::Acosh => "Acosh", ScalarFunction::Asinh => "Asinh", @@ -3038,8 +3037,6 @@ impl ScalarFunction { ScalarFunction::Cot => "Cot", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Iszero => "Iszero", - ScalarFunction::OverLay => "OverLay", - ScalarFunction::Levenshtein => "Levenshtein", ScalarFunction::SubstrIndex => "SubstrIndex", ScalarFunction::FindInSet => "FindInSet", ScalarFunction::EndsWith => "EndsWith", @@ -3084,7 +3081,6 @@ impl ScalarFunction { "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), "Atan2" => Some(Self::Atan2), - "Uuid" => Some(Self::Uuid), "Cbrt" => Some(Self::Cbrt), "Acosh" => Some(Self::Acosh), "Asinh" => Some(Self::Asinh), @@ -3100,8 +3096,6 @@ impl ScalarFunction { "Cot" => Some(Self::Cot), "Nanvl" => Some(Self::Nanvl), "Iszero" => Some(Self::Iszero), - "OverLay" => Some(Self::OverLay), - "Levenshtein" => Some(Self::Levenshtein), "SubstrIndex" => Some(Self::SubstrIndex), "FindInSet" => Some(Self::FindInSet), "EndsWith" => Some(Self::EndsWith), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d00aeeda462b..3b44c1cb276d 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -51,14 +51,13 @@ use datafusion_expr::{ acosh, asinh, atan, atan2, atanh, bit_length, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, levenshtein, ln, log, - log10, log2, + factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lpad, nanvl, overlay, pi, power, radians, random, repeat, replace, reverse, right, - round, rpad, signum, sin, sinh, split_part, sqrt, strpos, substr, substr_index, - substring, translate, trunc, uuid, AggregateFunction, Between, BinaryExpr, - BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, - GetIndexedField, GroupingSet, + lpad, nanvl, pi, power, radians, random, repeat, replace, reverse, right, round, + rpad, signum, sin, sinh, split_part, sqrt, strpos, substr, substr_index, substring, + translate, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, + BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, + GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -477,7 +476,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::SplitPart => Self::SplitPart, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, - ScalarFunction::Uuid => Self::Uuid, ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, @@ -485,8 +483,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, - ScalarFunction::OverLay => Self::OverLay, - ScalarFunction::Levenshtein => Self::Levenshtein, ScalarFunction::SubstrIndex => Self::SubstrIndex, ScalarFunction::FindInSet => Self::FindInSet, } @@ -1449,7 +1445,6 @@ pub fn parse_expr( parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Random => Ok(random()), - ScalarFunction::Uuid => Ok(uuid()), ScalarFunction::Repeat => Ok(repeat( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, @@ -1518,10 +1513,6 @@ pub fn parse_expr( )) } } - ScalarFunction::Levenshtein => Ok(levenshtein( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Translate => Ok(translate( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, @@ -1554,12 +1545,6 @@ pub fn parse_expr( ScalarFunction::Iszero => { Ok(iszero(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::OverLay => Ok(overlay( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), ScalarFunction::SubstrIndex => Ok(substr_index( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index edb8c4e4eb01..446a91a39a1b 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1490,7 +1490,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Left => Self::Left, BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Uuid => Self::Uuid, BuiltinScalarFunction::Repeat => Self::Repeat, BuiltinScalarFunction::Replace => Self::Replace, BuiltinScalarFunction::Reverse => Self::Reverse, @@ -1506,8 +1505,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, - BuiltinScalarFunction::OverLay => Self::OverLay, - BuiltinScalarFunction::Levenshtein => Self::Levenshtein, BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, BuiltinScalarFunction::FindInSet => Self::FindInSet, }; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 04f8001bfc1b..d1fc03194997 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -795,7 +795,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = BuiltinScalarFunction::OverLay; + let fun = self + .context_provider + .get_function_meta("overlay") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'overlay' function") + })?; let arg = self.sql_expr_to_logical_expr(expr, schema, planner_context)?; let what_arg = self.sql_expr_to_logical_expr(overlay_what, schema, planner_context)?; @@ -809,7 +814,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => vec![arg, what_arg, from_arg], }; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_position_to_expr( &self, From 227d1f85ed77e55534e8c9b9052f3781676a7d2d Mon Sep 17 00:00:00 2001 From: Kunal Kundu Date: Sun, 24 Mar 2024 23:41:21 +0530 Subject: [PATCH 056/117] improve null handling for to_char (#9689) * improve null handling for to_char * early return from to_char for null format * remove invalid comment, update example * rename column for consistency across platforms for tests * return None instead of empty string from to_char * use arrow:new_null_array for fast init --- datafusion-examples/examples/to_char.rs | 19 +++++++++++ datafusion/functions/src/datetime/to_char.rs | 32 +++++++++++++++---- .../sqllogictest/test_files/timestamps.slt | 14 ++++++-- 3 files changed, 55 insertions(+), 10 deletions(-) diff --git a/datafusion-examples/examples/to_char.rs b/datafusion-examples/examples/to_char.rs index ef616d72cc1c..f8ed68b46f19 100644 --- a/datafusion-examples/examples/to_char.rs +++ b/datafusion-examples/examples/to_char.rs @@ -193,5 +193,24 @@ async fn main() -> Result<()> { &result ); + // output format is null + + let result = ctx + .sql("SELECT to_char(arrow_cast(123456, 'Duration(Second)'), null) as result") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+--------+", + "| result |", + "+--------+", + "| |", + "+--------+", + ], + &result + ); + Ok(()) } diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 3ca098b1f99b..ef5c45a5ad9c 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::sync::Arc; use arrow::array::cast::AsArray; -use arrow::array::{Array, ArrayRef, StringArray}; +use arrow::array::{new_null_array, Array, ArrayRef, StringArray}; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{ Date32, Date64, Duration, Time32, Time64, Timestamp, Utf8, @@ -109,7 +109,6 @@ impl ScalarUDFImpl for ToCharFunc { } match &args[1] { - // null format, use default formats ColumnarValue::Scalar(ScalarValue::Utf8(None)) | ColumnarValue::Scalar(ScalarValue::Null) => { _to_char_scalar(args[0].clone(), None) @@ -175,6 +174,18 @@ fn _to_char_scalar( let data_type = &expression.data_type(); let is_scalar_expression = matches!(&expression, ColumnarValue::Scalar(_)); let array = expression.into_array(1)?; + + if format.is_none() { + if is_scalar_expression { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } else { + return Ok(ColumnarValue::Array(new_null_array( + &DataType::Utf8, + array.len(), + ))); + } + } + let format_options = match _build_format_options(data_type, format) { Ok(value) => value, Err(value) => return value, @@ -202,7 +213,7 @@ fn _to_char_scalar( fn _to_char_array(args: &[ColumnarValue]) -> Result { let arrays = ColumnarValue::values_to_arrays(args)?; - let mut results: Vec = vec![]; + let mut results: Vec> = vec![]; let format_array = arrays[1].as_string::(); let data_type = arrays[0].data_type(); @@ -212,6 +223,10 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { } else { Some(format_array.value(idx)) }; + if format.is_none() { + results.push(None); + continue; + } let format_options = match _build_format_options(data_type, format) { Ok(value) => value, Err(value) => return value, @@ -221,7 +236,7 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { let formatter = ArrayFormatter::try_new(arrays[0].as_ref(), &format_options)?; let result = formatter.value(idx).try_to_string(); match result { - Ok(value) => results.push(value), + Ok(value) => results.push(Some(value)), Err(e) => return exec_err!("{}", e), } } @@ -230,9 +245,12 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { ColumnarValue::Array(_) => Ok(ColumnarValue::Array(Arc::new(StringArray::from( results, )) as ArrayRef)), - ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( - results.first().unwrap().to_string(), - )))), + ColumnarValue::Scalar(_) => match results.first().unwrap() { + Some(value) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + value.to_string(), + )))), + None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + }, } } diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index f718bbf14cbc..f0e04b522a78 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -2661,7 +2661,7 @@ PT123456S query T select to_char(arrow_cast(123456, 'Duration(Second)'), null); ---- -PT123456S +NULL query error DataFusion error: Execution error: Cast error: Format error SELECT to_char(timestamps, '%X%K') from formats; @@ -2672,14 +2672,22 @@ SELECT to_char('2000-02-03'::date, '%X%K'); query T SELECT to_char(timestamps, null) from formats; ---- -2024-01-01T06:00:00Z -2025-01-01T23:59:58Z +NULL +NULL query T SELECT to_char(null, '%d-%m-%Y'); ---- (empty) +query T +SELECT to_char(column1, column2) +FROM +(VALUES ('2024-01-01 06:00:00'::timestamp, null), ('2025-01-01 23:59:58'::timestamp, '%d:%m:%Y %H-%M-%S')); +---- +NULL +01:01:2025 23-59-58 + statement ok drop table formats; From cb9da2b46f602f006bdd8902208817830a6fc2f1 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Sun, 24 Mar 2024 12:15:34 -0600 Subject: [PATCH 057/117] Add Expr->String for ScalarFunction and InList (#9759) * add ScalarFunction and InList * cargo fmt * address comment --- datafusion/sql/src/unparser/expr.rs | 99 +++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8d25a607bb89..d007d4a843a2 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -52,13 +52,48 @@ impl Unparser<'_> { match expr { Expr::InList(InList { expr, - list: _, - negated: _, + list, + negated, }) => { - not_impl_err!("Unsupported expression: {expr:?}") + let list_expr = list + .iter() + .map(|e| self.expr_to_sql(e)) + .collect::>>()?; + Ok(ast::Expr::InList { + expr: Box::new(self.expr_to_sql(expr)?), + list: list_expr, + negated: *negated, + }) } - Expr::ScalarFunction(ScalarFunction { .. }) => { - not_impl_err!("Unsupported expression: {expr:?}") + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + let func_name = func_def.name(); + + let args = args + .iter() + .map(|e| { + if matches!(e, Expr::Wildcard { qualifier: None }) { + Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) + } else { + self.expr_to_sql(e).map(|e| { + FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) + }) + } + }) + .collect::>>()?; + + Ok(ast::Expr::Function(Function { + name: ast::ObjectName(vec![Ident { + value: func_name.to_string(), + quote_style: None, + }]), + args, + filter: None, + null_treatment: None, + over: None, + distinct: false, + special: false, + order_by: vec![], + })) } Expr::Between(Between { expr, @@ -526,13 +561,53 @@ impl Unparser<'_> { #[cfg(test)] mod tests { + use std::any::Any; + use datafusion_common::TableReference; - use datafusion_expr::{col, expr::AggregateFunction, lit}; + use datafusion_expr::{ + col, expr::AggregateFunction, lit, ColumnarValue, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, + }; use crate::unparser::dialect::CustomDialect; use super::*; + /// Mocked UDF + #[derive(Debug)] + struct DummyUDF { + signature: Signature, + } + + impl DummyUDF { + fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } + } // See sql::tests for E2E tests. #[test] @@ -561,6 +636,18 @@ mod tests { }), r#"CAST("a" AS INTEGER UNSIGNED)"#, ), + ( + col("a").in_list(vec![lit(1), lit(2), lit(3)], false), + r#""a" IN (1, 2, 3)"#, + ), + ( + col("a").in_list(vec![lit(1), lit(2), lit(3)], true), + r#""a" NOT IN (1, 2, 3)"#, + ), + ( + ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]), + r#"dummy_udf("a", "b")"#, + ), ( Expr::Literal(ScalarValue::Date64(Some(0))), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, From 1e4ddb6d86328cb6596bb50da9ccc654f19a83ea Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sun, 24 Mar 2024 14:28:53 -0400 Subject: [PATCH 058/117] Move repeat, replace, split_part to datafusion_functions (#9784) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Fixed missing trim() function. * Move repeat, replace, split_part to datafusion_functions --- datafusion/expr/src/built_in_function.rs | 44 +---- datafusion/expr/src/expr_fn.rs | 6 - datafusion/functions/src/string/mod.rs | 24 +++ datafusion/functions/src/string/repeat.rs | 144 +++++++++++++++ datafusion/functions/src/string/replace.rs | 97 ++++++++++ datafusion/functions/src/string/split_part.rs | 170 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 122 ++----------- .../physical-expr/src/string_expressions.rs | 67 ------- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 26 +-- datafusion/proto/src/logical_plan/to_proto.rs | 3 - 13 files changed, 469 insertions(+), 261 deletions(-) create mode 100644 datafusion/functions/src/string/repeat.rs create mode 100644 datafusion/functions/src/string/replace.rs create mode 100644 datafusion/functions/src/string/split_part.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 1904d58cfc92..b3f17ae3c2ca 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -123,18 +123,12 @@ pub enum BuiltinScalarFunction { Lpad, /// random Random, - /// repeat - Repeat, - /// replace - Replace, /// reverse Reverse, /// right Right, /// rpad Rpad, - /// split_part - SplitPart, /// strpos Strpos, /// substr @@ -238,12 +232,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Left => Volatility::Immutable, BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Repeat => Volatility::Immutable, - BuiltinScalarFunction::Replace => Volatility::Immutable, BuiltinScalarFunction::Reverse => Volatility::Immutable, BuiltinScalarFunction::Right => Volatility::Immutable, BuiltinScalarFunction::Rpad => Volatility::Immutable, - BuiltinScalarFunction::SplitPart => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, @@ -293,12 +284,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), - BuiltinScalarFunction::Repeat => { - utf8_to_str_type(&input_expr_types[0], "repeat") - } - BuiltinScalarFunction::Replace => { - utf8_to_str_type(&input_expr_types[0], "replace") - } BuiltinScalarFunction::Reverse => { utf8_to_str_type(&input_expr_types[0], "reverse") } @@ -306,9 +291,6 @@ impl BuiltinScalarFunction { utf8_to_str_type(&input_expr_types[0], "right") } BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), - BuiltinScalarFunction::SplitPart => { - utf8_to_str_type(&input_expr_types[0], "split_part") - } BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") @@ -417,21 +399,12 @@ impl BuiltinScalarFunction { self.volatility(), ) } - BuiltinScalarFunction::Left - | BuiltinScalarFunction::Repeat - | BuiltinScalarFunction::Right => Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], - self.volatility(), - ), - BuiltinScalarFunction::SplitPart => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, Utf8, Int64]), - Exact(vec![Utf8, LargeUtf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - self.volatility(), - ), + BuiltinScalarFunction::Left | BuiltinScalarFunction::Right => { + Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + self.volatility(), + ) + } BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { Signature::one_of( @@ -467,7 +440,7 @@ impl BuiltinScalarFunction { self.volatility(), ), - BuiltinScalarFunction::Replace | BuiltinScalarFunction::Translate => { + BuiltinScalarFunction::Translate => { Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) } BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), @@ -637,12 +610,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => &["initcap"], BuiltinScalarFunction::Left => &["left"], BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Repeat => &["repeat"], - BuiltinScalarFunction::Replace => &["replace"], BuiltinScalarFunction::Reverse => &["reverse"], BuiltinScalarFunction::Right => &["right"], BuiltinScalarFunction::Rpad => &["rpad"], - BuiltinScalarFunction::SplitPart => &["split_part"], BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::Translate => &["translate"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 60db21e5f5fe..f75d8869671e 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -598,11 +598,8 @@ scalar_expr!( ); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); -scalar_expr!(Replace, replace, string from to, "replaces all occurrences of `from` with `to` in the `string`"); -scalar_expr!(Repeat, repeat, string n, "repeats the `string` to `n` times"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`"); -scalar_expr!(SplitPart, split_part, string delimiter index, "splits a string based on a delimiter and picks out the desired field based on the index."); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); @@ -1056,13 +1053,10 @@ mod test { test_scalar_expr!(Left, left, string, count); test_nary_scalar_expr!(Lpad, lpad, string, count); test_nary_scalar_expr!(Lpad, lpad, string, count, characters); - test_scalar_expr!(Replace, replace, string, from, to); - test_scalar_expr!(Repeat, repeat, string, count); test_scalar_expr!(Reverse, reverse, string); test_scalar_expr!(Right, right, string, count); test_nary_scalar_expr!(Rpad, rpad, string, count); test_nary_scalar_expr!(Rpad, rpad, string, count, characters); - test_scalar_expr!(SplitPart, split_part, expr, delimiter, index); test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index 165a7c660404..d2b9fb2da805 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -29,7 +29,10 @@ mod lower; mod ltrim; mod octet_length; mod overlay; +mod repeat; +mod replace; mod rtrim; +mod split_part; mod starts_with; mod to_hex; mod upper; @@ -43,8 +46,11 @@ make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); make_udf_function!(lower::LowerFunc, LOWER, lower); make_udf_function!(octet_length::OctetLengthFunc, OCTET_LENGTH, octet_length); make_udf_function!(overlay::OverlayFunc, OVERLAY, overlay); +make_udf_function!(repeat::RepeatFunc, REPEAT, repeat); +make_udf_function!(replace::ReplaceFunc, REPLACE, replace); make_udf_function!(rtrim::RtrimFunc, RTRIM, rtrim); make_udf_function!(starts_with::StartsWithFunc, STARTS_WITH, starts_with); +make_udf_function!(split_part::SplitPartFunc, SPLIT_PART, split_part); make_udf_function!(to_hex::ToHexFunc, TO_HEX, to_hex); make_udf_function!(upper::UpperFunc, UPPER, upper); make_udf_function!(uuid::UuidFunc, UUID, uuid); @@ -87,11 +93,26 @@ pub mod expr_fn { super::overlay().call(args) } + #[doc = "Repeats the `string` to `n` times"] + pub fn repeat(string: Expr, n: Expr) -> Expr { + super::repeat().call(vec![string, n]) + } + + #[doc = "Replaces all occurrences of `from` with `to` in the `string`"] + pub fn replace(string: Expr, from: Expr, to: Expr) -> Expr { + super::replace().call(vec![string, from, to]) + } + #[doc = "Removes all characters, spaces by default, from the end of a string"] pub fn rtrim(args: Vec) -> Expr { super::rtrim().call(args) } + #[doc = "Splits a string based on a delimiter and picks out the desired field based on the index."] + pub fn split_part(string: Expr, delimiter: Expr, index: Expr) -> Expr { + super::split_part().call(vec![string, delimiter, index]) + } + #[doc = "Returns true if string starts with prefix."] pub fn starts_with(arg1: Expr, arg2: Expr) -> Expr { super::starts_with().call(vec![arg1, arg2]) @@ -128,7 +149,10 @@ pub fn functions() -> Vec> { ltrim(), octet_length(), overlay(), + repeat(), + replace(), rtrim(), + split_part(), starts_with(), to_hex(), upper(), diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs new file mode 100644 index 000000000000..83bc929cb9a4 --- /dev/null +++ b/datafusion/functions/src/string/repeat.rs @@ -0,0 +1,144 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +#[derive(Debug)] +pub(super) struct RepeatFunc { + signature: Signature, +} + +impl RepeatFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RepeatFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "repeat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "repeat") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(repeat::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(repeat::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function repeat"), + } + } +} + +/// Repeats string the specified number of times. +/// repeat('Pg', 4) = 'PgPgPgPg' +fn repeat(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let number_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(number_array.iter()) + .map(|(string, number)| match (string, number) { + (Some(string), Some(number)) => Some(string.repeat(number as usize)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::Result; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::common::test::test_function; + use crate::string::repeat::RepeatFunc; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(Some("PgPgPgPg")), + &str, + Utf8, + StringArray + ); + + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RepeatFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs new file mode 100644 index 000000000000..e35244296090 --- /dev/null +++ b/datafusion/functions/src/string/replace.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +#[derive(Debug)] +pub(super) struct ReplaceFunc { + signature: Signature, +} + +impl ReplaceFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ReplaceFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "replace" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "replace") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(replace::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(replace::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function replace") + } + } + } +} + +/// Replaces all occurrences in string of substring from with substring to. +/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' +fn replace(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let from_array = as_generic_string_array::(&args[1])?; + let to_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +mod test {} diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs new file mode 100644 index 000000000000..af201e90fcf6 --- /dev/null +++ b/datafusion/functions/src/string/split_part.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +#[derive(Debug)] +pub(super) struct SplitPartFunc { + signature: Signature, +} + +impl SplitPartFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, Utf8, Int64]), + Exact(vec![Utf8, LargeUtf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SplitPartFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "split_part" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "split_part") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(split_part::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(split_part::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function split_part") + } + } + } +} + +/// Splits string at occurrences of delimiter and returns the n'th field (counting from one). +/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' +fn split_part(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let n_array = as_int64_array(&args[2])?; + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(n_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + if n <= 0 { + exec_err!("field position must be greater than zero") + } else { + let split_string: Vec<&str> = string.split(delimiter).collect(); + match split_string.get(n as usize - 1) { + Some(s) => Ok(Some(*s)), + None => Ok(Some("")), + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::ScalarValue; + use datafusion_common::{exec_err, Result}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::string::common::test::test_function; + use crate::string::split_part::SplitPartFunc; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + SplitPartFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ], + Ok(Some("def")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SplitPartFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + "abc~@~def~@~ghi" + )))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ], + exec_err!("field position must be greater than zero"), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 8759adc89b40..163598c2df82 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -30,17 +30,16 @@ //! an argument i32 is passed to a function that supports f64, the //! argument is automatically is coerced to f64. -use crate::sort_properties::SortProperties; -use crate::{ - conditional_expressions, math_expressions, string_expressions, PhysicalExpr, - ScalarFunctionExpr, -}; +use std::ops::Neg; +use std::sync::Arc; + use arrow::{ array::ArrayRef, compute::kernels::length::bit_length, datatypes::{DataType, Int32Type, Int64Type, Schema}, }; use arrow_array::Array; + use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; pub use datafusion_expr::FuncMonotonicity; @@ -49,8 +48,12 @@ use datafusion_expr::{ type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, }; -use std::ops::Neg; -use std::sync::Arc; + +use crate::sort_properties::SortProperties; +use crate::{ + conditional_expressions, math_expressions, string_expressions, PhysicalExpr, + ScalarFunctionExpr, +}; /// Create a physical (function) expression. /// This function errors when `args`' can't be coerced to a valid argument type of the function. @@ -328,26 +331,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function lpad"), }), - BuiltinScalarFunction::Repeat => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::repeat::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::repeat::)(args) - } - other => exec_err!("Unsupported data type {other:?} for function repeat"), - }), - BuiltinScalarFunction::Replace => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::replace::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::replace::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function replace") - } - }), BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = @@ -387,17 +370,6 @@ pub fn create_physical_fun( } other => exec_err!("Unsupported data type {other:?} for function rpad"), }), - BuiltinScalarFunction::SplitPart => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - make_scalar_function_inner(string_expressions::split_part::)(args) - } - DataType::LargeUtf8 => { - make_scalar_function_inner(string_expressions::split_part::)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function split_part") - } - }), BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::ends_with::)(args) @@ -568,9 +540,6 @@ fn func_order_in_one_dimension( #[cfg(test)] mod tests { - use super::*; - use crate::expressions::lit; - use crate::expressions::try_cast; use arrow::{ array::{ Array, ArrayRef, BooleanArray, Float32Array, Float64Array, Int32Array, @@ -579,12 +548,18 @@ mod tests { datatypes::Field, record_batch::RecordBatch, }; + use datafusion_common::cast::as_uint64_array; use datafusion_common::{exec_err, internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::Signature; + use crate::expressions::lit; + use crate::expressions::try_cast; + + use super::*; + /// $FUNC function to test /// $ARGS arguments (vec) to pass to function /// $EXPECTED a Result> where Result allows testing errors and Option allows testing Null @@ -1124,33 +1099,6 @@ mod tests { Utf8, StringArray ); - test_function!( - Repeat, - &[lit("Pg"), lit(ScalarValue::Int64(Some(4))),], - Ok(Some("PgPgPgPg")), - &str, - Utf8, - StringArray - ); - test_function!( - Repeat, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(4))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Repeat, - &[lit("Pg"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); #[cfg(feature = "unicode_expressions")] test_function!( Reverse, @@ -1447,42 +1395,6 @@ mod tests { Utf8, StringArray ); - test_function!( - SplitPart, - &[ - lit("abc~@~def~@~ghi"), - lit("~@~"), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("def")), - &str, - Utf8, - StringArray - ); - test_function!( - SplitPart, - &[ - lit("abc~@~def~@~ghi"), - lit("~@~"), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - SplitPart, - &[ - lit("abc~@~def~@~ghi"), - lit("~@~"), - lit(ScalarValue::Int64(Some(-1))), - ], - exec_err!("field position must be greater than zero"), - &str, - Utf8, - StringArray - ); test_function!( EndsWith, &[lit("alphabet"), lit("alph"),], @@ -1812,7 +1724,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); // pick some arbitrary functions to test - let funs = [BuiltinScalarFunction::Concat, BuiltinScalarFunction::Repeat]; + let funs = [BuiltinScalarFunction::Concat]; for fun in funs.iter() { let expr = create_physical_expr_with_type_coercion( diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 766e167a9426..812b746354a4 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -242,73 +242,6 @@ pub fn instr(args: &[ArrayRef]) -> Result { } } -/// Repeats string the specified number of times. -/// repeat('Pg', 4) = 'PgPgPgPg' -pub fn repeat(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let number_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(number_array.iter()) - .map(|(string, number)| match (string, number) { - (Some(string), Some(number)) => Some(string.repeat(number as usize)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Replaces all occurrences in string of substring from with substring to. -/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef' -pub fn replace(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let from_array = as_generic_string_array::(&args[1])?; - let to_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .map(|((string, from), to)| match (string, from, to) { - (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Splits string at occurrences of delimiter and returns the n'th field (counting from one). -/// split_part('abc~@~def~@~ghi', '~@~', 2) = 'def' -pub fn split_part(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let n_array = as_int64_array(&args[2])?; - let result = string_array - .iter() - .zip(delimiter_array.iter()) - .zip(n_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - if n <= 0 { - exec_err!("field position must be greater than zero") - } else { - let split_string: Vec<&str> = string.split(delimiter).collect(); - match split_string.get(n as usize - 1) { - Some(s) => Ok(Some(*s)), - None => Ok(Some("")), - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) -} - /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' pub fn starts_with(args: &[ArrayRef]) -> Result { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 795995ce2c46..297e355dd7b1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -581,8 +581,8 @@ enum ScalarFunction { // 37 was OctetLength Random = 38; // 39 was RegexpReplace - Repeat = 40; - Replace = 41; + // 40 was Repeat + // 41 was Replace Reverse = 42; Right = 43; Rpad = 44; @@ -591,7 +591,7 @@ enum ScalarFunction { // 47 was SHA256 // 48 was SHA384 // 49 was SHA512 - SplitPart = 50; + // 50 was SplitPart // StartsWith = 51; Strpos = 52; Substr = 53; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 3941171e4fe6..dce815f0f234 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22937,12 +22937,9 @@ impl serde::Serialize for ScalarFunction { Self::Left => "Left", Self::Lpad => "Lpad", Self::Random => "Random", - Self::Repeat => "Repeat", - Self::Replace => "Replace", Self::Reverse => "Reverse", Self::Right => "Right", Self::Rpad => "Rpad", - Self::SplitPart => "SplitPart", Self::Strpos => "Strpos", Self::Substr => "Substr", Self::Translate => "Translate", @@ -23002,12 +22999,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Left", "Lpad", "Random", - "Repeat", - "Replace", "Reverse", "Right", "Rpad", - "SplitPart", "Strpos", "Substr", "Translate", @@ -23096,12 +23090,9 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Left" => Ok(ScalarFunction::Left), "Lpad" => Ok(ScalarFunction::Lpad), "Random" => Ok(ScalarFunction::Random), - "Repeat" => Ok(ScalarFunction::Repeat), - "Replace" => Ok(ScalarFunction::Replace), "Reverse" => Ok(ScalarFunction::Reverse), "Right" => Ok(ScalarFunction::Right), "Rpad" => Ok(ScalarFunction::Rpad), - "SplitPart" => Ok(ScalarFunction::SplitPart), "Strpos" => Ok(ScalarFunction::Strpos), "Substr" => Ok(ScalarFunction::Substr), "Translate" => Ok(ScalarFunction::Translate), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 58fda7fcb5ad..2292687b45a6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2880,8 +2880,8 @@ pub enum ScalarFunction { /// 37 was OctetLength Random = 38, /// 39 was RegexpReplace - Repeat = 40, - Replace = 41, + /// 40 was Repeat + /// 41 was Replace Reverse = 42, Right = 43, Rpad = 44, @@ -2890,7 +2890,7 @@ pub enum ScalarFunction { /// 47 was SHA256 /// 48 was SHA384 /// 49 was SHA512 - SplitPart = 50, + /// 50 was SplitPart /// StartsWith = 51; Strpos = 52, Substr = 53, @@ -3010,12 +3010,9 @@ impl ScalarFunction { ScalarFunction::Left => "Left", ScalarFunction::Lpad => "Lpad", ScalarFunction::Random => "Random", - ScalarFunction::Repeat => "Repeat", - ScalarFunction::Replace => "Replace", ScalarFunction::Reverse => "Reverse", ScalarFunction::Right => "Right", ScalarFunction::Rpad => "Rpad", - ScalarFunction::SplitPart => "SplitPart", ScalarFunction::Strpos => "Strpos", ScalarFunction::Substr => "Substr", ScalarFunction::Translate => "Translate", @@ -3069,12 +3066,9 @@ impl ScalarFunction { "Left" => Some(Self::Left), "Lpad" => Some(Self::Lpad), "Random" => Some(Self::Random), - "Repeat" => Some(Self::Repeat), - "Replace" => Some(Self::Replace), "Reverse" => Some(Self::Reverse), "Right" => Some(Self::Right), "Rpad" => Some(Self::Rpad), - "SplitPart" => Some(Self::SplitPart), "Strpos" => Some(Self::Strpos), "Substr" => Some(Self::Substr), "Translate" => Some(Self::Translate), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 3b44c1cb276d..b78e3ae6dc61 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -53,11 +53,10 @@ use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lpad, nanvl, pi, power, radians, random, repeat, replace, reverse, right, round, - rpad, signum, sin, sinh, split_part, sqrt, strpos, substr, substr_index, substring, - translate, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, - BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, - GroupingSet, + lpad, nanvl, pi, power, radians, random, reverse, right, round, rpad, signum, sin, + sinh, sqrt, strpos, substr, substr_index, substring, translate, trunc, + AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, + Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -468,12 +467,9 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Left => Self::Left, ScalarFunction::Lpad => Self::Lpad, ScalarFunction::Random => Self::Random, - ScalarFunction::Repeat => Self::Repeat, - ScalarFunction::Replace => Self::Replace, ScalarFunction::Reverse => Self::Reverse, ScalarFunction::Right => Self::Right, ScalarFunction::Rpad => Self::Rpad, - ScalarFunction::SplitPart => Self::SplitPart, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, ScalarFunction::Translate => Self::Translate, @@ -1445,15 +1441,6 @@ pub fn parse_expr( parse_expr(&args[1], registry, codec)?, )), ScalarFunction::Random => Ok(random()), - ScalarFunction::Repeat => Ok(repeat( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Replace => Ok(replace( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), ScalarFunction::Reverse => { Ok(reverse(parse_expr(&args[0], registry, codec)?)) } @@ -1485,11 +1472,6 @@ pub fn parse_expr( .map(|expr| parse_expr(expr, registry, codec)) .collect::, _>>()?, )), - ScalarFunction::SplitPart => Ok(split_part( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), ScalarFunction::EndsWith => Ok(ends_with( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 446a91a39a1b..0c0f0c6e0a92 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1490,12 +1490,9 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Left => Self::Left, BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Repeat => Self::Repeat, - BuiltinScalarFunction::Replace => Self::Replace, BuiltinScalarFunction::Reverse => Self::Reverse, BuiltinScalarFunction::Right => Self::Right, BuiltinScalarFunction::Rpad => Self::Rpad, - BuiltinScalarFunction::SplitPart => Self::SplitPart, BuiltinScalarFunction::Strpos => Self::Strpos, BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::Translate => Self::Translate, From 916d4dbcf7e9d70d722b8fc662aef738f61b1409 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Sun, 24 Mar 2024 15:37:57 -0700 Subject: [PATCH 059/117] Issue-9767 - Extract array_dims, array_ndims and flatten functions from functions-array subcrate' s kernels and udf containers (#9786) --- .../src/{udf.rs => dimension.rs} | 137 ++++++------ datafusion/functions-array/src/except.rs | 2 +- .../src/{kernels.rs => flatten.rs} | 195 +++++++++--------- datafusion/functions-array/src/lib.rs | 16 +- 4 files changed, 175 insertions(+), 175 deletions(-) rename datafusion/functions-array/src/{udf.rs => dimension.rs} (56%) rename datafusion/functions-array/src/{kernels.rs => flatten.rs} (55%) diff --git a/datafusion/functions-array/src/udf.rs b/datafusion/functions-array/src/dimension.rs similarity index 56% rename from datafusion/functions-array/src/udf.rs rename to datafusion/functions-array/src/dimension.rs index bdc11155b633..569eff66f7f4 100644 --- a/datafusion/functions-array/src/udf.rs +++ b/datafusion/functions-array/src/dimension.rs @@ -15,17 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array functions. +//! [`ScalarUDFImpl`] definitions for array_dims and array_ndims functions. -use arrow::datatypes::DataType; -use arrow::datatypes::Field; -use datafusion_common::exec_err; -use datafusion_common::plan_err; -use datafusion_common::Result; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use arrow::array::{ + Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, +}; +use arrow::datatypes::{DataType, UInt64Type}; use std::any::Any; + +use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::{exec_err, plan_err, Result}; + +use crate::utils::{compute_array_dims, make_scalar_function}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; +use arrow_schema::Field; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::sync::Arc; make_udf_function!( @@ -64,7 +69,6 @@ impl ScalarUDFImpl for ArrayDims { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; Ok(match arg_types[0] { List(_) | LargeList(_) | FixedSizeList(_, _) => { List(Arc::new(Field::new("item", UInt64, true))) @@ -76,8 +80,7 @@ impl ScalarUDFImpl for ArrayDims { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_dims(&args).map(ColumnarValue::Array) + make_scalar_function(array_dims_inner)(args) } fn aliases(&self) -> &[String] { @@ -120,7 +123,6 @@ impl ScalarUDFImpl for ArrayNdims { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; Ok(match arg_types[0] { List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, _ => { @@ -130,8 +132,7 @@ impl ScalarUDFImpl for ArrayNdims { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::array_ndims(&args).map(ColumnarValue::Array) + make_scalar_function(array_ndims_inner)(args) } fn aliases(&self) -> &[String] { @@ -139,70 +140,68 @@ impl ScalarUDFImpl for ArrayNdims { } } -make_udf_function!( - Flatten, - flatten, - array, - "flattens an array of arrays into a single array.", - flatten_udf -); +/// Array_dims SQL function +pub fn array_dims_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_dims needs one argument"); + } -#[derive(Debug)] -pub(super) struct Flatten { - signature: Signature, - aliases: Vec, -} -impl Flatten { - pub fn new() -> Self { - Self { - signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("flatten")], + let data = match args[0].data_type() { + List(_) => { + let array = as_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? } - } + LargeList(_) => { + let array = as_large_list_array(&args[0])?; + array + .iter() + .map(compute_array_dims) + .collect::>>()? + } + array_type => { + return exec_err!("array_dims does not support type '{array_type:?}'"); + } + }; + + let result = ListArray::from_iter_primitive::(data); + + Ok(Arc::new(result) as ArrayRef) } -impl ScalarUDFImpl for Flatten { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "flatten" +/// Array_ndims SQL function +pub fn array_ndims_inner(args: &[ArrayRef]) -> Result { + if args.len() != 1 { + return exec_err!("array_ndims needs one argument"); } - fn signature(&self) -> &Signature { - &self.signature - } + fn general_list_ndims( + array: &GenericListArray, + ) -> Result { + let mut data = Vec::new(); + let ndims = datafusion_common::utils::list_ndims(array.data_type()); - fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; - fn get_base_type(data_type: &DataType) -> Result { - match data_type { - List(field) | FixedSizeList(field, _) - if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => - { - get_base_type(field.data_type()) - } - LargeList(field) if matches!(field.data_type(), LargeList(_)) => { - get_base_type(field.data_type()) - } - Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), - FixedSizeList(field, _) => Ok(List(field.clone())), - _ => exec_err!( - "Not reachable, data_type should be List, LargeList or FixedSizeList" - ), + for arr in array.iter() { + if arr.is_some() { + data.push(Some(ndims)) + } else { + data.push(None) } } - let data_type = get_base_type(&arg_types[0])?; - Ok(data_type) + Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - let args = ColumnarValue::values_to_arrays(args)?; - crate::kernels::flatten(&args).map(ColumnarValue::Array) - } - - fn aliases(&self) -> &[String] { - &self.aliases + match args[0].data_type() { + List(_) => { + let array = as_list_array(&args[0])?; + general_list_ndims::(array) + } + LargeList(_) => { + let array = as_large_list_array(&args[0])?; + general_list_ndims::(array) + } + array_type => exec_err!("array_ndims does not support type {array_type:?}"), } } diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-array/src/except.rs index 1faaf80e69f6..72932d530ad0 100644 --- a/datafusion/functions-array/src/except.rs +++ b/datafusion/functions-array/src/except.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! implementation kernel for array_except function +//! [`ScalarUDFImpl`] definitions for array_except function. use crate::utils::check_datatypes; use arrow::row::{RowConverter, SortField}; diff --git a/datafusion/functions-array/src/kernels.rs b/datafusion/functions-array/src/flatten.rs similarity index 55% rename from datafusion/functions-array/src/kernels.rs rename to datafusion/functions-array/src/flatten.rs index 1a08b64197a9..27d4b1d5f971 100644 --- a/datafusion/functions-array/src/kernels.rs +++ b/datafusion/functions-array/src/flatten.rs @@ -15,111 +15,124 @@ // specific language governing permissions and limitations // under the License. -//! implementation kernels for array functions +//! [`ScalarUDFImpl`] definitions for flatten function. -use arrow::array::{ - Array, ArrayRef, GenericListArray, ListArray, OffsetSizeTrait, UInt64Array, -}; -use arrow::datatypes::{DataType, UInt64Type}; +use crate::utils::make_scalar_function; +use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::OffsetBuffer; - +use arrow_schema::DataType; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{ as_generic_list_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, Result}; - -use crate::utils::compute_array_dims; +use datafusion_common::exec_err; +use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; use std::sync::Arc; -/// Array_dims SQL function -pub fn array_dims(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("array_dims needs one argument"); +make_udf_function!( + Flatten, + flatten, + array, + "flattens an array of arrays into a single array.", + flatten_udf +); + +#[derive(Debug)] +pub(super) struct Flatten { + signature: Signature, + aliases: Vec, +} +impl Flatten { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + aliases: vec![String::from("flatten")], + } } +} - let data = match args[0].data_type() { - DataType::List(_) => { - let array = as_list_array(&args[0])?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - DataType::LargeList(_) => { - let array = as_large_list_array(&args[0])?; - array - .iter() - .map(compute_array_dims) - .collect::>>()? - } - array_type => { - return exec_err!("array_dims does not support type '{array_type:?}'"); +impl ScalarUDFImpl for Flatten { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "flatten" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn get_base_type(data_type: &DataType) -> datafusion_common::Result { + match data_type { + List(field) | FixedSizeList(field, _) + if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => + { + get_base_type(field.data_type()) + } + LargeList(field) if matches!(field.data_type(), LargeList(_)) => { + get_base_type(field.data_type()) + } + Null | List(_) | LargeList(_) => Ok(data_type.to_owned()), + FixedSizeList(field, _) => Ok(List(field.clone())), + _ => exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + ), + } } - }; - let result = ListArray::from_iter_primitive::(data); + let data_type = get_base_type(&arg_types[0])?; + Ok(data_type) + } - Ok(Arc::new(result) as ArrayRef) + fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + make_scalar_function(flatten_inner)(args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } } -/// Array_ndims SQL function -pub fn array_ndims(args: &[ArrayRef]) -> Result { +/// Flatten SQL function +pub fn flatten_inner(args: &[ArrayRef]) -> datafusion_common::Result { if args.len() != 1 { - return exec_err!("array_ndims needs one argument"); + return exec_err!("flatten expects one argument"); } - fn general_list_ndims( - array: &GenericListArray, - ) -> Result { - let mut data = Vec::new(); - let ndims = datafusion_common::utils::list_ndims(array.data_type()); - - for arr in array.iter() { - if arr.is_some() { - data.push(Some(ndims)) - } else { - data.push(None) - } + let array_type = args[0].data_type(); + match array_type { + List(_) => { + let list_arr = as_list_array(&args[0])?; + let flattened_array = flatten_internal::(list_arr.clone(), None)?; + Ok(Arc::new(flattened_array) as ArrayRef) } - - Ok(Arc::new(UInt64Array::from(data)) as ArrayRef) - } - match args[0].data_type() { - DataType::List(_) => { - let array = as_list_array(&args[0])?; - general_list_ndims::(array) + LargeList(_) => { + let list_arr = as_large_list_array(&args[0])?; + let flattened_array = flatten_internal::(list_arr.clone(), None)?; + Ok(Arc::new(flattened_array) as ArrayRef) } - DataType::LargeList(_) => { - let array = as_large_list_array(&args[0])?; - general_list_ndims::(array) + Null => Ok(args[0].clone()), + _ => { + exec_err!("flatten does not support type '{array_type:?}'") } - array_type => exec_err!("array_ndims does not support type {array_type:?}"), } } -// Create new offsets that are euqiavlent to `flatten` the array. -fn get_offsets_for_flatten( - offsets: OffsetBuffer, - indexes: OffsetBuffer, -) -> OffsetBuffer { - let buffer = offsets.into_inner(); - let offsets: Vec = indexes - .iter() - .map(|i| buffer[i.to_usize().unwrap()]) - .collect(); - OffsetBuffer::new(offsets.into()) -} - fn flatten_internal( list_arr: GenericListArray, indexes: Option>, -) -> Result> { +) -> datafusion_common::Result> { let (field, offsets, values, _) = list_arr.clone().into_parts(); let data_type = field.data_type(); match data_type { // Recursively get the base offsets for flattened array - DataType::List(_) | DataType::LargeList(_) => { + List(_) | LargeList(_) => { let sub_list = as_generic_list_array::(&values)?; if let Some(indexes) = indexes { let offsets = get_offsets_for_flatten(offsets, indexes); @@ -141,27 +154,15 @@ fn flatten_internal( } } -/// Flatten SQL function -pub fn flatten(args: &[ArrayRef]) -> Result { - if args.len() != 1 { - return exec_err!("flatten expects one argument"); - } - - let array_type = args[0].data_type(); - match array_type { - DataType::List(_) => { - let list_arr = as_list_array(&args[0])?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) - } - DataType::LargeList(_) => { - let list_arr = as_large_list_array(&args[0])?; - let flattened_array = flatten_internal::(list_arr.clone(), None)?; - Ok(Arc::new(flattened_array) as ArrayRef) - } - DataType::Null => Ok(args[0].clone()), - _ => { - exec_err!("flatten does not support type '{array_type:?}'") - } - } +// Create new offsets that are equivalent to `flatten` the array. +fn get_offsets_for_flatten( + offsets: OffsetBuffer, + indexes: OffsetBuffer, +) -> OffsetBuffer { + let buffer = offsets.into_inner(); + let offsets: Vec = indexes + .iter() + .map(|i| buffer[i.to_usize().unwrap()]) + .collect(); + OffsetBuffer::new(offsets.into()) } diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index feecd18c2e8d..30a63deee0e3 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -32,10 +32,11 @@ mod array_has; mod cardinality; mod concat; mod core; +mod dimension; mod empty; mod except; mod extract; -mod kernels; +mod flatten; mod length; mod position; mod range; @@ -48,7 +49,6 @@ mod rewrite; mod set_ops; mod sort; mod string; -mod udf; mod utils; use datafusion_common::Result; @@ -67,12 +67,15 @@ pub mod expr_fn { pub use super::concat::array_concat; pub use super::concat::array_prepend; pub use super::core::make_array; + pub use super::dimension::array_dims; + pub use super::dimension::array_ndims; pub use super::empty::array_empty; pub use super::except::array_except; pub use super::extract::array_element; pub use super::extract::array_pop_back; pub use super::extract::array_pop_front; pub use super::extract::array_slice; + pub use super::flatten::flatten; pub use super::length::array_length; pub use super::position::array_position; pub use super::position::array_positions; @@ -93,9 +96,6 @@ pub mod expr_fn { pub use super::sort::array_sort; pub use super::string::array_to_string; pub use super::string::string_to_array; - pub use super::udf::array_dims; - pub use super::udf::array_ndims; - pub use super::udf::flatten; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -105,9 +105,9 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { string::string_to_array_udf(), range::range_udf(), range::gen_series_udf(), - udf::array_dims_udf(), + dimension::array_dims_udf(), cardinality::cardinality_udf(), - udf::array_ndims_udf(), + dimension::array_ndims_udf(), concat::array_append_udf(), concat::array_prepend_udf(), concat::array_concat_udf(), @@ -122,7 +122,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { array_has::array_has_any_udf(), empty::array_empty_udf(), length::array_length_udf(), - udf::flatten_udf(), + flatten::flatten_udf(), sort::array_sort_udf(), repeat::array_repeat_udf(), resize::array_resize_udf(), From c5faaf7f22a715bf79cb1289f2b5c15131f95ecb Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 24 Mar 2024 18:38:26 -0400 Subject: [PATCH 060/117] Minor: Improve documentation about `ColumnarValues::values_to_array` (#9774) * Minor: Improve documentation about `ColumnarValues::values_to_array` * Apply suggestions from code review Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh --- datafusion/expr/src/columnar_value.rs | 14 +++++++++++--- datafusion/expr/src/udf.rs | 6 ++++-- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/columnar_value.rs b/datafusion/expr/src/columnar_value.rs index 831edc078d6a..87c3c063b91a 100644 --- a/datafusion/expr/src/columnar_value.rs +++ b/datafusion/expr/src/columnar_value.rs @@ -26,11 +26,14 @@ use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; /// Represents the result of evaluating an expression: either a single -/// `ScalarValue` or an [`ArrayRef`]. +/// [`ScalarValue`] or an [`ArrayRef`]. /// /// While a [`ColumnarValue`] can always be converted into an array /// for convenience, it is often much more performant to provide an /// optimized path for scalar values. +/// +/// See [`ColumnarValue::values_to_arrays`] for a function that converts +/// multiple columnar values into arrays of the same length. #[derive(Clone, Debug)] pub enum ColumnarValue { /// Array of values @@ -59,8 +62,13 @@ impl ColumnarValue { } } - /// Convert a columnar value into an ArrayRef. [`Self::Scalar`] is - /// converted by repeating the same scalar multiple times. + /// Convert a columnar value into an Arrow [`ArrayRef`] with the specified + /// number of rows. [`Self::Scalar`] is converted by repeating the same + /// scalar multiple times which is not as efficient as handling the scalar + /// directly. + /// + /// See [`Self::values_to_arrays`] to convert multiple columnar values into + /// arrays of the same length. /// /// # Errors /// diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3002a745055f..56266a05170b 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -326,8 +326,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// For the best performance, the implementations of `invoke` should handle /// the common case when one or more of their arguments are constant values - /// (aka [`ColumnarValue::Scalar`]). Calling [`ColumnarValue::into_array`] - /// and treating all arguments as arrays will work, but will be slower. + /// (aka [`ColumnarValue::Scalar`]). + /// + /// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments + /// to arrays, which will likely be simpler code, but be slower. fn invoke(&self, args: &[ColumnarValue]) -> Result; /// Returns any aliases (alternate names) for this function. From bd9b33ceca3553431b6b328de93f1836bbd9e263 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 24 Mar 2024 18:38:48 -0400 Subject: [PATCH 061/117] Fix panic in `struct` function with mixed scalar/array arguments (#9775) --- datafusion/functions/src/core/struct.rs | 11 ++--------- datafusion/sqllogictest/test_files/struct.slt | 9 +++++++++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 2a8622f0a1ec..ac300e0abde3 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -47,17 +47,10 @@ fn array_struct(args: &[ArrayRef]) -> Result { Ok(Arc::new(StructArray::from(vec))) } + /// put values in a struct array. fn struct_expr(args: &[ColumnarValue]) -> Result { - let arrays = args - .iter() - .map(|x| { - Ok(match x { - ColumnarValue::Array(array) => array.clone(), - ColumnarValue::Scalar(scalar) => scalar.to_array()?.clone(), - }) - }) - .collect::>>()?; + let arrays = ColumnarValue::values_to_arrays(args)?; Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } #[derive(Debug)] diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 936dedcc896e..1ab6f3908b53 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -58,6 +58,15 @@ select struct(a, b, c) from values; {c0: 2, c1: 2.2, c2: b} {c0: 3, c1: 3.3, c2: c} +# struct scalar function with columns and scalars +query ? +select struct(a, 'foo') from values; +---- +{c0: 1, c1: foo} +{c0: 2, c1: foo} +{c0: 3, c1: foo} + + # explain struct scalar function with columns #1 query TT explain select struct(a, b, c) from values; From 8ebff9e9fb7ac365cc3be687f42f6315c2303fe4 Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Mon, 25 Mar 2024 04:19:04 -0700 Subject: [PATCH 062/117] refactor: Apply minor refactorings to `functions-array` crate (#9788) * Issue-9787 - Apply minor refactorings to functions-array create * Issue-9787 - Clean-up redundant datafusion_common::Result definitions * Issue-9787 - Addressed review comment --- datafusion/functions-array/src/array_has.rs | 14 ++--- datafusion/functions-array/src/cardinality.rs | 11 ++-- datafusion/functions-array/src/concat.rs | 10 ++-- datafusion/functions-array/src/empty.rs | 12 ++--- datafusion/functions-array/src/except.rs | 15 +++--- datafusion/functions-array/src/extract.rs | 37 ++++++------- datafusion/functions-array/src/flatten.rs | 12 ++--- datafusion/functions-array/src/length.rs | 16 +++--- datafusion/functions-array/src/lib.rs | 6 +-- .../src/{core.rs => make_array.rs} | 23 ++++---- datafusion/functions-array/src/position.rs | 42 ++++++--------- datafusion/functions-array/src/range.rs | 25 +++------ datafusion/functions-array/src/remove.rs | 37 ++++++------- datafusion/functions-array/src/repeat.rs | 13 ++--- datafusion/functions-array/src/replace.rs | 23 ++++---- datafusion/functions-array/src/resize.rs | 15 +++--- datafusion/functions-array/src/reverse.rs | 18 ++++--- datafusion/functions-array/src/rewrite.rs | 3 +- datafusion/functions-array/src/set_ops.rs | 53 +++++++++---------- datafusion/functions-array/src/sort.rs | 13 ++--- datafusion/functions-array/src/string.rs | 16 +++--- 21 files changed, 195 insertions(+), 219 deletions(-) rename datafusion/functions-array/src/{core.rs => make_array.rs} (92%) diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index 17c0ad1619d6..4e4ebaf035fc 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array functions. +//! [`ScalarUDFImpl`] definitions for array_has, array_has_all and array_has_any functions. use arrow::array::{Array, ArrayRef, BooleanArray, OffsetSizeTrait}; use arrow::datatypes::DataType; @@ -85,11 +85,11 @@ impl ScalarUDFImpl for ArrayHas { &self.signature } - fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, _: &[DataType]) -> Result { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; if args.len() != 2 { @@ -147,11 +147,11 @@ impl ScalarUDFImpl for ArrayHasAll { &self.signature } - fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, _: &[DataType]) -> Result { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; if args.len() != 2 { return exec_err!("array_has_all needs two arguments"); @@ -204,11 +204,11 @@ impl ScalarUDFImpl for ArrayHasAny { &self.signature } - fn return_type(&self, _: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, _: &[DataType]) -> Result { Ok(DataType::Boolean) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; if args.len() != 2 { diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-array/src/cardinality.rs index 483336fe081d..ed9f8d01f973 100644 --- a/datafusion/functions-array/src/cardinality.rs +++ b/datafusion/functions-array/src/cardinality.rs @@ -22,6 +22,7 @@ use arrow_array::{ArrayRef, GenericListArray, OffsetSizeTrait, UInt64Array}; use arrow_schema::DataType; use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; use datafusion_common::cast::{as_large_list_array, as_list_array}; +use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; @@ -62,7 +63,7 @@ impl ScalarUDFImpl for Cardinality { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, _ => { @@ -71,7 +72,7 @@ impl ScalarUDFImpl for Cardinality { }) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(cardinality_inner)(args) } @@ -81,7 +82,7 @@ impl ScalarUDFImpl for Cardinality { } /// Cardinality SQL function -pub fn cardinality_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn cardinality_inner(args: &[ArrayRef]) -> Result { if args.len() != 1 { return exec_err!("cardinality expects one argument"); } @@ -103,13 +104,13 @@ pub fn cardinality_inner(args: &[ArrayRef]) -> datafusion_common::Result( array: &GenericListArray, -) -> datafusion_common::Result { +) -> Result { let result = array .iter() .map(|arr| match crate::utils::compute_array_dims(arr)? { Some(vector) => Ok(Some(vector.iter().map(|x| x.unwrap()).product::())), None => Ok(None), }) - .collect::>()?; + .collect::>()?; Ok(Arc::new(result) as ArrayRef) } diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-array/src/concat.rs index a8e7d1008f46..cb76192e29c2 100644 --- a/datafusion/functions-array/src/concat.rs +++ b/datafusion/functions-array/src/concat.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// Includes `array append`, `array prepend`, and `array concat` functions +//! [`ScalarUDFImpl`] definitions for `array_append`, `array_prepend` and `array_concat` functions. use std::{any::Any, cmp::Ordering, sync::Arc}; @@ -39,7 +39,7 @@ use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function make_udf_function!( ArrayAppend, array_append, - array element, // arg name + array element, // arg name "appends an element to the end of an array.", // doc array_append_udf // internal function name ); @@ -283,9 +283,9 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .collect::>(); // Concatenated array on i-th row - let concated_array = arrow::compute::concat(elements.as_slice())?; - array_lengths.push(concated_array.len()); - arrays.push(concated_array); + let concatenated_array = arrow::compute::concat(elements.as_slice())?; + array_lengths.push(concatenated_array.len()); + arrays.push(concatenated_array); valid.append(true); } } diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-array/src/empty.rs index 37b247deb4c8..f11a6f07cfc8 100644 --- a/datafusion/functions-array/src/empty.rs +++ b/datafusion/functions-array/src/empty.rs @@ -22,7 +22,7 @@ use arrow_array::{ArrayRef, BooleanArray, OffsetSizeTrait}; use arrow_schema::DataType; use arrow_schema::DataType::{Boolean, FixedSizeList, LargeList, List}; use datafusion_common::cast::{as_generic_list_array, as_null_array}; -use datafusion_common::{exec_err, plan_err}; +use datafusion_common::{exec_err, plan_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -62,7 +62,7 @@ impl ScalarUDFImpl for ArrayEmpty { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { List(_) | LargeList(_) | FixedSizeList(_, _) => Boolean, _ => { @@ -71,7 +71,7 @@ impl ScalarUDFImpl for ArrayEmpty { }) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_empty_inner)(args) } @@ -81,7 +81,7 @@ impl ScalarUDFImpl for ArrayEmpty { } /// Array_empty SQL function -pub fn array_empty_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_empty_inner(args: &[ArrayRef]) -> Result { if args.len() != 1 { return exec_err!("array_empty expects one argument"); } @@ -99,9 +99,7 @@ pub fn array_empty_inner(args: &[ArrayRef]) -> datafusion_common::Result( - array: &ArrayRef, -) -> datafusion_common::Result { +fn general_array_empty(array: &ArrayRef) -> Result { let array = as_generic_list_array::(array)?; let builder = array .iter() diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-array/src/except.rs index 72932d530ad0..444c7c758771 100644 --- a/datafusion/functions-array/src/except.rs +++ b/datafusion/functions-array/src/except.rs @@ -17,13 +17,13 @@ //! [`ScalarUDFImpl`] definitions for array_except function. -use crate::utils::check_datatypes; +use crate::utils::{check_datatypes, make_scalar_function}; use arrow::row::{RowConverter, SortField}; use arrow_array::cast::AsArray; use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, FieldRef}; -use datafusion_common::{exec_err, internal_err}; +use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -66,16 +66,15 @@ impl ScalarUDFImpl for ArrayExcept { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { match (&arg_types[0].clone(), &arg_types[1].clone()) { (DataType::Null, _) | (_, DataType::Null) => Ok(arg_types[0].clone()), (dt, _) => Ok(dt.clone()), } } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_except_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_except_inner)(args) } fn aliases(&self) -> &[String] { @@ -84,7 +83,7 @@ impl ScalarUDFImpl for ArrayExcept { } /// Array_except SQL function -pub fn array_except_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_except_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_except needs two arguments"); } @@ -118,7 +117,7 @@ fn general_except( l: &GenericListArray, r: &GenericListArray, field: &FieldRef, -) -> datafusion_common::Result> { +) -> Result> { let converter = RowConverter::new(vec![SortField::new(l.value_type())])?; let l_values = l.values().to_owned(); diff --git a/datafusion/functions-array/src/extract.rs b/datafusion/functions-array/src/extract.rs index 86eeaea3c9b4..0dbd106b6f18 100644 --- a/datafusion/functions-array/src/extract.rs +++ b/datafusion/functions-array/src/extract.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// Array Element and Array Slice +//! [`ScalarUDFImpl`] definitions for array_element, array_slice, array_pop_front and array_pop_back functions. use arrow::array::Array; use arrow::array::ArrayRef; @@ -27,15 +27,14 @@ use arrow::array::MutableArrayData; use arrow::array::OffsetSizeTrait; use arrow::buffer::OffsetBuffer; use arrow::datatypes::DataType; +use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::Field; use datafusion_common::cast::as_int64_array; use datafusion_common::cast::as_large_list_array; use datafusion_common::cast::as_list_array; -use datafusion_common::exec_err; -use datafusion_common::internal_datafusion_err; -use datafusion_common::plan_err; -use datafusion_common::DataFusionError; -use datafusion_common::Result; +use datafusion_common::{ + exec_err, internal_datafusion_err, plan_err, DataFusionError, Result, +}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; @@ -110,7 +109,6 @@ impl ScalarUDFImpl for ArrayElement { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; match &arg_types[0] { List(field) | LargeList(field) @@ -137,18 +135,18 @@ impl ScalarUDFImpl for ArrayElement { /// /// For example: /// > array_element(\[1, 2, 3], 2) -> 2 -fn array_element_inner(args: &[ArrayRef]) -> datafusion_common::Result { +fn array_element_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_element needs two arguments"); } match &args[0].data_type() { - DataType::List(_) => { + List(_) => { let array = as_list_array(&args[0])?; let indexes = as_int64_array(&args[1])?; general_array_element::(array, indexes) } - DataType::LargeList(_) => { + LargeList(_) => { let array = as_large_list_array(&args[0])?; let indexes = as_int64_array(&args[1])?; general_array_element::(array, indexes) @@ -163,7 +161,7 @@ fn array_element_inner(args: &[ArrayRef]) -> datafusion_common::Result fn general_array_element( array: &GenericListArray, indexes: &Int64Array, -) -> datafusion_common::Result +) -> Result where i64: TryInto, { @@ -175,10 +173,7 @@ where let mut mutable = MutableArrayData::with_capacities(vec![&original_data], true, capacity); - fn adjusted_array_index( - index: i64, - len: O, - ) -> datafusion_common::Result> + fn adjusted_array_index(index: i64, len: O) -> Result> where i64: TryInto, { @@ -302,11 +297,11 @@ fn array_slice_inner(args: &[ArrayRef]) -> Result { let array_data_type = args[0].data_type(); match array_data_type { - DataType::List(_) => { + List(_) => { let array = as_list_array(&args[0])?; general_array_slice::(array, from_array, to_array, stride) } - DataType::LargeList(_) => { + LargeList(_) => { let array = as_large_list_array(&args[0])?; let from_array = as_int64_array(&args[1])?; let to_array = as_int64_array(&args[2])?; @@ -545,11 +540,11 @@ impl ScalarUDFImpl for ArrayPopFront { fn array_pop_front_inner(args: &[ArrayRef]) -> Result { let array_data_type = args[0].data_type(); match array_data_type { - DataType::List(_) => { + List(_) => { let array = as_list_array(&args[0])?; general_pop_front_list::(array) } - DataType::LargeList(_) => { + LargeList(_) => { let array = as_large_list_array(&args[0])?; general_pop_front_list::(array) } @@ -627,11 +622,11 @@ fn array_pop_back_inner(args: &[ArrayRef]) -> Result { let array_data_type = args[0].data_type(); match array_data_type { - DataType::List(_) => { + List(_) => { let array = as_list_array(&args[0])?; general_pop_back_list::(array) } - DataType::LargeList(_) => { + LargeList(_) => { let array = as_large_list_array(&args[0])?; general_pop_back_list::(array) } diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-array/src/flatten.rs index 27d4b1d5f971..e2b50c6c02cc 100644 --- a/datafusion/functions-array/src/flatten.rs +++ b/datafusion/functions-array/src/flatten.rs @@ -25,7 +25,7 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{ as_generic_list_array, as_large_list_array, as_list_array, }; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -66,8 +66,8 @@ impl ScalarUDFImpl for Flatten { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { - fn get_base_type(data_type: &DataType) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { + fn get_base_type(data_type: &DataType) -> Result { match data_type { List(field) | FixedSizeList(field, _) if matches!(field.data_type(), List(_) | FixedSizeList(_, _)) => @@ -89,7 +89,7 @@ impl ScalarUDFImpl for Flatten { Ok(data_type) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(flatten_inner)(args) } @@ -99,7 +99,7 @@ impl ScalarUDFImpl for Flatten { } /// Flatten SQL function -pub fn flatten_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn flatten_inner(args: &[ArrayRef]) -> Result { if args.len() != 1 { return exec_err!("flatten expects one argument"); } @@ -126,7 +126,7 @@ pub fn flatten_inner(args: &[ArrayRef]) -> datafusion_common::Result { fn flatten_internal( list_arr: GenericListArray, indexes: Option>, -) -> datafusion_common::Result> { +) -> Result> { let (field, offsets, values, _) = list_arr.clone().into_parts(); let data_type = field.data_type(); diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-array/src/length.rs index e8e361131763..9bbd11950d21 100644 --- a/datafusion/functions-array/src/length.rs +++ b/datafusion/functions-array/src/length.rs @@ -26,7 +26,7 @@ use arrow_schema::DataType::{FixedSizeList, LargeList, List, UInt64}; use core::any::type_name; use datafusion_common::cast::{as_generic_list_array, as_int64_array}; use datafusion_common::DataFusionError; -use datafusion_common::{exec_err, plan_err}; +use datafusion_common::{exec_err, plan_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -66,7 +66,7 @@ impl ScalarUDFImpl for ArrayLength { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(match arg_types[0] { List(_) | LargeList(_) | FixedSizeList(_, _) => UInt64, _ => { @@ -75,7 +75,7 @@ impl ScalarUDFImpl for ArrayLength { }) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_length_inner)(args) } @@ -85,7 +85,7 @@ impl ScalarUDFImpl for ArrayLength { } /// Array_length SQL function -pub fn array_length_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_length_inner(args: &[ArrayRef]) -> Result { if args.len() != 1 && args.len() != 2 { return exec_err!("array_length expects one or two arguments"); } @@ -98,9 +98,7 @@ pub fn array_length_inner(args: &[ArrayRef]) -> datafusion_common::Result( - array: &[ArrayRef], -) -> datafusion_common::Result { +fn general_array_length(array: &[ArrayRef]) -> Result { let list_array = as_generic_list_array::(&array[0])?; let dimension = if array.len() == 2 { as_int64_array(&array[1])?.clone() @@ -112,7 +110,7 @@ fn general_array_length( .iter() .zip(dimension.iter()) .map(|(arr, dim)| compute_array_length(arr, dim)) - .collect::>()?; + .collect::>()?; Ok(Arc::new(result) as ArrayRef) } @@ -121,7 +119,7 @@ fn general_array_length( fn compute_array_length( arr: Option, dimension: Option, -) -> datafusion_common::Result> { +) -> Result> { let mut current_dimension: i64 = 1; let mut value = match arr { Some(arr) => arr, diff --git a/datafusion/functions-array/src/lib.rs b/datafusion/functions-array/src/lib.rs index 30a63deee0e3..7c261f958bf0 100644 --- a/datafusion/functions-array/src/lib.rs +++ b/datafusion/functions-array/src/lib.rs @@ -31,13 +31,13 @@ pub mod macros; mod array_has; mod cardinality; mod concat; -mod core; mod dimension; mod empty; mod except; mod extract; mod flatten; mod length; +mod make_array; mod position; mod range; mod remove; @@ -66,7 +66,6 @@ pub mod expr_fn { pub use super::concat::array_append; pub use super::concat::array_concat; pub use super::concat::array_prepend; - pub use super::core::make_array; pub use super::dimension::array_dims; pub use super::dimension::array_ndims; pub use super::empty::array_empty; @@ -77,6 +76,7 @@ pub mod expr_fn { pub use super::extract::array_slice; pub use super::flatten::flatten; pub use super::length::array_length; + pub use super::make_array::make_array; pub use super::position::array_position; pub use super::position::array_positions; pub use super::range::gen_series; @@ -116,7 +116,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { extract::array_pop_back_udf(), extract::array_pop_front_udf(), extract::array_slice_udf(), - core::make_array_udf(), + make_array::make_array_udf(), array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), diff --git a/datafusion/functions-array/src/core.rs b/datafusion/functions-array/src/make_array.rs similarity index 92% rename from datafusion/functions-array/src/core.rs rename to datafusion/functions-array/src/make_array.rs index fdd127cc3f32..8eaae09f28f5 100644 --- a/datafusion/functions-array/src/core.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -// core array function like `make_array` +//! [`ScalarUDFImpl`] definitions for `make_array` function. use std::{any::Any, sync::Arc}; @@ -24,9 +24,9 @@ use arrow_array::{ new_null_array, Array, ArrayRef, GenericListArray, NullArray, OffsetSizeTrait, }; use arrow_buffer::OffsetBuffer; +use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; -use datafusion_common::Result; -use datafusion_common::{plan_err, utils::array_into_list_array}; +use datafusion_common::{plan_err, utils::array_into_list_array, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::Expr; use datafusion_expr::{ @@ -73,7 +73,7 @@ impl ScalarUDFImpl for MakeArray { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { match arg_types.len() { 0 => Ok(DataType::List(Arc::new(Field::new( "item", @@ -89,9 +89,7 @@ impl ScalarUDFImpl for MakeArray { } } - Ok(DataType::List(Arc::new(Field::new( - "item", expr_type, true, - )))) + Ok(List(Arc::new(Field::new("item", expr_type, true)))) } } } @@ -109,10 +107,10 @@ impl ScalarUDFImpl for MakeArray { /// Constructs an array using the input `data` as `ArrayRef`. /// Returns a reference-counted `Array` instance result. pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { - let mut data_type = DataType::Null; + let mut data_type = Null; for arg in arrays { let arg_data_type = arg.data_type(); - if !arg_data_type.equals_datatype(&DataType::Null) { + if !arg_data_type.equals_datatype(&Null) { data_type = arg_data_type.clone(); break; } @@ -120,12 +118,11 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { match data_type { // Either an empty array or all nulls: - DataType::Null => { - let array = - new_null_array(&DataType::Null, arrays.iter().map(|a| a.len()).sum()); + Null => { + let array = new_null_array(&Null, arrays.iter().map(|a| a.len()).sum()); Ok(Arc::new(array_into_list_array(array))) } - DataType::LargeList(..) => array_array::(arrays, data_type), + LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), } } diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-array/src/position.rs index 627cf3cb0cf0..a5a7a7405aa9 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-array/src/position.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array_position function. +//! [`ScalarUDFImpl`] definitions for array_position and array_positions functions. use arrow_schema::DataType::{LargeList, List, UInt64}; use arrow_schema::{DataType, Field}; @@ -32,10 +32,10 @@ use arrow_array::{ use datafusion_common::cast::{ as_generic_list_array, as_int64_array, as_large_list_array, as_list_array, }; -use datafusion_common::{exec_err, internal_err}; +use datafusion_common::{exec_err, internal_err, Result}; use itertools::Itertools; -use crate::utils::compare_element_to_list; +use crate::utils::{compare_element_to_list, make_scalar_function}; make_udf_function!( ArrayPosition, @@ -78,16 +78,12 @@ impl ScalarUDFImpl for ArrayPosition { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(UInt64) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_position_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_position_inner)(args) } fn aliases(&self) -> &[String] { @@ -96,7 +92,7 @@ impl ScalarUDFImpl for ArrayPosition { } /// Array_position SQL function -pub fn array_position_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_position_inner(args: &[ArrayRef]) -> Result { if args.len() < 2 || args.len() > 3 { return exec_err!("array_position expects two or three arguments"); } @@ -106,9 +102,7 @@ pub fn array_position_inner(args: &[ArrayRef]) -> datafusion_common::Result exec_err!("array_position does not support type '{array_type:?}'."), } } -fn general_position_dispatch( - args: &[ArrayRef], -) -> datafusion_common::Result { +fn general_position_dispatch(args: &[ArrayRef]) -> Result { let list_array = as_generic_list_array::(&args[0])?; let element_array = &args[1]; @@ -146,7 +140,7 @@ fn generic_position( list_array: &GenericListArray, element_array: &ArrayRef, arr_from: Vec, // 0-indexed -) -> datafusion_common::Result { +) -> Result { let mut data = Vec::with_capacity(list_array.len()); for (row_index, (list_array_row, &from)) in @@ -211,16 +205,12 @@ impl ScalarUDFImpl for ArrayPositions { &self.signature } - fn return_type( - &self, - _arg_types: &[DataType], - ) -> datafusion_common::Result { + fn return_type(&self, _arg_types: &[DataType]) -> Result { Ok(List(Arc::new(Field::new("item", UInt64, true)))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_positions_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_positions_inner)(args) } fn aliases(&self) -> &[String] { @@ -229,7 +219,7 @@ impl ScalarUDFImpl for ArrayPositions { } /// Array_positions SQL function -pub fn array_positions_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_positions_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_positions expects two arguments"); } @@ -237,12 +227,12 @@ pub fn array_positions_inner(args: &[ArrayRef]) -> datafusion_common::Result { + List(_) => { let arr = as_list_array(&args[0])?; crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) } - DataType::LargeList(_) => { + LargeList(_) => { let arr = as_large_list_array(&args[0])?; crate::utils::check_datatypes("array_positions", &[arr.values(), element])?; general_positions::(arr, element) @@ -256,7 +246,7 @@ pub fn array_positions_inner(args: &[ArrayRef]) -> datafusion_common::Result( list_array: &GenericListArray, element_array: &ArrayRef, -) -> datafusion_common::Result { +) -> Result { let mut data = Vec::with_capacity(list_array.len()); for (row_index, list_array_row) in list_array.iter().enumerate() { diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs index 176a5617d599..1c9e0c878e6e 100644 --- a/datafusion/functions-array/src/range.rs +++ b/datafusion/functions-array/src/range.rs @@ -25,6 +25,7 @@ use std::any::Any; use crate::utils::make_scalar_function; use arrow_array::types::{Date32Type, IntervalMonthDayNanoType}; use arrow_array::Date32Array; +use arrow_schema::DataType::{Date32, Int64, Interval, List}; use arrow_schema::IntervalUnit::MonthDayNano; use datafusion_common::cast::{as_date32_array, as_int64_array, as_interval_mdn_array}; use datafusion_common::{exec_err, not_impl_datafusion_err, Result}; @@ -49,7 +50,6 @@ pub(super) struct Range { } impl Range { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ @@ -77,7 +77,6 @@ impl ScalarUDFImpl for Range { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; Ok(List(Arc::new(Field::new( "item", arg_types[0].clone(), @@ -87,12 +86,8 @@ impl ScalarUDFImpl for Range { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Int64 => { - make_scalar_function(|args| gen_range_inner(args, false))(args) - } - DataType::Date32 => { - make_scalar_function(|args| gen_range_date(args, false))(args) - } + Int64 => make_scalar_function(|args| gen_range_inner(args, false))(args), + Date32 => make_scalar_function(|args| gen_range_date(args, false))(args), _ => { exec_err!("unsupported type for range") } @@ -118,7 +113,6 @@ pub(super) struct GenSeries { } impl GenSeries { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ @@ -146,7 +140,6 @@ impl ScalarUDFImpl for GenSeries { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; Ok(List(Arc::new(Field::new( "item", arg_types[0].clone(), @@ -156,12 +149,8 @@ impl ScalarUDFImpl for GenSeries { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Int64 => { - make_scalar_function(|args| gen_range_inner(args, true))(args) - } - DataType::Date32 => { - make_scalar_function(|args| gen_range_date(args, true))(args) - } + Int64 => make_scalar_function(|args| gen_range_inner(args, true))(args), + Date32 => make_scalar_function(|args| gen_range_date(args, true))(args), _ => { exec_err!("unsupported type for range") } @@ -242,7 +231,7 @@ pub(super) fn gen_range_inner( }; } let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Int64, true)), + Arc::new(Field::new("item", Int64, true)), OffsetBuffer::new(offsets.into()), Arc::new(Int64Array::from(values)), Some(NullBuffer::new(valid.finish())), @@ -330,7 +319,7 @@ fn gen_range_date(args: &[ArrayRef], include_upper: bool) -> Result { } let arr = Arc::new(ListArray::try_new( - Arc::new(Field::new("item", DataType::Date32, true)), + Arc::new(Field::new("item", Date32, true)), OffsetBuffer::new(offsets.into()), Arc::new(Date32Array::from(values)), None, diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-array/src/remove.rs index 91c76a6708dc..21e373081054 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-array/src/remove.rs @@ -18,6 +18,7 @@ //! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions. use crate::utils; +use crate::utils::make_scalar_function; use arrow_array::cast::AsArray; use arrow_array::{ new_empty_array, Array, ArrayRef, BooleanArray, GenericListArray, OffsetSizeTrait, @@ -25,7 +26,7 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::{DataType, Field}; use datafusion_common::cast::as_int64_array; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -58,6 +59,7 @@ impl ScalarUDFImpl for ArrayRemove { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_remove" } @@ -66,13 +68,12 @@ impl ScalarUDFImpl for ArrayRemove { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_remove_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_remove_inner)(args) } fn aliases(&self) -> &[String] { @@ -107,6 +108,7 @@ impl ScalarUDFImpl for ArrayRemoveN { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_remove_n" } @@ -115,13 +117,12 @@ impl ScalarUDFImpl for ArrayRemoveN { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_remove_n_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_remove_n_inner)(args) } fn aliases(&self) -> &[String] { @@ -159,6 +160,7 @@ impl ScalarUDFImpl for ArrayRemoveAll { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_remove_all" } @@ -167,13 +169,12 @@ impl ScalarUDFImpl for ArrayRemoveAll { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { - let args = ColumnarValue::values_to_arrays(args)?; - array_remove_all_inner(&args).map(ColumnarValue::Array) + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(array_remove_all_inner)(args) } fn aliases(&self) -> &[String] { @@ -182,7 +183,7 @@ impl ScalarUDFImpl for ArrayRemoveAll { } /// Array_remove SQL function -pub fn array_remove_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_remove_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_remove expects two arguments"); } @@ -192,7 +193,7 @@ pub fn array_remove_inner(args: &[ArrayRef]) -> datafusion_common::Result datafusion_common::Result { +pub fn array_remove_n_inner(args: &[ArrayRef]) -> Result { if args.len() != 3 { return exec_err!("array_remove_n expects three arguments"); } @@ -202,7 +203,7 @@ pub fn array_remove_n_inner(args: &[ArrayRef]) -> datafusion_common::Result datafusion_common::Result { +pub fn array_remove_all_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_remove_all expects two arguments"); } @@ -215,7 +216,7 @@ fn array_remove_internal( array: &ArrayRef, element_array: &ArrayRef, arr_n: Vec, -) -> datafusion_common::Result { +) -> Result { match array.data_type() { DataType::List(_) => { let list_array = array.as_list::(); @@ -252,7 +253,7 @@ fn general_remove( list_array: &GenericListArray, element_array: &ArrayRef, arr_n: Vec, -) -> datafusion_common::Result { +) -> Result { let data_type = list_array.value_type(); let mut new_values = vec![]; // Build up the offsets for the final output array diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-array/src/repeat.rs index bf967f65724b..89b766bdcdfc 100644 --- a/datafusion/functions-array/src/repeat.rs +++ b/datafusion/functions-array/src/repeat.rs @@ -28,7 +28,7 @@ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List}; use arrow_schema::{DataType, Field}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -60,6 +60,7 @@ impl ScalarUDFImpl for ArrayRepeat { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_repeat" } @@ -68,7 +69,7 @@ impl ScalarUDFImpl for ArrayRepeat { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(List(Arc::new(Field::new( "item", arg_types[0].clone(), @@ -76,7 +77,7 @@ impl ScalarUDFImpl for ArrayRepeat { )))) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_repeat_inner)(args) } @@ -86,7 +87,7 @@ impl ScalarUDFImpl for ArrayRepeat { } /// Array_repeat SQL function -pub fn array_repeat_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_repeat_inner(args: &[ArrayRef]) -> Result { if args.len() != 2 { return exec_err!("array_repeat expects two arguments"); } @@ -122,7 +123,7 @@ pub fn array_repeat_inner(args: &[ArrayRef]) -> datafusion_common::Result( array: &ArrayRef, count_array: &Int64Array, -) -> datafusion_common::Result { +) -> Result { let data_type = array.data_type(); let mut new_values = vec![]; @@ -176,7 +177,7 @@ fn general_repeat( fn general_list_repeat( list_array: &GenericListArray, count_array: &Int64Array, -) -> datafusion_common::Result { +) -> Result { let data_type = list_array.data_type(); let value_type = list_array.value_type(); let mut new_values = vec![]; diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-array/src/replace.rs index 8ff65d315431..c32305bb454b 100644 --- a/datafusion/functions-array/src/replace.rs +++ b/datafusion/functions-array/src/replace.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! [`ScalarUDFImpl`] definitions for array functions. +//! [`ScalarUDFImpl`] definitions for array_replace, array_replace_n and array_replace_all functions. use arrow::array::{ Array, ArrayRef, AsArray, Capacities, MutableArrayData, OffsetSizeTrait, @@ -76,6 +76,7 @@ impl ScalarUDFImpl for ArrayReplace { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_replace" } @@ -84,11 +85,11 @@ impl ScalarUDFImpl for ArrayReplace { &self.signature } - fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, args: &[DataType]) -> Result { Ok(args[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_replace_inner)(args) } @@ -119,6 +120,7 @@ impl ScalarUDFImpl for ArrayReplaceN { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_replace_n" } @@ -127,11 +129,11 @@ impl ScalarUDFImpl for ArrayReplaceN { &self.signature } - fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, args: &[DataType]) -> Result { Ok(args[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_replace_n_inner)(args) } @@ -162,6 +164,7 @@ impl ScalarUDFImpl for ArrayReplaceAll { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_replace_all" } @@ -170,11 +173,11 @@ impl ScalarUDFImpl for ArrayReplaceAll { &self.signature } - fn return_type(&self, args: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, args: &[DataType]) -> Result { Ok(args[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_replace_all_inner)(args) } @@ -183,7 +186,7 @@ impl ScalarUDFImpl for ArrayReplaceAll { } } -/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurences +/// For each element of `list_array[i]`, replaces up to `arr_n[i]` occurrences /// of `from_array[i]`, `to_array[i]`. /// /// The type of each **element** in `list_array` must be the same as the type of @@ -299,7 +302,7 @@ pub(crate) fn array_replace_inner(args: &[ArrayRef]) -> Result { return exec_err!("array_replace expects three arguments"); } - // replace at most one occurence for each element + // replace at most one occurrence for each element let arr_n = vec![1; args[0].len()]; let array = &args[0]; match array.data_type() { @@ -320,7 +323,7 @@ pub(crate) fn array_replace_n_inner(args: &[ArrayRef]) -> Result { return exec_err!("array_replace_n expects four arguments"); } - // replace the specified number of occurences + // replace the specified number of occurrences let arr_n = as_int64_array(&args[3])?.values().to_vec(); let array = &args[0]; match array.data_type() { diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-array/src/resize.rs index f3996110f904..c5855d054494 100644 --- a/datafusion/functions-array/src/resize.rs +++ b/datafusion/functions-array/src/resize.rs @@ -24,7 +24,7 @@ use arrow_buffer::{ArrowNativeType, OffsetBuffer}; use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_int64_array, as_large_list_array, as_list_array}; -use datafusion_common::{exec_err, internal_datafusion_err, ScalarValue}; +use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -57,6 +57,7 @@ impl ScalarUDFImpl for ArrayResize { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_resize" } @@ -65,7 +66,7 @@ impl ScalarUDFImpl for ArrayResize { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { List(field) | FixedSizeList(field, _) => Ok(List(field.clone())), LargeList(field) => Ok(LargeList(field.clone())), @@ -75,7 +76,7 @@ impl ScalarUDFImpl for ArrayResize { } } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_resize_inner)(args) } @@ -85,7 +86,7 @@ impl ScalarUDFImpl for ArrayResize { } /// array_resize SQL function -pub fn array_resize_inner(arg: &[ArrayRef]) -> datafusion_common::Result { +pub(crate) fn array_resize_inner(arg: &[ArrayRef]) -> Result { if arg.len() < 2 || arg.len() > 3 { return exec_err!("array_resize needs two or three arguments"); } @@ -98,11 +99,11 @@ pub fn array_resize_inner(arg: &[ArrayRef]) -> datafusion_common::Result { + List(field) => { let array = as_list_array(&arg[0])?; general_list_resize::(array, new_len, field, new_element) } - DataType::LargeList(field) => { + LargeList(field) => { let array = as_large_list_array(&arg[0])?; general_list_resize::(array, new_len, field, new_element) } @@ -116,7 +117,7 @@ fn general_list_resize( count_array: &Int64Array, field: &FieldRef, default_element: Option, -) -> datafusion_common::Result +) -> Result where O: TryInto, { diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-array/src/reverse.rs index 7eb9e53deef4..8324c407bd86 100644 --- a/datafusion/functions-array/src/reverse.rs +++ b/datafusion/functions-array/src/reverse.rs @@ -21,9 +21,10 @@ use crate::utils::make_scalar_function; use arrow::array::{Capacities, MutableArrayData}; use arrow_array::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::OffsetBuffer; +use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, FieldRef}; use datafusion_common::cast::{as_large_list_array, as_list_array}; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -56,6 +57,7 @@ impl ScalarUDFImpl for ArrayReverse { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_reverse" } @@ -64,11 +66,11 @@ impl ScalarUDFImpl for ArrayReverse { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { Ok(arg_types[0].clone()) } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_reverse_inner)(args) } @@ -78,21 +80,21 @@ impl ScalarUDFImpl for ArrayReverse { } /// array_reverse SQL function -pub fn array_reverse_inner(arg: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_reverse_inner(arg: &[ArrayRef]) -> Result { if arg.len() != 1 { return exec_err!("array_reverse needs one argument"); } match &arg[0].data_type() { - DataType::List(field) => { + List(field) => { let array = as_list_array(&arg[0])?; general_array_reverse::(array, field) } - DataType::LargeList(field) => { + LargeList(field) => { let array = as_large_list_array(&arg[0])?; general_array_reverse::(array, field) } - DataType::Null => Ok(arg[0].clone()), + Null => Ok(arg[0].clone()), array_type => exec_err!("array_reverse does not support type '{array_type:?}'."), } } @@ -100,7 +102,7 @@ pub fn array_reverse_inner(arg: &[ArrayRef]) -> datafusion_common::Result( array: &GenericListArray, field: &FieldRef, -) -> datafusion_common::Result +) -> Result where O: TryFrom, { diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 6a91e9078232..d231dce4cb68 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -23,6 +23,7 @@ use crate::extract::{array_element, array_slice}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::list_ndims; +use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::FunctionRewrite; @@ -42,7 +43,7 @@ impl FunctionRewrite for ArrayFunctionRewriter { expr: Expr, schema: &DFSchema, _config: &ConfigOptions, - ) -> datafusion_common::Result> { + ) -> Result> { let transformed = match expr { // array1 @> array2 -> array_has_all(array1, array2) Expr::BinaryExpr(BinaryExpr { left, op, right }) diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-array/src/set_ops.rs index df5bc91a2689..5f3087fafd6f 100644 --- a/datafusion/functions-array/src/set_ops.rs +++ b/datafusion/functions-array/src/set_ops.rs @@ -15,15 +15,16 @@ // specific language governing permissions and limitations // under the License. -//! Array Intersection, Union, and Distinct functions +//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions. -use crate::core::make_array_inner; +use crate::make_array::make_array_inner; use crate::utils::make_scalar_function; use arrow::array::{new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, FieldRef}; use arrow::row::{RowConverter, SortField}; +use arrow_schema::DataType::{FixedSizeList, LargeList, List, Null}; use datafusion_common::cast::{as_large_list_array, as_list_array}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::expr::ScalarFunction; @@ -48,7 +49,7 @@ make_udf_function!( ArrayIntersect, array_intersect, first_array second_array, - "Returns an array of the elements in the intersection of array1 and array2.", + "returns an array of the elements in the intersection of array1 and array2.", array_intersect_udf ); @@ -56,7 +57,7 @@ make_udf_function!( ArrayDistinct, array_distinct, array, - "return distinct values from the array after removing duplicates.", + "returns distinct values from the array after removing duplicates.", array_distinct_udf ); @@ -79,6 +80,7 @@ impl ScalarUDFImpl for ArrayUnion { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_union" } @@ -89,8 +91,8 @@ impl ScalarUDFImpl for ArrayUnion { fn return_type(&self, arg_types: &[DataType]) -> Result { match (&arg_types[0], &arg_types[1]) { - (&DataType::Null, dt) => Ok(dt.clone()), - (dt, DataType::Null) => Ok(dt.clone()), + (&Null, dt) => Ok(dt.clone()), + (dt, Null) => Ok(dt.clone()), (dt, _) => Ok(dt.clone()), } } @@ -126,6 +128,7 @@ impl ScalarUDFImpl for ArrayIntersect { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_intersect" } @@ -136,12 +139,8 @@ impl ScalarUDFImpl for ArrayIntersect { fn return_type(&self, arg_types: &[DataType]) -> Result { match (arg_types[0].clone(), arg_types[1].clone()) { - (DataType::Null, DataType::Null) | (DataType::Null, _) => Ok(DataType::Null), - (_, DataType::Null) => Ok(DataType::List(Arc::new(Field::new( - "item", - DataType::Null, - true, - )))), + (Null, Null) | (Null, _) => Ok(Null), + (_, Null) => Ok(List(Arc::new(Field::new("item", Null, true)))), (dt, _) => Ok(dt), } } @@ -174,6 +173,7 @@ impl ScalarUDFImpl for ArrayDistinct { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_distinct" } @@ -183,7 +183,6 @@ impl ScalarUDFImpl for ArrayDistinct { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; match &arg_types[0] { List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( "item", @@ -218,17 +217,17 @@ fn array_distinct_inner(args: &[ArrayRef]) -> Result { } // handle null - if args[0].data_type() == &DataType::Null { + if args[0].data_type() == &Null { return Ok(args[0].clone()); } // handle for list & largelist match args[0].data_type() { - DataType::List(field) => { + List(field) => { let array = as_list_array(&args[0])?; general_array_distinct(array, field) } - DataType::LargeList(field) => { + LargeList(field) => { let array = as_large_list_array(&args[0])?; general_array_distinct(array, field) } @@ -257,10 +256,10 @@ fn generic_set_lists( field: Arc, set_op: SetOp, ) -> Result { - if matches!(l.value_type(), DataType::Null) { + if matches!(l.value_type(), Null) { let field = Arc::new(Field::new("item", r.value_type(), true)); return general_array_distinct::(r, &field); - } else if matches!(r.value_type(), DataType::Null) { + } else if matches!(r.value_type(), Null) { let field = Arc::new(Field::new("item", l.value_type(), true)); return general_array_distinct::(l, &field); } @@ -331,43 +330,43 @@ fn general_set_op( set_op: SetOp, ) -> Result { match (array1.data_type(), array2.data_type()) { - (DataType::Null, DataType::List(field)) => { + (Null, List(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&DataType::Null)); + return Ok(new_empty_array(&Null)); } let array = as_list_array(&array2)?; general_array_distinct::(array, field) } - (DataType::List(field), DataType::Null) => { + (List(field), Null) => { if set_op == SetOp::Intersect { return make_array_inner(&[]); } let array = as_list_array(&array1)?; general_array_distinct::(array, field) } - (DataType::Null, DataType::LargeList(field)) => { + (Null, LargeList(field)) => { if set_op == SetOp::Intersect { - return Ok(new_empty_array(&DataType::Null)); + return Ok(new_empty_array(&Null)); } let array = as_large_list_array(&array2)?; general_array_distinct::(array, field) } - (DataType::LargeList(field), DataType::Null) => { + (LargeList(field), Null) => { if set_op == SetOp::Intersect { return make_array_inner(&[]); } let array = as_large_list_array(&array1)?; general_array_distinct::(array, field) } - (DataType::Null, DataType::Null) => Ok(new_empty_array(&DataType::Null)), + (Null, Null) => Ok(new_empty_array(&Null)), - (DataType::List(field), DataType::List(_)) => { + (List(field), List(_)) => { let array1 = as_list_array(&array1)?; let array2 = as_list_array(&array2)?; generic_set_lists::(array1, array2, field.clone(), set_op) } - (DataType::LargeList(field), DataType::LargeList(_)) => { + (LargeList(field), LargeList(_)) => { let array1 = as_large_list_array(&array1)?; let array2 = as_large_list_array(&array2)?; generic_set_lists::(array1, array2, field.clone(), set_op) diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-array/src/sort.rs index 2f3fa33e6857..af78712065fc 100644 --- a/datafusion/functions-array/src/sort.rs +++ b/datafusion/functions-array/src/sort.rs @@ -24,7 +24,7 @@ use arrow_buffer::{BooleanBufferBuilder, NullBuffer, OffsetBuffer}; use arrow_schema::DataType::{FixedSizeList, LargeList, List}; use arrow_schema::{DataType, Field, SortOptions}; use datafusion_common::cast::{as_list_array, as_string_array}; -use datafusion_common::exec_err; +use datafusion_common::{exec_err, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -57,6 +57,7 @@ impl ScalarUDFImpl for ArraySort { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_sort" } @@ -65,7 +66,7 @@ impl ScalarUDFImpl for ArraySort { &self.signature } - fn return_type(&self, arg_types: &[DataType]) -> datafusion_common::Result { + fn return_type(&self, arg_types: &[DataType]) -> Result { match &arg_types[0] { List(field) | FixedSizeList(field, _) => Ok(List(Arc::new(Field::new( "item", @@ -83,7 +84,7 @@ impl ScalarUDFImpl for ArraySort { } } - fn invoke(&self, args: &[ColumnarValue]) -> datafusion_common::Result { + fn invoke(&self, args: &[ColumnarValue]) -> Result { make_scalar_function(array_sort_inner)(args) } @@ -93,7 +94,7 @@ impl ScalarUDFImpl for ArraySort { } /// Array_sort SQL function -pub fn array_sort_inner(args: &[ArrayRef]) -> datafusion_common::Result { +pub fn array_sort_inner(args: &[ArrayRef]) -> Result { if args.is_empty() || args.len() > 3 { return exec_err!("array_sort expects one to three arguments"); } @@ -157,7 +158,7 @@ pub fn array_sort_inner(args: &[ArrayRef]) -> datafusion_common::Result datafusion_common::Result { +fn order_desc(modifier: &str) -> Result { match modifier.to_uppercase().as_str() { "DESC" => Ok(true), "ASC" => Ok(false), @@ -165,7 +166,7 @@ fn order_desc(modifier: &str) -> datafusion_common::Result { } } -fn order_nulls_first(modifier: &str) -> datafusion_common::Result { +fn order_nulls_first(modifier: &str) -> Result { match modifier.to_uppercase().as_str() { "NULLS FIRST" => Ok(true), "NULLS LAST" => Ok(false), diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs index 3140866f5ff6..38059035005b 100644 --- a/datafusion/functions-array/src/string.rs +++ b/datafusion/functions-array/src/string.rs @@ -32,7 +32,7 @@ use datafusion_common::{plan_err, DataFusionError, Result}; use std::any::{type_name, Any}; use crate::utils::{downcast_arg, make_scalar_function}; -use arrow_schema::DataType::{LargeUtf8, Utf8}; +use arrow_schema::DataType::{FixedSizeList, LargeList, LargeUtf8, List, Null, Utf8}; use datafusion_common::cast::{ as_generic_string_array, as_large_list_array, as_list_array, as_string_array, }; @@ -133,6 +133,7 @@ impl ScalarUDFImpl for ArrayToString { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "array_to_string" } @@ -142,7 +143,6 @@ impl ScalarUDFImpl for ArrayToString { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; Ok(match arg_types[0] { List(_) | LargeList(_) | FixedSizeList(_, _) => Utf8, _ => { @@ -195,6 +195,7 @@ impl ScalarUDFImpl for StringToArray { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "string_to_array" } @@ -204,7 +205,6 @@ impl ScalarUDFImpl for StringToArray { } fn return_type(&self, arg_types: &[DataType]) -> Result { - use DataType::*; Ok(match arg_types[0] { Utf8 | LargeUtf8 => { List(Arc::new(Field::new("item", arg_types[0].clone(), true))) @@ -258,7 +258,7 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { with_null_string: bool, ) -> Result<&mut String> { match arr.data_type() { - DataType::List(..) => { + List(..) => { let list_array = as_list_array(&arr)?; for i in 0..list_array.len() { compute_array_to_string( @@ -272,7 +272,7 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { Ok(arg) } - DataType::LargeList(..) => { + LargeList(..) => { let list_array = as_large_list_array(&arr)?; for i in 0..list_array.len() { compute_array_to_string( @@ -286,7 +286,7 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { Ok(arg) } - DataType::Null => Ok(arg), + Null => Ok(arg), data_type => { macro_rules! array_function { ($ARRAY_TYPE:ident) => { @@ -339,7 +339,7 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { let arr_type = arr.data_type(); let string_arr = match arr_type { - DataType::List(_) | DataType::FixedSizeList(_, _) => { + List(_) | FixedSizeList(_, _) => { let list_array = as_list_array(&arr)?; generate_string_array::( list_array, @@ -348,7 +348,7 @@ pub(super) fn array_to_string_inner(args: &[ArrayRef]) -> Result { with_null_string, )? } - DataType::LargeList(_) => { + LargeList(_) => { let list_array = as_large_list_array(&arr)?; generate_string_array::( list_array, From f7e55814b9e009f310cb49c9e694a778da938a23 Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Mon, 25 Mar 2024 19:19:57 +0800 Subject: [PATCH 063/117] Move bit_length and chr functions to datafusion_functions (#9782) * Move bit_length function to datafusion_functions Signed-off-by: Chojan Shang * Move chr function to datafusion_functions Signed-off-by: Chojan Shang * Port error sqllogictests Signed-off-by: Chojan Shang * Refactor to keep ui Signed-off-by: Chojan Shang * Make clippy happy Signed-off-by: Chojan Shang --------- Signed-off-by: Chojan Shang --- datafusion/expr/src/built_in_function.rs | 18 +--- datafusion/expr/src/expr_fn.rs | 14 --- datafusion/functions/src/string/bit_length.rs | 85 ++++++++++++++++ datafusion/functions/src/string/chr.rs | 96 +++++++++++++++++++ datafusion/functions/src/string/mod.rs | 16 ++++ datafusion/physical-expr/src/functions.rs | 89 ----------------- .../physical-expr/src/string_expressions.rs | 31 +----- datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 -- datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 10 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 - datafusion/sqllogictest/test_files/expr.slt | 6 ++ 13 files changed, 211 insertions(+), 174 deletions(-) create mode 100644 datafusion/functions/src/string/bit_length.rs create mode 100644 datafusion/functions/src/string/chr.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index b3f17ae3c2ca..bb0f79f8eca4 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -103,12 +103,8 @@ pub enum BuiltinScalarFunction { Cot, // string functions - /// bit_length - BitLength, /// character_length CharacterLength, - /// chr - Chr, /// concat Concat, /// concat_ws @@ -222,9 +218,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::BitLength => Volatility::Immutable, BuiltinScalarFunction::CharacterLength => Volatility::Immutable, - BuiltinScalarFunction::Chr => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, @@ -263,13 +257,9 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::BitLength => { - utf8_to_int_type(&input_expr_types[0], "bit_length") - } BuiltinScalarFunction::CharacterLength => { utf8_to_int_type(&input_expr_types[0], "character_length") } - BuiltinScalarFunction::Chr => Ok(Utf8), BuiltinScalarFunction::Coalesce => { // COALESCE has multiple args and they might get coerced, get a preview of this let coerced_types = data_types(input_expr_types, &self.signature()); @@ -377,15 +367,11 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::BitLength - | BuiltinScalarFunction::CharacterLength + BuiltinScalarFunction::CharacterLength | BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::Chr => { - Signature::uniform(1, vec![Int64], self.volatility()) - } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { Signature::one_of( vec![ @@ -599,13 +585,11 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => &["coalesce"], // string functions - BuiltinScalarFunction::BitLength => &["bit_length"], BuiltinScalarFunction::CharacterLength => { &["character_length", "char_length", "length"] } BuiltinScalarFunction::Concat => &["concat"], BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], - BuiltinScalarFunction::Chr => &["chr"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], BuiltinScalarFunction::Left => &["left"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index f75d8869671e..0ea946288e0f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -578,24 +578,12 @@ scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argu scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); // string functions -scalar_expr!( - BitLength, - bit_length, - string, - "the number of bits in the `string`" -); scalar_expr!( CharacterLength, character_length, string, "the number of characters in the `string`" ); -scalar_expr!( - Chr, - chr, - code_point, - "converts the Unicode code point to a UTF8 character" -); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); @@ -1044,9 +1032,7 @@ mod test { test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); - test_scalar_expr!(BitLength, bit_length, string); test_scalar_expr!(CharacterLength, character_length, string); - test_scalar_expr!(Chr, chr, string); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs new file mode 100644 index 000000000000..9f612751584e --- /dev/null +++ b/datafusion/functions/src/string/bit_length.rs @@ -0,0 +1,85 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::kernels::length::bit_length; +use std::any::Any; + +use arrow::datatypes::DataType; + +use datafusion_common::{exec_err, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +#[derive(Debug)] +pub(super) struct BitLengthFunc { + signature: Signature, +} + +impl BitLengthFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for BitLengthFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "bit_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "bit_length") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!( + "bit_length function requires 1 argument, got {}", + args.len() + ); + } + + match &args[0] { + ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), + ColumnarValue::Scalar(v) => match v { + ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + v.as_ref().map(|x| (x.len() * 8) as i32), + ))), + ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), + )), + _ => unreachable!(), + }, + } + } +} diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs new file mode 100644 index 000000000000..df3b803ba659 --- /dev/null +++ b/datafusion/functions/src/string/chr.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::array::StringArray; +use arrow::datatypes::DataType; +use arrow::datatypes::DataType::Int64; +use arrow::datatypes::DataType::Utf8; + +use datafusion_common::cast::as_int64_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, Volatility}; +use datafusion_expr::{ScalarUDFImpl, Signature}; + +use crate::string::common::*; + +/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. +/// chr(65) = 'A' +pub fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| { + if integer == 0 { + exec_err!("null character not permitted.") + } else { + match core::char::from_u32(integer as u32) { + Some(integer) => Ok(integer.to_string()), + None => { + exec_err!("requested character too large for encoding.") + } + } + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +#[derive(Debug)] +pub(super) struct ChrFunc { + signature: Signature, +} + +impl ChrFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for ChrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "chr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(chr, vec![])(args) + } +} diff --git a/datafusion/functions/src/string/mod.rs b/datafusion/functions/src/string/mod.rs index d2b9fb2da805..81639c45f7ff 100644 --- a/datafusion/functions/src/string/mod.rs +++ b/datafusion/functions/src/string/mod.rs @@ -22,7 +22,9 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; mod ascii; +mod bit_length; mod btrim; +mod chr; mod common; mod levenshtein; mod lower; @@ -40,7 +42,9 @@ mod uuid; // create UDFs make_udf_function!(ascii::AsciiFunc, ASCII, ascii); +make_udf_function!(bit_length::BitLengthFunc, BIT_LENGTH, bit_length); make_udf_function!(btrim::BTrimFunc, BTRIM, btrim); +make_udf_function!(chr::ChrFunc, CHR, chr); make_udf_function!(levenshtein::LevenshteinFunc, LEVENSHTEIN, levenshtein); make_udf_function!(ltrim::LtrimFunc, LTRIM, ltrim); make_udf_function!(lower::LowerFunc, LOWER, lower); @@ -63,11 +67,21 @@ pub mod expr_fn { super::ascii().call(vec![arg1]) } + #[doc = "Returns the number of bits in the `string`"] + pub fn bit_length(arg: Expr) -> Expr { + super::bit_length().call(vec![arg]) + } + #[doc = "Removes all characters, spaces by default, from both sides of a string"] pub fn btrim(args: Vec) -> Expr { super::btrim().call(args) } + #[doc = "Converts the Unicode code point to a UTF8 character"] + pub fn chr(arg: Expr) -> Expr { + super::chr().call(vec![arg]) + } + #[doc = "Returns the Levenshtein distance between the two given strings"] pub fn levenshtein(arg1: Expr, arg2: Expr) -> Expr { super::levenshtein().call(vec![arg1, arg2]) @@ -143,7 +157,9 @@ pub mod expr_fn { pub fn functions() -> Vec> { vec![ ascii(), + bit_length(), btrim(), + chr(), levenshtein(), lower(), ltrim(), diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 163598c2df82..cd9bba63d624 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -35,7 +35,6 @@ use std::sync::Arc; use arrow::{ array::ArrayRef, - compute::kernels::length::bit_length, datatypes::{DataType, Int32Type, Int64Type, Schema}, }; use arrow_array::Array; @@ -255,18 +254,6 @@ pub fn create_physical_fun( Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } // string functions - BuiltinScalarFunction::BitLength => Arc::new(|args| match &args[0] { - ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( - v.as_ref().map(|x| (x.len() * 8) as i32), - ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), - )), - _ => unreachable!(), - }, - }), BuiltinScalarFunction::CharacterLength => { Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { @@ -290,9 +277,6 @@ pub fn create_physical_fun( ), }) } - BuiltinScalarFunction::Chr => { - Arc::new(|args| make_scalar_function_inner(string_expressions::chr)(args)) - } BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { @@ -611,23 +595,6 @@ mod tests { #[test] fn test_functions() -> Result<()> { - test_function!( - BitLength, - &[lit("chars")], - Ok(Some(40)), - i32, - Int32, - Int32Array - ); - test_function!( - BitLength, - &[lit("josé")], - Ok(Some(40)), - i32, - Int32, - Int32Array - ); - test_function!(BitLength, &[lit("")], Ok(Some(0)), i32, Int32, Int32Array); #[cfg(feature = "unicode_expressions")] test_function!( CharacterLength, @@ -675,62 +642,6 @@ mod tests { Int32, Int32Array ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(128175)))], - Ok(Some("💯")), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(120)))], - Ok(Some("x")), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(128175)))], - Ok(Some("💯")), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(0)))], - exec_err!("null character not permitted."), - &str, - Utf8, - StringArray - ); - test_function!( - Chr, - &[lit(ScalarValue::Int64(Some(i64::MAX)))], - exec_err!("requested character too large for encoding."), - &str, - Utf8, - StringArray - ); test_function!( Concat, &[lit("aa"), lit("bb"), lit("cc"),], diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 812b746354a4..2185b7c5b4a1 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -33,40 +33,11 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::{ - cast::{as_generic_string_array, as_int64_array, as_string_array}, + cast::{as_generic_string_array, as_string_array}, exec_err, ScalarValue, }; use datafusion_expr::ColumnarValue; -/// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. -/// chr(65) = 'A' -pub fn chr(args: &[ArrayRef]) -> Result { - let integer_array = as_int64_array(&args[0])?; - - // first map is the iterator, second is for the `Option<_>` - let result = integer_array - .iter() - .map(|integer: Option| { - integer - .map(|integer| { - if integer == 0 { - exec_err!("null character not permitted.") - } else { - match core::char::from_u32(integer as u32) { - Some(integer) => Ok(integer.to_string()), - None => { - exec_err!("requested character too large for encoding.") - } - } - } - }) - .transpose() - }) - .collect::>()?; - - Ok(Arc::new(result) as ArrayRef) -} - /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' pub fn concat(args: &[ColumnarValue]) -> Result { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 297e355dd7b1..f405ecf976be 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -563,10 +563,10 @@ enum ScalarFunction { Trunc = 19; // 20 was Array // RegexpMatch = 21; - BitLength = 22; + // 22 was BitLength // 23 was Btrim CharacterLength = 24; - Chr = 25; + // 25 was Chr Concat = 26; ConcatWithSeparator = 27; // 28 was DatePart diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index dce815f0f234..0d22ba5db773 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22928,9 +22928,7 @@ impl serde::Serialize for ScalarFunction { Self::Sin => "Sin", Self::Sqrt => "Sqrt", Self::Trunc => "Trunc", - Self::BitLength => "BitLength", Self::CharacterLength => "CharacterLength", - Self::Chr => "Chr", Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", @@ -22990,9 +22988,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sin", "Sqrt", "Trunc", - "BitLength", "CharacterLength", - "Chr", "Concat", "ConcatWithSeparator", "InitCap", @@ -23081,9 +23077,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sin" => Ok(ScalarFunction::Sin), "Sqrt" => Ok(ScalarFunction::Sqrt), "Trunc" => Ok(ScalarFunction::Trunc), - "BitLength" => Ok(ScalarFunction::BitLength), "CharacterLength" => Ok(ScalarFunction::CharacterLength), - "Chr" => Ok(ScalarFunction::Chr), "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2292687b45a6..07c3fad15373 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2862,10 +2862,10 @@ pub enum ScalarFunction { Trunc = 19, /// 20 was Array /// RegexpMatch = 21; - BitLength = 22, + /// 22 was BitLength /// 23 was Btrim CharacterLength = 24, - Chr = 25, + /// 25 was Chr Concat = 26, ConcatWithSeparator = 27, /// 28 was DatePart @@ -3001,9 +3001,7 @@ impl ScalarFunction { ScalarFunction::Sin => "Sin", ScalarFunction::Sqrt => "Sqrt", ScalarFunction::Trunc => "Trunc", - ScalarFunction::BitLength => "BitLength", ScalarFunction::CharacterLength => "CharacterLength", - ScalarFunction::Chr => "Chr", ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", @@ -3057,9 +3055,7 @@ impl ScalarFunction { "Sin" => Some(Self::Sin), "Sqrt" => Some(Self::Sqrt), "Trunc" => Some(Self::Trunc), - "BitLength" => Some(Self::BitLength), "CharacterLength" => Some(Self::CharacterLength), - "Chr" => Some(Self::Chr), "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index b78e3ae6dc61..d5eebcb69841 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,8 +48,8 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, asinh, atan, atan2, atanh, bit_length, cbrt, ceil, character_length, chr, - coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, + acosh, asinh, atan, atan2, atanh, cbrt, ceil, character_length, coalesce, + concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -458,9 +458,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Concat => Self::Concat, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, - ScalarFunction::BitLength => Self::BitLength, ScalarFunction::CharacterLength => Self::CharacterLength, - ScalarFunction::Chr => Self::Chr, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, @@ -1418,13 +1416,9 @@ pub fn parse_expr( ScalarFunction::Signum => { Ok(signum(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::BitLength => { - Ok(bit_length(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::CharacterLength => { Ok(character_length(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], registry, codec)?)), ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 0c0f0c6e0a92..0432b54acfa8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1481,9 +1481,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, - BuiltinScalarFunction::BitLength => Self::BitLength, BuiltinScalarFunction::CharacterLength => Self::CharacterLength, - BuiltinScalarFunction::Chr => Self::Chr, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 70fdc26a6002..75bcbc07755b 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -415,6 +415,12 @@ SELECT chr(CAST(NULL AS int)) ---- NULL +statement error DataFusion error: Execution error: null character not permitted. +SELECT chr(CAST(0 AS int)) + +statement error DataFusion error: Execution error: requested character too large for encoding. +SELECT chr(CAST(9223372036854775807 AS bigint)) + query T SELECT concat('a','b','c') ---- From ad89ff82421e9c4670f4440dd5a6fa6fb55c40c3 Mon Sep 17 00:00:00 2001 From: Harvey Yue Date: Mon, 25 Mar 2024 20:57:55 +0800 Subject: [PATCH 064/117] Support tencent cloud COS storage in `datafusion-cli` (#9734) * Support tencent cloud COS storage * Fix clippy * Update docs/source/user-guide/cli.md --------- Co-authored-by: Andrew Lamb --- datafusion-cli/src/catalog.rs | 2 +- datafusion-cli/src/exec.rs | 16 +++++++++ datafusion-cli/src/object_storage.rs | 50 +++++++++++++++++++--------- docs/source/user-guide/cli.md | 37 ++++++++++++++++---- 4 files changed, 81 insertions(+), 24 deletions(-) diff --git a/datafusion-cli/src/catalog.rs b/datafusion-cli/src/catalog.rs index 46dd8bb00f06..0fbb7a5908b5 100644 --- a/datafusion-cli/src/catalog.rs +++ b/datafusion-cli/src/catalog.rs @@ -177,7 +177,7 @@ impl SchemaProvider for DynamicFileSchemaProvider { // Register the store for this URL. Here we don't have access // to any command options so the only choice is to use an empty collection match scheme { - "s3" | "oss" => { + "s3" | "oss" | "cos" => { state = state.add_table_options_extension(AwsOptions::default()); } "gs" | "gcs" => { diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 4e374a4c0032..114e3cefa3bf 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -415,6 +415,7 @@ mod tests { let locations = vec![ "s3://bucket/path/file.parquet", "oss://bucket/path/file.parquet", + "cos://bucket/path/file.parquet", "gcs://bucket/path/file.parquet", ]; let mut ctx = SessionContext::new(); @@ -497,6 +498,21 @@ mod tests { Ok(()) } + #[tokio::test] + async fn create_object_store_table_cos() -> Result<()> { + let access_key_id = "fake_access_key_id"; + let secret_access_key = "fake_secret_access_key"; + let endpoint = "fake_endpoint"; + let location = "cos://bucket/path/file.parquet"; + + // Should be OK + let sql = format!("CREATE EXTERNAL TABLE test STORED AS PARQUET + OPTIONS('aws.access_key_id' '{access_key_id}', 'aws.secret_access_key' '{secret_access_key}', 'aws.cos.endpoint' '{endpoint}') LOCATION '{location}'"); + create_external_table_test(location, &sql).await?; + + Ok(()) + } + #[tokio::test] async fn create_object_store_table_gcs() -> Result<()> { let service_account_path = "fake_service_account_path"; diff --git a/datafusion-cli/src/object_storage.rs b/datafusion-cli/src/object_storage.rs index 033c8f839ab2..94560cb9d8da 100644 --- a/datafusion-cli/src/object_storage.rs +++ b/datafusion-cli/src/object_storage.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::fmt::{Debug, Display}; use std::sync::Arc; -use datafusion::common::{config_namespace, exec_datafusion_err, exec_err, internal_err}; +use datafusion::common::{exec_datafusion_err, exec_err, internal_err}; use datafusion::error::{DataFusionError, Result}; use datafusion::execution::context::SessionState; use datafusion::prelude::SessionContext; @@ -106,12 +106,27 @@ impl CredentialProvider for S3CredentialProvider { pub fn get_oss_object_store_builder( url: &Url, aws_options: &AwsOptions, +) -> Result { + get_object_store_builder(url, aws_options, true) +} + +pub fn get_cos_object_store_builder( + url: &Url, + aws_options: &AwsOptions, +) -> Result { + get_object_store_builder(url, aws_options, false) +} + +fn get_object_store_builder( + url: &Url, + aws_options: &AwsOptions, + virtual_hosted_style_request: bool, ) -> Result { let bucket_name = get_bucket_name(url)?; let mut builder = AmazonS3Builder::from_env() - .with_virtual_hosted_style_request(true) + .with_virtual_hosted_style_request(virtual_hosted_style_request) .with_bucket_name(bucket_name) - // oss don't care about the "region" field + // oss/cos don't care about the "region" field .with_region("do_not_care"); if let (Some(access_key_id), Some(secret_access_key)) = @@ -122,7 +137,7 @@ pub fn get_oss_object_store_builder( .with_secret_access_key(secret_access_key); } - if let Some(endpoint) = &aws_options.oss.endpoint { + if let Some(endpoint) = &aws_options.endpoint { builder = builder.with_endpoint(endpoint); } @@ -171,14 +186,8 @@ pub struct AwsOptions { pub session_token: Option, /// AWS Region pub region: Option, - /// Object Storage Service options - pub oss: OssOptions, -} - -config_namespace! { - pub struct OssOptions { - pub endpoint: Option, default = None - } + /// OSS or COS Endpoint + pub endpoint: Option, } impl ExtensionOptions for AwsOptions { @@ -210,8 +219,8 @@ impl ExtensionOptions for AwsOptions { "region" => { self.region.set(rem, value)?; } - "oss" => { - self.oss.set(rem, value)?; + "oss" | "cos" => { + self.endpoint.set(rem, value)?; } _ => { return internal_err!("Config value \"{}\" not found on AwsOptions", rem); @@ -252,7 +261,7 @@ impl ExtensionOptions for AwsOptions { .visit(&mut v, "secret_access_key", ""); self.session_token.visit(&mut v, "session_token", ""); self.region.visit(&mut v, "region", ""); - self.oss.visit(&mut v, "oss", ""); + self.endpoint.visit(&mut v, "endpoint", ""); v.0 } } @@ -376,7 +385,7 @@ pub(crate) fn register_options(ctx: &SessionContext, scheme: &str) { // Match the provided scheme against supported cloud storage schemes: match scheme { // For Amazon S3 or Alibaba Cloud OSS - "s3" | "oss" => { + "s3" | "oss" | "cos" => { // Register AWS specific table options in the session context: ctx.register_table_options_extension(AwsOptions::default()) } @@ -415,6 +424,15 @@ pub(crate) async fn get_object_store( let builder = get_oss_object_store_builder(url, options)?; Arc::new(builder.build()?) } + "cos" => { + let Some(options) = table_options.extensions.get::() else { + return exec_err!( + "Given table options incompatible with the 'cos' scheme" + ); + }; + let builder = get_cos_object_store_builder(url, options)?; + Arc::new(builder.build()?) + } "gs" | "gcs" => { let Some(options) = table_options.extensions.get::() else { return exec_err!( diff --git a/docs/source/user-guide/cli.md b/docs/source/user-guide/cli.md index a94e2427eaa2..da4c9870545a 100644 --- a/docs/source/user-guide/cli.md +++ b/docs/source/user-guide/cli.md @@ -312,9 +312,9 @@ select count(*) from hits; CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS( - 'access_key_id' '******', - 'secret_access_key' '******', - 'region' 'us-east-2' + 'aws.access_key_id' '******', + 'aws.secret_access_key' '******', + 'aws.region' 'us-east-2' ) LOCATION 's3://bucket/path/file.parquet'; ``` @@ -365,9 +365,9 @@ Details of the environment variables that can be used are: CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS( - 'access_key_id' '******', - 'secret_access_key' '******', - 'endpoint' 'https://bucket.oss-cn-hangzhou.aliyuncs.com' + 'aws.access_key_id' '******', + 'aws.secret_access_key' '******', + 'aws.oss.endpoint' 'https://bucket.oss-cn-hangzhou.aliyuncs.com' ) LOCATION 'oss://bucket/path/file.parquet'; ``` @@ -380,6 +380,29 @@ The supported OPTIONS are: Note that the `endpoint` format of oss needs to be: `https://{bucket}.{oss-region-endpoint}` +## Registering COS Data Sources + +[Tencent cloud COS](https://cloud.tencent.com/product/cos) data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. + +```sql +CREATE EXTERNAL TABLE test +STORED AS PARQUET +OPTIONS( + 'aws.access_key_id' '******', + 'aws.secret_access_key' '******', + 'aws.cos.endpoint' 'https://cos.ap-singapore.myqcloud.com' +) +LOCATION 'cos://bucket/path/file.parquet'; +``` + +The supported OPTIONS are: + +- access_key_id +- secret_access_key +- endpoint + +Note that the `endpoint` format of urls must be: `https://cos.{cos-region-endpoint}` + ## Registering GCS Data Sources [Google Cloud Storage](https://cloud.google.com/storage) data sources can be registered by executing a `CREATE EXTERNAL TABLE` SQL statement. @@ -388,7 +411,7 @@ Note that the `endpoint` format of oss needs to be: `https://{bucket}.{oss-regio CREATE EXTERNAL TABLE test STORED AS PARQUET OPTIONS( - 'service_account_path' '/tmp/gcs.json', + 'gcp.service_account_path' '/tmp/gcs.json', ) LOCATION 'gs://bucket/path/file.parquet'; ``` From 0b955776d172b4eb304097d8bab0bd82c3d20915 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Mon, 25 Mar 2024 14:50:22 +0000 Subject: [PATCH 065/117] Make it easier to register configuration extension ... (#9781) ... options closes #9529 --- datafusion/execution/src/config.rs | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 360bac71c510..0a7a87c7d81a 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -22,7 +22,10 @@ use std::{ sync::Arc, }; -use datafusion_common::{config::ConfigOptions, Result, ScalarValue}; +use datafusion_common::{ + config::{ConfigExtension, ConfigOptions}, + Result, ScalarValue, +}; /// Configuration options for [`SessionContext`]. /// @@ -198,6 +201,12 @@ impl SessionConfig { self } + /// Insert new [ConfigExtension] + pub fn with_option_extension(mut self, extension: T) -> Self { + self.options_mut().extensions.insert(extension); + self + } + /// Get [`target_partitions`] /// /// [`target_partitions`]: datafusion_common::config::ExecutionOptions::target_partitions From 349c586ea7c680c802f36c750e5f1e2dc7beb65b Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Mon, 25 Mar 2024 20:49:54 -0600 Subject: [PATCH 066/117] Expr to sql : Case (#9798) --- datafusion/sql/src/unparser/expr.rs | 51 ++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index d007d4a843a2..49a940060bf3 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -113,10 +113,38 @@ impl Unparser<'_> { } Expr::Case(Case { expr, - when_then_expr: _, - else_expr: _, + when_then_expr, + else_expr, }) => { - not_impl_err!("Unsupported expression: {expr:?}") + let conditions = when_then_expr + .iter() + .map(|(w, _)| self.expr_to_sql(w)) + .collect::>>()?; + let results = when_then_expr + .iter() + .map(|(_, t)| self.expr_to_sql(t)) + .collect::>>()?; + let operand = match expr.as_ref() { + Some(e) => match self.expr_to_sql(e) { + Ok(sql_expr) => Some(Box::new(sql_expr)), + Err(_) => None, + }, + None => None, + }; + let else_result = match else_expr.as_ref() { + Some(e) => match self.expr_to_sql(e) { + Ok(sql_expr) => Some(Box::new(sql_expr)), + Err(_) => None, + }, + None => None, + }; + + Ok(ast::Expr::Case { + operand, + conditions, + results, + else_result, + }) } Expr::Cast(Cast { expr, data_type }) => { let inner_expr = self.expr_to_sql(expr)?; @@ -565,7 +593,7 @@ mod tests { use datafusion_common::TableReference; use datafusion_expr::{ - col, expr::AggregateFunction, lit, ColumnarValue, ScalarUDF, ScalarUDFImpl, + case, col, expr::AggregateFunction, lit, ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; @@ -622,6 +650,13 @@ mod tests { .gt(lit(4)), r#"("a"."b"."c" > 4)"#, ), + ( + case(col("a")) + .when(lit(1), lit(true)) + .when(lit(0), lit(false)) + .otherwise(lit(ScalarValue::Null))?, + r#"CASE "a" WHEN 1 THEN true WHEN 0 THEN false ELSE NULL END"#, + ), ( Expr::Cast(Cast { expr: Box::new(col("a")), @@ -698,17 +733,17 @@ mod tests { }), "COUNT(DISTINCT *)", ), - (Expr::IsNotNull(Box::new(col("a"))), r#""a" IS NOT NULL"#), + (col("a").is_not_null(), r#""a" IS NOT NULL"#), ( - Expr::IsTrue(Box::new((col("a") + col("b")).gt(lit(4)))), + (col("a") + col("b")).gt(lit(4)).is_true(), r#"(("a" + "b") > 4) IS TRUE"#, ), ( - Expr::IsFalse(Box::new((col("a") + col("b")).gt(lit(4)))), + (col("a") + col("b")).gt(lit(4)).is_false(), r#"(("a" + "b") > 4) IS FALSE"#, ), ( - Expr::IsUnknown(Box::new((col("a") + col("b")).gt(lit(4)))), + (col("a") + col("b")).gt(lit(4)).is_unknown(), r#"(("a" + "b") > 4) IS UNKNOWN"#, ), ]; From 39f4aaf5cd1abfc9204c3eb96effdb4ebcf5b882 Mon Sep 17 00:00:00 2001 From: Sebastian Espinosa <40347293+sebastian2296@users.noreply.github.com> Date: Tue, 26 Mar 2024 08:39:16 -0500 Subject: [PATCH 067/117] feat: Between expr to sql string (#9803) * feat: Between expr to sql string * fix: use between logical expr * fix: format using fmt * fix: remove redundant field name --- datafusion/sql/src/unparser/expr.rs | 35 +++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 49a940060bf3..550b02cea7f6 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -97,11 +97,19 @@ impl Unparser<'_> { } Expr::Between(Between { expr, - negated: _, - low: _, - high: _, + negated, + low, + high, }) => { - not_impl_err!("Unsupported expression: {expr:?}") + let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_low = self.expr_to_sql(low)?; + let sql_high = self.expr_to_sql(high)?; + Ok(ast::Expr::Nested(Box::new(self.between_op_to_sql( + sql_parser_expr, + *negated, + sql_low, + sql_high, + )))) } Expr::Column(col) => self.col_to_sql(col), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { @@ -291,6 +299,21 @@ impl Unparser<'_> { } } + pub(super) fn between_op_to_sql( + &self, + expr: ast::Expr, + negated: bool, + low: ast::Expr, + high: ast::Expr, + ) -> ast::Expr { + ast::Expr::Between { + expr: Box::new(expr), + negated, + low: Box::new(low), + high: Box::new(high), + } + } + fn op_to_sql(&self, op: &Operator) -> Result { match op { Operator::Eq => Ok(ast::BinaryOperator::Eq), @@ -746,6 +769,10 @@ mod tests { (col("a") + col("b")).gt(lit(4)).is_unknown(), r#"(("a" + "b") > 4) IS UNKNOWN"#, ), + ( + Expr::between(col("a"), lit(1), lit(7)), + r#"("a" BETWEEN 1 AND 7)"#, + ), ]; for (expr, expected) in tests { From e337832946fc69f6ceccd14b96a071cc2cd4693d Mon Sep 17 00:00:00 2001 From: Eren Avsarogullari Date: Tue, 26 Mar 2024 16:37:51 -0700 Subject: [PATCH 068/117] Issue-New - Add array_empty and list_empty functions support as alias for empty function (#9807) --- datafusion/functions-array/src/empty.rs | 6 +- datafusion/sqllogictest/test_files/array.slt | 71 ++++++++++++++++++- docs/source/user-guide/expressions.md | 1 + .../source/user-guide/sql/scalar_functions.md | 24 +++++-- 4 files changed, 95 insertions(+), 7 deletions(-) diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-array/src/empty.rs index f11a6f07cfc8..d5fa174eee5f 100644 --- a/datafusion/functions-array/src/empty.rs +++ b/datafusion/functions-array/src/empty.rs @@ -45,7 +45,11 @@ impl ArrayEmpty { pub fn new() -> Self { Self { signature: Signature::array(Volatility::Immutable), - aliases: vec![String::from("empty")], + aliases: vec![ + "empty".to_string(), + "array_empty".to_string(), + "list_empty".to_string(), + ], } } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index ad979a316709..3456963aacfc 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -6116,7 +6116,7 @@ from fixed_size_flatten_table; [1, 2, 3] [1, 2, 3, 4, 5, 6] [1, 2, 3] [1.0, 2.1, 2.2, 3.2, 3.3, 3.4] [1, 2, 3, 4, 5, 6] [8, 9, 10, 11, 12, 13] [1, 2, 3] [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] -## empty +## empty (aliases: `array_empty`, `list_empty`) # empty scalar function #1 query B select empty(make_array(1)); @@ -6207,6 +6207,75 @@ NULL false false +## array_empty (aliases: `empty`, `list_empty`) +# array_empty scalar function #1 +query B +select array_empty(make_array(1)); +---- +false + +query B +select array_empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + +# array_empty scalar function #2 +query B +select array_empty(make_array()); +---- +true + +query B +select array_empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + +# array_empty scalar function #3 +query B +select array_empty(make_array(NULL)); +---- +false + +query B +select array_empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false + +## list_empty (aliases: `empty`, `array_empty`) +# list_empty scalar function #1 +query B +select list_empty(make_array(1)); +---- +false + +query B +select list_empty(arrow_cast(make_array(1), 'LargeList(Int64)')); +---- +false + +# list_empty scalar function #2 +query B +select list_empty(make_array()); +---- +true + +query B +select list_empty(arrow_cast(make_array(), 'LargeList(Null)')); +---- +true + +# list_empty scalar function #3 +query B +select list_empty(make_array(NULL)); +---- +false + +query B +select list_empty(arrow_cast(make_array(NULL), 'LargeList(Null)')); +---- +false + +# string_to_array scalar function query ? SELECT string_to_array('abcxxxdef', 'xxx') ---- diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index 005d2ec94229..a5fc13491677 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -217,6 +217,7 @@ select log(-1), log(0), sqrt(-1); | array_dims(array) | Returns an array of the array's dimensions. `array_dims([[1, 2, 3], [4, 5, 6]]) -> [2, 3]` | | array_distinct(array) | Returns distinct values from the array after removing duplicates. `array_distinct([1, 3, 2, 3, 1, 2, 4]) -> [1, 2, 3, 4]` | | array_element(array, index) | Extracts the element with the index n from the array `array_element([1, 2, 3, 4], 3) -> 3` | +| empty(array) | Returns true for an empty array or false for a non-empty array. `empty([1]) -> false` | | flatten(array) | Converts an array of arrays to a flat array `flatten([[1], [2, 3], [4, 5, 6]]) -> [1, 2, 3, 4, 5, 6]` | | array_length(array, dimension) | Returns the length of the array dimension. `array_length([1, 2, 3, 4, 5]) -> 5` | | array_ndims(array) | Returns the number of dimensions of the array. `array_ndims([[1, 2, 3], [4, 5, 6]]) -> 2` | diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 5eb3436b4256..52edf4bb7217 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -1931,6 +1931,7 @@ from_unixtime(expression) - [array_has_all](#array_has_all) - [array_has_any](#array_has_any) - [array_element](#array_element) +- [array_empty](#array_empty) - [array_except](#array_except) - [array_extract](#array_extract) - [array_fill](#array_fill) @@ -3009,6 +3010,11 @@ empty(array) +------------------+ ``` +#### Aliases + +- array_empty, +- list_empty + ### `generate_series` Similar to the range function, but it includes the upper bound. @@ -3038,10 +3044,6 @@ generate_series(start, stop, step) _Alias of [array_append](#array_append)._ -### `list_sort` - -_Alias of [array_sort](#array_sort)._ - ### `list_cat` _Alias of [array_concat](#array_concat)._ @@ -3062,6 +3064,10 @@ _Alias of [array_dims](#array_distinct)._ _Alias of [array_element](#array_element)._ +### `list_empty` + +_Alias of [empty](#empty)._ + ### `list_except` _Alias of [array_element](#array_except)._ @@ -3170,13 +3176,17 @@ _Alias of [array_reverse](#array_reverse)._ _Alias of [array_slice](#array_slice)._ +### `list_sort` + +_Alias of [array_sort](#array_sort)._ + ### `list_to_string` _Alias of [array_to_string](#array_to_string)._ ### `list_union` -_Alias of [array_to_string](#array_union)._ +_Alias of [array_union](#array_union)._ ### `make_array` @@ -3186,6 +3196,10 @@ Returns an Arrow array using the specified input expressions. make_array(expression1[, ..., expression_n]) ``` +### `array_empty` + +_Alias of [empty](#empty)._ + #### Arguments - **expression_n**: Expression to include in the output array. From 0b11d143ec77aeac0207bd20f790b21790397767 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Wed, 27 Mar 2024 11:00:29 +0800 Subject: [PATCH 069/117] expr like to sql (#9805) --- datafusion/sql/src/unparser/expr.rs | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 550b02cea7f6..8610269bdb29 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -175,14 +175,17 @@ impl Unparser<'_> { not_impl_err!("Unsupported expression: {expr:?}") } Expr::Like(Like { - negated: _, + negated, expr, - pattern: _, - escape_char: _, + pattern, + escape_char, case_insensitive: _, - }) => { - not_impl_err!("Unsupported expression: {expr:?}") - } + }) => Ok(ast::Expr::Like { + negated: *negated, + expr: Box::new(self.expr_to_sql(expr)?), + pattern: Box::new(self.expr_to_sql(pattern)?), + escape_char: *escape_char, + }), Expr::AggregateFunction(agg) => { let func_name = if let AggregateFunctionDefinition::BuiltIn(built_in) = &agg.func_def @@ -706,6 +709,16 @@ mod tests { ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]), r#"dummy_udf("a", "b")"#, ), + ( + Expr::Like(Like { + negated: true, + expr: Box::new(col("a")), + pattern: Box::new(lit("foo")), + escape_char: Some('o'), + case_insensitive: true, + }), + r#""a" NOT LIKE 'foo' ESCAPE 'o'"#, + ), ( Expr::Literal(ScalarValue::Date64(Some(0))), r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#, From ccd850bef0706664e49d792d0442c0ac16df866b Mon Sep 17 00:00:00 2001 From: Sebastian Espinosa <40347293+sebastian2296@users.noreply.github.com> Date: Tue, 26 Mar 2024 22:04:14 -0500 Subject: [PATCH 070/117] feat: Not expr to string (#9802) * feat: Not expr to string * fix: use not logical expr * fix: format using fmt * fix: import within test mod * fix: format new import --- datafusion/sql/src/unparser/expr.rs | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8610269bdb29..a29b5014b1ce 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -24,7 +24,9 @@ use datafusion_expr::{ expr::{AggregateFunctionDefinition, Alias, InList, ScalarFunction, WindowFunction}, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, }; -use sqlparser::ast::{self, Function, FunctionArg, Ident}; +use sqlparser::ast::{ + self, Expr as AstExpr, Function, FunctionArg, Ident, UnaryOperator, +}; use super::Unparser; @@ -267,6 +269,13 @@ impl Unparser<'_> { Expr::IsUnknown(expr) => { Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) } + Expr::Not(expr) => { + let sql_parser_expr = self.expr_to_sql(expr)?; + Ok(AstExpr::UnaryOp { + op: UnaryOperator::Not, + expr: Box::new(sql_parser_expr), + }) + } _ => not_impl_err!("Unsupported expression: {expr:?}"), } } @@ -619,8 +628,8 @@ mod tests { use datafusion_common::TableReference; use datafusion_expr::{ - case, col, expr::AggregateFunction, lit, ColumnarValue, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + case, col, expr::AggregateFunction, lit, not, ColumnarValue, ScalarUDF, + ScalarUDFImpl, Signature, Volatility, }; use crate::unparser::dialect::CustomDialect; @@ -782,6 +791,7 @@ mod tests { (col("a") + col("b")).gt(lit(4)).is_unknown(), r#"(("a" + "b") > 4) IS UNKNOWN"#, ), + (not(col("a")), r#"NOT "a""#), ( Expr::between(col("a"), lit(1), lit(7)), r#"("a" BETWEEN 1 AND 7)"#, From a4c71e220be20c591b3dde38d5a9aa410e458466 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 27 Mar 2024 06:09:23 +0300 Subject: [PATCH 071/117] [Minor]: Move some repetitive codes to functions(proto) (#9811) * add parse_exprs util * Minor changes * Minor changes * Add vector field converter * Add serialize exprs * proto to arrow field conversion * Simplifications * All tests pass * Simplifications --- .../proto/src/logical_plan/from_proto.rs | 176 ++++++------------ datafusion/proto/src/logical_plan/to_proto.rs | 116 +++++------- .../proto/src/physical_plan/from_proto.rs | 106 ++++------- datafusion/proto/src/physical_plan/mod.rs | 51 +++-- .../proto/src/physical_plan/to_proto.rs | 124 ++++++------ 5 files changed, 231 insertions(+), 342 deletions(-) diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d5eebcb69841..4b9874bf8f65 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -323,11 +323,7 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { DataType::FixedSizeList(Arc::new(list_type), list_size) } arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct( - strct - .sub_field_types - .iter() - .map(Field::try_from) - .collect::>()?, + parse_proto_fields_to_fields(&strct.sub_field_types)?.into(), ), arrow_type::ArrowTypeEnum::Union(union) => { let union_mode = protobuf::UnionMode::try_from(union.union_mode) @@ -336,11 +332,7 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { protobuf::UnionMode::Dense => UnionMode::Dense, protobuf::UnionMode::Sparse => UnionMode::Sparse, }; - let union_fields = union - .union_types - .iter() - .map(TryInto::try_into) - .collect::, _>>()?; + let union_fields = parse_proto_fields_to_fields(&union.union_types)?; // Default to index based type ids if not provided let type_ids: Vec<_> = match union.type_ids.is_empty() { @@ -763,10 +755,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .map(|f| f.field.clone()) .collect::>>(); let fields = fields.ok_or_else(|| Error::required("UnionField"))?; - let fields = fields - .iter() - .map(Field::try_from) - .collect::, _>>()?; + let fields = parse_proto_fields_to_fields(&fields)?; let fields = UnionFields::new(ids, fields); let v_id = val.value_id as i8; let val = match &val.value { @@ -937,11 +926,7 @@ pub fn parse_expr( match expr_type { ExprType::BinaryExpr(binary_expr) => { let op = from_proto_binary_op(&binary_expr.op)?; - let operands = binary_expr - .operands - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?; + let operands = parse_exprs(&binary_expr.operands, registry, codec)?; if operands.len() < 2 { return Err(proto_error( @@ -1025,16 +1010,8 @@ pub fn parse_expr( .window_function .as_ref() .ok_or_else(|| Error::required("window_function"))?; - let partition_by = expr - .partition_by - .iter() - .map(|e| parse_expr(e, registry, codec)) - .collect::, _>>()?; - let mut order_by = expr - .order_by - .iter() - .map(|e| parse_expr(e, registry, codec)) - .collect::, _>>()?; + let partition_by = parse_exprs(&expr.partition_by, registry, codec)?; + let mut order_by = parse_exprs(&expr.order_by, registry, codec)?; let window_frame = expr .window_frame .as_ref() @@ -1130,10 +1107,7 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, - expr.expr - .iter() - .map(|e| parse_expr(e, registry, codec)) - .collect::, _>>()?, + parse_exprs(&expr.expr, registry, codec)?, expr.distinct, parse_optional_expr(expr.filter.as_deref(), registry, codec)? .map(Box::new), @@ -1331,11 +1305,7 @@ pub fn parse_expr( parse_required_expr(negative.expr.as_deref(), registry, "expr", codec)?, ))), ExprType::Unnest(unnest) => { - let exprs = unnest - .exprs - .iter() - .map(|e| parse_expr(e, registry, codec)) - .collect::, _>>()?; + let exprs = parse_exprs(&unnest.exprs, registry, codec)?; Ok(Expr::Unnest(Unnest { exprs })) } ExprType::InList(in_list) => Ok(Expr::InList(InList::new( @@ -1345,11 +1315,7 @@ pub fn parse_expr( "expr", codec, )?), - in_list - .list - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, + parse_exprs(&in_list.list, registry, codec)?, in_list.negated, ))), ExprType::Wildcard(protobuf::Wildcard { qualifier }) => Ok(Expr::Wildcard { @@ -1401,18 +1367,8 @@ pub fn parse_expr( Ok(factorial(parse_expr(&args[0], registry, codec)?)) } ScalarFunction::Ceil => Ok(ceil(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Round => Ok(round( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::Trunc => Ok(trunc( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), + ScalarFunction::Round => Ok(round(parse_exprs(args, registry, codec)?)), + ScalarFunction::Trunc => Ok(trunc(parse_exprs(args, registry, codec)?)), ScalarFunction::Signum => { Ok(signum(parse_expr(&args[0], registry, codec)?)) } @@ -1442,30 +1398,14 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Concat => Ok(concat_expr( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::ConcatWithSeparator => Ok(concat_ws_expr( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::Lpad => Ok(lpad( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), - ScalarFunction::Rpad => Ok(rpad( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), + ScalarFunction::Concat => { + Ok(concat_expr(parse_exprs(args, registry, codec)?)) + } + ScalarFunction::ConcatWithSeparator => { + Ok(concat_ws_expr(parse_exprs(args, registry, codec)?)) + } + ScalarFunction::Lpad => Ok(lpad(parse_exprs(args, registry, codec)?)), + ScalarFunction::Rpad => Ok(rpad(parse_exprs(args, registry, codec)?)), ScalarFunction::EndsWith => Ok(ends_with( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, @@ -1494,12 +1434,9 @@ pub fn parse_expr( parse_expr(&args[1], registry, codec)?, parse_expr(&args[2], registry, codec)?, )), - ScalarFunction::Coalesce => Ok(coalesce( - args.to_owned() - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, _>>()?, - )), + ScalarFunction::Coalesce => { + Ok(coalesce(parse_exprs(args, registry, codec)?)) + } ScalarFunction::Pi => Ok(pi()), ScalarFunction::Power => Ok(power( parse_expr(&args[0], registry, codec)?, @@ -1543,9 +1480,7 @@ pub fn parse_expr( }; Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( scalar_fn, - args.iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>()?, + parse_exprs(args, registry, codec)?, ))) } ExprType::AggregateUdfExpr(pb) => { @@ -1553,10 +1488,7 @@ pub fn parse_expr( Ok(Expr::AggregateFunction(expr::AggregateFunction::new_udf( agg_fn, - pb.args - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>()?, + parse_exprs(&pb.args, registry, codec)?, false, parse_optional_expr(pb.filter.as_deref(), registry, codec)?.map(Box::new), parse_vec_expr(&pb.order_by, registry, codec)?, @@ -1566,28 +1498,16 @@ pub fn parse_expr( ExprType::GroupingSet(GroupingSetNode { expr }) => { Ok(Expr::GroupingSet(GroupingSets( expr.iter() - .map(|expr_list| { - expr_list - .expr - .iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>() - }) + .map(|expr_list| parse_exprs(&expr_list.expr, registry, codec)) .collect::, Error>>()?, ))) } ExprType::Cube(CubeNode { expr }) => Ok(Expr::GroupingSet(GroupingSet::Cube( - expr.iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>()?, + parse_exprs(expr, registry, codec)?, ))), - ExprType::Rollup(RollupNode { expr }) => { - Ok(Expr::GroupingSet(GroupingSet::Rollup( - expr.iter() - .map(|expr| parse_expr(expr, registry, codec)) - .collect::, Error>>()?, - ))) - } + ExprType::Rollup(RollupNode { expr }) => Ok(Expr::GroupingSet( + GroupingSet::Rollup(parse_exprs(expr, registry, codec)?), + )), ExprType::Placeholder(PlaceholderNode { id, data_type }) => match data_type { None => Ok(Expr::Placeholder(Placeholder::new(id.clone(), None))), Some(data_type) => Ok(Expr::Placeholder(Placeholder::new( @@ -1598,6 +1518,24 @@ pub fn parse_expr( } } +/// Parse a vector of `protobuf::LogicalExprNode`s. +pub fn parse_exprs<'a, I>( + protos: I, + registry: &dyn FunctionRegistry, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + let res = protos + .into_iter() + .map(|elem| { + parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) + }) + .collect::>>()?; + Ok(res) +} + /// Parse an optional escape_char for Like, ILike, SimilarTo fn parse_escape_char(s: &str) -> Result> { match s.len() { @@ -1654,12 +1592,7 @@ fn parse_vec_expr( registry: &dyn FunctionRegistry, codec: &dyn LogicalExtensionCodec, ) -> Result>, Error> { - let res = p - .iter() - .map(|elem| { - parse_expr(elem, registry, codec).map_err(|e| plan_datafusion_err!("{}", e)) - }) - .collect::>>()?; + let res = parse_exprs(p, registry, codec)?; // Convert empty vector to None. Ok((!res.is_empty()).then_some(res)) } @@ -1690,3 +1623,16 @@ fn parse_required_expr( fn proto_error>(message: S) -> Error { Error::General(message.into()) } + +/// Converts a vector of `protobuf::Field`s to `Arc`s. +fn parse_proto_fields_to_fields<'a, I>( + fields: I, +) -> std::result::Result, Error> +where + I: IntoIterator, +{ + fields + .into_iter() + .map(Field::try_from) + .collect::>() +} diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 0432b54acfa8..1335d511a0ea 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -19,6 +19,8 @@ //! DataFusion logical plans to be serialized and transmitted between //! processes. +use std::sync::Arc; + use crate::protobuf::{ self, arrow_type::ArrowTypeEnum, @@ -186,10 +188,7 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { field_type: Some(Box::new(item_type.as_ref().try_into()?)), })), DataType::Struct(struct_fields) => Self::Struct(protobuf::Struct { - sub_field_types: struct_fields - .iter() - .map(|field| field.as_ref().try_into()) - .collect::, Error>>()?, + sub_field_types: convert_arc_fields_to_proto_fields(struct_fields)?, }), DataType::Union(fields, union_mode) => { let union_mode = match union_mode { @@ -197,10 +196,7 @@ impl TryFrom<&DataType> for protobuf::arrow_type::ArrowTypeEnum { UnionMode::Dense => protobuf::UnionMode::Dense, }; Self::Union(protobuf::Union { - union_types: fields - .iter() - .map(|(_, field)| field.as_ref().try_into()) - .collect::, Error>>()?, + union_types: convert_arc_fields_to_proto_fields(fields.iter().map(|(_, item)|item))?, union_mode: union_mode.into(), type_ids: fields.iter().map(|(x, _)| x as i32).collect(), }) @@ -262,11 +258,7 @@ impl TryFrom<&Schema> for protobuf::Schema { fn try_from(schema: &Schema) -> Result { Ok(Self { - columns: schema - .fields() - .iter() - .map(|f| f.as_ref().try_into()) - .collect::, Error>>()?, + columns: convert_arc_fields_to_proto_fields(schema.fields())?, metadata: schema.metadata.clone(), }) } @@ -277,11 +269,7 @@ impl TryFrom for protobuf::Schema { fn try_from(schema: SchemaRef) -> Result { Ok(Self { - columns: schema - .fields() - .iter() - .map(|f| f.as_ref().try_into()) - .collect::, Error>>()?, + columns: convert_arc_fields_to_proto_fields(schema.fields())?, metadata: schema.metadata.clone(), }) } @@ -486,6 +474,19 @@ impl TryFrom<&WindowFrame> for protobuf::WindowFrame { } } +pub fn serialize_exprs<'a, I>( + exprs: I, + codec: &dyn LogicalExtensionCodec, +) -> Result, Error> +where + I: IntoIterator, +{ + exprs + .into_iter() + .map(|expr| serialize_expr(expr, codec)) + .collect::, Error>>() +} + pub fn serialize_expr( expr: &Expr, codec: &dyn LogicalExtensionCodec, @@ -543,11 +544,7 @@ pub fn serialize_expr( // We need to reverse exprs since operands are expected to be // linearized from left innermost to right outermost (but while // traversing the chain we do the exact opposite). - operands: exprs - .into_iter() - .rev() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + operands: serialize_exprs(exprs.into_iter().rev(), codec)?, op: format!("{op:?}"), }; protobuf::LogicalExprNode { @@ -639,14 +636,8 @@ pub fn serialize_expr( } else { None }; - let partition_by = partition_by - .iter() - .map(|e| serialize_expr(e, codec)) - .collect::, _>>()?; - let order_by = order_by - .iter() - .map(|e| serialize_expr(e, codec)) - .collect::, _>>()?; + let partition_by = serialize_exprs(partition_by, codec)?; + let order_by = serialize_exprs(order_by, codec)?; let window_frame: Option = Some(window_frame.try_into()?); @@ -744,20 +735,14 @@ pub fn serialize_expr( let aggregate_expr = protobuf::AggregateExprNode { aggr_function: aggr_function.into(), - expr: args - .iter() - .map(|v| serialize_expr(v, codec)) - .collect::, _>>()?, + expr: serialize_exprs(args, codec)?, distinct: *distinct, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e, codec)?)), None => None, }, order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, _>>()?, + Some(e) => serialize_exprs(e, codec)?, None => vec![], }, }; @@ -769,19 +754,13 @@ pub fn serialize_expr( expr_type: Some(ExprType::AggregateUdfExpr(Box::new( protobuf::AggregateUdfExprNode { fun_name: fun.name().to_string(), - args: args - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + args: serialize_exprs(args, codec)?, filter: match filter { Some(e) => Some(Box::new(serialize_expr(e.as_ref(), codec)?)), None => None, }, order_by: match order_by { - Some(e) => e - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, _>>()?, + Some(e) => serialize_exprs(e, codec)?, None => vec![], }, }, @@ -801,10 +780,7 @@ pub fn serialize_expr( )) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let args = args - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?; + let args = serialize_exprs(args, codec)?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { let fun: protobuf::ScalarFunction = fun.try_into()?; @@ -997,10 +973,7 @@ pub fn serialize_expr( } Expr::Unnest(Unnest { exprs }) => { let expr = protobuf::Unnest { - exprs: exprs - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + exprs: serialize_exprs(exprs, codec)?, }; protobuf::LogicalExprNode { expr_type: Some(ExprType::Unnest(expr)), @@ -1013,10 +986,7 @@ pub fn serialize_expr( }) => { let expr = Box::new(protobuf::InListNode { expr: Some(Box::new(serialize_expr(expr.as_ref(), codec)?)), - list: list - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + list: serialize_exprs(list, codec)?, negated: *negated, }); protobuf::LogicalExprNode { @@ -1077,18 +1047,12 @@ pub fn serialize_expr( Expr::GroupingSet(GroupingSet::Cube(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Cube(CubeNode { - expr: exprs - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + expr: serialize_exprs(exprs, codec)?, })), }, Expr::GroupingSet(GroupingSet::Rollup(exprs)) => protobuf::LogicalExprNode { expr_type: Some(ExprType::Rollup(RollupNode { - expr: exprs - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + expr: serialize_exprs(exprs, codec)?, })), }, Expr::GroupingSet(GroupingSet::GroupingSets(exprs)) => { @@ -1098,10 +1062,7 @@ pub fn serialize_expr( .iter() .map(|expr_list| { Ok(LogicalExprList { - expr: expr_list - .iter() - .map(|expr| serialize_expr(expr, codec)) - .collect::, Error>>()?, + expr: serialize_exprs(expr_list, codec)?, }) }) .collect::, Error>>()?, @@ -1680,3 +1641,16 @@ fn encode_scalar_nested_value( _ => unreachable!(), } } + +/// Converts a vector of `Arc`s to `protobuf::Field`s +fn convert_arc_fields_to_proto_fields<'a, I>( + fields: I, +) -> Result, Error> +where + I: IntoIterator>, +{ + fields + .into_iter() + .map(|field| field.as_ref().try_into()) + .collect::, Error>>() +} diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index ca54d4e803ca..aaca4dc48236 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -83,10 +83,10 @@ pub fn parse_physical_sort_expr( proto: &protobuf::PhysicalSortExprNode, registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result { if let Some(expr) = &proto.expr { - let codec = DefaultPhysicalExtensionCodec {}; - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; + let expr = parse_physical_expr(expr.as_ref(), registry, input_schema, codec)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -109,22 +109,12 @@ pub fn parse_physical_sort_exprs( proto: &[protobuf::PhysicalSortExprNode], registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { proto .iter() .map(|sort_expr| { - if let Some(expr) = &sort_expr.expr { - let codec = DefaultPhysicalExtensionCodec {}; - let expr = - parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; - let options = SortOptions { - descending: !sort_expr.asc, - nulls_first: sort_expr.nulls_first, - }; - Ok(PhysicalSortExpr { expr, options }) - } else { - Err(proto_error("Unexpected empty physical expression")) - } + parse_physical_sort_expr(sort_expr, registry, input_schema, codec) }) .collect::>>() } @@ -144,23 +134,14 @@ pub fn parse_physical_window_expr( input_schema: &Schema, ) -> Result> { let codec = DefaultPhysicalExtensionCodec {}; - let window_node_expr = proto - .args - .iter() - .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) - .collect::>>()?; + let window_node_expr = + parse_physical_exprs(&proto.args, registry, input_schema, &codec)?; - let partition_by = proto - .partition_by - .iter() - .map(|p| parse_physical_expr(p, registry, input_schema, &codec)) - .collect::>>()?; + let partition_by = + parse_physical_exprs(&proto.partition_by, registry, input_schema, &codec)?; - let order_by = proto - .order_by - .iter() - .map(|o| parse_physical_sort_expr(o, registry, input_schema)) - .collect::>>()?; + let order_by = + parse_physical_sort_exprs(&proto.order_by, registry, input_schema, &codec)?; let window_frame = proto .window_frame @@ -186,6 +167,21 @@ pub fn parse_physical_window_expr( ) } +pub fn parse_physical_exprs<'a, I>( + protos: I, + registry: &dyn FunctionRegistry, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, +) -> Result>> +where + I: IntoIterator, +{ + protos + .into_iter() + .map(|p| parse_physical_expr(p, registry, input_schema, codec)) + .collect::>>() +} + /// Parses a physical expression from a protobuf. /// /// # Arguments @@ -276,10 +272,7 @@ pub fn parse_physical_expr( "expr", input_schema, )?, - e.list - .iter() - .map(|x| parse_physical_expr(x, registry, input_schema, codec)) - .collect::, _>>()?, + parse_physical_exprs(&e.list, registry, input_schema, codec)?, &e.negated, input_schema, )?, @@ -339,11 +332,7 @@ pub fn parse_physical_expr( ) })?; - let args = e - .args - .iter() - .map(|x| parse_physical_expr(x, registry, input_schema, codec)) - .collect::, _>>()?; + let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?; // TODO Do not create new the ExecutionProps let execution_props = ExecutionProps::new(); @@ -363,11 +352,7 @@ pub fn parse_physical_expr( let signature = udf.signature(); let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone()); - let args = e - .args - .iter() - .map(|x| parse_physical_expr(x, registry, input_schema, codec)) - .collect::, _>>()?; + let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?; Arc::new(ScalarFunctionExpr::new( e.name.as_str(), @@ -452,11 +437,12 @@ pub fn parse_protobuf_hash_partitioning( match partitioning { Some(hash_part) => { let codec = DefaultPhysicalExtensionCodec {}; - let expr = hash_part - .hash_expr - .iter() - .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) - .collect::>, _>>()?; + let expr = parse_physical_exprs( + &hash_part.hash_expr, + registry, + input_schema, + &codec, + )?; Ok(Some(Partitioning::Hash( expr, @@ -517,24 +503,12 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { let codec = DefaultPhysicalExtensionCodec {}; - let sort_expr = node_collection - .physical_sort_expr_nodes - .iter() - .map(|node| { - let expr = node - .expr - .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, &schema, &codec)) - .unwrap()?; - Ok(PhysicalSortExpr { - expr, - options: SortOptions { - descending: !node.asc, - nulls_first: node.nulls_first, - }, - }) - }) - .collect::>>()?; + let sort_expr = parse_physical_sort_exprs( + &node_collection.physical_sort_expr_nodes, + registry, + &schema, + &codec, + )?; output_ordering.push(sort_expr); } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index da31c5e762bc..00dacffe06c2 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -48,7 +48,7 @@ use datafusion::datasource::physical_plan::ParquetExec; use datafusion::datasource::physical_plan::{AvroExec, CsvExec}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; -use datafusion::physical_expr::PhysicalExprRef; +use datafusion::physical_expr::{PhysicalExprRef, PhysicalSortRequirement}; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateMode}; use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; use datafusion::physical_plan::analyze::AnalyzeExec; @@ -492,7 +492,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let input_phy_expr: Vec> = agg_node.expr.iter() .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); let ordering_req: Vec = agg_node.ordering_req.iter() - .map(|e| parse_physical_sort_expr(e, registry, &physical_schema).unwrap()).collect(); + .map(|e| parse_physical_sort_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); agg_node.aggregate_function.as_ref().map(|func| { match func { AggregateFunction::AggrFunction(i) => { @@ -736,6 +736,7 @@ impl AsExecutionPlan for PhysicalPlanNode { &sym_join.left_sort_exprs, registry, &left_schema, + extension_codec, )?; let left_sort_exprs = if left_sort_exprs.is_empty() { None @@ -747,6 +748,7 @@ impl AsExecutionPlan for PhysicalPlanNode { &sym_join.right_sort_exprs, registry, &right_schema, + extension_codec, )?; let right_sort_exprs = if right_sort_exprs.is_empty() { None @@ -1018,14 +1020,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .sort_order .as_ref() .map(|collection| { - collection - .physical_sort_expr_nodes - .iter() - .map(|proto| { - parse_physical_sort_expr(proto, registry, &sink_schema) - .map(Into::into) - }) - .collect::>>() + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + registry, + &sink_schema, + extension_codec, + ) + .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) }) .transpose()?; Ok(Arc::new(FileSinkExec::new( @@ -1049,14 +1050,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .sort_order .as_ref() .map(|collection| { - collection - .physical_sort_expr_nodes - .iter() - .map(|proto| { - parse_physical_sort_expr(proto, registry, &sink_schema) - .map(Into::into) - }) - .collect::>>() + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + registry, + &sink_schema, + extension_codec, + ) + .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) }) .transpose()?; Ok(Arc::new(FileSinkExec::new( @@ -1080,14 +1080,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .sort_order .as_ref() .map(|collection| { - collection - .physical_sort_expr_nodes - .iter() - .map(|proto| { - parse_physical_sort_expr(proto, registry, &sink_schema) - .map(Into::into) - }) - .collect::>>() + parse_physical_sort_exprs( + &collection.physical_sort_expr_nodes, + registry, + &sink_schema, + extension_codec, + ) + .map(|item| PhysicalSortRequirement::from_sort_exprs(&item)) }) .transpose()?; Ok(Arc::new(FileSinkExec::new( diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index b66709d0c5bd..e1574f48fb8e 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -79,18 +79,10 @@ impl TryFrom> for protobuf::PhysicalExprNode { fn try_from(a: Arc) -> Result { let codec = DefaultPhysicalExtensionCodec {}; - let expressions: Vec = a - .expressions() - .iter() - .map(|e| serialize_physical_expr(e.clone(), &codec)) - .collect::>>()?; + let expressions = serialize_physical_exprs(a.expressions(), &codec)?; - let ordering_req: Vec = a - .order_bys() - .unwrap_or(&[]) - .iter() - .map(|e| e.clone().try_into()) - .collect::>>()?; + let ordering_req = a.order_bys().unwrap_or(&[]).to_vec(); + let ordering_req = serialize_physical_sort_exprs(ordering_req, &codec)?; if let Some(a) = a.as_any().downcast_ref::() { let name = a.fun().name().to_string(); @@ -245,22 +237,12 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; let codec = DefaultPhysicalExtensionCodec {}; - let args = args - .into_iter() - .map(|e| serialize_physical_expr(e, &codec)) - .collect::>>()?; - - let partition_by = window_expr - .partition_by() - .iter() - .map(|p| serialize_physical_expr(p.clone(), &codec)) - .collect::>>()?; + let args = serialize_physical_exprs(args, &codec)?; + let partition_by = + serialize_physical_exprs(window_expr.partition_by().to_vec(), &codec)?; - let order_by = window_expr - .order_by() - .iter() - .map(|o| o.clone().try_into()) - .collect::>>()?; + let order_by = + serialize_physical_sort_exprs(window_expr.order_by().to_vec(), &codec)?; let window_frame: protobuf::WindowFrame = window_frame .as_ref() @@ -381,6 +363,45 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { Ok(AggrFn { inner, distinct }) } +pub fn serialize_physical_sort_exprs( + sort_exprs: I, + codec: &dyn PhysicalExtensionCodec, +) -> Result, DataFusionError> +where + I: IntoIterator, +{ + sort_exprs + .into_iter() + .map(|sort_expr| serialize_physical_sort_expr(sort_expr, codec)) + .collect() +} + +pub fn serialize_physical_sort_expr( + sort_expr: PhysicalSortExpr, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let PhysicalSortExpr { expr, options } = sort_expr; + let expr = serialize_physical_expr(expr, codec)?; + Ok(PhysicalSortExprNode { + expr: Some(Box::new(expr)), + asc: !options.descending, + nulls_first: options.nulls_first, + }) +} + +pub fn serialize_physical_exprs( + values: I, + codec: &dyn PhysicalExtensionCodec, +) -> Result, DataFusionError> +where + I: IntoIterator>, +{ + values + .into_iter() + .map(|value| serialize_physical_expr(value, codec)) + .collect() +} + /// Serialize a `PhysicalExpr` to default protobuf representation. /// /// If required, a [`PhysicalExtensionCodec`] can be provided which can handle @@ -488,27 +509,16 @@ pub fn serialize_physical_expr( }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::InList( - Box::new( - protobuf::PhysicalInListNode { - expr: Some(Box::new(serialize_physical_expr( - expr.expr().to_owned(), - codec, - )?)), - list: expr - .list() - .iter() - .map(|a| serialize_physical_expr(a.clone(), codec)) - .collect::, - DataFusionError, - >>()?, - negated: expr.negated(), - }, - ), - ), - ), + expr_type: Some(protobuf::physical_expr_node::ExprType::InList(Box::new( + protobuf::PhysicalInListNode { + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + list: serialize_physical_exprs(expr.list().to_vec(), codec)?, + negated: expr.negated(), + }, + ))), }) } else if let Some(expr) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { @@ -552,11 +562,7 @@ pub fn serialize_physical_expr( ))), }) } else if let Some(expr) = expr.downcast_ref::() { - let args: Vec = expr - .args() - .iter() - .map(|e| serialize_physical_expr(e.to_owned(), codec)) - .collect::, _>>()?; + let args = serialize_physical_exprs(expr.args().to_vec(), codec)?; if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { let fun: protobuf::ScalarFunction = (&fun).try_into()?; @@ -754,18 +760,8 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { let mut output_orderings = vec![]; for order in &conf.output_ordering { - let expr_node_vec = order - .iter() - .map(|sort_expr| { - let expr = serialize_physical_expr(sort_expr.expr.clone(), &codec)?; - Ok(PhysicalSortExprNode { - expr: Some(Box::new(expr)), - asc: !sort_expr.options.descending, - nulls_first: sort_expr.options.nulls_first, - }) - }) - .collect::>>()?; - output_orderings.push(expr_node_vec) + let ordering = serialize_physical_sort_exprs(order.to_vec(), &codec)?; + output_orderings.push(ordering) } // Fields must be added to the schema so that they can persist in the protobuf From 8d3504cbe7823d29116db237f4a4cb0bb6a5989d Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Tue, 26 Mar 2024 23:21:08 -0700 Subject: [PATCH 072/117] Implement IGNORE NULLS for LAST_VALUE (#9801) * Implement IGNORE NULLS for LAST_VALUE * address comments --------- Co-authored-by: Huaxin Gao --- datafusion/core/tests/sql/aggregates.rs | 80 ------------------- .../physical-expr/src/aggregate/build_in.rs | 17 ++-- .../physical-expr/src/aggregate/first_last.rs | 49 ++++++++++-- .../sqllogictest/test_files/aggregate.slt | 52 ++++++++++++ 4 files changed, 104 insertions(+), 94 deletions(-) diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 14bc7a3d4f68..84b791a3de05 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -321,83 +321,3 @@ async fn test_accumulator_row_accumulator() -> Result<()> { Ok(()) } - -#[tokio::test] -async fn test_first_value() -> Result<()> { - let session_ctx = SessionContext::new(); - session_ctx - .sql("CREATE TABLE abc AS VALUES (null,2,3), (4,5,6)") - .await? - .collect() - .await?; - - let results1 = session_ctx - .sql("SELECT FIRST_VALUE(column1) ignore nulls FROM abc") - .await? - .collect() - .await?; - let expected1 = [ - "+--------------------------+", - "| FIRST_VALUE(abc.column1) |", - "+--------------------------+", - "| 4 |", - "+--------------------------+", - ]; - assert_batches_eq!(expected1, &results1); - - let results2 = session_ctx - .sql("SELECT FIRST_VALUE(column1) respect nulls FROM abc") - .await? - .collect() - .await?; - let expected2 = [ - "+--------------------------+", - "| FIRST_VALUE(abc.column1) |", - "+--------------------------+", - "| |", - "+--------------------------+", - ]; - assert_batches_eq!(expected2, &results2); - - Ok(()) -} - -#[tokio::test] -async fn test_first_value_with_sort() -> Result<()> { - let session_ctx = SessionContext::new(); - session_ctx - .sql("CREATE TABLE abc AS VALUES (null,2,3), (null,1,6), (4, 5, 5), (1, 4, 7), (2, 3, 8)") - .await? - .collect() - .await?; - - let results1 = session_ctx - .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) ignore nulls FROM abc") - .await? - .collect() - .await?; - let expected1 = [ - "+--------------------------+", - "| FIRST_VALUE(abc.column1) |", - "+--------------------------+", - "| 2 |", - "+--------------------------+", - ]; - assert_batches_eq!(expected1, &results1); - - let results2 = session_ctx - .sql("SELECT FIRST_VALUE(column1 ORDER BY column2) respect nulls FROM abc") - .await? - .collect() - .await?; - let expected2 = [ - "+--------------------------+", - "| FIRST_VALUE(abc.column1) |", - "+--------------------------+", - "| |", - "+--------------------------+", - ]; - assert_batches_eq!(expected2, &results2); - - Ok(()) -} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 846431034c96..cee679863870 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -370,13 +370,16 @@ pub fn create_aggregate_expr( ) .with_ignore_nulls(ignore_nulls), ), - (AggregateFunction::LastValue, _) => Arc::new(expressions::LastValue::new( - input_phy_exprs[0].clone(), - name, - input_phy_types[0].clone(), - ordering_req.to_vec(), - ordering_types, - )), + (AggregateFunction::LastValue, _) => Arc::new( + expressions::LastValue::new( + input_phy_exprs[0].clone(), + name, + input_phy_types[0].clone(), + ordering_req.to_vec(), + ordering_types, + ) + .with_ignore_nulls(ignore_nulls), + ), (AggregateFunction::NthValue, _) => { let expr = &input_phy_exprs[0]; let Some(n) = input_phy_exprs[1] diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs b/datafusion/physical-expr/src/aggregate/first_last.rs index 17dd3ef1206d..6d6e32a14987 100644 --- a/datafusion/physical-expr/src/aggregate/first_last.rs +++ b/datafusion/physical-expr/src/aggregate/first_last.rs @@ -393,6 +393,7 @@ pub struct LastValue { expr: Arc, ordering_req: LexOrdering, requirement_satisfied: bool, + ignore_nulls: bool, } impl LastValue { @@ -412,9 +413,15 @@ impl LastValue { expr, ordering_req, requirement_satisfied, + ignore_nulls: false, } } + pub fn with_ignore_nulls(mut self, ignore_nulls: bool) -> Self { + self.ignore_nulls = ignore_nulls; + self + } + /// Returns the name of the aggregate expression. pub fn name(&self) -> &str { &self.name @@ -483,6 +490,7 @@ impl AggregateExpr for LastValue { &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), + self.ignore_nulls, ) .map(|acc| { Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ @@ -528,6 +536,7 @@ impl AggregateExpr for LastValue { &self.input_data_type, &self.order_by_data_types, self.ordering_req.clone(), + self.ignore_nulls, ) .map(|acc| { Box::new(acc.with_requirement_satisfied(self.requirement_satisfied)) as _ @@ -561,6 +570,8 @@ struct LastValueAccumulator { ordering_req: LexOrdering, // Stores whether incoming data already satisfies the ordering requirement. requirement_satisfied: bool, + // Ignore null values. + ignore_nulls: bool, } impl LastValueAccumulator { @@ -569,6 +580,7 @@ impl LastValueAccumulator { data_type: &DataType, ordering_dtypes: &[DataType], ordering_req: LexOrdering, + ignore_nulls: bool, ) -> Result { let orderings = ordering_dtypes .iter() @@ -581,6 +593,7 @@ impl LastValueAccumulator { orderings, ordering_req, requirement_satisfied, + ignore_nulls, }) } @@ -597,7 +610,17 @@ impl LastValueAccumulator { }; if self.requirement_satisfied { // Get last entry according to the order of data: - return Ok((!value.is_empty()).then_some(value.len() - 1)); + if self.ignore_nulls { + // If ignoring nulls, find the last non-null value. + for i in (0..value.len()).rev() { + if !value.is_null(i) { + return Ok(Some(i)); + } + } + return Ok(None); + } else { + return Ok((!value.is_empty()).then_some(value.len() - 1)); + } } let sort_columns = ordering_values .iter() @@ -611,8 +634,20 @@ impl LastValueAccumulator { } }) .collect::>(); - let indices = lexsort_to_indices(&sort_columns, Some(1))?; - Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + + if self.ignore_nulls { + let indices = lexsort_to_indices(&sort_columns, None)?; + // If ignoring nulls, find the last non-null value. + for index in indices.iter().flatten() { + if !value.is_null(index as usize) { + return Ok(Some(index as usize)); + } + } + Ok(None) + } else { + let indices = lexsort_to_indices(&sort_columns, Some(1))?; + Ok((!indices.is_empty()).then_some(indices.value(0) as _)) + } } fn with_requirement_satisfied(mut self, requirement_satisfied: bool) -> Self { @@ -746,7 +781,7 @@ mod tests { let mut first_accumulator = FirstValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; // first value in the tuple is start of the range (inclusive), // second value in the tuple is end of the range (exclusive) let ranges: Vec<(i64, i64)> = vec![(0, 10), (1, 11), (2, 13)]; @@ -814,13 +849,13 @@ mod tests { // LastValueAccumulator let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; last_accumulator.update_batch(&[arrs[0].clone()])?; let state1 = last_accumulator.state()?; let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; last_accumulator.update_batch(&[arrs[1].clone()])?; let state2 = last_accumulator.state()?; @@ -836,7 +871,7 @@ mod tests { } let mut last_accumulator = - LastValueAccumulator::try_new(&DataType::Int64, &[], vec![])?; + LastValueAccumulator::try_new(&DataType::Int64, &[], vec![], false)?; last_accumulator.merge_batch(&states)?; let merged_state = last_accumulator.state()?; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 19bcf6024b50..4929ab485d6d 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3376,3 +3376,55 @@ SELECT FIRST_VALUE(column1 ORDER BY column2) IGNORE NULLS FROM t; statement ok DROP TABLE t; + +# Test for ignore null in LAST_VALUE +statement ok +CREATE TABLE t AS VALUES (3), (4), (null::bigint); + +query I +SELECT LAST_VALUE(column1) FROM t; +---- +NULL + +query I +SELECT LAST_VALUE(column1) RESPECT NULLS FROM t; +---- +NULL + +query I +SELECT LAST_VALUE(column1) IGNORE NULLS FROM t; +---- +4 + +statement ok +DROP TABLE t; + +# Test for ignore null with ORDER BY in LAST_VALUE +statement ok +CREATE TABLE t AS VALUES (3, 3), (4, 4), (null::bigint, 1), (null::bigint, 2); + +query I +SELECT column1 FROM t ORDER BY column2 DESC; +---- +4 +3 +NULL +NULL + +query I +SELECT LAST_VALUE(column1 ORDER BY column2 DESC) FROM t; +---- +NULL + +query I +SELECT LAST_VALUE(column1 ORDER BY column2 DESC) RESPECT NULLS FROM t; +---- +NULL + +query I +SELECT LAST_VALUE(column1 ORDER BY column2 DESC) IGNORE NULLS FROM t; +---- +3 + +statement ok +DROP TABLE t; From 56c735c458c6d6dd7696941457dd4bbe95eaa2e0 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:24:37 +0300 Subject: [PATCH 073/117] [MINOR]: Move some repetitive codes to functions (#9810) * Minor changes * Accept both owned and reference --- datafusion/core/src/datasource/memory.rs | 17 +++---- datafusion/core/src/physical_planner.rs | 63 ++++++++++-------------- datafusion/physical-expr/src/lib.rs | 2 +- datafusion/physical-expr/src/planner.rs | 48 +++++++++--------- 4 files changed, 60 insertions(+), 70 deletions(-) diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 3c76ee635855..608a46144da3 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -33,7 +33,7 @@ use crate::physical_plan::{ common, DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, Partitioning, SendableRecordBatchStream, }; -use crate::physical_planner::create_physical_sort_expr; +use crate::physical_planner::create_physical_sort_exprs; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; @@ -231,16 +231,11 @@ impl TableProvider for MemTable { let file_sort_order = sort_order .iter() .map(|sort_exprs| { - sort_exprs - .iter() - .map(|expr| { - create_physical_sort_expr( - expr, - &df_schema, - state.execution_props(), - ) - }) - .collect::>>() + create_physical_sort_exprs( + sort_exprs, + &df_schema, + state.execution_props(), + ) }) .collect::>>()?; exec = exec.with_sort_information(file_sort_order); diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index ca708b05823e..deac3dcf46fc 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -43,7 +43,7 @@ use crate::logical_expr::{ Repartition, Union, UserDefinedLogicalNode, }; use crate::logical_expr::{Limit, Values}; -use crate::physical_expr::create_physical_expr; +use crate::physical_expr::{create_physical_expr, create_physical_exprs}; use crate::physical_optimizer::optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::analyze::AnalyzeExec; @@ -96,6 +96,7 @@ use datafusion_sql::utils::window_expr_common_partition_keys; use async_trait::async_trait; use datafusion_common::config::FormatOptions; +use datafusion_physical_expr::LexOrdering; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryStreamExt}; use itertools::{multiunzip, Itertools}; @@ -958,14 +959,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => { let physical_input = self.create_initial_plan(input, session_state).await?; let input_dfschema = input.as_ref().schema(); - let sort_expr = expr - .iter() - .map(|e| create_physical_sort_expr( - e, - input_dfschema, - session_state.execution_props(), - )) - .collect::>>()?; + let sort_expr = create_physical_sort_exprs(expr, input_dfschema, session_state.execution_props())?; let new_sort = SortExec::new(sort_expr, physical_input) .with_fetch(*fetch); Ok(Arc::new(new_sort)) @@ -1592,18 +1586,11 @@ pub fn create_window_expr_with_name( window_frame, null_treatment, }) => { - let args = args - .iter() - .map(|e| create_physical_expr(e, logical_schema, execution_props)) - .collect::>>()?; - let partition_by = partition_by - .iter() - .map(|e| create_physical_expr(e, logical_schema, execution_props)) - .collect::>>()?; - let order_by = order_by - .iter() - .map(|e| create_physical_sort_expr(e, logical_schema, execution_props)) - .collect::>>()?; + let args = create_physical_exprs(args, logical_schema, execution_props)?; + let partition_by = + create_physical_exprs(partition_by, logical_schema, execution_props)?; + let order_by = + create_physical_sort_exprs(order_by, logical_schema, execution_props)?; if !is_window_frame_bound_valid(window_frame) { return plan_err!( @@ -1670,10 +1657,8 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( order_by, null_treatment, }) => { - let args = args - .iter() - .map(|e| create_physical_expr(e, logical_input_schema, execution_props)) - .collect::>>()?; + let args = + create_physical_exprs(args, logical_input_schema, execution_props)?; let filter = match filter { Some(e) => Some(create_physical_expr( e, @@ -1683,17 +1668,11 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( None => None, }; let order_by = match order_by { - Some(e) => Some( - e.iter() - .map(|expr| { - create_physical_sort_expr( - expr, - logical_input_schema, - execution_props, - ) - }) - .collect::>>()?, - ), + Some(e) => Some(create_physical_sort_exprs( + e, + logical_input_schema, + execution_props, + )?), None => None, }; let ignore_nulls = null_treatment @@ -1780,6 +1759,18 @@ pub fn create_physical_sort_expr( } } +/// Create vector of physical sort expression from a vector of logical expression +pub fn create_physical_sort_exprs( + exprs: &[Expr], + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result { + exprs + .iter() + .map(|expr| create_physical_sort_expr(expr, input_dfschema, execution_props)) + .collect::>>() +} + impl DefaultPhysicalPlanner { /// Handles capturing the various plans for EXPLAIN queries /// diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 1791a6ed60b2..1dead099540b 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -53,7 +53,7 @@ pub use physical_expr::{ physical_exprs_bag_equal, physical_exprs_contains, physical_exprs_equal, PhysicalExpr, PhysicalExprRef, }; -pub use planner::create_physical_expr; +pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; pub use sort_expr::{ LexOrdering, LexOrderingRef, LexRequirement, LexRequirementRef, PhysicalSortExpr, diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 241f01a4170a..319d9ca2269a 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -168,20 +168,15 @@ pub fn create_physical_expr( } else { None }; - let when_expr = case + let (when_expr, then_expr): (Vec<&Expr>, Vec<&Expr>) = case .when_then_expr .iter() - .map(|(w, _)| { - create_physical_expr(w.as_ref(), input_dfschema, execution_props) - }) - .collect::>>()?; - let then_expr = case - .when_then_expr - .iter() - .map(|(_, t)| { - create_physical_expr(t.as_ref(), input_dfschema, execution_props) - }) - .collect::>>()?; + .map(|(w, t)| (w.as_ref(), t.as_ref())) + .unzip(); + let when_expr = + create_physical_exprs(when_expr, input_dfschema, execution_props)?; + let then_expr = + create_physical_exprs(then_expr, input_dfschema, execution_props)?; let when_then_expr: Vec<(Arc, Arc)> = when_expr .iter() @@ -248,10 +243,8 @@ pub fn create_physical_expr( } Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let physical_args = args - .iter() - .map(|e| create_physical_expr(e, input_dfschema, execution_props)) - .collect::>>()?; + let physical_args = + create_physical_exprs(args, input_dfschema, execution_props)?; match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -310,12 +303,8 @@ pub fn create_physical_expr( let value_expr = create_physical_expr(expr, input_dfschema, execution_props)?; - let list_exprs = list - .iter() - .map(|expr| { - create_physical_expr(expr, input_dfschema, execution_props) - }) - .collect::>>()?; + let list_exprs = + create_physical_exprs(list, input_dfschema, execution_props)?; expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, @@ -325,6 +314,21 @@ pub fn create_physical_expr( } } +/// Create vector of Physical Expression from a vector of logical expression +pub fn create_physical_exprs<'a, I>( + exprs: I, + input_dfschema: &DFSchema, + execution_props: &ExecutionProps, +) -> Result>> +where + I: IntoIterator, +{ + exprs + .into_iter() + .map(|expr| create_physical_expr(expr, input_dfschema, execution_props)) + .collect::>>() +} + #[cfg(test)] mod tests { use super::*; From 75caa9ceb38418e3b222be0a9189c878135f978a Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Wed, 27 Mar 2024 21:18:29 +0800 Subject: [PATCH 074/117] fix: ensure mutual compatibility of the two input schemas from recursive CTEs (#9795) * fix: Ensure mutual compatibility of the two input schemas from recursive CTEs * fix typo --- datafusion/expr/src/logical_plan/builder.rs | 26 +++++++++++++----- datafusion/sqllogictest/test_files/cte.slt | 30 +++++++++++++++++++++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 01e6af948762..f47249d76d5b 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -51,9 +51,9 @@ use arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion_common::config::FormatOptions; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::{ - get_target_functional_dependencies, plan_datafusion_err, plan_err, Column, DFField, - DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, ScalarValue, - TableReference, ToDFSchema, UnnestOptions, + get_target_functional_dependencies, not_impl_err, plan_datafusion_err, plan_err, + Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result, + ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; /// Default table name for unnamed table @@ -132,14 +132,26 @@ impl LogicalPlanBuilder { ) -> Result { // TODO: we need to do a bunch of validation here. Maybe more. if is_distinct { - return Err(DataFusionError::NotImplemented( - "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported".to_string(), - )); + return not_impl_err!( + "Recursive queries with a distinct 'UNION' (in which the previous iteration's results will be de-duplicated) is not supported" + ); + } + // Ensure that the static term and the recursive term have the same number of fields + let static_fields_len = self.plan.schema().fields().len(); + let recurive_fields_len = recursive_term.schema().fields().len(); + if static_fields_len != recurive_fields_len { + return plan_err!( + "Non-recursive term and recursive term must have the same number of columns ({} != {})", + static_fields_len, recurive_fields_len + ); } + // Ensure that the recursive term has the same field types as the static term + let coerced_recursive_term = + coerce_plan_expr_for_schema(&recursive_term, self.plan.schema())?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, static_term: Arc::new(self.plan.clone()), - recursive_term: Arc::new(recursive_term), + recursive_term: Arc::new(coerced_recursive_term), is_distinct, }))) } diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index 50c88e41959f..e33dfabaf2ca 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -714,3 +714,33 @@ RecursiveQueryExec: name=recursive_cte, is_distinct=false --------------WorkTableExec: name=recursive_cte ------ProjectionExec: expr=[2 as val] --------PlaceholderRowExec + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9794 +# Non-recursive term and recursive term have different types +query IT +WITH RECURSIVE my_cte AS( + SELECT 1::int AS a + UNION ALL + SELECT a::bigint+2 FROM my_cte WHERE a<3 +) SELECT *, arrow_typeof(a) FROM my_cte; +---- +1 Int32 +3 Int32 + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9794 +# Non-recursive term and recursive term have different number of columns +query error DataFusion error: Error during planning: Non\-recursive term and recursive term must have the same number of columns \(1 != 3\) +WITH RECURSIVE my_cte AS ( + SELECT 1::bigint AS a + UNION ALL + SELECT a+2, 'a','c' FROM my_cte WHERE a<3 +) SELECT * FROM my_cte; + +# Test issue: https://github.com/apache/arrow-datafusion/issues/9794 +# Non-recursive term and recursive term have different types, and cannot be casted +query error DataFusion error: Arrow error: Cast error: Cannot cast string 'abc' to value of Int64 type +WITH RECURSIVE my_cte AS ( + SELECT 1 AS a + UNION ALL + SELECT 'abc' FROM my_cte WHERE CAST(a AS text) !='abc' +) SELECT * FROM my_cte; From 1b6ae8fcdad8324c30e654eaf47eb1ae9ddcd964 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:37:02 +0300 Subject: [PATCH 075/117] Add support for constant expression evaluation in limit (#9790) * Add support for constant expression evaluation in limit * Use existing const evaluator * Revert "Use existing const evaluator" This reverts commit 99b7a552d6ebf03997afd7d06f48da1e2adc4d94. * Update datafusion/sql/src/query.rs Co-authored-by: Andrew Lamb * Add negative tests --------- Co-authored-by: Andrew Lamb --- datafusion/sql/src/query.rs | 83 ++++++++++++++----- datafusion/sqllogictest/test_files/select.slt | 24 +++++- 2 files changed, 83 insertions(+), 24 deletions(-) diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index ea8edd0771c8..eda8398c432b 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -25,6 +25,7 @@ use datafusion_common::{ }; use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, + Operator, }; use sqlparser::ast::{ Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, @@ -221,37 +222,29 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let skip = match skip { - Some(skip_expr) => match self.sql_to_expr( - skip_expr.value, - input.schema(), - &mut PlannerContext::new(), - )? { - Expr::Literal(ScalarValue::Int64(Some(s))) => { - if s < 0 { - return plan_err!("Offset must be >= 0, '{s}' was provided."); - } - Ok(s as usize) - } - _ => plan_err!("Unexpected expression in OFFSET clause"), - }?, - _ => 0, - }; + Some(skip_expr) => { + let expr = self.sql_to_expr( + skip_expr.value, + input.schema(), + &mut PlannerContext::new(), + )?; + let n = get_constant_result(&expr, "OFFSET")?; + convert_usize_with_check(n, "OFFSET") + } + _ => Ok(0), + }?; let fetch = match fetch { Some(limit_expr) if limit_expr != sqlparser::ast::Expr::Value(Value::Null) => { - let n = match self.sql_to_expr( + let expr = self.sql_to_expr( limit_expr, input.schema(), &mut PlannerContext::new(), - )? { - Expr::Literal(ScalarValue::Int64(Some(n))) if n >= 0 => { - Ok(n as usize) - } - _ => plan_err!("LIMIT must not be negative"), - }?; - Some(n) + )?; + let n = get_constant_result(&expr, "LIMIT")?; + Some(convert_usize_with_check(n, "LIMIT")?) } _ => None, }; @@ -283,3 +276,47 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } } + +/// Retrieves the constant result of an expression, evaluating it if possible. +/// +/// This function takes an expression and an argument name as input and returns +/// a `Result` indicating either the constant result of the expression or an +/// error if the expression cannot be evaluated. +/// +/// # Arguments +/// +/// * `expr` - An `Expr` representing the expression to evaluate. +/// * `arg_name` - The name of the argument for error messages. +/// +/// # Returns +/// +/// * `Result` - An `Ok` variant containing the constant result if evaluation is successful, +/// or an `Err` variant containing an error message if evaluation fails. +/// +/// tracks a more general solution +fn get_constant_result(expr: &Expr, arg_name: &str) -> Result { + match expr { + Expr::Literal(ScalarValue::Int64(Some(s))) => Ok(*s), + Expr::BinaryExpr(binary_expr) => { + let lhs = get_constant_result(&binary_expr.left, arg_name)?; + let rhs = get_constant_result(&binary_expr.right, arg_name)?; + let res = match binary_expr.op { + Operator::Plus => lhs + rhs, + Operator::Minus => lhs - rhs, + Operator::Multiply => lhs * rhs, + _ => return plan_err!("Unsupported operator for {arg_name} clause"), + }; + Ok(res) + } + _ => plan_err!("Unexpected expression in {arg_name} clause"), + } +} + +/// Converts an `i64` to `usize`, performing a boundary check. +fn convert_usize_with_check(n: i64, arg_name: &str) -> Result { + if n < 0 { + plan_err!("{arg_name} must be >= 0, '{n}' was provided.") + } else { + Ok(n as usize) + } +} diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 3d3e73e81637..3a5c6497ebd4 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -550,9 +550,31 @@ select * from (select 1 a union all select 2) b order by a limit 1; 1 # select limit clause invalid -statement error DataFusion error: Error during planning: LIMIT must not be negative +statement error DataFusion error: Error during planning: LIMIT must be >= 0, '\-1' was provided\. select * from (select 1 a union all select 2) b order by a limit -1; +# select limit with basic arithmetic +query I +select * from (select 1 a union all select 2) b order by a limit 1+1; +---- +1 +2 + +# select limit with basic arithmetic +query I +select * from (values (1)) LIMIT 10*100; +---- +1 + +# More complex expressions in the limit is not supported yet. +# See issue: https://github.com/apache/arrow-datafusion/issues/9821 +statement error DataFusion error: Error during planning: Unsupported operator for LIMIT clause +select * from (values (1)) LIMIT 100/10; + +# More complex expressions in the limit is not supported yet. +statement error DataFusion error: Error during planning: Unexpected expression in LIMIT clause +select * from (values (1)) LIMIT cast(column1 as tinyint); + # select limit clause query I select * from (select 1 a union all select 2) b order by a limit null; From ba8f1af25d2c7d81d1b06e838cf5895e3da1dbb9 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Wed, 27 Mar 2024 17:49:16 +0300 Subject: [PATCH 076/117] Projection Pushdown through user defined LogicalPlan nodes. (#9690) * Naive support for schema preserving plans * Add mapping support between schemas * Fix name * Update comment * Update comment * Do not calculate mapping for unnecessary sections * Update datafusion/optimizer/src/optimize_projections.rs Co-authored-by: Andrew Lamb * Add new tests * Add new api to get necessary columns * Add new test for multi children * Address reviews --------- Co-authored-by: Andrew Lamb --- datafusion/expr/src/logical_plan/extension.rs | 43 +++ .../optimizer/src/optimize_projections.rs | 289 +++++++++++++++++- 2 files changed, 325 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/logical_plan/extension.rs b/datafusion/expr/src/logical_plan/extension.rs index bb2c932ce391..b55256ca17de 100644 --- a/datafusion/expr/src/logical_plan/extension.rs +++ b/datafusion/expr/src/logical_plan/extension.rs @@ -98,6 +98,24 @@ pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync { inputs: &[LogicalPlan], ) -> Arc; + /// Returns the necessary input columns for this node required to compute + /// the columns in the output schema + /// + /// This is used for projection push-down when DataFusion has determined that + /// only a subset of the output columns of this node are needed by its parents. + /// This API is used to tell DataFusion which, if any, of the input columns are no longer + /// needed. + /// + /// Return `None`, the default, if this information can not be determined. + /// Returns `Some(_)` with the column indices for each child of this node that are + /// needed to compute `output_columns` + fn necessary_children_exprs( + &self, + _output_columns: &[usize], + ) -> Option>> { + None + } + /// Update the hash `state` with this node requirements from /// [`Hash`]. /// @@ -243,6 +261,24 @@ pub trait UserDefinedLogicalNodeCore: // but the doc comments have not been updated. #[allow(clippy::wrong_self_convention)] fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self; + + /// Returns the necessary input columns for this node required to compute + /// the columns in the output schema + /// + /// This is used for projection push-down when DataFusion has determined that + /// only a subset of the output columns of this node are needed by its parents. + /// This API is used to tell DataFusion which, if any, of the input columns are no longer + /// needed. + /// + /// Return `None`, the default, if this information can not be determined. + /// Returns `Some(_)` with the column indices for each child of this node that are + /// needed to compute `output_columns` + fn necessary_children_exprs( + &self, + _output_columns: &[usize], + ) -> Option>> { + None + } } /// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode` @@ -284,6 +320,13 @@ impl UserDefinedLogicalNode for T { Arc::new(self.from_template(exprs, inputs)) } + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + self.necessary_children_exprs(output_columns) + } + fn dyn_hash(&self, state: &mut dyn Hasher) { let mut s = state; self.hash(&mut s); diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 08ee38f64abd..b942f187c331 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -31,7 +31,8 @@ use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::SchemaRef; use datafusion_common::{ - get_required_group_by_exprs_indices, Column, DFSchema, DFSchemaRef, JoinType, Result, + get_required_group_by_exprs_indices, internal_err, Column, DFSchema, DFSchemaRef, + JoinType, Result, }; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::{ @@ -162,14 +163,40 @@ fn optimize_projections( .map(|input| ((0..input.schema().fields().len()).collect_vec(), false)) .collect::>() } + LogicalPlan::Extension(extension) => { + let necessary_children_indices = if let Some(necessary_children_indices) = + extension.node.necessary_children_exprs(indices) + { + necessary_children_indices + } else { + // Requirements from parent cannot be routed down to user defined logical plan safely + return Ok(None); + }; + let children = extension.node.inputs(); + if children.len() != necessary_children_indices.len() { + return internal_err!("Inconsistent length between children and necessary children indices. \ + Make sure `.necessary_children_exprs` implementation of the `UserDefinedLogicalNode` is \ + consistent with actual children length for the node."); + } + // Expressions used by node. + let exprs = plan.expressions(); + children + .into_iter() + .zip(necessary_children_indices) + .map(|(child, necessary_indices)| { + let child_schema = child.schema(); + let child_req_indices = + indices_referred_by_exprs(child_schema, exprs.iter())?; + Ok((merge_slices(&necessary_indices, &child_req_indices), false)) + }) + .collect::>>()? + } LogicalPlan::EmptyRelation(_) | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) - | LogicalPlan::Extension(_) | LogicalPlan::DescribeTable(_) => { // These operators have no inputs, so stop the optimization process. - // TODO: Add support for `LogicalPlan::Extension`. return Ok(None); } LogicalPlan::Projection(proj) => { @@ -899,21 +926,161 @@ fn is_projection_unnecessary(input: &LogicalPlan, proj_exprs: &[Expr]) -> Result #[cfg(test)] mod tests { + use std::fmt::Formatter; use std::sync::Arc; use crate::optimize_projections::OptimizeProjections; - use crate::test::{assert_optimized_plan_eq, test_table_scan}; + use crate::test::{ + assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name, + }; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Result, TableReference}; + use datafusion_common::{Column, DFSchemaRef, JoinType, Result, TableReference}; use datafusion_expr::{ - binary_expr, col, count, lit, logical_plan::builder::LogicalPlanBuilder, not, - table_scan, try_cast, when, Expr, Like, LogicalPlan, Operator, + binary_expr, build_join_schema, col, count, lit, + logical_plan::builder::LogicalPlanBuilder, not, table_scan, try_cast, when, + BinaryExpr, Expr, Extension, Like, LogicalPlan, Operator, + UserDefinedLogicalNodeCore, }; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } + #[derive(Debug, Hash, PartialEq, Eq)] + struct NoOpUserDefined { + exprs: Vec, + schema: DFSchemaRef, + input: Arc, + } + + impl NoOpUserDefined { + fn new(schema: DFSchemaRef, input: Arc) -> Self { + Self { + exprs: vec![], + schema, + input, + } + } + + fn with_exprs(mut self, exprs: Vec) -> Self { + self.exprs = exprs; + self + } + } + + impl UserDefinedLogicalNodeCore for NoOpUserDefined { + fn name(&self) -> &str { + "NoOpUserDefined" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.exprs.clone() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoOpUserDefined") + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + Self { + exprs: exprs.to_vec(), + input: Arc::new(inputs[0].clone()), + schema: self.schema.clone(), + } + } + + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + // Since schema is same. Output columns requires their corresponding version in the input columns. + Some(vec![output_columns.to_vec()]) + } + } + + #[derive(Debug, Hash, PartialEq, Eq)] + struct UserDefinedCrossJoin { + exprs: Vec, + schema: DFSchemaRef, + left_child: Arc, + right_child: Arc, + } + + impl UserDefinedCrossJoin { + fn new(left_child: Arc, right_child: Arc) -> Self { + let left_schema = left_child.schema(); + let right_schema = right_child.schema(); + let schema = Arc::new( + build_join_schema(left_schema, right_schema, &JoinType::Inner).unwrap(), + ); + Self { + exprs: vec![], + schema, + left_child, + right_child, + } + } + } + + impl UserDefinedLogicalNodeCore for UserDefinedCrossJoin { + fn name(&self) -> &str { + "UserDefinedCrossJoin" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.left_child, &self.right_child] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.exprs.clone() + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "UserDefinedCrossJoin") + } + + fn from_template(&self, exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + assert_eq!(inputs.len(), 2); + Self { + exprs: exprs.to_vec(), + left_child: Arc::new(inputs[0].clone()), + right_child: Arc::new(inputs[1].clone()), + schema: self.schema.clone(), + } + } + + fn necessary_children_exprs( + &self, + output_columns: &[usize], + ) -> Option>> { + let left_child_len = self.left_child.schema().fields().len(); + let mut left_reqs = vec![]; + let mut right_reqs = vec![]; + for &out_idx in output_columns { + if out_idx < left_child_len { + left_reqs.push(out_idx); + } else { + // Output indices further than the left_child_len + // comes from right children + right_reqs.push(out_idx - left_child_len) + } + } + Some(vec![left_reqs, right_reqs]) + } + } + #[test] fn merge_two_projection() -> Result<()> { let table_scan = test_table_scan()?; @@ -1192,4 +1359,112 @@ mod tests { \n TableScan: test projection=[a]"; assert_optimized_plan_equal(&plan, expected) } + + // Since only column `a` is referred at the output. Scan should only contain projection=[a]. + // User defined node should be able to propagate necessary expressions by its parent to its child. + #[test] + fn test_user_defined_logical_plan_node() -> Result<()> { + let table_scan = test_table_scan()?; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + )), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Only column `a` is referred at the output. However, User defined node itself uses column `b` + // during its operation. Hence, scan should contain projection=[a, b]. + // User defined node should be able to propagate necessary expressions by its parent, as well as its own + // required expressions. + #[test] + fn test_user_defined_logical_plan_node2() -> Result<()> { + let table_scan = test_table_scan()?; + let exprs = vec![Expr::Column(Column::from_qualified_name("b"))]; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new( + NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + ) + .with_exprs(exprs), + ), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a, b]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` + // during its operation. Hence, scan should contain projection=[a, b, c]. + // User defined node should be able to propagate necessary expressions by its parent, as well as its own + // required expressions. Expressions doesn't have to be just column. Requirements from complex expressions + // should be propagated also. + #[test] + fn test_user_defined_logical_plan_node3() -> Result<()> { + let table_scan = test_table_scan()?; + let left_expr = Expr::Column(Column::from_qualified_name("b")); + let right_expr = Expr::Column(Column::from_qualified_name("c")); + let binary_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(left_expr), + Operator::Plus, + Box::new(right_expr), + )); + let exprs = vec![binary_expr]; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new( + NoOpUserDefined::new( + table_scan.schema().clone(), + Arc::new(table_scan.clone()), + ) + .with_exprs(exprs), + ), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: test.a, Int32(0) AS d\ + \n NoOpUserDefined\ + \n TableScan: test projection=[a, b, c]"; + assert_optimized_plan_equal(&plan, expected) + } + + // Columns `l.a`, `l.c`, `r.a` is referred at the output. + // User defined node should be able to propagate necessary expressions by its parent, to its children. + // Even if it has multiple children. + // left child should have `projection=[a, c]`, and right side should have `projection=[a]`. + #[test] + fn test_user_defined_logical_plan_node4() -> Result<()> { + let left_table = test_table_scan_with_name("l")?; + let right_table = test_table_scan_with_name("r")?; + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(UserDefinedCrossJoin::new( + Arc::new(left_table.clone()), + Arc::new(right_table.clone()), + )), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .project(vec![col("l.a"), col("l.c"), col("r.a"), lit(0).alias("d")])? + .build()?; + + let expected = "Projection: l.a, l.c, r.a, Int32(0) AS d\ + \n UserDefinedCrossJoin\ + \n TableScan: l projection=[a, c]\ + \n TableScan: r projection=[a]"; + assert_optimized_plan_equal(&plan, expected) + } } From 3dfbc97d9a4eb58b081dd64d78d4899b22decf32 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 27 Mar 2024 12:10:47 -0400 Subject: [PATCH 077/117] chore(deps): update substrait requirement from 0.27.0 to 0.28.0 (#9809) * chore(deps): update substrait requirement from 0.27.0 to 0.28.0 Updates the requirements on [substrait](https://github.com/substrait-io/substrait-rs) to permit the latest version. - [Release notes](https://github.com/substrait-io/substrait-rs/releases) - [Changelog](https://github.com/substrait-io/substrait-rs/blob/main/CHANGELOG.md) - [Commits](https://github.com/substrait-io/substrait-rs/compare/v0.27.0...v0.28.0) --- updated-dependencies: - dependency-name: substrait dependency-type: direct:production ... Signed-off-by: dependabot[bot] * update dataufsion-cli dependencies --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 112 ++++++++++++++++---------------- datafusion/substrait/Cargo.toml | 2 +- 2 files changed, 57 insertions(+), 57 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 2f1d95d639d4..b5535a47e9c1 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -39,9 +39,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] @@ -272,7 +272,7 @@ dependencies = [ "arrow-schema", "chrono", "half", - "indexmap 2.2.5", + "indexmap 2.2.6", "lexical-core", "num", "serde", @@ -381,13 +381,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.78" +version = "0.1.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85" +checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -412,9 +412,9 @@ dependencies = [ [[package]] name = "autocfg" -version = "1.1.0" +version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" [[package]] name = "aws-config" @@ -708,9 +708,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.69" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", "cc", @@ -832,9 +832,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "bytes-utils" @@ -885,9 +885,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.35" +version = "0.4.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf5903dcbc0a39312feb77df2ff4c76387d591b9fc7b04a238dcf8bb62639a" +checksum = "8a0d04d43504c61aa6c7531f1871dd0d418d91130162063b789da00fd7057a5e" dependencies = [ "android-tzdata", "iana-time-zone", @@ -1092,7 +1092,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad291aa74992b9b7a7e88c38acbbf6ad7e107f1d90ee8775b7bc1fc3394f485c" dependencies = [ "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -1145,7 +1145,7 @@ dependencies = [ "glob", "half", "hashbrown 0.14.3", - "indexmap 2.2.5", + "indexmap 2.2.6", "itertools", "log", "num-traits", @@ -1331,7 +1331,7 @@ dependencies = [ "half", "hashbrown 0.14.3", "hex", - "indexmap 2.2.5", + "indexmap 2.2.6", "itertools", "log", "md-5", @@ -1362,7 +1362,7 @@ dependencies = [ "futures", "half", "hashbrown 0.14.3", - "indexmap 2.2.5", + "indexmap 2.2.6", "itertools", "log", "once_cell", @@ -1530,9 +1530,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.0.1" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" [[package]] name = "fd-lock" @@ -1651,7 +1651,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -1735,7 +1735,7 @@ dependencies = [ "futures-sink", "futures-util", "http", - "indexmap 2.2.5", + "indexmap 2.2.6", "slab", "tokio", "tokio-util", @@ -1952,9 +1952,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.5" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", "hashbrown 0.14.3", @@ -1995,9 +1995,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "jobserver" @@ -2514,7 +2514,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 2.2.5", + "indexmap 2.2.6", ] [[package]] @@ -2572,7 +2572,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -2755,9 +2755,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", @@ -2784,15 +2784,15 @@ checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" [[package]] name = "reqwest" -version = "0.11.26" +version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ "base64 0.21.7", "bytes", @@ -2910,9 +2910,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.31" +version = "0.38.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" dependencies = [ "bitflags 2.5.0", "errno", @@ -2978,9 +2978,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8" +checksum = "868e20fada228fefaf6b652e00cc73623d54f8171e7352c18bb281571f2d92da" [[package]] name = "rustls-webpki" @@ -3113,14 +3113,14 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] name = "serde_json" -version = "1.0.114" +version = "1.0.115" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" +checksum = "12dc5c46daa8e9fdf4f5e71b6cf9a53f2487da0e86e55808e2d35539666497dd" dependencies = [ "itoa", "ryu", @@ -3176,9 +3176,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "snafu" @@ -3248,7 +3248,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -3294,7 +3294,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -3307,7 +3307,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -3329,9 +3329,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.53" +version = "2.0.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" +checksum = "002a1b3dbf967edfafc32655d0f377ab0bb7b994aa1d32c8cc7e9b8bf3ebb8f0" dependencies = [ "proc-macro2", "quote", @@ -3372,7 +3372,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85b77fafb263dd9d05cbeac119526425676db3784113aa9295c88498cbf8bff1" dependencies = [ "cfg-if", - "fastrand 2.0.1", + "fastrand 2.0.2", "rustix", "windows-sys 0.52.0", ] @@ -3415,7 +3415,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -3510,7 +3510,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -3607,7 +3607,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -3652,7 +3652,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] @@ -3806,7 +3806,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", "wasm-bindgen-shared", ] @@ -3840,7 +3840,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4098,7 +4098,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.53", + "syn 2.0.55", ] [[package]] diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml index 7475dfc1e37b..cc79685c9429 100644 --- a/datafusion/substrait/Cargo.toml +++ b/datafusion/substrait/Cargo.toml @@ -36,7 +36,7 @@ itertools = { workspace = true } object_store = { workspace = true } prost = "0.12" prost-types = "0.12" -substrait = "0.27.0" +substrait = "0.28.0" [dev-dependencies] tokio = { workspace = true } From 7f4b338d6f7e4434529b87ca9eb273c7fb8819ec Mon Sep 17 00:00:00 2001 From: Marko Grujic Date: Wed, 27 Mar 2024 22:52:58 +0100 Subject: [PATCH 078/117] Run TPC-H SF10 during PR benchmarks (#9822) * Run TPC-H SF10 during PR benchmarks * Add memory benchmarks to the workflow Also distinguish the output file by the SF used. --- .github/workflows/pr_benchmarks.yml | 11 +++++++++-- benchmarks/bench.sh | 4 ++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr_benchmarks.yml b/.github/workflows/pr_benchmarks.yml index b7b85c9fcf14..29d001783b17 100644 --- a/.github/workflows/pr_benchmarks.yml +++ b/.github/workflows/pr_benchmarks.yml @@ -28,9 +28,10 @@ jobs: cd benchmarks mkdir data - # Setup the TPC-H data set with a scale factor of 10 + # Setup the TPC-H data sets for scale factors 1 and 10 ./bench.sh data tpch - + ./bench.sh data tpch10 + - name: Generate unique result names run: | echo "HEAD_LONG_SHA=$(git log -1 --format='%H')" >> "$GITHUB_ENV" @@ -44,6 +45,9 @@ jobs: cd benchmarks ./bench.sh run tpch + ./bench.sh run tpch_mem + ./bench.sh run tpch10 + ./bench.sh run tpch_mem10 # For some reason this step doesn't seem to propagate the env var down into the script if [ -d "results/HEAD" ]; then @@ -64,6 +68,9 @@ jobs: cd benchmarks ./bench.sh run tpch + ./bench.sh run tpch_mem + ./bench.sh run tpch10 + ./bench.sh run tpch_mem10 echo ${{ github.event.issue.number }} > pr diff --git a/benchmarks/bench.sh b/benchmarks/bench.sh index 039f4790acb0..a72400892752 100755 --- a/benchmarks/bench.sh +++ b/benchmarks/bench.sh @@ -314,7 +314,7 @@ run_tpch() { fi TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - RESULTS_FILE="${RESULTS_DIR}/tpch.json" + RESULTS_FILE="${RESULTS_DIR}/tpch_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch benchmark..." $CARGO_COMMAND --bin tpch -- benchmark datafusion --iterations 5 --path "${TPCH_DIR}" --format parquet -o ${RESULTS_FILE} @@ -329,7 +329,7 @@ run_tpch_mem() { fi TPCH_DIR="${DATA_DIR}/tpch_sf${SCALE_FACTOR}" - RESULTS_FILE="${RESULTS_DIR}/tpch_mem.json" + RESULTS_FILE="${RESULTS_DIR}/tpch_mem_sf${SCALE_FACTOR}.json" echo "RESULTS_FILE: ${RESULTS_FILE}" echo "Running tpch_mem benchmark..." # -m means in memory From 0534382b9984146a06dbab7a5a9cea3f105d11dd Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Wed, 27 Mar 2024 19:36:38 -0700 Subject: [PATCH 079/117] Expose `parser` on DFParser to enable user controlled parsing (#9729) * poc: custom parser * play with extension statement * tweak * Revert "tweak" This reverts commit e57006e089c2378ca6cdad5c628a6b7c8d90a884. * Revert "play with extension statement" This reverts commit 86588e4513fc787c01c567c40bb76cb598a55ad1. * style: cargo fmt * Update datafusion-examples/examples/sql_parsing.rs Co-authored-by: Andrew Lamb * Apply suggestions from code review Co-authored-by: Andrew Lamb * style: cargo cmt * refactor: less nesting in parse statement * docs: better example description --------- Co-authored-by: Andrew Lamb --- datafusion-examples/README.md | 21 +-- datafusion-examples/examples/sql_dialect.rs | 134 ++++++++++++++++++++ datafusion/sql/src/parser.rs | 2 +- 3 files changed, 146 insertions(+), 11 deletions(-) create mode 100644 datafusion-examples/examples/sql_dialect.rs diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index dbc8050555b9..7ca90463cf8c 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -42,36 +42,37 @@ cargo run --example csv_sql ## Single Process +- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) +- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) +- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) - [`avro_sql.rs`](examples/avro_sql.rs): Build and run a query plan from a SQL statement against a local AVRO file +- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`csv_sql.rs`](examples/csv_sql.rs): Build and run a query plan from a SQL statement against a local CSV file - [`csv_sql_streaming.rs`](examples/csv_sql_streaming.rs): Build and run a streaming query plan from a SQL statement against a local CSV file -- [`catalog.rs`](examples/catalog.rs): Register the table into a custom catalog - [`custom_datasource.rs`](examples/custom_datasource.rs): Run queries against a custom datasource (TableProvider) -- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file - [`dataframe-to-s3.rs`](examples/external_dependency/dataframe-to-s3.rs): Run a query using a DataFrame against a parquet file from s3 and writing back to s3 -- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame +- [`dataframe.rs`](examples/dataframe.rs): Run a query using a DataFrame against a local parquet file - [`dataframe_in_memory.rs`](examples/dataframe_in_memory.rs): Run a query using a DataFrame against data in memory +- [`dataframe_output.rs`](examples/dataframe_output.rs): Examples of methods which write data out from a DataFrame - [`deserialize_to_struct.rs`](examples/deserialize_to_struct.rs): Convert query results into rust structs using serde - [`expr_api.rs`](examples/expr_api.rs): Create, execute, simplify and analyze `Expr`s - [`flight_sql_server.rs`](examples/flight/flight_sql_server.rs): Run DataFusion as a standalone process and execute SQL queries from JDBC clients - [`function_factory.rs`](examples/function_factory.rs): Register `CREATE FUNCTION` handler to implement SQL macros - [`make_date.rs`](examples/make_date.rs): Examples of using the make_date function - [`memtable.rs`](examples/memtable.rs): Create an query data in memory using SQL and `RecordBatch`es -- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics - [`parquet_sql.rs`](examples/parquet_sql.rs): Build and run a query plan from a SQL statement against a local Parquet file - [`parquet_sql_multiple_files.rs`](examples/parquet_sql_multiple_files.rs): Build and run a query plan from a SQL statement against multiple local Parquet files +- [`pruning.rs`](examples/parquet_sql.rs): Use pruning to rule out files based on statistics - [`query-aws-s3.rs`](examples/external_dependency/query-aws-s3.rs): Configure `object_store` and run a query against files stored in AWS S3 - [`query-http-csv.rs`](examples/query-http-csv.rs): Configure `object_store` and run a query against files vi HTTP - [`regexp.rs`](examples/regexp.rs): Examples of using regular expression functions - [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass -- [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function -- [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions -- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) -- [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) -- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) +- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) -- [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) +- [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a custom SQL dialect on top of `DFParser` +- [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function +- [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using to_timestamp functions ## Distributed diff --git a/datafusion-examples/examples/sql_dialect.rs b/datafusion-examples/examples/sql_dialect.rs new file mode 100644 index 000000000000..259f38216b80 --- /dev/null +++ b/datafusion-examples/examples/sql_dialect.rs @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::fmt::Display; + +use datafusion::error::Result; +use datafusion_sql::{ + parser::{CopyToSource, CopyToStatement, DFParser, Statement}, + sqlparser::{keywords::Keyword, parser::ParserError, tokenizer::Token}, +}; + +/// This example demonstrates how to use the DFParser to parse a statement in a custom way +/// +/// This technique can be used to implement a custom SQL dialect, for example. +#[tokio::main] +async fn main() -> Result<()> { + let mut my_parser = + MyParser::new("COPY source_table TO 'file.fasta' STORED AS FASTA")?; + + let my_statement = my_parser.parse_statement()?; + + match my_statement { + MyStatement::DFStatement(s) => println!("df: {}", s), + MyStatement::MyCopyTo(s) => println!("my_copy: {}", s), + } + + Ok(()) +} + +/// Here we define a Parser for our new SQL dialect that wraps the existing `DFParser` +struct MyParser<'a> { + df_parser: DFParser<'a>, +} + +impl MyParser<'_> { + fn new(sql: &str) -> Result { + let df_parser = DFParser::new(sql)?; + Ok(Self { df_parser }) + } + + /// Returns true if the next token is `COPY` keyword, false otherwise + fn is_copy(&self) -> bool { + matches!( + self.df_parser.parser.peek_token().token, + Token::Word(w) if w.keyword == Keyword::COPY + ) + } + + /// This is the entry point to our parser -- it handles `COPY` statements specially + /// but otherwise delegates to the existing DataFusion parser. + pub fn parse_statement(&mut self) -> Result { + if self.is_copy() { + self.df_parser.parser.next_token(); // COPY + let df_statement = self.df_parser.parse_copy()?; + + if let Statement::CopyTo(s) = df_statement { + Ok(MyStatement::from(s)) + } else { + Ok(MyStatement::DFStatement(Box::from(df_statement))) + } + } else { + let df_statement = self.df_parser.parse_statement()?; + Ok(MyStatement::from(df_statement)) + } + } +} + +enum MyStatement { + DFStatement(Box), + MyCopyTo(MyCopyToStatement), +} + +impl Display for MyStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MyStatement::DFStatement(s) => write!(f, "{}", s), + MyStatement::MyCopyTo(s) => write!(f, "{}", s), + } + } +} + +impl From for MyStatement { + fn from(s: Statement) -> Self { + Self::DFStatement(Box::from(s)) + } +} + +impl From for MyStatement { + fn from(s: CopyToStatement) -> Self { + if s.stored_as == Some("FASTA".to_string()) { + Self::MyCopyTo(MyCopyToStatement::from(s)) + } else { + Self::DFStatement(Box::from(Statement::CopyTo(s))) + } + } +} + +struct MyCopyToStatement { + pub source: CopyToSource, + pub target: String, +} + +impl From for MyCopyToStatement { + fn from(s: CopyToStatement) -> Self { + Self { + source: s.source, + target: s.target, + } + } +} + +impl Display for MyCopyToStatement { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "COPY {} TO '{}' STORED AS FASTA", + self.source, self.target + ) + } +} diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index a5d7970495c5..c585917a1ed0 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -278,7 +278,7 @@ fn ensure_not_set(field: &Option, name: &str) -> Result<(), ParserError> { /// `CREATE EXTERNAL TABLE` have special syntax in DataFusion. See /// [`Statement`] for a list of this special syntax pub struct DFParser<'a> { - parser: Parser<'a>, + pub parser: Parser<'a>, } impl<'a> DFParser<'a> { From ce3d446be5f6a11664e100fc47940e6ecb5418d3 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Thu, 28 Mar 2024 09:47:37 -0500 Subject: [PATCH 080/117] Disable parallel reading for gziped ndjson file (#9799) * for debug * disable paralle reading for gziped ndjson file * directly return None * delete .gz * fix clippy --- .../core/src/datasource/physical_plan/json.rs | 52 ++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index 194a4a91c34a..c876b3d078f3 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -150,6 +150,9 @@ impl ExecutionPlan for NdJsonExec { target_partitions: usize, config: &datafusion_common::config::ConfigOptions, ) -> Result>> { + if self.file_compression_type == FileCompressionType::GZIP { + return Ok(None); + } let repartition_file_min_size = config.optimizer.repartition_file_min_size; let preserve_order_within_groups = self.properties().output_ordering().is_some(); let file_groups = &self.base_config.file_groups; @@ -392,11 +395,14 @@ mod tests { use arrow::datatypes::{Field, SchemaBuilder}; use datafusion_common::cast::{as_int32_array, as_int64_array, as_string_array}; use datafusion_common::FileType; - + use flate2::write::GzEncoder; + use flate2::Compression; use futures::StreamExt; use object_store::chunked::ChunkedStore; use object_store::local::LocalFileSystem; use rstest::*; + use std::fs::File; + use std::io; use tempfile::TempDir; use url::Url; @@ -884,4 +890,48 @@ mod tests { Ok(()) } + fn compress_file(path: &str, output_path: &str) -> io::Result<()> { + let input_file = File::open(path)?; + let mut reader = BufReader::new(input_file); + + let output_file = File::create(output_path)?; + let writer = std::io::BufWriter::new(output_file); + + let mut encoder = GzEncoder::new(writer, Compression::default()); + io::copy(&mut reader, &mut encoder)?; + + encoder.finish()?; + Ok(()) + } + #[tokio::test] + async fn test_disable_parallel_for_json_gz() -> Result<()> { + let config = SessionConfig::new() + .with_repartition_file_scans(true) + .with_repartition_file_min_size(0) + .with_target_partitions(4); + let ctx = SessionContext::new_with_config(config); + let path = format!("{TEST_DATA_BASE}/1.json"); + let compressed_path = format!("{}.gz", &path); + compress_file(&path, &compressed_path)?; + let read_option = NdJsonReadOptions::default() + .file_compression_type(FileCompressionType::GZIP) + .file_extension("gz"); + let df = ctx.read_json(compressed_path.clone(), read_option).await?; + let res = df.collect().await; + fs::remove_file(&compressed_path)?; + assert_batches_eq!( + &[ + "+-----+------------------+---------------+------+", + "| a | b | c | d |", + "+-----+------------------+---------------+------+", + "| 1 | [2.0, 1.3, -6.1] | [false, true] | 4 |", + "| -10 | [2.0, 1.3, -6.1] | [true, true] | 4 |", + "| 2 | [2.0, , -6.1] | [false, ] | text |", + "| | | | |", + "+-----+------------------+---------------+------+", + ], + &res? + ); + Ok(()) + } } From 666f7a5221ac9b4d5232cef7b8008ca71d2c1be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vojt=C4=9Bch=20Toman?= Date: Thu, 28 Mar 2024 17:57:16 +0100 Subject: [PATCH 081/117] Optimize to_timestamp (with format) (#9090) (#9833) Eliminate duplicate parsing of the input and format strings in some cases Co-authored-by: Vojtech Toman --- datafusion/functions/src/datetime/common.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index 007ffd35ca3a..f0689ffd64e9 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -22,8 +22,9 @@ use arrow::array::{ }; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::DataType; +use chrono::format::{parse, Parsed, StrftimeItems}; use chrono::LocalResult::Single; -use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; +use chrono::{DateTime, TimeZone, Utc}; use itertools::Either; use datafusion_common::cast::as_generic_string_array; @@ -84,12 +85,15 @@ pub(crate) fn string_to_datetime_formatted( )) }; + let mut parsed = Parsed::new(); + parse(&mut parsed, s, StrftimeItems::new(format)).map_err(|e| err(&e.to_string()))?; + // attempt to parse the string assuming it has a timezone - let dt = DateTime::parse_from_str(s, format); + let dt = parsed.to_datetime(); if let Err(e) = &dt { // no timezone or other failure, try without a timezone - let ndt = NaiveDateTime::parse_from_str(s, format); + let ndt = parsed.to_naive_datetime_with_offset(0); if let Err(e) = &ndt { return Err(err(&e.to_string())); } From 2cca3710f3b31148ffe99d9e225c768c921a748b Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 13:02:27 -0400 Subject: [PATCH 082/117] Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function (#9825) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Fixed missing trim() function. * Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function --- datafusion-cli/Cargo.lock | 1 + datafusion/core/Cargo.toml | 1 + .../tests/dataframe/dataframe_functions.rs | 1 + datafusion/expr/src/built_in_function.rs | 14 +- datafusion/expr/src/expr_fn.rs | 8 - datafusion/functions/Cargo.toml | 4 + datafusion/functions/src/lib.rs | 9 + datafusion/functions/src/string/ascii.rs | 2 +- datafusion/functions/src/string/bit_length.rs | 4 +- datafusion/functions/src/string/btrim.rs | 1 + datafusion/functions/src/string/chr.rs | 2 +- datafusion/functions/src/string/common.rs | 158 +--------------- .../functions/src/string/levenshtein.rs | 3 +- datafusion/functions/src/string/lower.rs | 8 +- datafusion/functions/src/string/ltrim.rs | 3 +- .../functions/src/string/octet_length.rs | 13 +- datafusion/functions/src/string/overlay.rs | 2 +- datafusion/functions/src/string/repeat.rs | 4 +- datafusion/functions/src/string/replace.rs | 2 +- datafusion/functions/src/string/rtrim.rs | 1 + datafusion/functions/src/string/split_part.rs | 4 +- .../functions/src/string/starts_with.rs | 9 +- datafusion/functions/src/string/to_hex.rs | 9 +- datafusion/functions/src/string/upper.rs | 3 +- .../functions/src/unicode/character_length.rs | 176 +++++++++++++++++ datafusion/functions/src/unicode/mod.rs | 55 ++++++ datafusion/functions/src/utils.rs | 178 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 70 ------- .../physical-expr/src/unicode_expressions.rs | 23 --- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 8 +- datafusion/proto/src/logical_plan/to_proto.rs | 1 - datafusion/sql/Cargo.toml | 1 + datafusion/sql/tests/sql_integration.rs | 15 +- 36 files changed, 484 insertions(+), 318 deletions(-) create mode 100644 datafusion/functions/src/unicode/character_length.rs create mode 100644 datafusion/functions/src/unicode/mod.rs create mode 100644 datafusion/functions/src/utils.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index b5535a47e9c1..ba60c04cea55 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1273,6 +1273,7 @@ dependencies = [ "md-5", "regex", "sha2", + "unicode-segmentation", "uuid", ] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 1e5c0d748e3d..de03579975a2 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -70,6 +70,7 @@ unicode_expressions = [ "datafusion-physical-expr/unicode_expressions", "datafusion-optimizer/unicode_expressions", "datafusion-sql/unicode_expressions", + "datafusion-functions/unicode_expressions", ] [dependencies] diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 6ebd64c9b628..4371cce856ce 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -37,6 +37,7 @@ use datafusion::assert_batches_eq; use datafusion_common::DFSchema; use datafusion_expr::expr::Alias; use datafusion_expr::{approx_median, cast, ExprSchemable}; +use datafusion_functions::unicode::expr_fn::character_length; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index bb0f79f8eca4..eefbc131a27b 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -103,8 +103,6 @@ pub enum BuiltinScalarFunction { Cot, // string functions - /// character_length - CharacterLength, /// concat Concat, /// concat_ws @@ -218,7 +216,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Cbrt => Volatility::Immutable, BuiltinScalarFunction::Cot => Volatility::Immutable, BuiltinScalarFunction::Trunc => Volatility::Immutable, - BuiltinScalarFunction::CharacterLength => Volatility::Immutable, BuiltinScalarFunction::Concat => Volatility::Immutable, BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, @@ -257,9 +254,6 @@ impl BuiltinScalarFunction { // the return type of the built in function. // Some built-in functions' return type depends on the incoming type. match self { - BuiltinScalarFunction::CharacterLength => { - utf8_to_int_type(&input_expr_types[0], "character_length") - } BuiltinScalarFunction::Coalesce => { // COALESCE has multiple args and they might get coerced, get a preview of this let coerced_types = data_types(input_expr_types, &self.signature()); @@ -367,9 +361,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::CharacterLength - | BuiltinScalarFunction::InitCap - | BuiltinScalarFunction::Reverse => { + BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Reverse => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { @@ -584,10 +576,6 @@ impl BuiltinScalarFunction { // conditional functions BuiltinScalarFunction::Coalesce => &["coalesce"], - // string functions - BuiltinScalarFunction::CharacterLength => { - &["character_length", "char_length", "length"] - } BuiltinScalarFunction::Concat => &["concat"], BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0ea946288e0f..654464798625 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -577,13 +577,6 @@ scalar_expr!(Power, power, base exponent, "`base` raised to the power of `expone scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); -// string functions -scalar_expr!( - CharacterLength, - character_length, - string, - "the number of characters in the `string`" -); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Reverse, reverse, string, "reverses the `string`"); @@ -1032,7 +1025,6 @@ mod test { test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); - test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 81050dfddf66..0cab0276ff4b 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -43,6 +43,7 @@ default = [ "regex_expressions", "crypto_expressions", "string_expressions", + "unicode_expressions", ] # enable encode/decode functions encoding_expressions = ["base64", "hex"] @@ -52,6 +53,8 @@ math_expressions = [] regex_expressions = ["regex"] # enable string functions string_expressions = [] +# enable unicode functions +unicode_expressions = ["unicode-segmentation"] [lib] name = "datafusion_functions" @@ -75,6 +78,7 @@ log = { workspace = true } md-5 = { version = "^0.10.0", optional = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } +unicode-segmentation = { version = "^1.7.1", optional = true } uuid = { version = "1.7", features = ["v4"] } [dev-dependencies] diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs index f469b343e144..2a00839dc532 100644 --- a/datafusion/functions/src/lib.rs +++ b/datafusion/functions/src/lib.rs @@ -124,6 +124,12 @@ make_stub_package!(regex, "regex_expressions"); pub mod crypto; make_stub_package!(crypto, "crypto_expressions"); +#[cfg(feature = "unicode_expressions")] +pub mod unicode; +make_stub_package!(unicode, "unicode_expressions"); + +mod utils; + /// Fluent-style API for creating `Expr`s pub mod expr_fn { #[cfg(feature = "core_expressions")] @@ -140,6 +146,8 @@ pub mod expr_fn { pub use super::regex::expr_fn::*; #[cfg(feature = "string_expressions")] pub use super::string::expr_fn::*; + #[cfg(feature = "unicode_expressions")] + pub use super::unicode::expr_fn::*; } /// Registers all enabled packages with a [`FunctionRegistry`] @@ -151,6 +159,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { .chain(math::functions()) .chain(regex::functions()) .chain(crypto::functions()) + .chain(unicode::functions()) .chain(string::functions()); all_functions.try_for_each(|udf| { diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 5bd77833a935..9a07f4c19cf1 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use crate::utils::make_scalar_function; use arrow::array::Int32Array; use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 9f612751584e..6a200471d42d 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::kernels::length::bit_length; use std::any::Any; +use arrow::compute::kernels::length::bit_length; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::utf8_to_int_type; #[derive(Debug)] pub(super) struct BitLengthFunc { diff --git a/datafusion/functions/src/string/btrim.rs b/datafusion/functions/src/string/btrim.rs index de1c9cc69b72..573a23d07021 100644 --- a/datafusion/functions/src/string/btrim.rs +++ b/datafusion/functions/src/string/btrim.rs @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed. /// btrim('xyxtrimyyx', 'xyz') = 'trim' diff --git a/datafusion/functions/src/string/chr.rs b/datafusion/functions/src/string/chr.rs index df3b803ba659..d1f8dc398a2b 100644 --- a/datafusion/functions/src/string/chr.rs +++ b/datafusion/functions/src/string/chr.rs @@ -29,7 +29,7 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::make_scalar_function; /// Returns the character with the given code. chr(0) is disallowed because text data types cannot store that character. /// chr(65) = 'A' diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 339f4e6c1a23..276aad121df2 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -24,8 +24,7 @@ use arrow::datatypes::DataType; use datafusion_common::cast::as_generic_string_array; use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; -use datafusion_physical_expr::functions::Hint; +use datafusion_expr::ColumnarValue; pub(crate) enum TrimType { Left, @@ -98,52 +97,6 @@ pub(crate) fn general_trim( } } -/// Creates a function to identify the optimal return type of a string function given -/// the type of its first argument. -/// -/// If the input type is `LargeUtf8` or `LargeBinary` the return type is -/// `$largeUtf8Type`, -/// -/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, -macro_rules! get_optimal_return_type { - ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { - pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { - Ok(match arg_type { - // LargeBinary inputs are automatically coerced to Utf8 - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - // Binary inputs are automatically coerced to Utf8 - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - DataType::Dictionary(_, value_type) => match **value_type { - DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, - DataType::Utf8 | DataType::Binary => $utf8Type, - DataType::Null => DataType::Null, - _ => { - return datafusion_common::exec_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - **value_type - ); - } - }, - data_type => { - return datafusion_common::exec_err!( - "The {} function can only accept strings, but got {:?}.", - name.to_uppercase(), - data_type - ); - } - }) - } - }; -} - -// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. -get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); - -// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. -get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - /// applies a unary expression to `args[0]` that is expected to be downcastable to /// a `GenericStringArray` and returns a `GenericStringArray` (which may have a different offset) /// # Errors @@ -221,112 +174,3 @@ where }, } } - -pub(super) fn make_scalar_function( - inner: F, - hints: Vec, -) -> ScalarFunctionImplementation -where - F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, -{ - Arc::new(move |args: &[ColumnarValue]| { - // first, identify if any of the arguments is an Array. If yes, store its `len`, - // as any scalar will need to be converted to an array of len `len`. - let len = args - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - - let is_scalar = len.is_none(); - - let inferred_length = len.unwrap_or(1); - let args = args - .iter() - .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) - .map(|(arg, hint)| { - // Decide on the length to expand this scalar to depending - // on the given hints. - let expansion_len = match hint { - Hint::AcceptsSingular => 1, - Hint::Pad => inferred_length, - }; - arg.clone().into_array(expansion_len) - }) - .collect::>>()?; - - let result = (inner)(&args); - if is_scalar { - // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) - } else { - result.map(ColumnarValue::Array) - } - }) -} - -#[cfg(test)] -pub mod test { - /// $FUNC ScalarUDFImpl to test - /// $ARGS arguments (vec) to pass to function - /// $EXPECTED a Result - /// $EXPECTED_TYPE is the expected value type - /// $EXPECTED_DATA_TYPE is the expected result type - /// $ARRAY_TYPE is the column type after function applied - macro_rules! test_function { - ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { - let expected: Result> = $EXPECTED; - let func = $FUNC; - - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); - let return_type = func.return_type(&type_array); - - match expected { - Ok(expected) => { - assert_eq!(return_type.is_ok(), true); - assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); - - let result = func.invoke($ARGS); - assert_eq!(result.is_ok(), true); - - let len = $ARGS - .iter() - .fold(Option::::None, |acc, arg| match arg { - ColumnarValue::Scalar(_) => acc, - ColumnarValue::Array(a) => Some(a.len()), - }); - let inferred_length = len.unwrap_or(1); - let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); - let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); - - // value is correct - match expected { - Some(v) => assert_eq!(result.value(0), v), - None => assert!(result.is_null(0)), - }; - } - Err(expected_error) => { - if return_type.is_err() { - match return_type { - Ok(_) => assert!(false, "expected error"), - Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } - } - } - else { - // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke($ARGS) { - Ok(_) => assert!(false, "expected error"), - Err(error) => { - assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); - } - } - } - } - }; - }; - } - - pub(crate) use test_function; -} diff --git a/datafusion/functions/src/string/levenshtein.rs b/datafusion/functions/src/string/levenshtein.rs index b5de4b28948f..8f497e73e393 100644 --- a/datafusion/functions/src/string/levenshtein.rs +++ b/datafusion/functions/src/string/levenshtein.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, Int32Array, Int64Array, OffsetSizeTrait}; use arrow::datatypes::DataType; +use crate::utils::{make_scalar_function, utf8_to_int_type}; use datafusion_common::cast::as_generic_string_array; use datafusion_common::utils::datafusion_strsim; use datafusion_common::{exec_err, Result}; @@ -28,8 +29,6 @@ use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use crate::string::common::{make_scalar_function, utf8_to_int_type}; - #[derive(Debug)] pub(super) struct LevenshteinFunc { signature: Signature, diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 42bda0470067..327772bd808d 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -15,12 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::{handle, utf8_to_str_type}; +use std::any::Any; + use arrow::datatypes::DataType; + use datafusion_common::Result; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; + +use crate::string::common::handle; +use crate::utils::utf8_to_str_type; #[derive(Debug)] pub(super) struct LowerFunc { diff --git a/datafusion/functions/src/string/ltrim.rs b/datafusion/functions/src/string/ltrim.rs index 535ffb14f5f5..e6926e5bd56e 100644 --- a/datafusion/functions/src/string/ltrim.rs +++ b/datafusion/functions/src/string/ltrim.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayRef, OffsetSizeTrait}; use std::any::Any; +use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with leading characters removed. If the characters are not specified, whitespace is removed. /// ltrim('zzzytest', 'xyz') = 'test' diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index 36a62fbe4e38..639bf6cb48a9 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::kernels::length::length; use std::any::Any; +use arrow::compute::kernels::length::length; use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::utf8_to_int_type; #[derive(Debug)] pub(super) struct OctetLengthFunc { @@ -86,14 +86,17 @@ impl ScalarUDFImpl for OctetLengthFunc { #[cfg(test)] mod tests { - use crate::string::common::test::test_function; - use crate::string::octet_length::OctetLengthFunc; + use std::sync::Arc; + use arrow::array::{Array, Int32Array, StringArray}; use arrow::datatypes::DataType::Int32; + use datafusion_common::ScalarValue; use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use std::sync::Arc; + + use crate::string::octet_length::OctetLengthFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index d7cc0da8068e..8b9cc03afc4d 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct OverlayFunc { diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 83bc929cb9a4..f4319af0a5c4 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct RepeatFunc { @@ -99,8 +99,8 @@ mod tests { use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::string::common::test::test_function; use crate::string::repeat::RepeatFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index e35244296090..e869ac205440 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct ReplaceFunc { diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index 17d2f8234b34..d04d15ce8847 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -26,6 +26,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; /// Returns the longest string with trailing characters removed. If the characters are not specified, whitespace is removed. /// rtrim('testxxzx', 'xyz') = 'test' diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index af201e90fcf6..0aa968a1ef5b 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -27,7 +27,7 @@ use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; -use crate::string::common::*; +use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub(super) struct SplitPartFunc { @@ -117,8 +117,8 @@ mod tests { use datafusion_common::{exec_err, Result}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; - use crate::string::common::test::test_function; use crate::string::split_part::SplitPartFunc; + use crate::utils::test::test_function; #[test] fn test_functions() -> Result<()> { diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 4450b9d332a0..f1b03907f8d8 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + use arrow::array::{ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; + use datafusion_common::{cast::as_generic_string_array, internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; + +use crate::utils::make_scalar_function; /// Returns true if string starts with prefix. /// starts_with('alphabet', 'alph') = 't' diff --git a/datafusion/functions/src/string/to_hex.rs b/datafusion/functions/src/string/to_hex.rs index 1bdece3f7af8..ab320c68d493 100644 --- a/datafusion/functions/src/string/to_hex.rs +++ b/datafusion/functions/src/string/to_hex.rs @@ -15,18 +15,21 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::make_scalar_function; +use std::any::Any; +use std::sync::Arc; + use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::{ ArrowNativeType, ArrowPrimitiveType, DataType, Int32Type, Int64Type, }; + use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::{exec_err, plan_err}; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; -use std::any::Any; -use std::sync::Arc; + +use crate::utils::make_scalar_function; /// Converts the number to its equivalent hexadecimal representation. /// to_hex(2147483647) = '7fffffff' diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index a0c910ebb2c8..066174abf277 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::string::common::{handle, utf8_to_str_type}; +use crate::string::common::handle; +use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::ColumnarValue; diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs new file mode 100644 index 000000000000..51331bf9a586 --- /dev/null +++ b/datafusion/functions/src/unicode/character_length.rs @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{make_scalar_function, utf8_to_int_type}; +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug)] +pub(super) struct CharacterLengthFunc { + signature: Signature, + aliases: Vec, +} + +impl CharacterLengthFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + aliases: vec![String::from("length"), String::from("char_length")], + } + } +} + +impl ScalarUDFImpl for CharacterLengthFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "character_length" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "character_length") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(character_length::, vec![])(args) + } + DataType::LargeUtf8 => { + make_scalar_function(character_length::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function character_length") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns number of characters in the string. +/// character_length('josé') = 4 +/// The implementation counts UTF-8 code points to count the number of characters +fn character_length(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| { + string.map(|string: &str| { + T::Native::from_usize(string.chars().count()) + .expect("should not fail as string.chars will always return integer") + }) + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use crate::unicode::character_length::CharacterLengthFunc; + use crate::utils::test::test_function; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + #[test] + fn test_functions() -> Result<()> { + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("chars") + )))], + Ok(Some(5)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("josé") + )))], + Ok(Some(4)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( + String::from("") + )))], + Ok(Some(0)), + i32, + Int32, + Int32Array + ); + #[cfg(feature = "unicode_expressions")] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + i32, + Int32, + Int32Array + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + CharacterLengthFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("josé"))))], + internal_err!( + "function character_length requires compilation with feature flag: unicode_expressions." + ), + i32, + Int32, + Int32Array + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs new file mode 100644 index 000000000000..291de3843903 --- /dev/null +++ b/datafusion/functions/src/unicode/mod.rs @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! "unicode" DataFusion functions + +use std::sync::Arc; + +use datafusion_expr::ScalarUDF; + +mod character_length; + +// create UDFs +make_udf_function!( + character_length::CharacterLengthFunc, + CHARACTER_LENGTH, + character_length +); + +pub mod expr_fn { + use datafusion_expr::Expr; + + #[doc = "the number of characters in the `string`"] + pub fn char_length(string: Expr) -> Expr { + character_length(string) + } + + #[doc = "the number of characters in the `string`"] + pub fn character_length(string: Expr) -> Expr { + super::character_length().call(vec![string]) + } + + #[doc = "the number of characters in the `string`"] + pub fn length(string: Expr) -> Expr { + character_length(string) + } +} + +/// Return a list of all functions in this package +pub fn functions() -> Vec> { + vec![character_length()] +} diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs new file mode 100644 index 000000000000..f45deafdb37a --- /dev/null +++ b/datafusion/functions/src/utils.rs @@ -0,0 +1,178 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_physical_expr::functions::Hint; +use std::sync::Arc; + +/// Creates a function to identify the optimal return type of a string function given +/// the type of its first argument. +/// +/// If the input type is `LargeUtf8` or `LargeBinary` the return type is +/// `$largeUtf8Type`, +/// +/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`, +macro_rules! get_optimal_return_type { + ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { + pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result { + Ok(match arg_type { + // LargeBinary inputs are automatically coerced to Utf8 + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + // Binary inputs are automatically coerced to Utf8 + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + DataType::Dictionary(_, value_type) => match **value_type { + DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type, + DataType::Utf8 | DataType::Binary => $utf8Type, + DataType::Null => DataType::Null, + _ => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + **value_type + ); + } + }, + data_type => { + return datafusion_common::exec_err!( + "The {} function can only accept strings, but got {:?}.", + name.to_uppercase(), + data_type + ); + } + }) + } + }; +} + +// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. +get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); + +// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. +get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); + +pub(super) fn make_scalar_function( + inner: F, + hints: Vec, +) -> ScalarFunctionImplementation +where + F: Fn(&[ArrayRef]) -> Result + Sync + Send + 'static, +{ + Arc::new(move |args: &[ColumnarValue]| { + // first, identify if any of the arguments is an Array. If yes, store its `len`, + // as any scalar will need to be converted to an array of len `len`. + let len = args + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + + let is_scalar = len.is_none(); + + let inferred_length = len.unwrap_or(1); + let args = args + .iter() + .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad))) + .map(|(arg, hint)| { + // Decide on the length to expand this scalar to depending + // on the given hints. + let expansion_len = match hint { + Hint::AcceptsSingular => 1, + Hint::Pad => inferred_length, + }; + arg.clone().into_array(expansion_len) + }) + .collect::>>()?; + + let result = (inner)(&args); + if is_scalar { + // If all inputs are scalar, keeps output as scalar + let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + result.map(ColumnarValue::Scalar) + } else { + result.map(ColumnarValue::Array) + } + }) +} + +#[cfg(test)] +pub mod test { + /// $FUNC ScalarUDFImpl to test + /// $ARGS arguments (vec) to pass to function + /// $EXPECTED a Result + /// $EXPECTED_TYPE is the expected value type + /// $EXPECTED_DATA_TYPE is the expected result type + /// $ARRAY_TYPE is the column type after function applied + macro_rules! test_function { + ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => { + let expected: Result> = $EXPECTED; + let func = $FUNC; + + let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let return_type = func.return_type(&type_array); + + match expected { + Ok(expected) => { + assert_eq!(return_type.is_ok(), true); + assert_eq!(return_type.unwrap(), $EXPECTED_DATA_TYPE); + + let result = func.invoke($ARGS); + assert_eq!(result.is_ok(), true); + + let len = $ARGS + .iter() + .fold(Option::::None, |acc, arg| match arg { + ColumnarValue::Scalar(_) => acc, + ColumnarValue::Array(a) => Some(a.len()), + }); + let inferred_length = len.unwrap_or(1); + let result = result.unwrap().clone().into_array(inferred_length).expect("Failed to convert to array"); + let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type"); + + // value is correct + match expected { + Some(v) => assert_eq!(result.value(0), v), + None => assert!(result.is_null(0)), + }; + } + Err(expected_error) => { + if return_type.is_err() { + match return_type { + Ok(_) => assert!(false, "expected error"), + Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } + } + } + else { + // invoke is expected error - cannot use .expect_err() due to Debug not being implemented + match func.invoke($ARGS) { + Ok(_) => assert!(false, "expected error"), + Err(error) => { + assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); + } + } + } + } + }; + }; + } + + pub(crate) use test_function; +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index cd9bba63d624..9adc8536341d 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -254,29 +254,6 @@ pub fn create_physical_fun( Arc::new(|args| make_scalar_function_inner(math_expressions::cot)(args)) } // string functions - BuiltinScalarFunction::CharacterLength => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int32Type, - "character_length" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - character_length, - Int64Type, - "character_length" - ); - make_scalar_function_inner(func)(args) - } - other => exec_err!( - "Unsupported data type {other:?} for function character_length" - ), - }) - } BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { @@ -595,53 +572,6 @@ mod tests { #[test] fn test_functions() -> Result<()> { - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("chars")], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("josé")], - Ok(Some(4)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit("")], - Ok(Some(0)), - i32, - Int32, - Int32Array - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - CharacterLength, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - i32, - Int32, - Int32Array - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - CharacterLength, - &[lit("josé")], - internal_err!( - "function character_length requires compilation with feature flag: unicode_expressions." - ), - i32, - Int32, - Int32Array - ); test_function!( Concat, &[lit("aa"), lit("bb"), lit("cc"),], diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index 8ec9e062d9b7..c7e4b7d7c443 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -36,29 +36,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns number of characters in the string. -/// character_length('josé') = 4 -/// The implementation counts UTF-8 code points to count the number of characters -pub fn character_length(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| { - string.map(|string: &str| { - T::Native::from_usize(string.chars().count()) - .expect("should not fail as string.chars will always return integer") - }) - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - /// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. /// left('abcde', 2) = 'ab' /// The implementation uses UTF-8 code points as characters diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index f405ecf976be..766ca6633ee1 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -565,7 +565,7 @@ enum ScalarFunction { // RegexpMatch = 21; // 22 was BitLength // 23 was Btrim - CharacterLength = 24; + // 24 was CharacterLength // 25 was Chr Concat = 26; ConcatWithSeparator = 27; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0d22ba5db773..f2814956ef1b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22928,7 +22928,6 @@ impl serde::Serialize for ScalarFunction { Self::Sin => "Sin", Self::Sqrt => "Sqrt", Self::Trunc => "Trunc", - Self::CharacterLength => "CharacterLength", Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", @@ -22988,7 +22987,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sin", "Sqrt", "Trunc", - "CharacterLength", "Concat", "ConcatWithSeparator", "InitCap", @@ -23077,7 +23075,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Sin" => Ok(ScalarFunction::Sin), "Sqrt" => Ok(ScalarFunction::Sqrt), "Trunc" => Ok(ScalarFunction::Trunc), - "CharacterLength" => Ok(ScalarFunction::CharacterLength), "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 07c3fad15373..ecc94fcdaf99 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2864,7 +2864,7 @@ pub enum ScalarFunction { /// RegexpMatch = 21; /// 22 was BitLength /// 23 was Btrim - CharacterLength = 24, + /// 24 was CharacterLength /// 25 was Chr Concat = 26, ConcatWithSeparator = 27, @@ -3001,7 +3001,6 @@ impl ScalarFunction { ScalarFunction::Sin => "Sin", ScalarFunction::Sqrt => "Sqrt", ScalarFunction::Trunc => "Trunc", - ScalarFunction::CharacterLength => "CharacterLength", ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", @@ -3055,7 +3054,6 @@ impl ScalarFunction { "Sin" => Some(Self::Sin), "Sqrt" => Some(Self::Sqrt), "Trunc" => Some(Self::Trunc), - "CharacterLength" => Some(Self::CharacterLength), "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 4b9874bf8f65..19edd71a3a80 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -48,8 +48,8 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, asinh, atan, atan2, atanh, cbrt, ceil, character_length, coalesce, - concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, + acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, + cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, @@ -450,7 +450,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Concat => Self::Concat, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, - ScalarFunction::CharacterLength => Self::CharacterLength, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, @@ -1372,9 +1371,6 @@ pub fn parse_expr( ScalarFunction::Signum => { Ok(signum(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::CharacterLength => { - Ok(character_length(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::InitCap => { Ok(initcap(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1335d511a0ea..11fc7362c75d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1442,7 +1442,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Concat => Self::Concat, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, - BuiltinScalarFunction::CharacterLength => Self::CharacterLength, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index ca2c1a240c21..b9f6dc259eb7 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -49,6 +49,7 @@ strum = { version = "0.26.1", features = ["derive"] } [dev-dependencies] ctor = { workspace = true } +datafusion-functions = { workspace = true, default-features = true } env_logger = { workspace = true } paste = "^1.0" rstest = { workspace = true } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 448a9c54202e..101c31039c7e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -38,6 +38,7 @@ use datafusion_sql::{ planner::{ContextProvider, ParserOptions, PlannerContext, SqlToRel}, }; +use datafusion_functions::unicode; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -88,7 +89,7 @@ fn parse_decimals() { fn parse_ident_normalization() { let test_data = [ ( - "SELECT LENGTH('str')", + "SELECT CHARACTER_LENGTH('str')", "Ok(Projection: character_length(Utf8(\"str\"))\n EmptyRelation)", false, ), @@ -2688,6 +2689,7 @@ fn logical_plan_with_dialect_and_options( options: ParserOptions, ) -> Result { let context = MockContextProvider::default() + .with_udf(unicode::character_length().as_ref().clone()) .with_udf(make_udf( "nullif", vec![DataType::Int32, DataType::Int32], @@ -4508,26 +4510,27 @@ fn test_field_not_found_window_function() { #[test] fn test_parse_escaped_string_literal_value() { - let sql = r"SELECT length('\r\n') AS len"; + let sql = r"SELECT character_length('\r\n') AS len"; let expected = "Projection: character_length(Utf8(\"\\r\\n\")) AS len\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\r\n') AS len"; + let sql = r"SELECT character_length(E'\r\n') AS len"; let expected = "Projection: character_length(Utf8(\"\r\n\")) AS len\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; + let sql = + r"SELECT character_length(E'\445') AS len, E'\x4B' AS hex, E'\u0001' AS unicode"; let expected = "Projection: character_length(Utf8(\"%\")) AS len, Utf8(\"\u{004b}\") AS hex, Utf8(\"\u{0001}\") AS unicode\ \n EmptyRelation"; quick_test(sql, expected); - let sql = r"SELECT length(E'\000') AS len"; + let sql = r"SELECT character_length(E'\000') AS len"; assert_eq!( logical_plan(sql).unwrap_err().strip_backtrace(), - "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 15\")" + "SQL error: TokenizerError(\"Unterminated encoded string literal at Line: 1, Column 25\")" ) } From 656887c1158e6855da54407986c710b8006968fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Thu, 28 Mar 2024 18:04:53 +0100 Subject: [PATCH 083/117] [Minor] Update TCPDS tests, remove some #[ignore]d tests (#9829) * Update TCPDS tests * Fmt --- datafusion/core/tests/tpcds_planning.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/datafusion/core/tests/tpcds_planning.rs b/datafusion/core/tests/tpcds_planning.rs index 4db97c75cb33..e8d2c3764e0c 100644 --- a/datafusion/core/tests/tpcds_planning.rs +++ b/datafusion/core/tests/tpcds_planning.rs @@ -73,9 +73,6 @@ async fn tpcds_logical_q9() -> Result<()> { create_logical_plan(9).await } -#[ignore] -// Schema error: No field named 'c'.'c_customer_sk'. -// issue: https://github.com/apache/arrow-datafusion/issues/4794 #[tokio::test] async fn tpcds_logical_q10() -> Result<()> { create_logical_plan(10).await @@ -201,9 +198,6 @@ async fn tpcds_logical_q34() -> Result<()> { create_logical_plan(34).await } -#[ignore] -// Schema error: No field named 'c'.'c_customer_sk'. -// issue: https://github.com/apache/arrow-datafusion/issues/4794 #[tokio::test] async fn tpcds_logical_q35() -> Result<()> { create_logical_plan(35).await @@ -577,7 +571,7 @@ async fn tpcds_physical_q9() -> Result<()> { create_physical_plan(9).await } -#[ignore] // FieldNotFound +#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q10() -> Result<()> { create_physical_plan(10).await @@ -703,7 +697,7 @@ async fn tpcds_physical_q34() -> Result<()> { create_physical_plan(34).await } -#[ignore] // FieldNotFound +#[ignore] // Physical plan does not support logical expression Exists() #[tokio::test] async fn tpcds_physical_q35() -> Result<()> { create_physical_plan(35).await @@ -734,7 +728,8 @@ async fn tpcds_physical_q40() -> Result<()> { create_physical_plan(40).await } -#[ignore] // Physical plan does not support logical expression () +#[ignore] +// Context("check_analyzed_plan", Plan("Correlated column is not allowed in predicate: (..) #[tokio::test] async fn tpcds_physical_q41() -> Result<()> { create_physical_plan(41).await From 45b8b0b2afa92656c1d35728d13645db5e97f70b Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 28 Mar 2024 10:15:22 -0700 Subject: [PATCH 084/117] doc: Adding baseline benchmark example (#9827) * Adding baseline benchmark example --- docs/source/contributor-guide/index.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 9d3a177be6bd..eadf4147c57e 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -237,6 +237,25 @@ If the environment variable `PARQUET_FILE` is set, the benchmark will run querie The benchmark will automatically remove any generated parquet file on exit, however, if interrupted (e.g. by CTRL+C) it will not. This can be useful for analysing the particular file after the fact, or preserving it to use with `PARQUET_FILE` in subsequent runs. +### Comparing Baselines + +By default, Criterion.rs will compare the measurements against the previous run (if any). Sometimes it's useful to keep a set of measurements around for several runs. For example, you might want to make multiple changes to the code while comparing against the master branch. For this situation, Criterion.rs supports custom baselines. + +``` + git checkout main + cargo bench --bench sql_planner -- --save-baseline main + git checkout YOUR_BRANCH + cargo bench --bench sql_planner -- --baseline main +``` + +Note: For MacOS it may be required to run `cargo bench` with `sudo` + +``` +sudo cargo bench ... +``` + +More information on [Baselines](https://bheisler.github.io/criterion.rs/book/user_guide/command_line_options.html#baselines) + ### Upstream Benchmark Suites Instructions and tooling for running upstream benchmark suites against DataFusion can be found in [benchmarks](https://github.com/apache/arrow-datafusion/tree/main/benchmarks). From 3a1e3adc727eb46c582d48105a4b115f5b80a992 Mon Sep 17 00:00:00 2001 From: Matthew Turner Date: Thu, 28 Mar 2024 14:08:38 -0400 Subject: [PATCH 085/117] Add name method to execution plan (#9793) * Add name method to execution plan * Cleanup * Change default impl * Add tests * Clippy * Use unimplemented macro * Fix --- .../examples/custom_datasource.rs | 4 + .../datasource/physical_plan/arrow_file.rs | 4 + .../core/src/datasource/physical_plan/avro.rs | 4 + .../core/src/datasource/physical_plan/csv.rs | 4 + .../core/src/datasource/physical_plan/json.rs | 4 + .../datasource/physical_plan/parquet/mod.rs | 4 + .../enforce_distribution.rs | 4 + .../physical_optimizer/output_requirements.rs | 4 + datafusion/core/src/physical_planner.rs | 4 + .../physical-plan/src/aggregates/mod.rs | 8 + datafusion/physical-plan/src/analyze.rs | 4 + .../physical-plan/src/coalesce_batches.rs | 4 + .../physical-plan/src/coalesce_partitions.rs | 4 + datafusion/physical-plan/src/display.rs | 4 + datafusion/physical-plan/src/empty.rs | 4 + datafusion/physical-plan/src/explain.rs | 4 + datafusion/physical-plan/src/filter.rs | 4 + datafusion/physical-plan/src/insert.rs | 4 + .../physical-plan/src/joins/cross_join.rs | 4 + .../physical-plan/src/joins/hash_join.rs | 4 + .../src/joins/nested_loop_join.rs | 4 + .../src/joins/sort_merge_join.rs | 4 + .../src/joins/symmetric_hash_join.rs | 4 + datafusion/physical-plan/src/lib.rs | 140 ++++++++++++++++++ datafusion/physical-plan/src/limit.rs | 8 + datafusion/physical-plan/src/memory.rs | 4 + .../physical-plan/src/placeholder_row.rs | 4 + datafusion/physical-plan/src/projection.rs | 4 + .../physical-plan/src/recursive_query.rs | 4 + .../physical-plan/src/repartition/mod.rs | 4 + .../physical-plan/src/sorts/partial_sort.rs | 4 + datafusion/physical-plan/src/sorts/sort.rs | 4 + .../src/sorts/sort_preserving_merge.rs | 4 + datafusion/physical-plan/src/streaming.rs | 4 + datafusion/physical-plan/src/union.rs | 8 + datafusion/physical-plan/src/unnest.rs | 4 + datafusion/physical-plan/src/values.rs | 4 + .../src/windows/bounded_window_agg_exec.rs | 4 + .../src/windows/window_agg_exec.rs | 4 + datafusion/physical-plan/src/work_table.rs | 4 + .../custom-table-providers.md | 4 + 41 files changed, 312 insertions(+) diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index 0b7e3d4c6442..ba0d2f3b30f8 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -226,6 +226,10 @@ impl DisplayAs for CustomExec { } impl ExecutionPlan for CustomExec { + fn name(&self) -> &'static str { + "CustomExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/src/datasource/physical_plan/arrow_file.rs b/datafusion/core/src/datasource/physical_plan/arrow_file.rs index 96b3adf968b8..1e8775731015 100644 --- a/datafusion/core/src/datasource/physical_plan/arrow_file.rs +++ b/datafusion/core/src/datasource/physical_plan/arrow_file.rs @@ -122,6 +122,10 @@ impl DisplayAs for ArrowExec { } impl ExecutionPlan for ArrowExec { + fn name(&self) -> &'static str { + "ArrowExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/src/datasource/physical_plan/avro.rs b/datafusion/core/src/datasource/physical_plan/avro.rs index 2ccd83de80cb..4e5140e82d3f 100644 --- a/datafusion/core/src/datasource/physical_plan/avro.rs +++ b/datafusion/core/src/datasource/physical_plan/avro.rs @@ -99,6 +99,10 @@ impl DisplayAs for AvroExec { } impl ExecutionPlan for AvroExec { + fn name(&self) -> &'static str { + "AvroExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs b/datafusion/core/src/datasource/physical_plan/csv.rs index 31cc52f79697..831ef4520567 100644 --- a/datafusion/core/src/datasource/physical_plan/csv.rs +++ b/datafusion/core/src/datasource/physical_plan/csv.rs @@ -160,6 +160,10 @@ impl DisplayAs for CsvExec { } impl ExecutionPlan for CsvExec { + fn name(&self) -> &'static str { + "CsvExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/core/src/datasource/physical_plan/json.rs b/datafusion/core/src/datasource/physical_plan/json.rs index c876b3d078f3..a5afda47527f 100644 --- a/datafusion/core/src/datasource/physical_plan/json.rs +++ b/datafusion/core/src/datasource/physical_plan/json.rs @@ -127,6 +127,10 @@ impl DisplayAs for NdJsonExec { } impl ExecutionPlan for NdJsonExec { + fn name(&self) -> &'static str { + "NdJsonExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 767cde9cc55e..377dad5cee6c 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -315,6 +315,10 @@ impl DisplayAs for ParquetExec { } impl ExecutionPlan for ParquetExec { + fn name(&self) -> &'static str { + "ParquetExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 0740a8d2cdbc..a58f8698d6ce 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1369,6 +1369,10 @@ pub(crate) mod tests { } impl ExecutionPlan for SortRequiredExec { + fn name(&self) -> &'static str { + "SortRequiredExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index bf010a5e39d8..829d523c990c 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -133,6 +133,10 @@ impl DisplayAs for OutputRequirementExec { } impl ExecutionPlan for OutputRequirementExec { + fn name(&self) -> &'static str { + "OutputRequirementExec" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index deac3dcf46fc..0a1730e944d3 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2616,6 +2616,10 @@ mod tests { } impl ExecutionPlan for NoOpExecutionPlan { + fn name(&self) -> &'static str { + "NoOpExecutionPlan" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 65987e01553d..e263876b07d5 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -636,6 +636,10 @@ impl DisplayAs for AggregateExec { } impl ExecutionPlan for AggregateExec { + fn name(&self) -> &'static str { + "AggregateExec" + } + /// Return a reference to Any that can be used for down-casting fn as_any(&self) -> &dyn Any { self @@ -1658,6 +1662,10 @@ mod tests { } impl ExecutionPlan for TestYieldingExec { + fn name(&self) -> &'static str { + "TestYieldingExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/analyze.rs b/datafusion/physical-plan/src/analyze.rs index 83a73ee992fb..c420581c4323 100644 --- a/datafusion/physical-plan/src/analyze.rs +++ b/datafusion/physical-plan/src/analyze.rs @@ -111,6 +111,10 @@ impl DisplayAs for AnalyzeExec { } impl ExecutionPlan for AnalyzeExec { + fn name(&self) -> &'static str { + "AnalyzeExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/coalesce_batches.rs b/datafusion/physical-plan/src/coalesce_batches.rs index 0b9ecebbb1e8..bc7c4a3d0673 100644 --- a/datafusion/physical-plan/src/coalesce_batches.rs +++ b/datafusion/physical-plan/src/coalesce_batches.rs @@ -104,6 +104,10 @@ impl DisplayAs for CoalesceBatchesExec { } impl ExecutionPlan for CoalesceBatchesExec { + fn name(&self) -> &'static str { + "CoalesceBatchesExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 1e58260a5344..1c725ce31f14 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -89,6 +89,10 @@ impl DisplayAs for CoalescePartitionsExec { } impl ExecutionPlan for CoalescePartitionsExec { + fn name(&self) -> &'static str { + "CoalescePartitionsExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/display.rs b/datafusion/physical-plan/src/display.rs index 4b7b35e53e1b..ca93ce5e7b83 100644 --- a/datafusion/physical-plan/src/display.rs +++ b/datafusion/physical-plan/src/display.rs @@ -489,6 +489,10 @@ mod tests { } impl ExecutionPlan for TestStatsExecPlan { + fn name(&self) -> &'static str { + "TestStatsExecPlan" + } + fn as_any(&self) -> &dyn std::any::Any { self } diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 4ff79cdaae70..8e8eb4d25e32 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -101,6 +101,10 @@ impl DisplayAs for EmptyExec { } impl ExecutionPlan for EmptyExec { + fn name(&self) -> &'static str { + "EmptyExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/explain.rs b/datafusion/physical-plan/src/explain.rs index 320ee37bed95..649946993229 100644 --- a/datafusion/physical-plan/src/explain.rs +++ b/datafusion/physical-plan/src/explain.rs @@ -98,6 +98,10 @@ impl DisplayAs for ExplainExec { } impl ExecutionPlan for ExplainExec { + fn name(&self) -> &'static str { + "ExplainExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index f44ade7106df..2996152fb924 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -231,6 +231,10 @@ impl DisplayAs for FilterExec { } impl ExecutionPlan for FilterExec { + fn name(&self) -> &'static str { + "FilterExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/insert.rs b/datafusion/physical-plan/src/insert.rs index 16c929b78144..f0233264f280 100644 --- a/datafusion/physical-plan/src/insert.rs +++ b/datafusion/physical-plan/src/insert.rs @@ -206,6 +206,10 @@ impl DisplayAs for FileSinkExec { } impl ExecutionPlan for FileSinkExec { + fn name(&self) -> &'static str { + "FileSinkExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 9f8dc0ce56b0..19d34f8048e3 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -194,6 +194,10 @@ impl DisplayAs for CrossJoinExec { } impl ExecutionPlan for CrossJoinExec { + fn name(&self) -> &'static str { + "CrossJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index a1c50a2113ba..1c0181c2e116 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -611,6 +611,10 @@ fn project_index_to_exprs( } impl ExecutionPlan for HashJoinExec { + fn name(&self) -> &'static str { + "HashJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index c6d891dd13c1..e6236e45f0a7 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -205,6 +205,10 @@ impl DisplayAs for NestedLoopJoinExec { } impl ExecutionPlan for NestedLoopJoinExec { + fn name(&self) -> &'static str { + "NestedLoopJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs b/datafusion/physical-plan/src/joins/sort_merge_join.rs index 7b70a2952b4c..21630087f2ca 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs @@ -262,6 +262,10 @@ impl DisplayAs for SortMergeJoinExec { } impl ExecutionPlan for SortMergeJoinExec { + fn name(&self) -> &'static str { + "SortMergeJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 79b8c813d860..453b217f7fc7 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -385,6 +385,10 @@ impl DisplayAs for SymmetricHashJoinExec { } impl ExecutionPlan for SymmetricHashJoinExec { + fn name(&self) -> &'static str { + "SymmetricHashJoinExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 6334a4a211d4..4b4b37f8b51b 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -113,6 +113,15 @@ pub use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream}; /// [`required_input_distribution`]: ExecutionPlan::required_input_distribution /// [`required_input_ordering`]: ExecutionPlan::required_input_ordering pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { + /// Short name for the ExecutionPlan, such as 'ParquetExec'. + fn name(&self) -> &'static str { + let full_name = std::any::type_name::(); + let maybe_start_idx = full_name.rfind(':'); + match maybe_start_idx { + Some(start_idx) => &full_name[start_idx + 1..], + None => "UNKNOWN", + } + } /// Returns the execution plan as [`Any`] so that it can be /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; @@ -778,4 +787,135 @@ pub fn get_plan_string(plan: &Arc) -> Vec { #[allow(clippy::single_component_path_imports)] use rstest_reuse; +#[cfg(test)] +mod tests { + use std::any::Any; + use std::sync::Arc; + + use arrow_schema::{Schema, SchemaRef}; + use datafusion_common::{Result, Statistics}; + use datafusion_execution::{SendableRecordBatchStream, TaskContext}; + + use crate::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties}; + + #[derive(Debug)] + pub struct EmptyExec; + + impl EmptyExec { + pub fn new(_schema: SchemaRef) -> Self { + Self + } + } + + impl DisplayAs for EmptyExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for EmptyExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[derive(Debug)] + pub struct RenamedEmptyExec; + + impl RenamedEmptyExec { + pub fn new(_schema: SchemaRef) -> Self { + Self + } + } + + impl DisplayAs for RenamedEmptyExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + _f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + unimplemented!() + } + } + + impl ExecutionPlan for RenamedEmptyExec { + fn name(&self) -> &'static str { + "MyRenamedEmptyExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + unimplemented!() + } + + fn children(&self) -> Vec> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result> { + unimplemented!() + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!() + } + + fn statistics(&self) -> Result { + unimplemented!() + } + } + + #[test] + fn test_execution_plan_name() { + let schema1 = Arc::new(Schema::empty()); + let default_name_exec = EmptyExec::new(schema1); + assert_eq!(default_name_exec.name(), "EmptyExec"); + + let schema2 = Arc::new(Schema::empty()); + let renamed_exec = RenamedEmptyExec::new(schema2); + assert_eq!(renamed_exec.name(), "MyRenamedEmptyExec"); + } +} + pub mod test; diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs index 9fa15cbf64e2..fab483b0da7d 100644 --- a/datafusion/physical-plan/src/limit.rs +++ b/datafusion/physical-plan/src/limit.rs @@ -111,6 +111,10 @@ impl DisplayAs for GlobalLimitExec { } impl ExecutionPlan for GlobalLimitExec { + fn name(&self) -> &'static str { + "GlobalLimitExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -317,6 +321,10 @@ impl DisplayAs for LocalLimitExec { } impl ExecutionPlan for LocalLimitExec { + fn name(&self) -> &'static str { + "LocalLimitExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 795ec3c7315e..883cdb540a9e 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -103,6 +103,10 @@ impl DisplayAs for MemoryExec { } impl ExecutionPlan for MemoryExec { + fn name(&self) -> &'static str { + "MemoryExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 3880cf3d77af..c047ff5122fe 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -119,6 +119,10 @@ impl DisplayAs for PlaceholderRowExec { } impl ExecutionPlan for PlaceholderRowExec { + fn name(&self) -> &'static str { + "PlaceholderRowExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 8fe82e7de3eb..f72815c01a9e 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -180,6 +180,10 @@ impl DisplayAs for ProjectionExec { } impl ExecutionPlan for ProjectionExec { + fn name(&self) -> &'static str { + "ProjectionExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 140820ff782a..ba7d1a54548a 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -108,6 +108,10 @@ impl RecursiveQueryExec { } impl ExecutionPlan for RecursiveQueryExec { + fn name(&self) -> &'static str { + "RecursiveQueryExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 7ac70949f893..c0dbf5164e19 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -406,6 +406,10 @@ impl DisplayAs for RepartitionExec { } impl ExecutionPlan for RepartitionExec { + fn name(&self) -> &'static str { + "RepartitionExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/sorts/partial_sort.rs b/datafusion/physical-plan/src/sorts/partial_sort.rs index 2acb881246a4..d24bc5a670e5 100644 --- a/datafusion/physical-plan/src/sorts/partial_sort.rs +++ b/datafusion/physical-plan/src/sorts/partial_sort.rs @@ -226,6 +226,10 @@ impl DisplayAs for PartialSortExec { } impl ExecutionPlan for PartialSortExec { + fn name(&self) -> &'static str { + "PartialSortExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index a80dab058ca6..a6f47d3d2fc9 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -860,6 +860,10 @@ impl DisplayAs for SortExec { } impl ExecutionPlan for SortExec { + fn name(&self) -> &'static str { + "SortExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 556615f64de6..edef022b0c00 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -144,6 +144,10 @@ impl DisplayAs for SortPreservingMergeExec { } impl ExecutionPlan for SortPreservingMergeExec { + fn name(&self) -> &'static str { + "SortPreservingMergeExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/streaming.rs b/datafusion/physical-plan/src/streaming.rs index 7b062ab8741f..d7e254c42fe1 100644 --- a/datafusion/physical-plan/src/streaming.rs +++ b/datafusion/physical-plan/src/streaming.rs @@ -191,6 +191,10 @@ impl DisplayAs for StreamingTableExec { #[async_trait] impl ExecutionPlan for StreamingTableExec { + fn name(&self) -> &'static str { + "StreamingTableExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/union.rs b/datafusion/physical-plan/src/union.rs index 64322bd5f101..69901aa2fa37 100644 --- a/datafusion/physical-plan/src/union.rs +++ b/datafusion/physical-plan/src/union.rs @@ -183,6 +183,10 @@ impl DisplayAs for UnionExec { } impl ExecutionPlan for UnionExec { + fn name(&self) -> &'static str { + "UnionExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self @@ -370,6 +374,10 @@ impl DisplayAs for InterleaveExec { } impl ExecutionPlan for InterleaveExec { + fn name(&self) -> &'static str { + "InterleaveExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 886b718e6efe..324e2ea2d773 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -112,6 +112,10 @@ impl DisplayAs for UnnestExec { } impl ExecutionPlan for UnnestExec { + fn name(&self) -> &'static str { + "UnnestExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 8868a59008b7..63e8c32349ab 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -154,6 +154,10 @@ impl DisplayAs for ValuesExec { } impl ExecutionPlan for ValuesExec { + fn name(&self) -> &'static str { + "ValuesExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs index 70b6182d81e7..75e203891cad 100644 --- a/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/bounded_window_agg_exec.rs @@ -237,6 +237,10 @@ impl DisplayAs for BoundedWindowAggExec { } impl ExecutionPlan for BoundedWindowAggExec { + fn name(&self) -> &'static str { + "BoundedWindowAggExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/windows/window_agg_exec.rs b/datafusion/physical-plan/src/windows/window_agg_exec.rs index e300eee49d31..46ba21bd797e 100644 --- a/datafusion/physical-plan/src/windows/window_agg_exec.rs +++ b/datafusion/physical-plan/src/windows/window_agg_exec.rs @@ -172,6 +172,10 @@ impl DisplayAs for WindowAggExec { } impl ExecutionPlan for WindowAggExec { + fn name(&self) -> &'static str { + "WindowAggExec" + } + /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/physical-plan/src/work_table.rs b/datafusion/physical-plan/src/work_table.rs index f6fc0334dfc5..dfdb624a5625 100644 --- a/datafusion/physical-plan/src/work_table.rs +++ b/datafusion/physical-plan/src/work_table.rs @@ -157,6 +157,10 @@ impl DisplayAs for WorkTableExec { } impl ExecutionPlan for WorkTableExec { + fn name(&self) -> &'static str { + "WorkTableExec" + } + fn as_any(&self) -> &dyn Any { self } diff --git a/docs/source/library-user-guide/custom-table-providers.md b/docs/source/library-user-guide/custom-table-providers.md index 9da207da68f3..11024f77e0d0 100644 --- a/docs/source/library-user-guide/custom-table-providers.md +++ b/docs/source/library-user-guide/custom-table-providers.md @@ -46,6 +46,10 @@ struct CustomExec { } impl ExecutionPlan for CustomExec { + fn name(&self) { + "CustomExec" + } + fn execute( &self, _partition: usize, From b6915f560aa8a48c5e4784b14d38d6809f858585 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 28 Mar 2024 14:11:07 -0400 Subject: [PATCH 086/117] chore(deps-dev): bump express (#9826) Bumps [express](https://github.com/expressjs/express) from 4.18.2 to 4.19.2. - [Release notes](https://github.com/expressjs/express/releases) - [Changelog](https://github.com/expressjs/express/blob/master/History.md) - [Commits](https://github.com/expressjs/express/compare/4.18.2...4.19.2) --- updated-dependencies: - dependency-name: express dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../datafusion-wasm-app/package-lock.json | 350 +++++++++++++----- 1 file changed, 265 insertions(+), 85 deletions(-) diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index aac87845bc9f..8b1b8ae079c2 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -731,13 +731,13 @@ } }, "node_modules/body-parser": { - "version": "1.20.1", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.1.tgz", - "integrity": "sha512-jWi7abTbYwajOytWCQc37VulmWiRae5RyTpaCyDcS5/lMdtwSz5lOpDE67srw/HYe35f1z3fDQw+3txg7gNtWw==", + "version": "1.20.2", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", + "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", "dev": true, "dependencies": { "bytes": "3.1.2", - "content-type": "~1.0.4", + "content-type": "~1.0.5", "debug": "2.6.9", "depd": "2.0.0", "destroy": "1.2.0", @@ -745,7 +745,7 @@ "iconv-lite": "0.4.24", "on-finished": "2.4.1", "qs": "6.11.0", - "raw-body": "2.5.1", + "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" }, @@ -892,13 +892,19 @@ } }, "node_modules/call-bind": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", - "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", + "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", "dev": true, "dependencies": { - "function-bind": "^1.1.1", - "get-intrinsic": "^1.0.2" + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1109,9 +1115,9 @@ } }, "node_modules/cookie": { - "version": "0.5.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz", - "integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==", + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", "dev": true, "engines": { "node": ">= 0.6" @@ -1204,6 +1210,23 @@ "node": ">= 10" } }, + "node_modules/define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "dependencies": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/define-lazy-prop": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", @@ -1323,6 +1346,27 @@ "node": ">=4" } }, + "node_modules/es-define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", + "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.2.4" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/es-module-lexer": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", @@ -1435,17 +1479,17 @@ } }, "node_modules/express": { - "version": "4.18.2", - "resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz", - "integrity": "sha512-5/PsL6iGPdfQ/lKM1UuielYgv3BUoJfz1aUwU9vHZ+J7gyvwdQXFEBIEIaxeGf0GIcreATNyBExtalisDbuMqQ==", + "version": "4.19.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.19.2.tgz", + "integrity": "sha512-5T6nhjsT+EOMzuck8JjBHARTHfMht0POzlA60WV2pMD3gyXw2LZnZ+ueGdNxG+0calOJcWKbpFcuzLZ91YWq9Q==", "dev": true, "dependencies": { "accepts": "~1.3.8", "array-flatten": "1.1.1", - "body-parser": "1.20.1", + "body-parser": "1.20.2", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.5.0", + "cookie": "0.6.0", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", @@ -1742,21 +1786,28 @@ } }, "node_modules/function-bind": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", - "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", - "dev": true + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", + "dev": true, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } }, "node_modules/get-intrinsic": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.1.tgz", - "integrity": "sha512-2DcsyfABl+gVHEfCOaTrWgyt+tb6MSEGmKq+kI5HwLbIYgjgmMcV8KQ41uaKz1xxUcn9tJtgFbQUEVcEbd0FYw==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", + "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", "dev": true, "dependencies": { - "function-bind": "^1.1.1", - "has": "^1.0.3", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", "has-proto": "^1.0.1", - "has-symbols": "^1.0.3" + "has-symbols": "^1.0.3", + "hasown": "^2.0.0" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1832,6 +1883,18 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/gopd": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", + "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.1.3" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/graceful-fs": { "version": "4.2.11", "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", @@ -1865,10 +1928,22 @@ "node": ">=8" } }, + "node_modules/has-property-descriptors": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", + "dev": true, + "dependencies": { + "es-define-property": "^1.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/has-proto": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.1.tgz", - "integrity": "sha512-7qE+iP+O+bgF9clE5+UoBFzE65mlBiVj3tKCrlNQ0Ogwm0BjpT/gK4SlLYDMybDh5I3TCTKnPPa0oMG7JDYrhg==", + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", + "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", "dev": true, "engines": { "node": ">= 0.4" @@ -1889,6 +1964,18 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/hpack.js": { "version": "2.1.6", "resolved": "https://registry.npmjs.org/hpack.js/-/hpack.js-2.1.6.tgz", @@ -2648,9 +2735,9 @@ } }, "node_modules/object-inspect": { - "version": "1.12.3", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", - "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "version": "1.13.1", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.1.tgz", + "integrity": "sha512-5qoj1RUiKOMsCCNLV1CBiPYE10sziTsnmNxkAI/rZhiD63CF7IqdFGC/XzjWjpSgLf0LxXX3bDFIh0E18f6UhQ==", "dev": true, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -2987,9 +3074,9 @@ } }, "node_modules/raw-body": { - "version": "2.5.1", - "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz", - "integrity": "sha512-qqJBtEyVgS0ZmPGdCFPWJ3FreoqvG4MVQln/kCgF7Olq95IbOp0/BWyMwbdtn4VTvkM8Y7khCQ2Xgk/tcrCXig==", + "version": "2.5.2", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.2.tgz", + "integrity": "sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==", "dev": true, "dependencies": { "bytes": "3.1.2", @@ -3357,6 +3444,23 @@ "node": ">= 0.8.0" } }, + "node_modules/set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -3406,14 +3510,18 @@ } }, "node_modules/side-channel": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", - "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", + "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", "dev": true, "dependencies": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" + "call-bind": "^1.0.7", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.4", + "object-inspect": "^1.13.1" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -4868,13 +4976,13 @@ "dev": true }, "body-parser": { - "version": "1.20.1", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.1.tgz", - "integrity": "sha512-jWi7abTbYwajOytWCQc37VulmWiRae5RyTpaCyDcS5/lMdtwSz5lOpDE67srw/HYe35f1z3fDQw+3txg7gNtWw==", + "version": "1.20.2", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", + "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", "dev": true, "requires": { "bytes": "3.1.2", - "content-type": "~1.0.4", + "content-type": "~1.0.5", "debug": "2.6.9", "depd": "2.0.0", "destroy": "1.2.0", @@ -4882,7 +4990,7 @@ "iconv-lite": "0.4.24", "on-finished": "2.4.1", "qs": "6.11.0", - "raw-body": "2.5.1", + "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" }, @@ -4992,13 +5100,16 @@ } }, "call-bind": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", - "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", + "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", "dev": true, "requires": { - "function-bind": "^1.1.1", - "get-intrinsic": "^1.0.2" + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.1" } }, "caniuse-lite": { @@ -5144,9 +5255,9 @@ "dev": true }, "cookie": { - "version": "0.5.0", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.5.0.tgz", - "integrity": "sha512-YZ3GUyn/o8gfKJlnlX7g7xq4gyO6OSuhGPKaaGssGB2qgDUS0gPgtTvoyZLTt9Ab6dC4hfc9dV5arkvc/OCmrw==", + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", + "integrity": "sha512-U71cyTamuh1CRNCfpGY6to28lxvNwPG4Guz/EVjgf3Jmzv0vlDp1atT9eS5dDjMYHucpHbWns6Lwf3BKz6svdw==", "dev": true }, "cookie-signature": { @@ -5220,6 +5331,17 @@ "execa": "^5.0.0" } }, + "define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "requires": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + } + }, "define-lazy-prop": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-2.0.0.tgz", @@ -5308,6 +5430,21 @@ "integrity": "sha512-ZtUjZO6l5mwTHvc1L9+1q5p/R3wTopcfqMW8r5t8SJSKqeVI/LtajORwRFEKpEFuekjD0VBjwu1HMxL4UalIRw==", "dev": true }, + "es-define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", + "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "dev": true, + "requires": { + "get-intrinsic": "^1.2.4" + } + }, + "es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true + }, "es-module-lexer": { "version": "1.3.1", "resolved": "https://registry.npmjs.org/es-module-lexer/-/es-module-lexer-1.3.1.tgz", @@ -5395,17 +5532,17 @@ } }, "express": { - "version": "4.18.2", - "resolved": "https://registry.npmjs.org/express/-/express-4.18.2.tgz", - "integrity": "sha512-5/PsL6iGPdfQ/lKM1UuielYgv3BUoJfz1aUwU9vHZ+J7gyvwdQXFEBIEIaxeGf0GIcreATNyBExtalisDbuMqQ==", + "version": "4.19.2", + "resolved": "https://registry.npmjs.org/express/-/express-4.19.2.tgz", + "integrity": "sha512-5T6nhjsT+EOMzuck8JjBHARTHfMht0POzlA60WV2pMD3gyXw2LZnZ+ueGdNxG+0calOJcWKbpFcuzLZ91YWq9Q==", "dev": true, "requires": { "accepts": "~1.3.8", "array-flatten": "1.1.1", - "body-parser": "1.20.1", + "body-parser": "1.20.2", "content-disposition": "0.5.4", "content-type": "~1.0.4", - "cookie": "0.5.0", + "cookie": "0.6.0", "cookie-signature": "1.0.6", "debug": "2.6.9", "depd": "2.0.0", @@ -5626,21 +5763,22 @@ "optional": true }, "function-bind": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.1.tgz", - "integrity": "sha512-yIovAzMX49sF8Yl58fSCWJ5svSLuaibPxXQJFLmBObTuCr0Mf1KiPopGM9NiFjiYBCbfaa2Fh6breQ6ANVTI0A==", + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", + "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", "dev": true }, "get-intrinsic": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.1.tgz", - "integrity": "sha512-2DcsyfABl+gVHEfCOaTrWgyt+tb6MSEGmKq+kI5HwLbIYgjgmMcV8KQ41uaKz1xxUcn9tJtgFbQUEVcEbd0FYw==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", + "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", "dev": true, "requires": { - "function-bind": "^1.1.1", - "has": "^1.0.3", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", "has-proto": "^1.0.1", - "has-symbols": "^1.0.3" + "has-symbols": "^1.0.3", + "hasown": "^2.0.0" } }, "get-stream": { @@ -5692,6 +5830,15 @@ "slash": "^3.0.0" } }, + "gopd": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", + "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "dev": true, + "requires": { + "get-intrinsic": "^1.1.3" + } + }, "graceful-fs": { "version": "4.2.11", "resolved": "https://registry.npmjs.org/graceful-fs/-/graceful-fs-4.2.11.tgz", @@ -5719,10 +5866,19 @@ "integrity": "sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==", "dev": true }, + "has-property-descriptors": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", + "dev": true, + "requires": { + "es-define-property": "^1.0.0" + } + }, "has-proto": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.1.tgz", - "integrity": "sha512-7qE+iP+O+bgF9clE5+UoBFzE65mlBiVj3tKCrlNQ0Ogwm0BjpT/gK4SlLYDMybDh5I3TCTKnPPa0oMG7JDYrhg==", + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", + "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", "dev": true }, "has-symbols": { @@ -5731,6 +5887,15 @@ "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", "dev": true }, + "hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "requires": { + "function-bind": "^1.1.2" + } + }, "hpack.js": { "version": "2.1.6", "resolved": "https://registry.npmjs.org/hpack.js/-/hpack.js-2.1.6.tgz", @@ -6284,9 +6449,9 @@ } }, "object-inspect": { - "version": "1.12.3", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", - "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "version": "1.13.1", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.1.tgz", + "integrity": "sha512-5qoj1RUiKOMsCCNLV1CBiPYE10sziTsnmNxkAI/rZhiD63CF7IqdFGC/XzjWjpSgLf0LxXX3bDFIh0E18f6UhQ==", "dev": true }, "obuf": { @@ -6523,9 +6688,9 @@ "dev": true }, "raw-body": { - "version": "2.5.1", - "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.1.tgz", - "integrity": "sha512-qqJBtEyVgS0ZmPGdCFPWJ3FreoqvG4MVQln/kCgF7Olq95IbOp0/BWyMwbdtn4VTvkM8Y7khCQ2Xgk/tcrCXig==", + "version": "2.5.2", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-2.5.2.tgz", + "integrity": "sha512-8zGqypfENjCIqGhgXToC8aB2r7YrBX+AQAfIPs/Mlk+BtPTztOvTS01NRW/3Eh60J+a48lt8qsCzirQ6loCVfA==", "dev": true, "requires": { "bytes": "3.1.2", @@ -6813,6 +6978,20 @@ "send": "0.18.0" } }, + "set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "requires": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + } + }, "setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -6850,14 +7029,15 @@ "dev": true }, "side-channel": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", - "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", + "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", "dev": true, "requires": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" + "call-bind": "^1.0.7", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.4", + "object-inspect": "^1.13.1" } }, "signal-exit": { From 6f9948b8c027f782431805331de174c4092de40a Mon Sep 17 00:00:00 2001 From: Trent Hauck Date: Thu, 28 Mar 2024 13:40:36 -0700 Subject: [PATCH 087/117] feat: pass SessionState not SessionConfig to FunctionFactory::create (#9837) --- datafusion-examples/examples/function_factory.rs | 7 ++++--- datafusion/core/src/execution/context/mod.rs | 4 ++-- .../tests/user_defined/user_defined_scalar_functions.rs | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/datafusion-examples/examples/function_factory.rs b/datafusion-examples/examples/function_factory.rs index 6c033e6c8eef..a7c8558c6da8 100644 --- a/datafusion-examples/examples/function_factory.rs +++ b/datafusion-examples/examples/function_factory.rs @@ -16,8 +16,9 @@ // under the License. use datafusion::error::Result; -use datafusion::execution::config::SessionConfig; -use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionContext}; +use datafusion::execution::context::{ + FunctionFactory, RegisterFunction, SessionContext, SessionState, +}; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{exec_err, internal_err, DataFusionError}; use datafusion_expr::simplify::ExprSimplifyResult; @@ -91,7 +92,7 @@ impl FunctionFactory for CustomFunctionFactory { /// the function instance. async fn create( &self, - _state: &SessionConfig, + _state: &SessionState, statement: CreateFunction, ) -> Result { let f: ScalarFunctionWrapper = statement.try_into()?; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 116e45c8c130..31f390607f04 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -794,7 +794,7 @@ impl SessionContext { let function_factory = &state.function_factory; match function_factory { - Some(f) => f.create(state.config(), stmt).await?, + Some(f) => f.create(&state, stmt).await?, _ => Err(DataFusionError::Configuration( "Function factory has not been configured".into(), ))?, @@ -1288,7 +1288,7 @@ pub trait FunctionFactory: Sync + Send { /// Handles creation of user defined function specified in [CreateFunction] statement async fn create( &self, - state: &SessionConfig, + state: &SessionState, statement: CreateFunction, ) -> Result; } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index b525e4fc6341..86be887198ae 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -747,7 +747,7 @@ struct CustomFunctionFactory {} impl FunctionFactory for CustomFunctionFactory { async fn create( &self, - _state: &SessionConfig, + _state: &SessionState, statement: CreateFunction, ) -> Result { let f: ScalarFunctionWrapper = statement.try_into()?; From 81c96fc3db0ea35638278f32df066be63b745a51 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 28 Mar 2024 17:37:25 -0600 Subject: [PATCH 088/117] Prepare 37.0.0 Release (#9697) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * bump version * changelog * Update configs.md * Update Cargo.toml * Update 37.0.0.md * Update 37.0.0.md * Update 37.0.0.md * update changelog * update changelog --------- Co-authored-by: Daniël Heres --- Cargo.toml | 30 +-- datafusion-cli/Cargo.lock | 24 +-- datafusion-cli/Cargo.toml | 4 +- datafusion/CHANGELOG.md | 1 + dev/changelog/37.0.0.md | 347 ++++++++++++++++++++++++++++++ docs/source/user-guide/configs.md | 2 +- 6 files changed, 378 insertions(+), 30 deletions(-) create mode 100644 dev/changelog/37.0.0.md diff --git a/Cargo.toml b/Cargo.toml index c3dade8bc6c5..8e89e5ef3b85 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,7 +49,7 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/apache/arrow-datafusion" rust-version = "1.72" -version = "36.0.0" +version = "37.0.0" [workspace.dependencies] # We turn off default-features for some dependencies here so the workspaces which inherit them can @@ -71,20 +71,20 @@ bytes = "1.4" chrono = { version = "0.4.34", default-features = false } ctor = "0.2.0" dashmap = "5.4.0" -datafusion = { path = "datafusion/core", version = "36.0.0", default-features = false } -datafusion-common = { path = "datafusion/common", version = "36.0.0", default-features = false } -datafusion-common-runtime = { path = "datafusion/common-runtime", version = "36.0.0" } -datafusion-execution = { path = "datafusion/execution", version = "36.0.0" } -datafusion-expr = { path = "datafusion/expr", version = "36.0.0" } -datafusion-functions = { path = "datafusion/functions", version = "36.0.0" } -datafusion-functions-array = { path = "datafusion/functions-array", version = "36.0.0" } -datafusion-optimizer = { path = "datafusion/optimizer", version = "36.0.0", default-features = false } -datafusion-physical-expr = { path = "datafusion/physical-expr", version = "36.0.0", default-features = false } -datafusion-physical-plan = { path = "datafusion/physical-plan", version = "36.0.0" } -datafusion-proto = { path = "datafusion/proto", version = "36.0.0" } -datafusion-sql = { path = "datafusion/sql", version = "36.0.0" } -datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "36.0.0" } -datafusion-substrait = { path = "datafusion/substrait", version = "36.0.0" } +datafusion = { path = "datafusion/core", version = "37.0.0", default-features = false } +datafusion-common = { path = "datafusion/common", version = "37.0.0", default-features = false } +datafusion-common-runtime = { path = "datafusion/common-runtime", version = "37.0.0" } +datafusion-execution = { path = "datafusion/execution", version = "37.0.0" } +datafusion-expr = { path = "datafusion/expr", version = "37.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "37.0.0" } +datafusion-functions-array = { path = "datafusion/functions-array", version = "37.0.0" } +datafusion-optimizer = { path = "datafusion/optimizer", version = "37.0.0", default-features = false } +datafusion-physical-expr = { path = "datafusion/physical-expr", version = "37.0.0", default-features = false } +datafusion-physical-plan = { path = "datafusion/physical-plan", version = "37.0.0" } +datafusion-proto = { path = "datafusion/proto", version = "37.0.0" } +datafusion-sql = { path = "datafusion/sql", version = "37.0.0" } +datafusion-sqllogictest = { path = "datafusion/sqllogictest", version = "37.0.0" } +datafusion-substrait = { path = "datafusion/substrait", version = "37.0.0" } doc-comment = "0.3" env_logger = "0.11" futures = "0.3" diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index ba60c04cea55..0277d23f4de0 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1116,7 +1116,7 @@ dependencies = [ [[package]] name = "datafusion" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "apache-avro", @@ -1167,7 +1167,7 @@ dependencies = [ [[package]] name = "datafusion-cli" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "assert_cmd", @@ -1195,7 +1195,7 @@ dependencies = [ [[package]] name = "datafusion-common" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "apache-avro", @@ -1215,14 +1215,14 @@ dependencies = [ [[package]] name = "datafusion-common-runtime" -version = "36.0.0" +version = "37.0.0" dependencies = [ "tokio", ] [[package]] name = "datafusion-execution" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "chrono", @@ -1241,7 +1241,7 @@ dependencies = [ [[package]] name = "datafusion-expr" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "arrow", @@ -1256,7 +1256,7 @@ dependencies = [ [[package]] name = "datafusion-functions" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "base64 0.22.0", @@ -1279,7 +1279,7 @@ dependencies = [ [[package]] name = "datafusion-functions-array" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "arrow-array", @@ -1297,7 +1297,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "async-trait", @@ -1313,7 +1313,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "arrow", @@ -1346,7 +1346,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" -version = "36.0.0" +version = "37.0.0" dependencies = [ "ahash", "arrow", @@ -1375,7 +1375,7 @@ dependencies = [ [[package]] name = "datafusion-sql" -version = "36.0.0" +version = "37.0.0" dependencies = [ "arrow", "arrow-array", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index da744a06f3aa..18e14357314e 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "datafusion-cli" description = "Command Line Client for DataFusion query engine." -version = "36.0.0" +version = "37.0.0" authors = ["Apache Arrow "] edition = "2021" keywords = ["arrow", "datafusion", "query", "sql"] @@ -35,7 +35,7 @@ async-trait = "0.1.41" aws-config = "0.55" aws-credential-types = "0.55" clap = { version = "3", features = ["derive", "cargo"] } -datafusion = { path = "../datafusion/core", version = "36.0.0", features = [ +datafusion = { path = "../datafusion/core", version = "37.0.0", features = [ "avro", "crypto_expressions", "datetime_expressions", diff --git a/datafusion/CHANGELOG.md b/datafusion/CHANGELOG.md index 2d09782a3982..c111375e3058 100644 --- a/datafusion/CHANGELOG.md +++ b/datafusion/CHANGELOG.md @@ -19,6 +19,7 @@ # Changelog +- [37.0.0](../dev/changelog/37.0.0.md) - [36.0.0](../dev/changelog/36.0.0.md) - [35.0.0](../dev/changelog/35.0.0.md) - [34.0.0](../dev/changelog/34.0.0.md) diff --git a/dev/changelog/37.0.0.md b/dev/changelog/37.0.0.md new file mode 100644 index 000000000000..b1fcd5fdf008 --- /dev/null +++ b/dev/changelog/37.0.0.md @@ -0,0 +1,347 @@ + + +## [37.0.0](https://github.com/apache/arrow-datafusion/tree/37.0.0) (2024-03-28) + +[Full Changelog](https://github.com/apache/arrow-datafusion/compare/36.0.0...37.0.0) + +**Breaking changes:** + +- refactor: Change `SchemaProvider::table` to return `Result` rather than `Option<..>` [#9307](https://github.com/apache/arrow-datafusion/pull/9307) (crepererum) +- feat: issue_9285: port builtin reg function into datafusion-function-\* (1/3 regexpmatch) [#9329](https://github.com/apache/arrow-datafusion/pull/9329) (Lordworms) +- Cache common plan properties to eliminate recursive calls in physical plan [#9346](https://github.com/apache/arrow-datafusion/pull/9346) (mustafasrepo) +- Consolidate `TreeNode` transform and rewrite APIs [#8891](https://github.com/apache/arrow-datafusion/pull/8891) (peter-toth) +- Extend argument types for udf `return_type_from_exprs` [#9522](https://github.com/apache/arrow-datafusion/pull/9522) (jayzhan211) +- Systematic Configuration in 'Create External Table' and 'Copy To' Options [#9382](https://github.com/apache/arrow-datafusion/pull/9382) (metesynnada) +- Move trim functions (btrim, ltrim, rtrim) to datafusion_functions, make expr_fn API consistent [#9730](https://github.com/apache/arrow-datafusion/pull/9730) (Omega359) + +**Performance related:** + +- perf: improve to_field performance [#9722](https://github.com/apache/arrow-datafusion/pull/9722) (haohuaijin) + +**Implemented enhancements:** + +- feat: support for defining ARRAY columns in `CREATE TABLE` [#9381](https://github.com/apache/arrow-datafusion/pull/9381) (jonahgao) +- feat: support `unnest` in FROM clause [#9355](https://github.com/apache/arrow-datafusion/pull/9355) (jonahgao) +- feat: support nvl2 function [#9364](https://github.com/apache/arrow-datafusion/pull/9364) (guojidan) +- feat: issue #9224 substitute tlide in table path [#9259](https://github.com/apache/arrow-datafusion/pull/9259) (Lordworms) +- feat: replace std Instant with wasm-compatible wrapper [#9189](https://github.com/apache/arrow-datafusion/pull/9189) (waynexia) +- feat: support `unnest` with additional columns [#9400](https://github.com/apache/arrow-datafusion/pull/9400) (jonahgao) +- feat: Support `EscapedStringLiteral`, update sqlparser to `0.44.0` [#9268](https://github.com/apache/arrow-datafusion/pull/9268) (JasonLi-cn) +- feat: add support for fixed list wildcard in type signature [#9312](https://github.com/apache/arrow-datafusion/pull/9312) (universalmind303) +- feat: Add projection to HashJoinExec. [#9236](https://github.com/apache/arrow-datafusion/pull/9236) (my-vegetable-has-exploded) +- feat: function name hints for UDFs [#9407](https://github.com/apache/arrow-datafusion/pull/9407) (SteveLauC) +- feat: Introduce convert Expr to SQL string API and basic feature [#9517](https://github.com/apache/arrow-datafusion/pull/9517) (backkem) +- feat: implement more expr_to_sql functionality [#9578](https://github.com/apache/arrow-datafusion/pull/9578) (devinjdangelo) +- feat: implement aggregation and subquery plans to SQL [#9606](https://github.com/apache/arrow-datafusion/pull/9606) (devinjdangelo) +- feat: track memory usage for recursive CTE, enable recursive CTEs by default [#9619](https://github.com/apache/arrow-datafusion/pull/9619) (jonahgao) +- feat: Between expr to sql string [#9803](https://github.com/apache/arrow-datafusion/pull/9803) (sebastian2296) +- feat: Expose `array_empty` and `list_empty` functions as alias of `empty` function [#9807](https://github.com/apache/arrow-datafusion/pull/9807) (erenavsarogullari) +- feat: Not expr to string [#9802](https://github.com/apache/arrow-datafusion/pull/9802) (sebastian2296) +- feat: pass SessionState not SessionConfig to FunctionFactory::create [#9837](https://github.com/apache/arrow-datafusion/pull/9837) (tshauck) + +**Fixed bugs:** + +- fix: use `JoinSet` to make spawned tasks cancel-safe [#9318](https://github.com/apache/arrow-datafusion/pull/9318) (DDtKey) +- fix: nvl function's return type [#9357](https://github.com/apache/arrow-datafusion/pull/9357) (guojidan) +- fix: panic in isnan() when no args are given [#9377](https://github.com/apache/arrow-datafusion/pull/9377) (SteveLauC) +- fix: using test data sample for catalog example [#9372](https://github.com/apache/arrow-datafusion/pull/9372) (korowa) +- fix: sort_batch function unsupported mixed types with list [#9410](https://github.com/apache/arrow-datafusion/pull/9410) (JasonLi-cn) +- fix: casting to ARRAY types failed [#9441](https://github.com/apache/arrow-datafusion/pull/9441) (jonahgao) +- fix: reading from partitioned `json` & `arrow` tables [#9431](https://github.com/apache/arrow-datafusion/pull/9431) (korowa) +- fix: coalesce function should return correct data type [#9459](https://github.com/apache/arrow-datafusion/pull/9459) (viirya) +- fix: `generate_series` and `range` panic on edge cases [#9503](https://github.com/apache/arrow-datafusion/pull/9503) (jonahgao) +- fix: `substr_index` not handling negative occurrence correctly [#9475](https://github.com/apache/arrow-datafusion/pull/9475) (jonahgao) +- fix: support two argument TRIM [#9521](https://github.com/apache/arrow-datafusion/pull/9521) (tshauck) +- fix: incorrect null handling in `range` and `generate_series` [#9574](https://github.com/apache/arrow-datafusion/pull/9574) (jonahgao) +- fix: recursive cte hangs on joins [#9687](https://github.com/apache/arrow-datafusion/pull/9687) (jonahgao) +- fix: parallel parquet can underflow when max_record_batch_rows < execution.batch_size [#9737](https://github.com/apache/arrow-datafusion/pull/9737) (devinjdangelo) +- fix: change placeholder errors from Internal to Plan [#9745](https://github.com/apache/arrow-datafusion/pull/9745) (erratic-pattern) +- fix: ensure mutual compatibility of the two input schemas from recursive CTEs [#9795](https://github.com/apache/arrow-datafusion/pull/9795) (jonahgao) + +**Documentation updates:** + +- docs: put flatten in top fn list [#9376](https://github.com/apache/arrow-datafusion/pull/9376) (SteveLauC) +- Update documentation so list_to_string alias to point to array_to_string [#9374](https://github.com/apache/arrow-datafusion/pull/9374) (monkwire) +- Uplift keys/dependencies to use more workspace inheritance [#9293](https://github.com/apache/arrow-datafusion/pull/9293) (Jefffrey) +- docs: update contributor guide (migration to sqllogictest is done) [#9408](https://github.com/apache/arrow-datafusion/pull/9408) (SteveLauC) +- Move the to_timestamp\* functions to datafusion-functions [#9388](https://github.com/apache/arrow-datafusion/pull/9388) (Omega359) +- NEW Logo [#9385](https://github.com/apache/arrow-datafusion/pull/9385) (pinarbayata) +- Minor: docs: rm duplicate words. [#9449](https://github.com/apache/arrow-datafusion/pull/9449) (my-vegetable-has-exploded) +- Update contributor guide with updated scalar function howto [#9438](https://github.com/apache/arrow-datafusion/pull/9438) (Omega359) +- docs: fix extraneous char in array functions table of contents [#9560](https://github.com/apache/arrow-datafusion/pull/9560) (tshauck) +- doc: Add missing doc link [#9631](https://github.com/apache/arrow-datafusion/pull/9631) (Weijun-H) +- chore: remove repetitive word `the the` --> `the` in docs / comments [#9673](https://github.com/apache/arrow-datafusion/pull/9673) (InventiveCoder) +- Update example-usage.md to remove reference to simd and rust nightly. [#9677](https://github.com/apache/arrow-datafusion/pull/9677) (Omega359) +- Minor: Improve documentation for `LogicalPlan::expressions` [#9698](https://github.com/apache/arrow-datafusion/pull/9698) (alamb) +- Add Minimum Supported Rust Version policy to docs [#9681](https://github.com/apache/arrow-datafusion/pull/9681) (alamb) +- doc: Updated known users list and usage dependency description [#9718](https://github.com/apache/arrow-datafusion/pull/9718) (comphead) + +**Merged pull requests:** + +- refactor: Change `SchemaProvider::table` to return `Result` rather than `Option<..>` [#9307](https://github.com/apache/arrow-datafusion/pull/9307) (crepererum) +- fix write_partitioned_parquet_results test case bug [#9360](https://github.com/apache/arrow-datafusion/pull/9360) (guojidan) +- fix: use `JoinSet` to make spawned tasks cancel-safe [#9318](https://github.com/apache/arrow-datafusion/pull/9318) (DDtKey) +- Update nix requirement from 0.27.1 to 0.28.0 [#9344](https://github.com/apache/arrow-datafusion/pull/9344) (dependabot[bot]) +- Replace usages of internal_err with exec_err where appropriate [#9241](https://github.com/apache/arrow-datafusion/pull/9241) (Omega359) +- feat : Support for deregistering user defined functions [#9239](https://github.com/apache/arrow-datafusion/pull/9239) (mobley-trent) +- fix: nvl function's return type [#9357](https://github.com/apache/arrow-datafusion/pull/9357) (guojidan) +- refactor: move acos() to function crate [#9297](https://github.com/apache/arrow-datafusion/pull/9297) (SteveLauC) +- docs: put flatten in top fn list [#9376](https://github.com/apache/arrow-datafusion/pull/9376) (SteveLauC) +- Update documentation so list_to_string alias to point to array_to_string [#9374](https://github.com/apache/arrow-datafusion/pull/9374) (monkwire) +- feat: issue_9285: port builtin reg function into datafusion-function-\* (1/3 regexpmatch) [#9329](https://github.com/apache/arrow-datafusion/pull/9329) (Lordworms) +- Add test to verify issue #9161 [#9265](https://github.com/apache/arrow-datafusion/pull/9265) (jonahgao) +- refactor: fix error macros hygiene (always import `DataFusionError`) [#9366](https://github.com/apache/arrow-datafusion/pull/9366) (crepererum) +- feat: support for defining ARRAY columns in `CREATE TABLE` [#9381](https://github.com/apache/arrow-datafusion/pull/9381) (jonahgao) +- fix: panic in isnan() when no args are given [#9377](https://github.com/apache/arrow-datafusion/pull/9377) (SteveLauC) +- feat: support `unnest` in FROM clause [#9355](https://github.com/apache/arrow-datafusion/pull/9355) (jonahgao) +- feat: support nvl2 function [#9364](https://github.com/apache/arrow-datafusion/pull/9364) (guojidan) +- refactor: move asin() to function crate [#9379](https://github.com/apache/arrow-datafusion/pull/9379) (SteveLauC) +- fix: using test data sample for catalog example [#9372](https://github.com/apache/arrow-datafusion/pull/9372) (korowa) +- delete tail space, fix `error: unused import: DataFusionError` [#9386](https://github.com/apache/arrow-datafusion/pull/9386) (Tangruilin) +- Run cargo-fmt on `datafusion-functions/core` [#9367](https://github.com/apache/arrow-datafusion/pull/9367) (alamb) +- Cache common plan properties to eliminate recursive calls in physical plan [#9346](https://github.com/apache/arrow-datafusion/pull/9346) (mustafasrepo) +- Run cargo-fmt on all of `datafusion-functions` [#9390](https://github.com/apache/arrow-datafusion/pull/9390) (alamb) +- feat: issue #9224 substitute tlide in table path [#9259](https://github.com/apache/arrow-datafusion/pull/9259) (Lordworms) +- port range function and change gen_series logic [#9352](https://github.com/apache/arrow-datafusion/pull/9352) (Lordworms) +- [MINOR]: Generate physical plan, instead of logical plan in the bench test [#9383](https://github.com/apache/arrow-datafusion/pull/9383) (mustafasrepo) +- Add `to_date` function [#9019](https://github.com/apache/arrow-datafusion/pull/9019) (Tangruilin) +- Minor: clarify performance in docs for `ScalarUDF`, `ScalarUDAF` and `ScalarUDWF` [#9384](https://github.com/apache/arrow-datafusion/pull/9384) (alamb) +- feat: replace std Instant with wasm-compatible wrapper [#9189](https://github.com/apache/arrow-datafusion/pull/9189) (waynexia) +- Uplift keys/dependencies to use more workspace inheritance [#9293](https://github.com/apache/arrow-datafusion/pull/9293) (Jefffrey) +- Improve documentation for ExecutionPlanProperties, use consistent field name [#9389](https://github.com/apache/arrow-datafusion/pull/9389) (alamb) +- Doc: Workaround for Running cargo test locally without signficant memory [#9402](https://github.com/apache/arrow-datafusion/pull/9402) (devinjdangelo) +- feat: support `unnest` with additional columns [#9400](https://github.com/apache/arrow-datafusion/pull/9400) (jonahgao) +- Minor: improve the display name of `unnest` expressions [#9412](https://github.com/apache/arrow-datafusion/pull/9412) (jonahgao) +- Minor: Move function signature check to planning stage [#9401](https://github.com/apache/arrow-datafusion/pull/9401) (2010YOUY01) +- chore(deps): update substrait requirement from 0.24.0 to 0.25.1 [#9406](https://github.com/apache/arrow-datafusion/pull/9406) (dependabot[bot]) +- docs: update contributor guide (migration to sqllogictest is done) [#9408](https://github.com/apache/arrow-datafusion/pull/9408) (SteveLauC) +- Move the to_timestamp\* functions to datafusion-functions [#9388](https://github.com/apache/arrow-datafusion/pull/9388) (Omega359) +- Minor: Support LargeList List Range indexing and fix large list handling in ConstEvaluator [#9393](https://github.com/apache/arrow-datafusion/pull/9393) (jayzhan211) +- NEW Logo [#9385](https://github.com/apache/arrow-datafusion/pull/9385) (pinarbayata) +- Handle serde for ScalarUDF [#9395](https://github.com/apache/arrow-datafusion/pull/9395) (yyy1000) +- Minior: Add tests with `sqrt` with negative argument [#9426](https://github.com/apache/arrow-datafusion/pull/9426) (caicancai) +- Move SpawnedTask from datafusion_physical_plan to new `datafusion_common_runtime` crate [#9414](https://github.com/apache/arrow-datafusion/pull/9414) (mustafasrepo) +- Re-export datafusion-functions-array [#9433](https://github.com/apache/arrow-datafusion/pull/9433) (andygrove) +- Minor: Support LargeList for ListIndex [#9424](https://github.com/apache/arrow-datafusion/pull/9424) (PsiACE) +- move ArrayDims, ArrayNdims and Cardinality to datafusion-function-crate [#9425](https://github.com/apache/arrow-datafusion/pull/9425) (Weijun-H) +- refactor: make instr() an alias of strpos() [#9396](https://github.com/apache/arrow-datafusion/pull/9396) (SteveLauC) +- Add test case for invalid tz in timestamp literal [#9429](https://github.com/apache/arrow-datafusion/pull/9429) (MohamedAbdeen21) +- Minor: simplify call [#9434](https://github.com/apache/arrow-datafusion/pull/9434) (alamb) +- Support IGNORE NULLS for LEAD window function [#9419](https://github.com/apache/arrow-datafusion/pull/9419) (comphead) +- fix sqllogicaltest result [#9444](https://github.com/apache/arrow-datafusion/pull/9444) (jackwener) +- Minor: docs: rm duplicate words. [#9449](https://github.com/apache/arrow-datafusion/pull/9449) (my-vegetable-has-exploded) +- minor: fix cargo clippy some warning [#9442](https://github.com/apache/arrow-datafusion/pull/9442) (jackwener) +- port regexp_like function and port related tests [#9397](https://github.com/apache/arrow-datafusion/pull/9397) (Lordworms) +- fix: sort_batch function unsupported mixed types with list [#9410](https://github.com/apache/arrow-datafusion/pull/9410) (JasonLi-cn) +- refactor: add `join_unwind` to `SpawnedTask` [#9422](https://github.com/apache/arrow-datafusion/pull/9422) (DDtKey) +- Ignore null LEAD support for small batch sizes. [#9445](https://github.com/apache/arrow-datafusion/pull/9445) (mustafasrepo) +- fix: casting to ARRAY types failed [#9441](https://github.com/apache/arrow-datafusion/pull/9441) (jonahgao) +- fix: reading from partitioned `json` & `arrow` tables [#9431](https://github.com/apache/arrow-datafusion/pull/9431) (korowa) +- feat: Support `EscapedStringLiteral`, update sqlparser to `0.44.0` [#9268](https://github.com/apache/arrow-datafusion/pull/9268) (JasonLi-cn) +- Minor: fix LEAD test description [#9451](https://github.com/apache/arrow-datafusion/pull/9451) (comphead) +- Consolidate `TreeNode` transform and rewrite APIs [#8891](https://github.com/apache/arrow-datafusion/pull/8891) (peter-toth) +- Support `Date32` arguments for `generate_series` [#9420](https://github.com/apache/arrow-datafusion/pull/9420) (Lordworms) +- Minor: change doc for range [#9455](https://github.com/apache/arrow-datafusion/pull/9455) (Lordworms) +- doc: add missing function index in scalar_expression.md [#9462](https://github.com/apache/arrow-datafusion/pull/9462) (Weijun-H) +- build: Update bigdecimal version in `Cargo.toml` [#9471](https://github.com/apache/arrow-datafusion/pull/9471) (comphead) +- chore(deps): update base64 requirement from 0.21 to 0.22 [#9446](https://github.com/apache/arrow-datafusion/pull/9446) (dependabot[bot]) +- Port regexp_replace functions and related tests [#9454](https://github.com/apache/arrow-datafusion/pull/9454) (Lordworms) +- Update contributor guide with updated scalar function howto [#9438](https://github.com/apache/arrow-datafusion/pull/9438) (Omega359) +- feat: add support for fixed list wildcard in type signature [#9312](https://github.com/apache/arrow-datafusion/pull/9312) (universalmind303) +- Add a `ScalarUDFImpl::simplfy()` API, move `SimplifyInfo` et al to datafusion_expr [#9304](https://github.com/apache/arrow-datafusion/pull/9304) (jayzhan211) +- Implement IGNORE NULLS for FIRST_VALUE [#9411](https://github.com/apache/arrow-datafusion/pull/9411) (huaxingao) +- Add plugable handler for `CREATE FUNCTION` [#9333](https://github.com/apache/arrow-datafusion/pull/9333) (milenkovicm) +- Enable configurable display of partition sizes in the explain statement [#9474](https://github.com/apache/arrow-datafusion/pull/9474) (jayzhan211) +- Reduce casts for LEAD/LAG [#9468](https://github.com/apache/arrow-datafusion/pull/9468) (comphead) +- [CI build] fix chrono suggestions [#9486](https://github.com/apache/arrow-datafusion/pull/9486) (comphead) +- Make regex dependency optional in datafusion-functions, add CI checks for function packages [#9473](https://github.com/apache/arrow-datafusion/pull/9473) (alamb) +- fix: coalesce function should return correct data type [#9459](https://github.com/apache/arrow-datafusion/pull/9459) (viirya) +- LEAD/LAG calculate default value once [#9485](https://github.com/apache/arrow-datafusion/pull/9485) (comphead) +- chore: simplify the return type of `validate_data_types()` [#9491](https://github.com/apache/arrow-datafusion/pull/9491) (waynexia) +- minor: use arrow-rs casting from Float to Timestamp [#9500](https://github.com/apache/arrow-datafusion/pull/9500) (comphead) +- chore(deps): update substrait requirement from 0.25.1 to 0.27.0 [#9502](https://github.com/apache/arrow-datafusion/pull/9502) (dependabot[bot]) +- fix: `generate_series` and `range` panic on edge cases [#9503](https://github.com/apache/arrow-datafusion/pull/9503) (jonahgao) +- Fix undeterministic behaviour of schema nullability of lag window query [#9508](https://github.com/apache/arrow-datafusion/pull/9508) (mustafasrepo) +- Add `to_unixtime` function [#9077](https://github.com/apache/arrow-datafusion/pull/9077) (Tangruilin) +- Minor: fixed transformed state in UDF Simplify [#9484](https://github.com/apache/arrow-datafusion/pull/9484) (alamb) +- test: port strpos test in physical_expr/src/functions to sqllogictest [#9439](https://github.com/apache/arrow-datafusion/pull/9439) (SteveLauC) +- Port ArrayHas family to `functions-array` [#9496](https://github.com/apache/arrow-datafusion/pull/9496) (jayzhan211) +- port array_empty and array_length to datafusion-function-array crate [#9510](https://github.com/apache/arrow-datafusion/pull/9510) (Weijun-H) +- fix: `substr_index` not handling negative occurrence correctly [#9475](https://github.com/apache/arrow-datafusion/pull/9475) (jonahgao) +- [minor] extract collect file statistics method and add doc [#9490](https://github.com/apache/arrow-datafusion/pull/9490) (Ted-Jiang) +- test: sqllogictests for multiple tables join [#9480](https://github.com/apache/arrow-datafusion/pull/9480) (korowa) +- Add support for ignore nulls for LEAD, LAG in WindowAggExec [#9498](https://github.com/apache/arrow-datafusion/pull/9498) (Lordworms) +- Minior: Improve log expr description [#9516](https://github.com/apache/arrow-datafusion/pull/9516) (caicancai) +- port flatten to datafusion-function-array [#9523](https://github.com/apache/arrow-datafusion/pull/9523) (Weijun-H) +- feat: Add projection to HashJoinExec. [#9236](https://github.com/apache/arrow-datafusion/pull/9236) (my-vegetable-has-exploded) +- Add example for `FunctionFactory` [#9482](https://github.com/apache/arrow-datafusion/pull/9482) (milenkovicm) +- Move date_part, date_trunc, date_bin functions to datafusion-functions [#9435](https://github.com/apache/arrow-datafusion/pull/9435) (Omega359) +- fix: support two argument TRIM [#9521](https://github.com/apache/arrow-datafusion/pull/9521) (tshauck) +- Remove physical expr of ListIndex and ListRange, convert to `array_element` and `array_slice` functions [#9492](https://github.com/apache/arrow-datafusion/pull/9492) (jayzhan211) +- feat: function name hints for UDFs [#9407](https://github.com/apache/arrow-datafusion/pull/9407) (SteveLauC) +- Minor: Improve documentation for registering `AnalyzerRule` [#9520](https://github.com/apache/arrow-datafusion/pull/9520) (alamb) +- Extend argument types for udf `return_type_from_exprs` [#9522](https://github.com/apache/arrow-datafusion/pull/9522) (jayzhan211) +- move make_array array_append array_prepend array_concat function to datafusion-functions-array crate [#9504](https://github.com/apache/arrow-datafusion/pull/9504) (guojidan) +- Port `StringToArray` to `function-arrays` subcrate [#9543](https://github.com/apache/arrow-datafusion/pull/9543) (erenavsarogullari) +- Minor: remove `..` pattern matching in sql planner [#9531](https://github.com/apache/arrow-datafusion/pull/9531) (alamb) +- Minor: Fix document Interval syntax [#9542](https://github.com/apache/arrow-datafusion/pull/9542) (yyy1000) +- Port `struct` to datafusion-functions [#9546](https://github.com/apache/arrow-datafusion/pull/9546) (yyy1000) +- UDAF and UDWF support aliases [#9489](https://github.com/apache/arrow-datafusion/pull/9489) (lewiszlw) +- docs: fix extraneous char in array functions table of contents [#9560](https://github.com/apache/arrow-datafusion/pull/9560) (tshauck) +- [MINOR]: Fix undeterministic test [#9559](https://github.com/apache/arrow-datafusion/pull/9559) (mustafasrepo) +- Port `arrow_typeof` to datafusion-function [#9524](https://github.com/apache/arrow-datafusion/pull/9524) (yyy1000) +- feat: Introduce convert Expr to SQL string API and basic feature [#9517](https://github.com/apache/arrow-datafusion/pull/9517) (backkem) +- Port `ArraySort` to `function-arrays` subcrate [#9551](https://github.com/apache/arrow-datafusion/pull/9551) (erenavsarogullari) +- refactor: unify some plan optimization in CommonSubexprEliminate [#9556](https://github.com/apache/arrow-datafusion/pull/9556) (jackwener) +- Port `ArrayDistinct` to `functions-array` subcrate [#9549](https://github.com/apache/arrow-datafusion/pull/9549) (erenavsarogullari) +- Minor: add a sql_planner benchmarks to reflecte select many field on a huge table [#9536](https://github.com/apache/arrow-datafusion/pull/9536) (haohuaijin) +- Support IGNORE NULLS for FIRST/LAST window function [#9470](https://github.com/apache/arrow-datafusion/pull/9470) (huaxingao) +- Systematic Configuration in 'Create External Table' and 'Copy To' Options [#9382](https://github.com/apache/arrow-datafusion/pull/9382) (metesynnada) +- fix: incorrect null handling in `range` and `generate_series` [#9574](https://github.com/apache/arrow-datafusion/pull/9574) (jonahgao) +- Update README.md [#9572](https://github.com/apache/arrow-datafusion/pull/9572) (Abdullahsab3) +- Port tan, tanh to datafusion-functions [#9535](https://github.com/apache/arrow-datafusion/pull/9535) (ongchi) +- feat(9493): provide access to FileMetaData for files written with ParquetSink [#9548](https://github.com/apache/arrow-datafusion/pull/9548) (wiedld) +- Export datafusion-functions UDFs publically [#9585](https://github.com/apache/arrow-datafusion/pull/9585) (alamb) +- Update the comment and Add a check [#9571](https://github.com/apache/arrow-datafusion/pull/9571) (colommar) +- Port `ArrayRepeat` to `functions-array` subcrate [#9568](https://github.com/apache/arrow-datafusion/pull/9568) (erenavsarogullari) +- Fix ApproxPercentileAccumulator on zero values [#9582](https://github.com/apache/arrow-datafusion/pull/9582) (Dandandan) +- Add `FunctionRewrite` API, Move Array specific rewrites to `datafusion_functions_array` [#9583](https://github.com/apache/arrow-datafusion/pull/9583) (alamb) +- Move from_unixtime, now, current_date, current_time functions to datafusion-functions [#9537](https://github.com/apache/arrow-datafusion/pull/9537) (Omega359) +- minor: update Debug trait impl for WindowsFrame [#9587](https://github.com/apache/arrow-datafusion/pull/9587) (comphead) +- Initial support LogicalPlan to SQL String [#9596](https://github.com/apache/arrow-datafusion/pull/9596) (backkem) +- refactor: use a common macro to define math UDFs [#9598](https://github.com/apache/arrow-datafusion/pull/9598) (jonahgao) +- Move all `crypto` related functions to `datafusion-functions` [#9590](https://github.com/apache/arrow-datafusion/pull/9590) (Lordworms) +- Remove physical expr of NamedStructField, convert to `get_field` function call [#9563](https://github.com/apache/arrow-datafusion/pull/9563) (yyy1000) +- Add `/benchmark` github command to comparison benchmark between base and pr commit [#9461](https://github.com/apache/arrow-datafusion/pull/9461) (gruuya) +- support unnest as subexpression [#9592](https://github.com/apache/arrow-datafusion/pull/9592) (YjyJeff) +- feat: implement more expr_to_sql functionality [#9578](https://github.com/apache/arrow-datafusion/pull/9578) (devinjdangelo) +- Port `ArrayResize` to `functions-array` subcrate [#9570](https://github.com/apache/arrow-datafusion/pull/9570) (erenavsarogullari) +- Move make_date, to_char to datafusion-functions [#9601](https://github.com/apache/arrow-datafusion/pull/9601) (Omega359) +- Fix to_timestamp benchmark [#9608](https://github.com/apache/arrow-datafusion/pull/9608) (Omega359) +- feat: implement aggregation and subquery plans to SQL [#9606](https://github.com/apache/arrow-datafusion/pull/9606) (devinjdangelo) +- Port ArrayElem/Slice/PopFront/Back into `functions-array` [#9615](https://github.com/apache/arrow-datafusion/pull/9615) (jayzhan211) +- Minor: Remove datafusion-functions-array dependency from datafusion-optimizer [#9621](https://github.com/apache/arrow-datafusion/pull/9621) (alamb) +- Enable TTY during bench data generation [#9626](https://github.com/apache/arrow-datafusion/pull/9626) (gruuya) +- Remove constant expressions from SortExprs in the SortExec [#9618](https://github.com/apache/arrow-datafusion/pull/9618) (mustafasrepo) +- Try fixing missing results name in the benchmark step [#9632](https://github.com/apache/arrow-datafusion/pull/9632) (gruuya) +- feat: track memory usage for recursive CTE, enable recursive CTEs by default [#9619](https://github.com/apache/arrow-datafusion/pull/9619) (jonahgao) +- doc: Add missing doc link [#9631](https://github.com/apache/arrow-datafusion/pull/9631) (Weijun-H) +- Add explicit move of PR bench results if they were placed in HEAD dir [#9636](https://github.com/apache/arrow-datafusion/pull/9636) (gruuya) +- Add `array_reverse` function to datafusion-function-\* crate [#9630](https://github.com/apache/arrow-datafusion/pull/9630) (Weijun-H) +- Move parts of `InListSimplifier` simplify rules to `Simplifier` [#9628](https://github.com/apache/arrow-datafusion/pull/9628) (jayzhan211) +- Port Array Union and Intersect to `functions-array` [#9629](https://github.com/apache/arrow-datafusion/pull/9629) (jayzhan211) +- Port `ArrayPosition` and `ArrayPositions` to `functions-array` subcrate [#9617](https://github.com/apache/arrow-datafusion/pull/9617) (erenavsarogullari) +- Optimize make_date (#9089) [#9600](https://github.com/apache/arrow-datafusion/pull/9600) (vojtechtoman) +- Support AT TIME ZONE clause [#9647](https://github.com/apache/arrow-datafusion/pull/9647) (tinfoil-knight) +- Window Linear Mode use smaller buffers [#9597](https://github.com/apache/arrow-datafusion/pull/9597) (mustafasrepo) +- Port `ArrayExcept` to `functions-array` subcrate [#9634](https://github.com/apache/arrow-datafusion/pull/9634) (erenavsarogullari) +- chore: improve array expression doc and clean up array_expression.rs [#9650](https://github.com/apache/arrow-datafusion/pull/9650) (Weijun-H) +- Minor: remove clone in `exprlist_to_fields` [#9657](https://github.com/apache/arrow-datafusion/pull/9657) (jayzhan211) +- Port `ArrayRemove`, `ArrayRemoveN`, `ArrayRemoveAll` to `functions-array` subcrate [#9656](https://github.com/apache/arrow-datafusion/pull/9656) (erenavsarogullari) +- Minor: Remove redundant dependencies from `datafusion-functions/Cargo.toml` [#9622](https://github.com/apache/arrow-datafusion/pull/9622) (alamb) +- Support IGNORE NULLS for NTH_VALUE window function [#9625](https://github.com/apache/arrow-datafusion/pull/9625) (huaxingao) +- Improve Robustness of Unparser Testing and Implementation [#9623](https://github.com/apache/arrow-datafusion/pull/9623) (devinjdangelo) +- Adding Constant Check for FilterExec [#9649](https://github.com/apache/arrow-datafusion/pull/9649) (Lordworms) +- chore(deps-dev): bump follow-redirects from 1.15.4 to 1.15.6 in /datafusion/wasmtest/datafusion-wasm-app [#9609](https://github.com/apache/arrow-datafusion/pull/9609) (dependabot[bot]) +- move array_replace family functions to datafusion-function-array crate [#9651](https://github.com/apache/arrow-datafusion/pull/9651) (Weijun-H) +- chore: remove repetitive word `the the` --> `the` in docs / comments [#9673](https://github.com/apache/arrow-datafusion/pull/9673) (InventiveCoder) +- Update example-usage.md to remove reference to simd and rust nightly. [#9677](https://github.com/apache/arrow-datafusion/pull/9677) (Omega359) +- [MINOR]: Remove some `.unwrap`s from nth_value.rs file [#9674](https://github.com/apache/arrow-datafusion/pull/9674) (mustafasrepo) +- minor: Remove deprecated methods [#9627](https://github.com/apache/arrow-datafusion/pull/9627) (comphead) +- Migrate `arrow_cast` to a UDF [#9610](https://github.com/apache/arrow-datafusion/pull/9610) (alamb) +- parquet: Add row*groups_matched*{statistics,bloom_filter} statistics [#9640](https://github.com/apache/arrow-datafusion/pull/9640) (progval) +- Make COPY TO align with CREATE EXTERNAL TABLE [#9604](https://github.com/apache/arrow-datafusion/pull/9604) (metesynnada) +- Support "A column is known to be entirely NULL" in `PruningPredicate` [#9223](https://github.com/apache/arrow-datafusion/pull/9223) (appletreeisyellow) +- Suppress self update for windows CI runner [#9661](https://github.com/apache/arrow-datafusion/pull/9661) (jayzhan211) +- add schema to SQL ast builder [#9624](https://github.com/apache/arrow-datafusion/pull/9624) (sardination) +- core/tests/parquet/row_group_pruning.rs: Add tests for strings [#9642](https://github.com/apache/arrow-datafusion/pull/9642) (progval) +- Fix incorrect results with multiple `COUNT(DISTINCT..)` aggregates on dictionaries [#9679](https://github.com/apache/arrow-datafusion/pull/9679) (alamb) +- parquet: Add support for Bloom filters on binary columns [#9644](https://github.com/apache/arrow-datafusion/pull/9644) (progval) +- Update Arrow/Parquet to `51.0.0`, tonic to `0.11` [#9613](https://github.com/apache/arrow-datafusion/pull/9613) (tustvold) +- Move inlist rule to expr_simplifier [#9692](https://github.com/apache/arrow-datafusion/pull/9692) (jayzhan211) +- Support Serde for ScalarUDF in Physical Expressions [#9436](https://github.com/apache/arrow-datafusion/pull/9436) (yyy1000) +- Support Union types in `ScalarValue` [#9683](https://github.com/apache/arrow-datafusion/pull/9683) (avantgardnerio) +- parquet: Add support for row group pruning on FixedSizeBinary [#9646](https://github.com/apache/arrow-datafusion/pull/9646) (progval) +- Minor: Improve documentation for `LogicalPlan::expressions` [#9698](https://github.com/apache/arrow-datafusion/pull/9698) (alamb) +- Make builtin window function output datatype to be derived from schema [#9686](https://github.com/apache/arrow-datafusion/pull/9686) (comphead) +- refactor: Extract `array_to_string` and `string_to_array` from `functions-array` subcrate' s `kernels` and `udf` containers [#9704](https://github.com/apache/arrow-datafusion/pull/9704) (erenavsarogullari) +- Add Minimum Supported Rust Version policy to docs [#9681](https://github.com/apache/arrow-datafusion/pull/9681) (alamb) +- doc: Add DataFusion profiling documentation for MacOS [#9711](https://github.com/apache/arrow-datafusion/pull/9711) (comphead) +- Minor: add ticket reference to commented out test [#9715](https://github.com/apache/arrow-datafusion/pull/9715) (alamb) +- Minor: Rename path from `common_runtime` to `common-runtime` [#9717](https://github.com/apache/arrow-datafusion/pull/9717) (alamb) +- Use object_store:BufWriter to replace put_multipart [#9648](https://github.com/apache/arrow-datafusion/pull/9648) (yyy1000) +- Fix COPY TO failing on passing format options through CLI [#9709](https://github.com/apache/arrow-datafusion/pull/9709) (tinfoil-knight) +- fix: recursive cte hangs on joins [#9687](https://github.com/apache/arrow-datafusion/pull/9687) (jonahgao) +- Move `starts_with`, `to_hex`,` trim`, `upper` to datafusion-functions (and add string_expressions) [#9541](https://github.com/apache/arrow-datafusion/pull/9541) (Tangruilin) +- Support for `extract(x from time)` / `date_part` from time types [#8693](https://github.com/apache/arrow-datafusion/pull/8693) (Jefffrey) +- doc: Updated known users list and usage dependency description [#9718](https://github.com/apache/arrow-datafusion/pull/9718) (comphead) +- Minor: improve documentation for `CommonSubexprEliminate` [#9700](https://github.com/apache/arrow-datafusion/pull/9700) (alamb) +- build: modify code to comply with latest clippy requirement [#9725](https://github.com/apache/arrow-datafusion/pull/9725) (comphead) +- Minor: return internal error rather than panic on unexpected error in COUNT DISTINCT [#9712](https://github.com/apache/arrow-datafusion/pull/9712) (alamb) +- fix(9678): short circuiting prevented population of visited stack, for common subexpr elimination optimization [#9685](https://github.com/apache/arrow-datafusion/pull/9685) (wiedld) +- perf: improve to_field performance [#9722](https://github.com/apache/arrow-datafusion/pull/9722) (haohuaijin) +- Minor: Run ScalarValue size test on aarch again [#9728](https://github.com/apache/arrow-datafusion/pull/9728) (alamb) +- Move trim functions (btrim, ltrim, rtrim) to datafusion_functions, make expr_fn API consistent [#9730](https://github.com/apache/arrow-datafusion/pull/9730) (Omega359) +- make format prefix optional for format options in COPY [#9723](https://github.com/apache/arrow-datafusion/pull/9723) (tinfoil-knight) +- refactor: Extract `range` and `gen_series` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9720](https://github.com/apache/arrow-datafusion/pull/9720) (erenavsarogullari) +- Move ascii function to datafusion_functions [#9740](https://github.com/apache/arrow-datafusion/pull/9740) (PsiACE) +- adding expr to string for IsNotNull IsTrue IsFalse and IsUnkown [#9739](https://github.com/apache/arrow-datafusion/pull/9739) (Lordworms) +- fix: parallel parquet can underflow when max_record_batch_rows < execution.batch_size [#9737](https://github.com/apache/arrow-datafusion/pull/9737) (devinjdangelo) +- support format in options of COPY command [#9744](https://github.com/apache/arrow-datafusion/pull/9744) (tinfoil-knight) +- Move lower, octet_length to datafusion-functions [#9747](https://github.com/apache/arrow-datafusion/pull/9747) (Omega359) +- Fixed missing trim() in rust api [#9749](https://github.com/apache/arrow-datafusion/pull/9749) (Omega359) +- refactor: Extract `array_length`, `array_reverse` and `array_sort` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9751](https://github.com/apache/arrow-datafusion/pull/9751) (erenavsarogullari) +- refactor: Extract `array_empty` and `array_repeat` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9762](https://github.com/apache/arrow-datafusion/pull/9762) (erenavsarogullari) +- Minor: remove an outdated TODO in `TypeCoercion` [#9752](https://github.com/apache/arrow-datafusion/pull/9752) (jonahgao) +- refactor: Extract `array_resize` and `cardinality` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9766](https://github.com/apache/arrow-datafusion/pull/9766) (erenavsarogullari) +- fix: change placeholder errors from Internal to Plan [#9745](https://github.com/apache/arrow-datafusion/pull/9745) (erratic-pattern) +- Move levenshtein, uuid, overlay to datafusion-functions [#9760](https://github.com/apache/arrow-datafusion/pull/9760) (Omega359) +- improve null handling for to_char [#9689](https://github.com/apache/arrow-datafusion/pull/9689) (tinfoil-knight) +- Add Expr->String for ScalarFunction and InList [#9759](https://github.com/apache/arrow-datafusion/pull/9759) (yyy1000) +- Move repeat, replace, split_part to datafusion_functions [#9784](https://github.com/apache/arrow-datafusion/pull/9784) (Omega359) +- refactor: Extract `array_dims`, `array_ndims` and `flatten` functions from `functions-array` subcrate' s `kernels` and `udf` containers [#9786](https://github.com/apache/arrow-datafusion/pull/9786) (erenavsarogullari) +- Minor: Improve documentation about `ColumnarValues::values_to_array` [#9774](https://github.com/apache/arrow-datafusion/pull/9774) (alamb) +- Fix panic in `struct` function with mixed scalar/array arguments [#9775](https://github.com/apache/arrow-datafusion/pull/9775) (alamb) +- refactor: Apply minor refactorings to `functions-array` crate [#9788](https://github.com/apache/arrow-datafusion/pull/9788) (erenavsarogullari) +- Move bit_length and chr functions to datafusion_functions [#9782](https://github.com/apache/arrow-datafusion/pull/9782) (PsiACE) +- Support tencent cloud COS storage in `datafusion-cli` [#9734](https://github.com/apache/arrow-datafusion/pull/9734) (harveyyue) +- Make it easier to register configuration extension ... [#9781](https://github.com/apache/arrow-datafusion/pull/9781) (milenkovicm) +- Expr to Sql : Case [#9798](https://github.com/apache/arrow-datafusion/pull/9798) (yyy1000) +- feat: Between expr to sql string [#9803](https://github.com/apache/arrow-datafusion/pull/9803) (sebastian2296) +- feat: Expose `array_empty` and `list_empty` functions as alias of `empty` function [#9807](https://github.com/apache/arrow-datafusion/pull/9807) (erenavsarogullari) +- Support Expr `Like` to sql [#9805](https://github.com/apache/arrow-datafusion/pull/9805) (Weijun-H) +- feat: Not expr to string [#9802](https://github.com/apache/arrow-datafusion/pull/9802) (sebastian2296) +- [Minor]: Move some repetitive codes to functions(proto) [#9811](https://github.com/apache/arrow-datafusion/pull/9811) (mustafasrepo) +- Implement IGNORE NULLS for LAST_VALUE [#9801](https://github.com/apache/arrow-datafusion/pull/9801) (huaxingao) +- [MINOR]: Move some repetitive codes to functions [#9810](https://github.com/apache/arrow-datafusion/pull/9810) (mustafasrepo) +- fix: ensure mutual compatibility of the two input schemas from recursive CTEs [#9795](https://github.com/apache/arrow-datafusion/pull/9795) (jonahgao) +- Add support for constant expression evaluation in limit [#9790](https://github.com/apache/arrow-datafusion/pull/9790) (mustafasrepo) +- Projection Pushdown through user defined LogicalPlan nodes. [#9690](https://github.com/apache/arrow-datafusion/pull/9690) (mustafasrepo) +- chore(deps): update substrait requirement from 0.27.0 to 0.28.0 [#9809](https://github.com/apache/arrow-datafusion/pull/9809) (dependabot[bot]) +- Run TPC-H SF10 during PR benchmarks [#9822](https://github.com/apache/arrow-datafusion/pull/9822) (gruuya) +- Expose `parser` on DFParser to enable user controlled parsing [#9729](https://github.com/apache/arrow-datafusion/pull/9729) (tshauck) +- Disable parallel reading for gziped ndjson file [#9799](https://github.com/apache/arrow-datafusion/pull/9799) (Lordworms) +- Optimize to_timestamp (with format) (#9090) [#9833](https://github.com/apache/arrow-datafusion/pull/9833) (vojtechtoman) +- Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function [#9825](https://github.com/apache/arrow-datafusion/pull/9825) (Omega359) +- [Minor] Update TCPDS tests, remove some #[ignore]d tests [#9829](https://github.com/apache/arrow-datafusion/pull/9829) (Dandandan) +- doc: Adding baseline benchmark example [#9827](https://github.com/apache/arrow-datafusion/pull/9827) (comphead) +- Add name method to execution plan [#9793](https://github.com/apache/arrow-datafusion/pull/9793) (matthewmturner) +- chore(deps-dev): bump express from 4.18.2 to 4.19.2 in /datafusion/wasmtest/datafusion-wasm-app [#9826](https://github.com/apache/arrow-datafusion/pull/9826) (dependabot[bot]) +- feat: pass SessionState not SessionConfig to FunctionFactory::create [#9837](https://github.com/apache/arrow-datafusion/pull/9837) (tshauck) diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index 492be93caf0c..a95f2f802dfb 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -64,7 +64,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.execution.parquet.statistics_enabled | NULL | Sets if statistics are enabled for any column Valid values are: "none", "chunk", and "page" These values are not case sensitive. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_statistics_size | NULL | Sets max statistics size for any column. If NULL, uses default parquet writer setting | | datafusion.execution.parquet.max_row_group_size | 1048576 | Target maximum number of rows in each row group (defaults to 1M rows). Writing larger row groups requires more memory to write, but can get better compression and be faster to read. | -| datafusion.execution.parquet.created_by | datafusion version 36.0.0 | Sets "created by" property | +| datafusion.execution.parquet.created_by | datafusion version 37.0.0 | Sets "created by" property | | datafusion.execution.parquet.column_index_truncate_length | NULL | Sets column index truncate length | | datafusion.execution.parquet.data_page_row_count_limit | 18446744073709551615 | Sets best effort maximum number of rows in data page | | datafusion.execution.parquet.encoding | NULL | Sets default encoding for any column Valid values are: plain, plain_dictionary, rle, bit_packed, delta_binary_packed, delta_length_byte_array, delta_byte_array, rle_dictionary, and byte_stream_split. These values are not case sensitive. If NULL, uses default parquet writer setting | From 09f5a544d25f36ff1d65cc377123aee9b0e8f538 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Thu, 28 Mar 2024 22:56:15 -0400 Subject: [PATCH 089/117] move Left, Lpad, Reverse, Right, Rpad functions to datafusion_functions (#9841) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Fixed missing trim() function. * Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function * move Left, Lpad, Reverse, Right, Rpad functions to datafusion_functions * Code cleanup from PR review. --- datafusion/expr/src/built_in_function.rs | 50 +- datafusion/expr/src/expr_fn.rs | 21 - datafusion/functions/src/unicode/left.rs | 236 +++++++ datafusion/functions/src/unicode/lpad.rs | 369 +++++++++++ datafusion/functions/src/unicode/mod.rs | 44 +- datafusion/functions/src/unicode/reverse.rs | 149 +++++ datafusion/functions/src/unicode/right.rs | 238 +++++++ datafusion/functions/src/unicode/rpad.rs | 361 +++++++++++ datafusion/physical-expr/src/functions.rs | 606 ------------------ datafusion/physical-expr/src/planner.rs | 4 +- .../physical-expr/src/unicode_expressions.rs | 263 +------- datafusion/proto/proto/datafusion.proto | 10 +- datafusion/proto/src/generated/pbjson.rs | 15 - datafusion/proto/src/generated/prost.rs | 20 +- .../proto/src/logical_plan/from_proto.rs | 53 +- datafusion/proto/src/logical_plan/to_proto.rs | 5 - 16 files changed, 1428 insertions(+), 1016 deletions(-) create mode 100644 datafusion/functions/src/unicode/left.rs create mode 100644 datafusion/functions/src/unicode/lpad.rs create mode 100644 datafusion/functions/src/unicode/reverse.rs create mode 100644 datafusion/functions/src/unicode/right.rs create mode 100644 datafusion/functions/src/unicode/rpad.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index eefbc131a27b..196d278dc70e 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -111,18 +111,8 @@ pub enum BuiltinScalarFunction { EndsWith, /// initcap InitCap, - /// left - Left, - /// lpad - Lpad, /// random Random, - /// reverse - Reverse, - /// right - Right, - /// rpad - Rpad, /// strpos Strpos, /// substr @@ -220,12 +210,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable, BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, - BuiltinScalarFunction::Left => Volatility::Immutable, - BuiltinScalarFunction::Lpad => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Reverse => Volatility::Immutable, - BuiltinScalarFunction::Right => Volatility::Immutable, - BuiltinScalarFunction::Rpad => Volatility::Immutable, BuiltinScalarFunction::Strpos => Volatility::Immutable, BuiltinScalarFunction::Substr => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, @@ -264,17 +249,8 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::InitCap => { utf8_to_str_type(&input_expr_types[0], "initcap") } - BuiltinScalarFunction::Left => utf8_to_str_type(&input_expr_types[0], "left"), - BuiltinScalarFunction::Lpad => utf8_to_str_type(&input_expr_types[0], "lpad"), BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), - BuiltinScalarFunction::Reverse => { - utf8_to_str_type(&input_expr_types[0], "reverse") - } - BuiltinScalarFunction::Right => { - utf8_to_str_type(&input_expr_types[0], "right") - } - BuiltinScalarFunction::Rpad => utf8_to_str_type(&input_expr_types[0], "rpad"), BuiltinScalarFunction::EndsWith => Ok(Boolean), BuiltinScalarFunction::Strpos => { utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") @@ -361,28 +337,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => { Signature::variadic_equal(self.volatility()) } - BuiltinScalarFunction::InitCap | BuiltinScalarFunction::Reverse => { + BuiltinScalarFunction::InitCap => { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::Lpad | BuiltinScalarFunction::Rpad => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Utf8]), - Exact(vec![LargeUtf8, Int64, Utf8]), - Exact(vec![Utf8, Int64, LargeUtf8]), - Exact(vec![LargeUtf8, Int64, LargeUtf8]), - ], - self.volatility(), - ) - } - BuiltinScalarFunction::Left | BuiltinScalarFunction::Right => { - Signature::one_of( - vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], - self.volatility(), - ) - } BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { Signature::one_of( @@ -580,11 +537,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Left => &["left"], - BuiltinScalarFunction::Lpad => &["lpad"], - BuiltinScalarFunction::Reverse => &["reverse"], - BuiltinScalarFunction::Right => &["right"], - BuiltinScalarFunction::Rpad => &["rpad"], BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::Translate => &["translate"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 654464798625..21dab72855e5 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -578,25 +578,11 @@ scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argu scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); -scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); -scalar_expr!(Reverse, reverse, string, "reverses the `string`"); -scalar_expr!(Right, right, string n, "returns the last `n` characters in the `string`"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); -//use vec as parameter -nary_scalar_expr!( - Lpad, - lpad, - "fill up a string to the length by prepending the characters" -); -nary_scalar_expr!( - Rpad, - rpad, - "fill up a string to the length by appending the characters" -); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c nary_scalar_expr!( @@ -1028,13 +1014,6 @@ mod test { test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); - test_scalar_expr!(Left, left, string, count); - test_nary_scalar_expr!(Lpad, lpad, string, count); - test_nary_scalar_expr!(Lpad, lpad, string, count, characters); - test_scalar_expr!(Reverse, reverse, string); - test_scalar_expr!(Right, right, string, count); - test_nary_scalar_expr!(Rpad, rpad, string, count); - test_nary_scalar_expr!(Rpad, rpad, string, count, characters); test_scalar_expr!(EndsWith, ends_with, string, characters); test_scalar_expr!(Strpos, strpos, string, substring); test_scalar_expr!(Substr, substr, string, position); diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs new file mode 100644 index 000000000000..473589fdc8aa --- /dev/null +++ b/datafusion/functions/src/unicode/left.rs @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::Ordering; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct LeftFunc { + signature: Signature, +} + +impl LeftFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LeftFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "left" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "left") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(left::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(left::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function left"), + } + } +} + +/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. +/// left('abcde', 2) = 'ab' +/// The implementation uses UTF-8 code points as characters +pub fn left(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let n_array = as_int64_array(&args[1])?; + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => { + let len = string.chars().count() as i64; + Some(if n.abs() < len { + string.chars().take((len + n) as usize).collect::() + } else { + "".to_string() + }) + } + Ordering::Equal => Some("".to_string()), + Ordering::Greater => { + Some(string.chars().take(n as usize).collect::()) + } + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::left::LeftFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ab")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ], + Ok(Some("abc")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(-200i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("joséé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + LeftFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + internal_err!( + "function left requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs new file mode 100644 index 000000000000..76a8e68cca25 --- /dev/null +++ b/datafusion/functions/src/unicode/lpad.rs @@ -0,0 +1,369 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct LPadFunc { + signature: Signature, +} + +impl LPadFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for LPadFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "lpad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "lpad") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(lpad::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(lpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function lpad"), + } + } +} + +/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). +/// lpad('hi', 5, 'xy') = 'xyxhi' +pub fn lpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "lpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else { + let mut s: String = " ".repeat(length - graphemes.len()); + s.push_str(string); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + let fill_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "lpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector.push( + *fill_chars.get(l % fill_chars.len()).unwrap(), + ); + } + s.insert_str( + 0, + char_vector.iter().collect::().as_str(), + ); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => exec_err!( + "lpad was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::lpad::LPadFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some(" josé")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("xyxhi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(21i64)), + ColumnarValue::Scalar(ScalarValue::from("abcdef")), + ], + Ok(Some("abcdefabcdefabcdefahi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(" ")), + ], + Ok(Some(" hi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("")), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("xyxyxyjosé")), + &str, + Utf8, + StringArray + ); + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::Scalar(ScalarValue::from("éñ")), + ], + Ok(Some("éñéñéñjosé")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + LPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + internal_err!( + "function lpad requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index 291de3843903..ea4e70a92199 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -22,6 +22,11 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; mod character_length; +mod left; +mod lpad; +mod reverse; +mod right; +mod rpad; // create UDFs make_udf_function!( @@ -29,6 +34,11 @@ make_udf_function!( CHARACTER_LENGTH, character_length ); +make_udf_function!(left::LeftFunc, LEFT, left); +make_udf_function!(lpad::LPadFunc, LPAD, lpad); +make_udf_function!(right::RightFunc, RIGHT, right); +make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); +make_udf_function!(rpad::RPadFunc, RPAD, rpad); pub mod expr_fn { use datafusion_expr::Expr; @@ -47,9 +57,41 @@ pub mod expr_fn { pub fn length(string: Expr) -> Expr { character_length(string) } + + #[doc = "returns the first `n` characters in the `string`"] + pub fn left(string: Expr, n: Expr) -> Expr { + super::left().call(vec![string, n]) + } + + #[doc = "fill up a string to the length by prepending the characters"] + pub fn lpad(args: Vec) -> Expr { + super::lpad().call(args) + } + + #[doc = "reverses the `string`"] + pub fn reverse(string: Expr) -> Expr { + super::reverse().call(vec![string]) + } + + #[doc = "returns the last `n` characters in the `string`"] + pub fn right(string: Expr, n: Expr) -> Expr { + super::right().call(vec![string, n]) + } + + #[doc = "fill up a string to the length by appending the characters"] + pub fn rpad(args: Vec) -> Expr { + super::rpad().call(args) + } } /// Return a list of all functions in this package pub fn functions() -> Vec> { - vec![character_length()] + vec![ + character_length(), + left(), + lpad(), + reverse(), + right(), + rpad(), + ] } diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs new file mode 100644 index 000000000000..42ca6e0d17c3 --- /dev/null +++ b/datafusion/functions/src/unicode/reverse.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct ReverseFunc { + signature: Signature, +} + +impl ReverseFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![Utf8, LargeUtf8], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for ReverseFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "reverse" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "reverse") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(reverse::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(reverse::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function reverse") + } + } + } +} + +/// Reverses the order of the characters in the string. +/// reverse('abcde') = 'edcba' +/// The implementation uses UTF-8 code points as characters +pub fn reverse(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + + let result = string_array + .iter() + .map(|string| string.map(|string: &str| string.chars().rev().collect::())) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::reverse::ReverseFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("abcde"))], + Ok(Some("edcba")), + &str, + Utf8, + StringArray + ); + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("loẅks"))], + Ok(Some("sk̈wol")), + &str, + Utf8, + StringArray + ); + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("loẅks"))], + Ok(Some("sk̈wol")), + &str, + Utf8, + StringArray + ); + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + Ok(None), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + ReverseFunc::new(), + &[ColumnarValue::Scalar(ScalarValue::from("abcde"))], + internal_err!( + "function reverse requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs new file mode 100644 index 000000000000..d1bd976342b2 --- /dev/null +++ b/datafusion/functions/src/unicode/right.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::{max, Ordering}; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::exec_err; +use datafusion_common::Result; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct RightFunc { + signature: Signature, +} + +impl RightFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RightFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "right" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "right") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(right::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(right::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function right"), + } + } +} + +/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. +/// right('abcde', 2) = 'de' +/// The implementation uses UTF-8 code points as characters +pub fn right(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let n_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(n_array.iter()) + .map(|(string, n)| match (string, n) { + (Some(string), Some(n)) => match n.cmp(&0) { + Ordering::Less => Some( + string + .chars() + .skip(n.unsigned_abs() as usize) + .collect::(), + ), + Ordering::Equal => Some("".to_string()), + Ordering::Greater => Some( + string + .chars() + .skip(max(string.chars().count() as i64 - n, 0) as usize) + .collect::(), + ), + }, + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::right::RightFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("de")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(200i64)), + ], + Ok(Some("abcde")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ], + Ok(Some("cde")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(-200i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("éésoj")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + RightFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("abcde")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + internal_err!( + "function right requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs new file mode 100644 index 000000000000..070278c90b2f --- /dev/null +++ b/datafusion/functions/src/unicode/rpad.rs @@ -0,0 +1,361 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use unicode_segmentation::UnicodeSegmentation; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +#[derive(Debug)] +pub(super) struct RPadFunc { + signature: Signature, +} + +impl RPadFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Utf8]), + Exact(vec![LargeUtf8, Int64, Utf8]), + Exact(vec![Utf8, Int64, LargeUtf8]), + Exact(vec![LargeUtf8, Int64, LargeUtf8]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for RPadFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "rpad" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "rpad") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(rpad::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(rpad::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function rpad"), + } + } +} + +/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. +/// rpad('hi', 5, 'xy') = 'hixyx' +pub fn rpad(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .map(|(string, length)| match (string, length) { + (Some(string), Some(length)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else { + let mut s = string.to_string(); + s.push_str(" ".repeat(length - graphemes.len()).as_str()); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let length_array = as_int64_array(&args[1])?; + let fill_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(length_array.iter()) + .zip(fill_array.iter()) + .map(|((string, length), fill)| match (string, length, fill) { + (Some(string), Some(length), Some(fill)) => { + if length > i32::MAX as i64 { + return exec_err!( + "rpad requested length {length} too large" + ); + } + + let length = if length < 0 { 0 } else { length as usize }; + let graphemes = string.graphemes(true).collect::>(); + let fill_chars = fill.chars().collect::>(); + + if length < graphemes.len() { + Ok(Some(graphemes[..length].concat())) + } else if fill_chars.is_empty() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + let mut char_vector = + Vec::::with_capacity(length - graphemes.len()); + for l in 0..length - graphemes.len() { + char_vector + .push(*fill_chars.get(l % fill_chars.len()).unwrap()); + } + s.push_str(char_vector.iter().collect::().as_str()); + Ok(Some(s)) + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => exec_err!( + "rpad was called with {other} arguments. It requires at least 2 and at most 3." + ), + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::rpad::RPadFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("josé ")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("hixyx")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(21i64)), + ColumnarValue::Scalar(ScalarValue::from("abcdef")), + ], + Ok(Some("hiabcdefabcdefabcdefa")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(" ")), + ], + Ok(Some("hi ")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("")), + ], + Ok(Some("hi")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("hi")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::Scalar(ScalarValue::from("xy")), + ], + Ok(Some("joséxyxyxy")), + &str, + Utf8, + StringArray + ); + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::Scalar(ScalarValue::from("éñ")), + ], + Ok(Some("josééñéñéñ")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + RPadFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("josé")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + internal_err!( + "function rpad requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 9adc8536341d..c1b4900e399a 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -270,67 +270,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function initcap") } }), - BuiltinScalarFunction::Left => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i32, "left"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(left, i64, "left"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function left"), - }), - BuiltinScalarFunction::Lpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i32, "lpad"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(lpad, i64, "lpad"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function lpad"), - }), - BuiltinScalarFunction::Reverse => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i32, "reverse"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(reverse, i64, "reverse"); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function reverse") - } - }), - BuiltinScalarFunction::Right => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i32, "right"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(right, i64, "right"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function right"), - }), - BuiltinScalarFunction::Rpad => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i32, "rpad"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!(rpad, i64, "rpad"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function rpad"), - }), BuiltinScalarFunction::EndsWith => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::ends_with::)(args) @@ -691,551 +630,6 @@ mod tests { Utf8, StringArray ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int8(Some(2))),], - Ok(Some("ab")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(200))),], - Ok(Some("abcde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-2))),], - Ok(Some("abc")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-200))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("abcde"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Left, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("joséé")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Left, - &[ - lit("abcde"), - lit(ScalarValue::Int8(Some(2))), - ], - internal_err!( - "function left requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some(" josé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit("xy"),], - Ok(Some("xyxhi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(21))), lit("abcdef"),], - Ok(Some("abcdefabcdefabcdefahi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(" "),], - Ok(Some(" hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(""),], - Ok(Some("hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - lit("xy"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("hi"), lit(ScalarValue::Int64(None)), lit("xy"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[ - lit("hi"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("xy"),], - Ok(Some("xyxyxyjosé")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Lpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("éñ"),], - Ok(Some("éñéñéñjosé")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Lpad, - &[ - lit("josé"), - lit(ScalarValue::Int64(Some(5))), - ], - internal_err!( - "function lpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("abcde")], - Ok(Some("edcba")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("loẅks")], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit("loẅks")], - Ok(Some("sk̈wol")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Reverse, - &[lit(ScalarValue::Utf8(None))], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Reverse, - &[lit("abcde")], - internal_err!( - "function reverse requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int8(Some(2))),], - Ok(Some("de")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(200))),], - Ok(Some("abcde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-2))),], - Ok(Some("cde")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(-200))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("abcde"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("éésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Right, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("éésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Right, - &[ - lit("abcde"), - lit(ScalarValue::Int8(Some(2))), - ], - internal_err!( - "function right requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("josé ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("hi ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit("xy"),], - Ok(Some("hixyx")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(21))), lit("abcdef"),], - Ok(Some("hiabcdefabcdefabcdefa")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(" "),], - Ok(Some("hi ")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(Some(5))), lit(""),], - Ok(Some("hi")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit(ScalarValue::Utf8(None)), - lit(ScalarValue::Int64(Some(5))), - lit("xy"), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("hi"), lit(ScalarValue::Int64(None)), lit("xy"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[ - lit("hi"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Utf8(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("xy"),], - Ok(Some("joséxyxyxy")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Rpad, - &[lit("josé"), lit(ScalarValue::Int64(Some(10))), lit("éñ"),], - Ok(Some("josééñéñéñ")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Rpad, - &[ - lit("josé"), - lit(ScalarValue::Int64(Some(5))), - ], - internal_err!( - "function rpad requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); test_function!( EndsWith, &[lit("alphabet"), lit("alph"),], diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 319d9ca2269a..0dbea09ffb51 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -335,11 +335,11 @@ mod tests { use arrow_array::{ArrayRef, BooleanArray, RecordBatch, StringArray}; use arrow_schema::{DataType, Field, Schema}; use datafusion_common::{DFSchema, Result}; - use datafusion_expr::{col, left, Literal}; + use datafusion_expr::{col, lit}; #[test] fn test_create_physical_expr_scalar_input_output() -> Result<()> { - let expr = col("letter").eq(left("APACHE".lit(), 1i64.lit())); + let expr = col("letter").eq(lit("A")); let schema = Schema::new(vec![Field::new("letter", DataType::Utf8, false)]); let df_schema = DFSchema::try_from_qualified_schema("data", &schema)?; diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index c7e4b7d7c443..faff21111a61 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -21,7 +21,7 @@ //! Unicode expressions -use std::cmp::{max, Ordering}; +use std::cmp::max; use std::sync::Arc; use arrow::{ @@ -36,267 +36,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters. -/// left('abcde', 2) = 'ab' -/// The implementation uses UTF-8 code points as characters -pub fn left(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let n_array = as_int64_array(&args[1])?; - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => { - let len = string.chars().count() as i64; - Some(if n.abs() < len { - string.chars().take((len + n) as usize).collect::() - } else { - "".to_string() - }) - } - Ordering::Equal => Some("".to_string()), - Ordering::Greater => { - Some(string.chars().take(n as usize).collect::()) - } - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extends the string to length 'length' by prepending the characters fill (a space by default). If the string is already longer than length then it is truncated (on the right). -/// lpad('hi', 5, 'xy') = 'xyxhi' -pub fn lpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s: String = " ".repeat(length - graphemes.len()); - s.push_str(string); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "lpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector.push( - *fill_chars.get(l % fill_chars.len()).unwrap(), - ); - } - s.insert_str( - 0, - char_vector.iter().collect::().as_str(), - ); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "lpad was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - -/// Reverses the order of the characters in the string. -/// reverse('abcde') = 'edcba' -/// The implementation uses UTF-8 code points as characters -pub fn reverse(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - - let result = string_array - .iter() - .map(|string| string.map(|string: &str| string.chars().rev().collect::())) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters. -/// right('abcde', 2) = 'de' -/// The implementation uses UTF-8 code points as characters -pub fn right(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let n_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(n_array.iter()) - .map(|(string, n)| match (string, n) { - (Some(string), Some(n)) => match n.cmp(&0) { - Ordering::Less => Some( - string - .chars() - .skip(n.unsigned_abs() as usize) - .collect::(), - ), - Ordering::Equal => Some("".to_string()), - Ordering::Greater => Some( - string - .chars() - .skip(max(string.chars().count() as i64 - n, 0) as usize) - .collect::(), - ), - }, - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extends the string to length 'length' by appending the characters fill (a space by default). If the string is already longer than length then it is truncated. -/// rpad('hi', 5, 'xy') = 'hixyx' -pub fn rpad(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .map(|(string, length)| match (string, length) { - (Some(string), Some(length)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - if length == 0 { - Ok(Some("".to_string())) - } else { - let graphemes = string.graphemes(true).collect::>(); - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else { - let mut s = string.to_string(); - s.push_str(" ".repeat(length - graphemes.len()).as_str()); - Ok(Some(s)) - } - } - } - _ => Ok(None), - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let length_array = as_int64_array(&args[1])?; - let fill_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(length_array.iter()) - .zip(fill_array.iter()) - .map(|((string, length), fill)| match (string, length, fill) { - (Some(string), Some(length), Some(fill)) => { - if length > i32::MAX as i64 { - return exec_err!( - "rpad requested length {length} too large" - ); - } - - let length = if length < 0 { 0 } else { length as usize }; - let graphemes = string.graphemes(true).collect::>(); - let fill_chars = fill.chars().collect::>(); - - if length < graphemes.len() { - Ok(Some(graphemes[..length].concat())) - } else if fill_chars.is_empty() { - Ok(Some(string.to_string())) - } else { - let mut s = string.to_string(); - let mut char_vector = - Vec::::with_capacity(length - graphemes.len()); - for l in 0..length - graphemes.len() { - char_vector - .push(*fill_chars.get(l % fill_chars.len()).unwrap()); - } - s.push_str(char_vector.iter().collect::().as_str()); - Ok(Some(s)) - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => exec_err!( - "rpad was called with {other} arguments. It requires at least 2 and at most 3." - ), - } -} - /// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) /// strpos('high', 'ig') = 2 /// The implementation uses UTF-8 code points as characters diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 766ca6633ee1..6319372d98d2 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -572,8 +572,8 @@ enum ScalarFunction { // 28 was DatePart // 29 was DateTrunc InitCap = 30; - Left = 31; - Lpad = 32; + // 31 was Left + // 32 was Lpad // 33 was Lower // 34 was Ltrim // 35 was MD5 @@ -583,9 +583,9 @@ enum ScalarFunction { // 39 was RegexpReplace // 40 was Repeat // 41 was Replace - Reverse = 42; - Right = 43; - Rpad = 44; + // 42 was Reverse + // 43 was Right + // 44 was Rpad // 45 was Rtrim // 46 was SHA224 // 47 was SHA256 diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f2814956ef1b..7281bc9dc263 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22931,12 +22931,7 @@ impl serde::Serialize for ScalarFunction { Self::Concat => "Concat", Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", - Self::Left => "Left", - Self::Lpad => "Lpad", Self::Random => "Random", - Self::Reverse => "Reverse", - Self::Right => "Right", - Self::Rpad => "Rpad", Self::Strpos => "Strpos", Self::Substr => "Substr", Self::Translate => "Translate", @@ -22990,12 +22985,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Concat", "ConcatWithSeparator", "InitCap", - "Left", - "Lpad", "Random", - "Reverse", - "Right", - "Rpad", "Strpos", "Substr", "Translate", @@ -23078,12 +23068,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Concat" => Ok(ScalarFunction::Concat), "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), - "Left" => Ok(ScalarFunction::Left), - "Lpad" => Ok(ScalarFunction::Lpad), "Random" => Ok(ScalarFunction::Random), - "Reverse" => Ok(ScalarFunction::Reverse), - "Right" => Ok(ScalarFunction::Right), - "Rpad" => Ok(ScalarFunction::Rpad), "Strpos" => Ok(ScalarFunction::Strpos), "Substr" => Ok(ScalarFunction::Substr), "Translate" => Ok(ScalarFunction::Translate), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index ecc94fcdaf99..2fe89efb9cea 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2871,8 +2871,8 @@ pub enum ScalarFunction { /// 28 was DatePart /// 29 was DateTrunc InitCap = 30, - Left = 31, - Lpad = 32, + /// 31 was Left + /// 32 was Lpad /// 33 was Lower /// 34 was Ltrim /// 35 was MD5 @@ -2882,9 +2882,9 @@ pub enum ScalarFunction { /// 39 was RegexpReplace /// 40 was Repeat /// 41 was Replace - Reverse = 42, - Right = 43, - Rpad = 44, + /// 42 was Reverse + /// 43 was Right + /// 44 was Rpad /// 45 was Rtrim /// 46 was SHA224 /// 47 was SHA256 @@ -3004,12 +3004,7 @@ impl ScalarFunction { ScalarFunction::Concat => "Concat", ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", - ScalarFunction::Left => "Left", - ScalarFunction::Lpad => "Lpad", ScalarFunction::Random => "Random", - ScalarFunction::Reverse => "Reverse", - ScalarFunction::Right => "Right", - ScalarFunction::Rpad => "Rpad", ScalarFunction::Strpos => "Strpos", ScalarFunction::Substr => "Substr", ScalarFunction::Translate => "Translate", @@ -3057,12 +3052,7 @@ impl ScalarFunction { "Concat" => Some(Self::Concat), "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), - "Left" => Some(Self::Left), - "Lpad" => Some(Self::Lpad), "Random" => Some(Self::Random), - "Reverse" => Some(Self::Reverse), - "Right" => Some(Self::Right), - "Rpad" => Some(Self::Rpad), "Strpos" => Some(Self::Strpos), "Substr" => Some(Self::Substr), "Translate" => Some(Self::Translate), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 19edd71a3a80..2c6f2e479b24 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -17,18 +17,6 @@ use std::sync::Arc; -use crate::protobuf::{ - self, - plan_type::PlanTypeEnum::{ - AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, - FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, - InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, - OptimizedPhysicalPlan, - }, - AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, - OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, -}; - use arrow::{ array::AsArray, buffer::Buffer, @@ -38,6 +26,7 @@ use arrow::{ }, ipc::{reader::read_record_batch, root_as_message}, }; + use datafusion::execution::registry::FunctionRegistry; use datafusion_common::{ arrow_datafusion_err, internal_err, plan_datafusion_err, Column, Constraint, @@ -51,17 +40,29 @@ use datafusion_expr::{ acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, floor, gcd, initcap, iszero, lcm, left, ln, log, log10, log2, + factorial, find_in_set, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - lpad, nanvl, pi, power, radians, random, reverse, right, round, rpad, signum, sin, - sinh, sqrt, strpos, substr, substr_index, substring, translate, trunc, - AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, - Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, strpos, substr, + substr_index, substring, translate, trunc, AggregateFunction, Between, BinaryExpr, + BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, + GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, }; +use crate::protobuf::{ + self, + plan_type::PlanTypeEnum::{ + AnalyzedLogicalPlan, FinalAnalyzedLogicalPlan, FinalLogicalPlan, + FinalPhysicalPlan, FinalPhysicalPlanWithStats, InitialLogicalPlan, + InitialPhysicalPlan, InitialPhysicalPlanWithStats, OptimizedLogicalPlan, + OptimizedPhysicalPlan, + }, + AnalyzedLogicalPlanType, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, + OptimizedPhysicalPlanType, PlaceholderNode, RollupNode, +}; + use super::LogicalExtensionCodec; #[derive(Debug)] @@ -453,12 +454,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, - ScalarFunction::Left => Self::Left, - ScalarFunction::Lpad => Self::Lpad, ScalarFunction::Random => Self::Random, - ScalarFunction::Reverse => Self::Reverse, - ScalarFunction::Right => Self::Right, - ScalarFunction::Rpad => Self::Rpad, ScalarFunction::Strpos => Self::Strpos, ScalarFunction::Substr => Self::Substr, ScalarFunction::Translate => Self::Translate, @@ -1382,26 +1378,13 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Left => Ok(left( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Random => Ok(random()), - ScalarFunction::Reverse => { - Ok(reverse(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Right => Ok(right( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Concat => { Ok(concat_expr(parse_exprs(args, registry, codec)?)) } ScalarFunction::ConcatWithSeparator => { Ok(concat_ws_expr(parse_exprs(args, registry, codec)?)) } - ScalarFunction::Lpad => Ok(lpad(parse_exprs(args, registry, codec)?)), - ScalarFunction::Rpad => Ok(rpad(parse_exprs(args, registry, codec)?)), ScalarFunction::EndsWith => Ok(ends_with( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 11fc7362c75d..ea682a5a22f8 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1445,12 +1445,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, - BuiltinScalarFunction::Left => Self::Left, - BuiltinScalarFunction::Lpad => Self::Lpad, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Reverse => Self::Reverse, - BuiltinScalarFunction::Right => Self::Right, - BuiltinScalarFunction::Rpad => Self::Rpad, BuiltinScalarFunction::Strpos => Self::Strpos, BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::Translate => Self::Translate, From 7f497b3b23d4aa2cb6336671d09b9c9837ed0d82 Mon Sep 17 00:00:00 2001 From: Mustafa Akur <106137913+mustafasrepo@users.noreply.github.com> Date: Fri, 29 Mar 2024 10:34:49 +0300 Subject: [PATCH 090/117] Add non-column expression equality tracking to filter exec (#9819) * Add non-column expression equality tracking to filter exec * Minor changes --- datafusion/physical-plan/src/filter.rs | 47 +++++++++---------- datafusion/physical-plan/src/lib.rs | 1 - datafusion/sqllogictest/test_files/select.slt | 21 +++++++++ 3 files changed, 44 insertions(+), 25 deletions(-) diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 2996152fb924..a9201f435ad8 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -29,7 +29,7 @@ use super::{ }; use crate::{ metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - Column, DisplayFormatType, ExecutionPlan, + DisplayFormatType, ExecutionPlan, }; use arrow::compute::filter_record_batch; @@ -192,9 +192,7 @@ impl FilterExec { let mut eq_properties = input.equivalence_properties().clone(); let (equal_pairs, _) = collect_columns_from_predicate(predicate); for (lhs, rhs) in equal_pairs { - let lhs_expr = Arc::new(lhs.clone()) as _; - let rhs_expr = Arc::new(rhs.clone()) as _; - eq_properties.add_equal_conditions(&lhs_expr, &rhs_expr) + eq_properties.add_equal_conditions(lhs, rhs) } // Add the columns that have only one viable value (singleton) after // filtering to constants. @@ -405,34 +403,33 @@ impl RecordBatchStream for FilterExecStream { /// Return the equals Column-Pairs and Non-equals Column-Pairs fn collect_columns_from_predicate(predicate: &Arc) -> EqualAndNonEqual { - let mut eq_predicate_columns = Vec::<(&Column, &Column)>::new(); - let mut ne_predicate_columns = Vec::<(&Column, &Column)>::new(); + let mut eq_predicate_columns = Vec::::new(); + let mut ne_predicate_columns = Vec::::new(); let predicates = split_conjunction(predicate); predicates.into_iter().for_each(|p| { if let Some(binary) = p.as_any().downcast_ref::() { - if let (Some(left_column), Some(right_column)) = ( - binary.left().as_any().downcast_ref::(), - binary.right().as_any().downcast_ref::(), - ) { - match binary.op() { - Operator::Eq => { - eq_predicate_columns.push((left_column, right_column)) - } - Operator::NotEq => { - ne_predicate_columns.push((left_column, right_column)) - } - _ => {} + match binary.op() { + Operator::Eq => { + eq_predicate_columns.push((binary.left(), binary.right())) + } + Operator::NotEq => { + ne_predicate_columns.push((binary.left(), binary.right())) } + _ => {} } } }); (eq_predicate_columns, ne_predicate_columns) } + +/// Pair of `Arc`s +pub type PhysicalExprPairRef<'a> = (&'a Arc, &'a Arc); + /// The equals Column-Pairs and Non-equals Column-Pairs in the Predicates pub type EqualAndNonEqual<'a> = - (Vec<(&'a Column, &'a Column)>, Vec<(&'a Column, &'a Column)>); + (Vec>, Vec>); #[cfg(test)] mod tests { @@ -482,14 +479,16 @@ mod tests { )?; let (equal_pairs, ne_pairs) = collect_columns_from_predicate(&predicate); + assert_eq!(2, equal_pairs.len()); + assert!(equal_pairs[0].0.eq(&col("c2", &schema)?)); + assert!(equal_pairs[0].1.eq(&lit(4u32))); - assert_eq!(1, equal_pairs.len()); - assert_eq!(equal_pairs[0].0.name(), "c2"); - assert_eq!(equal_pairs[0].1.name(), "c9"); + assert!(equal_pairs[1].0.eq(&col("c2", &schema)?)); + assert!(equal_pairs[1].1.eq(&col("c9", &schema)?)); assert_eq!(1, ne_pairs.len()); - assert_eq!(ne_pairs[0].0.name(), "c1"); - assert_eq!(ne_pairs[0].1.name(), "c13"); + assert!(ne_pairs[0].0.eq(&col("c1", &schema)?)); + assert!(ne_pairs[0].1.eq(&col("c13", &schema)?)); Ok(()) } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 4b4b37f8b51b..3e8e439c9a38 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -33,7 +33,6 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::utils::DataPtr; use datafusion_common::Result; use datafusion_execution::TaskContext; -use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{ EquivalenceProperties, LexOrdering, PhysicalSortExpr, PhysicalSortRequirement, }; diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 3a5c6497ebd4..ad4b0df1a546 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1386,6 +1386,27 @@ AggregateExec: mode=FinalPartitioned, gby=[c2@0 as c2], aggr=[COUNT(*)] --------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 ----------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c2], has_header=true +# FilterExec can track equality of non-column expressions. +# plan below shouldn't have a SortExec because given column 'a' is ordered. +# 'CAST(ROUND(b) as INT)' is also ordered. After filter is applied. +query TT +EXPLAIN SELECT * +FROM annotated_data_finite2 +WHERE CAST(ROUND(b) as INT) = a +ORDER BY CAST(ROUND(b) as INT); +---- +logical_plan +Sort: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) ASC NULLS LAST +--Filter: CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a +----TableScan: annotated_data_finite2 projection=[a0, a, b, c, d], partial_filters=[CAST(round(CAST(annotated_data_finite2.b AS Float64)) AS Int32) = annotated_data_finite2.a] +physical_plan +SortPreservingMergeExec: [CAST(round(CAST(b@2 AS Float64)) AS Int32) ASC NULLS LAST] +--CoalesceBatchesExec: target_batch_size=8192 +----FilterExec: CAST(round(CAST(b@2 AS Float64)) AS Int32) = a@1 +------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1 +--------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a0, a, b, c, d], output_ordering=[a@1 ASC NULLS LAST, b@2 ASC NULLS LAST, c@3 ASC NULLS LAST], has_header=true + + statement ok drop table annotated_data_finite2; From d7957636327fb8d89e6428152492107e39d614b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Fri, 29 Mar 2024 14:28:41 +0300 Subject: [PATCH 091/117] datafusion-cli support for multiple commands in a single line (#9831) * Multiple Create External Table's are supported from CLI * Handle in-quote semicolons * add test --- datafusion-cli/src/exec.rs | 32 +++++++------ datafusion-cli/src/helper.rs | 91 ++++++++++++++++++++++++++++++++---- 2 files changed, 100 insertions(+), 23 deletions(-) diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs index 114e3cefa3bf..53375ab4104f 100644 --- a/datafusion-cli/src/exec.rs +++ b/datafusion-cli/src/exec.rs @@ -22,6 +22,7 @@ use std::fs::File; use std::io::prelude::*; use std::io::BufReader; +use crate::helper::split_from_semicolon; use crate::print_format::PrintFormat; use crate::{ command::{Command, OutputFormat}, @@ -164,21 +165,24 @@ pub async fn exec_from_repl( } } Ok(line) => { - rl.add_history_entry(line.trim_end())?; - tokio::select! { - res = exec_and_print(ctx, print_options, line) => match res { - Ok(_) => {} - Err(err) => eprintln!("{err}"), - }, - _ = signal::ctrl_c() => { - println!("^C"); - continue - }, + let lines = split_from_semicolon(line); + for line in lines { + rl.add_history_entry(line.trim_end())?; + tokio::select! { + res = exec_and_print(ctx, print_options, line) => match res { + Ok(_) => {} + Err(err) => eprintln!("{err}"), + }, + _ = signal::ctrl_c() => { + println!("^C"); + continue + }, + } + // dialect might have changed + rl.helper_mut().unwrap().set_dialect( + &ctx.task_ctx().session_config().options().sql_parser.dialect, + ); } - // dialect might have changed - rl.helper_mut().unwrap().set_dialect( - &ctx.task_ctx().session_config().options().sql_parser.dialect, - ); } Err(ReadlineError::Interrupted) => { println!("^C"); diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index a8e149b4c5c6..8b196484ee2c 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -86,16 +86,23 @@ impl CliHelper { )))) } }; - - match DFParser::parse_sql_with_dialect(&sql, dialect.as_ref()) { - Ok(statements) if statements.is_empty() => Ok(ValidationResult::Invalid( - Some(" 🤔 You entered an empty statement".to_string()), - )), - Ok(_statements) => Ok(ValidationResult::Valid(None)), - Err(err) => Ok(ValidationResult::Invalid(Some(format!( - " 🤔 Invalid statement: {err}", - )))), + let lines = split_from_semicolon(sql); + for line in lines { + match DFParser::parse_sql_with_dialect(&line, dialect.as_ref()) { + Ok(statements) if statements.is_empty() => { + return Ok(ValidationResult::Invalid(Some( + " 🤔 You entered an empty statement".to_string(), + ))); + } + Ok(_statements) => {} + Err(err) => { + return Ok(ValidationResult::Invalid(Some(format!( + " 🤔 Invalid statement: {err}", + )))); + } + } } + Ok(ValidationResult::Valid(None)) } else if input.starts_with('\\') { // command Ok(ValidationResult::Valid(None)) @@ -197,6 +204,37 @@ pub fn unescape_input(input: &str) -> datafusion::error::Result { Ok(result) } +/// Splits a string which consists of multiple queries. +pub(crate) fn split_from_semicolon(sql: String) -> Vec { + let mut commands = Vec::new(); + let mut current_command = String::new(); + let mut in_single_quote = false; + let mut in_double_quote = false; + + for c in sql.chars() { + if c == '\'' && !in_double_quote { + in_single_quote = !in_single_quote; + } else if c == '"' && !in_single_quote { + in_double_quote = !in_double_quote; + } + + if c == ';' && !in_single_quote && !in_double_quote { + if !current_command.trim().is_empty() { + commands.push(format!("{};", current_command.trim())); + current_command.clear(); + } + } else { + current_command.push(c); + } + } + + if !current_command.trim().is_empty() { + commands.push(format!("{};", current_command.trim())); + } + + commands +} + #[cfg(test)] mod tests { use std::io::{BufRead, Cursor}; @@ -292,4 +330,39 @@ mod tests { Ok(()) } + + #[test] + fn test_split_from_semicolon() { + let sql = "SELECT 1; SELECT 2;"; + let expected = vec!["SELECT 1;", "SELECT 2;"]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = r#"SELECT ";";"#; + let expected = vec![r#"SELECT ";";"#]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = "SELECT ';';"; + let expected = vec!["SELECT ';';"]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = r#"SELECT 1; SELECT 'value;value'; SELECT 1 as "text;text";"#; + let expected = vec![ + "SELECT 1;", + "SELECT 'value;value';", + r#"SELECT 1 as "text;text";"#, + ]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = ""; + let expected: Vec = Vec::new(); + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = "SELECT 1"; + let expected = vec!["SELECT 1;"]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + + let sql = "SELECT 1; "; + let expected = vec!["SELECT 1;"]; + assert_eq!(split_from_semicolon(sql.to_string()), expected); + } } From 230a6b476804c0a8964d559cc16e41328a43efc5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 29 Mar 2024 08:03:21 -0400 Subject: [PATCH 092/117] Add tests for filtering, grouping, aggregation of ARRAYs (#9695) * Add tests for filtering, grouping, aggregation of ARRAYs * Update output to correct results --- .../sqllogictest/test_files/array_query.slt | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 datafusion/sqllogictest/test_files/array_query.slt diff --git a/datafusion/sqllogictest/test_files/array_query.slt b/datafusion/sqllogictest/test_files/array_query.slt new file mode 100644 index 000000000000..24c99fc849b6 --- /dev/null +++ b/datafusion/sqllogictest/test_files/array_query.slt @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +## Tests for basic array queries + +# Make a table with multiple input partitions +statement ok +CREATE TABLE data AS + SELECT * FROM (VALUES + ([1,2,3], [4,5], 1) + ) + UNION ALL + SELECT * FROM (VALUES + ([2,3], [2,3], 1), + ([1,2,3], NULL, 1) + ) +; + +query ??I rowsort +SELECT * FROM data; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +########### +# Filtering +########### + +query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +SELECT * FROM data WHERE column1 = [1,2,3]; + +query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) == List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +SELECT * FROM data WHERE column1 = column2 + +query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +SELECT * FROM data WHERE column1 != [1,2,3]; + +query error DataFusion error: Arrow error: Invalid argument error: Invalid comparison operation: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) != List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) +SELECT * FROM data WHERE column1 != column2 + +########### +# Aggregates +########### + +query error Internal error: Min/Max accumulator not implemented for type List +SELECT min(column1) FROM data; + +query error Internal error: Min/Max accumulator not implemented for type List +SELECT max(column1) FROM data; + +query I +SELECT count(column1) FROM data; +---- +3 + +# note single count distincts are rewritten to use a group by +query I +SELECT count(distinct column1) FROM data; +---- +2 + +query I +SELECT count(distinct column2) FROM data; +---- +2 + + +# note multiple count distincts are not rewritten +query II +SELECT count(distinct column1), count(distinct column2) FROM data; +---- +2 2 + + +########### +# GROUP BY +########### + + +query I +SELECT count(column1) FROM data GROUP BY column3; +---- +3 + +# note single count distincts are rewritten to use a group by +query I +SELECT count(distinct column1) FROM data GROUP BY column3; +---- +2 + +query I +SELECT count(distinct column2) FROM data GROUP BY column3; +---- +2 + +# note multiple count distincts are not rewritten +query II +SELECT count(distinct column1), count(distinct column2) FROM data GROUP BY column3; +---- +2 2 + + +########### +# ORDER BY +########### + +query ??I +SELECT * FROM data ORDER BY column2; +---- +[2, 3] [2, 3] 1 +[1, 2, 3] [4, 5] 1 +[1, 2, 3] NULL 1 + +query ??I +SELECT * FROM data ORDER BY column2 DESC; +---- +[1, 2, 3] NULL 1 +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 + +query ??I +SELECT * FROM data ORDER BY column2 DESC NULLS LAST; +---- +[1, 2, 3] [4, 5] 1 +[2, 3] [2, 3] 1 +[1, 2, 3] NULL 1 + +# multi column +query ??I +SELECT * FROM data ORDER BY column1, column2; +---- +[1, 2, 3] [4, 5] 1 +[1, 2, 3] NULL 1 +[2, 3] [2, 3] 1 + +query ??I +SELECT * FROM data ORDER BY column1, column3, column2; +---- +[1, 2, 3] [4, 5] 1 +[1, 2, 3] NULL 1 +[2, 3] [2, 3] 1 + + +statement ok +drop table data From aaad010e82d51d84c441f11f4359616fab39b960 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 29 Mar 2024 09:25:55 -0400 Subject: [PATCH 093/117] Remove vestigal conbench integration (#9855) --- conbench/.flake8 | 2 - conbench/.gitignore | 130 ----------------- conbench/.isort.cfg | 2 - conbench/README.md | 252 --------------------------------- conbench/_criterion.py | 98 ------------- conbench/benchmarks.json | 8 -- conbench/benchmarks.py | 41 ------ conbench/requirements-test.txt | 3 - conbench/requirements.txt | 1 - 9 files changed, 537 deletions(-) delete mode 100644 conbench/.flake8 delete mode 100755 conbench/.gitignore delete mode 100644 conbench/.isort.cfg delete mode 100644 conbench/README.md delete mode 100644 conbench/_criterion.py delete mode 100644 conbench/benchmarks.json delete mode 100644 conbench/benchmarks.py delete mode 100644 conbench/requirements-test.txt delete mode 100644 conbench/requirements.txt diff --git a/conbench/.flake8 b/conbench/.flake8 deleted file mode 100644 index e44b81084185..000000000000 --- a/conbench/.flake8 +++ /dev/null @@ -1,2 +0,0 @@ -[flake8] -ignore = E501 diff --git a/conbench/.gitignore b/conbench/.gitignore deleted file mode 100755 index aa44ee2adbd4..000000000000 --- a/conbench/.gitignore +++ /dev/null @@ -1,130 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - diff --git a/conbench/.isort.cfg b/conbench/.isort.cfg deleted file mode 100644 index f238bf7ea137..000000000000 --- a/conbench/.isort.cfg +++ /dev/null @@ -1,2 +0,0 @@ -[settings] -profile = black diff --git a/conbench/README.md b/conbench/README.md deleted file mode 100644 index f655ac8bd297..000000000000 --- a/conbench/README.md +++ /dev/null @@ -1,252 +0,0 @@ - - -# DataFusion + Conbench Integration - - -## Quick start - -``` -$ cd ~/arrow-datafusion/conbench/ -$ conda create -y -n conbench python=3.9 -$ conda activate conbench -(conbench) $ pip install -r requirements.txt -(conbench) $ conbench datafusion -``` - -## Example output - -``` -{ - "batch_id": "3c82f9d23fce49328b78ba9fd963b254", - "context": { - "benchmark_language": "Rust" - }, - "github": { - "commit": "e8c198b9fac6cd8822b950b9f71898e47965488d", - "repository": "https://github.com/dianaclarke/arrow-datafusion" - }, - "info": {}, - "machine_info": { - "architecture_name": "x86_64", - "cpu_core_count": "8", - "cpu_frequency_max_hz": "2400000000", - "cpu_l1d_cache_bytes": "65536", - "cpu_l1i_cache_bytes": "131072", - "cpu_l2_cache_bytes": "4194304", - "cpu_l3_cache_bytes": "0", - "cpu_model_name": "Apple M1", - "cpu_thread_count": "8", - "gpu_count": "0", - "gpu_product_names": [], - "kernel_name": "20.6.0", - "memory_bytes": "17179869184", - "name": "diana", - "os_name": "macOS", - "os_version": "10.16" - }, - "run_id": "ec2a50b9380c470b96d7eb7d63ab5b77", - "stats": { - "data": [ - "0.001532", - "0.001394", - "0.001333", - "0.001356", - "0.001379", - "0.001361", - "0.001307", - "0.001348", - "0.001436", - "0.001397", - "0.001339", - "0.001523", - "0.001593", - "0.001415", - "0.001344", - "0.001312", - "0.001402", - "0.001362", - "0.001329", - "0.001330", - "0.001447", - "0.001413", - "0.001536", - "0.001330", - "0.001333", - "0.001338", - "0.001333", - "0.001331", - "0.001426", - "0.001575", - "0.001362", - "0.001343", - "0.001334", - "0.001383", - "0.001476", - "0.001356", - "0.001362", - "0.001334", - "0.001390", - "0.001497", - "0.001330", - "0.001347", - "0.001331", - "0.001468", - "0.001377", - "0.001351", - "0.001328", - "0.001509", - "0.001338", - "0.001355", - "0.001332", - "0.001485", - "0.001370", - "0.001366", - "0.001507", - "0.001358", - "0.001331", - "0.001463", - "0.001362", - "0.001336", - "0.001428", - "0.001343", - "0.001359", - "0.001905", - "0.001726", - "0.001411", - "0.001433", - "0.001391", - "0.001453", - "0.001346", - "0.001339", - "0.001420", - "0.001330", - "0.001422", - "0.001683", - "0.001426", - "0.001349", - "0.001342", - "0.001430", - "0.001330", - "0.001436", - "0.001331", - "0.001415", - "0.001332", - "0.001408", - "0.001343", - "0.001392", - "0.001371", - "0.001655", - "0.001354", - "0.001438", - "0.001347", - "0.001341", - "0.001374", - "0.001453", - "0.001352", - "0.001358", - "0.001398", - "0.001362", - "0.001454" - ], - "iqr": "0.000088", - "iterations": 100, - "max": "0.001905", - "mean": "0.001401", - "median": "0.001362", - "min": "0.001307", - "q1": "0.001340", - "q3": "0.001428", - "stdev": "0.000095", - "time_unit": "s", - "times": [], - "unit": "s" - }, - "tags": { - "name": "aggregate_query_group_by", - "suite": "aggregate_query_group_by" - }, - "timestamp": "2022-02-09T01:32:55.769468+00:00" -} -``` - -## Debug with test benchmark - -``` -(conbench) $ cd ~/arrow-datafusion/conbench/ -(conbench) $ conbench test --iterations=3 - -Benchmark result: -{ - "batch_id": "41a144761bc24d82b94efa70d6e460b3", - "context": { - "benchmark_language": "Python" - }, - "github": { - "commit": "e8c198b9fac6cd8822b950b9f71898e47965488d", - "repository": "https://github.com/dianaclarke/arrow-datafusion" - }, - "info": { - "benchmark_language_version": "Python 3.9.7" - }, - "machine_info": { - "architecture_name": "x86_64", - "cpu_core_count": "8", - "cpu_frequency_max_hz": "2400000000", - "cpu_l1d_cache_bytes": "65536", - "cpu_l1i_cache_bytes": "131072", - "cpu_l2_cache_bytes": "4194304", - "cpu_l3_cache_bytes": "0", - "cpu_model_name": "Apple M1", - "cpu_thread_count": "8", - "gpu_count": "0", - "gpu_product_names": [], - "kernel_name": "20.6.0", - "memory_bytes": "17179869184", - "name": "diana", - "os_name": "macOS", - "os_version": "10.16" - }, - "run_id": "71f46362db8844afacea82cba119cefc", - "stats": { - "data": [ - "0.000001", - "0.000001", - "0.000000" - ], - "iqr": "0.000000", - "iterations": 3, - "max": "0.000001", - "mean": "0.000001", - "median": "0.000001", - "min": "0.000000", - "q1": "0.000000", - "q3": "0.000001", - "stdev": "0.000001", - "time_unit": "s", - "times": [], - "unit": "s" - }, - "tags": { - "name": "test" - }, - "timestamp": "2022-02-09T01:36:45.823615+00:00" -} -``` - diff --git a/conbench/_criterion.py b/conbench/_criterion.py deleted file mode 100644 index 168a1b9b6cb1..000000000000 --- a/conbench/_criterion.py +++ /dev/null @@ -1,98 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import collections -import csv -import os -import pathlib -import subprocess - -import conbench.runner -from conbench.machine_info import github_info - - -def _result_in_seconds(row): - # sample_measured_value - The value of the measurement for this sample. - # Note that this is the measured value for the whole sample, not the - # time-per-iteration To calculate the time-per-iteration, use - # sample_measured_value/iteration_count - # -- https://bheisler.github.io/criterion.rs/book/user_guide/csv_output.html - count = int(row["iteration_count"]) - sample = float(row["sample_measured_value"]) - return sample / count / 10**9 - - -def _parse_benchmark_group(row): - parts = row["group"].split(",") - if len(parts) > 1: - suite, name = parts[0], ",".join(parts[1:]) - else: - suite, name = row["group"], row["group"] - return suite.strip(), name.strip() - - -def _read_results(src_dir): - results = collections.defaultdict(lambda: collections.defaultdict(list)) - path = pathlib.Path(os.path.join(src_dir, "target", "criterion")) - for path in list(path.glob("**/new/raw.csv")): - with open(path) as csv_file: - reader = csv.DictReader(csv_file) - for row in reader: - suite, name = _parse_benchmark_group(row) - results[suite][name].append(_result_in_seconds(row)) - return results - - -def _execute_command(command): - try: - print(command) - result = subprocess.run(command, capture_output=True, check=True) - except subprocess.CalledProcessError as e: - print(e.stderr.decode("utf-8")) - raise e - return result.stdout.decode("utf-8"), result.stderr.decode("utf-8") - - -class CriterionBenchmark(conbench.runner.Benchmark): - external = True - - def run(self, **kwargs): - src_dir = os.path.join(os.getcwd(), "..") - self._cargo_bench(src_dir) - results = _read_results(src_dir) - for suite in results: - self.conbench.mark_new_batch() - for name, data in results[suite].items(): - yield self._record_result(suite, name, data, kwargs) - - def _cargo_bench(self, src_dir): - os.chdir(src_dir) - _execute_command(["cargo", "bench"]) - - def _record_result(self, suite, name, data, options): - tags = {"suite": suite} - result = {"data": data, "unit": "s"} - context = {"benchmark_language": "Rust"} - github = github_info() - return self.conbench.record( - result, - name, - tags=tags, - context=context, - github=github, - options=options, - ) diff --git a/conbench/benchmarks.json b/conbench/benchmarks.json deleted file mode 100644 index bb7033547722..000000000000 --- a/conbench/benchmarks.json +++ /dev/null @@ -1,8 +0,0 @@ -[ - { - "command": "datafusion", - "flags": { - "language": "Rust" - } - } -] diff --git a/conbench/benchmarks.py b/conbench/benchmarks.py deleted file mode 100644 index f80b3add90f9..000000000000 --- a/conbench/benchmarks.py +++ /dev/null @@ -1,41 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import conbench.runner - -import _criterion - - -@conbench.runner.register_benchmark -class TestBenchmark(conbench.runner.Benchmark): - name = "test" - - def run(self, **kwargs): - yield self.conbench.benchmark( - self._f(), - self.name, - options=kwargs, - ) - - def _f(self): - return lambda: 1 + 1 - - -@conbench.runner.register_benchmark -class CargoBenchmarks(_criterion.CriterionBenchmark): - name = "datafusion" - description = "Run Arrow DataFusion micro benchmarks." diff --git a/conbench/requirements-test.txt b/conbench/requirements-test.txt deleted file mode 100644 index 5e5647acd2d6..000000000000 --- a/conbench/requirements-test.txt +++ /dev/null @@ -1,3 +0,0 @@ -black -flake8 -isort diff --git a/conbench/requirements.txt b/conbench/requirements.txt deleted file mode 100644 index a877c7b44e9b..000000000000 --- a/conbench/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -conbench From 2d023299fa2544350cb18b45181cc8aa729eda3f Mon Sep 17 00:00:00 2001 From: Renjie Liu Date: Fri, 29 Mar 2024 21:38:43 +0800 Subject: [PATCH 094/117] feat: Add display_pg_json for LogicalPlan (#9789) * feat: Add display_pg_json for LogicalPlan * Fix lints * Fix comments * Fix format --- datafusion-cli/Cargo.lock | 1 + datafusion/expr/Cargo.toml | 1 + datafusion/expr/src/logical_plan/display.rs | 494 +++++++++++++++++++- datafusion/expr/src/logical_plan/plan.rs | 82 ++++ 4 files changed, 577 insertions(+), 1 deletion(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 0277d23f4de0..2bbe89f24bbe 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1249,6 +1249,7 @@ dependencies = [ "chrono", "datafusion-common", "paste", + "serde_json", "sqlparser", "strum 0.26.2", "strum_macros 0.26.2", diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index 621a320230f2..6f6147d36883 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -43,6 +43,7 @@ arrow-array = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } paste = "^1.0" +serde_json = { workspace = true } sqlparser = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } strum_macros = "0.26.0" diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index e0cb44626e24..edc3afd55d63 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -16,14 +16,22 @@ // under the License. //! This module provides logic for displaying LogicalPlans in various styles +use std::collections::HashMap; use std::fmt; -use crate::LogicalPlan; +use crate::{ + expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr, + Filter, Join, Limit, LogicalPlan, Partitioning, Prepare, Projection, RecursiveQuery, + Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan, + Unnest, Values, Window, +}; +use crate::dml::CopyTo; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; +use serde_json::json; /// Formats plans with a single line per node. For example: /// @@ -221,6 +229,490 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } } +/// Formats plans to display as postgresql plan json format. +/// +/// There are already many existing visualizer for this format, for example [dalibo](https://explain.dalibo.com/). +/// Unfortunately, there is no formal spec for this format, but it is widely used in the PostgreSQL community. +/// +/// Here is an example of the format: +/// +/// ```json +/// [ +/// { +/// "Plan": { +/// "Node Type": "Sort", +/// "Output": [ +/// "question_1.id", +/// "question_1.title", +/// "question_1.text", +/// "question_1.file", +/// "question_1.type", +/// "question_1.source", +/// "question_1.exam_id" +/// ], +/// "Sort Key": [ +/// "question_1.id" +/// ], +/// "Plans": [ +/// { +/// "Node Type": "Seq Scan", +/// "Parent Relationship": "Left", +/// "Relation Name": "question", +/// "Schema": "public", +/// "Alias": "question_1", +/// "Output": [ +/// "question_1.id", +/// "question_1.title", +/// "question_1.text", +/// "question_1.file", +/// "question_1.type", +/// "question_1.source", +/// "question_1.exam_id" +/// ], +/// "Filter": "(question_1.exam_id = 1)" +/// } +/// ] +/// } +/// } +/// ] +/// ``` +pub struct PgJsonVisitor<'a, 'b> { + f: &'a mut fmt::Formatter<'b>, + + /// A mapping from plan node id to the plan node json representation. + objects: HashMap, + + next_id: u32, + + /// If true, includes summarized schema information + with_schema: bool, + + /// Holds the ids (as generated from `graphviz_builder` of all + /// parent nodes + parent_ids: Vec, +} + +impl<'a, 'b> PgJsonVisitor<'a, 'b> { + pub fn new(f: &'a mut fmt::Formatter<'b>) -> Self { + Self { + f, + objects: HashMap::new(), + next_id: 0, + with_schema: false, + parent_ids: Vec::new(), + } + } + + /// Sets a flag which controls if the output schema is displayed + pub fn with_schema(&mut self, with_schema: bool) { + self.with_schema = with_schema; + } + + /// Converts a logical plan node to a json object. + fn to_json_value(node: &LogicalPlan) -> serde_json::Value { + match node { + LogicalPlan::EmptyRelation(_) => { + json!({ + "Node Type": "EmptyRelation", + }) + } + LogicalPlan::RecursiveQuery(RecursiveQuery { is_distinct, .. }) => { + json!({ + "Node Type": "RecursiveQuery", + "Is Distinct": is_distinct, + }) + } + LogicalPlan::Values(Values { ref values, .. }) => { + let str_values = values + .iter() + // limit to only 5 values to avoid horrible display + .take(5) + .map(|row| { + let item = row + .iter() + .map(|expr| expr.to_string()) + .collect::>() + .join(", "); + format!("({item})") + }) + .collect::>() + .join(", "); + + let elipse = if values.len() > 5 { "..." } else { "" }; + + let values_str = format!("{}{}", str_values, elipse); + json!({ + "Node Type": "Values", + "Values": values_str + }) + } + LogicalPlan::TableScan(TableScan { + ref source, + ref table_name, + ref filters, + ref fetch, + .. + }) => { + let mut object = json!({ + "Node Type": "TableScan", + "Relation Name": table_name.table(), + }); + + if let Some(s) = table_name.schema() { + object["Schema"] = serde_json::Value::String(s.to_string()); + } + + if let Some(c) = table_name.catalog() { + object["Catalog"] = serde_json::Value::String(c.to_string()); + } + + if !filters.is_empty() { + let mut full_filter = vec![]; + let mut partial_filter = vec![]; + let mut unsupported_filters = vec![]; + let filters: Vec<&Expr> = filters.iter().collect(); + + if let Ok(results) = source.supports_filters_pushdown(&filters) { + filters.iter().zip(results.iter()).for_each( + |(x, res)| match res { + TableProviderFilterPushDown::Exact => full_filter.push(x), + TableProviderFilterPushDown::Inexact => { + partial_filter.push(x) + } + TableProviderFilterPushDown::Unsupported => { + unsupported_filters.push(x) + } + }, + ); + } + + if !full_filter.is_empty() { + object["Full Filters"] = serde_json::Value::String( + expr_vec_fmt!(full_filter).to_string(), + ); + }; + if !partial_filter.is_empty() { + object["Partial Filters"] = serde_json::Value::String( + expr_vec_fmt!(partial_filter).to_string(), + ); + } + if !unsupported_filters.is_empty() { + object["Unsupported Filters"] = serde_json::Value::String( + expr_vec_fmt!(unsupported_filters).to_string(), + ); + } + } + + if let Some(f) = fetch { + object["Fetch"] = serde_json::Value::Number((*f).into()); + } + + object + } + LogicalPlan::Projection(Projection { ref expr, .. }) => { + json!({ + "Node Type": "Projection", + "Expressions": expr.iter().map(|e| e.to_string()).collect::>() + }) + } + LogicalPlan::Dml(DmlStatement { table_name, op, .. }) => { + json!({ + "Node Type": "Projection", + "Operation": op.name(), + "Table Name": table_name.table() + }) + } + LogicalPlan::Copy(CopyTo { + input: _, + output_url, + format_options, + partition_by: _, + options, + }) => { + let op_str = options + .iter() + .map(|(k, v)| format!("{}={}", k, v)) + .collect::>() + .join(", "); + json!({ + "Node Type": "CopyTo", + "Output URL": output_url, + "Format Options": format!("{}", format_options), + "Options": op_str + }) + } + LogicalPlan::Ddl(ddl) => { + json!({ + "Node Type": "Ddl", + "Operation": format!("{}", ddl.display()) + }) + } + LogicalPlan::Filter(Filter { + predicate: ref expr, + .. + }) => { + json!({ + "Node Type": "Filter", + "Condition": format!("{}", expr) + }) + } + LogicalPlan::Window(Window { + ref window_expr, .. + }) => { + json!({ + "Node Type": "WindowAggr", + "Expressions": expr_vec_fmt!(window_expr) + }) + } + LogicalPlan::Aggregate(Aggregate { + ref group_expr, + ref aggr_expr, + .. + }) => { + json!({ + "Node Type": "Aggregate", + "Group By": expr_vec_fmt!(group_expr), + "Aggregates": expr_vec_fmt!(aggr_expr) + }) + } + LogicalPlan::Sort(Sort { expr, fetch, .. }) => { + let mut object = json!({ + "Node Type": "Sort", + "Sort Key": expr_vec_fmt!(expr), + }); + + if let Some(fetch) = fetch { + object["Fetch"] = serde_json::Value::Number((*fetch).into()); + } + + object + } + LogicalPlan::Join(Join { + on: ref keys, + filter, + join_constraint, + join_type, + .. + }) => { + let join_expr: Vec = + keys.iter().map(|(l, r)| format!("{l} = {r}")).collect(); + let filter_expr = filter + .as_ref() + .map(|expr| format!(" Filter: {expr}")) + .unwrap_or_else(|| "".to_string()); + json!({ + "Node Type": format!("{} Join", join_type), + "Join Constraint": format!("{:?}", join_constraint), + "Join Keys": join_expr.join(", "), + "Filter": format!("{}", filter_expr) + }) + } + LogicalPlan::CrossJoin(_) => { + json!({ + "Node Type": "Cross Join" + }) + } + LogicalPlan::Repartition(Repartition { + partitioning_scheme, + .. + }) => match partitioning_scheme { + Partitioning::RoundRobinBatch(n) => { + json!({ + "Node Type": "Repartition", + "Partitioning Scheme": "RoundRobinBatch", + "Partition Count": n + }) + } + Partitioning::Hash(expr, n) => { + let hash_expr: Vec = + expr.iter().map(|e| format!("{e}")).collect(); + + json!({ + "Node Type": "Repartition", + "Partitioning Scheme": "Hash", + "Partition Count": n, + "Partitioning Key": hash_expr + }) + } + Partitioning::DistributeBy(expr) => { + let dist_by_expr: Vec = + expr.iter().map(|e| format!("{e}")).collect(); + json!({ + "Node Type": "Repartition", + "Partitioning Scheme": "DistributeBy", + "Partitioning Key": dist_by_expr + }) + } + }, + LogicalPlan::Limit(Limit { + ref skip, + ref fetch, + .. + }) => { + let mut object = serde_json::json!( + { + "Node Type": "Limit", + "Skip": skip, + } + ); + if let Some(f) = fetch { + object["Fetch"] = serde_json::Value::Number((*f).into()); + }; + object + } + LogicalPlan::Subquery(Subquery { .. }) => { + json!({ + "Node Type": "Subquery" + }) + } + LogicalPlan::SubqueryAlias(SubqueryAlias { ref alias, .. }) => { + json!({ + "Node Type": "Subquery", + "Alias": alias.table(), + }) + } + LogicalPlan::Statement(statement) => { + json!({ + "Node Type": "Statement", + "Statement": format!("{}", statement.display()) + }) + } + LogicalPlan::Distinct(distinct) => match distinct { + Distinct::All(_) => { + json!({ + "Node Type": "DistinctAll" + }) + } + Distinct::On(DistinctOn { + on_expr, + select_expr, + sort_expr, + .. + }) => { + let mut object = json!({ + "Node Type": "DistinctOn", + "On": expr_vec_fmt!(on_expr), + "Select": expr_vec_fmt!(select_expr), + }); + if let Some(sort_expr) = sort_expr { + object["Sort"] = serde_json::Value::String( + expr_vec_fmt!(sort_expr).to_string(), + ); + } + + object + } + }, + LogicalPlan::Explain { .. } => { + json!({ + "Node Type": "Explain" + }) + } + LogicalPlan::Analyze { .. } => { + json!({ + "Node Type": "Analyze" + }) + } + LogicalPlan::Union(_) => { + json!({ + "Node Type": "Union" + }) + } + LogicalPlan::Extension(e) => { + json!({ + "Node Type": e.node.name(), + "Detail": format!("{:?}", e.node) + }) + } + LogicalPlan::Prepare(Prepare { + name, data_types, .. + }) => { + json!({ + "Node Type": "Prepare", + "Name": name, + "Data Types": format!("{:?}", data_types) + }) + } + LogicalPlan::DescribeTable(DescribeTable { .. }) => { + json!({ + "Node Type": "DescribeTable" + }) + } + LogicalPlan::Unnest(Unnest { column, .. }) => { + json!({ + "Node Type": "Unnest", + "Column": format!("{}", column) + }) + } + } + } +} + +impl<'a, 'b> TreeNodeVisitor for PgJsonVisitor<'a, 'b> { + type Node = LogicalPlan; + + fn f_down( + &mut self, + node: &LogicalPlan, + ) -> datafusion_common::Result { + let id = self.next_id; + self.next_id += 1; + let mut object = Self::to_json_value(node); + + object["Plans"] = serde_json::Value::Array(vec![]); + + if self.with_schema { + object["Output"] = serde_json::Value::Array( + node.schema() + .fields() + .iter() + .map(|f| f.name().to_string()) + .map(serde_json::Value::String) + .collect(), + ); + }; + + self.objects.insert(id, object); + self.parent_ids.push(id); + Ok(TreeNodeRecursion::Continue) + } + + fn f_up( + &mut self, + _node: &Self::Node, + ) -> datafusion_common::Result { + let id = self.parent_ids.pop().unwrap(); + + let current_node = self.objects.remove(&id).ok_or_else(|| { + DataFusionError::Internal("Missing current node!".to_string()) + })?; + + if let Some(parent_id) = self.parent_ids.last() { + let parent_node = self + .objects + .get_mut(parent_id) + .expect("Missing parent node!"); + let plans = parent_node + .get_mut("Plans") + .and_then(|p| p.as_array_mut()) + .expect("Plans should be an array"); + + plans.push(current_node); + } else { + // This is the root node + let plan = serde_json::json!([{"Plan": current_node}]); + write!( + self.f, + "{}", + serde_json::to_string_pretty(&plan) + .map_err(|e| DataFusionError::External(Box::new(e)))? + )?; + } + + Ok(TreeNodeRecursion::Continue) + } +} + #[cfg(test)] mod tests { use arrow::datatypes::{DataType, Field}; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 05d7ac539458..9f4094d483c9 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -54,6 +54,7 @@ use datafusion_common::{ }; // backwards compatibility +use crate::display::PgJsonVisitor; pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; pub use datafusion_common::{JoinConstraint, JoinType}; @@ -1302,6 +1303,26 @@ impl LogicalPlan { Wrapper(self) } + /// Return a displayable structure that produces plan in postgresql JSON format. + /// + /// Users can use this format to visualize the plan in existing plan visualization tools, for example [dalibo](https://explain.dalibo.com/) + pub fn display_pg_json(&self) -> impl Display + '_ { + // Boilerplate structure to wrap LogicalPlan with something + // that that can be formatted + struct Wrapper<'a>(&'a LogicalPlan); + impl<'a> Display for Wrapper<'a> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let mut visitor = PgJsonVisitor::new(f); + visitor.with_schema(true); + match self.0.visit(&mut visitor) { + Ok(_) => Ok(()), + Err(_) => Err(fmt::Error), + } + } + } + Wrapper(self) + } + /// Return a `format`able structure that produces lines meant for /// graphical display using the `DOT` language. This format can be /// visualized using software from @@ -2781,6 +2802,67 @@ digraph { Ok(()) } + #[test] + fn test_display_pg_json() -> Result<()> { + let plan = display_plan()?; + + let expected_pg_json = r#"[ + { + "Plan": { + "Expressions": [ + "employee_csv.id" + ], + "Node Type": "Projection", + "Output": [ + "id" + ], + "Plans": [ + { + "Condition": "employee_csv.state IN ()", + "Node Type": "Filter", + "Output": [ + "id", + "state" + ], + "Plans": [ + { + "Node Type": "Subquery", + "Output": [ + "state" + ], + "Plans": [ + { + "Node Type": "TableScan", + "Output": [ + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] + }, + { + "Node Type": "TableScan", + "Output": [ + "id", + "state" + ], + "Plans": [], + "Relation Name": "employee_csv" + } + ] + } + ] + } + } +]"#; + + let pg_json = format!("{}", plan.display_pg_json()); + + assert_eq!(expected_pg_json, pg_json); + Ok(()) + } + /// Tests for the Visitor trait and walking logical plan nodes #[derive(Debug, Default)] struct OkVisitor { From 2956ec2962d7af94be53243427f8795d29fa90a3 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 29 Mar 2024 09:39:27 -0400 Subject: [PATCH 095/117] Update `COPY` documentation to reflect cahnges (#9754) --- docs/source/user-guide/sql/dml.md | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/source/user-guide/sql/dml.md b/docs/source/user-guide/sql/dml.md index b9614bb8f929..79c36092fd3d 100644 --- a/docs/source/user-guide/sql/dml.md +++ b/docs/source/user-guide/sql/dml.md @@ -25,11 +25,14 @@ and modifying data in tables. ## COPY Copies the contents of a table or query to file(s). Supported file -formats are `parquet`, `csv`, and `json` and can be inferred based on -filename if writing to a single file. +formats are `parquet`, `csv`, `json`, and `arrow`.
-COPY { table_name | query } TO 'file_name' [ ( option [, ... ] ) ]
+COPY { table_name | query } 
+TO 'file_name'
+[ STORED AS format ]
+[ PARTITIONED BY column_name [, ...] ]
+[ OPTIONS( option [, ... ] ) ]
 
For a detailed list of valid OPTIONS, see [Write Options](write_options). @@ -61,7 +64,7 @@ Copy the contents of `source_table` to multiple directories of hive-style partitioned parquet files: ```sql -> COPY source_table TO 'dir_name' (FORMAT parquet, partition_by 'column1, column2'); +> COPY source_table TO 'dir_name' STORED AS parquet, PARTITIONED BY (column1, column2); +-------+ | count | +-------+ @@ -74,7 +77,7 @@ results (maintaining the order) to a parquet file named `output.parquet` with a maximum parquet row group size of 10MB: ```sql -> COPY (SELECT * from source ORDER BY time) TO 'output.parquet' (ROW_GROUP_LIMIT_BYTES 10000000); +> COPY (SELECT * from source ORDER BY time) TO 'output.parquet' OPTIONS (MAX_ROW_GROUP_SIZE 10000000); +-------+ | count | +-------+ @@ -82,6 +85,12 @@ results (maintaining the order) to a parquet file named +-------+ ``` +The output format is determined by the first match of the following rules: + +1. Value of `STORED AS` +2. Value of the `OPTION (FORMAT ..)` +3. Filename extension (e.g. `foo.parquet` implies `PARQUET` format) + ## INSERT Insert values into a table. From f1adc68394ff378382dad5b36e89886139e902fa Mon Sep 17 00:00:00 2001 From: Marko Grujic Date: Fri, 29 Mar 2024 14:41:40 +0100 Subject: [PATCH 096/117] Remove the two cases most likely to cause OOM in CI (#9858) --- .github/workflows/pr_benchmarks.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/pr_benchmarks.yml b/.github/workflows/pr_benchmarks.yml index 29d001783b17..5827c42e85ae 100644 --- a/.github/workflows/pr_benchmarks.yml +++ b/.github/workflows/pr_benchmarks.yml @@ -47,7 +47,6 @@ jobs: ./bench.sh run tpch ./bench.sh run tpch_mem ./bench.sh run tpch10 - ./bench.sh run tpch_mem10 # For some reason this step doesn't seem to propagate the env var down into the script if [ -d "results/HEAD" ]; then @@ -70,7 +69,6 @@ jobs: ./bench.sh run tpch ./bench.sh run tpch_mem ./bench.sh run tpch10 - ./bench.sh run tpch_mem10 echo ${{ github.event.issue.number }} > pr From bf141dd113291615cbe986545969fae6368efa98 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Fri, 29 Mar 2024 09:42:37 -0400 Subject: [PATCH 097/117] Minor: make uuid an optional dependency on datafusion-functions (#9771) * Minor: make uuid an optional dependency on datafusion-functions * fix merge --- datafusion/functions/Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 0cab0276ff4b..3ae3061012e0 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -52,7 +52,7 @@ math_expressions = [] # enable regular expressions regex_expressions = ["regex"] # enable string functions -string_expressions = [] +string_expressions = ["uuid"] # enable unicode functions unicode_expressions = ["unicode-segmentation"] @@ -79,7 +79,7 @@ md-5 = { version = "^0.10.0", optional = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } -uuid = { version = "1.7", features = ["v4"] } +uuid = { version = "1.7", features = ["v4"], optional = true } [dev-dependencies] criterion = "0.5" From c202965c1740140fa4ff49364c99f3c4c9293182 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Fri, 29 Mar 2024 22:43:51 +0900 Subject: [PATCH 098/117] Add `Spice.ai` to Known Users (#9852) --- docs/source/user-guide/introduction.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 0e9d731c6e21..708318db4aba 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -116,6 +116,7 @@ Here are some active projects using DataFusion: - [Restate](https://github.com/restatedev) Easily build resilient applications using distributed durable async/await - [ROAPI](https://github.com/roapi/roapi) - [Seafowl](https://github.com/splitgraph/seafowl) CDN-friendly analytical database +- [Spice.ai](https://github.com/spiceai/spiceai) Unified SQL query interface & materialization engine - [Synnada](https://synnada.ai/) Streaming-first framework for data products - [VegaFusion](https://vegafusion.io/) Server-side acceleration for the [Vega](https://vega.github.io/) visualization grammar - [ZincObserve](https://github.com/zinclabs/zincobserve) Distributed cloud native observability platform @@ -146,6 +147,7 @@ Here are some less active projects that used DataFusion: [qv]: https://github.com/timvw/qv [roapi]: https://github.com/roapi/roapi [seafowl]: https://github.com/splitgraph/seafowl +[spice.ai]: https://github.com/spiceai/spiceai [synnada]: https://synnada.ai/ [tensorbase]: https://github.com/tensorbase/tensorbase [vegafusion]: https://vegafusion.io/ From 5ab5511db6f1715dea1f123cc40e480490443bca Mon Sep 17 00:00:00 2001 From: comphead Date: Fri, 29 Mar 2024 06:44:40 -0700 Subject: [PATCH 099/117] minor: add a hint how to adjust max rows displayed (#9845) --- datafusion-cli/src/print_options.rs | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/datafusion-cli/src/print_options.rs b/datafusion-cli/src/print_options.rs index 02cb0fb9c63e..93630c8d48f8 100644 --- a/datafusion-cli/src/print_options.rs +++ b/datafusion-cli/src/print_options.rs @@ -73,21 +73,22 @@ pub struct PrintOptions { pub color: bool, } -fn get_timing_info_str( +// Returns the query execution details formatted +fn get_execution_details_formatted( row_count: usize, maxrows: MaxRows, query_start_time: Instant, ) -> String { - let row_word = if row_count == 1 { "row" } else { "rows" }; let nrows_shown_msg = match maxrows { - MaxRows::Limited(nrows) if nrows < row_count => format!(" ({} shown)", nrows), + MaxRows::Limited(nrows) if nrows < row_count => { + format!("(First {nrows} displayed. Use --maxrows to adjust)") + } _ => String::new(), }; format!( - "{} {} in set{}. Query took {:.3} seconds.\n", + "{} row(s) fetched. {}\nElapsed {:.3} seconds.\n", row_count, - row_word, nrows_shown_msg, query_start_time.elapsed().as_secs_f64() ) @@ -107,7 +108,7 @@ impl PrintOptions { .print_batches(&mut writer, batches, self.maxrows, true)?; let row_count: usize = batches.iter().map(|b| b.num_rows()).sum(); - let timing_info = get_timing_info_str( + let formatted_exec_details = get_execution_details_formatted( row_count, if self.format == PrintFormat::Table { self.maxrows @@ -118,7 +119,7 @@ impl PrintOptions { ); if !self.quiet { - writeln!(writer, "{timing_info}")?; + writeln!(writer, "{formatted_exec_details}")?; } Ok(()) @@ -154,11 +155,14 @@ impl PrintOptions { with_header = false; } - let timing_info = - get_timing_info_str(row_count, MaxRows::Unlimited, query_start_time); + let formatted_exec_details = get_execution_details_formatted( + row_count, + MaxRows::Unlimited, + query_start_time, + ); if !self.quiet { - writeln!(writer, "{timing_info}")?; + writeln!(writer, "{formatted_exec_details}")?; } Ok(()) From 3eeb108125b35424baac39dd20ba88433b347419 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 29 Mar 2024 07:47:14 -0600 Subject: [PATCH 100/117] Exclude .github directory from release tarball (#9850) --- .gitattributes | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitattributes b/.gitattributes index 7ff0bbb6d959..bcdeffc09a11 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,3 @@ +.github/ export-ignore datafusion/proto/src/generated/prost.rs linguist-generated datafusion/proto/src/generated/pbjson.rs linguist-generated From c2879f510533a01bc04ef75da4f1416d0ddb99f6 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Fri, 29 Mar 2024 09:53:03 -0400 Subject: [PATCH 101/117] move strpos, substr functions to datafusion_functions (#9849) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Fixed missing trim() function. * Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function * move Left, Lpad, Reverse, Right, Rpad functions to datafusion_functions * move strpos, substr functions to datafusion_functions * Cleanup tests --- datafusion/expr/src/built_in_function.rs | 36 +- datafusion/expr/src/expr_fn.rs | 6 - datafusion/functions/src/unicode/mod.rs | 31 ++ datafusion/functions/src/unicode/strpos.rs | 121 ++++++ datafusion/functions/src/unicode/substr.rs | 392 ++++++++++++++++++ datafusion/physical-expr/src/functions.rs | 258 +----------- .../physical-expr/src/unicode_expressions.rs | 95 ----- datafusion/proto/Cargo.toml | 1 + datafusion/proto/proto/datafusion.proto | 4 +- datafusion/proto/src/generated/pbjson.rs | 6 - datafusion/proto/src/generated/prost.rs | 8 +- .../proto/src/logical_plan/from_proto.rs | 29 +- datafusion/proto/src/logical_plan/to_proto.rs | 2 - .../tests/cases/roundtrip_logical_plan.rs | 29 +- datafusion/proto/tests/cases/serialize.rs | 5 +- datafusion/sql/src/expr/mod.rs | 9 +- datafusion/sql/src/expr/substring.rs | 16 +- datafusion/sqllogictest/test_files/scalar.slt | 2 +- 18 files changed, 598 insertions(+), 452 deletions(-) create mode 100644 datafusion/functions/src/unicode/strpos.rs create mode 100644 datafusion/functions/src/unicode/substr.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 196d278dc70e..423fc11c1d8c 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -113,10 +113,6 @@ pub enum BuiltinScalarFunction { InitCap, /// random Random, - /// strpos - Strpos, - /// substr - Substr, /// translate Translate, /// substr_index @@ -211,8 +207,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Strpos => Volatility::Immutable, - BuiltinScalarFunction::Substr => Volatility::Immutable, BuiltinScalarFunction::Translate => Volatility::Immutable, BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, BuiltinScalarFunction::FindInSet => Volatility::Immutable, @@ -252,12 +246,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::EndsWith => Ok(Boolean), - BuiltinScalarFunction::Strpos => { - utf8_to_int_type(&input_expr_types[0], "strpos/instr/position") - } - BuiltinScalarFunction::Substr => { - utf8_to_str_type(&input_expr_types[0], "substr") - } BuiltinScalarFunction::SubstrIndex => { utf8_to_str_type(&input_expr_types[0], "substr_index") } @@ -341,24 +329,12 @@ impl BuiltinScalarFunction { Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility()) } - BuiltinScalarFunction::EndsWith | BuiltinScalarFunction::Strpos => { - Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![Utf8, LargeUtf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![LargeUtf8, LargeUtf8]), - ], - self.volatility(), - ) - } - - BuiltinScalarFunction::Substr => Signature::one_of( + BuiltinScalarFunction::EndsWith => Signature::one_of( vec![ - Exact(vec![Utf8, Int64]), - Exact(vec![LargeUtf8, Int64]), - Exact(vec![Utf8, Int64, Int64]), - Exact(vec![LargeUtf8, Int64, Int64]), + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), ], self.volatility(), ), @@ -537,8 +513,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Strpos => &["strpos", "instr", "position"], - BuiltinScalarFunction::Substr => &["substr"], BuiltinScalarFunction::Translate => &["translate"], BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], BuiltinScalarFunction::FindInSet => &["find_in_set"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 21dab72855e5..09170ae639ff 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -579,9 +579,6 @@ scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); -scalar_expr!(Strpos, strpos, string substring, "finds the position from where the `substring` matches the `string`"); -scalar_expr!(Substr, substr, string position, "substring from the `position` to the end"); -scalar_expr!(Substr, substring, string position length, "substring from the `position` with `length` characters"); scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c @@ -1015,9 +1012,6 @@ mod test { test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); - test_scalar_expr!(Strpos, strpos, string, substring); - test_scalar_expr!(Substr, substr, string, position); - test_scalar_expr!(Substr, substring, string, position, count); test_scalar_expr!(Translate, translate, string, from, to); test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); test_scalar_expr!(FindInSet, find_in_set, string, stringlist); diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index ea4e70a92199..ddab0d1e27c9 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -27,6 +27,8 @@ mod lpad; mod reverse; mod right; mod rpad; +mod strpos; +mod substr; // create UDFs make_udf_function!( @@ -39,6 +41,8 @@ make_udf_function!(lpad::LPadFunc, LPAD, lpad); make_udf_function!(right::RightFunc, RIGHT, right); make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); make_udf_function!(rpad::RPadFunc, RPAD, rpad); +make_udf_function!(strpos::StrposFunc, STRPOS, strpos); +make_udf_function!(substr::SubstrFunc, SUBSTR, substr); pub mod expr_fn { use datafusion_expr::Expr; @@ -53,6 +57,11 @@ pub mod expr_fn { super::character_length().call(vec![string]) } + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn instr(string: Expr, substring: Expr) -> Expr { + strpos(string, substring) + } + #[doc = "the number of characters in the `string`"] pub fn length(string: Expr) -> Expr { character_length(string) @@ -68,6 +77,11 @@ pub mod expr_fn { super::lpad().call(args) } + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn position(string: Expr, substring: Expr) -> Expr { + strpos(string, substring) + } + #[doc = "reverses the `string`"] pub fn reverse(string: Expr) -> Expr { super::reverse().call(vec![string]) @@ -82,6 +96,21 @@ pub mod expr_fn { pub fn rpad(args: Vec) -> Expr { super::rpad().call(args) } + + #[doc = "finds the position from where the `substring` matches the `string`"] + pub fn strpos(string: Expr, substring: Expr) -> Expr { + super::strpos().call(vec![string, substring]) + } + + #[doc = "substring from the `position` to the end"] + pub fn substr(string: Expr, position: Expr) -> Expr { + super::substr().call(vec![string, position]) + } + + #[doc = "substring from the `position` with `length` characters"] + pub fn substring(string: Expr, position: Expr, length: Expr) -> Expr { + super::substr().call(vec![string, position, length]) + } } /// Return a list of all functions in this package @@ -93,5 +122,7 @@ pub fn functions() -> Vec> { reverse(), right(), rpad(), + strpos(), + substr(), ] } diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs new file mode 100644 index 000000000000..1e8bfa37d40e --- /dev/null +++ b/datafusion/functions/src/unicode/strpos.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_int_type}; + +#[derive(Debug)] +pub(super) struct StrposFunc { + signature: Signature, + aliases: Vec, +} + +impl StrposFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![Utf8, LargeUtf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![LargeUtf8, LargeUtf8]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("instr"), String::from("position")], + } + } +} + +impl ScalarUDFImpl for StrposFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "strpos" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "strpos/instr/position") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(strpos::, vec![])(args), + DataType::LargeUtf8 => { + make_scalar_function(strpos::, vec![])(args) + } + other => exec_err!("Unsupported data type {other:?} for function strpos"), + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) +/// strpos('high', 'ig') = 2 +/// The implementation uses UTF-8 code points as characters +fn strpos(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + let string_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + + let substring_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = string_array + .iter() + .zip(substring_array.iter()) + .map(|(string, substring)| match (string, substring) { + (Some(string), Some(substring)) => { + // the find method returns the byte index of the substring + // Next, we count the number of the chars until that byte + T::Native::from_usize( + string + .find(substring) + .map(|x| string[..x].chars().count() + 1) + .unwrap_or(0), + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs new file mode 100644 index 000000000000..403157e2a85a --- /dev/null +++ b/datafusion/functions/src/unicode/substr.rs @@ -0,0 +1,392 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::cmp::max; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct SubstrFunc { + signature: Signature, +} + +impl SubstrFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Int64]), + Exact(vec![LargeUtf8, Int64]), + Exact(vec![Utf8, Int64, Int64]), + Exact(vec![LargeUtf8, Int64, Int64]), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SubstrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "substr") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(substr::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(substr::, vec![])(args), + other => exec_err!("Unsupported data type {other:?} for function substr"), + } + } +} + +/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) +/// substr('alphabet', 3) = 'phabet' +/// substr('alphabet', 3, 2) = 'ph' +/// The implementation uses UTF-8 code points as characters +pub fn substr(args: &[ArrayRef]) -> Result { + match args.len() { + 2 => { + let string_array = as_generic_string_array::(&args[0])?; + let start_array = as_int64_array(&args[1])?; + + let result = string_array + .iter() + .zip(start_array.iter()) + .map(|(string, start)| match (string, start) { + (Some(string), Some(start)) => { + if start <= 0 { + Some(string.to_string()) + } else { + Some(string.chars().skip(start as usize - 1).collect()) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) + } + 3 => { + let string_array = as_generic_string_array::(&args[0])?; + let start_array = as_int64_array(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(start_array.iter()) + .zip(count_array.iter()) + .map(|((string, start), count)| match (string, start, count) { + (Some(string), Some(start), Some(count)) => { + if count < 0 { + exec_err!( + "negative substring length not allowed: substr(, {start}, {count})" + ) + } else { + let skip = max(0, start - 1); + let count = max(0, count + (if start < 1 {start - 1} else {0})); + Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) + } + } + _ => Ok(None), + }) + .collect::>>()?; + + Ok(Arc::new(result) as ArrayRef) + } + other => { + exec_err!("substr was called with {other} arguments. It requires 2 or 3.") + } + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{exec_err, Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::substr::SubstrFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("ésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ], + Ok(Some("joséésoj")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("lphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(30i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("ph")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::from(20i64)), + ], + Ok(Some("phabet")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from 5 (10 + -5) + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(10i64)), + ], + Ok(Some("alph")), + &str, + Utf8, + StringArray + ); + // starting from -1 (4 + -5) + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(4i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + // starting from 0 (5 + -5) + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ], + Ok(Some("")), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::Scalar(ScalarValue::from(20i64)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::Scalar(ScalarValue::Int64(None)), + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ], + exec_err!("negative substring length not allowed: substr(, 1, -1)"), + &str, + Utf8, + StringArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("joséésoj")), + ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::Scalar(ScalarValue::from(2i64)), + ], + Ok(Some("és")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("alphabet")), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + internal_err!( + "function substr requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index c1b4900e399a..513dd71d4074 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -281,34 +281,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function ends_with") } }), - BuiltinScalarFunction::Strpos => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int32Type, "strpos" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - strpos, Int64Type, "strpos" - ); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function strpos"), - }), - BuiltinScalarFunction::Substr => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i32, "substr"); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = - invoke_if_unicode_expressions_feature_flag!(substr, i64, "substr"); - make_scalar_function_inner(func)(args) - } - other => exec_err!("Unsupported data type {other:?} for function substr"), - }), BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { let func = invoke_if_unicode_expressions_feature_flag!( @@ -450,7 +422,7 @@ mod tests { }; use datafusion_common::cast::as_uint64_array; - use datafusion_common::{exec_err, internal_err, plan_err}; + use datafusion_common::{internal_err, plan_err}; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::type_coercion::functions::data_types; use datafusion_expr::Signature; @@ -663,234 +635,6 @@ mod tests { BooleanArray ); #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(0))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(5))),], - Ok(Some("ésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("joséésoj"), lit(ScalarValue::Int64(Some(-5))),], - Ok(Some("joséésoj")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(1))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(2))),], - Ok(Some("lphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(3))),], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(-3))),], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(Some(30))),], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[lit("alphabet"), lit(ScalarValue::Int64(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("ph")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(0))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some("alph")), - &str, - Utf8, - StringArray - ); - // starting from 5 (10 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(10))), - ], - Ok(Some("alph")), - &str, - Utf8, - StringArray - ); - // starting from -1 (4 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(4))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - // starting from 0 (5 + -5) - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(-5))), - lit(ScalarValue::Int64(Some(5))), - ], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(None)), - lit(ScalarValue::Int64(Some(20))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(3))), - lit(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(1))), - lit(ScalarValue::Int64(Some(-1))), - ], - exec_err!("negative substring length not allowed: substr(, 1, -1)"), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Substr, - &[ - lit("joséésoj"), - lit(ScalarValue::Int64(Some(5))), - lit(ScalarValue::Int64(Some(2))), - ], - Ok(Some("és")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Substr, - &[ - lit("alphabet"), - lit(ScalarValue::Int64(Some(0))), - ], - internal_err!( - "function substr requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] test_function!( Translate, &[lit("12345"), lit("143"), lit("ax"),], diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs index faff21111a61..ecbd1ea320d4 100644 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ b/datafusion/physical-expr/src/unicode_expressions.rs @@ -21,7 +21,6 @@ //! Unicode expressions -use std::cmp::max; use std::sync::Arc; use arrow::{ @@ -36,100 +35,6 @@ use datafusion_common::{ exec_err, Result, }; -/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.) -/// strpos('high', 'ig') = 2 -/// The implementation uses UTF-8 code points as characters -pub fn strpos(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - let string_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - - let substring_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = string_array - .iter() - .zip(substring_array.iter()) - .map(|(string, substring)| match (string, substring) { - (Some(string), Some(substring)) => { - // the find method returns the byte index of the substring - // Next, we count the number of the chars until that byte - T::Native::from_usize( - string - .find(substring) - .map(|x| string[..x].chars().count() + 1) - .unwrap_or(0), - ) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).) -/// substr('alphabet', 3) = 'phabet' -/// substr('alphabet', 3, 2) = 'ph' -/// The implementation uses UTF-8 code points as characters -pub fn substr(args: &[ArrayRef]) -> Result { - match args.len() { - 2 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - - let result = string_array - .iter() - .zip(start_array.iter()) - .map(|(string, start)| match (string, start) { - (Some(string), Some(start)) => { - if start <= 0 { - Some(string.to_string()) - } else { - Some(string.chars().skip(start as usize - 1).collect()) - } - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) - } - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let start_array = as_int64_array(&args[1])?; - let count_array = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(start_array.iter()) - .zip(count_array.iter()) - .map(|((string, start), count)| match (string, start, count) { - (Some(string), Some(start), Some(count)) => { - if count < 0 { - exec_err!( - "negative substring length not allowed: substr(, {start}, {count})" - ) - } else { - let skip = max(0, start - 1); - let count = max(0, count + (if start < 1 {start - 1} else {0})); - Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::())) - } - } - _ => Ok(None), - }) - .collect::>>()?; - - Ok(Arc::new(result) as ArrayRef) - } - other => { - exec_err!("substr was called with {other} arguments. It requires 2 or 3.") - } - } -} - /// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. /// translate('12345', '143', 'ax') = 'a2x5' pub fn translate(args: &[ArrayRef]) -> Result { diff --git a/datafusion/proto/Cargo.toml b/datafusion/proto/Cargo.toml index f5297aefcd1c..bec2b8c53a7a 100644 --- a/datafusion/proto/Cargo.toml +++ b/datafusion/proto/Cargo.toml @@ -54,6 +54,7 @@ serde = { version = "1.0", optional = true } serde_json = { workspace = true, optional = true } [dev-dependencies] +datafusion-functions = { workspace = true, default-features = true } doc-comment = { workspace = true } strum = { version = "0.26.1", features = ["derive"] } tokio = { workspace = true, features = ["rt-multi-thread"] } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 6319372d98d2..3a187eabe836 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -593,8 +593,8 @@ enum ScalarFunction { // 49 was SHA512 // 50 was SplitPart // StartsWith = 51; - Strpos = 52; - Substr = 53; + // 52 was Strpos + // 53 was Substr // ToHex = 54; // 55 was ToTimestamp // 56 was ToTimestampMillis diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7281bc9dc263..07b91b26d60b 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22932,8 +22932,6 @@ impl serde::Serialize for ScalarFunction { Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", Self::Random => "Random", - Self::Strpos => "Strpos", - Self::Substr => "Substr", Self::Translate => "Translate", Self::Coalesce => "Coalesce", Self::Power => "Power", @@ -22986,8 +22984,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator", "InitCap", "Random", - "Strpos", - "Substr", "Translate", "Coalesce", "Power", @@ -23069,8 +23065,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), "Random" => Ok(ScalarFunction::Random), - "Strpos" => Ok(ScalarFunction::Strpos), - "Substr" => Ok(ScalarFunction::Substr), "Translate" => Ok(ScalarFunction::Translate), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 2fe89efb9cea..babeccec595f 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2892,8 +2892,8 @@ pub enum ScalarFunction { /// 49 was SHA512 /// 50 was SplitPart /// StartsWith = 51; - Strpos = 52, - Substr = 53, + /// 52 was Strpos + /// 53 was Substr /// ToHex = 54; /// 55 was ToTimestamp /// 56 was ToTimestampMillis @@ -3005,8 +3005,6 @@ impl ScalarFunction { ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", ScalarFunction::Random => "Random", - ScalarFunction::Strpos => "Strpos", - ScalarFunction::Substr => "Substr", ScalarFunction::Translate => "Translate", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", @@ -3053,8 +3051,6 @@ impl ScalarFunction { "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), "Random" => Some(Self::Random), - "Strpos" => Some(Self::Strpos), - "Substr" => Some(Self::Substr), "Translate" => Some(Self::Translate), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2c6f2e479b24..ff3d6773d512 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -42,10 +42,10 @@ use datafusion_expr::{ expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, logical_plan::{PlanType, StringifiedPlan}, - nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, strpos, substr, - substr_index, substring, translate, trunc, AggregateFunction, Between, BinaryExpr, - BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, - GetIndexedField, GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, substr_index, + translate, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, + BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, + GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -455,8 +455,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Random => Self::Random, - ScalarFunction::Strpos => Self::Strpos, - ScalarFunction::Substr => Self::Substr, ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, @@ -1389,25 +1387,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Strpos => Ok(strpos( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), - ScalarFunction::Substr => { - if args.len() > 2 { - assert_eq!(args.len(), 3); - Ok(substring( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )) - } else { - Ok(substr( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )) - } - } ScalarFunction::Translate => Ok(translate( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index ea682a5a22f8..89d49c5658a2 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1446,8 +1446,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Strpos => Self::Strpos, - BuiltinScalarFunction::Substr => Self::Substr, BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 3c43f100750f..3a47f556c0f3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -34,8 +34,8 @@ use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::config::{FormatOptions, TableOptions}; use datafusion_common::scalar::ScalarStructBuilder; use datafusion_common::{ - internal_err, not_impl_err, plan_err, DFField, DFSchema, DFSchemaRef, - DataFusionError, FileType, Result, ScalarValue, + internal_datafusion_err, internal_err, not_impl_err, plan_err, DFField, DFSchema, + DFSchemaRef, DataFusionError, FileType, Result, ScalarValue, }; use datafusion_expr::dml::CopyTo; use datafusion_expr::expr::{ @@ -44,8 +44,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ - col, create_udaf, lit, Accumulator, AggregateFunction, - BuiltinScalarFunction::{Sqrt, Substr}, + col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::Sqrt, ColumnarValue, Expr, ExprSchemable, LogicalPlan, Operator, PartitionEvaluator, ScalarUDF, ScalarUDFImpl, Signature, TryCast, Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, WindowUDF, @@ -60,6 +59,7 @@ use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; use datafusion_proto::protobuf; +use datafusion::execution::FunctionRegistry; use prost::Message; #[cfg(feature = "json")] @@ -1863,17 +1863,28 @@ fn roundtrip_cube() { #[test] fn roundtrip_substr() { + let ctx = SessionContext::new(); + + let fun = ctx + .state() + .udf("substr") + .map_err(|e| { + internal_datafusion_err!("Unable to find expected 'substr' function: {e:?}") + }) + .unwrap(); + // substr(string, position) - let test_expr = - Expr::ScalarFunction(ScalarFunction::new(Substr, vec![col("col"), lit(1_i64)])); + let test_expr = Expr::ScalarFunction(ScalarFunction::new_udf( + fun.clone(), + vec![col("col"), lit(1_i64)], + )); // substr(string, position, count) - let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new( - Substr, + let test_expr_with_count = Expr::ScalarFunction(ScalarFunction::new_udf( + fun, vec![col("col"), lit(1_i64), lit(1_i64)], )); - let ctx = SessionContext::new(); roundtrip_expr_test(test_expr, ctx.clone()); roundtrip_expr_test(test_expr_with_count, ctx); } diff --git a/datafusion/proto/tests/cases/serialize.rs b/datafusion/proto/tests/cases/serialize.rs index d4a1ab44a6ea..972382b841d5 100644 --- a/datafusion/proto/tests/cases/serialize.rs +++ b/datafusion/proto/tests/cases/serialize.rs @@ -260,10 +260,7 @@ fn test_expression_serialization_roundtrip() { let lit = Expr::Literal(ScalarValue::Utf8(None)); for builtin_fun in BuiltinScalarFunction::iter() { // default to 4 args (though some exprs like substr have error checking) - let num_args = match builtin_fun { - BuiltinScalarFunction::Substr => 3, - _ => 4, - }; + let num_args = 4; let args: Vec<_> = std::iter::repeat(&lit).take(num_args).cloned().collect(); let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, args)); diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index d1fc03194997..43bf2d871564 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -823,12 +823,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, planner_context: &mut PlannerContext, ) -> Result { - let fun = BuiltinScalarFunction::Strpos; + let fun = self + .context_provider + .get_function_meta("strpos") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'strpos' function") + })?; let substr = self.sql_expr_to_logical_expr(substr_expr, schema, planner_context)?; let fullstr = self.sql_expr_to_logical_expr(str_expr, schema, planner_context)?; let args = vec![fullstr, substr]; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, args))) + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } fn sql_agg_with_filter_to_expr( &self, diff --git a/datafusion/sql/src/expr/substring.rs b/datafusion/sql/src/expr/substring.rs index a5d1abf0f265..f58c6f3b94d0 100644 --- a/datafusion/sql/src/expr/substring.rs +++ b/datafusion/sql/src/expr/substring.rs @@ -16,10 +16,10 @@ // under the License. use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; -use datafusion_common::plan_err; +use datafusion_common::{internal_datafusion_err, plan_err}; use datafusion_common::{DFSchema, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::{BuiltinScalarFunction, Expr}; +use datafusion_expr::Expr; use sqlparser::ast::Expr as SQLExpr; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -68,9 +68,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } }; - Ok(Expr::ScalarFunction(ScalarFunction::new( - BuiltinScalarFunction::Substr, - args, - ))) + let fun = self + .context_provider + .get_function_meta("substr") + .ok_or_else(|| { + internal_datafusion_err!("Unable to find expected 'substr' function") + })?; + + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, args))) } } diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index a77a2bf4059c..20c8b3d25fdd 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -2087,7 +2087,7 @@ select position('' in '') 1 -query error DataFusion error: Error during planning: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. +query error DataFusion error: Execution error: The STRPOS/INSTR/POSITION function can only accept strings, but got Int64. select position(1 in 1) From eb83e95bfaaacd991dc73757ea53851a906e11ac Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Fri, 29 Mar 2024 23:07:52 +0900 Subject: [PATCH 102/117] Add datafusion-federation to Integrations (#9853) --- docs/source/user-guide/introduction.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 708318db4aba..be15848407a2 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -141,7 +141,7 @@ Here are some less active projects that used DataFusion: [kamu]: https://github.com/kamu-data/kamu-cli [greptime db]: https://github.com/GreptimeTeam/greptimedb [horaedb]: https://github.com/apache/incubator-horaedb -[influxdb iox]: https://github.com/influxdata/influxdb_iox +[influxdb]: https://github.com/influxdata/influxdb [parseable]: https://github.com/parseablehq/parseable [prql-query]: https://github.com/prql/prql-query [qv]: https://github.com/timvw/qv @@ -169,6 +169,7 @@ provide integrations with other systems, some of which are described below: - [datafusion-bigtable](https://github.com/datafusion-contrib/datafusion-bigtable) - [datafusion-catalogprovider-glue](https://github.com/datafusion-contrib/datafusion-catalogprovider-glue) +- [datafusion-federation](https://github.com/datafusion-contrib/datafusion-federation) ## Why DataFusion? From 2e94e2fe97503c30c6efa0aa6ea88d2aaf50041b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 29 Mar 2024 10:14:34 -0400 Subject: [PATCH 103/117] chore(deps): update cargo requirement from 0.77.0 to 0.78.1 (#9844) Updates the requirements on [cargo](https://github.com/rust-lang/cargo) to permit the latest version. - [Changelog](https://github.com/rust-lang/cargo/blob/master/CHANGELOG.md) - [Commits](https://github.com/rust-lang/cargo/compare/0.77.0...0.78.1) --- updated-dependencies: - dependency-name: cargo dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- datafusion/core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index de03579975a2..fbbe047880b1 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -130,7 +130,7 @@ zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] async-trait = { workspace = true } bigdecimal = { workspace = true } -cargo = "0.77.0" +cargo = "0.78.1" criterion = { version = "0.5", features = ["async_tokio"] } csv = "1.1.6" ctor = { workspace = true } From 179179c0b719a7f9e33d138ab728fdc2b0e1e1d8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:00:28 -0400 Subject: [PATCH 104/117] chore(deps-dev): bump webpack-dev-middleware (#9741) Bumps [webpack-dev-middleware](https://github.com/webpack/webpack-dev-middleware) from 5.3.3 to 5.3.4. - [Release notes](https://github.com/webpack/webpack-dev-middleware/releases) - [Changelog](https://github.com/webpack/webpack-dev-middleware/blob/v5.3.4/CHANGELOG.md) - [Commits](https://github.com/webpack/webpack-dev-middleware/compare/v5.3.3...v5.3.4) --- updated-dependencies: - dependency-name: webpack-dev-middleware dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../wasmtest/datafusion-wasm-app/package-lock.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json index 8b1b8ae079c2..7d324d074c9d 100644 --- a/datafusion/wasmtest/datafusion-wasm-app/package-lock.json +++ b/datafusion/wasmtest/datafusion-wasm-app/package-lock.json @@ -4053,9 +4053,9 @@ } }, "node_modules/webpack-dev-middleware": { - "version": "5.3.3", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.3.tgz", - "integrity": "sha512-hj5CYrY0bZLB+eTO+x/j67Pkrquiy7kWepMHmUMoPsmcUaeEnQJqFzHJOyxgWlq746/wUuA64p9ta34Kyb01pA==", + "version": "5.3.4", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", + "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", "dev": true, "dependencies": { "colorette": "^2.0.10", @@ -7427,9 +7427,9 @@ } }, "webpack-dev-middleware": { - "version": "5.3.3", - "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.3.tgz", - "integrity": "sha512-hj5CYrY0bZLB+eTO+x/j67Pkrquiy7kWepMHmUMoPsmcUaeEnQJqFzHJOyxgWlq746/wUuA64p9ta34Kyb01pA==", + "version": "5.3.4", + "resolved": "https://registry.npmjs.org/webpack-dev-middleware/-/webpack-dev-middleware-5.3.4.tgz", + "integrity": "sha512-BVdTqhhs+0IfoeAf7EoH5WE+exCmqGerHfDM0IL096Px60Tq2Mn9MAbnaGUe6HiMa41KMCYF19gyzZmBcq/o4Q==", "dev": true, "requires": { "colorette": "^2.0.10", From 21fe0b7762d088731689750e2cef1762d4f9db5e Mon Sep 17 00:00:00 2001 From: Eduard Karacharov <13005055+korowa@users.noreply.github.com> Date: Sat, 30 Mar 2024 14:44:35 +0200 Subject: [PATCH 105/117] Implement semi/anti join output statistics estimation (#9800) * semi/anti join output statistics * fix antijoin cardinality estimation --- datafusion/physical-plan/src/joins/utils.rs | 373 +++++++++++++++++--- 1 file changed, 323 insertions(+), 50 deletions(-) diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 1cb2b100e2d6..a3d20b97d1ab 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -825,27 +825,27 @@ fn estimate_join_cardinality( right_stats: Statistics, on: &JoinOn, ) -> Option { + let (left_col_stats, right_col_stats) = on + .iter() + .map(|(left, right)| { + match ( + left.as_any().downcast_ref::(), + right.as_any().downcast_ref::(), + ) { + (Some(left), Some(right)) => ( + left_stats.column_statistics[left.index()].clone(), + right_stats.column_statistics[right.index()].clone(), + ), + _ => ( + ColumnStatistics::new_unknown(), + ColumnStatistics::new_unknown(), + ), + } + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { - let (left_col_stats, right_col_stats) = on - .iter() - .map(|(left, right)| { - match ( - left.as_any().downcast_ref::(), - right.as_any().downcast_ref::(), - ) { - (Some(left), Some(right)) => ( - left_stats.column_statistics[left.index()].clone(), - right_stats.column_statistics[right.index()].clone(), - ), - _ => ( - ColumnStatistics::new_unknown(), - ColumnStatistics::new_unknown(), - ), - } - }) - .unzip::<_, _, Vec<_>, Vec<_>>(); - let ij_cardinality = estimate_inner_join_cardinality( Statistics { num_rows: left_stats.num_rows.clone(), @@ -888,10 +888,38 @@ fn estimate_join_cardinality( }) } - JoinType::LeftSemi - | JoinType::RightSemi - | JoinType::LeftAnti - | JoinType::RightAnti => None, + // For SemiJoins estimation result is either zero, in cases when inputs + // are non-overlapping according to statistics, or equal to number of rows + // for outer input + JoinType::LeftSemi | JoinType::RightSemi => { + let (outer_stats, inner_stats) = match join_type { + JoinType::LeftSemi => (left_stats, right_stats), + _ => (right_stats, left_stats), + }; + let cardinality = match estimate_disjoint_inputs(&outer_stats, &inner_stats) { + Some(estimation) => *estimation.get_value()?, + None => *outer_stats.num_rows.get_value()?, + }; + + Some(PartialJoinStatistics { + num_rows: cardinality, + column_statistics: outer_stats.column_statistics, + }) + } + + // For AntiJoins estimation always equals to outer statistics, as + // non-overlapping inputs won't affect estimation + JoinType::LeftAnti | JoinType::RightAnti => { + let outer_stats = match join_type { + JoinType::LeftAnti => left_stats, + _ => right_stats, + }; + + Some(PartialJoinStatistics { + num_rows: *outer_stats.num_rows.get_value()?, + column_statistics: outer_stats.column_statistics, + }) + } } } @@ -903,6 +931,11 @@ fn estimate_inner_join_cardinality( left_stats: Statistics, right_stats: Statistics, ) -> Option> { + // Immediatedly return if inputs considered as non-overlapping + if let Some(estimation) = estimate_disjoint_inputs(&left_stats, &right_stats) { + return Some(estimation); + }; + // The algorithm here is partly based on the non-histogram selectivity estimation // from Spark's Catalyst optimizer. let mut join_selectivity = Precision::Absent; @@ -911,30 +944,13 @@ fn estimate_inner_join_cardinality( .iter() .zip(right_stats.column_statistics.iter()) { - // If there is no overlap in any of the join columns, this means the join - // itself is disjoint and the cardinality is 0. Though we can only assume - // this when the statistics are exact (since it is a very strong assumption). - if left_stat.min_value.get_value()? > right_stat.max_value.get_value()? { - return Some( - if left_stat.min_value.is_exact().unwrap_or(false) - && right_stat.max_value.is_exact().unwrap_or(false) - { - Precision::Exact(0) - } else { - Precision::Inexact(0) - }, - ); - } - if left_stat.max_value.get_value()? < right_stat.min_value.get_value()? { - return Some( - if left_stat.max_value.is_exact().unwrap_or(false) - && right_stat.min_value.is_exact().unwrap_or(false) - { - Precision::Exact(0) - } else { - Precision::Inexact(0) - }, - ); + // Break if any of statistics bounds are undefined + if left_stat.min_value.get_value().is_none() + || left_stat.max_value.get_value().is_none() + || right_stat.min_value.get_value().is_none() + || right_stat.max_value.get_value().is_none() + { + return None; } let left_max_distinct = max_distinct_count(&left_stats.num_rows, left_stat); @@ -968,6 +984,58 @@ fn estimate_inner_join_cardinality( } } +/// Estimates if inputs are non-overlapping, using input statistics. +/// If inputs are disjoint, returns zero estimation, otherwise returns None +fn estimate_disjoint_inputs( + left_stats: &Statistics, + right_stats: &Statistics, +) -> Option> { + for (left_stat, right_stat) in left_stats + .column_statistics + .iter() + .zip(right_stats.column_statistics.iter()) + { + // If there is no overlap in any of the join columns, this means the join + // itself is disjoint and the cardinality is 0. Though we can only assume + // this when the statistics are exact (since it is a very strong assumption). + let left_min_val = left_stat.min_value.get_value(); + let right_max_val = right_stat.max_value.get_value(); + if left_min_val.is_some() + && right_max_val.is_some() + && left_min_val > right_max_val + { + return Some( + if left_stat.min_value.is_exact().unwrap_or(false) + && right_stat.max_value.is_exact().unwrap_or(false) + { + Precision::Exact(0) + } else { + Precision::Inexact(0) + }, + ); + } + + let left_max_val = left_stat.max_value.get_value(); + let right_min_val = right_stat.min_value.get_value(); + if left_max_val.is_some() + && right_min_val.is_some() + && left_max_val < right_min_val + { + return Some( + if left_stat.max_value.is_exact().unwrap_or(false) + && right_stat.min_value.is_exact().unwrap_or(false) + { + Precision::Exact(0) + } else { + Precision::Inexact(0) + }, + ); + } + } + + None +} + /// Estimate the number of maximum distinct values that can be present in the /// given column from its statistics. If distinct_count is available, uses it /// directly. Otherwise, if the column is numeric and has min/max values, it @@ -1716,9 +1784,11 @@ mod tests { #[test] fn test_inner_join_cardinality_single_column() -> Result<()> { let cases: Vec<(PartialStats, PartialStats, Option>)> = vec![ - // ----------------------------------------------------------------------------- - // | left(rows, min, max, distinct), right(rows, min, max, distinct), expected | - // ----------------------------------------------------------------------------- + // ------------------------------------------------ + // | left(rows, min, max, distinct, null_count), | + // | right(rows, min, max, distinct, null_count), | + // | expected, | + // ------------------------------------------------ // Cardinality computation // ======================= @@ -1824,6 +1894,11 @@ mod tests { None, ), // Non overlapping min/max (when exact=False). + ( + (10, Absent, Inexact(4), Absent, Absent), + (10, Inexact(5), Absent, Absent, Absent), + Some(Inexact(0)), + ), ( (10, Inexact(0), Inexact(10), Absent, Absent), (10, Inexact(11), Inexact(20), Absent, Absent), @@ -2106,6 +2181,204 @@ mod tests { Ok(()) } + #[test] + fn test_anti_semi_join_cardinality() -> Result<()> { + let cases: Vec<(JoinType, PartialStats, PartialStats, Option)> = vec![ + // ------------------------------------------------ + // | join_type , | + // | left(rows, min, max, distinct, null_count), | + // | right(rows, min, max, distinct, null_count), | + // | expected, | + // ------------------------------------------------ + + // Cardinality computation + // ======================= + ( + JoinType::LeftSemi, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(15), Inexact(25), Absent, Absent), + Some(50), + ), + ( + JoinType::RightSemi, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(15), Inexact(25), Absent, Absent), + Some(10), + ), + ( + JoinType::LeftSemi, + (10, Absent, Absent, Absent, Absent), + (50, Absent, Absent, Absent, Absent), + Some(10), + ), + ( + JoinType::LeftSemi, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(30), Inexact(40), Absent, Absent), + Some(0), + ), + ( + JoinType::LeftSemi, + (50, Inexact(10), Absent, Absent, Absent), + (10, Absent, Inexact(5), Absent, Absent), + Some(0), + ), + ( + JoinType::LeftSemi, + (50, Absent, Inexact(20), Absent, Absent), + (10, Inexact(30), Absent, Absent, Absent), + Some(0), + ), + ( + JoinType::LeftAnti, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(15), Inexact(25), Absent, Absent), + Some(50), + ), + ( + JoinType::RightAnti, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(15), Inexact(25), Absent, Absent), + Some(10), + ), + ( + JoinType::LeftAnti, + (10, Absent, Absent, Absent, Absent), + (50, Absent, Absent, Absent, Absent), + Some(10), + ), + ( + JoinType::LeftAnti, + (50, Inexact(10), Inexact(20), Absent, Absent), + (10, Inexact(30), Inexact(40), Absent, Absent), + Some(50), + ), + ( + JoinType::LeftAnti, + (50, Inexact(10), Absent, Absent, Absent), + (10, Absent, Inexact(5), Absent, Absent), + Some(50), + ), + ( + JoinType::LeftAnti, + (50, Absent, Inexact(20), Absent, Absent), + (10, Inexact(30), Absent, Absent, Absent), + Some(50), + ), + ]; + + let join_on = vec![( + Arc::new(Column::new("l_col", 0)) as _, + Arc::new(Column::new("r_col", 0)) as _, + )]; + + for (join_type, outer_info, inner_info, expected) in cases { + let outer_num_rows = outer_info.0; + let outer_col_stats = vec![create_column_stats( + outer_info.1, + outer_info.2, + outer_info.3, + outer_info.4, + )]; + + let inner_num_rows = inner_info.0; + let inner_col_stats = vec![create_column_stats( + inner_info.1, + inner_info.2, + inner_info.3, + inner_info.4, + )]; + + let output_cardinality = estimate_join_cardinality( + &join_type, + Statistics { + num_rows: Inexact(outer_num_rows), + total_byte_size: Absent, + column_statistics: outer_col_stats, + }, + Statistics { + num_rows: Inexact(inner_num_rows), + total_byte_size: Absent, + column_statistics: inner_col_stats, + }, + &join_on, + ) + .map(|cardinality| cardinality.num_rows); + + assert_eq!( + output_cardinality, expected, + "failure for join_type: {}", + join_type + ); + } + + Ok(()) + } + + #[test] + fn test_semi_join_cardinality_absent_rows() -> Result<()> { + let dummy_column_stats = + vec![create_column_stats(Absent, Absent, Absent, Absent)]; + let join_on = vec![( + Arc::new(Column::new("l_col", 0)) as _, + Arc::new(Column::new("r_col", 0)) as _, + )]; + + let absent_outer_estimation = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Absent, + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + Statistics { + num_rows: Exact(10), + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + &join_on, + ); + assert!( + absent_outer_estimation.is_none(), + "Expected \"None\" esimated SemiJoin cardinality for absent outer num_rows" + ); + + let absent_inner_estimation = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Inexact(500), + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + Statistics { + num_rows: Absent, + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + &join_on, + ).expect("Expected non-empty PartialJoinStatistics for SemiJoin with absent inner num_rows"); + + assert_eq!(absent_inner_estimation.num_rows, 500, "Expected outer.num_rows esimated SemiJoin cardinality for absent inner num_rows"); + + let absent_inner_estimation = estimate_join_cardinality( + &JoinType::LeftSemi, + Statistics { + num_rows: Absent, + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + Statistics { + num_rows: Absent, + total_byte_size: Absent, + column_statistics: dummy_column_stats.clone(), + }, + &join_on, + ); + assert!(absent_inner_estimation.is_none(), "Expected \"None\" esimated SemiJoin cardinality for absent outer and inner num_rows"); + + Ok(()) + } + #[test] fn test_calculate_join_output_ordering() -> Result<()> { let options = SortOptions::default(); From 078aeb6b8ade94689fc6c54143d4be9038929c4d Mon Sep 17 00:00:00 2001 From: Kunal Kundu Date: Sat, 30 Mar 2024 18:33:26 +0530 Subject: [PATCH 106/117] move Log2, Log10, Ln to datafusion-functions (#9869) * move log2 * move log10, ln * refactor log_b functions to use macro * update proto --- datafusion/expr/src/built_in_function.rs | 21 ------------------- datafusion/expr/src/expr_fn.rs | 6 ------ datafusion/functions/src/macros.rs | 10 +++++++-- datafusion/functions/src/math/mod.rs | 15 +++++++++---- datafusion/physical-expr/src/functions.rs | 3 --- datafusion/proto/proto/datafusion.proto | 6 +++--- datafusion/proto/src/generated/pbjson.rs | 9 -------- datafusion/proto/src/generated/prost.rs | 12 +++-------- .../proto/src/logical_plan/from_proto.rs | 10 +-------- datafusion/proto/src/logical_plan/to_proto.rs | 3 --- 10 files changed, 26 insertions(+), 69 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 423fc11c1d8c..f07e84027552 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -71,14 +71,8 @@ pub enum BuiltinScalarFunction { Lcm, /// iszero Iszero, - /// ln, Natural logarithm - Ln, /// log, same as log10 Log, - /// log10 - Log10, - /// log2 - Log2, /// nanvl Nanvl, /// pi @@ -187,10 +181,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Gcd => Volatility::Immutable, BuiltinScalarFunction::Iszero => Volatility::Immutable, BuiltinScalarFunction::Lcm => Volatility::Immutable, - BuiltinScalarFunction::Ln => Volatility::Immutable, BuiltinScalarFunction::Log => Volatility::Immutable, - BuiltinScalarFunction::Log10 => Volatility::Immutable, - BuiltinScalarFunction::Log2 => Volatility::Immutable, BuiltinScalarFunction::Nanvl => Volatility::Immutable, BuiltinScalarFunction::Pi => Volatility::Immutable, BuiltinScalarFunction::Power => Volatility::Immutable, @@ -292,9 +283,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Degrees | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 | BuiltinScalarFunction::Radians | BuiltinScalarFunction::Round | BuiltinScalarFunction::Signum @@ -412,9 +400,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Degrees | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 | BuiltinScalarFunction::Radians | BuiltinScalarFunction::Signum | BuiltinScalarFunction::Sin @@ -450,9 +435,6 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Floor - | BuiltinScalarFunction::Ln - | BuiltinScalarFunction::Log10 - | BuiltinScalarFunction::Log2 | BuiltinScalarFunction::Radians | BuiltinScalarFunction::Round | BuiltinScalarFunction::Signum @@ -490,10 +472,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Gcd => &["gcd"], BuiltinScalarFunction::Iszero => &["iszero"], BuiltinScalarFunction::Lcm => &["lcm"], - BuiltinScalarFunction::Ln => &["ln"], BuiltinScalarFunction::Log => &["log"], - BuiltinScalarFunction::Log10 => &["log10"], - BuiltinScalarFunction::Log2 => &["log2"], BuiltinScalarFunction::Nanvl => &["nanvl"], BuiltinScalarFunction::Pi => &["pi"], BuiltinScalarFunction::Power => &["power", "pow"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 09170ae639ff..e216e4e86dc1 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -570,9 +570,6 @@ scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) "); scalar_expr!(Exp, exp, num, "exponential"); scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor"); scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple"); -scalar_expr!(Log2, log2, num, "base 2 logarithm of number"); -scalar_expr!(Log10, log10, num, "base 10 logarithm of number"); -scalar_expr!(Ln, ln, num, "natural logarithm (base e) of number"); scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); @@ -1001,9 +998,6 @@ mod test { test_nary_scalar_expr!(Trunc, trunc, num, precision); test_unary_scalar_expr!(Signum, signum); test_unary_scalar_expr!(Exp, exp); - test_unary_scalar_expr!(Log2, log2); - test_unary_scalar_expr!(Log10, log10); - test_unary_scalar_expr!(Ln, ln); test_scalar_expr!(Atan2, atan2, y, x); test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index e735523df621..b23baeeacf23 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -157,14 +157,16 @@ macro_rules! downcast_arg { /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument macro_rules! make_math_unary_udf { - ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident) => { + ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $MONOTONICITY:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); mod $NAME { use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::DataType; use datafusion_common::{exec_err, DataFusionError, Result}; - use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + use datafusion_expr::{ + ColumnarValue, FuncMonotonicity, ScalarUDFImpl, Signature, Volatility, + }; use std::any::Any; use std::sync::Arc; @@ -208,6 +210,10 @@ macro_rules! make_math_unary_udf { } } + fn monotonicity(&self) -> Result> { + Ok($MONOTONICITY) + } + fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 27deb7d68427..3a4c1b1e8710 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -24,10 +24,14 @@ mod nans; make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); -make_math_unary_udf!(TanhFunc, TANH, tanh, tanh); -make_math_unary_udf!(AcosFunc, ACOS, acos, acos); -make_math_unary_udf!(AsinFunc, ASIN, asin, asin); -make_math_unary_udf!(TanFunc, TAN, tan, tan); +make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); +make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); +make_math_unary_udf!(LnFunc, LN, ln, ln, Some(vec![Some(true)])); + +make_math_unary_udf!(TanhFunc, TANH, tanh, tanh, None); +make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); +make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); +make_math_unary_udf!(TanFunc, TAN, tan, tan, None); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( @@ -37,6 +41,9 @@ export_functions!( "returns true if a given number is +NaN or -NaN otherwise returns false" ), (abs, num, "returns the absolute value of a given number"), + (log2, num, "base 2 logarithm of a number"), + (log10, num, "base 10 logarithm of a number"), + (ln, num, "natural logarithm (base e) of a number"), ( acos, num, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 513dd71d4074..515511b15fbb 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -221,9 +221,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Lcm => { Arc::new(|args| make_scalar_function_inner(math_expressions::lcm)(args)) } - BuiltinScalarFunction::Ln => Arc::new(math_expressions::ln), - BuiltinScalarFunction::Log10 => Arc::new(math_expressions::log10), - BuiltinScalarFunction::Log2 => Arc::new(math_expressions::log2), BuiltinScalarFunction::Nanvl => { Arc::new(|args| make_scalar_function_inner(math_expressions::nanvl)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 3a187eabe836..81451e40aa50 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -551,10 +551,10 @@ enum ScalarFunction { // 7 was Digest Exp = 8; Floor = 9; - Ln = 10; + // 10 was Ln Log = 11; - Log10 = 12; - Log2 = 13; + // 12 was Log10 + // 13 was Log2 Round = 14; Signum = 15; Sin = 16; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 07b91b26d60b..2949ab807e04 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22919,10 +22919,7 @@ impl serde::Serialize for ScalarFunction { Self::Cos => "Cos", Self::Exp => "Exp", Self::Floor => "Floor", - Self::Ln => "Ln", Self::Log => "Log", - Self::Log10 => "Log10", - Self::Log2 => "Log2", Self::Round => "Round", Self::Signum => "Signum", Self::Sin => "Sin", @@ -22971,10 +22968,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cos", "Exp", "Floor", - "Ln", "Log", - "Log10", - "Log2", "Round", "Signum", "Sin", @@ -23052,10 +23046,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cos" => Ok(ScalarFunction::Cos), "Exp" => Ok(ScalarFunction::Exp), "Floor" => Ok(ScalarFunction::Floor), - "Ln" => Ok(ScalarFunction::Ln), "Log" => Ok(ScalarFunction::Log), - "Log10" => Ok(ScalarFunction::Log10), - "Log2" => Ok(ScalarFunction::Log2), "Round" => Ok(ScalarFunction::Round), "Signum" => Ok(ScalarFunction::Signum), "Sin" => Ok(ScalarFunction::Sin), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index babeccec595f..6f7e8a9789a6 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2850,10 +2850,10 @@ pub enum ScalarFunction { /// 7 was Digest Exp = 8, Floor = 9, - Ln = 10, + /// 10 was Ln Log = 11, - Log10 = 12, - Log2 = 13, + /// 12 was Log10 + /// 13 was Log2 Round = 14, Signum = 15, Sin = 16, @@ -2992,10 +2992,7 @@ impl ScalarFunction { ScalarFunction::Cos => "Cos", ScalarFunction::Exp => "Exp", ScalarFunction::Floor => "Floor", - ScalarFunction::Ln => "Ln", ScalarFunction::Log => "Log", - ScalarFunction::Log10 => "Log10", - ScalarFunction::Log2 => "Log2", ScalarFunction::Round => "Round", ScalarFunction::Signum => "Signum", ScalarFunction::Sin => "Sin", @@ -3038,10 +3035,7 @@ impl ScalarFunction { "Cos" => Some(Self::Cos), "Exp" => Some(Self::Exp), "Floor" => Some(Self::Floor), - "Ln" => Some(Self::Ln), "Log" => Some(Self::Log), - "Log10" => Some(Self::Log10), - "Log2" => Some(Self::Log2), "Round" => Some(Self::Round), "Signum" => Some(Self::Signum), "Sin" => Some(Self::Sin), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index ff3d6773d512..d372cb428c73 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -40,7 +40,7 @@ use datafusion_expr::{ acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, floor, gcd, initcap, iszero, lcm, ln, log, log10, log2, + factorial, find_in_set, floor, gcd, initcap, iszero, lcm, log, logical_plan::{PlanType, StringifiedPlan}, nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, substr_index, translate, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, @@ -437,8 +437,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Atanh => Self::Atanh, ScalarFunction::Exp => Self::Exp, ScalarFunction::Log => Self::Log, - ScalarFunction::Ln => Self::Ln, - ScalarFunction::Log10 => Self::Log10, ScalarFunction::Degrees => Self::Degrees, ScalarFunction::Radians => Self::Radians, ScalarFunction::Factorial => Self::Factorial, @@ -449,7 +447,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Round => Self::Round, ScalarFunction::Trunc => Self::Trunc, ScalarFunction::Concat => Self::Concat, - ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, ScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, ScalarFunction::EndsWith => Self::EndsWith, @@ -1348,11 +1345,6 @@ pub fn parse_expr( ScalarFunction::Radians => { Ok(radians(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::Log2 => Ok(log2(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Ln => Ok(ln(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Log10 => { - Ok(log10(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Floor => { Ok(floor(parse_expr(&args[0], registry, codec)?)) } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 89d49c5658a2..1e4e85c51f70 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1431,8 +1431,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Gcd => Self::Gcd, BuiltinScalarFunction::Lcm => Self::Lcm, BuiltinScalarFunction::Log => Self::Log, - BuiltinScalarFunction::Ln => Self::Ln, - BuiltinScalarFunction::Log10 => Self::Log10, BuiltinScalarFunction::Degrees => Self::Degrees, BuiltinScalarFunction::Radians => Self::Radians, BuiltinScalarFunction::Floor => Self::Floor, @@ -1440,7 +1438,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Round => Self::Round, BuiltinScalarFunction::Trunc => Self::Trunc, BuiltinScalarFunction::Concat => Self::Concat, - BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, BuiltinScalarFunction::ConcatWithSeparator => Self::ConcatWithSeparator, BuiltinScalarFunction::EndsWith => Self::EndsWith, From d896000ad3111e46d1b3e53b03c9a10092b51e65 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 30 Mar 2024 12:54:26 -0400 Subject: [PATCH 107/117] Add CI compile checks for feature flags in datafusion-functions (#9772) * Add CI checks for feature flags * Fix builds * Move function benchmark to datafusion-functions crate * Less aggressive ci checks * Improve doc * Fix compilation of datafusion-array * toml format * Update datafusion/functions-array/benches/array_expression.rs --- .github/workflows/rust.yml | 59 ++++++++++++++----- datafusion/core/Cargo.toml | 7 +-- datafusion/functions-array/Cargo.toml | 7 +++ .../benches/array_expression.rs | 4 +- datafusion/functions/Cargo.toml | 4 ++ 5 files changed, 60 insertions(+), 21 deletions(-) rename datafusion/{core => functions-array}/benches/array_expression.rs (93%) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 07c46351e9ac..ffd45b9777ef 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -65,42 +65,73 @@ jobs: # this key equals the ones on `linux-build-lib` for re-use key: cargo-cache-benchmark-${{ hashFiles('datafusion/**/Cargo.toml', 'benchmarks/Cargo.toml', 'datafusion-cli/Cargo.toml') }} - - name: Check workspace without default features + - name: Check datafusion without default features + # Some of the test binaries require the parquet feature still + #run: cargo check --all-targets --no-default-features -p datafusion run: cargo check --no-default-features -p datafusion - name: Check datafusion-common without default features - run: cargo check --tests --no-default-features -p datafusion-common + run: cargo check --all-targets --no-default-features -p datafusion-common + + - name: Check datafusion-functions + run: cargo check --all-targets --no-default-features -p datafusion-functions - name: Check workspace in debug mode run: cargo check - - name: Check workspace with all features + - name: Check workspace with avro,json features run: cargo check --workspace --benches --features avro,json + - name: Check Cargo.lock for datafusion-cli + run: | + # If this test fails, try running `cargo update` in the `datafusion-cli` directory + # and check in the updated Cargo.lock file. + cargo check --manifest-path datafusion-cli/Cargo.toml --locked + # Ensure that the datafusion crate can be built with only a subset of the function # packages enabled. - - name: Check function packages (array_expressions) + - name: Check datafusion (array_expressions) run: cargo check --no-default-features --features=array_expressions -p datafusion - - name: Check function packages (datetime_expressions) + - name: Check datafusion (crypto) + run: cargo check --no-default-features --features=crypto_expressions -p datafusion + + - name: Check datafusion (datetime_expressions) run: cargo check --no-default-features --features=datetime_expressions -p datafusion - - name: Check function packages (encoding_expressions) + - name: Check datafusion (encoding_expressions) run: cargo check --no-default-features --features=encoding_expressions -p datafusion - - name: Check function packages (math_expressions) + - name: Check datafusion (math_expressions) run: cargo check --no-default-features --features=math_expressions -p datafusion - - name: Check function packages (regex_expressions) + - name: Check datafusion (regex_expressions) run: cargo check --no-default-features --features=regex_expressions -p datafusion - - name: Check Cargo.lock for datafusion-cli - run: | - # If this test fails, try running `cargo update` in the `datafusion-cli` directory - # and check in the updated Cargo.lock file. - cargo check --manifest-path datafusion-cli/Cargo.toml --locked + - name: Check datafusion (string_expressions) + run: cargo check --no-default-features --features=string_expressions -p datafusion + + # Ensure that the datafusion-functions crate can be built with only a subset of the function + # packages enabled. + - name: Check datafusion-functions (crypto) + run: cargo check --all-targets --no-default-features --features=crypto_expressions -p datafusion-functions + + - name: Check datafusion-functions (datetime_expressions) + run: cargo check --all-targets --no-default-features --features=datetime_expressions -p datafusion-functions + + - name: Check datafusion-functions (encoding_expressions) + run: cargo check --all-targets --no-default-features --features=encoding_expressions -p datafusion-functions + + - name: Check datafusion-functions (math_expressions) + run: cargo check --all-targets --no-default-features --features=math_expressions -p datafusion-functions + + - name: Check datafusion-functions (regex_expressions) + run: cargo check --all-targets --no-default-features --features=regex_expressions -p datafusion-functions + + - name: Check datafusion-functions (string_expressions) + run: cargo check --all-targets --no-default-features --features=string_expressions -p datafusion-functions - # test the crate + # Run tests linux-test: name: cargo test (amd64) needs: [ linux-build-lib ] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index fbbe047880b1..18946334dbf5 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -50,6 +50,7 @@ default = [ "datetime_expressions", "encoding_expressions", "regex_expressions", + "string_expressions", "unicode_expressions", "compression", "parquet", @@ -66,6 +67,7 @@ regex_expressions = [ "datafusion-functions/regex_expressions", ] serde = ["arrow-schema/serde"] +string_expressions = ["datafusion-functions/string_expressions"] unicode_expressions = [ "datafusion-physical-expr/unicode_expressions", "datafusion-optimizer/unicode_expressions", @@ -188,6 +190,7 @@ name = "physical_plan" [[bench]] harness = false name = "parquet_query_sql" +required-features = ["parquet"] [[bench]] harness = false @@ -204,7 +207,3 @@ name = "sort" [[bench]] harness = false name = "topk_aggregate" - -[[bench]] -harness = false -name = "array_expression" diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index 80c0e5e18768..6ef9c6b055af 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -49,3 +49,10 @@ datafusion-functions = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" + +[dev-dependencies] +criterion = { version = "0.5", features = ["async_tokio"] } + +[[bench]] +harness = false +name = "array_expression" diff --git a/datafusion/core/benches/array_expression.rs b/datafusion/functions-array/benches/array_expression.rs similarity index 93% rename from datafusion/core/benches/array_expression.rs rename to datafusion/functions-array/benches/array_expression.rs index c980329620aa..48b829793cef 100644 --- a/datafusion/core/benches/array_expression.rs +++ b/datafusion/functions-array/benches/array_expression.rs @@ -18,12 +18,10 @@ #[macro_use] extern crate criterion; extern crate arrow; -extern crate datafusion; -mod data_utils; use crate::criterion::Criterion; -use datafusion::functions_array::expr_fn::{array_replace_all, make_array}; use datafusion_expr::lit; +use datafusion_functions_array::expr_fn::{array_replace_all, make_array}; fn criterion_benchmark(c: &mut Criterion) { // Construct large arrays for benchmarking diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 3ae3061012e0..51452b9d4ca1 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -90,15 +90,19 @@ tokio = { workspace = true, features = ["macros", "rt", "sync"] } [[bench]] harness = false name = "to_timestamp" +required-features = ["datetime_expressions"] [[bench]] harness = false name = "regx" +required-features = ["regex_expressions"] [[bench]] harness = false name = "make_date" +required-features = ["datetime_expressions"] [[bench]] harness = false name = "to_char" +required-features = ["datetime_expressions"] From a5f771470e518ca5a65089464e9008aac819c318 Mon Sep 17 00:00:00 2001 From: Bruce Ritchie Date: Sat, 30 Mar 2024 16:09:55 -0400 Subject: [PATCH 108/117] move the Translate, SubstrIndex, FindInSet functions to datafusion-functions (#9864) * Fix to_timestamp benchmark * Remove reference to simd and nightly build as simd is no longer an available feature in DataFusion and building with nightly may not be a good recommendation when getting started. * Fixed missing trim() function. * Create unicode module in datafusion/functions/src/unicode and unicode_expressions feature flag, move char_length function * move Left, Lpad, Reverse, Right, Rpad functions to datafusion_functions * move strpos, substr functions to datafusion_functions * move the Translate, SubstrIndex, FindInSet functions to new datafusion-functions crate * Test code cleanup * unicode_expressions Cargo.toml updates. --------- Co-authored-by: Andrew Lamb --- datafusion-cli/Cargo.lock | 2 +- datafusion/core/Cargo.toml | 2 - datafusion/expr/src/built_in_function.rs | 40 ---- datafusion/expr/src/expr_fn.rs | 7 - datafusion/functions/Cargo.toml | 3 +- .../functions/src/unicode/find_in_set.rs | 119 ++++++++++ datafusion/functions/src/unicode/mod.rs | 24 ++ .../functions/src/unicode/substrindex.rs | 138 ++++++++++++ datafusion/functions/src/unicode/translate.rs | 213 ++++++++++++++++++ datafusion/optimizer/Cargo.toml | 3 +- datafusion/physical-expr/Cargo.toml | 3 - datafusion/physical-expr/src/functions.rs | 148 +----------- datafusion/physical-expr/src/lib.rs | 2 - .../physical-expr/src/unicode_expressions.rs | 172 -------------- datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 9 - datafusion/proto/src/generated/prost.rs | 12 +- .../proto/src/logical_plan/from_proto.rs | 26 +-- datafusion/proto/src/logical_plan/to_proto.rs | 3 - 19 files changed, 510 insertions(+), 422 deletions(-) create mode 100644 datafusion/functions/src/unicode/find_in_set.rs create mode 100644 datafusion/functions/src/unicode/substrindex.rs create mode 100644 datafusion/functions/src/unicode/translate.rs delete mode 100644 datafusion/physical-expr/src/unicode_expressions.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 2bbe89f24bbe..3be92221d3ee 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1268,6 +1268,7 @@ dependencies = [ "datafusion-execution", "datafusion-expr", "datafusion-physical-expr", + "hashbrown 0.14.3", "hex", "itertools", "log", @@ -1342,7 +1343,6 @@ dependencies = [ "rand", "regex", "sha2", - "unicode-segmentation", ] [[package]] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 18946334dbf5..f483f8aed1cd 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -69,8 +69,6 @@ regex_expressions = [ serde = ["arrow-schema/serde"] string_expressions = ["datafusion-functions/string_expressions"] unicode_expressions = [ - "datafusion-physical-expr/unicode_expressions", - "datafusion-optimizer/unicode_expressions", "datafusion-sql/unicode_expressions", "datafusion-functions/unicode_expressions", ] diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f07e84027552..f8d16f465091 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -107,12 +107,6 @@ pub enum BuiltinScalarFunction { InitCap, /// random Random, - /// translate - Translate, - /// substr_index - SubstrIndex, - /// find_in_set - FindInSet, } /// Maps the sql function name to `BuiltinScalarFunction` @@ -198,9 +192,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::EndsWith => Volatility::Immutable, BuiltinScalarFunction::InitCap => Volatility::Immutable, BuiltinScalarFunction::Radians => Volatility::Immutable, - BuiltinScalarFunction::Translate => Volatility::Immutable, - BuiltinScalarFunction::SubstrIndex => Volatility::Immutable, - BuiltinScalarFunction::FindInSet => Volatility::Immutable, // Volatile builtin functions BuiltinScalarFunction::Random => Volatility::Volatile, @@ -237,15 +228,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Pi => Ok(Float64), BuiltinScalarFunction::Random => Ok(Float64), BuiltinScalarFunction::EndsWith => Ok(Boolean), - BuiltinScalarFunction::SubstrIndex => { - utf8_to_str_type(&input_expr_types[0], "substr_index") - } - BuiltinScalarFunction::FindInSet => { - utf8_to_int_type(&input_expr_types[0], "find_in_set") - } - BuiltinScalarFunction::Translate => { - utf8_to_str_type(&input_expr_types[0], "translate") - } BuiltinScalarFunction::Factorial | BuiltinScalarFunction::Gcd @@ -326,22 +308,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - - BuiltinScalarFunction::SubstrIndex => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8, Int64]), - Exact(vec![LargeUtf8, LargeUtf8, Int64]), - ], - self.volatility(), - ), - BuiltinScalarFunction::FindInSet => Signature::one_of( - vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], - self.volatility(), - ), - - BuiltinScalarFunction::Translate => { - Signature::one_of(vec![Exact(vec![Utf8, Utf8, Utf8])], self.volatility()) - } BuiltinScalarFunction::Pi => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Random => Signature::exact(vec![], self.volatility()), BuiltinScalarFunction::Power => Signature::one_of( @@ -492,9 +458,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"], BuiltinScalarFunction::EndsWith => &["ends_with"], BuiltinScalarFunction::InitCap => &["initcap"], - BuiltinScalarFunction::Translate => &["translate"], - BuiltinScalarFunction::SubstrIndex => &["substr_index", "substring_index"], - BuiltinScalarFunction::FindInSet => &["find_in_set"], } } } @@ -559,9 +522,6 @@ macro_rules! get_optimal_return_type { // `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size. get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); -// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. -get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index e216e4e86dc1..ab5628fece12 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -576,7 +576,6 @@ scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(EndsWith, ends_with, string suffix, "whether the `string` ends with the `suffix`"); -scalar_expr!(Translate, translate, string from to, "replaces the characters in `from` with the counterpart in `to`"); nary_scalar_expr!(Coalesce, coalesce, "returns `coalesce(args...)`, which evaluates to the value of the first [Expr] which is not NULL"); //there is a func concat_ws before, so use concat_ws_expr as name.c nary_scalar_expr!( @@ -593,9 +592,6 @@ scalar_expr!( "returns true if a given number is +0.0 or -0.0 otherwise returns false" ); -scalar_expr!(SubstrIndex, substr_index, string delimiter count, "Returns the substring from str before count occurrences of the delimiter"); -scalar_expr!(FindInSet, find_in_set, str strlist, "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"); - /// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression. pub fn case(expr: Expr) -> CaseBuilder { CaseBuilder::new(Some(Box::new(expr)), vec![], vec![], None) @@ -1006,8 +1002,5 @@ mod test { test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); test_scalar_expr!(EndsWith, ends_with, string, characters); - test_scalar_expr!(Translate, translate, string, from, to); - test_scalar_expr!(SubstrIndex, substr_index, string, delimiter, count); - test_scalar_expr!(FindInSet, find_in_set, string, stringlist); } } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 51452b9d4ca1..425ac207c33e 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -54,7 +54,7 @@ regex_expressions = ["regex"] # enable string functions string_expressions = ["uuid"] # enable unicode functions -unicode_expressions = ["unicode-segmentation"] +unicode_expressions = ["hashbrown", "unicode-segmentation"] [lib] name = "datafusion_functions" @@ -72,6 +72,7 @@ datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true, default-features = true } +hashbrown = { version = "0.14", features = ["raw"], optional = true } hex = { version = "0.4", optional = true } itertools = { workspace = true } log = { workspace = true } diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs new file mode 100644 index 000000000000..7e0306d49454 --- /dev/null +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ + ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_int_type}; + +#[derive(Debug)] +pub(super) struct FindInSetFunc { + signature: Signature, +} + +impl FindInSetFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for FindInSetFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "find_in_set" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_int_type(&arg_types[0], "find_in_set") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => { + make_scalar_function(find_in_set::, vec![])(args) + } + DataType::LargeUtf8 => { + make_scalar_function(find_in_set::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function find_in_set") + } + } + } +} + +///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings +///A string list is a string composed of substrings separated by , characters. +pub fn find_in_set(args: &[ArrayRef]) -> Result +where + T::Native: OffsetSizeTrait, +{ + if args.len() != 2 { + return exec_err!( + "find_in_set was called with {} arguments. It requires 2.", + args.len() + ); + } + + let str_array: &GenericStringArray = + as_generic_string_array::(&args[0])?; + let str_list_array: &GenericStringArray = + as_generic_string_array::(&args[1])?; + + let result = str_array + .iter() + .zip(str_list_array.iter()) + .map(|(string, str_list)| match (string, str_list) { + (Some(string), Some(str_list)) => { + let mut res = 0; + let str_set: Vec<&str> = str_list.split(',').collect(); + for (idx, str) in str_set.iter().enumerate() { + if str == &string { + res = idx + 1; + break; + } + } + T::Native::from_usize(res) + } + _ => None, + }) + .collect::>(); + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/mod.rs b/datafusion/functions/src/unicode/mod.rs index ddab0d1e27c9..eba4cd5048eb 100644 --- a/datafusion/functions/src/unicode/mod.rs +++ b/datafusion/functions/src/unicode/mod.rs @@ -22,6 +22,7 @@ use std::sync::Arc; use datafusion_expr::ScalarUDF; mod character_length; +mod find_in_set; mod left; mod lpad; mod reverse; @@ -29,6 +30,8 @@ mod right; mod rpad; mod strpos; mod substr; +mod substrindex; +mod translate; // create UDFs make_udf_function!( @@ -36,6 +39,7 @@ make_udf_function!( CHARACTER_LENGTH, character_length ); +make_udf_function!(find_in_set::FindInSetFunc, FIND_IN_SET, find_in_set); make_udf_function!(left::LeftFunc, LEFT, left); make_udf_function!(lpad::LPadFunc, LPAD, lpad); make_udf_function!(right::RightFunc, RIGHT, right); @@ -43,6 +47,8 @@ make_udf_function!(reverse::ReverseFunc, REVERSE, reverse); make_udf_function!(rpad::RPadFunc, RPAD, rpad); make_udf_function!(strpos::StrposFunc, STRPOS, strpos); make_udf_function!(substr::SubstrFunc, SUBSTR, substr); +make_udf_function!(substrindex::SubstrIndexFunc, SUBSTR_INDEX, substr_index); +make_udf_function!(translate::TranslateFunc, TRANSLATE, translate); pub mod expr_fn { use datafusion_expr::Expr; @@ -57,6 +63,11 @@ pub mod expr_fn { super::character_length().call(vec![string]) } + #[doc = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings"] + pub fn find_in_set(string: Expr, strlist: Expr) -> Expr { + super::find_in_set().call(vec![string, strlist]) + } + #[doc = "finds the position from where the `substring` matches the `string`"] pub fn instr(string: Expr, substring: Expr) -> Expr { strpos(string, substring) @@ -111,12 +122,23 @@ pub mod expr_fn { pub fn substring(string: Expr, position: Expr, length: Expr) -> Expr { super::substr().call(vec![string, position, length]) } + + #[doc = "Returns the substring from str before count occurrences of the delimiter"] + pub fn substr_index(string: Expr, delimiter: Expr, count: Expr) -> Expr { + super::substr_index().call(vec![string, delimiter, count]) + } + + #[doc = "replaces the characters in `from` with the counterpart in `to`"] + pub fn translate(string: Expr, from: Expr, to: Expr) -> Expr { + super::translate().call(vec![string, from, to]) + } } /// Return a list of all functions in this package pub fn functions() -> Vec> { vec![ character_length(), + find_in_set(), left(), lpad(), reverse(), @@ -124,5 +146,7 @@ pub fn functions() -> Vec> { rpad(), strpos(), substr(), + substr_index(), + translate(), ] } diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs new file mode 100644 index 000000000000..77e8116fff4c --- /dev/null +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; + +use datafusion_common::cast::{as_generic_string_array, as_int64_array}; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct SubstrIndexFunc { + signature: Signature, + aliases: Vec, +} + +impl SubstrIndexFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8, Int64]), + Exact(vec![LargeUtf8, LargeUtf8, Int64]), + ], + Volatility::Immutable, + ), + aliases: vec![String::from("substring_index")], + } + } +} + +impl ScalarUDFImpl for SubstrIndexFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "substr_index" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "substr_index") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(substr_index::, vec![])(args), + DataType::LargeUtf8 => { + make_scalar_function(substr_index::, vec![])(args) + } + other => { + exec_err!("Unsupported data type {other:?} for function substr_index") + } + } + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. +/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www +/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache +/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org +/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org +pub fn substr_index(args: &[ArrayRef]) -> Result { + if args.len() != 3 { + return exec_err!( + "substr_index was called with {} arguments. It requires 3.", + args.len() + ); + } + + let string_array = as_generic_string_array::(&args[0])?; + let delimiter_array = as_generic_string_array::(&args[1])?; + let count_array = as_int64_array(&args[2])?; + + let result = string_array + .iter() + .zip(delimiter_array.iter()) + .zip(count_array.iter()) + .map(|((string, delimiter), n)| match (string, delimiter, n) { + (Some(string), Some(delimiter), Some(n)) => { + // In MySQL, these cases will return an empty string. + if n == 0 || string.is_empty() || delimiter.is_empty() { + return Some(String::new()); + } + + let splitted: Box> = if n > 0 { + Box::new(string.split(delimiter)) + } else { + Box::new(string.rsplit(delimiter)) + }; + let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); + // The length of the substring covered by substr_index. + let length = splitted + .take(occurrences) // at least 1 element, since n != 0 + .map(|s| s.len() + delimiter.len()) + .sum::() + - delimiter.len(); + if n > 0 { + Some(string[..length].to_owned()) + } else { + Some(string[string.len() - length..].to_owned()) + } + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs new file mode 100644 index 000000000000..bc1836700304 --- /dev/null +++ b/datafusion/functions/src/unicode/translate.rs @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::datatypes::DataType; +use hashbrown::HashMap; +use unicode_segmentation::UnicodeSegmentation; + +use datafusion_common::cast::as_generic_string_array; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::TypeSignature::Exact; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; + +use crate::utils::{make_scalar_function, utf8_to_str_type}; + +#[derive(Debug)] +pub(super) struct TranslateFunc { + signature: Signature, +} + +impl TranslateFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Utf8, Utf8, Utf8])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for TranslateFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "translate" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + utf8_to_str_type(&arg_types[0], "translate") + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match args[0].data_type() { + DataType::Utf8 => make_scalar_function(translate::, vec![])(args), + DataType::LargeUtf8 => make_scalar_function(translate::, vec![])(args), + other => { + exec_err!("Unsupported data type {other:?} for function translate") + } + } + } +} + +/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. +/// translate('12345', '143', 'ax') = 'a2x5' +fn translate(args: &[ArrayRef]) -> Result { + let string_array = as_generic_string_array::(&args[0])?; + let from_array = as_generic_string_array::(&args[1])?; + let to_array = as_generic_string_array::(&args[2])?; + + let result = string_array + .iter() + .zip(from_array.iter()) + .zip(to_array.iter()) + .map(|((string, from), to)| match (string, from, to) { + (Some(string), Some(from), Some(to)) => { + // create a hashmap of [char, index] to change from O(n) to O(1) for from list + let from_map: HashMap<&str, usize> = from + .graphemes(true) + .collect::>() + .iter() + .enumerate() + .map(|(index, c)| (c.to_owned(), index)) + .collect(); + + let to = to.graphemes(true).collect::>(); + + Some( + string + .graphemes(true) + .collect::>() + .iter() + .flat_map(|c| match from_map.get(*c) { + Some(n) => to.get(*n).copied(), + None => Some(*c), + }) + .collect::>() + .concat(), + ) + } + _ => None, + }) + .collect::>(); + + Ok(Arc::new(result) as ArrayRef) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; + + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::translate::TranslateFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() -> Result<()> { + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(Some("a2x5")), + &str, + Utf8, + StringArray + ); + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::from("ax")) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::Utf8(None)) + ], + Ok(None), + &str, + Utf8, + StringArray + ); + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("é2íñ5")), + ColumnarValue::Scalar(ScalarValue::from("éñí")), + ColumnarValue::Scalar(ScalarValue::from("óü")), + ], + Ok(Some("ó2ü5")), + &str, + Utf8, + StringArray + ); + #[cfg(not(feature = "unicode_expressions"))] + test_function!( + TranslateFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::from("12345")), + ColumnarValue::Scalar(ScalarValue::from("143")), + ColumnarValue::Scalar(ScalarValue::from("ax")), + ], + internal_err!( + "function translate requires compilation with feature flag: unicode_expressions." + ), + &str, + Utf8, + StringArray + ); + + Ok(()) + } +} diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 861715b351a6..1d64a22f1463 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -34,9 +34,8 @@ path = "src/lib.rs" [features] crypto_expressions = ["datafusion-physical-expr/crypto_expressions"] -default = ["unicode_expressions", "crypto_expressions", "regex_expressions"] +default = ["crypto_expressions", "regex_expressions"] regex_expressions = ["datafusion-physical-expr/regex_expressions"] -unicode_expressions = ["datafusion-physical-expr/unicode_expressions"] [dependencies] arrow = { workspace = true } diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 24b831e7c575..baca00bea724 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,12 +37,10 @@ crypto_expressions = ["md-5", "sha2", "blake2", "blake3"] default = [ "crypto_expressions", "regex_expressions", - "unicode_expressions", "encoding_expressions", ] encoding_expressions = ["base64", "hex"] regex_expressions = ["regex"] -unicode_expressions = ["unicode-segmentation"] [dependencies] ahash = { version = "0.8", default-features = false, features = [ @@ -73,7 +71,6 @@ petgraph = "0.6.2" rand = { workspace = true } regex = { version = "1.8", optional = true } sha2 = { version = "^0.10.1", optional = true } -unicode-segmentation = { version = "^1.7.1", optional = true } [dev-dependencies] criterion = "0.5" diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 515511b15fbb..5b9b46c3991b 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -35,7 +35,7 @@ use std::sync::Arc; use arrow::{ array::ArrayRef, - datatypes::{DataType, Int32Type, Int64Type, Schema}, + datatypes::{DataType, Schema}, }; use arrow_array::Array; @@ -85,26 +85,6 @@ pub fn create_physical_expr( ))) } -#[cfg(feature = "unicode_expressions")] -macro_rules! invoke_if_unicode_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => {{ - use crate::unicode_expressions; - unicode_expressions::$FUNC::<$T> - }}; -} - -#[cfg(not(feature = "unicode_expressions"))] -macro_rules! invoke_if_unicode_expressions_feature_flag { - ($FUNC:ident, $T:tt, $NAME:expr) => { - |_: &[ArrayRef]| -> Result { - internal_err!( - "function {} requires compilation with feature flag: unicode_expressions.", - $NAME - ) - } - }; -} - #[derive(Debug, Clone, Copy)] pub enum Hint { /// Indicates the argument needs to be padded if it is scalar @@ -278,71 +258,6 @@ pub fn create_physical_fun( exec_err!("Unsupported data type {other:?} for function ends_with") } }), - BuiltinScalarFunction::Translate => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i32, - "translate" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - translate, - i64, - "translate" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function translate") - } - }), - BuiltinScalarFunction::SubstrIndex => { - Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - substr_index, - i32, - "substr_index" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - substr_index, - i64, - "substr_index" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function substr_index") - } - }) - } - BuiltinScalarFunction::FindInSet => Arc::new(|args| match args[0].data_type() { - DataType::Utf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - find_in_set, - Int32Type, - "find_in_set" - ); - make_scalar_function_inner(func)(args) - } - DataType::LargeUtf8 => { - let func = invoke_if_unicode_expressions_feature_flag!( - find_in_set, - Int64Type, - "find_in_set" - ); - make_scalar_function_inner(func)(args) - } - other => { - exec_err!("Unsupported data type {other:?} for function find_in_set") - } - }), }) } @@ -631,66 +546,7 @@ mod tests { Boolean, BooleanArray ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit("143"), lit("ax"),], - Ok(Some("a2x5")), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit(ScalarValue::Utf8(None)), lit("143"), lit("ax"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit(ScalarValue::Utf8(None)), lit("ax"),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("12345"), lit("143"), lit(ScalarValue::Utf8(None)),], - Ok(None), - &str, - Utf8, - StringArray - ); - #[cfg(feature = "unicode_expressions")] - test_function!( - Translate, - &[lit("é2íñ5"), lit("éñí"), lit("óü"),], - Ok(Some("ó2ü5")), - &str, - Utf8, - StringArray - ); - #[cfg(not(feature = "unicode_expressions"))] - test_function!( - Translate, - &[ - lit("12345"), - lit("143"), - lit("ax"), - ], - internal_err!( - "function translate requires compilation with feature flag: unicode_expressions." - ), - &str, - Utf8, - StringArray - ); + Ok(()) } diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index 1dead099540b..7819d5116160 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -33,8 +33,6 @@ pub mod sort_properties; pub mod string_expressions; pub mod tree_node; pub mod udf; -#[cfg(feature = "unicode_expressions")] -pub mod unicode_expressions; pub mod utils; pub mod window; diff --git a/datafusion/physical-expr/src/unicode_expressions.rs b/datafusion/physical-expr/src/unicode_expressions.rs deleted file mode 100644 index ecbd1ea320d4..000000000000 --- a/datafusion/physical-expr/src/unicode_expressions.rs +++ /dev/null @@ -1,172 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Some of these functions reference the Postgres documentation -// or implementation to ensure compatibility and are subject to -// the Postgres license. - -//! Unicode expressions - -use std::sync::Arc; - -use arrow::{ - array::{ArrayRef, GenericStringArray, OffsetSizeTrait, PrimitiveArray}, - datatypes::{ArrowNativeType, ArrowPrimitiveType}, -}; -use hashbrown::HashMap; -use unicode_segmentation::UnicodeSegmentation; - -use datafusion_common::{ - cast::{as_generic_string_array, as_int64_array}, - exec_err, Result, -}; - -/// Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted. -/// translate('12345', '143', 'ax') = 'a2x5' -pub fn translate(args: &[ArrayRef]) -> Result { - let string_array = as_generic_string_array::(&args[0])?; - let from_array = as_generic_string_array::(&args[1])?; - let to_array = as_generic_string_array::(&args[2])?; - - let result = string_array - .iter() - .zip(from_array.iter()) - .zip(to_array.iter()) - .map(|((string, from), to)| match (string, from, to) { - (Some(string), Some(from), Some(to)) => { - // create a hashmap of [char, index] to change from O(n) to O(1) for from list - let from_map: HashMap<&str, usize> = from - .graphemes(true) - .collect::>() - .iter() - .enumerate() - .map(|(index, c)| (c.to_owned(), index)) - .collect(); - - let to = to.graphemes(true).collect::>(); - - Some( - string - .graphemes(true) - .collect::>() - .iter() - .flat_map(|c| match from_map.get(*c) { - Some(n) => to.get(*n).copied(), - None => Some(*c), - }) - .collect::>() - .concat(), - ) - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -/// Returns the substring from str before count occurrences of the delimiter delim. If count is positive, everything to the left of the final delimiter (counting from the left) is returned. If count is negative, everything to the right of the final delimiter (counting from the right) is returned. -/// SUBSTRING_INDEX('www.apache.org', '.', 1) = www -/// SUBSTRING_INDEX('www.apache.org', '.', 2) = www.apache -/// SUBSTRING_INDEX('www.apache.org', '.', -2) = apache.org -/// SUBSTRING_INDEX('www.apache.org', '.', -1) = org -pub fn substr_index(args: &[ArrayRef]) -> Result { - if args.len() != 3 { - return exec_err!( - "substr_index was called with {} arguments. It requires 3.", - args.len() - ); - } - - let string_array = as_generic_string_array::(&args[0])?; - let delimiter_array = as_generic_string_array::(&args[1])?; - let count_array = as_int64_array(&args[2])?; - - let result = string_array - .iter() - .zip(delimiter_array.iter()) - .zip(count_array.iter()) - .map(|((string, delimiter), n)| match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - // In MySQL, these cases will return an empty string. - if n == 0 || string.is_empty() || delimiter.is_empty() { - return Some(String::new()); - } - - let splitted: Box> = if n > 0 { - Box::new(string.split(delimiter)) - } else { - Box::new(string.rsplit(delimiter)) - }; - let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX); - // The length of the substring covered by substr_index. - let length = splitted - .take(occurrences) // at least 1 element, since n != 0 - .map(|s| s.len() + delimiter.len()) - .sum::() - - delimiter.len(); - if n > 0 { - Some(string[..length].to_owned()) - } else { - Some(string[string.len() - length..].to_owned()) - } - } - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) -} - -///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings -///A string list is a string composed of substrings separated by , characters. -pub fn find_in_set(args: &[ArrayRef]) -> Result -where - T::Native: OffsetSizeTrait, -{ - if args.len() != 2 { - return exec_err!( - "find_in_set was called with {} arguments. It requires 2.", - args.len() - ); - } - - let str_array: &GenericStringArray = - as_generic_string_array::(&args[0])?; - let str_list_array: &GenericStringArray = - as_generic_string_array::(&args[1])?; - - let result = str_array - .iter() - .zip(str_list_array.iter()) - .map(|(string, str_list)| match (string, str_list) { - (Some(string), Some(str_list)) => { - let mut res = 0; - let str_set: Vec<&str> = str_list.split(',').collect(); - for (idx, str) in str_set.iter().enumerate() { - if str == &string { - res = idx + 1; - break; - } - } - T::Native::from_usize(res) - } - _ => None, - }) - .collect::>(); - Ok(Arc::new(result) as ArrayRef) -} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 81451e40aa50..b756e0575d71 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -601,7 +601,7 @@ enum ScalarFunction { // 57 was ToTimestampMicros // 58 was ToTimestampSeconds // 59 was Now - Translate = 60; + // 60 was Translate // Trim = 61; // Upper = 62; Coalesce = 63; @@ -665,8 +665,8 @@ enum ScalarFunction { // 123 is ArrayExcept // 124 was ArrayPopFront // 125 was Levenshtein - SubstrIndex = 126; - FindInSet = 127; + // 126 was SubstrIndex + // 127 was FindInSet // 128 was ArraySort // 129 was ArrayDistinct // 130 was ArrayResize diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 2949ab807e04..3c3d60300786 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22929,7 +22929,6 @@ impl serde::Serialize for ScalarFunction { Self::ConcatWithSeparator => "ConcatWithSeparator", Self::InitCap => "InitCap", Self::Random => "Random", - Self::Translate => "Translate", Self::Coalesce => "Coalesce", Self::Power => "Power", Self::Atan2 => "Atan2", @@ -22948,8 +22947,6 @@ impl serde::Serialize for ScalarFunction { Self::Cot => "Cot", Self::Nanvl => "Nanvl", Self::Iszero => "Iszero", - Self::SubstrIndex => "SubstrIndex", - Self::FindInSet => "FindInSet", Self::EndsWith => "EndsWith", }; serializer.serialize_str(variant) @@ -22978,7 +22975,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator", "InitCap", "Random", - "Translate", "Coalesce", "Power", "Atan2", @@ -22997,8 +22993,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cot", "Nanvl", "Iszero", - "SubstrIndex", - "FindInSet", "EndsWith", ]; @@ -23056,7 +23050,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ConcatWithSeparator" => Ok(ScalarFunction::ConcatWithSeparator), "InitCap" => Ok(ScalarFunction::InitCap), "Random" => Ok(ScalarFunction::Random), - "Translate" => Ok(ScalarFunction::Translate), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), "Atan2" => Ok(ScalarFunction::Atan2), @@ -23075,8 +23068,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cot" => Ok(ScalarFunction::Cot), "Nanvl" => Ok(ScalarFunction::Nanvl), "Iszero" => Ok(ScalarFunction::Iszero), - "SubstrIndex" => Ok(ScalarFunction::SubstrIndex), - "FindInSet" => Ok(ScalarFunction::FindInSet), "EndsWith" => Ok(ScalarFunction::EndsWith), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 6f7e8a9789a6..9860587d3eca 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2900,7 +2900,7 @@ pub enum ScalarFunction { /// 57 was ToTimestampMicros /// 58 was ToTimestampSeconds /// 59 was Now - Translate = 60, + /// 60 was Translate /// Trim = 61; /// Upper = 62; Coalesce = 63, @@ -2964,8 +2964,8 @@ pub enum ScalarFunction { /// 123 is ArrayExcept /// 124 was ArrayPopFront /// 125 was Levenshtein - SubstrIndex = 126, - FindInSet = 127, + /// 126 was SubstrIndex + /// 127 was FindInSet /// 128 was ArraySort /// 129 was ArrayDistinct /// 130 was ArrayResize @@ -3002,7 +3002,6 @@ impl ScalarFunction { ScalarFunction::ConcatWithSeparator => "ConcatWithSeparator", ScalarFunction::InitCap => "InitCap", ScalarFunction::Random => "Random", - ScalarFunction::Translate => "Translate", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", ScalarFunction::Atan2 => "Atan2", @@ -3021,8 +3020,6 @@ impl ScalarFunction { ScalarFunction::Cot => "Cot", ScalarFunction::Nanvl => "Nanvl", ScalarFunction::Iszero => "Iszero", - ScalarFunction::SubstrIndex => "SubstrIndex", - ScalarFunction::FindInSet => "FindInSet", ScalarFunction::EndsWith => "EndsWith", } } @@ -3045,7 +3042,6 @@ impl ScalarFunction { "ConcatWithSeparator" => Some(Self::ConcatWithSeparator), "InitCap" => Some(Self::InitCap), "Random" => Some(Self::Random), - "Translate" => Some(Self::Translate), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), "Atan2" => Some(Self::Atan2), @@ -3064,8 +3060,6 @@ impl ScalarFunction { "Cot" => Some(Self::Cot), "Nanvl" => Some(Self::Nanvl), "Iszero" => Some(Self::Iszero), - "SubstrIndex" => Some(Self::SubstrIndex), - "FindInSet" => Some(Self::FindInSet), "EndsWith" => Some(Self::EndsWith), _ => None, } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index d372cb428c73..c068cfd46c1f 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -40,12 +40,11 @@ use datafusion_expr::{ acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, ends_with, exp, expr::{self, InList, Sort, WindowFunction}, - factorial, find_in_set, floor, gcd, initcap, iszero, lcm, log, + factorial, floor, gcd, initcap, iszero, lcm, log, logical_plan::{PlanType, StringifiedPlan}, - nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, substr_index, - translate, trunc, AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, - BuiltinScalarFunction, Case, Cast, Expr, GetFieldAccess, GetIndexedField, - GroupingSet, + nanvl, pi, power, radians, random, round, signum, sin, sinh, sqrt, trunc, + AggregateFunction, Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, + Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, GroupingSet::GroupingSets, JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, WindowFrameUnits, @@ -452,15 +451,12 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::EndsWith => Self::EndsWith, ScalarFunction::InitCap => Self::InitCap, ScalarFunction::Random => Self::Random, - ScalarFunction::Translate => Self::Translate, ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, ScalarFunction::Power => Self::Power, ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, - ScalarFunction::SubstrIndex => Self::SubstrIndex, - ScalarFunction::FindInSet => Self::FindInSet, } } } @@ -1379,11 +1375,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Translate => Ok(translate( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), ScalarFunction::Coalesce => { Ok(coalesce(parse_exprs(args, registry, codec)?)) } @@ -1408,15 +1399,6 @@ pub fn parse_expr( ScalarFunction::Iszero => { Ok(iszero(parse_expr(&args[0], registry, codec)?)) } - ScalarFunction::SubstrIndex => Ok(substr_index( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - parse_expr(&args[2], registry, codec)?, - )), - ScalarFunction::FindInSet => Ok(find_in_set( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), } } ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 1e4e85c51f70..9d433bb6ff97 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1443,15 +1443,12 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::EndsWith => Self::EndsWith, BuiltinScalarFunction::InitCap => Self::InitCap, BuiltinScalarFunction::Random => Self::Random, - BuiltinScalarFunction::Translate => Self::Translate, BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, BuiltinScalarFunction::Power => Self::Power, BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, - BuiltinScalarFunction::SubstrIndex => Self::SubstrIndex, - BuiltinScalarFunction::FindInSet => Self::FindInSet, }; Ok(scalar_function) From aa879bf045965d20792e9cd6ac08c550b3615280 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 30 Mar 2024 17:11:55 -0300 Subject: [PATCH 109/117] Support custom struct field names with new scalar function named_struct (#9743) * Support custom struct field names with new scalar function named_struct * add tests and corretly handle mixed arrray and scalar values * fix slt * fmt * port test to slt --------- Co-authored-by: Andrew Lamb --- datafusion/functions/src/core/mod.rs | 3 + datafusion/functions/src/core/named_struct.rs | 148 ++++++++++++++++++ datafusion/sql/src/expr/mod.rs | 45 ++++-- .../sqllogictest/test_files/explain.slt | 4 +- datafusion/sqllogictest/test_files/struct.slt | 112 ++++++++++++- .../source/user-guide/sql/scalar_functions.md | 58 +++++-- 6 files changed, 343 insertions(+), 27 deletions(-) create mode 100644 datafusion/functions/src/core/named_struct.rs diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 5a0bd2c77f63..85d2410251c5 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -20,6 +20,7 @@ mod arrow_cast; mod arrowtypeof; mod getfield; +mod named_struct; mod nullif; mod nvl; mod nvl2; @@ -32,6 +33,7 @@ make_udf_function!(nvl::NVLFunc, NVL, nvl); make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof); make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); +make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct); make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); // Export the functions out of this package, both as expr_fn as well as a list of functions @@ -42,5 +44,6 @@ export_functions!( (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."), (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."), (r#struct, args, "Returns a struct with the given arguments"), + (named_struct, args, "Returns a struct with the given names and arguments pairs"), (get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct") ); diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs new file mode 100644 index 000000000000..327a41baa741 --- /dev/null +++ b/datafusion/functions/src/core/named_struct.rs @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::StructArray; +use arrow::datatypes::{DataType, Field, Fields}; +use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +/// put values in a struct array. +fn named_struct_expr(args: &[ColumnarValue]) -> Result { + // do not accept 0 arguments. + if args.is_empty() { + return exec_err!( + "named_struct requires at least one pair of arguments, got 0 instead" + ); + } + + if args.len() % 2 != 0 { + return exec_err!( + "named_struct requires an even number of arguments, got {} instead", + args.len() + ); + } + + let (names, values): (Vec<_>, Vec<_>) = args + .chunks_exact(2) + .enumerate() + .map(|(i, chunk)| { + + let name_column = &chunk[0]; + + let name = match name_column { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, + _ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2) + }; + + Ok((name, chunk[1].clone())) + }) + .collect::>>()? + .into_iter() + .unzip(); + + let arrays = ColumnarValue::values_to_arrays(&values)?; + + let fields = names + .into_iter() + .zip(arrays) + .map(|(name, value)| { + ( + Arc::new(Field::new(name, value.data_type().clone(), true)), + value, + ) + }) + .collect::>(); + + Ok(ColumnarValue::Array(Arc::new(StructArray::from(fields)))) +} + +#[derive(Debug)] +pub(super) struct NamedStructFunc { + signature: Signature, +} + +impl NamedStructFunc { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for NamedStructFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "named_struct" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!( + "named_struct: return_type called instead of return_type_from_exprs" + ) + } + + fn return_type_from_exprs( + &self, + args: &[datafusion_expr::Expr], + schema: &dyn datafusion_common::ExprSchema, + _arg_types: &[DataType], + ) -> Result { + // do not accept 0 arguments. + if args.is_empty() { + return exec_err!( + "named_struct requires at least one pair of arguments, got 0 instead" + ); + } + + if args.len() % 2 != 0 { + return exec_err!( + "named_struct requires an even number of arguments, got {} instead", + args.len() + ); + } + + let return_fields = args + .chunks_exact(2) + .enumerate() + .map(|(i, chunk)| { + let name = &chunk[0]; + let value = &chunk[1]; + + if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name { + Ok(Field::new(name, value.get_type(schema)?, true)) + } else { + exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2) + } + }) + .collect::>>()?; + Ok(DataType::Struct(Fields::from(return_fields))) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + named_struct_expr(args) + } +} diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 43bf2d871564..064578ad51d6 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,7 +29,8 @@ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast, - Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast, + Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Literal, Operator, + TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -604,18 +605,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let args = values .into_iter() - .map(|value| { - self.sql_expr_to_logical_expr(value, input_schema, planner_context) + .enumerate() + .map(|(i, value)| { + let args = if let SQLExpr::Named { expr, name } = value { + [ + name.value.lit(), + self.sql_expr_to_logical_expr( + *expr, + input_schema, + planner_context, + )?, + ] + } else { + [ + format!("c{i}").lit(), + self.sql_expr_to_logical_expr( + value, + input_schema, + planner_context, + )?, + ] + }; + + Ok(args) }) - .collect::>>()?; - let struct_func = self + .collect::>>()? + .into_iter() + .flatten() + .collect(); + + let named_struct_func = self .context_provider - .get_function_meta("struct") + .get_function_meta("named_struct") .ok_or_else(|| { - internal_datafusion_err!("Unable to find expected 'struct' function") - })?; + internal_datafusion_err!("Unable to find expected 'named_struct' function") + })?; + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( - struct_func, + named_struct_func, args, ))) } diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index b7ad36dace16..4653250cf93f 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -390,8 +390,8 @@ query TT explain select struct(1, 2.3, 'abc'); ---- logical_plan -Projection: Struct({c0:1,c1:2.3,c2:abc}) AS struct(Int64(1),Float64(2.3),Utf8("abc")) +Projection: Struct({c0:1,c1:2.3,c2:abc}) AS named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc")) --EmptyRelation physical_plan -ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))] +ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc"))] --PlaceholderRowExec diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 1ab6f3908b53..2e0b699f6dd6 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -23,11 +23,12 @@ statement ok CREATE TABLE values( a INT, b FLOAT, - c VARCHAR + c VARCHAR, + n VARCHAR, ) AS VALUES - (1, 1.1, 'a'), - (2, 2.2, 'b'), - (3, 3.3, 'c') + (1, 1.1, 'a', NULL), + (2, 2.2, 'b', NULL), + (3, 3.3, 'c', NULL) ; # struct[i] @@ -50,6 +51,18 @@ select struct(1, 3.14, 'e'); ---- {c0: 1, c1: 3.14, c2: e} +# struct scalar function with named values +query ? +select struct(1 as "name0", 3.14 as name1, 'e', true as 'name3'); +---- +{name0: 1, name1: 3.14, c2: e, name3: true} + +# struct scalar function with mixed named and unnamed values +query ? +select struct(1, 3.14 as name1, 'e', true); +---- +{c0: 1, name1: 3.14, c2: e, c3: true} + # struct scalar function with columns #1 query ? select struct(a, b, c) from values; @@ -72,11 +85,98 @@ query TT explain select struct(a, b, c) from values; ---- logical_plan -Projection: struct(values.a, values.b, values.c) +Projection: named_struct(Utf8("c0"), values.a, Utf8("c1"), values.b, Utf8("c2"), values.c) --TableScan: values projection=[a, b, c] physical_plan -ProjectionExec: expr=[struct(a@0, b@1, c@2) as struct(values.a,values.b,values.c)] +ProjectionExec: expr=[named_struct(c0, a@0, c1, b@1, c2, c@2) as named_struct(Utf8("c0"),values.a,Utf8("c1"),values.b,Utf8("c2"),values.c)] --MemoryExec: partitions=1, partition_sizes=[1] +# error on 0 arguments +query error DataFusion error: Error during planning: No function matches the given name and argument types 'named_struct\(\)'. You might need to add explicit type casts. +select named_struct(); + +# error on odd number of arguments #1 +query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead +select named_struct('a'); + +# error on odd number of arguments #2 +query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead +select named_struct(1); + +# error on odd number of arguments #3 +query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead +select named_struct(values.a) from values; + +# error on odd number of arguments #4 +query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 3 instead +select named_struct('a', 1, 'b'); + +# error on even argument not a string literal #1 +query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(1\) instead at position 0 +select named_struct(1, 'a'); + +# error on even argument not a string literal #2 +query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(0\) instead at position 2 +select named_struct('corret', 1, 0, 'wrong'); + +# error on even argument not a string literal #3 +query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.a instead at position 0 +select named_struct(values.a, 'a') from values; + +# error on even argument not a string literal #4 +query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.c instead at position 0 +select named_struct(values.c, 'c') from values; + +# named_struct with mixed scalar and array values #1 +query ? +select named_struct('scalar', 27, 'array', values.a, 'null', NULL) from values; +---- +{scalar: 27, array: 1, null: } +{scalar: 27, array: 2, null: } +{scalar: 27, array: 3, null: } + +# named_struct with mixed scalar and array values #2 +query ? +select named_struct('array', values.a, 'scalar', 27, 'null', NULL) from values; +---- +{array: 1, scalar: 27, null: } +{array: 2, scalar: 27, null: } +{array: 3, scalar: 27, null: } + +# named_struct with mixed scalar and array values #3 +query ? +select named_struct('null', NULL, 'array', values.a, 'scalar', 27) from values; +---- +{null: , array: 1, scalar: 27} +{null: , array: 2, scalar: 27} +{null: , array: 3, scalar: 27} + +# named_struct with mixed scalar and array values #4 +query ? +select named_struct('null_array', values.n, 'array', values.a, 'scalar', 27, 'null', NULL) from values; +---- +{null_array: , array: 1, scalar: 27, null: } +{null_array: , array: 2, scalar: 27, null: } +{null_array: , array: 3, scalar: 27, null: } + +# named_struct arrays only +query ? +select named_struct('field_a', a, 'field_b', b) from values; +---- +{field_a: 1, field_b: 1.1} +{field_a: 2, field_b: 2.2} +{field_a: 3, field_b: 3.3} + +# named_struct scalars only +query ? +select named_struct('field_a', 1, 'field_b', 2); +---- +{field_a: 1, field_b: 2} + statement ok drop table values; + +query T +select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3)); +---- +Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]) \ No newline at end of file diff --git a/docs/source/user-guide/sql/scalar_functions.md b/docs/source/user-guide/sql/scalar_functions.md index 52edf4bb7217..e2e129a2e2d1 100644 --- a/docs/source/user-guide/sql/scalar_functions.md +++ b/docs/source/user-guide/sql/scalar_functions.md @@ -3312,11 +3312,12 @@ are not allowed ## Struct Functions - [struct](#struct) +- [named_struct](#named_struct) ### `struct` -Returns an Arrow struct using the specified input expressions. -Fields in the returned struct use the `cN` naming convention. +Returns an Arrow struct using the specified input expressions optionally named. +Fields in the returned struct use the optional name or the `cN` naming convention. For example: `c0`, `c1`, `c2`, etc. ``` @@ -3324,7 +3325,7 @@ struct(expression1[, ..., expression_n]) ``` For example, this query converts two columns `a` and `b` to a single column with -a struct type of fields `c0` and `c1`: +a struct type of fields `field_a` and `c1`: ``` select * from t; @@ -3335,18 +3336,55 @@ select * from t; | 3 | 4 | +---+---+ -select struct(a, b) from t; -+-----------------+ -| struct(t.a,t.b) | -+-----------------+ -| {c0: 1, c1: 2} | -| {c0: 3, c1: 4} | -+-----------------+ +select struct(a as field_a, b) from t; ++--------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("c1"),t.b) | ++--------------------------------------------------+ +| {field_a: 1, c1: 2} | +| {field_a: 3, c1: 4} | ++--------------------------------------------------+ ``` #### Arguments - **expression_n**: Expression to include in the output struct. + Can be a constant, column, or function, any combination of arithmetic or + string operators, or a named expression of previous listed . + +### `named_struct` + +Returns an Arrow struct using the specified name and input expressions pairs. + +``` +named_struct(expression1_name, expression1_input[, ..., expression_n_name, expression_n_input]) +``` + +For example, this query converts two columns `a` and `b` to a single column with +a struct type of fields `field_a` and `field_b`: + +``` +select * from t; ++---+---+ +| a | b | ++---+---+ +| 1 | 2 | +| 3 | 4 | ++---+---+ + +select named_struct('field_a', a, 'field_b', b) from t; ++-------------------------------------------------------+ +| named_struct(Utf8("field_a"),t.a,Utf8("field_b"),t.b) | ++-------------------------------------------------------+ +| {field_a: 1, field_b: 2} | +| {field_a: 3, field_b: 4} | ++-------------------------------------------------------+ +``` + +#### Arguments + +- **expression_n_name**: Name of the column field. + Must be a constant string. +- **expression_n_input**: Expression to include in the output struct. Can be a constant, column, or function, and any combination of arithmetic or string operators. From ab88220e900f1ba8ec3c65a1e4ae53b3feb2a2af Mon Sep 17 00:00:00 2001 From: Mohamed Abdeen <83442793+MohamedAbdeen21@users.noreply.github.com> Date: Sat, 30 Mar 2024 22:14:43 +0200 Subject: [PATCH 110/117] Allow declaring partition columns in `PARTITION BY` clause, backwards compatible (#9599) * Draft allow both syntaxes * suppress unused code error * prevent constraints in partition clauses * fix clippy * More tests * comment and prevent constraints on partition columns * trailing whitespaces * End-to-End test of new Hive syntax --------- Co-authored-by: Mohamed Abdeen --- datafusion/sql/src/parser.rs | 59 +++++++++++++++++-- .../test_files/create_external_table.slt | 10 +++- .../test_files/insert_to_external.slt | 36 +++++++++-- 3 files changed, 94 insertions(+), 11 deletions(-) diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index c585917a1ed0..67fa1325eea7 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -175,7 +175,7 @@ pub(crate) type LexOrdering = Vec; /// [ WITH HEADER ROW ] /// [ DELIMITER ] /// [ COMPRESSION TYPE ] -/// [ PARTITIONED BY () ] +/// [ PARTITIONED BY ( | ) ] /// [ WITH ORDER () /// [ OPTIONS () ] /// LOCATION @@ -693,7 +693,7 @@ impl<'a> DFParser<'a> { self.parser .parse_keywords(&[Keyword::IF, Keyword::NOT, Keyword::EXISTS]); let table_name = self.parser.parse_object_name(true)?; - let (columns, constraints) = self.parse_columns()?; + let (mut columns, constraints) = self.parse_columns()?; #[derive(Default)] struct Builder { @@ -754,7 +754,30 @@ impl<'a> DFParser<'a> { Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; ensure_not_set(&builder.table_partition_cols, "PARTITIONED BY")?; - builder.table_partition_cols = Some(self.parse_partitions()?); + // Expects either list of column names (col_name [, col_name]*) + // or list of column definitions (col_name datatype [, col_name datatype]* ) + // use the token after the name to decide which parsing rule to use + // Note that mixing both names and definitions is not allowed + let peeked = self.parser.peek_nth_token(2); + if peeked == Token::Comma || peeked == Token::RParen { + // list of column names + builder.table_partition_cols = Some(self.parse_partitions()?) + } else { + // list of column defs + let (cols, cons) = self.parse_columns()?; + builder.table_partition_cols = Some( + cols.iter().map(|col| col.name.to_string()).collect(), + ); + + columns.extend(cols); + + if !cons.is_empty() { + return Err(ParserError::ParserError( + "Constraints on Partition Columns are not supported" + .to_string(), + )); + } + } } Keyword::OPTIONS => { ensure_not_set(&builder.options, "OPTIONS")?; @@ -1167,9 +1190,37 @@ mod tests { }); expect_parse_ok(sql, expected)?; - // Error cases: partition column does not support type + // positive case: column definiton allowed in 'partition by' clause let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int) LOCATION 'foo.csv'"; + let expected = Statement::CreateExternalTable(CreateExternalTable { + name: "t".into(), + columns: vec![ + make_column_def("c1", DataType::Int(None)), + make_column_def("p1", DataType::Int(None)), + ], + file_type: "CSV".to_string(), + has_header: false, + delimiter: ',', + location: "foo.csv".into(), + table_partition_cols: vec!["p1".to_string()], + order_exprs: vec![], + if_not_exists: false, + file_compression_type: UNCOMPRESSED, + unbounded: false, + options: HashMap::new(), + constraints: vec![], + }); + expect_parse_ok(sql, expected)?; + + // negative case: mixed column defs and column names in `PARTITIONED BY` clause + let sql = + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (p1 int, c1) LOCATION 'foo.csv'"; + expect_parse_error(sql, "sql parser error: Expected a data type name, found: )"); + + // negative case: mixed column defs and column names in `PARTITIONED BY` clause + let sql = + "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c1, p1 int) LOCATION 'foo.csv'"; expect_parse_error(sql, "sql parser error: Expected ',' or ')' after partition definition, found: int"); // positive case: additional options (one entry) can be specified diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index c4a26a5e227d..a200217af6e1 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -100,9 +100,17 @@ CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV WITH LOCATION 'foo.csv'; statement error DataFusion error: SQL error: ParserError\("Unexpected token FOOBAR"\) CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV FOOBAR BARBAR BARFOO LOCATION 'foo.csv'; +# Missing partition column +statement error DataFusion error: Arrow error: Schema error: Unable to get field named "c2". Valid fields: \["c1"\] +create EXTERNAL TABLE t(c1 int) STORED AS CSV PARTITIONED BY (c2) LOCATION 'foo.csv' + +# Duplicate Column in `PARTITIONED BY` clause +statement error DataFusion error: Schema error: Schema contains duplicate unqualified field name c1 +create EXTERNAL TABLE t(c1 int, c2 int) STORED AS CSV PARTITIONED BY (c1 int) LOCATION 'foo.csv' + # Conflicting options statement error DataFusion error: Invalid or Unsupported Configuration: Config value "column_index_truncate_length" not found on CsvOptions CREATE EXTERNAL TABLE csv_table (column1 int) STORED AS CSV LOCATION 'foo.csv' -OPTIONS ('format.delimiter' ';', 'format.column_index_truncate_length' '123') +OPTIONS ('format.delimiter' ';', 'format.column_index_truncate_length' '123') \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/insert_to_external.slt b/datafusion/sqllogictest/test_files/insert_to_external.slt index dc60bafaa8db..4b9af3bdeafb 100644 --- a/datafusion/sqllogictest/test_files/insert_to_external.slt +++ b/datafusion/sqllogictest/test_files/insert_to_external.slt @@ -42,7 +42,7 @@ LOCATION '../../testing/data/csv/aggregate_test_100.csv' statement ok -create table dictionary_encoded_values as values +create table dictionary_encoded_values as values ('a', arrow_cast('foo', 'Dictionary(Int32, Utf8)')), ('b', arrow_cast('bar', 'Dictionary(Int32, Utf8)')); query TTT @@ -55,13 +55,13 @@ statement ok CREATE EXTERNAL TABLE dictionary_encoded_parquet_partitioned( a varchar, b varchar, -) +) STORED AS parquet LOCATION 'test_files/scratch/insert_to_external/parquet_types_partitioned/' PARTITIONED BY (b); query TT -insert into dictionary_encoded_parquet_partitioned +insert into dictionary_encoded_parquet_partitioned select * from dictionary_encoded_values ---- 2 @@ -76,13 +76,13 @@ statement ok CREATE EXTERNAL TABLE dictionary_encoded_arrow_partitioned( a varchar, b varchar, -) +) STORED AS arrow LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/' PARTITIONED BY (b); query TT -insert into dictionary_encoded_arrow_partitioned +insert into dictionary_encoded_arrow_partitioned select * from dictionary_encoded_values ---- 2 @@ -90,7 +90,7 @@ select * from dictionary_encoded_values statement ok CREATE EXTERNAL TABLE dictionary_encoded_arrow_test_readback( a varchar, -) +) STORED AS arrow LOCATION 'test_files/scratch/insert_to_external/arrow_dict_partitioned/b=bar/'; @@ -185,6 +185,30 @@ select * from partitioned_insert_test_verify; 1 2 +statement ok +CREATE EXTERNAL TABLE +partitioned_insert_test_hive(c bigint) +STORED AS csv +LOCATION 'test_files/scratch/insert_to_external/insert_to_partitioned' +PARTITIONED BY (a string, b string); + +query ITT +INSERT INTO partitioned_insert_test_hive VALUES (3,30,300); +---- +1 + +query ITT +SELECT * FROM partitioned_insert_test_hive order by a,b,c; +---- +1 10 100 +1 10 200 +1 20 100 +2 20 100 +1 20 200 +2 20 200 +3 30 300 + + statement ok CREATE EXTERNAL TABLE partitioned_insert_test_json(a string, b string) From 2159d8295218544c3e076e486c00514be7b4deb9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 30 Mar 2024 16:17:29 -0400 Subject: [PATCH 111/117] Minor: Move depcheck out of datafusion crate (200 less crates to compile) (#9865) * Minor: Move depcheck out of main datafusion test * Update dev/depcheck/README.md Co-authored-by: Andrew Lamb --------- Co-authored-by: comphead --- .github/workflows/rust.yml | 19 ++++++++++++ Cargo.toml | 2 +- datafusion/core/Cargo.toml | 1 - dev/depcheck/Cargo.toml | 25 ++++++++++++++++ dev/depcheck/README.md | 26 ++++++++++++++++ .../depcheck.rs => dev/depcheck/src/main.rs | 30 +++++++++++++++---- 6 files changed, 96 insertions(+), 7 deletions(-) create mode 100644 dev/depcheck/Cargo.toml create mode 100644 dev/depcheck/README.md rename datafusion/core/tests/depcheck.rs => dev/depcheck/src/main.rs (75%) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index ffd45b9777ef..6f6179fa52a2 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -195,6 +195,25 @@ jobs: - name: Verify Working Directory Clean run: git diff --exit-code + depcheck: + name: circular dependency check + needs: [ linux-build-lib ] + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v4 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: stable + - name: Check dependencies + run: | + cd dev/depcheck + cargo run + # Run `cargo test doc` (test documentation examples) linux-test-doc: name: cargo test doc (amd64) diff --git a/Cargo.toml b/Cargo.toml index 8e89e5ef3b85..9df489724d46 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [workspace] -exclude = ["datafusion-cli"] +exclude = ["datafusion-cli", "dev/depcheck"] members = [ "datafusion/common", "datafusion/common-runtime", diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index f483f8aed1cd..77a909731d89 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -130,7 +130,6 @@ zstd = { version = "0.13", optional = true, default-features = false } [dev-dependencies] async-trait = { workspace = true } bigdecimal = { workspace = true } -cargo = "0.78.1" criterion = { version = "0.5", features = ["async_tokio"] } csv = "1.1.6" ctor = { workspace = true } diff --git a/dev/depcheck/Cargo.toml b/dev/depcheck/Cargo.toml new file mode 100644 index 000000000000..cb4e77eabb22 --- /dev/null +++ b/dev/depcheck/Cargo.toml @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Circular dependency checker for DataFusion +[package] +name = "depcheck" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +cargo = "0.78.1" diff --git a/dev/depcheck/README.md b/dev/depcheck/README.md new file mode 100644 index 000000000000..4a628cdd88e9 --- /dev/null +++ b/dev/depcheck/README.md @@ -0,0 +1,26 @@ + + +This directory contains a tool that ensures there are no circular dependencies +in the DataFusion codebase. + +Specifically, it checks that no create's tests depend on another crate which +depends on the first, which prevents publishing to crates.io, for example + +[issue 9272]: https://github.com/apache/arrow-datafusion/issues/9277: diff --git a/datafusion/core/tests/depcheck.rs b/dev/depcheck/src/main.rs similarity index 75% rename from datafusion/core/tests/depcheck.rs rename to dev/depcheck/src/main.rs index 94448818691e..b52074c9b1d3 100644 --- a/datafusion/core/tests/depcheck.rs +++ b/dev/depcheck/src/main.rs @@ -15,18 +15,38 @@ // specific language governing permissions and limitations // under the License. +extern crate cargo; + +use cargo::CargoResult; /// Check for circular dependencies between DataFusion crates use std::collections::{HashMap, HashSet}; use std::env; use std::path::Path; use cargo::util::config::Config; -#[test] -fn test_deps() -> Result<(), Box> { + +/// Verifies that there are no circular dependencies between DataFusion crates +/// (which prevents publishing on crates.io) by parsing the Cargo.toml files and +/// checking the dependency graph. +/// +/// See https://github.com/apache/arrow-datafusion/issues/9278 for more details +fn main() -> CargoResult<()> { let config = Config::default()?; + // This is the path for the depcheck binary let path = env::var("CARGO_MANIFEST_DIR").unwrap(); - let dir = Path::new(&path); - let root_cargo_toml = dir.join("Cargo.toml"); + let root_cargo_toml = Path::new(&path) + // dev directory + .parent() + .expect("Can not find dev directory") + // project root directory + .parent() + .expect("Can not find project root directory") + .join("Cargo.toml"); + + println!( + "Checking for circular dependencies in {}", + root_cargo_toml.display() + ); let workspace = cargo::core::Workspace::new(&root_cargo_toml, &config)?; let (_, resolve) = cargo::ops::resolve_ws(&workspace)?; @@ -50,7 +70,7 @@ fn test_deps() -> Result<(), Box> { check_circular_deps(root_package, dep, &package_deps, &mut seen); } } - + println!("No circular dependencies found"); Ok(()) } From 57c0bc65280c467ef6f5a498c6e78a616401b286 Mon Sep 17 00:00:00 2001 From: Lordworms <48054792+Lordworms@users.noreply.github.com> Date: Sat, 30 Mar 2024 15:18:16 -0500 Subject: [PATCH 112/117] delete duplicate test (#9866) --- datafusion/functions/benches/regx.rs | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/datafusion/functions/benches/regx.rs b/datafusion/functions/benches/regx.rs index f22be5ba3532..da4882381e76 100644 --- a/datafusion/functions/benches/regx.rs +++ b/datafusion/functions/benches/regx.rs @@ -103,20 +103,6 @@ fn criterion_benchmark(c: &mut Criterion) { }) }); - c.bench_function("regexp_match_1000", |b| { - let mut rng = rand::thread_rng(); - let data = Arc::new(data(&mut rng)) as ArrayRef; - let regex = Arc::new(regex(&mut rng)) as ArrayRef; - let flags = Arc::new(flags(&mut rng)) as ArrayRef; - - b.iter(|| { - black_box( - regexp_match::(&[data.clone(), regex.clone(), flags.clone()]) - .expect("regexp_match should work on valid values"), - ) - }) - }); - c.bench_function("regexp_replace_1000", |b| { let mut rng = rand::thread_rng(); let data = Arc::new(data(&mut rng)) as ArrayRef; From 2cb6f73cbb53e08eadcf91954ade5c76c2803379 Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Sun, 31 Mar 2024 10:58:39 +0200 Subject: [PATCH 113/117] parquet: Add tests for pruning on Int8/Int16/Int64 columns (#9778) * parquet: Add tests for Bloom filters on Int8/Int16/Int64 columns * Document int_tests macro --------- Co-authored-by: Andrew Lamb --- datafusion/core/Cargo.toml | 1 + datafusion/core/tests/parquet/mod.rs | 48 ++- datafusion/core/tests/parquet/page_pruning.rs | 276 +++++++------- .../core/tests/parquet/row_group_pruning.rs | 339 +++++++++--------- 4 files changed, 353 insertions(+), 311 deletions(-) diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 77a909731d89..610784f91dec 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -136,6 +136,7 @@ ctor = { workspace = true } doc-comment = { workspace = true } env_logger = { workspace = true } half = { workspace = true, default-features = true } +paste = "^1.0" postgres-protocol = "0.6.4" postgres-types = { version = "0.2.4", features = ["derive", "with-chrono-0_4"] } rand = { workspace = true, features = ["small_rng"] } diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 3fe51288e79a..368637d024e6 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -20,8 +20,9 @@ use arrow::array::Decimal128Array; use arrow::{ array::{ Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray, - Float64Array, Int32Array, StringArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -62,7 +63,7 @@ fn init() { enum Scenario { Timestamps, Dates, - Int32, + Int, Int32Range, Float64, Decimal, @@ -389,12 +390,31 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { /// Return record batch with i32 sequence /// /// Columns are named -/// "i" -> Int32Array -fn make_int32_batch(start: i32, end: i32) -> RecordBatch { - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let v: Vec = (start..end).collect(); - let array = Arc::new(Int32Array::from(v)) as ArrayRef; - RecordBatch::try_new(schema, vec![array.clone()]).unwrap() +/// "i8" -> Int8Array +/// "i16" -> Int16Array +/// "i32" -> Int32Array +/// "i64" -> Int64Array +fn make_int_batches(start: i8, end: i8) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("i8", DataType::Int8, true), + Field::new("i16", DataType::Int16, true), + Field::new("i32", DataType::Int32, true), + Field::new("i64", DataType::Int64, true), + ])); + let v8: Vec = (start..end).collect(); + let v16: Vec = (start as _..end as _).collect(); + let v32: Vec = (start as _..end as _).collect(); + let v64: Vec = (start as _..end as _).collect(); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int8Array::from(v8)) as ArrayRef, + Arc::new(Int16Array::from(v16)) as ArrayRef, + Arc::new(Int32Array::from(v32)) as ArrayRef, + Arc::new(Int64Array::from(v64)) as ArrayRef, + ], + ) + .unwrap() } fn make_int32_range(start: i32, end: i32) -> RecordBatch { @@ -589,12 +609,12 @@ fn create_data_batch(scenario: Scenario) -> Vec { make_date_batch(TimeDelta::try_days(3600).unwrap()), ] } - Scenario::Int32 => { + Scenario::Int => { vec![ - make_int32_batch(-5, 0), - make_int32_batch(-4, 1), - make_int32_batch(0, 5), - make_int32_batch(5, 10), + make_int_batches(-5, 0), + make_int_batches(-4, 1), + make_int_batches(0, 5), + make_int_batches(5, 10), ] } Scenario::Int32Range => { diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 3a43428f5bcf..e9e99cd3f88e 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -371,112 +371,149 @@ async fn prune_date64() { assert_eq!(output.result_rows, 1, "{}", output.description()); } -#[tokio::test] -// null count min max -// page-0 0 -5 -1 -// page-1 0 -4 0 -// page-2 0 0 4 -// page-3 0 5 9 -async fn prune_int32_lt() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where i < 1", - Some(0), - Some(5), - 11, - ) - .await; - // result of sql "SELECT * FROM t where i < 1" is same as - // "SELECT * FROM t where -i > -1" - test_prune( - Scenario::Int32, - "SELECT * FROM t where -i > -1", - Some(0), - Some(5), - 11, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_gt() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where i > 8", - Some(0), - Some(15), - 1, - ) - .await; - - test_prune( - Scenario::Int32, - "SELECT * FROM t where -i < -8", - Some(0), - Some(15), - 1, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_eq() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where i = 1", - Some(0), - Some(15), - 1, - ) - .await; -} -#[tokio::test] -async fn prune_int32_scalar_fun_and_eq() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where abs(i) = 1 and i = 1", - Some(0), - Some(15), - 1, - ) - .await; +macro_rules! int_tests { + ($bits:expr) => { + paste::item! { + #[tokio::test] + // null count min max + // page-0 0 -5 -1 + // page-1 0 -4 0 + // page-2 0 0 4 + // page-3 0 5 9 + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} < 1", $bits), + Some(0), + Some(5), + 11, + ) + .await; + // result of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where -i{} > -1", $bits), + Some(0), + Some(5), + 11, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} > 8", $bits), + Some(0), + Some(15), + 1, + ) + .await; + + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where -i{} < -8", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} = 1", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where abs(i{}) = 1 and i{} = 1", $bits, $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where abs(i{}) = 1", $bits), + Some(0), + Some(0), + 3, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{}+1 = 1", $bits), + Some(0), + Some(0), + 2, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where 1-i{} > 1", $bits), + Some(0), + Some(0), + 9, + ) + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where in (1)" + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} in (1)", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + test_prune( + Scenario::Int, + &format!("SELECT * FROM t where i{} not in (1)", $bits), + Some(0), + Some(0), + 19, + ) + .await; + } + } + } } -#[tokio::test] -async fn prune_int32_scalar_fun() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where abs(i) = 1", - Some(0), - Some(0), - 3, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_complex_expr() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where i+1 = 1", - Some(0), - Some(0), - 2, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_complex_expr_subtract() { - test_prune( - Scenario::Int32, - "SELECT * FROM t where 1-i > 1", - Some(0), - Some(0), - 9, - ) - .await; -} +int_tests!(8); +int_tests!(16); +int_tests!(32); +int_tests!(64); #[tokio::test] // null count min max @@ -556,37 +593,6 @@ async fn prune_f64_complex_expr_subtract() { .await; } -#[tokio::test] -// null count min max -// page-0 0 -5 -1 -// page-1 0 -4 0 -// page-2 0 0 4 -// page-3 0 5 9 -async fn prune_int32_eq_in_list() { - // result of sql "SELECT * FROM t where in (1)" - test_prune( - Scenario::Int32, - "SELECT * FROM t where i in (1)", - Some(0), - Some(15), - 1, - ) - .await; -} - -#[tokio::test] -async fn prune_int32_eq_in_list_negated() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - test_prune( - Scenario::Int32, - "SELECT * FROM t where i not in (1)", - Some(0), - Some(0), - 19, - ) - .await; -} - #[tokio::test] async fn prune_decimal_lt() { // The data type of decimal_col is decimal(9,2) diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index ed48d040648c..b70102f78a96 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -285,105 +285,191 @@ async fn prune_disabled() { ); } -#[tokio::test] -async fn prune_int32_lt() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i < 1") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(3)) - .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; - - // result of sql "SELECT * FROM t where i < 1" is same as - // "SELECT * FROM t where -i > -1" - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where -i > -1") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(3)) - .with_pruned_by_stats(Some(1)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(11) - .test_row_group_prune() - .await; -} - -#[tokio::test] -async fn prune_int32_eq() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i = 1") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; -} -#[tokio::test] -async fn prune_int32_scalar_fun_and_eq() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i = 1") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; -} - -#[tokio::test] -async fn prune_int32_scalar_fun() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where abs(i) = 1") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(3) - .test_row_group_prune() - .await; +// $bits: number of bits of the integer to test (8, 16, 32, 64) +// $correct_bloom_filters: if false, replicates the +// https://github.com/apache/arrow-datafusion/issues/9779 bug so that tests pass +// if and only if Bloom filters on Int8 and Int16 columns are still buggy. +macro_rules! int_tests { + ($bits:expr, correct_bloom_filters: $correct_bloom_filters:expr) => { + paste::item! { + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} < 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + + // result of sql "SELECT * FROM t where i < 1" is same as + // "SELECT * FROM t where -i > -1" + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where -i{} > -1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(3)) + .with_pruned_by_stats(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(11) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) + .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) + .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .test_row_group_prune() + .await; + } + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) + .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) + .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where abs(i{}) = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(3) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{}+1 = 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(2) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where 1-i{} > 1", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(9) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where in (1)" + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} in (1)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(3)) + .with_matched_by_bloom_filter(Some(if $correct_bloom_filters { 1 } else { 0 })) + .with_pruned_by_bloom_filter(Some(if $correct_bloom_filters { 0 } else { 1 })) + .with_expected_rows(if $correct_bloom_filters { 1 } else { 0 }) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where in (1000)", prune all + // test whether statistics works + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} in (100)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(0)) + .with_pruned_by_stats(Some(4)) + .with_matched_by_bloom_filter(Some(0)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(0) + .test_row_group_prune() + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where not in (1)" prune nothing + RowGroupPruningTest::new() + .with_scenario(Scenario::Int) + .with_query(&format!("SELECT * FROM t where i{} not in (1)", $bits)) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(4)) + .with_pruned_by_stats(Some(0)) + .with_matched_by_bloom_filter(Some(4)) + .with_pruned_by_bloom_filter(Some(0)) + .with_expected_rows(19) + .test_row_group_prune() + .await; + } + } + }; } -#[tokio::test] -async fn prune_int32_complex_expr() { - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i+1 = 1") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(2) - .test_row_group_prune() - .await; -} +int_tests!(8, correct_bloom_filters: false); +int_tests!(16, correct_bloom_filters: false); +int_tests!(32, correct_bloom_filters: true); +int_tests!(64, correct_bloom_filters: true); #[tokio::test] -async fn prune_int32_complex_expr_subtract() { +async fn prune_int32_eq_large_in_list() { + // result of sql "SELECT * FROM t where i in (2050...2582)", prune all RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where 1-i > 1") + .with_scenario(Scenario::Int32Range) + .with_query( + format!( + "SELECT * FROM t where i in ({})", + (200050..200082).join(",") + ) + .as_str(), + ) .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) + .with_matched_by_stats(Some(1)) .with_pruned_by_stats(Some(0)) .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(9) + .with_pruned_by_bloom_filter(Some(1)) + .with_expected_rows(0) .test_row_group_prune() .await; } @@ -479,77 +565,6 @@ async fn prune_f64_complex_expr_subtract() { .await; } -#[tokio::test] -async fn prune_int32_eq_in_list() { - // result of sql "SELECT * FROM t where in (1)" - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i in (1)") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(3)) - .with_matched_by_bloom_filter(Some(1)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(1) - .test_row_group_prune() - .await; -} - -#[tokio::test] -async fn prune_int32_eq_in_list_2() { - // result of sql "SELECT * FROM t where in (1000)", prune all - // test whether statistics works - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i in (1000)") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(0)) - .with_pruned_by_stats(Some(4)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(0) - .test_row_group_prune() - .await; -} - -#[tokio::test] -async fn prune_int32_eq_large_in_list() { - // result of sql "SELECT * FROM t where i in (2050...2582)", prune all - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32Range) - .with_query( - format!( - "SELECT * FROM t where i in ({})", - (200050..200082).join(",") - ) - .as_str(), - ) - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(1)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(0)) - .with_pruned_by_bloom_filter(Some(1)) - .with_expected_rows(0) - .test_row_group_prune() - .await; -} - -#[tokio::test] -async fn prune_int32_eq_in_list_negated() { - // result of sql "SELECT * FROM t where not in (1)" prune nothing - RowGroupPruningTest::new() - .with_scenario(Scenario::Int32) - .with_query("SELECT * FROM t where i not in (1)") - .with_expected_errors(Some(0)) - .with_matched_by_stats(Some(4)) - .with_pruned_by_stats(Some(0)) - .with_matched_by_bloom_filter(Some(4)) - .with_pruned_by_bloom_filter(Some(0)) - .with_expected_rows(19) - .test_row_group_prune() - .await; -} - #[tokio::test] async fn prune_decimal_lt() { // The data type of decimal_col is decimal(9,2) From 66c8ba22aa9569300c2ae12cef337fee089875e0 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sun, 31 Mar 2024 17:41:15 +0800 Subject: [PATCH 114/117] move `Atan2`, `Atan`, `Acosh`, `Asinh`, `Atanh` to `datafusion-function` (#9872) * Refactor math functions in datafusion code * fic ci * fix: avoid regression * refactor: move atan2 function * chore: move atan2 test --- datafusion/expr/src/built_in_function.rs | 48 +----- datafusion/expr/src/expr_fn.rs | 10 -- datafusion/functions/src/macros.rs | 29 ++++ datafusion/functions/src/math/atan2.rs | 140 ++++++++++++++++++ datafusion/functions/src/math/mod.rs | 14 +- datafusion/functions/src/utils.rs | 3 + datafusion/physical-expr/src/functions.rs | 7 - .../physical-expr/src/math_expressions.rs | 61 -------- datafusion/proto/proto/datafusion.proto | 10 +- datafusion/proto/src/generated/pbjson.rs | 15 -- datafusion/proto/src/generated/prost.rs | 20 +-- .../proto/src/logical_plan/from_proto.rs | 23 +-- datafusion/proto/src/logical_plan/to_proto.rs | 5 - 13 files changed, 201 insertions(+), 184 deletions(-) create mode 100644 datafusion/functions/src/math/atan2.rs diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f8d16f465091..a1b3b717392e 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -37,16 +37,6 @@ use strum_macros::EnumIter; #[derive(Debug, Clone, PartialEq, Eq, Hash, EnumIter, Copy)] pub enum BuiltinScalarFunction { // math functions - /// atan - Atan, - /// atan2 - Atan2, - /// acosh - Acosh, - /// asinh - Asinh, - /// atanh - Atanh, /// cbrt Cbrt, /// ceil @@ -159,11 +149,6 @@ impl BuiltinScalarFunction { pub fn volatility(&self) -> Volatility { match self { // Immutable scalar builtins - BuiltinScalarFunction::Atan => Volatility::Immutable, - BuiltinScalarFunction::Atan2 => Volatility::Immutable, - BuiltinScalarFunction::Acosh => Volatility::Immutable, - BuiltinScalarFunction::Asinh => Volatility::Immutable, - BuiltinScalarFunction::Atanh => Volatility::Immutable, BuiltinScalarFunction::Ceil => Volatility::Immutable, BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, @@ -238,11 +223,6 @@ impl BuiltinScalarFunction { _ => Ok(Float64), }, - BuiltinScalarFunction::Atan2 => match &input_expr_types[0] { - Float32 => Ok(Float32), - _ => Ok(Float64), - }, - BuiltinScalarFunction::Log => match &input_expr_types[0] { Float32 => Ok(Float32), _ => Ok(Float64), @@ -255,11 +235,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Iszero => Ok(Boolean), - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Cos | BuiltinScalarFunction::Cosh | BuiltinScalarFunction::Degrees @@ -332,10 +308,7 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Atan2 => Signature::one_of( - vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], - self.volatility(), - ), + BuiltinScalarFunction::Log => Signature::one_of( vec![ Exact(vec![Float32]), @@ -355,11 +328,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Gcd | BuiltinScalarFunction::Lcm => { Signature::uniform(2, vec![Int64], self.volatility()) } - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Cbrt + BuiltinScalarFunction::Cbrt | BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Cos | BuiltinScalarFunction::Cosh @@ -392,11 +361,7 @@ impl BuiltinScalarFunction { pub fn monotonicity(&self) -> Option { if matches!( &self, - BuiltinScalarFunction::Atan - | BuiltinScalarFunction::Acosh - | BuiltinScalarFunction::Asinh - | BuiltinScalarFunction::Atanh - | BuiltinScalarFunction::Ceil + BuiltinScalarFunction::Ceil | BuiltinScalarFunction::Degrees | BuiltinScalarFunction::Exp | BuiltinScalarFunction::Factorial @@ -421,11 +386,6 @@ impl BuiltinScalarFunction { /// Returns all names that can be used to call this function pub fn aliases(&self) -> &'static [&'static str] { match self { - BuiltinScalarFunction::Acosh => &["acosh"], - BuiltinScalarFunction::Asinh => &["asinh"], - BuiltinScalarFunction::Atan => &["atan"], - BuiltinScalarFunction::Atanh => &["atanh"], - BuiltinScalarFunction::Atan2 => &["atan2"], BuiltinScalarFunction::Cbrt => &["cbrt"], BuiltinScalarFunction::Ceil => &["ceil"], BuiltinScalarFunction::Cos => &["cos"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index ab5628fece12..a2015787040f 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -541,10 +541,6 @@ scalar_expr!(Cos, cos, num, "cosine"); scalar_expr!(Cot, cot, num, "cotangent"); scalar_expr!(Sinh, sinh, num, "hyperbolic sine"); scalar_expr!(Cosh, cosh, num, "hyperbolic cosine"); -scalar_expr!(Atan, atan, num, "inverse tangent"); -scalar_expr!(Asinh, asinh, num, "inverse hyperbolic sine"); -scalar_expr!(Acosh, acosh, num, "inverse hyperbolic cosine"); -scalar_expr!(Atanh, atanh, num, "inverse hyperbolic tangent"); scalar_expr!(Factorial, factorial, num, "factorial"); scalar_expr!( Floor, @@ -571,7 +567,6 @@ scalar_expr!(Exp, exp, num, "exponential"); scalar_expr!(Gcd, gcd, arg_1 arg_2, "greatest common divisor"); scalar_expr!(Lcm, lcm, arg_1 arg_2, "least common multiple"); scalar_expr!(Power, power, base exponent, "`base` raised to the power of `exponent`"); -scalar_expr!(Atan2, atan2, y x, "inverse tangent of a division given in the argument"); scalar_expr!(Log, log, base x, "logarithm of a `x` for a particular `base`"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); @@ -979,10 +974,6 @@ mod test { test_unary_scalar_expr!(Cot, cot); test_unary_scalar_expr!(Sinh, sinh); test_unary_scalar_expr!(Cosh, cosh); - test_unary_scalar_expr!(Atan, atan); - test_unary_scalar_expr!(Asinh, asinh); - test_unary_scalar_expr!(Acosh, acosh); - test_unary_scalar_expr!(Atanh, atanh); test_unary_scalar_expr!(Factorial, factorial); test_unary_scalar_expr!(Floor, floor); test_unary_scalar_expr!(Ceil, ceil); @@ -994,7 +985,6 @@ mod test { test_nary_scalar_expr!(Trunc, trunc, num, precision); test_unary_scalar_expr!(Signum, signum); test_unary_scalar_expr!(Exp, exp); - test_scalar_expr!(Atan2, atan2, y, x); test_scalar_expr!(Nanvl, nanvl, x, y); test_scalar_expr!(Iszero, iszero, input); diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index b23baeeacf23..4907d74fe941 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -156,6 +156,7 @@ macro_rules! downcast_arg { /// $GNAME: a singleton instance of the UDF /// $NAME: the name of the function /// $UNARY_FUNC: the unary function to apply to the argument +/// $MONOTONIC_FUNC: the monotonicity of the function macro_rules! make_math_unary_udf { ($UDF:ident, $GNAME:ident, $NAME:ident, $UNARY_FUNC:ident, $MONOTONICITY:expr) => { make_udf_function!($NAME::$UDF, $GNAME, $NAME); @@ -249,3 +250,31 @@ macro_rules! make_math_unary_udf { } }; } + +#[macro_export] +macro_rules! make_function_inputs2 { + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE>() + }}; + ($ARG1: expr, $ARG2: expr, $NAME1:expr, $NAME2: expr, $ARRAY_TYPE1:ident, $ARRAY_TYPE2:ident, $FUNC: block) => {{ + let arg1 = downcast_arg!($ARG1, $NAME1, $ARRAY_TYPE1); + let arg2 = downcast_arg!($ARG2, $NAME2, $ARRAY_TYPE2); + + arg1.iter() + .zip(arg2.iter()) + .map(|(a1, a2)| match (a1, a2) { + (Some(a1), Some(a2)) => Some($FUNC(a1, a2.try_into().ok()?)), + _ => None, + }) + .collect::<$ARRAY_TYPE1>() + }}; +} diff --git a/datafusion/functions/src/math/atan2.rs b/datafusion/functions/src/math/atan2.rs new file mode 100644 index 000000000000..b090c6c454fd --- /dev/null +++ b/datafusion/functions/src/math/atan2.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math function: `atan2()`. + +use arrow::array::{ArrayRef, Float32Array, Float64Array}; +use arrow::datatypes::DataType; +use datafusion_common::DataFusionError; +use datafusion_common::{exec_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +use crate::make_function_inputs2; +use crate::utils::make_scalar_function; + +#[derive(Debug)] +pub(super) struct Atan2 { + signature: Signature, +} + +impl Atan2 { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::one_of( + vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for Atan2 { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "atan2" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use self::DataType::*; + match &arg_types[0] { + Float32 => Ok(Float32), + _ => Ok(Float64), + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + make_scalar_function(atan2, vec![])(args) + } +} + +/// Atan2 SQL function +pub fn atan2(args: &[ArrayRef]) -> Result { + match args[0].data_type() { + DataType::Float64 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float64Array, + { f64::atan2 } + )) as ArrayRef), + + DataType::Float32 => Ok(Arc::new(make_function_inputs2!( + &args[0], + &args[1], + "y", + "x", + Float32Array, + { f32::atan2 } + )) as ArrayRef), + + other => exec_err!("Unsupported data type {other:?} for function atan2"), + } +} + +#[cfg(test)] +mod test { + use super::*; + use datafusion_common::cast::{as_float32_array, as_float64_array}; + + #[test] + fn test_atan2_f64() { + let args: Vec = vec![ + Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("failed to initialize function atan2"); + let floats = + as_float64_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); + } + + #[test] + fn test_atan2_f32() { + let args: Vec = vec![ + Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y + Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x + ]; + + let result = atan2(&args).expect("failed to initialize function atan2"); + let floats = + as_float32_array(&result).expect("failed to initialize function atan2"); + + assert_eq!(floats.len(), 4); + assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); + assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); + assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); + assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); + } +} diff --git a/datafusion/functions/src/math/mod.rs b/datafusion/functions/src/math/mod.rs index 3a4c1b1e8710..2ee1fffa1625 100644 --- a/datafusion/functions/src/math/mod.rs +++ b/datafusion/functions/src/math/mod.rs @@ -18,11 +18,13 @@ //! "math" DataFusion functions mod abs; +mod atan2; mod nans; // Create UDFs make_udf_function!(nans::IsNanFunc, ISNAN, isnan); make_udf_function!(abs::AbsFunc, ABS, abs); +make_udf_function!(atan2::Atan2, ATAN2, atan2); make_math_unary_udf!(Log2Func, LOG2, log2, log2, Some(vec![Some(true)])); make_math_unary_udf!(Log10Func, LOG10, log10, log10, Some(vec![Some(true)])); @@ -33,6 +35,11 @@ make_math_unary_udf!(AcosFunc, ACOS, acos, acos, None); make_math_unary_udf!(AsinFunc, ASIN, asin, asin, None); make_math_unary_udf!(TanFunc, TAN, tan, tan, None); +make_math_unary_udf!(AtanhFunc, ATANH, atanh, atanh, Some(vec![Some(true)])); +make_math_unary_udf!(AsinhFunc, ASINH, asinh, asinh, Some(vec![Some(true)])); +make_math_unary_udf!(AcoshFunc, ACOSH, acosh, acosh, Some(vec![Some(true)])); +make_math_unary_udf!(AtanFunc, ATAN, atan, atan, Some(vec![Some(true)])); + // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( ( @@ -55,5 +62,10 @@ export_functions!( "returns the arc sine or inverse sine of a number" ), (tan, num, "returns the tangent of a number"), - (tanh, num, "returns the hyperbolic tangent of a number") + (tanh, num, "returns the hyperbolic tangent of a number"), + (atanh, num, "returns inverse hyperbolic tangent"), + (asinh, num, "returns inverse hyperbolic sine"), + (acosh, num, "returns inverse hyperbolic cosine"), + (atan, num, "returns inverse tangent"), + (atan2, y x, "returns inverse tangent of a division given in the argument") ); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index f45deafdb37a..9b7144b483bd 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -68,6 +68,9 @@ get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8); // `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size. get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32); +/// Creates a scalar function implementation for the given function. +/// * `inner` - the function to be executed +/// * `hints` - hints to be used when expanding scalars to arrays pub(super) fn make_scalar_function( inner: F, hints: Vec, diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5b9b46c3991b..a1e471bdd422 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -179,10 +179,6 @@ pub fn create_physical_fun( ) -> Result { Ok(match fun { // math functions - BuiltinScalarFunction::Atan => Arc::new(math_expressions::atan), - BuiltinScalarFunction::Acosh => Arc::new(math_expressions::acosh), - BuiltinScalarFunction::Asinh => Arc::new(math_expressions::asinh), - BuiltinScalarFunction::Atanh => Arc::new(math_expressions::atanh), BuiltinScalarFunction::Ceil => Arc::new(math_expressions::ceil), BuiltinScalarFunction::Cos => Arc::new(math_expressions::cos), BuiltinScalarFunction::Cosh => Arc::new(math_expressions::cosh), @@ -221,9 +217,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Power => { Arc::new(|args| make_scalar_function_inner(math_expressions::power)(args)) } - BuiltinScalarFunction::Atan2 => { - Arc::new(|args| make_scalar_function_inner(math_expressions::atan2)(args)) - } BuiltinScalarFunction::Log => { Arc::new(|args| make_scalar_function_inner(math_expressions::log)(args)) } diff --git a/datafusion/physical-expr/src/math_expressions.rs b/datafusion/physical-expr/src/math_expressions.rs index db8855cb5400..5339c12f6e93 100644 --- a/datafusion/physical-expr/src/math_expressions.rs +++ b/datafusion/physical-expr/src/math_expressions.rs @@ -492,31 +492,6 @@ pub fn power(args: &[ArrayRef]) -> Result { } } -/// Atan2 SQL function -pub fn atan2(args: &[ArrayRef]) -> Result { - match args[0].data_type() { - DataType::Float64 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float64Array, - { f64::atan2 } - )) as ArrayRef), - - DataType::Float32 => Ok(Arc::new(make_function_inputs2!( - &args[0], - &args[1], - "y", - "x", - Float32Array, - { f32::atan2 } - )) as ArrayRef), - - other => exec_err!("Unsupported data type {other:?} for function atan2"), - } -} - /// Log SQL function pub fn log(args: &[ArrayRef]) -> Result { // Support overloaded log(base, x) and log(x) which defaults to log(10, x) @@ -725,42 +700,6 @@ mod tests { assert_eq!(floats.value(3), 625); } - #[test] - fn test_atan2_f64() { - let args: Vec = vec![ - Arc::new(Float64Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float64Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float64_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f64).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f64).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f64).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f64).atan2(-4.0)); - } - - #[test] - fn test_atan2_f32() { - let args: Vec = vec![ - Arc::new(Float32Array::from(vec![2.0, -3.0, 4.0, -5.0])), // y - Arc::new(Float32Array::from(vec![1.0, 2.0, -3.0, -4.0])), // x - ]; - - let result = atan2(&args).expect("failed to initialize function atan2"); - let floats = - as_float32_array(&result).expect("failed to initialize function atan2"); - - assert_eq!(floats.len(), 4); - assert_eq!(floats.value(0), (2.0_f32).atan2(1.0)); - assert_eq!(floats.value(1), (-3.0_f32).atan2(2.0)); - assert_eq!(floats.value(2), (4.0_f32).atan2(-3.0)); - assert_eq!(floats.value(3), (-5.0_f32).atan2(-4.0)); - } - #[test] fn test_log_f64() { let args: Vec = vec![ diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index b756e0575d71..e959cad2a810 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -544,7 +544,7 @@ enum ScalarFunction { unknown = 0; // 1 was Acos // 2 was Asin - Atan = 3; + // 3 was Atan // 4 was Ascii Ceil = 5; Cos = 6; @@ -608,16 +608,16 @@ enum ScalarFunction { Power = 64; // 65 was StructFun // 66 was FromUnixtime - Atan2 = 67; + // 67 Atan2 // 68 was DateBin // 69 was ArrowTypeof // 70 was CurrentDate // 71 was CurrentTime // 72 was Uuid Cbrt = 73; - Acosh = 74; - Asinh = 75; - Atanh = 76; + // 74 Acosh + // 75 was Asinh + // 76 was Atanh Sinh = 77; Cosh = 78; // Tanh = 79; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 3c3d60300786..d900d0031df3 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -22914,7 +22914,6 @@ impl serde::Serialize for ScalarFunction { { let variant = match self { Self::Unknown => "unknown", - Self::Atan => "Atan", Self::Ceil => "Ceil", Self::Cos => "Cos", Self::Exp => "Exp", @@ -22931,11 +22930,7 @@ impl serde::Serialize for ScalarFunction { Self::Random => "Random", Self::Coalesce => "Coalesce", Self::Power => "Power", - Self::Atan2 => "Atan2", Self::Cbrt => "Cbrt", - Self::Acosh => "Acosh", - Self::Asinh => "Asinh", - Self::Atanh => "Atanh", Self::Sinh => "Sinh", Self::Cosh => "Cosh", Self::Pi => "Pi", @@ -22960,7 +22955,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { const FIELDS: &[&str] = &[ "unknown", - "Atan", "Ceil", "Cos", "Exp", @@ -22977,11 +22971,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Random", "Coalesce", "Power", - "Atan2", "Cbrt", - "Acosh", - "Asinh", - "Atanh", "Sinh", "Cosh", "Pi", @@ -23035,7 +23025,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { { match value { "unknown" => Ok(ScalarFunction::Unknown), - "Atan" => Ok(ScalarFunction::Atan), "Ceil" => Ok(ScalarFunction::Ceil), "Cos" => Ok(ScalarFunction::Cos), "Exp" => Ok(ScalarFunction::Exp), @@ -23052,11 +23041,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Random" => Ok(ScalarFunction::Random), "Coalesce" => Ok(ScalarFunction::Coalesce), "Power" => Ok(ScalarFunction::Power), - "Atan2" => Ok(ScalarFunction::Atan2), "Cbrt" => Ok(ScalarFunction::Cbrt), - "Acosh" => Ok(ScalarFunction::Acosh), - "Asinh" => Ok(ScalarFunction::Asinh), - "Atanh" => Ok(ScalarFunction::Atanh), "Sinh" => Ok(ScalarFunction::Sinh), "Cosh" => Ok(ScalarFunction::Cosh), "Pi" => Ok(ScalarFunction::Pi), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9860587d3eca..753abb4e2756 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2843,7 +2843,7 @@ pub enum ScalarFunction { Unknown = 0, /// 1 was Acos /// 2 was Asin - Atan = 3, + /// 3 was Atan /// 4 was Ascii Ceil = 5, Cos = 6, @@ -2907,16 +2907,16 @@ pub enum ScalarFunction { Power = 64, /// 65 was StructFun /// 66 was FromUnixtime - Atan2 = 67, + /// 67 Atan2 /// 68 was DateBin /// 69 was ArrowTypeof /// 70 was CurrentDate /// 71 was CurrentTime /// 72 was Uuid Cbrt = 73, - Acosh = 74, - Asinh = 75, - Atanh = 76, + /// 74 Acosh + /// 75 was Asinh + /// 76 was Atanh Sinh = 77, Cosh = 78, /// Tanh = 79; @@ -2987,7 +2987,6 @@ impl ScalarFunction { pub fn as_str_name(&self) -> &'static str { match self { ScalarFunction::Unknown => "unknown", - ScalarFunction::Atan => "Atan", ScalarFunction::Ceil => "Ceil", ScalarFunction::Cos => "Cos", ScalarFunction::Exp => "Exp", @@ -3004,11 +3003,7 @@ impl ScalarFunction { ScalarFunction::Random => "Random", ScalarFunction::Coalesce => "Coalesce", ScalarFunction::Power => "Power", - ScalarFunction::Atan2 => "Atan2", ScalarFunction::Cbrt => "Cbrt", - ScalarFunction::Acosh => "Acosh", - ScalarFunction::Asinh => "Asinh", - ScalarFunction::Atanh => "Atanh", ScalarFunction::Sinh => "Sinh", ScalarFunction::Cosh => "Cosh", ScalarFunction::Pi => "Pi", @@ -3027,7 +3022,6 @@ impl ScalarFunction { pub fn from_str_name(value: &str) -> ::core::option::Option { match value { "unknown" => Some(Self::Unknown), - "Atan" => Some(Self::Atan), "Ceil" => Some(Self::Ceil), "Cos" => Some(Self::Cos), "Exp" => Some(Self::Exp), @@ -3044,11 +3038,7 @@ impl ScalarFunction { "Random" => Some(Self::Random), "Coalesce" => Some(Self::Coalesce), "Power" => Some(Self::Power), - "Atan2" => Some(Self::Atan2), "Cbrt" => Some(Self::Cbrt), - "Acosh" => Some(Self::Acosh), - "Asinh" => Some(Self::Asinh), - "Atanh" => Some(Self::Atanh), "Sinh" => Some(Self::Sinh), "Cosh" => Some(Self::Cosh), "Pi" => Some(Self::Pi), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index c068cfd46c1f..f9e2dc5596ac 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -37,8 +37,8 @@ use datafusion_expr::expr::Unnest; use datafusion_expr::expr::{Alias, Placeholder}; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - acosh, asinh, atan, atan2, atanh, cbrt, ceil, coalesce, concat_expr, concat_ws_expr, - cos, cosh, cot, degrees, ends_with, exp, + cbrt, ceil, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, degrees, + ends_with, exp, expr::{self, InList, Sort, WindowFunction}, factorial, floor, gcd, initcap, iszero, lcm, log, logical_plan::{PlanType, StringifiedPlan}, @@ -428,12 +428,8 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Sin => Self::Sin, ScalarFunction::Cos => Self::Cos, ScalarFunction::Cot => Self::Cot, - ScalarFunction::Atan => Self::Atan, ScalarFunction::Sinh => Self::Sinh, ScalarFunction::Cosh => Self::Cosh, - ScalarFunction::Asinh => Self::Asinh, - ScalarFunction::Acosh => Self::Acosh, - ScalarFunction::Atanh => Self::Atanh, ScalarFunction::Exp => Self::Exp, ScalarFunction::Log => Self::Log, ScalarFunction::Degrees => Self::Degrees, @@ -454,7 +450,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Coalesce => Self::Coalesce, ScalarFunction::Pi => Self::Pi, ScalarFunction::Power => Self::Power, - ScalarFunction::Atan2 => Self::Atan2, ScalarFunction::Nanvl => Self::Nanvl, ScalarFunction::Iszero => Self::Iszero, } @@ -1318,22 +1313,12 @@ pub fn parse_expr( match scalar_function { ScalarFunction::Unknown => Err(proto_error("Unknown scalar function")), - ScalarFunction::Asinh => { - Ok(asinh(parse_expr(&args[0], registry, codec)?)) - } - ScalarFunction::Acosh => { - Ok(acosh(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Sqrt => Ok(sqrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cbrt => Ok(cbrt(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sin => Ok(sin(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cos => Ok(cos(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Atan => Ok(atan(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Sinh => Ok(sinh(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Cosh => Ok(cosh(parse_expr(&args[0], registry, codec)?)), - ScalarFunction::Atanh => { - Ok(atanh(parse_expr(&args[0], registry, codec)?)) - } ScalarFunction::Exp => Ok(exp(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Degrees => { Ok(degrees(parse_expr(&args[0], registry, codec)?)) @@ -1387,10 +1372,6 @@ pub fn parse_expr( parse_expr(&args[0], registry, codec)?, parse_expr(&args[1], registry, codec)?, )), - ScalarFunction::Atan2 => Ok(atan2( - parse_expr(&args[0], registry, codec)?, - parse_expr(&args[1], registry, codec)?, - )), ScalarFunction::Cot => Ok(cot(parse_expr(&args[0], registry, codec)?)), ScalarFunction::Nanvl => Ok(nanvl( parse_expr(&args[0], registry, codec)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9d433bb6ff97..3ee69066e1aa 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1422,10 +1422,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Cot => Self::Cot, BuiltinScalarFunction::Sinh => Self::Sinh, BuiltinScalarFunction::Cosh => Self::Cosh, - BuiltinScalarFunction::Atan => Self::Atan, - BuiltinScalarFunction::Asinh => Self::Asinh, - BuiltinScalarFunction::Acosh => Self::Acosh, - BuiltinScalarFunction::Atanh => Self::Atanh, BuiltinScalarFunction::Exp => Self::Exp, BuiltinScalarFunction::Factorial => Self::Factorial, BuiltinScalarFunction::Gcd => Self::Gcd, @@ -1446,7 +1442,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::Coalesce => Self::Coalesce, BuiltinScalarFunction::Pi => Self::Pi, BuiltinScalarFunction::Power => Self::Power, - BuiltinScalarFunction::Atan2 => Self::Atan2, BuiltinScalarFunction::Nanvl => Self::Nanvl, BuiltinScalarFunction::Iszero => Self::Iszero, }; From ef601d2caa7f62cc5c6dde7bb7c371a1e298c2fb Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Sun, 31 Mar 2024 02:41:57 -0700 Subject: [PATCH 115/117] minor(doc): fix dead link for catalogs example (#9883) --- datafusion/core/src/catalog/mod.rs | 2 +- docs/source/library-user-guide/catalogs.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/core/src/catalog/mod.rs b/datafusion/core/src/catalog/mod.rs index 8aeeaf9f72d8..d39fad8a5643 100644 --- a/datafusion/core/src/catalog/mod.rs +++ b/datafusion/core/src/catalog/mod.rs @@ -177,7 +177,7 @@ impl CatalogProviderList for MemoryCatalogProviderList { /// /// [`datafusion-cli`]: https://arrow.apache.org/datafusion/user-guide/cli.html /// [`DynamicFileCatalogProvider`]: https://github.com/apache/arrow-datafusion/blob/31b9b48b08592b7d293f46e75707aad7dadd7cbc/datafusion-cli/src/catalog.rs#L75 -/// [`catalog.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/external_dependency/catalog.rs +/// [`catalog.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/catalog.rs /// [delta-rs]: https://github.com/delta-io/delta-rs /// [`UnityCatalogProvider`]: https://github.com/delta-io/delta-rs/blob/951436ecec476ce65b5ed3b58b50fb0846ca7b91/crates/deltalake-core/src/data_catalog/unity/datafusion.rs#L111-L123 /// diff --git a/docs/source/library-user-guide/catalogs.md b/docs/source/library-user-guide/catalogs.md index 06cd2765d161..d30e26f1964a 100644 --- a/docs/source/library-user-guide/catalogs.md +++ b/docs/source/library-user-guide/catalogs.md @@ -19,7 +19,7 @@ # Catalogs, Schemas, and Tables -This section describes how to create and manage catalogs, schemas, and tables in DataFusion. For those wanting to dive into the code quickly please see the [example](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/external_dependency/catalog.rs). +This section describes how to create and manage catalogs, schemas, and tables in DataFusion. For those wanting to dive into the code quickly please see the [example](https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/catalog.rs). ## General Concepts From a23f50768deea4757593233056ad10cf847c35ff Mon Sep 17 00:00:00 2001 From: Val Lorentz Date: Sun, 31 Mar 2024 13:40:21 +0200 Subject: [PATCH 116/117] parquet: Add tests for page pruning on unsigned integers (#9888) --- datafusion/core/tests/parquet/mod.rs | 43 ++++++- datafusion/core/tests/parquet/page_pruning.rs | 114 ++++++++++++++++++ 2 files changed, 155 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 368637d024e6..1da86a0363a5 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -22,7 +22,7 @@ use arrow::{ Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -65,6 +65,7 @@ enum Scenario { Dates, Int, Int32Range, + UInt, Float64, Decimal, DecimalBloomFilterInt32, @@ -387,7 +388,7 @@ fn make_timestamp_batch(offset: Duration) -> RecordBatch { .unwrap() } -/// Return record batch with i32 sequence +/// Return record batch with i8, i16, i32, and i64 sequences /// /// Columns are named /// "i8" -> Int8Array @@ -417,6 +418,36 @@ fn make_int_batches(start: i8, end: i8) -> RecordBatch { .unwrap() } +/// Return record batch with i8, i16, i32, and i64 sequences +/// +/// Columns are named +/// "u8" -> UInt8Array +/// "u16" -> UInt16Array +/// "u32" -> UInt32Array +/// "u64" -> UInt64Array +fn make_uint_batches(start: u8, end: u8) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("u8", DataType::UInt8, true), + Field::new("u16", DataType::UInt16, true), + Field::new("u32", DataType::UInt32, true), + Field::new("u64", DataType::UInt64, true), + ])); + let v8: Vec = (start..end).collect(); + let v16: Vec = (start as _..end as _).collect(); + let v32: Vec = (start as _..end as _).collect(); + let v64: Vec = (start as _..end as _).collect(); + RecordBatch::try_new( + schema, + vec![ + Arc::new(UInt8Array::from(v8)) as ArrayRef, + Arc::new(UInt16Array::from(v16)) as ArrayRef, + Arc::new(UInt32Array::from(v32)) as ArrayRef, + Arc::new(UInt64Array::from(v64)) as ArrayRef, + ], + ) + .unwrap() +} + fn make_int32_range(start: i32, end: i32) -> RecordBatch { let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); let v = vec![start, end]; @@ -620,6 +651,14 @@ fn create_data_batch(scenario: Scenario) -> Vec { Scenario::Int32Range => { vec![make_int32_range(0, 10), make_int32_range(200000, 300000)] } + Scenario::UInt => { + vec![ + make_uint_batches(0, 5), + make_uint_batches(1, 6), + make_uint_batches(5, 10), + make_uint_batches(250, 255), + ] + } Scenario::Float64 => { vec![ make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index e9e99cd3f88e..da9617f13ee9 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -515,6 +515,120 @@ int_tests!(16); int_tests!(32); int_tests!(64); +macro_rules! uint_tests { + ($bits:expr) => { + paste::item! { + #[tokio::test] + // null count min max + // page-0 0 0 4 + // page-1 0 1 5 + // page-2 0 5 9 + // page-3 0 250 254 + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} < 6", $bits), + Some(0), + Some(5), + 11, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} > 253", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} = 6", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where power(u{}, 2) = 36 and u{} = 6", $bits, $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where power(u{}, 2) = 25", $bits), + Some(0), + Some(0), + 2, + ) + .await; + } + + #[tokio::test] + async fn []() { + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{}+1 = 6", $bits), + Some(0), + Some(0), + 2, + ) + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where in (1)" + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} in (6)", $bits), + Some(0), + Some(15), + 1, + ) + .await; + } + + #[tokio::test] + async fn []() { + // result of sql "SELECT * FROM t where not in (6)" prune nothing + test_prune( + Scenario::UInt, + &format!("SELECT * FROM t where u{} not in (6)", $bits), + Some(0), + Some(0), + 19, + ) + .await; + } + } + } +} + +uint_tests!(8); +uint_tests!(16); +uint_tests!(32); +uint_tests!(64); + #[tokio::test] // null count min max // page-0 0 -5.0 -1.0 From cd7a00b08309f7229073e4bba686d6271726ab1c Mon Sep 17 00:00:00 2001 From: wiedld Date: Sun, 31 Mar 2024 05:09:06 -0700 Subject: [PATCH 117/117] fix(9870): common expression elimination optimization, should always re-find the correct expression during re-write. (#9871) * test(9870): reproducer of error with jumping traversal patterns in common-expr-elimination traversals * refactor: remove the IdArray ordered idx, since the idx ordering does not always stay in sync with the updated TreeNode traversal * refactor: use the only reproducible key (expr_identifer) for expr_set, while keeping the (stack-popped) symbol used for alias. * refactor: encapsulate most of the logic within ExprSet, and delineate the expr_identifier from the alias symbol * test(9870): demonstrate that the sqllogictests are now passing --- datafusion/expr/src/logical_plan/plan.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 441 ++++++------------ datafusion/sqllogictest/test_files/expr.slt | 63 +++ 3 files changed, 214 insertions(+), 292 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9f4094d483c9..0bf5b8dffaa2 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2389,7 +2389,7 @@ impl DistinctOn { /// Aggregates its input based on a set of grouping and aggregate /// expressions (e.g. SUM). -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] // mark non_exhaustive to encourage use of try_new/new() #[non_exhaustive] pub struct Aggregate { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0c9064d0641f..25c25c63f0b7 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -17,6 +17,7 @@ //! Eliminate common sub-expression. +use std::collections::hash_map::Entry; use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; @@ -35,37 +36,75 @@ use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{Aggregate, LogicalPlan, Projection, Window}; use datafusion_expr::{col, Expr, ExprSchemable}; -/// A map from expression's identifier to tuple including -/// - the expression itself (cloned) -/// - counter -/// - DataType of this expression. -type ExprSet = HashMap; +/// Set of expressions generated by the [`ExprIdentifierVisitor`] +/// and consumed by the [`CommonSubexprRewriter`]. +#[derive(Default)] +struct ExprSet { + /// A map from expression's identifier (stringified expr) to tuple including: + /// - the expression itself (cloned) + /// - counter + /// - DataType of this expression. + /// - symbol used as the identifier in the alias. + map: HashMap, +} -/// An ordered map of Identifiers assigned by `ExprIdentifierVisitor` in an -/// initial expression walk. -/// -/// Used by `CommonSubexprRewriter`, which rewrites the expressions to remove -/// common subexpressions. -/// -/// Elements in this array are created on the walk down the expression tree -/// during `f_down`. Thus element 0 is the root of the expression tree. The -/// tuple contains: -/// - series_number. -/// - Incremented during `f_up`, start from 1. -/// - Thus, items with higher idx have the lower series_number. -/// - [`Identifier`] -/// - Identifier of the expression. If empty (`""`), expr should not be considered for common elimination. -/// -/// # Example -/// An expression like `(a + b)` would have the following `IdArray`: -/// ```text -/// [ -/// (3, "a + b"), -/// (2, "a"), -/// (1, "b") -/// ] -/// ``` -type IdArray = Vec<(usize, Identifier)>; +impl ExprSet { + fn expr_identifier(expr: &Expr) -> Identifier { + format!("{expr}") + } + + fn get(&self, key: &Identifier) -> Option<&(Expr, usize, DataType, Identifier)> { + self.map.get(key) + } + + fn entry( + &mut self, + key: Identifier, + ) -> Entry<'_, Identifier, (Expr, usize, DataType, Identifier)> { + self.map.entry(key) + } + + fn populate_expr_set( + &mut self, + expr: &[Expr], + input_schema: DFSchemaRef, + expr_mask: ExprMask, + ) -> Result<()> { + expr.iter().try_for_each(|e| { + self.expr_to_identifier(e, Arc::clone(&input_schema), expr_mask)?; + + Ok(()) + }) + } + + /// Go through an expression tree and generate identifier for every node in this tree. + fn expr_to_identifier( + &mut self, + expr: &Expr, + input_schema: DFSchemaRef, + expr_mask: ExprMask, + ) -> Result<()> { + expr.visit(&mut ExprIdentifierVisitor { + expr_set: self, + input_schema, + visit_stack: vec![], + node_count: 0, + expr_mask, + })?; + + Ok(()) + } +} + +impl From> for ExprSet { + fn from(entries: Vec<(Identifier, (Expr, usize, DataType, Identifier))>) -> Self { + let mut expr_set = Self::default(); + entries.into_iter().for_each(|(k, v)| { + expr_set.map.insert(k, v); + }); + expr_set + } +} /// Identifier for each subexpression. /// @@ -112,21 +151,16 @@ impl CommonSubexprEliminate { fn rewrite_exprs_list( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result>> { exprs_list .iter() - .zip(arrays_list.iter()) - .map(|(exprs, arrays)| { + .map(|exprs| { exprs .iter() .cloned() - .zip(arrays.iter()) - .map(|(expr, id_array)| { - replace_common_expr(expr, id_array, expr_set, affected_id) - }) + .map(|expr| replace_common_expr(expr, expr_set, affected_id)) .collect::>>() }) .collect::>>() @@ -135,7 +169,6 @@ impl CommonSubexprEliminate { fn rewrite_expr( &self, exprs_list: &[&[Expr]], - arrays_list: &[&[Vec<(usize, String)>]], input: &LogicalPlan, expr_set: &ExprSet, config: &dyn OptimizerConfig, @@ -143,7 +176,7 @@ impl CommonSubexprEliminate { let mut affected_id = BTreeSet::::new(); let rewrite_exprs = - self.rewrite_exprs_list(exprs_list, arrays_list, expr_set, &mut affected_id)?; + self.rewrite_exprs_list(exprs_list, expr_set, &mut affected_id)?; let mut new_input = self .try_optimize(input, config)? @@ -161,8 +194,7 @@ impl CommonSubexprEliminate { config: &dyn OptimizerConfig, ) -> Result { let mut window_exprs = vec![]; - let mut arrays_per_window = vec![]; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); // Get all window expressions inside the consecutive window operators. // Consecutive window expressions may refer to same complex expression. @@ -181,30 +213,18 @@ impl CommonSubexprEliminate { plan = input.as_ref().clone(); let input_schema = Arc::clone(input.schema()); - let arrays = - to_arrays(&window_expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(&window_expr, input_schema, ExprMask::Normal)?; window_exprs.push(window_expr); - arrays_per_window.push(arrays); } let mut window_exprs = window_exprs .iter() .map(|expr| expr.as_slice()) .collect::>(); - let arrays_per_window = arrays_per_window - .iter() - .map(|arrays| arrays.as_slice()) - .collect::>(); - assert_eq!(window_exprs.len(), arrays_per_window.len()); - let (mut new_expr, new_input) = self.rewrite_expr( - &window_exprs, - &arrays_per_window, - &plan, - &expr_set, - config, - )?; + let (mut new_expr, new_input) = + self.rewrite_expr(&window_exprs, &plan, &expr_set, config)?; assert_eq!(window_exprs.len(), new_expr.len()); // Construct consecutive window operator, with their corresponding new window expressions. @@ -241,46 +261,36 @@ impl CommonSubexprEliminate { input, .. } = aggregate; - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); - // rewrite inputs + // build expr_set, with groupby and aggr let input_schema = Arc::clone(input.schema()); - let group_arrays = to_arrays( + expr_set.populate_expr_set( group_expr, Arc::clone(&input_schema), - &mut expr_set, ExprMask::Normal, )?; - let aggr_arrays = - to_arrays(aggr_expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(aggr_expr, input_schema, ExprMask::Normal)?; - let (mut new_expr, new_input) = self.rewrite_expr( - &[group_expr, aggr_expr], - &[&group_arrays, &aggr_arrays], - input, - &expr_set, - config, - )?; + // rewrite inputs + let (mut new_expr, new_input) = + self.rewrite_expr(&[group_expr, aggr_expr], input, &expr_set, config)?; // note the reversed pop order. let new_aggr_expr = pop_expr(&mut new_expr)?; let new_group_expr = pop_expr(&mut new_expr)?; // create potential projection on top - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); let new_input_schema = Arc::clone(new_input.schema()); - let aggr_arrays = to_arrays( + expr_set.populate_expr_set( &new_aggr_expr, new_input_schema.clone(), - &mut expr_set, ExprMask::NormalAndAggregates, )?; + let mut affected_id = BTreeSet::::new(); - let mut rewritten = self.rewrite_exprs_list( - &[&new_aggr_expr], - &[&aggr_arrays], - &expr_set, - &mut affected_id, - )?; + let mut rewritten = + self.rewrite_exprs_list(&[&new_aggr_expr], &expr_set, &mut affected_id)?; let rewritten = pop_expr(&mut rewritten)?; if affected_id.is_empty() { @@ -300,9 +310,9 @@ impl CommonSubexprEliminate { for id in affected_id { match expr_set.get(&id) { - Some((expr, _, _)) => { + Some((expr, _, _, symbol)) => { // todo: check `nullable` - agg_exprs.push(expr.clone().alias(&id)); + agg_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -320,9 +330,7 @@ impl CommonSubexprEliminate { agg_exprs.push(expr.alias(&name)); proj_exprs.push(Expr::Column(Column::from_name(name))); } else { - let id = ExprIdentifierVisitor::<'static>::expr_identifier( - &expr_rewritten, - ); + let id = ExprSet::expr_identifier(&expr_rewritten); let out_name = expr_rewritten.to_field(&new_input_schema)?.qualified_name(); agg_exprs.push(expr_rewritten.alias(&id)); @@ -356,13 +364,13 @@ impl CommonSubexprEliminate { let inputs = plan.inputs(); let input = inputs[0]; let input_schema = Arc::clone(input.schema()); - let mut expr_set = ExprSet::new(); + let mut expr_set = ExprSet::default(); // Visit expr list and build expr identifier to occuring count map (`expr_set`). - let arrays = to_arrays(&expr, input_schema, &mut expr_set, ExprMask::Normal)?; + expr_set.populate_expr_set(&expr, input_schema, ExprMask::Normal)?; let (mut new_expr, new_input) = - self.rewrite_expr(&[&expr], &[&arrays], input, &expr_set, config)?; + self.rewrite_expr(&[&expr], input, &expr_set, config)?; plan.with_new_exprs(pop_expr(&mut new_expr)?, vec![new_input]) } @@ -448,28 +456,6 @@ fn pop_expr(new_expr: &mut Vec>) -> Result> { .ok_or_else(|| DataFusionError::Internal("Failed to pop expression".to_string())) } -fn to_arrays( - expr: &[Expr], - input_schema: DFSchemaRef, - expr_set: &mut ExprSet, - expr_mask: ExprMask, -) -> Result>> { - expr.iter() - .map(|e| { - let mut id_array = vec![]; - expr_to_identifier( - e, - expr_set, - &mut id_array, - Arc::clone(&input_schema), - expr_mask, - )?; - - Ok(id_array) - }) - .collect::>>() -} - /// Build the "intermediate" projection plan that evaluates the extracted common expressions. fn build_common_expr_project_plan( input: LogicalPlan, @@ -481,11 +467,11 @@ fn build_common_expr_project_plan( for id in affected_id { match expr_set.get(&id) { - Some((expr, _, data_type)) => { + Some((expr, _, data_type, symbol)) => { // todo: check `nullable` let field = DFField::new_unqualified(&id, data_type.clone(), true); fields_set.insert(field.name().to_owned()); - project_exprs.push(expr.clone().alias(&id)); + project_exprs.push(expr.clone().alias(symbol.as_str())); } _ => { return internal_err!("expr_set invalid state"); @@ -601,8 +587,6 @@ impl ExprMask { struct ExprIdentifierVisitor<'a> { // param expr_set: &'a mut ExprSet, - /// series number (usize) and identifier. - id_array: &'a mut IdArray, /// input schema for the node that we're optimizing, so we can determine the correct datatype /// for each subexpression input_schema: DFSchemaRef, @@ -610,8 +594,6 @@ struct ExprIdentifierVisitor<'a> { visit_stack: Vec, /// increased in fn_down, start from 0. node_count: usize, - /// increased in fn_up, start from 1. - series_number: usize, /// which expression should be skipped? expr_mask: ExprMask, } @@ -628,10 +610,6 @@ enum VisitRecord { } impl ExprIdentifierVisitor<'_> { - fn expr_identifier(expr: &Expr) -> Identifier { - format!("{expr}") - } - /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. fn pop_enter_mark(&mut self) -> (usize, Identifier) { @@ -655,9 +633,6 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type Node = Expr; fn f_down(&mut self, expr: &Expr) -> Result { - // put placeholder, sets the proper array length - self.id_array.push((0, "".to_string())); - // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { @@ -674,70 +649,38 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { } fn f_up(&mut self, expr: &Expr) -> Result { - self.series_number += 1; - - let (idx, sub_expr_identifier) = self.pop_enter_mark(); + let (_idx, sub_expr_identifier) = self.pop_enter_mark(); // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { - let curr_expr_identifier = Self::expr_identifier(expr); + let curr_expr_identifier = ExprSet::expr_identifier(expr); self.visit_stack .push(VisitRecord::ExprItem(curr_expr_identifier)); - self.id_array[idx].0 = self.series_number; // leave Identifer as empty "", since will not use as common expr return Ok(TreeNodeRecursion::Continue); } - let mut desc = Self::expr_identifier(expr); - desc.push_str(&sub_expr_identifier); + let curr_expr_identifier = ExprSet::expr_identifier(expr); + let alias_symbol = format!("{curr_expr_identifier}{sub_expr_identifier}"); - self.id_array[idx] = (self.series_number, desc.clone()); - self.visit_stack.push(VisitRecord::ExprItem(desc.clone())); + self.visit_stack + .push(VisitRecord::ExprItem(alias_symbol.clone())); let data_type = expr.get_type(&self.input_schema)?; self.expr_set - .entry(desc) - .or_insert_with(|| (expr.clone(), 0, data_type)) + .entry(curr_expr_identifier) + .or_insert_with(|| (expr.clone(), 0, data_type, alias_symbol)) .1 += 1; Ok(TreeNodeRecursion::Continue) } } -/// Go through an expression tree and generate identifier for every node in this tree. -fn expr_to_identifier( - expr: &Expr, - expr_set: &mut ExprSet, - id_array: &mut Vec<(usize, Identifier)>, - input_schema: DFSchemaRef, - expr_mask: ExprMask, -) -> Result<()> { - expr.visit(&mut ExprIdentifierVisitor { - expr_set, - id_array, - input_schema, - visit_stack: vec![], - node_count: 0, - series_number: 0, - expr_mask, - })?; - - Ok(()) -} - /// Rewrite expression by replacing detected common sub-expression with /// the corresponding temporary column name. That column contains the /// evaluate result of replaced expression. struct CommonSubexprRewriter<'a> { expr_set: &'a ExprSet, - id_array: &'a IdArray, /// Which identifier is replaced. affected_id: &'a mut BTreeSet, - - /// the max series number we have rewritten. Other expression nodes - /// with smaller series number is already replaced and shouldn't - /// do anything with them. - max_series_number: usize, - /// current node's information's index in `id_array`. - curr_index: usize, } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { @@ -751,80 +694,41 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } - let (series_number, curr_id) = &self.id_array[self.curr_index]; - - // halting conditions - if self.curr_index >= self.id_array.len() - || self.max_series_number > *series_number - { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); - } - - // skip `Expr`s without identifier (empty identifier). - if curr_id.is_empty() { - self.curr_index += 1; // incr idx for id_array, when not jumping - return Ok(Transformed::no(expr)); - } + let curr_id = &ExprSet::expr_identifier(&expr); // lookup previously visited expression match self.expr_set.get(curr_id) { - Some((_, counter, _)) => { + Some((_, counter, _, symbol)) => { // if has a commonly used (a.k.a. 1+ use) expr if *counter > 1 { self.affected_id.insert(curr_id.clone()); - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok(Transformed::new( - expr, - false, - TreeNodeRecursion::Jump, - )); - } - - // incr idx for id_array, when not jumping - self.curr_index += 1; - - // series_number was the inverse number ordering (when doing f_up) - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - let expr_name = expr.display_name()?; // Alias this `Column` expr to it original "expr name", // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. Ok(Transformed::new( - col(curr_id).alias(expr_name), + col(symbol).alias(expr_name), true, TreeNodeRecursion::Jump, )) } else { - self.curr_index += 1; Ok(Transformed::no(expr)) } } - _ => internal_err!("expr_set invalid state"), + None => Ok(Transformed::no(expr)), } } } fn replace_common_expr( expr: Expr, - id_array: &IdArray, expr_set: &ExprSet, affected_id: &mut BTreeSet, ) -> Result { expr.rewrite(&mut CommonSubexprRewriter { expr_set, - id_array, affected_id, - max_series_number: 0, - curr_index: 0, }) .data() } @@ -860,73 +764,6 @@ mod test { assert_eq!(expected, formatted_plan); } - #[test] - fn id_array_visitor() -> Result<()> { - let expr = ((sum(col("a") + lit(1))) - avg(col("c"))) * lit(2); - - let schema = Arc::new(DFSchema::new_with_metadata( - vec![ - DFField::new_unqualified("a", DataType::Int64, false), - DFField::new_unqualified("c", DataType::Int64, false), - ], - Default::default(), - )?); - - // skip aggregates - let mut id_array = vec![]; - expr_to_identifier( - &expr, - &mut HashMap::new(), - &mut id_array, - Arc::clone(&schema), - ExprMask::Normal, - )?; - - let expected = vec![ - (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), - (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)SUM(a + Int32(1))"), - (4, ""), - (3, "a + Int32(1)Int32(1)a"), - (1, ""), - (2, ""), - (6, ""), - (5, ""), - (8, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); - assert_eq!(expected, id_array); - - // include aggregates - let mut id_array = vec![]; - expr_to_identifier( - &expr, - &mut HashMap::new(), - &mut id_array, - Arc::clone(&schema), - ExprMask::NormalAndAggregates, - )?; - - let expected = vec![ - (9, "(SUM(a + Int32(1)) - AVG(c)) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"), - (3, "a + Int32(1)Int32(1)a"), - (1, ""), - (2, ""), - (6, "AVG(c)c"), - (5, ""), - (8, "") - ] - .into_iter() - .map(|(number, id)| (number, id.into())) - .collect::>(); - assert_eq!(expected, id_array); - - Ok(()) - } - #[test] fn tpch_q1_simplified() -> Result<()> { // SQL: @@ -1171,24 +1008,28 @@ mod test { let table_scan = test_table_scan().unwrap(); let affected_id: BTreeSet = ["c+a".to_string(), "b+a".to_string()].into_iter().collect(); - let expr_set_1 = [ + let expr_set_1 = vec![ ( "c+a".to_string(), - (col("c") + col("a"), 1, DataType::UInt32), + (col("c") + col("a"), 1, DataType::UInt32, "c+a".to_string()), ), ( "b+a".to_string(), - (col("b") + col("a"), 1, DataType::UInt32), + (col("b") + col("a"), 1, DataType::UInt32, "b+a".to_string()), ), ] - .into_iter() - .collect(); - let expr_set_2 = [ - ("c+a".to_string(), (col("c+a"), 1, DataType::UInt32)), - ("b+a".to_string(), (col("b+a"), 1, DataType::UInt32)), + .into(); + let expr_set_2 = vec![ + ( + "c+a".to_string(), + (col("c+a"), 1, DataType::UInt32, "c+a".to_string()), + ), + ( + "b+a".to_string(), + (col("b+a"), 1, DataType::UInt32, "b+a".to_string()), + ), ] - .into_iter() - .collect(); + .into(); let project = build_common_expr_project_plan(table_scan, affected_id.clone(), &expr_set_1) .unwrap(); @@ -1214,30 +1055,48 @@ mod test { ["test1.c+test1.a".to_string(), "test1.b+test1.a".to_string()] .into_iter() .collect(); - let expr_set_1 = [ + let expr_set_1 = vec![ ( "test1.c+test1.a".to_string(), - (col("test1.c") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.c") + col("test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b") + col("test1.a"), 1, DataType::UInt32), + ( + col("test1.b") + col("test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] - .into_iter() - .collect(); - let expr_set_2 = [ + .into(); + let expr_set_2 = vec![ ( "test1.c+test1.a".to_string(), - (col("test1.c+test1.a"), 1, DataType::UInt32), + ( + col("test1.c+test1.a"), + 1, + DataType::UInt32, + "test1.c+test1.a".to_string(), + ), ), ( "test1.b+test1.a".to_string(), - (col("test1.b+test1.a"), 1, DataType::UInt32), + ( + col("test1.b+test1.a"), + 1, + DataType::UInt32, + "test1.b+test1.a".to_string(), + ), ), ] - .into_iter() - .collect(); + .into(); let project = build_common_expr_project_plan(join, affected_id.clone(), &expr_set_1) .unwrap(); diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 75bcbc07755b..2e0cbf50cab9 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2262,3 +2262,66 @@ query RRR rowsort select f64, case when f64 > 0 then 1.0 / f64 else null end, acos(case when f64 > 0 then 1.0 / f64 else null end) from doubles; ---- 10.1 0.09900990099 1.471623942989 + + +statement ok +CREATE TABLE t1( + time TIMESTAMP, + load1 DOUBLE, + load2 DOUBLE, + host VARCHAR +) AS VALUES + (to_timestamp_nanos(1527018806000000000), 1.1, 101, 'host1'), + (to_timestamp_nanos(1527018806000000000), 2.2, 202, 'host2'), + (to_timestamp_nanos(1527018806000000000), 3.3, 303, 'host3'), + (to_timestamp_nanos(1527018806000000000), 1.1, 101, NULL) +; + +# struct scalar function with columns +query ? +select struct(time,load1,load2,host) from t1; +---- +{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: host1} +{c0: 2018-05-22T19:53:26, c1: 2.2, c2: 202.0, c3: host2} +{c0: 2018-05-22T19:53:26, c1: 3.3, c2: 303.0, c3: host3} +{c0: 2018-05-22T19:53:26, c1: 1.1, c2: 101.0, c3: } + +# can have an aggregate function with an inner coalesce +query TR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 +host2 2.2 +host3 3.3 + +# can have an aggregate function with an inner CASE WHEN +query TR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 101 +host2 202 +host3 303 + +# can have 2 projections with aggr(short_circuited), with different short-circuited expr +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303 + +# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) +query TRR +select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +---- +host1 1.1 101 +host2 2.2 202 +host3 3.3 303