From 0b607aedaf630e89867a4ec4e10509b50fc0d6d5 Mon Sep 17 00:00:00 2001 From: Li Yazhou Date: Tue, 9 Nov 2021 11:13:40 +0800 Subject: [PATCH] allow null array to be casted to all other types --- src/compute/cast/mod.rs | 80 ++++++++++++++++++++++++++++++++++++++-- tests/it/compute/cast.rs | 38 +++++++++++++++++++ 2 files changed, 114 insertions(+), 4 deletions(-) diff --git a/src/compute/cast/mod.rs b/src/compute/cast/mod.rs index 99c1542e792..a918e4a96e6 100644 --- a/src/compute/cast/mod.rs +++ b/src/compute/cast/mod.rs @@ -78,6 +78,44 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } match (from_type, to_type) { + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + ) + | ( + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + Null, + ) => true, (Struct(_), _) => false, (_, Struct(_)) => false, (List(list_from), List(list_to)) => { @@ -254,7 +292,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Timestamp(_, _), Date64) => true, (Int64, Duration(_)) => true, (Duration(_), Int64) => true, - (Null, Int32) => true, (_, _) => false, } } @@ -337,7 +374,44 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu let as_options = options.with_wrapped(true); match (from_type, to_type) { - (Null, Int32) => Ok(new_null_array(to_type.clone(), array.len())), + ( + Null, + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + ) + | ( + Boolean + | Int8 + | UInt8 + | Int16 + | UInt16 + | Int32 + | UInt32 + | Float32 + | Date32 + | Time32(_) + | Int64 + | UInt64 + | Float64 + | Date64 + | List(_) + | Dictionary(_, _), + Null, + ) => Ok(new_null_array(to_type.clone(), array.len())), (Struct(_), _) => Err(ArrowError::NotYetImplemented( "Cannot cast from struct to other types".to_string(), )), @@ -790,8 +864,6 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu (Int64, Duration(_)) => primitive_to_same_primitive_dyn::(array, to_type), (Duration(_), Int64) => primitive_to_same_primitive_dyn::(array, to_type), - // null to primitive/flat types - //(Null, Int32) => Ok(Box::new(Int32Array::from(vec![None; array.len()]))), (_, _) => Err(ArrowError::NotYetImplemented(format!( "Casting from {:?} to {:?} not supported", from_type, to_type, diff --git a/tests/it/compute/cast.rs b/tests/it/compute/cast.rs index c1156553ee5..94bc05f497c 100644 --- a/tests/it/compute/cast.rs +++ b/tests/it/compute/cast.rs @@ -597,6 +597,44 @@ fn naive_timestamp_to_utf8() { assert_eq!(expected, result.as_ref()); } +#[test] +fn null_array_from_and_to_others() { + macro_rules! typed_test { + ($ARR_TYPE:ident, $DATATYPE:ident) => {{ + { + let array = new_null_array(DataType::Null, 6); + let expected = $ARR_TYPE::from(vec![None; 6]); + let cast_type = DataType::$DATATYPE; + let result = + cast(array.as_ref(), &cast_type, CastOptions::default()).expect("cast failed"); + let result = result.as_any().downcast_ref::<$ARR_TYPE>().unwrap(); + assert_eq!(result.data_type(), &cast_type); + assert_eq!(result, &expected); + } + { + let array = $ARR_TYPE::from(vec![None; 4]); + let expected = NullArray::new_null(DataType::Null, 4); + let result = + cast(&array, &DataType::Null, CastOptions::default()).expect("cast failed"); + let result = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result.data_type(), &DataType::Null); + assert_eq!(result, &expected); + } + }}; + } + + typed_test!(Int16Array, Int16); + typed_test!(Int32Array, Int32); + typed_test!(Int64Array, Int64); + + typed_test!(UInt16Array, UInt16); + typed_test!(UInt32Array, UInt32); + typed_test!(UInt64Array, UInt64); + + typed_test!(Float32Array, Float32); + typed_test!(Float64Array, Float64); +} + /* #[test] fn dict_to_dict_bad_index_value_primitive() {