From ae0f401d89cf9dc8b717e6b95f73ed4f3be9798b Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sat, 20 Jan 2024 04:14:44 +0800 Subject: [PATCH] Add support for FixedSizeList type in `arrow_cast`, hashing (#8344) * Add support for parsing FixedSizeList type * fix fmt * support cast fixedsizelist from list * clean comment * support cast between NULL and FixedSizedLisr * add test for FixedSizeList hash * add test for cast fixedsizelist --- datafusion/common/src/hash_utils.rs | 65 ++++++++++++++++++- datafusion/common/src/scalar.rs | 24 +++++-- datafusion/common/src/utils.rs | 17 ++++- datafusion/expr/src/utils.rs | 1 + datafusion/sql/src/expr/arrow_cast.rs | 17 +++++ .../sqllogictest/test_files/arrow_typeof.slt | 39 ++++++++++- 6 files changed, 155 insertions(+), 8 deletions(-) diff --git a/datafusion/common/src/hash_utils.rs b/datafusion/common/src/hash_utils.rs index 8dcc00ca1c29..d5a1b3ee363b 100644 --- a/datafusion/common/src/hash_utils.rs +++ b/datafusion/common/src/hash_utils.rs @@ -27,8 +27,9 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array}; use arrow_buffer::i256; use crate::cast::{ - as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, - as_primitive_array, as_string_array, as_struct_array, + as_boolean_array, as_fixed_size_list_array, as_generic_binary_array, + as_large_list_array, as_list_array, as_primitive_array, as_string_array, + as_struct_array, }; use crate::error::{DataFusionError, Result, _internal_err}; @@ -267,6 +268,38 @@ where Ok(()) } +fn hash_fixed_list_array( + array: &FixedSizeListArray, + random_state: &RandomState, + hashes_buffer: &mut [u64], +) -> Result<()> { + let values = array.values().clone(); + let value_len = array.value_length(); + let offset_size = value_len as usize / array.len(); + let nulls = array.nulls(); + let mut values_hashes = vec![0u64; values.len()]; + create_hashes(&[values], random_state, &mut values_hashes)?; + if let Some(nulls) = nulls { + for i in 0..array.len() { + if nulls.is_valid(i) { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size] + { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + } else { + for i in 0..array.len() { + let hash = &mut hashes_buffer[i]; + for values_hash in &values_hashes[i * offset_size..(i + 1) * offset_size] { + *hash = combine_hashes(*hash, *values_hash); + } + } + } + Ok(()) +} + /// Test version of `create_hashes` that produces the same value for /// all hashes (to test collisions) /// @@ -366,6 +399,10 @@ pub fn create_hashes<'a>( let array = as_large_list_array(array)?; hash_list_array(array, random_state, hashes_buffer)?; } + DataType::FixedSizeList(_,_) => { + let array = as_fixed_size_list_array(array)?; + hash_fixed_list_array(array, random_state, hashes_buffer)?; + } _ => { // This is internal because we should have caught this before. return _internal_err!( @@ -546,6 +583,30 @@ mod tests { assert_eq!(hashes[2], hashes[3]); } + #[test] + // Tests actual values of hashes, which are different if forcing collisions + #[cfg(not(feature = "force_hash_collisions"))] + fn create_hashes_for_fixed_size_list_arrays() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), None, Some(5)]), + Some(vec![Some(3), None, Some(5)]), + None, + Some(vec![Some(0), Some(1), Some(2)]), + ]; + let list_array = + Arc::new(FixedSizeListArray::from_iter_primitive::( + data, 3, + )) as ArrayRef; + let random_state = RandomState::with_seeds(0, 0, 0, 0); + let mut hashes = vec![0; list_array.len()]; + create_hashes(&[list_array], &random_state, &mut hashes).unwrap(); + assert_eq!(hashes[0], hashes[5]); + assert_eq!(hashes[1], hashes[4]); + assert_eq!(hashes[2], hashes[3]); + } + #[test] // Tests actual values of hashes, which are different if forcing collisions #[cfg(not(feature = "force_hash_collisions"))] diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 20d03c70960a..99b8cff20de7 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -34,8 +34,9 @@ use crate::cast::{ }; use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; -use crate::utils::{array_into_large_list_array, array_into_list_array}; - +use crate::utils::{ + array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array, +}; use arrow::compute::kernels::numeric::*; use arrow::util::display::{ArrayFormatter, FormatOptions}; use arrow::{ @@ -2223,9 +2224,11 @@ impl ScalarValue { let list_array = as_fixed_size_list_array(array)?; let nested_array = list_array.value(index); // Produces a single element `ListArray` with the value at `index`. - let arr = Arc::new(array_into_list_array(nested_array)); + let list_size = nested_array.len(); + let arr = + Arc::new(array_into_fixed_size_list_array(nested_array, list_size)); - ScalarValue::List(arr) + ScalarValue::FixedSizeList(arr) } DataType::Date32 => typed_cast!(array, index, Date32Array, Date32)?, DataType::Date64 => typed_cast!(array, index, Date64Array, Date64)?, @@ -2971,6 +2974,19 @@ impl TryFrom<&DataType> for ScalarValue { .to_owned() .into(), ), + // `ScalaValue::FixedSizeList` contains single element `FixedSizeList`. + DataType::FixedSizeList(field, _) => ScalarValue::FixedSizeList( + new_null_array( + &DataType::FixedSizeList( + Arc::new(Field::new("item", field.data_type().clone(), true)), + 1, + ), + 1, + ) + .as_fixed_size_list() + .to_owned() + .into(), + ), DataType::Struct(fields) => ScalarValue::Struct(None, fields.clone()), DataType::Null => ScalarValue::Null, _ => { diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index 0a61fce15482..d21bd464f850 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -25,7 +25,9 @@ use arrow::compute; use arrow::compute::{partition, SortColumn, SortOptions}; use arrow::datatypes::{Field, SchemaRef, UInt32Type}; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, LargeListArray, ListArray, RecordBatchOptions}; +use arrow_array::{ + Array, FixedSizeListArray, LargeListArray, ListArray, RecordBatchOptions, +}; use arrow_schema::DataType; use sqlparser::ast::Ident; use sqlparser::dialect::GenericDialect; @@ -368,6 +370,19 @@ pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { ) } +pub fn array_into_fixed_size_list_array( + arr: ArrayRef, + list_size: usize, +) -> FixedSizeListArray { + let list_size = list_size as i32; + FixedSizeListArray::new( + Arc::new(Field::new("item", arr.data_type().to_owned(), true)), + list_size, + arr, + None, + ) +} + /// Wrap arrays into a single element `ListArray`. /// /// Example: diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 40c2c4705362..02479c0765bd 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -911,6 +911,7 @@ pub fn can_hash(data_type: &DataType) -> bool { } DataType::List(_) => true, DataType::LargeList(_) => true, + DataType::FixedSizeList(_, _) => true, _ => false, } } diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/sql/src/expr/arrow_cast.rs index ade8b96b5cc2..9a0d61f41c01 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/sql/src/expr/arrow_cast.rs @@ -150,6 +150,7 @@ impl<'a> Parser<'a> { Token::Dictionary => self.parse_dictionary(), Token::List => self.parse_list(), Token::LargeList => self.parse_large_list(), + Token::FixedSizeList => self.parse_fixed_size_list(), tok => Err(make_error( self.val, &format!("finding next type, got unexpected '{tok}'"), @@ -177,6 +178,19 @@ impl<'a> Parser<'a> { )))) } + /// Parses the FixedSizeList type + fn parse_fixed_size_list(&mut self) -> Result { + self.expect_token(Token::LParen)?; + let length = self.parse_i32("FixedSizeList")?; + self.expect_token(Token::Comma)?; + let data_type = self.parse_next_type()?; + self.expect_token(Token::RParen)?; + Ok(DataType::FixedSizeList( + Arc::new(Field::new("item", data_type, true)), + length, + )) + } + /// Parses the next timeunit fn parse_time_unit(&mut self, context: &str) -> Result { match self.next_token()? { @@ -508,6 +522,7 @@ impl<'a> Tokenizer<'a> { "List" => Token::List, "LargeList" => Token::LargeList, + "FixedSizeList" => Token::FixedSizeList, "Second" => Token::TimeUnit(TimeUnit::Second), "Millisecond" => Token::TimeUnit(TimeUnit::Millisecond), @@ -598,6 +613,7 @@ enum Token { DoubleQuotedString(String), List, LargeList, + FixedSizeList, } impl Display for Token { @@ -606,6 +622,7 @@ impl Display for Token { Token::SimpleType(t) => write!(f, "{t}"), Token::List => write!(f, "List"), Token::LargeList => write!(f, "LargeList"), + Token::FixedSizeList => write!(f, "FixedSizeList"), Token::Timestamp => write!(f, "Timestamp"), Token::Time32 => write!(f, "Time32"), Token::Time64 => write!(f, "Time64"), diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 5e9e7ff03d8b..afc28ecc39dc 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -384,4 +384,41 @@ LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, di query T select arrow_typeof(arrow_cast(make_array([1, 2, 3]), 'LargeList(LargeList(Int64))')); ---- -LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) \ No newline at end of file +LargeList(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + +## FixedSizeList + +query ? +select arrow_cast(null, 'FixedSizeList(1, Int64)'); +---- +NULL + +#TODO: arrow-rs doesn't support it yet +#query ? +#select arrow_cast('1', 'FixedSizeList(1, Int64)'); +#---- +#[1] + + +query ? +select arrow_cast([1], 'FixedSizeList(1, Int64)'); +---- +[1] + +query error DataFusion error: Optimizer rule 'simplify_expressions' failed +select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(4, Int64)'); + +query ? +select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(3, Int64)'); +---- +[1, 2, 3] + +query T +select arrow_typeof(arrow_cast(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 'FixedSizeList(3, Int64)')); +---- +FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, 3) + +query ? +select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)'); +---- +[1, 2, 3] \ No newline at end of file