From cdf527fa3f0ef53fbdcc3963df6e66e4dd6799ec Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 9 Jun 2023 16:00:35 +0800 Subject: [PATCH 1/7] Internal cast for array() Signed-off-by: jayzhan211 --- Cargo.toml | 1 + datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/type_coercion/mod.rs | 2 + datafusion/physical-expr/Cargo.toml | 1 + .../physical-expr/src/array_expressions.rs | 58 ++++++++++++++++++- 5 files changed, 62 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 6b24a44e9ad7..9f494ab472c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ arrow-flight = { version = "41.0.0", features = ["flight-sql-experimental"] } arrow-buffer = { version = "41.0.0", default-features = false } arrow-schema = { version = "41.0.0", default-features = false } arrow-array = { version = "41.0.0", default-features = false, features = ["chrono-tz"] } +arrow-cast = { version = "41.0.0", default-features = false } parquet = { version = "41.0.0", features = ["arrow", "async", "object_store"] } [profile.release] diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 1675afb9c98a..75f5986621fb 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -71,6 +71,7 @@ pub use nullif::SUPPORTED_NULLIF_TYPES; pub use operator::Operator; pub use signature::{Signature, TypeSignature, Volatility}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; +pub use type_coercion::comparison_coercion; pub use udaf::AggregateUDF; pub use udf::ScalarUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 0881bce98d6a..08a6de73606a 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -36,6 +36,8 @@ pub mod binary; pub mod functions; pub mod other; +pub use binary::comparison_coercion; + use arrow::datatypes::DataType; /// Determine whether the given data type `dt` represents signed numeric values. pub fn is_signed_numeric(dt: &DataType) -> bool { diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index b851c00edc2b..919a345ac27f 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -46,6 +46,7 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } +arrow-cast = { workspace = true } arrow-schema = { workspace = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 44b747082a5c..6238365c21f8 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -21,9 +21,11 @@ use arrow::array::*; use arrow::buffer::Buffer; use arrow::compute; use arrow::datatypes::{DataType, Field}; +use arrow_cast::cast; use core::any::type_name; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::comparison_coercion; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -85,7 +87,18 @@ fn array_array(args: &[ArrayRef]) -> Result { )); } - let data_type = args[0].data_type(); + let data_type = args + .iter() + .skip(1) + .fold(args[0].data_type().clone(), |acc, x| { + comparison_coercion(&acc, x.data_type()).unwrap_or(acc) + }); + + let args: &[ArrayRef] = &args + .iter() + .map(|item| cast(item, &data_type).unwrap()) + .collect::>(); + let res = match data_type { DataType::List(..) => { let arrays = @@ -1156,6 +1169,49 @@ mod tests { .values() ) } + #[test] + fn test_array_with_different_types_1() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), + ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), + ]; + let array = array(&args) + .expect("failed to initialize function array") + .into_array(1); + let result = as_list_array(&array).expect("failed to initialize function array"); + assert_eq!(result.len(), 1); + assert_eq!( + &[1, 1], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ); + } + + #[test] + fn test_array_with_different_types_2() { + let args = [ + ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), + ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), + ]; + let array = array(&args) + .expect("failed to initialize function array") + .into_array(1); + let result = as_list_array(&array).expect("failed to initialize function array"); + assert_eq!(result.len(), 1); + assert_eq!( + &[1.0, 1.0], + result + .value(0) + .as_any() + .downcast_ref::() + .unwrap() + .values() + ) + } #[test] fn test_nested_array() { From 89f58e784597bed339782bf8eaac7c691c56da42 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Fri, 9 Jun 2023 16:24:44 +0800 Subject: [PATCH 2/7] Add sqllogictest Signed-off-by: jayzhan211 --- datafusion/core/tests/sqllogictests/test_files/array.slt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 183522138044..78697a0f4694 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -300,3 +300,8 @@ query II rowsort select array_ndims(make_array()), array_ndims(make_array(make_array())) ---- 1 2 + +query ? +select make_array(1, 2.0) +---- +[1.0, 2.0] From b5e5d622a7d0faa47cb03cb5dd7e98f3f23b8007 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 12 Jun 2023 20:40:35 +0800 Subject: [PATCH 3/7] address CI fail Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 90 +++++++++++++++------------------------ 1 file changed, 34 insertions(+), 56 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 71b18f71a5fb..f3cae2330466 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -657,9 +657,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.3.3" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ae2468a89544a466886840aa467a25b766499f4f04bf7d9fcd10ecee9fccef" +checksum = "729b71f35bd3fa1a4c86b85d32c8b9069ea7fe14f7a53cfabb65f62d4265b888" dependencies = [ "arrayref", "arrayvec", @@ -883,9 +883,9 @@ dependencies = [ [[package]] name = "constant_time_eq" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13418e745008f7349ec7e449155f419a61b92b58a99cc3616942b926825ec76b" +checksum = "21a53c0a4d288377e7415b53dcfc3c04da5cdc2cc95c8d5ac178b58f0b861ad6" [[package]] name = "core-foundation" @@ -1104,6 +1104,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-buffer", + "arrow-cast", "arrow-schema", "blake2", "blake3", @@ -1434,9 +1435,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", "libc", @@ -1634,14 +1635,14 @@ dependencies = [ "hyper", "rustls 0.21.1", "tokio", - "tokio-rustls 0.24.0", + "tokio-rustls 0.24.1", ] [[package]] name = "iana-time-zone" -version = "0.1.56" +version = "0.1.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0722cd7114b7de04316e7ea5456a0bbb20e4adb46fd27a3697adb812cff0f37c" +checksum = "2fad5b825842d2b38bd206f3e81d6957625fd7f0a361e345c30e01a0ae2dd613" dependencies = [ "android_system_properties", "core-foundation-sys", @@ -1855,9 +1856,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.18" +version = "0.4.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "518ef76f2f87365916b142844c16d8fefd85039bc5699050210a7778ee1cd1de" +checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" [[package]] name = "lz4" @@ -2101,9 +2102,9 @@ dependencies = [ [[package]] name = "os_str_bytes" -version = "6.5.0" +version = "6.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267" +checksum = "4d5d9eb14b174ee9aa2ef96dc2b94637a2d4b6e7cb873c7e171f0c20c6cf3eac" [[package]] name = "outref" @@ -2131,7 +2132,7 @@ dependencies = [ "libc", "redox_syscall 0.3.5", "smallvec", - "windows-targets 0.48.0", + "windows-targets", ] [[package]] @@ -2313,9 +2314,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.59" +version = "1.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6aeca18b86b413c660b781aa319e4e2648a3e6f9eadc9b47e9038e6fe9f3451b" +checksum = "dec2b086b7a862cf4de201096214fa870344cf922b2b30c167badb3af3195406" dependencies = [ "unicode-ident", ] @@ -2454,7 +2455,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "tokio", - "tokio-rustls 0.24.0", + "tokio-rustls 0.24.1", "tokio-util", "tower-service", "url", @@ -2492,9 +2493,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.37.19" +version = "0.37.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acf8729d8542766f1b2cf77eb034d52f40d375bb8b615d0b147089946e16613d" +checksum = "b96e891d04aa506a6d1f318d2771bcb1c7dfda84e126660ace067c9b474bb2c0" dependencies = [ "bitflags", "errno", @@ -2665,18 +2666,18 @@ checksum = "e6b44e8fc93a14e66336d230954dda83d18b4605ccace8fe09bc7514a71ad0bc" [[package]] name = "serde" -version = "1.0.163" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2113ab51b87a539ae008b5c6c02dc020ffa39afd2d83cffcb3f4eb2722cebec2" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.163" +version = "1.0.164" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c805777e3930c8883389c602315a24224bcc738b63905ef87cd1420353ea93e" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" dependencies = [ "proc-macro2", "quote", @@ -2873,15 +2874,16 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.5.0" +version = "3.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" +checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" dependencies = [ + "autocfg", "cfg-if", "fastrand", "redox_syscall 0.3.5", "rustix", - "windows-sys 0.45.0", + "windows-sys 0.48.0", ] [[package]] @@ -2932,9 +2934,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.21" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f3403384eaacbca9923fa06940178ac13e4edb725486d70e8e15881d0c836cc" +checksum = "ea9e1b3cf1243ae005d9e74085d4d542f3125458f3a81af210d901dcd7411efd" dependencies = [ "serde", "time-core", @@ -3022,9 +3024,9 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.24.0" +version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0d409377ff5b1e3ca6437aa86c1eb7d40c134bfec254e44c830defa92669db5" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ "rustls 0.21.1", "tokio", @@ -3392,7 +3394,7 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e686886bc078bc1b0b600cac0147aadb815089b6e4da64016cbd754b6342700f" dependencies = [ - "windows-targets 0.48.0", + "windows-targets", ] [[package]] @@ -3410,37 +3412,13 @@ dependencies = [ "windows_x86_64_msvc 0.42.2", ] -[[package]] -name = "windows-sys" -version = "0.45.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" -dependencies = [ - "windows-targets 0.42.2", -] - [[package]] name = "windows-sys" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets 0.48.0", -] - -[[package]] -name = "windows-targets" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" -dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", + "windows-targets", ] [[package]] From c1bf80750cead4a04696ba13c3375cdf5d4cdb6c Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 13 Jun 2023 20:02:02 +0800 Subject: [PATCH 4/7] address comments Signed-off-by: jayzhan211 --- Cargo.toml | 1 - datafusion-cli/Cargo.lock | 1 - .../optimizer/src/analyzer/type_coercion.rs | 36 ++++++++++-- datafusion/physical-expr/Cargo.toml | 1 - .../physical-expr/src/array_expressions.rs | 58 +------------------ 5 files changed, 32 insertions(+), 65 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9f494ab472c3..6b24a44e9ad7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,6 @@ arrow-flight = { version = "41.0.0", features = ["flight-sql-experimental"] } arrow-buffer = { version = "41.0.0", default-features = false } arrow-schema = { version = "41.0.0", default-features = false } arrow-array = { version = "41.0.0", default-features = false, features = ["chrono-tz"] } -arrow-cast = { version = "41.0.0", default-features = false } parquet = { version = "41.0.0", features = ["arrow", "async", "object_store"] } [profile.release] diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index f3cae2330466..9f6296b5b08c 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1104,7 +1104,6 @@ dependencies = [ "arrow", "arrow-array", "arrow-buffer", - "arrow-cast", "arrow-schema", "blake2", "blake3", diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3ee6a2401b02..4e0d60898471 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -41,8 +41,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_numeric, is_utf8_or_large_u use datafusion_expr::utils::from_plan; use datafusion_expr::{ aggregate_function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, - is_unknown, type_coercion, AggregateFunction, Expr, LogicalPlan, Operator, - Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, + is_unknown, type_coercion, AggregateFunction, BuiltinScalarFunction, Expr, + LogicalPlan, Operator, Projection, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use datafusion_expr::{ExprSchemable, Signature}; @@ -387,13 +387,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter { Ok(expr) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let nex_expr = coerce_arguments_for_signature( + let mut new_expr = coerce_arguments_for_signature( args.as_slice(), &self.schema, &fun.signature(), )?; - let expr = Expr::ScalarFunction(ScalarFunction::new(fun, nex_expr)); - Ok(expr) + if fun == BuiltinScalarFunction::MakeArray { + new_expr = coerce_arguments(new_expr.as_slice())?; + } + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_expr))) } Expr::AggregateFunction(expr::AggregateFunction { fun, @@ -602,6 +604,30 @@ fn coerce_arguments_for_signature( .collect::>>() } +fn coerce_arguments(expressions: &[Expr]) -> Result> { + if expressions.is_empty() { + return Ok(vec![]); + } + + let current_types = expressions + .iter() + .map(|e| e.get_type(&DFSchema::empty())) + .collect::>>()?; + + let new_type = current_types + .iter() + .skip(1) + .fold(current_types.first().unwrap().clone(), |acc, x| { + comparison_coercion(&acc, x).unwrap_or(acc) + }); + + expressions + .iter() + .enumerate() + .map(|(_, expr)| cast_expr(expr, &new_type, &DFSchema::empty())) + .collect() +} + /// Cast `expr` to the specified type, if possible fn cast_expr(expr: &Expr, to_type: &DataType, schema: &DFSchema) -> Result { expr.clone().cast_to(to_type, schema) diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 919a345ac27f..b851c00edc2b 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -46,7 +46,6 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] arrow = { workspace = true } arrow-array = { workspace = true } arrow-buffer = { workspace = true } -arrow-cast = { workspace = true } arrow-schema = { workspace = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 6238365c21f8..44b747082a5c 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -21,11 +21,9 @@ use arrow::array::*; use arrow::buffer::Buffer; use arrow::compute; use arrow::datatypes::{DataType, Field}; -use arrow_cast::cast; use core::any::type_name; use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::comparison_coercion; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -87,18 +85,7 @@ fn array_array(args: &[ArrayRef]) -> Result { )); } - let data_type = args - .iter() - .skip(1) - .fold(args[0].data_type().clone(), |acc, x| { - comparison_coercion(&acc, x.data_type()).unwrap_or(acc) - }); - - let args: &[ArrayRef] = &args - .iter() - .map(|item| cast(item, &data_type).unwrap()) - .collect::>(); - + let data_type = args[0].data_type(); let res = match data_type { DataType::List(..) => { let arrays = @@ -1169,49 +1156,6 @@ mod tests { .values() ) } - #[test] - fn test_array_with_different_types_1() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ]; - let array = array(&args) - .expect("failed to initialize function array") - .into_array(1); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 1); - assert_eq!( - &[1, 1], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ); - } - - #[test] - fn test_array_with_different_types_2() { - let args = [ - ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))), - ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))), - ]; - let array = array(&args) - .expect("failed to initialize function array") - .into_array(1); - let result = as_list_array(&array).expect("failed to initialize function array"); - assert_eq!(result.len(), 1); - assert_eq!( - &[1.0, 1.0], - result - .value(0) - .as_any() - .downcast_ref::() - .unwrap() - .values() - ) - } #[test] fn test_nested_array() { From 050aafbf357a5d2f0daed50ec8ea85ff483124b3 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 13 Jun 2023 20:17:03 +0800 Subject: [PATCH 5/7] refactor Signed-off-by: jayzhan211 --- datafusion/expr/src/lib.rs | 1 - datafusion/expr/src/type_coercion/mod.rs | 2 - .../optimizer/src/analyzer/type_coercion.rs | 50 +++++++++++-------- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 75f5986621fb..1675afb9c98a 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -71,7 +71,6 @@ pub use nullif::SUPPORTED_NULLIF_TYPES; pub use operator::Operator; pub use signature::{Signature, TypeSignature, Volatility}; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use type_coercion::comparison_coercion; pub use udaf::AggregateUDF; pub use udf::ScalarUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/type_coercion/mod.rs b/datafusion/expr/src/type_coercion/mod.rs index 08a6de73606a..0881bce98d6a 100644 --- a/datafusion/expr/src/type_coercion/mod.rs +++ b/datafusion/expr/src/type_coercion/mod.rs @@ -36,8 +36,6 @@ pub mod binary; pub mod functions; pub mod other; -pub use binary::comparison_coercion; - use arrow::datatypes::DataType; /// Determine whether the given data type `dt` represents signed numeric values. pub fn is_signed_numeric(dt: &DataType) -> bool { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4e0d60898471..b187e1e59597 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -387,15 +387,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { Ok(expr) } Expr::ScalarFunction(ScalarFunction { fun, args }) => { - let mut new_expr = coerce_arguments_for_signature( + let new_args = coerce_arguments_for_signature( args.as_slice(), &self.schema, &fun.signature(), )?; - if fun == BuiltinScalarFunction::MakeArray { - new_expr = coerce_arguments(new_expr.as_slice())?; - } - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_expr))) + let new_args = coerce_arguments_for_fun(fun, new_args.as_slice())?; + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) } Expr::AggregateFunction(expr::AggregateFunction { fun, @@ -604,28 +602,36 @@ fn coerce_arguments_for_signature( .collect::>>() } -fn coerce_arguments(expressions: &[Expr]) -> Result> { +fn coerce_arguments_for_fun( + fun: BuiltinScalarFunction, + expressions: &[Expr], +) -> Result> { if expressions.is_empty() { return Ok(vec![]); } - let current_types = expressions - .iter() - .map(|e| e.get_type(&DFSchema::empty())) - .collect::>>()?; - - let new_type = current_types - .iter() - .skip(1) - .fold(current_types.first().unwrap().clone(), |acc, x| { - comparison_coercion(&acc, x).unwrap_or(acc) - }); + if fun == BuiltinScalarFunction::MakeArray { + // Find the final data type for the function arguments + let current_types = expressions + .iter() + .map(|e| e.get_type(&DFSchema::empty())) + .collect::>>()?; + + let new_type = current_types + .iter() + .skip(1) + .fold(current_types.first().unwrap().clone(), |acc, x| { + comparison_coercion(&acc, x).unwrap_or(acc) + }); + + return expressions + .iter() + .enumerate() + .map(|(_, expr)| cast_expr(expr, &new_type, &DFSchema::empty())) + .collect(); + } - expressions - .iter() - .enumerate() - .map(|(_, expr)| cast_expr(expr, &new_type, &DFSchema::empty())) - .collect() + Ok(expressions.to_vec()) } /// Cast `expr` to the specified type, if possible From 1b983b4be2a4a507a6fe34ff118de2afe725034f Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 13 Jun 2023 20:55:13 +0800 Subject: [PATCH 6/7] add schema for coerce_args_for_fun Signed-off-by: jayzhan211 --- datafusion/optimizer/src/analyzer/type_coercion.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index b187e1e59597..b17c314cc662 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -392,7 +392,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature(), )?; - let new_args = coerce_arguments_for_fun(fun, new_args.as_slice())?; + let new_args = + coerce_arguments_for_fun(new_args.as_slice(), &self.schema, &fun)?; Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) } Expr::AggregateFunction(expr::AggregateFunction { @@ -603,18 +604,19 @@ fn coerce_arguments_for_signature( } fn coerce_arguments_for_fun( - fun: BuiltinScalarFunction, expressions: &[Expr], + schema: &DFSchema, + fun: &BuiltinScalarFunction, ) -> Result> { if expressions.is_empty() { return Ok(vec![]); } - if fun == BuiltinScalarFunction::MakeArray { + if *fun == BuiltinScalarFunction::MakeArray { // Find the final data type for the function arguments let current_types = expressions .iter() - .map(|e| e.get_type(&DFSchema::empty())) + .map(|e| e.get_type(schema)) .collect::>>()?; let new_type = current_types @@ -627,7 +629,7 @@ fn coerce_arguments_for_fun( return expressions .iter() .enumerate() - .map(|(_, expr)| cast_expr(expr, &new_type, &DFSchema::empty())) + .map(|(_, expr)| cast_expr(expr, &new_type, schema)) .collect(); } From 531f76a2b055cec7bdfeec05931018b0fbfc8f68 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 14 Jun 2023 09:07:53 +0800 Subject: [PATCH 7/7] add more tests Signed-off-by: jayzhan211 --- .../tests/sqllogictests/test_files/array.slt | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index 78697a0f4694..459046136b83 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -305,3 +305,34 @@ query ? select make_array(1, 2.0) ---- [1.0, 2.0] + +query ? +select make_array(null, 1.0) +---- +[, 1.0] + +query ? +select make_array(1, 2.0, null, 3) +---- +[1.0, 2.0, , 3.0] + +query ? +select make_array(1.0, '2', null) +---- +[1.0, 2, ] + +statement ok +create table foo1 (x int, y double) as values (1, 2.0); + +query ? +select make_array(x, y) from foo1; +---- +[1.0, 2.0] + +statement ok +create table foo2 (x float, y varchar) as values (1.0, '1'); + +query ? +select make_array(x, y) from foo2; +---- +[1.0, 1]