Skip to content

Commit

Permalink
Internal cast for array()
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Jun 9, 2023
1 parent 1d3860d commit 7b67c23
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 1 deletion.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ arrow-flight = { version = "41.0.0", features = ["flight-sql-experimental"] }
arrow-buffer = { version = "41.0.0", default-features = false }
arrow-schema = { version = "41.0.0", default-features = false }
arrow-array = { version = "41.0.0", default-features = false, features = ["chrono-tz"] }
arrow-cast = { version = "41.0.0", default-features = false }
parquet = { version = "41.0.0", features = ["arrow", "async", "object_store"] }

[profile.release]
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub use nullif::SUPPORTED_NULLIF_TYPES;
pub use operator::Operator;
pub use signature::{Signature, TypeSignature, Volatility};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use type_coercion::comparison_coercion;
pub use udaf::AggregateUDF;
pub use udf::ScalarUDF;
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
2 changes: 2 additions & 0 deletions datafusion/expr/src/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ pub mod binary;
pub mod functions;
pub mod other;

pub use binary::comparison_coercion;

use arrow::datatypes::DataType;
/// Determine whether the given data type `dt` represents signed numeric values.
pub fn is_signed_numeric(dt: &DataType) -> bool {
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ ahash = { version = "0.8", default-features = false, features = ["runtime-rng"]
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
arrow-cast = { workspace = true }
arrow-schema = { workspace = true }
blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
Expand Down
58 changes: 57 additions & 1 deletion datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ use arrow::array::*;
use arrow::buffer::Buffer;
use arrow::compute;
use arrow::datatypes::{DataType, Field};
use arrow_cast::cast;
use core::any::type_name;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::comparison_coercion;
use datafusion_expr::ColumnarValue;
use std::sync::Arc;

Expand Down Expand Up @@ -85,7 +87,18 @@ fn array_array(args: &[ArrayRef]) -> Result<ArrayRef> {
));
}

let data_type = args[0].data_type();
let data_type = args
.iter()
.skip(1)
.fold(args[0].data_type().clone(), |acc, x| {
comparison_coercion(&acc, x.data_type()).unwrap_or(acc)
});

let args: &[ArrayRef] = &args
.iter()
.map(|item| cast(item, &data_type).unwrap())
.collect::<Vec<ArrayRef>>();

let res = match data_type {
DataType::List(..) => {
let arrays =
Expand Down Expand Up @@ -1124,6 +1137,49 @@ mod tests {
.values()
)
}
#[test]
fn test_array_with_different_types_1() {
let args = [
ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
ColumnarValue::Scalar(ScalarValue::Int64(Some(1))),
];
let array = array(&args)
.expect("failed to initialize function array")
.into_array(1);
let result = as_list_array(&array).expect("failed to initialize function array");
assert_eq!(result.len(), 1);
assert_eq!(
&[1, 1],
result
.value(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap()
.values()
);
}

#[test]
fn test_array_with_different_types_2() {
let args = [
ColumnarValue::Scalar(ScalarValue::Float32(Some(1.0))),
ColumnarValue::Scalar(ScalarValue::Float64(Some(1.0))),
];
let array = array(&args)
.expect("failed to initialize function array")
.into_array(1);
let result = as_list_array(&array).expect("failed to initialize function array");
assert_eq!(result.len(), 1);
assert_eq!(
&[1.0, 1.0],
result
.value(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.values()
)
}

#[test]
fn test_nested_array() {
Expand Down

0 comments on commit 7b67c23

Please sign in to comment.