From 6cec428cefe8ba624201cd77f6f1b398a13ee1c8 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Wed, 14 Aug 2024 11:19:25 +0200 Subject: [PATCH 01/12] Remove LargeUtf8|Binary, Utf8|BinaryView, and Dictionary from ScalarValue --- datafusion/common/src/scalar/mod.rs | 424 +------------ datafusion/common/src/utils/mod.rs | 4 +- .../physical_plan/file_scan_config.rs | 283 +-------- .../core/src/datasource/physical_plan/mod.rs | 4 +- .../datasource/physical_plan/parquet/mod.rs | 23 +- datafusion/core/tests/sql/path_partition.rs | 16 +- datafusion/functions-aggregate/src/min_max.rs | 58 +- .../functions-aggregate/src/string_agg.rs | 11 +- datafusion/functions/src/crypto/basic.rs | 24 +- datafusion/functions/src/datetime/common.rs | 11 +- .../functions/src/datetime/date_part.rs | 2 - .../functions/src/datetime/date_trunc.rs | 3 - .../functions/src/datetime/make_date.rs | 2 +- .../functions/src/datetime/to_timestamp.rs | 34 -- datafusion/functions/src/encoding/inner.rs | 93 +-- datafusion/functions/src/string/ascii.rs | 18 - datafusion/functions/src/string/bit_length.rs | 3 - datafusion/functions/src/string/common.rs | 4 - datafusion/functions/src/string/concat.rs | 4 +- datafusion/functions/src/string/concat_ws.rs | 8 +- datafusion/functions/src/string/initcap.rs | 38 -- .../functions/src/string/octet_length.rs | 36 -- .../functions/src/string/starts_with.rs | 12 +- .../functions/src/unicode/character_length.rs | 22 +- datafusion/functions/src/unicode/lpad.rs | 134 +--- .../optimizer/src/optimize_projections/mod.rs | 15 + .../simplify_expressions/expr_simplifier.rs | 3 +- .../src/simplify_expressions/guarantees.rs | 1 - .../src/unwrap_cast_in_comparison.rs | 114 +--- .../physical-expr/src/expressions/binary.rs | 572 +----------------- .../physical-expr/src/expressions/in_list.rs | 93 --- datafusion/physical-expr/src/utils/mod.rs | 10 +- datafusion/proto-common/src/from_proto/mod.rs | 30 +- datafusion/proto-common/src/to_proto/mod.rs | 32 - .../tests/cases/roundtrip_logical_plan.rs | 27 - .../tests/cases/roundtrip_physical_plan.rs | 8 +- datafusion/sql/src/unparser/expr.rs | 17 - .../substrait/src/logical_plan/consumer.rs | 6 - .../substrait/src/logical_plan/producer.rs | 8 - 39 files changed, 166 insertions(+), 2041 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index fd0c11ed0ab0..b22bae959961 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -224,18 +224,10 @@ pub enum ScalarValue { UInt64(Option), /// utf-8 encoded string. Utf8(Option), - /// utf-8 encoded string but from view types. - Utf8View(Option), - /// utf-8 encoded string representing a LargeString's arrow type. - LargeUtf8(Option), /// binary Binary(Option>), - /// binary but from view types. - BinaryView(Option>), /// fixed size binary FixedSizeBinary(i32, Option>), - /// large binary - LargeBinary(Option>), /// Fixed size list scalar. /// /// The array must be a FixedSizeListArray with length 1. @@ -293,8 +285,6 @@ pub enum ScalarValue { /// `.1`: the list of fields, zero-to-one of which will by set in `.0` /// `.2`: the physical storage of the source/destination UnionArray from which this Scalar came Union(Option<(i8, Box)>, UnionFields, UnionMode), - /// Dictionary type: index type and value - Dictionary(Box, Box), } impl Hash for Fl { @@ -354,18 +344,10 @@ impl PartialEq for ScalarValue { (UInt64(_), _) => false, (Utf8(v1), Utf8(v2)) => v1.eq(v2), (Utf8(_), _) => false, - (Utf8View(v1), Utf8View(v2)) => v1.eq(v2), - (Utf8View(_), _) => false, - (LargeUtf8(v1), LargeUtf8(v2)) => v1.eq(v2), - (LargeUtf8(_), _) => false, (Binary(v1), Binary(v2)) => v1.eq(v2), (Binary(_), _) => false, - (BinaryView(v1), BinaryView(v2)) => v1.eq(v2), - (BinaryView(_), _) => false, (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.eq(v2), (FixedSizeBinary(_, _), _) => false, - (LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2), - (LargeBinary(_), _) => false, (FixedSizeList(v1), FixedSizeList(v2)) => v1.eq(v2), (FixedSizeList(_), _) => false, (List(v1), List(v2)) => v1.eq(v2), @@ -414,8 +396,6 @@ impl PartialEq for ScalarValue { val1.eq(val2) && fields1.eq(fields2) && mode1.eq(mode2) } (Union(_, _, _), _) => false, - (Dictionary(k1, v1), Dictionary(k2, v2)) => k1.eq(k2) && v1.eq(v2), - (Dictionary(_, _), _) => false, (Null, Null) => true, (Null, _) => false, } @@ -483,18 +463,10 @@ impl PartialOrd for ScalarValue { (UInt64(_), _) => None, (Utf8(v1), Utf8(v2)) => v1.partial_cmp(v2), (Utf8(_), _) => None, - (LargeUtf8(v1), LargeUtf8(v2)) => v1.partial_cmp(v2), - (LargeUtf8(_), _) => None, - (Utf8View(v1), Utf8View(v2)) => v1.partial_cmp(v2), - (Utf8View(_), _) => None, (Binary(v1), Binary(v2)) => v1.partial_cmp(v2), (Binary(_), _) => None, - (BinaryView(v1), BinaryView(v2)) => v1.partial_cmp(v2), - (BinaryView(_), _) => None, (FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.partial_cmp(v2), (FixedSizeBinary(_, _), _) => None, - (LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2), - (LargeBinary(_), _) => None, // ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1 (List(arr1), List(arr2)) => partial_cmp_list(arr1.as_ref(), arr2.as_ref()), (FixedSizeList(arr1), FixedSizeList(arr2)) => { @@ -558,15 +530,6 @@ impl PartialOrd for ScalarValue { } } (Union(_, _, _), _) => None, - (Dictionary(k1, v1), Dictionary(k2, v2)) => { - // Don't compare if the key types don't match (it is effectively a different datatype) - if k1 == k2 { - v1.partial_cmp(v2) - } else { - None - } - } - (Dictionary(_, _), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, } @@ -716,10 +679,8 @@ impl std::hash::Hash for ScalarValue { UInt16(v) => v.hash(state), UInt32(v) => v.hash(state), UInt64(v) => v.hash(state), - Utf8(v) | LargeUtf8(v) | Utf8View(v) => v.hash(state), - Binary(v) | FixedSizeBinary(_, v) | LargeBinary(v) | BinaryView(v) => { - v.hash(state) - } + Utf8(v) => v.hash(state), + Binary(v) | FixedSizeBinary(_, v) => v.hash(state), List(arr) => { hash_nested_array(arr.to_owned() as ArrayRef, state); } @@ -757,10 +718,6 @@ impl std::hash::Hash for ScalarValue { t.hash(state); m.hash(state); } - Dictionary(k, v) => { - k.hash(state); - v.hash(state); - } // stable hash for Null value Null => 1.hash(state), } @@ -791,68 +748,6 @@ pub fn get_dict_value( Ok((dict_array.values(), dict_array.key(index))) } -/// Create a dictionary array representing `value` repeated `size` -/// times -fn dict_from_scalar( - value: &ScalarValue, - size: usize, -) -> Result { - // values array is one element long (the value) - let values_array = value.to_array_of_size(1)?; - - // Create a key array with `size` elements, each of 0 - let key_array: PrimitiveArray = std::iter::repeat(if value.is_null() { - None - } else { - Some(K::default_value()) - }) - .take(size) - .collect(); - - // create a new DictionaryArray - // - // Note: this path could be made faster by using the ArrayData - // APIs and skipping validation, if it every comes up in - // performance traces. - Ok(Arc::new( - DictionaryArray::::try_new(key_array, values_array)?, // should always be valid by construction above - )) -} - -/// Create a dictionary array representing all the values in values -fn dict_from_values( - values_array: ArrayRef, -) -> Result { - // Create a key array with `size` elements of 0..array_len for all - // non-null value elements - let key_array: PrimitiveArray = (0..values_array.len()) - .map(|index| { - if values_array.is_valid(index) { - let native_index = K::Native::from_usize(index).ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not create index of type {} from value {}", - K::DATA_TYPE, - index - )) - })?; - Ok(Some(native_index)) - } else { - Ok(None) - } - }) - .collect::>>()? - .into_iter() - .collect(); - - // create a new DictionaryArray - // - // Note: this path could be made faster by using the ArrayData - // APIs and skipping validation, if it every comes up in - // performance traces. - let dict_array = DictionaryArray::::try_new(key_array, values_array)?; - Ok(Arc::new(dict_array)) -} - macro_rules! typed_cast_tz { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ use std::any::type_name; @@ -1279,12 +1174,8 @@ impl ScalarValue { ScalarValue::Float32(_) => DataType::Float32, ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, - ScalarValue::LargeUtf8(_) => DataType::LargeUtf8, - ScalarValue::Utf8View(_) => DataType::Utf8View, ScalarValue::Binary(_) => DataType::Binary, - ScalarValue::BinaryView(_) => DataType::BinaryView, ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz), - ScalarValue::LargeBinary(_) => DataType::LargeBinary, ScalarValue::List(arr) => arr.data_type().to_owned(), ScalarValue::LargeList(arr) => arr.data_type().to_owned(), ScalarValue::FixedSizeList(arr) => arr.data_type().to_owned(), @@ -1314,9 +1205,6 @@ impl ScalarValue { DataType::Duration(TimeUnit::Nanosecond) } ScalarValue::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), - ScalarValue::Dictionary(k, v) => { - DataType::Dictionary(k.clone(), Box::new(v.data_type())) - } ScalarValue::Null => DataType::Null, } } @@ -1548,13 +1436,8 @@ impl ScalarValue { ScalarValue::UInt16(v) => v.is_none(), ScalarValue::UInt32(v) => v.is_none(), ScalarValue::UInt64(v) => v.is_none(), - ScalarValue::Utf8(v) - | ScalarValue::Utf8View(v) - | ScalarValue::LargeUtf8(v) => v.is_none(), - ScalarValue::Binary(v) - | ScalarValue::BinaryView(v) - | ScalarValue::FixedSizeBinary(_, v) - | ScalarValue::LargeBinary(v) => v.is_none(), + ScalarValue::Utf8(v) => v.is_none(), + ScalarValue::Binary(v) | ScalarValue::FixedSizeBinary(_, v) => v.is_none(), // arr.len() should be 1 for a list scalar, but we don't seem to // enforce that anywhere, so we still check against array length. ScalarValue::List(arr) => arr.len() == arr.null_count(), @@ -1583,7 +1466,6 @@ impl ScalarValue { Some((_, s)) => s.is_null(), None => true, }, - ScalarValue::Dictionary(_, v) => v.is_null(), } } @@ -1802,12 +1684,8 @@ impl ScalarValue { DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), - DataType::Utf8View => build_array_string!(StringViewArray, Utf8View), DataType::Utf8 => build_array_string!(StringArray, Utf8), - DataType::LargeUtf8 => build_array_string!(LargeStringArray, LargeUtf8), - DataType::BinaryView => build_array_string!(BinaryViewArray, BinaryView), DataType::Binary => build_array_string!(BinaryArray, Binary), - DataType::LargeBinary => build_array_string!(LargeBinaryArray, LargeBinary), DataType::Date32 => build_array_primitive!(Date32Array, Date32), DataType::Date64 => build_array_primitive!(Date64Array, Date64), DataType::Time32(TimeUnit::Second) => { @@ -1900,40 +1778,6 @@ impl ScalarValue { let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); arrow::compute::concat(arrays.as_slice())? } - DataType::Dictionary(key_type, value_type) => { - // create the values array - let value_scalars = scalars - .map(|scalar| match scalar { - ScalarValue::Dictionary(inner_key_type, scalar) => { - if &inner_key_type == key_type { - Ok(*scalar) - } else { - _internal_err!("Expected inner key type of {key_type} but found: {inner_key_type}, value was ({scalar:?})") - } - } - _ => { - _internal_err!( - "Expected scalar of type {value_type} but found: {scalar} {scalar:?}" - ) - } - }) - .collect::>>()?; - - let values = Self::iter_to_array(value_scalars)?; - assert_eq!(values.data_type(), value_type.as_ref()); - - match key_type.as_ref() { - DataType::Int8 => dict_from_values::(values)?, - DataType::Int16 => dict_from_values::(values)?, - DataType::Int32 => dict_from_values::(values)?, - DataType::Int64 => dict_from_values::(values)?, - DataType::UInt8 => dict_from_values::(values)?, - DataType::UInt16 => dict_from_values::(values)?, - DataType::UInt32 => dict_from_values::(values)?, - DataType::UInt64 => dict_from_values::(values)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), - } - } DataType::FixedSizeBinary(size) => { let array = scalars .map(|sv| { @@ -1964,6 +1808,11 @@ impl ScalarValue { | DataType::Time64(TimeUnit::Millisecond) | DataType::RunEndEncoded(_, _) | DataType::ListView(_) + | DataType::LargeBinary + | DataType::BinaryView + | DataType::LargeUtf8 + | DataType::Utf8View + | DataType::Dictionary(_, _) | DataType::LargeListView(_) => { return _internal_err!( "Unsupported creation of {:?} array from ScalarValue {:?}", @@ -2186,7 +2035,7 @@ impl ScalarValue { } else { Self::iter_to_array(values.iter().cloned()).unwrap() }; - Arc::new(array_into_large_list_array(values)) + Arc::new(array_into_large_list_array(values, true)) } /// Converts a scalar value into an array of `size` rows. @@ -2275,18 +2124,6 @@ impl ScalarValue { } None => new_null_array(&DataType::Utf8, size), }, - ScalarValue::Utf8View(e) => match e { - Some(value) => { - Arc::new(StringViewArray::from_iter_values(repeat(value).take(size))) - } - None => new_null_array(&DataType::Utf8View, size), - }, - ScalarValue::LargeUtf8(e) => match e { - Some(value) => { - Arc::new(LargeStringArray::from_iter_values(repeat(value).take(size))) - } - None => new_null_array(&DataType::LargeUtf8, size), - }, ScalarValue::Binary(e) => match e { Some(value) => Arc::new( repeat(Some(value.as_slice())) @@ -2297,16 +2134,6 @@ impl ScalarValue { Arc::new(repeat(None::<&str>).take(size).collect::()) } }, - ScalarValue::BinaryView(e) => match e { - Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), - ), - None => { - Arc::new(repeat(None::<&str>).take(size).collect::()) - } - }, ScalarValue::FixedSizeBinary(s, e) => match e { Some(value) => Arc::new( FixedSizeBinaryArray::try_from_sparse_iter_with_size( @@ -2323,18 +2150,6 @@ impl ScalarValue { .unwrap(), ), }, - ScalarValue::LargeBinary(e) => match e { - Some(value) => Arc::new( - repeat(Some(value.as_slice())) - .take(size) - .collect::(), - ), - None => Arc::new( - repeat(None::<&str>) - .take(size) - .collect::(), - ), - }, ScalarValue::List(arr) => { Self::list_to_array_of_size(arr.as_ref() as &dyn Array, size)? } @@ -2473,20 +2288,6 @@ impl ScalarValue { new_null_array(&dt, size) } }, - ScalarValue::Dictionary(key_type, v) => { - // values array is one element long (the value) - match key_type.as_ref() { - DataType::Int8 => dict_from_scalar::(v, size)?, - DataType::Int16 => dict_from_scalar::(v, size)?, - DataType::Int32 => dict_from_scalar::(v, size)?, - DataType::Int64 => dict_from_scalar::(v, size)?, - DataType::UInt8 => dict_from_scalar::(v, size)?, - DataType::UInt16 => dict_from_scalar::(v, size)?, - DataType::UInt32 => dict_from_scalar::(v, size)?, - DataType::UInt64 => dict_from_scalar::(v, size)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), - } - } ScalarValue::Null => new_null_array(&DataType::Null, size), }) } @@ -2651,17 +2452,11 @@ impl ScalarValue { DataType::Int16 => typed_cast!(array, index, Int16Array, Int16)?, DataType::Int8 => typed_cast!(array, index, Int8Array, Int8)?, DataType::Binary => typed_cast!(array, index, BinaryArray, Binary)?, - DataType::LargeBinary => { - typed_cast!(array, index, LargeBinaryArray, LargeBinary)? - } - DataType::BinaryView => { - typed_cast!(array, index, BinaryViewArray, BinaryView)? - } + DataType::LargeBinary => typed_cast!(array, index, LargeBinaryArray, Binary)?, + DataType::BinaryView => typed_cast!(array, index, BinaryViewArray, Binary)?, DataType::Utf8 => typed_cast!(array, index, StringArray, Utf8)?, - DataType::LargeUtf8 => { - typed_cast!(array, index, LargeStringArray, LargeUtf8)? - } - DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8View)?, + DataType::LargeUtf8 => typed_cast!(array, index, LargeStringArray, Utf8)?, + DataType::Utf8View => typed_cast!(array, index, StringViewArray, Utf8)?, DataType::List(field) => { let list_array = array.as_list::(); let nested_array = list_array.value(index); @@ -2671,11 +2466,14 @@ impl ScalarValue { ScalarValue::List(arr) } - DataType::LargeList(_) => { + DataType::LargeList(field) => { let list_array = as_large_list_array(array); let nested_array = list_array.value(index); // Produces a single element `LargeListArray` with the value at `index`. - let arr = Arc::new(array_into_large_list_array(nested_array)); + let arr = Arc::new(array_into_large_list_array( + nested_array, + field.is_nullable(), + )); ScalarValue::LargeList(arr) } @@ -2745,15 +2543,13 @@ impl ScalarValue { _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), }; // look up the index in the values dictionary - let value = match values_index { + match values_index { Some(values_index) => { ScalarValue::try_from_array(values_array, values_index) } // else entry was null, so return null None => values_array.data_type().try_into(), - }?; - - Self::Dictionary(key_type.clone(), Box::new(value)) + }? } DataType::Struct(_) => { let a = array.slice(index, 1); @@ -2897,6 +2693,7 @@ impl ScalarValue { /// Panics if `self` is a dictionary with invalid key type #[inline] pub fn eq_array(&self, array: &ArrayRef, index: usize) -> Result { + // TODO(@notfilippo): maybe match on the array DataType instead of self Ok(match self { ScalarValue::Decimal128(v, precision, scale) => { ScalarValue::eq_array_decimal( @@ -2953,24 +2750,12 @@ impl ScalarValue { ScalarValue::Utf8(val) => { eq_array_primitive!(array, index, StringArray, val)? } - ScalarValue::Utf8View(val) => { - eq_array_primitive!(array, index, StringViewArray, val)? - } - ScalarValue::LargeUtf8(val) => { - eq_array_primitive!(array, index, LargeStringArray, val)? - } ScalarValue::Binary(val) => { eq_array_primitive!(array, index, BinaryArray, val)? } - ScalarValue::BinaryView(val) => { - eq_array_primitive!(array, index, BinaryViewArray, val)? - } ScalarValue::FixedSizeBinary(_, val) => { eq_array_primitive!(array, index, FixedSizeBinaryArray, val)? } - ScalarValue::LargeBinary(val) => { - eq_array_primitive!(array, index, LargeBinaryArray, val)? - } ScalarValue::List(arr) => { Self::eq_array_list(&(arr.to_owned() as ArrayRef), array, index) } @@ -3040,24 +2825,6 @@ impl ScalarValue { ScalarValue::Union(_, _, _) => { return _not_impl_err!("Union is not supported yet") } - ScalarValue::Dictionary(key_type, v) => { - let (values_array, values_index) = match key_type.as_ref() { - DataType::Int8 => get_dict_value::(array, index)?, - DataType::Int16 => get_dict_value::(array, index)?, - DataType::Int32 => get_dict_value::(array, index)?, - DataType::Int64 => get_dict_value::(array, index)?, - DataType::UInt8 => get_dict_value::(array, index)?, - DataType::UInt16 => get_dict_value::(array, index)?, - DataType::UInt32 => get_dict_value::(array, index)?, - DataType::UInt64 => get_dict_value::(array, index)?, - _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), - }; - // was the value in the array non null? - match values_index { - Some(values_index) => v.eq_array(values_array, values_index)?, - None => v.is_null(), - } - } ScalarValue::Null => array.is_null(index), }) } @@ -3100,9 +2867,7 @@ impl ScalarValue { | ScalarValue::DurationMillisecond(_) | ScalarValue::DurationMicrosecond(_) | ScalarValue::DurationNanosecond(_) => 0, - ScalarValue::Utf8(s) - | ScalarValue::LargeUtf8(s) - | ScalarValue::Utf8View(s) => { + ScalarValue::Utf8(s) => { s.as_ref().map(|s| s.capacity()).unwrap_or_default() } ScalarValue::TimestampSecond(_, s) @@ -3111,10 +2876,7 @@ impl ScalarValue { | ScalarValue::TimestampNanosecond(_, s) => { s.as_ref().map(|s| s.len()).unwrap_or_default() } - ScalarValue::Binary(b) - | ScalarValue::FixedSizeBinary(_, b) - | ScalarValue::LargeBinary(b) - | ScalarValue::BinaryView(b) => { + ScalarValue::Binary(b) | ScalarValue::FixedSizeBinary(_, b) => { b.as_ref().map(|b| b.capacity()).unwrap_or_default() } ScalarValue::List(arr) => arr.get_array_memory_size(), @@ -3131,10 +2893,6 @@ impl ScalarValue { + (std::mem::size_of::() * fields.len()) + fields.iter().map(|(_idx, field)| field.size() - std::mem::size_of_val(field)).sum::() } - ScalarValue::Dictionary(dt, sv) => { - // `dt` and `sv` are boxed, so they are NOT already included in `self` - dt.size() + sv.size() - } } } @@ -3381,12 +3139,12 @@ impl TryFrom<&DataType> for ScalarValue { ScalarValue::Decimal256(None, *precision, *scale) } DataType::Utf8 => ScalarValue::Utf8(None), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), - DataType::Utf8View => ScalarValue::Utf8View(None), + DataType::LargeUtf8 => ScalarValue::Utf8(None), + DataType::Utf8View => ScalarValue::Utf8(None), DataType::Binary => ScalarValue::Binary(None), - DataType::BinaryView => ScalarValue::BinaryView(None), + DataType::BinaryView => ScalarValue::Binary(None), DataType::FixedSizeBinary(len) => ScalarValue::FixedSizeBinary(*len, None), - DataType::LargeBinary => ScalarValue::LargeBinary(None), + DataType::LargeBinary => ScalarValue::Binary(None), DataType::Date32 => ScalarValue::Date32(None), DataType::Date64 => ScalarValue::Date64(None), DataType::Time32(TimeUnit::Second) => ScalarValue::Time32Second(None), @@ -3428,10 +3186,7 @@ impl TryFrom<&DataType> for ScalarValue { DataType::Duration(TimeUnit::Nanosecond) => { ScalarValue::DurationNanosecond(None) } - DataType::Dictionary(index_type, value_type) => ScalarValue::Dictionary( - index_type.clone(), - Box::new(value_type.as_ref().try_into()?), - ), + DataType::Dictionary(_, value_type) => Self::try_from(value_type.as_ref())?, // `ScalaValue::List` contains single element `ListArray`. DataType::List(field_ref) => ScalarValue::List(Arc::new( GenericListArray::new_null(Arc::clone(field_ref), 1), @@ -3512,13 +3267,8 @@ impl fmt::Display for ScalarValue { ScalarValue::TimestampMillisecond(e, _) => format_option!(f, e)?, ScalarValue::TimestampMicrosecond(e, _) => format_option!(f, e)?, ScalarValue::TimestampNanosecond(e, _) => format_option!(f, e)?, - ScalarValue::Utf8(e) - | ScalarValue::LargeUtf8(e) - | ScalarValue::Utf8View(e) => format_option!(f, e)?, - ScalarValue::Binary(e) - | ScalarValue::FixedSizeBinary(_, e) - | ScalarValue::LargeBinary(e) - | ScalarValue::BinaryView(e) => match e { + ScalarValue::Utf8(e) => format_option!(f, e)?, + ScalarValue::Binary(e) | ScalarValue::FixedSizeBinary(_, e) => match e { Some(l) => write!( f, "{}", @@ -3631,7 +3381,6 @@ impl fmt::Display for ScalarValue { Some((id, val)) => write!(f, "{}:{}", id, val)?, None => write!(f, "NULL")?, }, - ScalarValue::Dictionary(_k, v) => write!(f, "{v}")?, ScalarValue::Null => write!(f, "NULL")?, }; Ok(()) @@ -3679,22 +3428,14 @@ impl fmt::Debug for ScalarValue { } ScalarValue::Utf8(None) => write!(f, "Utf8({self})"), ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{self}\")"), - ScalarValue::Utf8View(None) => write!(f, "Utf8View({self})"), - ScalarValue::Utf8View(Some(_)) => write!(f, "Utf8View(\"{self}\")"), - ScalarValue::LargeUtf8(None) => write!(f, "LargeUtf8({self})"), - ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{self}\")"), ScalarValue::Binary(None) => write!(f, "Binary({self})"), ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{self}\")"), - ScalarValue::BinaryView(None) => write!(f, "BinaryView({self})"), - ScalarValue::BinaryView(Some(_)) => write!(f, "BinaryView(\"{self}\")"), ScalarValue::FixedSizeBinary(size, None) => { write!(f, "FixedSizeBinary({size}, {self})") } ScalarValue::FixedSizeBinary(size, Some(_)) => { write!(f, "FixedSizeBinary({size}, \"{self}\")") } - ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({self})"), - ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{self}\")"), ScalarValue::FixedSizeList(_) => write!(f, "FixedSizeList({self})"), ScalarValue::List(_) => write!(f, "List({self})"), ScalarValue::LargeList(_) => write!(f, "LargeList({self})"), @@ -3782,7 +3523,6 @@ impl fmt::Debug for ScalarValue { Some((id, val)) => write!(f, "Union {}:{}", id, val), None => write!(f, "Union(NULL)"), }, - ScalarValue::Dictionary(k, v) => write!(f, "Dictionary({k:?}, {v:?})"), ScalarValue::Null => write!(f, "NULL"), } } @@ -3834,9 +3574,7 @@ impl ScalarType for Date32Type { mod tests { use super::*; - use crate::cast::{ - as_map_array, as_string_array, as_struct_array, as_uint32_array, as_uint64_array, - }; + use crate::cast::{as_map_array, as_struct_array, as_uint32_array, as_uint64_array}; use crate::assert_batches_eq; use crate::utils::array_into_list_array_nullable; @@ -4805,21 +4543,11 @@ mod tests { StringArray, vec![Some("foo"), None, Some("bar")] ); - check_scalar_iter_string!( - LargeUtf8, - LargeStringArray, - vec![Some("foo"), None, Some("bar")] - ); check_scalar_iter_binary!( Binary, BinaryArray, vec![Some(b"foo"), None, Some(b"bar")] ); - check_scalar_iter_binary!( - LargeBinary, - LargeBinaryArray, - vec![Some(b"foo"), None, Some(b"bar")] - ); } #[test] @@ -4836,38 +4564,6 @@ mod tests { ); } - #[test] - fn scalar_iter_to_dictionary() { - fn make_val(v: Option) -> ScalarValue { - let key_type = DataType::Int32; - let value = ScalarValue::Utf8(v); - ScalarValue::Dictionary(Box::new(key_type), Box::new(value)) - } - - let scalars = [ - make_val(Some("Foo".into())), - make_val(None), - make_val(Some("Bar".into())), - ]; - - let array = ScalarValue::iter_to_array(scalars).unwrap(); - let array = as_dictionary_array::(&array).unwrap(); - let values_array = as_string_array(array.values()).unwrap(); - - let values = array - .keys_iter() - .map(|k| { - k.map(|k| { - assert!(values_array.is_valid(k)); - values_array.value(k) - }) - }) - .collect::>(); - - let expected = vec![Some("Foo"), None, Some("Bar")]; - assert_eq!(values, expected); - } - #[test] fn scalar_iter_to_array_mismatched_types() { use ScalarValue::*; @@ -4997,18 +4693,6 @@ mod tests { assert_ne!(list_scalar, nested_list_scalar); } - #[test] - fn scalar_try_from_dict_datatype() { - let data_type = - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); - let data_type = &data_type; - let expected = ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Utf8(None)), - ); - assert_eq!(expected, data_type.try_into().unwrap()) - } - #[test] fn size_of_scalar() { // Since ScalarValues are used in a non trivial number of places, @@ -5149,29 +4833,6 @@ mod tests { }}; } - /// create a test case for DictionaryArray<$INDEX_TY> - macro_rules! make_str_dict_test_case { - ($INPUT:expr, $INDEX_TY:ident) => {{ - TestCase { - array: Arc::new( - $INPUT - .iter() - .cloned() - .collect::>(), - ), - scalars: $INPUT - .iter() - .map(|v| { - ScalarValue::Dictionary( - Box::new($INDEX_TY::DATA_TYPE), - Box::new(ScalarValue::Utf8(v.map(|v| v.to_string()))), - ) - }) - .collect(), - } - }}; - } - let cases = vec![ make_test_case!(bool_vals, BooleanArray, Boolean), make_test_case!(f32_vals, Float32Array, Float32), @@ -5185,9 +4846,7 @@ mod tests { make_test_case!(u32_vals, UInt32Array, UInt32), make_test_case!(u64_vals, UInt64Array, UInt64), make_str_test_case!(str_vals, StringArray, Utf8), - make_str_test_case!(str_vals, LargeStringArray, LargeUtf8), make_binary_test_case!(str_vals, BinaryArray, Binary), - make_binary_test_case!(str_vals, LargeBinaryArray, LargeBinary), make_test_case!(i32_vals, Date32Array, Date32), make_test_case!(i64_vals, Date64Array, Date64), make_test_case!(i32_vals, Time32SecondArray, Time32Second), @@ -5244,14 +4903,6 @@ mod tests { IntervalMonthDayNanoArray, IntervalMonthDayNano ), - make_str_dict_test_case!(str_vals, Int8Type), - make_str_dict_test_case!(str_vals, Int16Type), - make_str_dict_test_case!(str_vals, Int32Type), - make_str_dict_test_case!(str_vals, Int64Type), - make_str_dict_test_case!(str_vals, UInt8Type), - make_str_dict_test_case!(str_vals, UInt16Type), - make_str_dict_test_case!(str_vals, UInt32Type), - make_str_dict_test_case!(str_vals, UInt64Type), ]; for case in cases { @@ -6678,15 +6329,4 @@ mod tests { ); assert!(dense_scalar.is_null()); } - - #[test] - fn null_dictionary_scalar_produces_null_dictionary_array() { - let dictionary_scalar = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Null), - ); - assert!(dictionary_scalar.is_null()); - let dictionary_array = dictionary_scalar.to_array().unwrap(); - assert!(dictionary_array.is_null(0)); - } } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index bf506c0551eb..a5708ae6d77f 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -378,10 +378,10 @@ pub fn array_into_list_array(arr: ArrayRef, nullable: bool) -> ListArray { /// Wrap an array into a single element `LargeListArray`. /// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]` -pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray { +pub fn array_into_large_list_array(arr: ArrayRef, nullable: bool) -> LargeListArray { let offsets = OffsetBuffer::from_lengths([arr.len()]); LargeListArray::new( - Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)), + Arc::new(Field::new_list_field(arr.data_type().to_owned(), nullable)), offsets, arr, None, diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 34fb6226c1a2..363b10fb0039 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -18,9 +18,7 @@ //! [`FileScanConfig`] to configure scanning of possibly partitioned //! file sources. -use std::{ - borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, -}; +use std::{collections::HashMap, sync::Arc, vec}; use super::{ get_projected_output_ordering, statistics::MinMaxStatistics, FileGroupPartitioner, @@ -28,40 +26,12 @@ use super::{ use crate::datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}; use crate::{error::Result, scalar::ScalarValue}; -use arrow::array::{ArrayData, BufferBuilder}; -use arrow::buffer::Buffer; -use arrow::datatypes::{ArrowNativeType, UInt16Type}; -use arrow_array::{ArrayRef, DictionaryArray, RecordBatch, RecordBatchOptions}; -use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow_schema::{Field, Schema, SchemaRef}; use datafusion_common::stats::Precision; use datafusion_common::{exec_err, ColumnStatistics, DataFusionError, Statistics}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; -use log::warn; - -/// Convert type to a type suitable for use as a [`ListingTable`] -/// partition column. Returns `Dictionary(UInt16, val_type)`, which is -/// a reasonable trade off between a reasonable number of partition -/// values and space efficiency. -/// -/// This use this to specify types for partition columns. However -/// you MAY also choose not to dictionary-encode the data or to use a -/// different dictionary type. -/// -/// Use [`wrap_partition_value_in_dict`] to wrap a [`ScalarValue`] in the same say. -/// -/// [`ListingTable`]: crate::datasource::listing::ListingTable -pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { - DataType::Dictionary(Box::new(DataType::UInt16), Box::new(val_type)) -} - -/// Convert a [`ScalarValue`] of partition columns to a type, as -/// described in the documentation of [`wrap_partition_type_in_dict`], -/// which can wrap the types. -pub fn wrap_partition_value_in_dict(val: ScalarValue) -> ScalarValue { - ScalarValue::Dictionary(Box::new(DataType::UInt16), Box::new(val)) -} - /// The base configurations to provide when creating a physical plan for /// any given file format. /// @@ -382,10 +352,6 @@ impl FileScanConfig { /// have all their keys equal to 0. This enables us to re-use the same "all-zero" buffer across batches, /// which makes the space consumption of the partition columns O(batch_size) instead of O(record_count). pub struct PartitionColumnProjector { - /// An Arrow buffer initialized to zeros that represents the key array of all partition - /// columns (partition columns are materialized by dictionary arrays with only one - /// value in the dictionary, thus all the keys are equal to zero). - key_buffer_cache: ZeroBufferGenerators, /// Mapping between the indexes in the list of partition columns and the target /// schema. Sorted by index in the target schema so that we can iterate on it to /// insert the partition columns in the target record batch. @@ -411,7 +377,6 @@ impl PartitionColumnProjector { Self { projected_partition_indexes, - key_buffer_cache: Default::default(), projected_schema, } } @@ -445,30 +410,7 @@ impl PartitionColumnProjector { "Invalid partitioning found on disk".to_string(), ))?; - let mut partition_value = Cow::Borrowed(p_value); - - // check if user forgot to dict-encode the partition value - let field = self.projected_schema.field(sidx); - let expected_data_type = field.data_type(); - let actual_data_type = partition_value.data_type(); - if let DataType::Dictionary(key_type, _) = expected_data_type { - if !matches!(actual_data_type, DataType::Dictionary(_, _)) { - warn!("Partition value for column {} was not dictionary-encoded, applied auto-fix.", field.name()); - partition_value = Cow::Owned(ScalarValue::Dictionary( - key_type.clone(), - Box::new(partition_value.as_ref().clone()), - )); - } - } - - cols.insert( - sidx, - create_output_array( - &mut self.key_buffer_cache, - partition_value.as_ref(), - file_batch.num_rows(), - )?, - ) + cols.insert(sidx, create_output_array(p_value, file_batch.num_rows())?) } RecordBatch::try_new_with_options( @@ -480,155 +422,17 @@ impl PartitionColumnProjector { } } -#[derive(Debug, Default)] -struct ZeroBufferGenerators { - gen_i8: ZeroBufferGenerator, - gen_i16: ZeroBufferGenerator, - gen_i32: ZeroBufferGenerator, - gen_i64: ZeroBufferGenerator, - gen_u8: ZeroBufferGenerator, - gen_u16: ZeroBufferGenerator, - gen_u32: ZeroBufferGenerator, - gen_u64: ZeroBufferGenerator, -} - -/// Generate a arrow [`Buffer`] that contains zero values. -#[derive(Debug, Default)] -struct ZeroBufferGenerator -where - T: ArrowNativeType, -{ - cache: Option, - _t: PhantomData, -} - -impl ZeroBufferGenerator -where - T: ArrowNativeType, -{ - const SIZE: usize = std::mem::size_of::(); - - fn get_buffer(&mut self, n_vals: usize) -> Buffer { - match &mut self.cache { - Some(buf) if buf.len() >= n_vals * Self::SIZE => { - buf.slice_with_length(0, n_vals * Self::SIZE) - } - _ => { - let mut key_buffer_builder = BufferBuilder::::new(n_vals); - key_buffer_builder.advance(n_vals); // keys are all 0 - self.cache.insert(key_buffer_builder.finish()).clone() - } - } - } -} - -fn create_dict_array( - buffer_gen: &mut ZeroBufferGenerator, - dict_val: &ScalarValue, - len: usize, - data_type: DataType, -) -> Result -where - T: ArrowNativeType, -{ - let dict_vals = dict_val.to_array()?; - - let sliced_key_buffer = buffer_gen.get_buffer(len); - - // assemble pieces together - let mut builder = ArrayData::builder(data_type) - .len(len) - .add_buffer(sliced_key_buffer); - builder = builder.add_child_data(dict_vals.to_data()); - Ok(Arc::new(DictionaryArray::::from( - builder.build().unwrap(), - ))) -} - -fn create_output_array( - key_buffer_cache: &mut ZeroBufferGenerators, - val: &ScalarValue, - len: usize, -) -> Result { - if let ScalarValue::Dictionary(key_type, dict_val) = &val { - match key_type.as_ref() { - DataType::Int8 => { - return create_dict_array( - &mut key_buffer_cache.gen_i8, - dict_val, - len, - val.data_type(), - ); - } - DataType::Int16 => { - return create_dict_array( - &mut key_buffer_cache.gen_i16, - dict_val, - len, - val.data_type(), - ); - } - DataType::Int32 => { - return create_dict_array( - &mut key_buffer_cache.gen_i32, - dict_val, - len, - val.data_type(), - ); - } - DataType::Int64 => { - return create_dict_array( - &mut key_buffer_cache.gen_i64, - dict_val, - len, - val.data_type(), - ); - } - DataType::UInt8 => { - return create_dict_array( - &mut key_buffer_cache.gen_u8, - dict_val, - len, - val.data_type(), - ); - } - DataType::UInt16 => { - return create_dict_array( - &mut key_buffer_cache.gen_u16, - dict_val, - len, - val.data_type(), - ); - } - DataType::UInt32 => { - return create_dict_array( - &mut key_buffer_cache.gen_u32, - dict_val, - len, - val.data_type(), - ); - } - DataType::UInt64 => { - return create_dict_array( - &mut key_buffer_cache.gen_u64, - dict_val, - len, - val.data_type(), - ); - } - _ => {} - } - } - +fn create_output_array(val: &ScalarValue, len: usize) -> Result { + // TODO(@notfilippo): should we reintroduce a way to encode as dictionaries? val.to_array_of_size(len) } #[cfg(test)] mod tests { - use arrow_array::Int32Array; - use super::*; use crate::{test::columns, test_util::aggr_test_schema}; + use arrow_array::Int32Array; + use arrow_schema::DataType; #[test] fn physical_plan_config_no_projection() { @@ -637,10 +441,7 @@ mod tests { Arc::clone(&file_schema), None, Statistics::new_unknown(&file_schema), - to_partition_cols(vec![( - "date".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - )]), + to_partition_cols(vec![("date".to_owned(), DataType::Utf8)]), ); let (proj_schema, proj_statistics, _) = conf.project(); @@ -669,11 +470,9 @@ mod tests { // make a table_partition_col as a field let table_partition_col = - Field::new("date", wrap_partition_type_in_dict(DataType::Utf8), true) - .with_metadata(HashMap::from_iter(vec![( - "key_whatever".to_owned(), - "value_whatever".to_owned(), - )])); + Field::new("date", DataType::Utf8, true).with_metadata(HashMap::from_iter( + vec![("key_whatever".to_owned(), "value_whatever".to_owned())], + )); let conf = config_for_projection( Arc::clone(&file_schema), @@ -710,10 +509,7 @@ mod tests { .collect(), total_byte_size: Precision::Absent, }, - to_partition_cols(vec![( - "date".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - )]), + to_partition_cols(vec![("date".to_owned(), DataType::Utf8)]), ); let (proj_schema, proj_statistics, _) = conf.project(); @@ -742,18 +538,9 @@ mod tests { ("c", &vec![10, 11, 12]), ); let partition_cols = vec![ - ( - "year".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), - ( - "month".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), - ( - "day".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), + ("year".to_owned(), DataType::Utf8), + ("month".to_owned(), DataType::Utf8), + ("day".to_owned(), DataType::Utf8), ]; // create a projected schema let conf = config_for_projection( @@ -785,9 +572,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::from("2021")), - wrap_partition_value_in_dict(ScalarValue::from("10")), - wrap_partition_value_in_dict(ScalarValue::from("26")), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("26"), ], ) .expect("Projection of partition columns into record batch failed"); @@ -813,9 +600,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::from("2021")), - wrap_partition_value_in_dict(ScalarValue::from("10")), - wrap_partition_value_in_dict(ScalarValue::from("27")), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("27"), ], ) .expect("Projection of partition columns into record batch failed"); @@ -843,9 +630,9 @@ mod tests { // file_batch is ok here because we kept all the file cols in the projection file_batch, &[ - wrap_partition_value_in_dict(ScalarValue::from("2021")), - wrap_partition_value_in_dict(ScalarValue::from("10")), - wrap_partition_value_in_dict(ScalarValue::from("28")), + ScalarValue::from("2021"), + ScalarValue::from("10"), + ScalarValue::from("28"), ], ) .expect("Projection of partition columns into record batch failed"); @@ -893,14 +680,8 @@ mod tests { fn test_projected_file_schema_with_partition_col() { let schema = aggr_test_schema(); let partition_cols = vec![ - ( - "part1".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), - ( - "part2".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), + ("part1".to_owned(), DataType::Utf8), + ("part2".to_owned(), DataType::Utf8), ]; // Projected file schema for config with projection including partition column @@ -926,14 +707,8 @@ mod tests { fn test_projected_file_schema_without_projection() { let schema = aggr_test_schema(); let partition_cols = vec![ - ( - "part1".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), - ( - "part2".to_owned(), - wrap_partition_type_in_dict(DataType::Utf8), - ), + ("part1".to_owned(), DataType::Utf8), + ("part2".to_owned(), DataType::Utf8), ]; // Projected file schema for config without projection diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index f810fb86bd89..a6f0d6e1c5a0 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -37,9 +37,7 @@ pub use arrow_file::ArrowExec; pub use avro::AvroExec; pub use csv::{CsvConfig, CsvExec, CsvExecBuilder, CsvOpener}; pub use file_groups::FileGroupPartitioner; -pub use file_scan_config::{ - wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, -}; +pub use file_scan_config::FileScanConfig; pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; pub use json::{JsonOpener, NdJsonExec}; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs index 72aabefba595..dd212077ae5d 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/mod.rs @@ -1619,10 +1619,7 @@ mod tests { partition_values: vec![ ScalarValue::from("2021"), ScalarValue::UInt8(Some(10)), - ScalarValue::Dictionary( - Box::new(DataType::UInt16), - Box::new(ScalarValue::from("26")), - ), + ScalarValue::from("26"), ], range: None, statistics: None, @@ -1634,14 +1631,7 @@ mod tests { Field::new("bool_col", DataType::Boolean, true), Field::new("tinyint_col", DataType::Int32, true), Field::new("month", DataType::UInt8, false), - Field::new( - "day", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - false, - ), + Field::new("day", DataType::Utf8, false), ]); let parquet_exec = ParquetExec::builder( @@ -1652,14 +1642,7 @@ mod tests { .with_table_partition_cols(vec![ Field::new("year", DataType::Utf8, false), Field::new("month", DataType::UInt8, false), - Field::new( - "day", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - false, - ), + Field::new("day", DataType::Utf8, false), ]), ) .build(); diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 7e7544bdb7c0..82a24a57c49c 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -64,13 +64,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { ], &[ ("year", DataType::Int32), - ( - "month", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - ), + ("month", DataType::Utf8), ("day", DataType::Utf8), ], "mirror:///", @@ -170,7 +164,7 @@ async fn parquet_distinct_partition_col() -> Result<()> { let s = ScalarValue::try_from_array(results[0].column(1), 0)?; let month = match extract_as_utf(&s) { Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), + _ => panic!("Expected month as Utf8 found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -192,10 +186,8 @@ async fn parquet_distinct_partition_col() -> Result<()> { } fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } + if let ScalarValue::Utf8(v) = v { + return v.clone(); } None } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index f9a08631bfb9..42a423a91276 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -436,15 +436,10 @@ fn min_batch(values: &ArrayRef) -> Result { typed_min_max_batch_string!(values, StringArray, Utf8, min_string) } DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) + typed_min_max_batch_string!(values, LargeStringArray, Utf8, min_string) } DataType::Utf8View => { - typed_min_max_batch_string!( - values, - StringViewArray, - Utf8View, - min_string_view - ) + typed_min_max_batch_string!(values, StringViewArray, Utf8, min_string_view) } DataType::Boolean => { typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) @@ -453,20 +448,10 @@ fn min_batch(values: &ArrayRef) -> Result { typed_min_max_batch_binary!(&values, BinaryArray, Binary, min_binary) } DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - min_binary - ) + typed_min_max_batch_binary!(&values, LargeBinaryArray, Binary, min_binary) } DataType::BinaryView => { - typed_min_max_batch_binary!( - &values, - BinaryViewArray, - BinaryView, - min_binary_view - ) + typed_min_max_batch_binary!(&values, BinaryViewArray, Binary, min_binary_view) } _ => min_max_batch!(values, min), }) @@ -479,15 +464,10 @@ fn max_batch(values: &ArrayRef) -> Result { typed_min_max_batch_string!(values, StringArray, Utf8, max_string) } DataType::LargeUtf8 => { - typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) + typed_min_max_batch_string!(values, LargeStringArray, Utf8, max_string) } DataType::Utf8View => { - typed_min_max_batch_string!( - values, - StringViewArray, - Utf8View, - max_string_view - ) + typed_min_max_batch_string!(values, StringViewArray, Utf8, max_string_view) } DataType::Boolean => { typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) @@ -496,20 +476,10 @@ fn max_batch(values: &ArrayRef) -> Result { typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) } DataType::BinaryView => { - typed_min_max_batch_binary!( - &values, - BinaryViewArray, - BinaryView, - max_binary_view - ) + typed_min_max_batch_binary!(&values, BinaryViewArray, Binary, max_binary_view) } DataType::LargeBinary => { - typed_min_max_batch_binary!( - &values, - LargeBinaryArray, - LargeBinary, - max_binary - ) + typed_min_max_batch_binary!(&values, LargeBinaryArray, Binary, max_binary) } _ => min_max_batch!(values, max), }) @@ -643,21 +613,9 @@ macro_rules! min_max { (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { typed_min_max_string!(lhs, rhs, Utf8, $OP) } - (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) - } - (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8View, $OP) - } (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { typed_min_max_string!(lhs, rhs, Binary, $OP) } - (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeBinary, $OP) - } - (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { - typed_min_max_string!(lhs, rhs, BinaryView, $OP) - } (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) } diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index a7e9a37e23ad..510b3b535dc9 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -85,13 +85,12 @@ impl AggregateUDFImpl for StringAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { return match lit.value() { - ScalarValue::Utf8(Some(delimiter)) - | ScalarValue::LargeUtf8(Some(delimiter)) => { + ScalarValue::Utf8(Some(delimiter)) => { Ok(Box::new(StringAggAccumulator::new(delimiter.as_str()))) } - ScalarValue::Utf8(None) - | ScalarValue::LargeUtf8(None) - | ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))), + ScalarValue::Utf8(None) | ScalarValue::Null => { + Ok(Box::new(StringAggAccumulator::new(""))) + } e => not_impl_err!("StringAgg not supported for delimiter {}", e), }; } @@ -142,7 +141,7 @@ impl Accumulator for StringAggAccumulator { } fn evaluate(&mut self) -> Result { - Ok(ScalarValue::LargeUtf8(self.values.clone())) + Ok(ScalarValue::Utf8(self.values.clone())) } fn size(&self) -> usize { diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index 716afd84a9c9..f3015c24b3fa 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -121,9 +121,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result { } let digest_algorithm = match &args[1] { ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { - method.parse::() - } + ScalarValue::Utf8(Some(method)) => method.parse::(), other => exec_err!("Unsupported data type {other:?} for function digest"), }, ColumnarValue::Array(_) => { @@ -338,16 +336,16 @@ pub fn digest_process( "Unsupported data type {other:?} for function {digest_algorithm}" ), }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { - Ok(digest_algorithm - .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) + ColumnarValue::Scalar(scalar) => { + match scalar { + ScalarValue::Utf8(a) => Ok(digest_algorithm + .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), + ScalarValue::Binary(a) => Ok(digest_algorithm + .digest_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), + other => exec_err!( + "Unsupported data type {other:?} for function {digest_algorithm}" + ), } - ScalarValue::Binary(a) | ScalarValue::LargeBinary(a) => Ok(digest_algorithm - .digest_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), - other => exec_err!( - "Unsupported data type {other:?} for function {digest_algorithm}" - ), - }, + } } } diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index 6048eeeaa554..e0d775e602d6 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -171,7 +171,7 @@ where other => exec_err!("Unsupported data type {other:?} for function {name}"), }, ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + ScalarValue::Utf8(a) => { let result = a.as_ref().map(|x| (op)(x)).transpose()?; Ok(ColumnarValue::Scalar(S::scalar(result))) } @@ -228,7 +228,7 @@ where }, // if the first argument is a scalar utf8 all arguments are expected to be scalar utf8 ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => { + ScalarValue::Utf8(a) => { let a = a.as_ref(); // ASK: Why do we trust `a` to be non-null at this point? let a = unwrap_or_internal_err!(a); @@ -236,10 +236,7 @@ where let mut ret = None; for (pos, v) in args.iter().enumerate().skip(1) { - let ColumnarValue::Scalar( - ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x), - ) = v - else { + let ColumnarValue::Scalar(ScalarValue::Utf8(x)) = v else { return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); }; @@ -304,7 +301,7 @@ where Ok(Either::Left(as_generic_string_array::(a.as_ref())?)) } ColumnarValue::Scalar(s) => match s { - ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => Ok(Either::Right(a)), + ScalarValue::Utf8(a) => Ok(Either::Right(a)), other => exec_err!( "Unexpected scalar type encountered '{other}' for function '{name}'" ), diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index e24b11aeb71f..f4ea165e174e 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -140,8 +140,6 @@ impl ScalarUDFImpl for DatePartFunc { let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { v - } else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = part { - v } else { return exec_err!( "First argument of `DATE_PART` must be non-null scalar Utf8" diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 308ea668d3d7..3bb22ce7913a 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -139,9 +139,6 @@ impl ScalarUDFImpl for DateTruncFunc { let granularity = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = granularity - { - v.to_lowercase() - } else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = granularity { v.to_lowercase() } else { diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index ded7b454f9eb..5f59ef0710a1 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -221,7 +221,7 @@ mod tests { let res = MakeDateFunc::new() .invoke(&[ ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("1".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), ]) .expect("that make_date parsed values without error"); diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index cbb6f37603d2..b767fd0720db 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -922,40 +922,6 @@ mod tests { panic!("Expected a columnar array") } - // test LargeUTF8 - let string_array = [ - ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%s".to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%c".to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("%+".to_string()))), - ]; - let parsed_timestamps = func(&string_array) - .expect("that to_timestamp with format args parsed values without error"); - if let ColumnarValue::Array(parsed_array) = parsed_timestamps { - assert_eq!(parsed_array.len(), 1); - assert!(matches!( - parsed_array.data_type(), - DataType::Timestamp(_, None) - )); - - match time_unit { - Nanosecond => { - assert_eq!(nanos_expected_timestamps, parsed_array.as_ref()) - } - Millisecond => { - assert_eq!(millis_expected_timestamps, parsed_array.as_ref()) - } - Microsecond => { - assert_eq!(micros_expected_timestamps, parsed_array.as_ref()) - } - Second => { - assert_eq!(sec_expected_timestamps, parsed_array.as_ref()) - } - }; - } else { - panic!("Expected a columnar array") - } - // test other types let string_array = [ ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index d9ce299a2602..aedbe18ec18c 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -173,23 +173,17 @@ fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result { - match scalar { - ScalarValue::Utf8(a) => { - Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) - } - ScalarValue::LargeUtf8(a) => Ok(encoding - .encode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), - ScalarValue::Binary(a) => Ok( - encoding.encode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) - ), - ScalarValue::LargeBinary(a) => Ok(encoding - .encode_large_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), - other => exec_err!( - "Unsupported data type {other:?} for function encode({encoding})" - ), + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) } - } + ScalarValue::Binary(a) => { + Ok(encoding.encode_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))) + } + other => exec_err!( + "Unsupported data type {other:?} for function encode({encoding})" + ), + }, } } @@ -204,23 +198,17 @@ fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result { - match scalar { - ScalarValue::Utf8(a) => { - encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) - } - ScalarValue::LargeUtf8(a) => encoding - .decode_large_scalar(a.as_ref().map(|s: &String| s.as_bytes())), - ScalarValue::Binary(a) => { - encoding.decode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) - } - ScalarValue::LargeBinary(a) => encoding - .decode_large_scalar(a.as_ref().map(|v: &Vec| v.as_slice())), - other => exec_err!( - "Unsupported data type {other:?} for function decode({encoding})" - ), + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(a) => { + encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) } - } + ScalarValue::Binary(a) => { + encoding.decode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) + } + other => exec_err!( + "Unsupported data type {other:?} for function decode({encoding})" + ), + }, } } @@ -274,15 +262,6 @@ impl Encoding { }) } - fn encode_large_scalar(self, value: Option<&[u8]>) -> ColumnarValue { - ColumnarValue::Scalar(match self { - Self::Base64 => ScalarValue::LargeUtf8( - value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), - ), - Self::Hex => ScalarValue::LargeUtf8(value.map(hex::encode)), - }) - } - fn encode_binary_array(self, value: &dyn Array) -> Result where T: OffsetSizeTrait, @@ -335,34 +314,6 @@ impl Encoding { Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(out)))) } - fn decode_large_scalar(self, value: Option<&[u8]>) -> Result { - let value = match value { - Some(value) => value, - None => return Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(None))), - }; - - let out = match self { - Self::Base64 => { - general_purpose::STANDARD_NO_PAD - .decode(value) - .map_err(|e| { - DataFusionError::Internal(format!( - "Failed to decode value using base64: {}", - e - )) - })? - } - Self::Hex => hex::decode(value).map_err(|e| { - DataFusionError::Internal(format!( - "Failed to decode value using hex: {}", - e - )) - })?, - }; - - Ok(ColumnarValue::Scalar(ScalarValue::LargeBinary(Some(out)))) - } - fn decode_binary_array(self, value: &dyn Array) -> Result where T: OffsetSizeTrait, @@ -426,7 +377,7 @@ fn encode(args: &[ColumnarValue]) -> Result { } let encoding = match &args[1] { ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + ScalarValue::Utf8(Some(method)) => { method.parse::() } _ => not_impl_err!( @@ -452,7 +403,7 @@ fn decode(args: &[ColumnarValue]) -> Result { } let encoding = match &args[1] { ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { + ScalarValue::Utf8(Some(method)) => { method.parse::() } _ => not_impl_err!( diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 68ba3f5ff15f..016e5f11893a 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -128,24 +128,6 @@ mod tests { Int32, Int32Array ); - - test_function!( - AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); - - test_function!( - AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); }; } diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 65ec1a4a7734..6f1926f5a96d 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -81,9 +81,6 @@ impl ScalarUDFImpl for BitLengthFunc { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( v.as_ref().map(|x| (x.len() * 8) as i32), ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)), - )), _ => unreachable!(), }, } diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 7037c1d1c3c3..403e687541b7 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -218,10 +218,6 @@ where let result = a.as_ref().map(|x| op(x)); Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) } - ScalarValue::LargeUtf8(a) => { - let result = a.as_ref().map(|x| op(x)); - Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result))) - } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 6d15e2206721..827a2dfef222 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -151,11 +151,11 @@ pub fn simplify_concat(args: Vec) -> Result { for arg in args.clone() { match arg { // filter out `null` args - Expr::Literal(ScalarValue::Utf8(None) | ScalarValue::LargeUtf8(None)) => {} + Expr::Literal(ScalarValue::Utf8(None)) => {} // All literals have been converted to Utf8 or LargeUtf8 in type_coercion. // Concatenate it with the `contiguous_scalar`. Expr::Literal( - ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)), + ScalarValue::Utf8(Some(v)), ) => contiguous_scalar += &v, Expr::Literal(x) => { return internal_err!( diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 4d05f4e707b1..bdf153eaccb6 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -222,9 +222,7 @@ impl ScalarUDFImpl for ConcatWsFunc { fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result { match delimiter { - Expr::Literal( - ScalarValue::Utf8(delimiter) | ScalarValue::LargeUtf8(delimiter), - ) => { + Expr::Literal(ScalarValue::Utf8(delimiter)) => { match delimiter { // when the delimiter is an empty string, // we can use `concat` to replace `concat_ws` @@ -236,8 +234,8 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result {} - Expr::Literal(ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v))) => { + Expr::Literal(ScalarValue::Utf8(None)) => {} + Expr::Literal(ScalarValue::Utf8(Some(v))) => { match contiguous_scalar { None => contiguous_scalar = Some(v.to_string()), Some(mut pre) => { diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index 4e1eb213ef57..15861c39e807 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -167,44 +167,6 @@ mod tests { Utf8, StringArray ); - test_function!( - InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - "hi THOMAS".to_string() - )))], - Ok(Some("Hi Thomas")), - &str, - Utf8, - StringArray - ); - test_function!( - InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string() - )))], - Ok(Some("Hi Thomas With M0re Than 12 Chars")), - &str, - Utf8, - StringArray - ); - test_function!( - InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - "".to_string() - )))], - Ok(Some("")), - &str, - Utf8, - StringArray - ); - test_function!( - InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(None))], - Ok(None), - &str, - Utf8, - StringArray - ); Ok(()) } diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index f792914d862e..d0a533333247 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -81,12 +81,6 @@ impl ScalarUDFImpl for OctetLengthFunc { ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( v.as_ref().map(|x| x.len() as i32), ))), - ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int64(v.as_ref().map(|x| x.len() as i64)), - )), - ScalarValue::Utf8View(v) => Ok(ColumnarValue::Scalar( - ScalarValue::Int32(v.as_ref().map(|x| x.len() as i32)), - )), _ => unreachable!(), }, } @@ -179,36 +173,6 @@ mod tests { Int32, Int32Array ); - test_function!( - OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - String::from("joséjoséjoséjosé") - )))], - Ok(Some(20)), - i32, - Int32, - Int32Array - ); - test_function!( - OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - String::from("josé") - )))], - Ok(Some(5)), - i32, - Int32, - Int32Array - ); - test_function!( - OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View(Some( - String::from("") - )))], - Ok(Some(0)), - i32, - Int32, - Int32Array - ); Ok(()) } diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 8450697cbf30..78dd49ba4509 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -121,17 +121,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))), ]; - let large_utf_8_args = vec![ - ColumnarValue::Scalar(ScalarValue::LargeUtf8(a.map(|s| s.to_string()))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(b.map(|s| s.to_string()))), - ]; - - let utf_8_view_args = vec![ - ColumnarValue::Scalar(ScalarValue::Utf8View(a.map(|s| s.to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8View(b.map(|s| s.to_string()))), - ]; - - vec![(utf_8_args, c), (large_utf_8_args, c), (utf_8_view_args, c)] + vec![(utf_8_args, c)] }); for (args, expected) in test_cases { diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index e46ee162ff12..9e8de0a8405f 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -122,8 +122,8 @@ where mod tests { use crate::unicode::character_length::CharacterLengthFunc; use crate::utils::test::test_function; - use arrow::array::{Array, Int32Array, Int64Array}; - use arrow::datatypes::DataType::{Int32, Int64}; + use arrow::array::{Array, Int32Array}; + use arrow::datatypes::DataType::Int32; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -137,24 +137,6 @@ mod tests { Int32, Int32Array ); - - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], - $EXPECTED, - i64, - Int64, - Int64Array - ); - - test_function!( - CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], - $EXPECTED, - i32, - Int32, - Int32Array - ); }; } diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 5caa6acd6745..8396d6dd40bd 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -284,8 +284,8 @@ mod tests { use crate::unicode::lpad::LPadFunc; use crate::utils::test::test_function; - use arrow::array::{Array, LargeStringArray, StringArray}; - use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use arrow::array::{Array,StringArray}; + use arrow::datatypes::DataType::{Utf8}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -303,30 +303,6 @@ mod tests { Utf8, StringArray ); - - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), - ColumnarValue::Scalar($LENGTH) - ], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), - ColumnarValue::Scalar($LENGTH) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); }; ($INPUT:expr, $LENGTH:expr, $REPLACE:expr, $EXPECTED:expr) => { @@ -343,112 +319,6 @@ mod tests { Utf8, StringArray ); - // utf8, largeutf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - // utf8, utf8view - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - - // largeutf8, utf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) - ], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - // largeutf8, largeutf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) - ], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - // largeutf8, utf8view - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) - ], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - - // utf8view, utf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - // utf8view, largeutf8 - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::LargeUtf8($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); - // utf8view, utf8view - test_function!( - LPadFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8View($REPLACE)) - ], - $EXPECTED, - &str, - Utf8, - StringArray - ); }; } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index ac4ed87a4a1a..21814f0b469b 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -1677,6 +1677,21 @@ mod tests { assert_optimized_plan_equal(projection, expected) } + #[test] + fn cast_literal() -> Result<()> { + let projection = LogicalPlanBuilder::empty(false) + .project(vec![Expr::Cast(Cast::new( + Box::new(lit("hello")), + DataType::LargeUtf8, + ))])? + .build()?; + + let expected = "Projection: CAST(Utf8(\"hello\") AS LargeUtf8)\ + \n EmptyRelation"; + + assert_optimized_plan_equal(projection, expected) + } + #[test] fn table_scan_projected_schema() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index c45df74a564d..17ea6235984c 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,8 +27,9 @@ use arrow::{ record_batch::RecordBatch, }; +use datafusion_common::cast::as_large_list_array; use datafusion_common::{ - cast::{as_large_list_array, as_list_array}, + cast::as_list_array, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 09fdd7685a9c..1f748123365d 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -406,7 +406,6 @@ mod tests { ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(None), ScalarValue::from("abc"), - ScalarValue::LargeUtf8(Some("def".to_string())), ScalarValue::Date32(Some(18628)), ScalarValue::Date32(None), ScalarValue::Decimal128(Some(1000), 19, 2), diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index e0f50a470d43..908c7aa03337 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -315,7 +315,6 @@ fn try_cast_literal_to_type( } try_cast_numeric_literal(lit_value, target_type) .or_else(|| try_cast_string_literal(lit_value, target_type)) - .or_else(|| try_cast_dictionary(lit_value, target_type)) } /// Convert a numeric value from one numeric data type to another @@ -464,46 +463,16 @@ fn try_cast_string_literal( target_type: &DataType, ) -> Option { let string_value = match lit_value { - ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { - s.clone() - } + ScalarValue::Utf8(s) => s.clone(), _ => return None, }; let scalar_value = match target_type { DataType::Utf8 => ScalarValue::Utf8(string_value), - DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), - DataType::Utf8View => ScalarValue::Utf8View(string_value), _ => return None, }; Some(scalar_value) } -/// Attempt to cast to/from a dictionary type by wrapping/unwrapping the dictionary -fn try_cast_dictionary( - lit_value: &ScalarValue, - target_type: &DataType, -) -> Option { - let lit_value_type = lit_value.data_type(); - let result_scalar = match (lit_value, target_type) { - // Unwrap dictionary when inner type matches target type - (ScalarValue::Dictionary(_, inner_value), _) - if inner_value.data_type() == *target_type => - { - (**inner_value).clone() - } - // Wrap type when target type is dictionary - (_, DataType::Dictionary(index_type, inner_type)) - if **inner_type == lit_value_type => - { - ScalarValue::Dictionary(index_type.clone(), Box::new(lit_value.clone())) - } - _ => { - return None; - } - }; - Some(result_scalar) -} - /// Cast a timestamp value from one unit to another fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option { let value = value as i64; @@ -594,45 +563,6 @@ mod tests { assert_eq!(optimize_test(expr_input, &schema), expected); } - #[test] - fn test_unwrap_cast_comparison_string() { - let schema = expr_test_schema(); - let dict = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("value")), - ); - - // cast(str1 as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = Utf8('value1') - let expr_input = cast(col("str1"), dict.data_type()).eq(lit(dict.clone())); - let expected = col("str1").eq(lit("value")); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // cast(tag as Utf8) = Utf8('value') => tag = arrow_cast('value', 'Dictionary') - let expr_input = cast(col("tag"), DataType::Utf8).eq(lit("value")); - let expected = col("tag").eq(lit(dict.clone())); - assert_eq!(optimize_test(expr_input, &schema), expected); - - // Verify reversed argument order - // arrow_cast('value', 'Dictionary') = cast(str1 as Dictionary) => Utf8('value1') = str1 - let expr_input = lit(dict.clone()).eq(cast(col("str1"), dict.data_type())); - let expected = lit("value").eq(col("str1")); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - - #[test] - fn test_unwrap_cast_comparison_large_string() { - let schema = expr_test_schema(); - // cast(largestr as Dictionary) = arrow_cast('value', 'Dictionary') => str1 = LargeUtf8('value1') - let dict = ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::LargeUtf8(Some("value".to_owned()))), - ); - let expr_input = cast(col("largestr"), dict.data_type()).eq(lit(dict.clone())); - let expected = - col("largestr").eq(lit(ScalarValue::LargeUtf8(Some("value".to_owned())))); - assert_eq!(optimize_test(expr_input, &schema), expected); - } - #[test] fn test_not_unwrap_cast_with_decimal_comparison() { let schema = expr_test_schema(); @@ -913,7 +843,6 @@ mod tests { ScalarValue::Decimal128(None, 3, 0), ScalarValue::Decimal128(None, 8, 2), ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), ]; for s1 in &scalars { @@ -1366,45 +1295,4 @@ mod tests { .unwrap(); assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None)); } - - #[test] - fn test_try_cast_to_string_type() { - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - - for s1 in &scalars { - for s2 in &scalars { - let expected_value = ExpectedCast::Value(s2.clone()); - - expect_cast(s1.clone(), s2.data_type(), expected_value); - } - } - } - #[test] - fn test_try_cast_to_dictionary_type() { - fn dictionary_type(t: DataType) -> DataType { - DataType::Dictionary(Box::new(DataType::Int32), Box::new(t)) - } - fn dictionary_value(value: ScalarValue) -> ScalarValue { - ScalarValue::Dictionary(Box::new(DataType::Int32), Box::new(value)) - } - let scalars = vec![ - ScalarValue::from("string"), - ScalarValue::LargeUtf8(Some("string".to_owned())), - ]; - for s in &scalars { - expect_cast( - s.clone(), - dictionary_type(s.data_type()), - ExpectedCast::Value(dictionary_value(s.clone())), - ); - expect_cast( - dictionary_value(s.clone()), - s.data_type(), - ExpectedCast::Value(s.clone()), - ) - } - } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 347a5d82dbec..b2c1e21d1825 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -236,7 +236,7 @@ macro_rules! compute_utf8_flag_op_scalar { .downcast_ref::<$ARRAYTYPE>() .expect("compute_utf8_flag_op_scalar failed to downcast array"); - if let ScalarValue::Utf8(Some(string_value))|ScalarValue::LargeUtf8(Some(string_value)) = $RIGHT { + if let ScalarValue::Utf8(Some(string_value)) = $RIGHT { let flag = if $FLAG { Some("i") } else { None }; let mut array = paste::expr! {[<$OP _utf8_scalar>]}(&ll, &string_value, flag)?; @@ -1425,87 +1425,6 @@ mod tests { Ok(()) } - #[test] - fn plus_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Plus, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(1))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn plus_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value + 1), - Some(value), - None, - Some(value + 2), - Some(value + 1), - ], - 11, - 0, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Plus, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(1), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - #[test] fn minus_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -1536,98 +1455,6 @@ mod tests { Ok(()) } - #[test] - fn minus_op_dict() -> Result<()> { - let schema = Schema::new(vec![ - Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - ), - Field::new( - "b", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - ), - ]); - - let a = Int32Array::from(vec![1, 2, 3, 4, 5]); - let keys = Int8Array::from(vec![Some(0), None, Some(1), Some(3), None]); - let a = DictionaryArray::try_new(keys, Arc::new(a))?; - - let b = Int32Array::from(vec![1, 2, 4, 8, 16]); - let keys = Int8Array::from(vec![0, 1, 1, 2, 1]); - let b = DictionaryArray::try_new(keys, Arc::new(b))?; - - apply_arithmetic::( - Arc::new(schema), - vec![Arc::new(a), Arc::new(b)], - Operator::Minus, - Int32Array::from(vec![Some(0), None, Some(0), Some(0), None]), - )?; - - Ok(()) - } - - #[test] - fn minus_op_dict_decimal() -> Result<()> { - let schema = Schema::new(vec![ - Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - ), - Field::new( - "b", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - ), - ]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value), - Some(value + 2), - Some(value - 1), - Some(value + 1), - ], - 10, - 0, - )); - - let keys = Int8Array::from(vec![Some(0), Some(2), None, Some(3), Some(0)]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let keys = Int8Array::from(vec![Some(0), None, Some(3), Some(2), Some(2)]); - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value + 1), - Some(value + 3), - Some(value), - Some(value + 2), - ], - 10, - 0, - )); - let b = DictionaryArray::try_new(keys, decimal_array)?; - - apply_arithmetic( - Arc::new(schema), - vec![Arc::new(a), Arc::new(b)], - Operator::Minus, - create_decimal_array(&[Some(-1), None, None, Some(1), Some(0)], 11, 0), - )?; - - Ok(()) - } - #[test] fn minus_op_scalar() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -1644,87 +1471,6 @@ mod tests { Ok(()) } - #[test] - fn minus_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Minus, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(1))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn minus_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value - 1), - Some(value - 2), - None, - Some(value), - Some(value - 1), - ], - 11, - 0, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Minus, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(1), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - #[test] fn multiply_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -1856,81 +1602,6 @@ mod tests { Ok(()) } - #[test] - fn multiply_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Multiply, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(2))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn multiply_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[Some(246), Some(244), None, Some(248), Some(246)], - 21, - 0, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Multiply, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - #[test] fn divide_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -2074,81 +1745,6 @@ mod tests { Ok(()) } - #[test] - fn divide_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Divide, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(2))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn divide_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[Some(615000), Some(610000), None, Some(620000), Some(615000)], - 14, - 4, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Divide, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - #[test] fn modulus_op() -> Result<()> { let schema = Arc::new(Schema::new(vec![ @@ -2282,81 +1878,6 @@ mod tests { Ok(()) } - #[test] - fn modules_op_dict_scalar() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Int32)), - true, - )]); - - let mut dict_builder = PrimitiveDictionaryBuilder::::new(); - - dict_builder.append(1)?; - dict_builder.append_null(); - dict_builder.append(2)?; - dict_builder.append(5)?; - - let a = dict_builder.finish(); - - let expected: PrimitiveArray = - PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Modulo, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Int32(Some(2))), - ), - Arc::new(expected), - )?; - - Ok(()) - } - - #[test] - fn modulus_op_dict_scalar_decimal() -> Result<()> { - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(10, 0)), - ), - true, - )]); - - let value = 123; - let decimal_array = Arc::new(create_decimal_array( - &[Some(value), None, Some(value - 1), Some(value + 1)], - 10, - 0, - )); - - let keys = Int8Array::from(vec![0, 2, 1, 3, 0]); - let a = DictionaryArray::try_new(keys, decimal_array)?; - - let decimal_array = Arc::new(create_decimal_array( - &[Some(1), Some(0), None, Some(0), Some(1)], - 10, - 0, - )); - - apply_arithmetic_scalar( - Arc::new(schema), - vec![Arc::new(a)], - Operator::Modulo, - ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(2), 10, 0)), - ), - decimal_array, - )?; - - Ok(()) - } - fn apply_arithmetic( schema: SchemaRef, data: Vec, @@ -3044,97 +2565,6 @@ mod tests { .unwrap() } - #[test] - fn comparison_dict_decimal_scalar_expr_test() -> Result<()> { - // scalar of decimal compare with dictionary decimal array - let value_i128 = 123; - let decimal_scalar = ScalarValue::Dictionary( - Box::new(DataType::Int8), - Box::new(ScalarValue::Decimal128(Some(value_i128), 25, 3)), - ); - let schema = Arc::new(Schema::new(vec![Field::new( - "a", - DataType::Dictionary( - Box::new(DataType::Int8), - Box::new(DataType::Decimal128(25, 3)), - ), - true, - )])); - let decimal_array = Arc::new(create_decimal_array( - &[ - Some(value_i128), - None, - Some(value_i128 - 1), - Some(value_i128 + 1), - ], - 25, - 3, - )); - - let keys = Int8Array::from(vec![Some(0), None, Some(2), Some(3)]); - let dictionary = - Arc::new(DictionaryArray::try_new(keys, decimal_array)?) as ArrayRef; - - // array = scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::Eq, - &BooleanArray::from(vec![Some(true), None, Some(false), Some(false)]), - ) - .unwrap(); - // array != scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::NotEq, - &BooleanArray::from(vec![Some(false), None, Some(true), Some(true)]), - ) - .unwrap(); - // array < scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::Lt, - &BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]), - ) - .unwrap(); - - // array <= scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::LtEq, - &BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]), - ) - .unwrap(); - // array > scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::Gt, - &BooleanArray::from(vec![Some(false), None, Some(false), Some(true)]), - ) - .unwrap(); - - // array >= scalar - apply_logic_op_arr_scalar( - &schema, - &dictionary, - &decimal_scalar, - Operator::GtEq, - &BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]), - ) - .unwrap(); - - Ok(()) - } - #[test] fn comparison_decimal_expr_test() -> Result<()> { // scalar of decimal compare with decimal array diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index dfc70551ccf6..33ca659987a3 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -222,7 +222,6 @@ fn evaluate_list( exec_err!("InList expression must evaluate to a scalar") } // Flatten dictionary values - ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v), ColumnarValue::Scalar(s) => Ok(s), }) }) @@ -1333,96 +1332,4 @@ mod tests { Ok(()) } - - #[test] - fn in_list_utf8_with_dict_types() -> Result<()> { - fn dict_lit(key_type: DataType, value: &str) -> Arc { - lit(ScalarValue::Dictionary( - Box::new(key_type), - Box::new(ScalarValue::new_utf8(value.to_string())), - )) - } - - fn null_dict_lit(key_type: DataType) -> Arc { - lit(ScalarValue::Dictionary( - Box::new(key_type), - Box::new(ScalarValue::Utf8(None)), - )) - } - - let schema = Schema::new(vec![Field::new( - "a", - DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Utf8)), - true, - )]); - let a: UInt16DictionaryArray = - vec![Some("a"), Some("d"), None].into_iter().collect(); - let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; - - // expression: "a in ("a", "b")" - let lists = [ - vec![lit("a"), lit("b")], - vec![ - dict_lit(DataType::Int8, "a"), - dict_lit(DataType::UInt16, "b"), - ], - ]; - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), Some(false), None], - Arc::clone(&col_a), - &schema - ); - } - - // expression: "a not in ("a", "b")" - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &true, - vec![Some(false), Some(true), None], - Arc::clone(&col_a), - &schema - ); - } - - // expression: "a in ("a", "b", null)" - let lists = [ - vec![lit("a"), lit("b"), lit(ScalarValue::Utf8(None))], - vec![ - dict_lit(DataType::Int8, "a"), - dict_lit(DataType::UInt16, "b"), - null_dict_lit(DataType::UInt16), - ], - ]; - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &false, - vec![Some(true), None, None], - Arc::clone(&col_a), - &schema - ); - } - - // expression: "a not in ("a", "b", null)" - for list in lists.iter() { - in_list_raw!( - batch, - list.clone(), - &true, - vec![Some(false), None, None], - Arc::clone(&col_a), - &schema - ); - } - - Ok(()) - } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 4c37db4849a7..195a32042e6f 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -503,10 +503,7 @@ pub(crate) mod tests { let schema_big = Arc::new(Schema::new(vec![int_field, dict_field])); let pred = in_list( Arc::new(Column::new_with_schema("id", &schema_big).unwrap()), - vec![lit(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("2")), - ))], + vec![lit(ScalarValue::from("2"))], &false, &schema_big, ) @@ -516,10 +513,7 @@ pub(crate) mod tests { let expected = in_list( Arc::new(Column::new_with_schema("id", &schema_small).unwrap()), - vec![lit(ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("2")), - ))], + vec![lit(ScalarValue::from("2"))], &false, &schema_small, ) diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index feb4c11aa809..b07a31071034 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -368,8 +368,8 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Ok(match value { Value::BoolValue(v) => Self::Boolean(Some(*v)), Value::Utf8Value(v) => Self::Utf8(Some(v.to_owned())), - Value::Utf8ViewValue(v) => Self::Utf8View(Some(v.to_owned())), - Value::LargeUtf8Value(v) => Self::LargeUtf8(Some(v.to_owned())), + Value::Utf8ViewValue(v) => Self::Utf8(Some(v.to_owned())), + Value::LargeUtf8Value(v) => Self::Utf8(Some(v.to_owned())), Value::Int8Value(v) => Self::Int8(Some(*v as i8)), Value::Int16Value(v) => Self::Int16(Some(*v as i16)), Value::Int32Value(v) => Self::Int32(Some(*v)), @@ -564,25 +564,15 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } } } - Value::DictionaryValue(v) => { - let index_type: DataType = v - .index_type - .as_ref() - .ok_or_else(|| Error::required("index_type"))? - .try_into()?; - - let value: Self = v - .value - .as_ref() - .ok_or_else(|| Error::required("value"))? - .as_ref() - .try_into()?; - - Self::Dictionary(Box::new(index_type), Box::new(value)) - } + Value::DictionaryValue(v) => v + .value + .as_ref() + .ok_or_else(|| Error::required("value"))? + .as_ref() + .try_into()?, Value::BinaryValue(v) => Self::Binary(Some(v.clone())), - Value::BinaryViewValue(v) => Self::BinaryView(Some(v.clone())), - Value::LargeBinaryValue(v) => Self::LargeBinary(Some(v.clone())), + Value::BinaryViewValue(v) => Self::Binary(Some(v.clone())), + Value::LargeBinaryValue(v) => Self::Binary(Some(v.clone())), Value::IntervalDaytimeValue(v) => Self::IntervalDayTime(Some( IntervalDayTimeType::make_value(v.days, v.milliseconds), )), diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index 4cf7e73ac912..e09eacfb9709 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -345,16 +345,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::Utf8Value(s.to_owned()) }) } - ScalarValue::LargeUtf8(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::LargeUtf8Value(s.to_owned()) - }) - } - ScalarValue::Utf8View(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::Utf8ViewValue(s.to_owned()) - }) - } ScalarValue::List(arr) => { encode_scalar_nested_value(arr.to_owned() as ArrayRef, val) } @@ -472,16 +462,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::BinaryValue(s.to_owned()) }) } - ScalarValue::BinaryView(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::BinaryViewValue(s.to_owned()) - }) - } - ScalarValue::LargeBinary(val) => { - create_proto_scalar(val.as_ref(), &data_type, |s| { - Value::LargeBinaryValue(s.to_owned()) - }) - } ScalarValue::FixedSizeBinary(length, val) => { create_proto_scalar(val.as_ref(), &data_type, |s| { Value::FixedSizeBinaryValue(protobuf::ScalarFixedSizeBinary { @@ -622,18 +602,6 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { let val = protobuf::ScalarValue { value: Some(val) }; Ok(val) } - - ScalarValue::Dictionary(index_type, val) => { - let value: protobuf::ScalarValue = val.as_ref().try_into()?; - Ok(protobuf::ScalarValue { - value: Some(Value::DictionaryValue(Box::new( - protobuf::ScalarDictionaryValue { - index_type: Some(index_type.as_ref().try_into()?), - value: Some(Box::new(value)), - }, - ))), - }) - } } } } diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index eb7cc5c4b9c5..6d8ce13b4f89 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -1190,7 +1190,6 @@ fn round_trip_scalar_values() { ScalarValue::UInt32(None), ScalarValue::UInt64(None), ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), ScalarValue::List(ScalarValue::new_list_nullable(&[], &DataType::Boolean)), ScalarValue::LargeList(ScalarValue::new_large_list(&[], &DataType::Boolean)), ScalarValue::Date32(None), @@ -1229,7 +1228,6 @@ fn round_trip_scalar_values() { ScalarValue::UInt64(Some(u64::MAX)), ScalarValue::UInt64(Some(0)), ScalarValue::Utf8(Some(String::from("Test string "))), - ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), ScalarValue::Date32(Some(0)), ScalarValue::Date32(Some(i32::MAX)), ScalarValue::Date32(None), @@ -1349,18 +1347,8 @@ fn round_trip_scalar_values() { vec![Some(vec![Some(1), Some(2), Some(3)])], 3, ))), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::from("foo")), - ), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(None)), - ), ScalarValue::Binary(Some(b"bar".to_vec())), ScalarValue::Binary(None), - ScalarValue::LargeBinary(Some(b"bar".to_vec())), - ScalarValue::LargeBinary(None), ScalarStructBuilder::new() .with_scalar( Field::new("a", DataType::Int32, true), @@ -1381,20 +1369,6 @@ fn round_trip_scalar_values() { Field::new("b", DataType::Boolean, false), ScalarValue::from(false), ) - .with_scalar( - Field::new( - "c", - DataType::Dictionary( - Box::new(DataType::UInt16), - Box::new(DataType::Utf8), - ), - false, - ), - ScalarValue::Dictionary( - Box::new(DataType::UInt16), - Box::new("value".into()), - ), - ) .build() .unwrap(), ScalarValue::try_from(&DataType::Struct(Fields::from(vec![ @@ -1737,7 +1711,6 @@ fn roundtrip_null_scalar_values() { ScalarValue::UInt32(None), ScalarValue::UInt64(None), ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), ScalarValue::Date32(None), ScalarValue::TimestampMicrosecond(None, None), ScalarValue::TimestampNanosecond(None, None), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 6766468ef443..b18ade44377d 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -40,8 +40,7 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ - wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, - FileSinkConfig, ParquetExec, + FileScanConfig, FileSinkConfig, ParquetExec, }; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; @@ -671,8 +670,7 @@ fn roundtrip_parquet_exec_with_pruning_predicate() -> Result<()> { async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { let mut file_group = PartitionedFile::new("/path/to/part=0/file.parquet".to_string(), 1024); - file_group.partition_values = - vec![wrap_partition_value_in_dict(ScalarValue::Int64(Some(0)))]; + file_group.partition_values = vec![ScalarValue::Int64(Some(0))]; let schema = Arc::new(Schema::new(vec![Field::new("col", DataType::Utf8, false)])); let scan_config = FileScanConfig { @@ -684,7 +682,7 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { limit: None, table_partition_cols: vec![Field::new( "part".to_string(), - wrap_partition_type_in_dict(DataType::Int16), + DataType::Int16, false, )], output_ordering: vec![], diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 39511ea4d03a..dd024b98e6f0 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1009,27 +1009,11 @@ impl Unparser<'_> { ast::Value::SingleQuotedString(str.to_string()), )), ScalarValue::Utf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::Utf8View(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), - ScalarValue::Utf8View(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::LargeUtf8(Some(str)) => Ok(ast::Expr::Value( - ast::Value::SingleQuotedString(str.to_string()), - )), - ScalarValue::LargeUtf8(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Binary(Some(_)) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Binary(None) => Ok(ast::Expr::Value(ast::Value::Null)), - ScalarValue::BinaryView(Some(_)) => { - not_impl_err!("Unsupported scalar: {v:?}") - } - ScalarValue::BinaryView(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::FixedSizeBinary(..) => { not_impl_err!("Unsupported scalar: {v:?}") } - ScalarValue::LargeBinary(Some(_)) => { - not_impl_err!("Unsupported scalar: {v:?}") - } - ScalarValue::LargeBinary(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::FixedSizeList(_a) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::List(_a) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::LargeList(_a) => not_impl_err!("Unsupported scalar: {v:?}"), @@ -1160,7 +1144,6 @@ impl Unparser<'_> { ScalarValue::Struct(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Map(_) => not_impl_err!("Unsupported scalar: {v:?}"), ScalarValue::Union(..) => not_impl_err!("Unsupported scalar: {v:?}"), - ScalarValue::Dictionary(..) => not_impl_err!("Unsupported scalar: {v:?}"), } } diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f2756bb06d1e..c74b0b162ff2 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1696,16 +1696,12 @@ fn from_substrait_literal( Some(LiteralType::Date(d)) => ScalarValue::Date32(Some(*d)), Some(LiteralType::String(s)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Utf8(Some(s.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => ScalarValue::LargeUtf8(Some(s.clone())), others => { return substrait_err!("Unknown type variation reference {others}"); } }, Some(LiteralType::Binary(b)) => match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => ScalarValue::Binary(Some(b.clone())), - LARGE_CONTAINER_TYPE_VARIATION_REF => { - ScalarValue::LargeBinary(Some(b.clone())) - } others => { return substrait_err!("Unknown type variation reference {others}"); } @@ -2052,7 +2048,6 @@ fn from_substrait_null( }, r#type::Kind::Binary(binary) => match binary.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Binary(None)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeBinary(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), @@ -2060,7 +2055,6 @@ fn from_substrait_null( // FixedBinary is not supported because `None` doesn't have length r#type::Kind::String(string) => match string.type_variation_reference { DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::Utf8(None)), - LARGE_CONTAINER_TYPE_VARIATION_REF => Ok(ScalarValue::LargeUtf8(None)), v => not_impl_err!( "Unsupported Substrait type variation {v} of type {kind:?}" ), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index ee04749f5e6b..2112d5d2e7f1 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1855,10 +1855,6 @@ fn to_substrait_literal( LiteralType::Binary(b.clone()), DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), - ScalarValue::LargeBinary(Some(b)) => ( - LiteralType::Binary(b.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), ScalarValue::FixedSizeBinary(_, Some(b)) => ( LiteralType::FixedBinary(b.clone()), DEFAULT_TYPE_VARIATION_REF, @@ -1867,10 +1863,6 @@ fn to_substrait_literal( LiteralType::String(s.clone()), DEFAULT_CONTAINER_TYPE_VARIATION_REF, ), - ScalarValue::LargeUtf8(Some(s)) => ( - LiteralType::String(s.clone()), - LARGE_CONTAINER_TYPE_VARIATION_REF, - ), ScalarValue::Decimal128(v, p, s) if v.is_some() => ( LiteralType::Decimal(Decimal { value: v.unwrap().to_le_bytes().to_vec(), From 78dc034d4b7d644e6c34745286e82a6332dd7606 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Wed, 14 Aug 2024 18:40:21 +0200 Subject: [PATCH 02/12] Progress sync for string_view.slt --- datafusion/common/src/lib.rs | 1 + datafusion/common/src/logical/eq.rs | 26 ++++ datafusion/common/src/logical/mod.rs | 1 + datafusion/common/src/scalar/mod.rs | 123 +++++++++++++----- datafusion/core/tests/optimizer/mod.rs | 3 +- datafusion/expr-common/src/columnar_value.rs | 21 ++- datafusion/functions/src/unicode/lpad.rs | 4 +- .../simplify_expressions/expr_simplifier.rs | 24 +++- datafusion/physical-plan/src/projection.rs | 5 +- .../sqllogictest/test_files/string_view.slt | 27 ++-- 10 files changed, 184 insertions(+), 51 deletions(-) create mode 100644 datafusion/common/src/logical/eq.rs create mode 100644 datafusion/common/src/logical/mod.rs diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 19af889e426a..4a28667ecfca 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -37,6 +37,7 @@ pub mod file_options; pub mod format; pub mod hash_utils; pub mod instant; +pub mod logical; pub mod parsers; pub mod rounding; pub mod scalar; diff --git a/datafusion/common/src/logical/eq.rs b/datafusion/common/src/logical/eq.rs new file mode 100644 index 000000000000..902f655cd961 --- /dev/null +++ b/datafusion/common/src/logical/eq.rs @@ -0,0 +1,26 @@ +use arrow_schema::DataType; + +pub trait LogicallyEq { + #[must_use] + fn logically_eq(&self, other: &Rhs) -> bool; +} + +impl LogicallyEq for DataType { + fn logically_eq(&self, other: &Self) -> bool { + use DataType::*; + + match (self, other) { + (Utf8 | LargeUtf8 | Utf8View, Utf8 | LargeUtf8 | Utf8View) + | (Binary | LargeBinary | BinaryView, Binary | LargeBinary | BinaryView) => { + true + } + (Dictionary(_, inner), other) | (other, Dictionary(_, inner)) => { + other.logically_eq(inner) + } + (RunEndEncoded(_, inner), other) | (other, RunEndEncoded(_, inner)) => { + other.logically_eq(inner.data_type()) + } + _ => self == other, + } + } +} diff --git a/datafusion/common/src/logical/mod.rs b/datafusion/common/src/logical/mod.rs new file mode 100644 index 000000000000..798e468ebe6d --- /dev/null +++ b/datafusion/common/src/logical/mod.rs @@ -0,0 +1 @@ +pub mod eq; diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index b22bae959961..ad0930457c12 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -748,6 +748,40 @@ pub fn get_dict_value( Ok((dict_array.values(), dict_array.key(index))) } +/// Create a dictionary array representing all the values in values +fn dict_from_values( + values_array: ArrayRef, +) -> Result { + // Create a key array with `size` elements of 0..array_len for all + // non-null value elements + let key_array: PrimitiveArray = (0..values_array.len()) + .map(|index| { + if values_array.is_valid(index) { + let native_index = K::Native::from_usize(index).ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not create index of type {} from value {}", + K::DATA_TYPE, + index + )) + })?; + Ok(Some(native_index)) + } else { + Ok(None) + } + }) + .collect::>>()? + .into_iter() + .collect(); + + // create a new DictionaryArray + // + // Note: this path could be made faster by using the ArrayData + // APIs and skipping validation, if it every comes up in + // performance traces. + let dict_array = DictionaryArray::::try_new(key_array, values_array)?; + Ok(Arc::new(dict_array)) +} + macro_rules! typed_cast_tz { ($array:expr, $index:expr, $ARRAYTYPE:ident, $SCALAR:ident, $TZ:expr) => {{ use std::any::type_name; @@ -1545,6 +1579,7 @@ impl ScalarValue { Ok(Scalar::new(self.to_array_of_size(1)?)) } + /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] /// corresponding to those values. For example, an iterator of /// [`ScalarValue::Int32`] would be converted to an [`Int32Array`]. @@ -1596,6 +1631,15 @@ impl ScalarValue { Some(sv) => sv.data_type(), }; + Self::iter_to_array_of_type(scalars.collect(), &data_type) + } + + fn iter_to_array_of_type( + scalars: Vec, + data_type: &DataType, + ) -> Result { + let scalars = scalars.into_iter(); + /// Creates an array of $ARRAY_TY by unpacking values of /// SCALAR_TY for primitive types macro_rules! build_array_primitive { @@ -1685,7 +1729,9 @@ impl ScalarValue { DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), DataType::Utf8 => build_array_string!(StringArray, Utf8), + DataType::LargeUtf8 => build_array_string!(LargeStringArray, Utf8), DataType::Binary => build_array_string!(BinaryArray, Binary), + DataType::LargeBinary => build_array_string!(LargeBinaryArray, Binary), DataType::Date32 => build_array_primitive!(Date32Array, Date32), DataType::Date64 => build_array_primitive!(Date64Array, Date64), DataType::Time32(TimeUnit::Second) => { @@ -1758,11 +1804,8 @@ impl ScalarValue { if let Some(DataType::FixedSizeList(f, l)) = first_non_null_data_type { for array in arrays.iter_mut() { if array.is_null(0) { - *array = Arc::new(FixedSizeListArray::new_null( - Arc::clone(&f), - l, - 1, - )); + *array = + Arc::new(FixedSizeListArray::new_null(f.clone(), l, 1)); } } } @@ -1771,13 +1814,28 @@ impl ScalarValue { } DataType::List(_) | DataType::LargeList(_) - | DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) => { let arrays = scalars.map(|s| s.to_array()).collect::>>()?; let arrays = arrays.iter().map(|a| a.as_ref()).collect::>(); arrow::compute::concat(arrays.as_slice())? } + DataType::Dictionary(key_type, value_type) => { + let values = Self::iter_to_array(scalars)?; + assert_eq!(values.data_type(), value_type.as_ref()); + + match key_type.as_ref() { + DataType::Int8 => dict_from_values::(values)?, + DataType::Int16 => dict_from_values::(values)?, + DataType::Int32 => dict_from_values::(values)?, + DataType::Int64 => dict_from_values::(values)?, + DataType::UInt8 => dict_from_values::(values)?, + DataType::UInt16 => dict_from_values::(values)?, + DataType::UInt32 => dict_from_values::(values)?, + DataType::UInt64 => dict_from_values::(values)?, + _ => unreachable!("Invalid dictionary keys type: {:?}", key_type), + } + } DataType::FixedSizeBinary(size) => { let array = scalars .map(|sv| { @@ -1806,18 +1864,15 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) + | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) - | DataType::ListView(_) - | DataType::LargeBinary - | DataType::BinaryView - | DataType::LargeUtf8 | DataType::Utf8View - | DataType::Dictionary(_, _) + | DataType::BinaryView + | DataType::ListView(_) | DataType::LargeListView(_) => { return _internal_err!( - "Unsupported creation of {:?} array from ScalarValue {:?}", - data_type, - scalars.peek() + "Unsupported creation of {:?} array", + data_type ); } }; @@ -1940,7 +1995,7 @@ impl ScalarValue { let values = if values.is_empty() { new_empty_array(data_type) } else { - Self::iter_to_array(values.iter().cloned()).unwrap() + Self::iter_to_array_of_type(values.to_vec(), data_type).unwrap() }; Arc::new(array_into_list_array(values, nullable)) } @@ -2931,6 +2986,11 @@ impl ScalarValue { .map(|sv| sv.size() - std::mem::size_of_val(sv)) .sum::() } + + pub fn supported_datatype(data_type: &DataType) -> Result { + let scalar = Self::try_from(data_type)?; + Ok(scalar.data_type()) + } } macro_rules! impl_scalar { @@ -5456,22 +5516,23 @@ mod tests { check_scalar_cast(ScalarValue::Float64(None), DataType::Int16); - check_scalar_cast( - ScalarValue::from("foo"), - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - ); - - check_scalar_cast( - ScalarValue::Utf8(None), - DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), - ); - - check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View); - check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View); - check_scalar_cast( - ScalarValue::from("larger than 12 bytes string"), - DataType::Utf8View, - ); + // TODO(@notfilippo): this tests fails but it should check if logically equal + // check_scalar_cast( + // ScalarValue::from("foo"), + // DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + // ); + // + // check_scalar_cast( + // ScalarValue::Utf8(None), + // DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + // ); + // + // check_scalar_cast(ScalarValue::Utf8(None), DataType::Utf8View); + // check_scalar_cast(ScalarValue::from("foo"), DataType::Utf8View); + // check_scalar_cast( + // ScalarValue::from("larger than 12 bytes string"), + // DataType::Utf8View, + // ); } // mimics how casting work on scalar values by `casting` `scalar` to `desired_type` diff --git a/datafusion/core/tests/optimizer/mod.rs b/datafusion/core/tests/optimizer/mod.rs index f17d13a42060..0686b954207f 100644 --- a/datafusion/core/tests/optimizer/mod.rs +++ b/datafusion/core/tests/optimizer/mod.rs @@ -56,7 +56,8 @@ fn init() { #[test] fn select_arrow_cast() { let sql = "SELECT arrow_cast(1234, 'Float64') as f64, arrow_cast('foo', 'LargeUtf8') as large"; - let expected = "Projection: Float64(1234) AS f64, LargeUtf8(\"foo\") AS large\ + let expected = + "Projection: Float64(1234) AS f64, CAST(Utf8(\"foo\") AS LargeUtf8) AS large\ \n EmptyRelation"; quick_test(sql, expected); } diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index bfefb37c98d7..f430c9fa9b7e 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -17,13 +17,14 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. -use arrow::array::ArrayRef; +use arrow::array::{Array, ArrayRef}; use arrow::array::NullArray; use arrow::compute::{kernels, CastOptions}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::format::DEFAULT_CAST_OPTIONS; use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; +use datafusion_common::logical::eq::LogicallyEq; /// The result of evaluating an expression. /// @@ -130,6 +131,20 @@ impl ColumnarValue { }) } + pub fn into_array_of_type(self, num_rows: usize, data_type: &DataType) -> Result { + let array = self.into_array(num_rows)?; + if array.data_type() == data_type { + Ok(array) + } else { + let cast_array = kernels::cast::cast_with_options( + &array, + data_type, + &DEFAULT_CAST_OPTIONS, + )?; + Ok(cast_array) + } + } + /// null columnar values are implemented as a null array in order to pass batch /// num_rows pub fn create_null_array(num_rows: usize) -> Self { @@ -195,6 +210,10 @@ impl ColumnarValue { kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), ColumnarValue::Scalar(scalar) => { + if scalar.data_type().logically_eq(cast_type) { + return Ok(self.clone()) + } + let scalar_array = if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { if let ScalarValue::Float64(Some(float_ts)) = scalar { diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 8396d6dd40bd..7cca36454514 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -284,8 +284,8 @@ mod tests { use crate::unicode::lpad::LPadFunc; use crate::utils::test::test_function; - use arrow::array::{Array,StringArray}; - use arrow::datatypes::DataType::{Utf8}; + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 17ea6235984c..4b37c26c1f7e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -28,6 +28,7 @@ use arrow::{ }; use datafusion_common::cast::as_large_list_array; +use datafusion_common::logical::eq::LogicallyEq; use datafusion_common::{ cast::as_list_array, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, @@ -36,8 +37,8 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarV use datafusion_expr::expr::{InList, InSubquery, WindowFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, - WindowFunctionDefinition, + and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, ExprSchemable, Like, Operator, + Volatility, WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -628,15 +629,34 @@ impl<'a> ConstEvaluator<'a> { return ConstSimplifyResult::NotSimplified(s); } + let start_type = match expr.get_type(&self.input_schema) { + Ok(t) => t, + Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), + }; + let phys_expr = match create_physical_expr(&expr, &self.input_schema, self.execution_props) { Ok(e) => e, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; + let col_val = match phys_expr.evaluate(&self.input_batch) { Ok(v) => v, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; + + // TODO(@notfilippo): a fix for the select_arrow_cast error + let end_type = col_val.data_type(); + if end_type.logically_eq(&start_type) && start_type != end_type { + return ConstSimplifyResult::SimplifyRuntimeError( + DataFusionError::Execution(format!( + "Skipping, end_type {} is logically equal to start_type {} but not strictly equal", + end_type, start_type + )), + expr, + ); + } + match col_val { ColumnarValue::Array(a) => { if a.len() != 1 { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index d2bb8f2b0ead..139e6c8a3018 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -306,9 +306,10 @@ impl ProjectionStream { let arrays = self .expr .iter() - .map(|expr| { + .zip(&self.schema.fields) + .map(|(expr, field)| { expr.evaluate(batch) - .and_then(|v| v.into_array(batch.num_rows())) + .and_then(|v| v.into_array_of_type(batch.num_rows(), field.data_type())) }) .collect::>>()?; diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 0a9b73babb96..6a1a2b927ecf 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -261,7 +261,7 @@ explain SELECT column1_utf8 from test where column1_utf8view = 'Andrew'; ---- logical_plan 01)Projection: test.column1_utf8 -02)--Filter: test.column1_utf8view = Utf8View("Andrew") +02)--Filter: test.column1_utf8view = CAST(Utf8("Andrew") AS Utf8View) 03)----TableScan: test projection=[column1_utf8, column1_utf8view] # reverse order should be the same @@ -270,21 +270,21 @@ explain SELECT column1_utf8 from test where 'Andrew' = column1_utf8view; ---- logical_plan 01)Projection: test.column1_utf8 -02)--Filter: test.column1_utf8view = Utf8View("Andrew") +02)--Filter: CAST(Utf8("Andrew") AS Utf8View) = test.column1_utf8view 03)----TableScan: test projection=[column1_utf8, column1_utf8view] query TT explain SELECT column1_utf8 from test where column1_utf8 = arrow_cast('Andrew', 'Utf8View'); ---- logical_plan -01)Filter: test.column1_utf8 = Utf8("Andrew") +01)Filter: CAST(test.column1_utf8 AS Utf8View) = CAST(Utf8("Andrew") AS Utf8View) 02)--TableScan: test projection=[column1_utf8] query TT explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Utf8View') = column1_utf8; ---- logical_plan -01)Filter: test.column1_utf8 = Utf8("Andrew") +01)Filter: CAST(Utf8("Andrew") AS Utf8View) = CAST(test.column1_utf8 AS Utf8View) 02)--TableScan: test projection=[column1_utf8] query TT @@ -292,7 +292,7 @@ explain SELECT column1_utf8 from test where column1_utf8view = arrow_cast('Andre ---- logical_plan 01)Projection: test.column1_utf8 -02)--Filter: test.column1_utf8view = Utf8View("Andrew") +02)--Filter: test.column1_utf8view = CAST(CAST(Utf8("Andrew") AS Dictionary(Int32, Utf8)) AS Utf8View) 03)----TableScan: test projection=[column1_utf8, column1_utf8view] query TT @@ -300,7 +300,7 @@ explain SELECT column1_utf8 from test where arrow_cast('Andrew', 'Dictionary(Int ---- logical_plan 01)Projection: test.column1_utf8 -02)--Filter: test.column1_utf8view = Utf8View("Andrew") +02)--Filter: CAST(CAST(Utf8("Andrew") AS Dictionary(Int32, Utf8)) AS Utf8View) = test.column1_utf8view 03)----TableScan: test projection=[column1_utf8, column1_utf8view] # compare string / stringview @@ -422,8 +422,10 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(test.column1_utf8view, Utf8View("äöüß")) AS c1, starts_with(test.column1_utf8view, Utf8View("")) AS c2, starts_with(test.column1_utf8view, Utf8View(NULL)) AS c3, starts_with(Utf8View(NULL), test.column1_utf8view) AS c4 -02)--TableScan: test projection=[column1_utf8view] +01)Projection: starts_with(test.column1_utf8view, CAST(Utf8("äöüß") AS Utf8View)) AS c1, starts_with(test.column1_utf8view, CAST(Utf8("") AS Utf8View)) AS c2, starts_with(test.column1_utf8view, __common_expr_1) AS c3, starts_with(__common_expr_1, test.column1_utf8view) AS c4 +02)--Projection: CAST(NULL AS Utf8View) AS __common_expr_1, test.column1_utf8view +03)----TableScan: test projection=[column1_utf8view] + ### Initcap @@ -481,8 +483,9 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: test.column1_utf8view LIKE Utf8View("foo") AS like, test.column1_utf8view ILIKE Utf8View("foo") AS ilike -02)--TableScan: test projection=[column1_utf8view] +01)Projection: test.column1_utf8view LIKE __common_expr_1 AS like, test.column1_utf8view ILIKE __common_expr_1 AS ilike +02)--Projection: CAST(Utf8("foo") AS Utf8View) AS __common_expr_1, test.column1_utf8view +03)----TableScan: test projection=[column1_utf8view] @@ -580,7 +583,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: btrim(test.column1_utf8view, Utf8View("foo")) AS l +01)Projection: btrim(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test BTRIM with Utf8View bytes longer than 12 @@ -590,7 +593,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: btrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +01)Projection: btrim(test.column1_utf8view, CAST(Utf8("this is longer than 12") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test BTRIM outputs From 1b22fa9b076d4c46e557624da7e61b2c309f6730 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Fri, 16 Aug 2024 14:36:03 +0200 Subject: [PATCH 03/12] Reintroduce Dictionary in partition values --- datafusion/common/src/scalar/mod.rs | 34 ++- datafusion/core/src/datasource/listing/mod.rs | 7 +- .../physical_plan/file_scan_config.rs | 256 ++++++++++++++++-- .../core/src/datasource/physical_plan/mod.rs | 2 +- datafusion/expr-common/src/columnar_value.rs | 12 +- .../simplify_expressions/expr_simplifier.rs | 6 +- datafusion/physical-plan/src/projection.rs | 5 +- .../tests/cases/roundtrip_physical_plan.rs | 4 +- 8 files changed, 273 insertions(+), 53 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index ad0930457c12..97300e8dc7a6 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -41,7 +41,7 @@ use crate::hash_utils::create_hashes; use crate::utils::{ array_into_fixed_size_list_array, array_into_large_list_array, array_into_list_array, }; -use arrow::compute::kernels::numeric::*; +use arrow::compute::kernels::{self, numeric::*}; use arrow::util::display::{array_value_to_string, ArrayFormatter, FormatOptions}; use arrow::{ array::*, @@ -1534,6 +1534,10 @@ impl ScalarValue { } } + pub fn to_array_of_type(&self, data_type: &DataType) -> Result { + self.to_array_of_size_and_type(1, data_type) + } + /// Converts a scalar value into an 1-row array. /// /// # Errors @@ -1579,7 +1583,6 @@ impl ScalarValue { Ok(Scalar::new(self.to_array_of_size(1)?)) } - /// Converts an iterator of references [`ScalarValue`] into an [`ArrayRef`] /// corresponding to those values. For example, an iterator of /// [`ScalarValue::Int32`] would be converted to an [`Int32Array`]. @@ -1631,11 +1634,11 @@ impl ScalarValue { Some(sv) => sv.data_type(), }; - Self::iter_to_array_of_type(scalars.collect(), &data_type) + Self::iter_to_array_of_type(scalars, &data_type) } fn iter_to_array_of_type( - scalars: Vec, + scalars: impl IntoIterator, data_type: &DataType, ) -> Result { let scalars = scalars.into_iter(); @@ -1821,7 +1824,7 @@ impl ScalarValue { arrow::compute::concat(arrays.as_slice())? } DataType::Dictionary(key_type, value_type) => { - let values = Self::iter_to_array(scalars)?; + let values = Self::iter_to_array_of_type(scalars, value_type)?; assert_eq!(values.data_type(), value_type.as_ref()); match key_type.as_ref() { @@ -1870,10 +1873,7 @@ impl ScalarValue { | DataType::BinaryView | DataType::ListView(_) | DataType::LargeListView(_) => { - return _internal_err!( - "Unsupported creation of {:?} array", - data_type - ); + return _internal_err!("Unsupported creation of {:?} array", data_type); } }; Ok(array) @@ -2093,6 +2093,17 @@ impl ScalarValue { Arc::new(array_into_large_list_array(values, true)) } + pub fn to_array_of_size_and_type( + &self, + size: usize, + data_type: &DataType, + ) -> Result { + // TODO(@notfilippo): for now cast as it's a POC, but it can be optimized later with a bit `match` + let array = self.to_array_of_size(size)?; + let cast_array = kernels::cast::cast(&array, data_type)?; + Ok(cast_array) + } + /// Converts a scalar value into an array of `size` rows. /// /// # Errors @@ -2986,11 +2997,6 @@ impl ScalarValue { .map(|sv| sv.size() - std::mem::size_of_val(sv)) .sum::() } - - pub fn supported_datatype(data_type: &DataType) -> Result { - let scalar = Self::try_from(data_type)?; - Ok(scalar.data_type()) - } } macro_rules! impl_scalar { diff --git a/datafusion/core/src/datasource/listing/mod.rs b/datafusion/core/src/datasource/listing/mod.rs index c5a441aacf1d..98e0dc48d034 100644 --- a/datafusion/core/src/datasource/listing/mod.rs +++ b/datafusion/core/src/datasource/listing/mod.rs @@ -63,13 +63,8 @@ pub struct PartitionedFile { pub object_meta: ObjectMeta, /// Values of partition columns to be appended to each row. /// - /// These MUST have the same count, order, and type than the [`table_partition_cols`]. + /// These MUST have the same count, order, than the [`table_partition_cols`]. /// - /// You may use [`wrap_partition_value_in_dict`] to wrap them if you have used [`wrap_partition_type_in_dict`] to wrap the column type. - /// - /// - /// [`wrap_partition_type_in_dict`]: crate::datasource::physical_plan::wrap_partition_type_in_dict - /// [`wrap_partition_value_in_dict`]: crate::datasource::physical_plan::wrap_partition_value_in_dict /// [`table_partition_cols`]: table::ListingOptions::table_partition_cols pub partition_values: Vec, /// An optional file range for a more fine-grained parallel execution diff --git a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs index 363b10fb0039..86f81b15d80f 100644 --- a/datafusion/core/src/datasource/physical_plan/file_scan_config.rs +++ b/datafusion/core/src/datasource/physical_plan/file_scan_config.rs @@ -18,7 +18,9 @@ //! [`FileScanConfig`] to configure scanning of possibly partitioned //! file sources. -use std::{collections::HashMap, sync::Arc, vec}; +use std::{ + borrow::Cow, collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc, vec, +}; use super::{ get_projected_output_ordering, statistics::MinMaxStatistics, FileGroupPartitioner, @@ -26,12 +28,29 @@ use super::{ use crate::datasource::{listing::PartitionedFile, object_store::ObjectStoreUrl}; use crate::{error::Result, scalar::ScalarValue}; -use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions}; -use arrow_schema::{Field, Schema, SchemaRef}; +use arrow::array::{ArrayData, BufferBuilder}; +use arrow::buffer::Buffer; +use arrow::datatypes::{ArrowNativeType, UInt16Type}; +use arrow_array::{ArrayRef, DictionaryArray, RecordBatch, RecordBatchOptions}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use datafusion_common::stats::Precision; use datafusion_common::{exec_err, ColumnStatistics, DataFusionError, Statistics}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr}; +/// Convert type to a type suitable for use as a [`ListingTable`] +/// partition column. Returns `Dictionary(UInt16, val_type)`, which is +/// a reasonable trade off between a reasonable number of partition +/// values and space efficiency. +/// +/// This use this to specify types for partition columns. However +/// you MAY also choose not to dictionary-encode the data or to use a +/// different dictionary type. +/// +/// [`ListingTable`]: crate::datasource::listing::ListingTable +pub fn wrap_partition_type_in_dict(val_type: DataType) -> DataType { + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(val_type)) +} + /// The base configurations to provide when creating a physical plan for /// any given file format. /// @@ -352,6 +371,10 @@ impl FileScanConfig { /// have all their keys equal to 0. This enables us to re-use the same "all-zero" buffer across batches, /// which makes the space consumption of the partition columns O(batch_size) instead of O(record_count). pub struct PartitionColumnProjector { + /// An Arrow buffer initialized to zeros that represents the key array of all partition + /// columns (partition columns are materialized by dictionary arrays with only one + /// value in the dictionary, thus all the keys are equal to zero). + key_buffer_cache: ZeroBufferGenerators, /// Mapping between the indexes in the list of partition columns and the target /// schema. Sorted by index in the target schema so that we can iterate on it to /// insert the partition columns in the target record batch. @@ -377,6 +400,7 @@ impl PartitionColumnProjector { Self { projected_partition_indexes, + key_buffer_cache: Default::default(), projected_schema, } } @@ -410,7 +434,20 @@ impl PartitionColumnProjector { "Invalid partitioning found on disk".to_string(), ))?; - cols.insert(sidx, create_output_array(p_value, file_batch.num_rows())?) + let partition_value = Cow::Borrowed(p_value); + + let field = self.projected_schema.field(sidx); + let expected_data_type = field.data_type(); + + cols.insert( + sidx, + create_output_array( + &mut self.key_buffer_cache, + partition_value.as_ref(), + expected_data_type, + file_batch.num_rows(), + )?, + ) } RecordBatch::try_new_with_options( @@ -422,17 +459,165 @@ impl PartitionColumnProjector { } } -fn create_output_array(val: &ScalarValue, len: usize) -> Result { - // TODO(@notfilippo): should we reintroduce a way to encode as dictionaries? - val.to_array_of_size(len) +#[derive(Debug, Default)] +struct ZeroBufferGenerators { + gen_i8: ZeroBufferGenerator, + gen_i16: ZeroBufferGenerator, + gen_i32: ZeroBufferGenerator, + gen_i64: ZeroBufferGenerator, + gen_u8: ZeroBufferGenerator, + gen_u16: ZeroBufferGenerator, + gen_u32: ZeroBufferGenerator, + gen_u64: ZeroBufferGenerator, +} + +/// Generate a arrow [`Buffer`] that contains zero values. +#[derive(Debug, Default)] +struct ZeroBufferGenerator +where + T: ArrowNativeType, +{ + cache: Option, + _t: PhantomData, +} + +impl ZeroBufferGenerator +where + T: ArrowNativeType, +{ + const SIZE: usize = std::mem::size_of::(); + + fn get_buffer(&mut self, n_vals: usize) -> Buffer { + match &mut self.cache { + Some(buf) if buf.len() >= n_vals * Self::SIZE => { + buf.slice_with_length(0, n_vals * Self::SIZE) + } + _ => { + let mut key_buffer_builder = BufferBuilder::::new(n_vals); + key_buffer_builder.advance(n_vals); // keys are all 0 + self.cache.insert(key_buffer_builder.finish()).clone() + } + } + } +} + +fn create_dict_array( + buffer_gen: &mut ZeroBufferGenerator, + val: &ScalarValue, + len: usize, + dict_type: &DataType, + inner_type: &DataType, +) -> Result +where + T: ArrowNativeType, +{ + let dict_vals = val.to_array_of_type(inner_type)?; + + let sliced_key_buffer = buffer_gen.get_buffer(len); + + // assemble pieces together + let mut builder = ArrayData::builder(dict_type.clone()) + .len(len) + .add_buffer(sliced_key_buffer); + builder = builder.add_child_data(dict_vals.to_data()); + Ok(Arc::new(DictionaryArray::::from( + builder.build().unwrap(), + ))) +} + +fn create_output_array( + key_buffer_cache: &mut ZeroBufferGenerators, + val: &ScalarValue, + data_type: &DataType, + len: usize, +) -> Result { + if let DataType::Dictionary(key_type, inner_type) = data_type { + match key_type.as_ref() { + DataType::Int8 => { + return create_dict_array( + &mut key_buffer_cache.gen_i8, + val, + len, + data_type, + inner_type, + ); + } + DataType::Int16 => { + return create_dict_array( + &mut key_buffer_cache.gen_i16, + val, + len, + data_type, + inner_type, + ); + } + DataType::Int32 => { + return create_dict_array( + &mut key_buffer_cache.gen_i32, + val, + len, + data_type, + inner_type, + ); + } + DataType::Int64 => { + return create_dict_array( + &mut key_buffer_cache.gen_i64, + val, + len, + data_type, + inner_type, + ); + } + DataType::UInt8 => { + return create_dict_array( + &mut key_buffer_cache.gen_u8, + val, + len, + data_type, + inner_type, + ); + } + DataType::UInt16 => { + return create_dict_array( + &mut key_buffer_cache.gen_u16, + val, + len, + data_type, + inner_type, + ); + } + DataType::UInt32 => { + return create_dict_array( + &mut key_buffer_cache.gen_u32, + val, + len, + data_type, + inner_type, + ); + } + DataType::UInt64 => { + return create_dict_array( + &mut key_buffer_cache.gen_u64, + val, + len, + data_type, + inner_type, + ); + } + _ => {} + } + } + + val.to_array_of_size_and_type(len, data_type) } #[cfg(test)] mod tests { + use arrow_array::Int32Array; + use super::*; use crate::{test::columns, test_util::aggr_test_schema}; - use arrow_array::Int32Array; - use arrow_schema::DataType; #[test] fn physical_plan_config_no_projection() { @@ -441,7 +626,10 @@ mod tests { Arc::clone(&file_schema), None, Statistics::new_unknown(&file_schema), - to_partition_cols(vec![("date".to_owned(), DataType::Utf8)]), + to_partition_cols(vec![( + "date".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + )]), ); let (proj_schema, proj_statistics, _) = conf.project(); @@ -470,9 +658,11 @@ mod tests { // make a table_partition_col as a field let table_partition_col = - Field::new("date", DataType::Utf8, true).with_metadata(HashMap::from_iter( - vec![("key_whatever".to_owned(), "value_whatever".to_owned())], - )); + Field::new("date", wrap_partition_type_in_dict(DataType::Utf8), true) + .with_metadata(HashMap::from_iter(vec![( + "key_whatever".to_owned(), + "value_whatever".to_owned(), + )])); let conf = config_for_projection( Arc::clone(&file_schema), @@ -509,7 +699,10 @@ mod tests { .collect(), total_byte_size: Precision::Absent, }, - to_partition_cols(vec![("date".to_owned(), DataType::Utf8)]), + to_partition_cols(vec![( + "date".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + )]), ); let (proj_schema, proj_statistics, _) = conf.project(); @@ -538,9 +731,18 @@ mod tests { ("c", &vec![10, 11, 12]), ); let partition_cols = vec![ - ("year".to_owned(), DataType::Utf8), - ("month".to_owned(), DataType::Utf8), - ("day".to_owned(), DataType::Utf8), + ( + "year".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "month".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "day".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), ]; // create a projected schema let conf = config_for_projection( @@ -680,8 +882,14 @@ mod tests { fn test_projected_file_schema_with_partition_col() { let schema = aggr_test_schema(); let partition_cols = vec![ - ("part1".to_owned(), DataType::Utf8), - ("part2".to_owned(), DataType::Utf8), + ( + "part1".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "part2".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), ]; // Projected file schema for config with projection including partition column @@ -707,8 +915,14 @@ mod tests { fn test_projected_file_schema_without_projection() { let schema = aggr_test_schema(); let partition_cols = vec![ - ("part1".to_owned(), DataType::Utf8), - ("part2".to_owned(), DataType::Utf8), + ( + "part1".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), + ( + "part2".to_owned(), + wrap_partition_type_in_dict(DataType::Utf8), + ), ]; // Projected file schema for config without projection diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index a6f0d6e1c5a0..940c1451abc0 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -37,7 +37,7 @@ pub use arrow_file::ArrowExec; pub use avro::AvroExec; pub use csv::{CsvConfig, CsvExec, CsvExecBuilder, CsvOpener}; pub use file_groups::FileGroupPartitioner; -pub use file_scan_config::FileScanConfig; +pub use file_scan_config::{wrap_partition_type_in_dict, FileScanConfig}; pub use file_stream::{FileOpenFuture, FileOpener, FileStream, OnError}; pub use json::{JsonOpener, NdJsonExec}; diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index f430c9fa9b7e..2c3cf03d61d6 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -17,14 +17,14 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. -use arrow::array::{Array, ArrayRef}; use arrow::array::NullArray; +use arrow::array::{Array, ArrayRef}; use arrow::compute::{kernels, CastOptions}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::format::DEFAULT_CAST_OPTIONS; +use datafusion_common::logical::eq::LogicallyEq; use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; -use datafusion_common::logical::eq::LogicallyEq; /// The result of evaluating an expression. /// @@ -131,7 +131,11 @@ impl ColumnarValue { }) } - pub fn into_array_of_type(self, num_rows: usize, data_type: &DataType) -> Result { + pub fn into_array_of_type( + self, + num_rows: usize, + data_type: &DataType, + ) -> Result { let array = self.into_array(num_rows)?; if array.data_type() == data_type { Ok(array) @@ -211,7 +215,7 @@ impl ColumnarValue { )), ColumnarValue::Scalar(scalar) => { if scalar.data_type().logically_eq(cast_type) { - return Ok(self.clone()) + return Ok(self.clone()); } let scalar_array = diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 4b37c26c1f7e..559659fdbf24 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,8 +27,8 @@ use arrow::{ record_batch::RecordBatch, }; -use datafusion_common::cast::as_large_list_array; use datafusion_common::logical::eq::LogicallyEq; +use datafusion_common::{cast::as_large_list_array, exec_datafusion_err}; use datafusion_common::{ cast::as_list_array, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, @@ -649,10 +649,10 @@ impl<'a> ConstEvaluator<'a> { let end_type = col_val.data_type(); if end_type.logically_eq(&start_type) && start_type != end_type { return ConstSimplifyResult::SimplifyRuntimeError( - DataFusionError::Execution(format!( + exec_datafusion_err!( "Skipping, end_type {} is logically equal to start_type {} but not strictly equal", end_type, start_type - )), + ), expr, ); } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 139e6c8a3018..c99a55259306 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -308,8 +308,9 @@ impl ProjectionStream { .iter() .zip(&self.schema.fields) .map(|(expr, field)| { - expr.evaluate(batch) - .and_then(|v| v.into_array_of_type(batch.num_rows(), field.data_type())) + expr.evaluate(batch).and_then(|v| { + v.into_array_of_type(batch.num_rows(), field.data_type()) + }) }) .collect::>>()?; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index b18ade44377d..1f358fcc627c 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -40,7 +40,7 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::datasource::listing::{ListingTableUrl, PartitionedFile}; use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::datasource::physical_plan::{ - FileScanConfig, FileSinkConfig, ParquetExec, + wrap_partition_type_in_dict, FileScanConfig, FileSinkConfig, ParquetExec, }; use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::sum::sum_udaf; @@ -682,7 +682,7 @@ async fn roundtrip_parquet_exec_with_table_partition_cols() -> Result<()> { limit: None, table_partition_cols: vec![Field::new( "part".to_string(), - DataType::Int16, + wrap_partition_type_in_dict(DataType::Int16), false, )], output_ordering: vec![], From 58ce569876ac8ce44d15b4b46b5c73afa7d9bfd9 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Fri, 16 Aug 2024 14:42:42 +0200 Subject: [PATCH 04/12] Add peekable --- datafusion/common/src/logical/eq.rs | 17 +++++++++++++++++ datafusion/common/src/logical/mod.rs | 17 +++++++++++++++++ datafusion/common/src/scalar/mod.rs | 2 +- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/datafusion/common/src/logical/eq.rs b/datafusion/common/src/logical/eq.rs index 902f655cd961..c8d60736fe95 100644 --- a/datafusion/common/src/logical/eq.rs +++ b/datafusion/common/src/logical/eq.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use arrow_schema::DataType; pub trait LogicallyEq { diff --git a/datafusion/common/src/logical/mod.rs b/datafusion/common/src/logical/mod.rs index 798e468ebe6d..ff72c3dd28d2 100644 --- a/datafusion/common/src/logical/mod.rs +++ b/datafusion/common/src/logical/mod.rs @@ -1 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + pub mod eq; diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8e5afce2a6e8..38aa675cfabc 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1639,7 +1639,7 @@ impl ScalarValue { scalars: impl IntoIterator, data_type: &DataType, ) -> Result { - let scalars = scalars.into_iter(); + let mut scalars = scalars.into_iter().peekable(); /// Creates an array of $ARRAY_TY by unpacking values of /// SCALAR_TY for primitive types From 349a04e778c7945488413ee87739e24aea5498d4 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Fri, 16 Aug 2024 15:12:24 +0200 Subject: [PATCH 05/12] Fix impl IntoIterator with Peekable --- datafusion/common/src/scalar/mod.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 38aa675cfabc..8d2b25021c7e 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -27,7 +27,7 @@ use std::convert::Infallible; use std::fmt; use std::hash::Hash; use std::hash::Hasher; -use std::iter::repeat; +use std::iter::{repeat, Peekable}; use std::str::FromStr; use std::sync::Arc; @@ -1632,15 +1632,21 @@ impl ScalarValue { Some(sv) => sv.data_type(), }; - Self::iter_to_array_of_type(scalars, &data_type) + Self::iter_to_array_of_type_internal(&mut scalars, &data_type) } - fn iter_to_array_of_type( + pub fn iter_to_array_of_type( scalars: impl IntoIterator, data_type: &DataType, ) -> Result { let mut scalars = scalars.into_iter().peekable(); + Self::iter_to_array_of_type_internal(&mut scalars, data_type) + } + fn iter_to_array_of_type_internal( + scalars: &mut Peekable>, + data_type: &DataType, + ) -> Result { /// Creates an array of $ARRAY_TY by unpacking values of /// SCALAR_TY for primitive types macro_rules! build_array_primitive { @@ -1729,8 +1735,10 @@ impl ScalarValue { DataType::UInt16 => build_array_primitive!(UInt16Array, UInt16), DataType::UInt32 => build_array_primitive!(UInt32Array, UInt32), DataType::UInt64 => build_array_primitive!(UInt64Array, UInt64), + DataType::Utf8View => build_array_string!(StringViewArray, Utf8), DataType::Utf8 => build_array_string!(StringArray, Utf8), DataType::LargeUtf8 => build_array_string!(LargeStringArray, Utf8), + DataType::BinaryView => build_array_string!(BinaryViewArray, Binary), DataType::Binary => build_array_string!(BinaryArray, Binary), DataType::LargeBinary => build_array_string!(LargeBinaryArray, Binary), DataType::Date32 => build_array_primitive!(Date32Array, Date32), @@ -1815,6 +1823,7 @@ impl ScalarValue { } DataType::List(_) | DataType::LargeList(_) + | DataType::Map(_, _) | DataType::Struct(_) | DataType::Union(_, _) => { let arrays = scalars.map(|s| s.to_array()).collect::>>()?; @@ -1822,7 +1831,7 @@ impl ScalarValue { arrow::compute::concat(arrays.as_slice())? } DataType::Dictionary(key_type, value_type) => { - let values = Self::iter_to_array_of_type(scalars, value_type)?; + let values = Self::iter_to_array_of_type_internal(scalars, value_type)?; assert_eq!(values.data_type(), value_type.as_ref()); match key_type.as_ref() { @@ -1865,10 +1874,7 @@ impl ScalarValue { | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) - | DataType::Map(_, _) | DataType::RunEndEncoded(_, _) - | DataType::Utf8View - | DataType::BinaryView | DataType::ListView(_) | DataType::LargeListView(_) => { return _not_impl_err!( From 301f031d2851042f37c16db00a35d756152d9e85 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Fri, 16 Aug 2024 17:27:11 +0200 Subject: [PATCH 06/12] Run rust_lint --- datafusion/expr-common/src/columnar_value.rs | 5 -- datafusion/functions/benches/repeat.rs | 12 ++-- datafusion/functions/src/string/repeat.rs | 34 ---------- datafusion/functions/src/unicode/reverse.rs | 22 +------ datafusion/functions/src/unicode/substr.rs | 65 -------------------- 5 files changed, 8 insertions(+), 130 deletions(-) diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 2c3cf03d61d6..78b039f8f6fa 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -22,7 +22,6 @@ use arrow::array::{Array, ArrayRef}; use arrow::compute::{kernels, CastOptions}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::format::DEFAULT_CAST_OPTIONS; -use datafusion_common::logical::eq::LogicallyEq; use datafusion_common::{internal_err, Result, ScalarValue}; use std::sync::Arc; @@ -214,10 +213,6 @@ impl ColumnarValue { kernels::cast::cast_with_options(array, cast_type, &cast_options)?, )), ColumnarValue::Scalar(scalar) => { - if scalar.data_type().logically_eq(cast_type) { - return Ok(self.clone()); - } - let scalar_array = if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { if let ScalarValue::Float64(Some(float_ts)) = scalar { diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 916c8374e5fb..e45313660ea2 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -67,7 +67,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 32, repeat_times, true); group.bench_function( - &format!( + format!( "repeat_string_view [size={}, repeat_times={}]", size, repeat_times ), @@ -76,7 +76,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 32, repeat_times, false); group.bench_function( - &format!( + format!( "repeat_string [size={}, repeat_times={}]", size, repeat_times ), @@ -85,7 +85,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 32, repeat_times, false); group.bench_function( - &format!( + format!( "repeat_large_string [size={}, repeat_times={}]", size, repeat_times ), @@ -103,7 +103,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 32, repeat_times, true); group.bench_function( - &format!( + format!( "repeat_string_view [size={}, repeat_times={}]", size, repeat_times ), @@ -112,7 +112,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 32, repeat_times, false); group.bench_function( - &format!( + format!( "repeat_string [size={}, repeat_times={}]", size, repeat_times ), @@ -121,7 +121,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args::(size, 32, repeat_times, false); group.bench_function( - &format!( + format!( "repeat_large_string [size={}, repeat_times={}]", size, repeat_times ), diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index 20e4462784b8..f04e4ce87546 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -178,40 +178,6 @@ mod tests { StringArray ); - test_function!( - RepeatFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ], - Ok(Some("PgPgPgPg")), - &str, - Utf8, - StringArray - ); - test_function!( - RepeatFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(None)), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - RepeatFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - Ok(()) } } diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index da16d3ee3752..ef290a9b3970 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -104,8 +104,8 @@ fn reverse_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor>( #[cfg(test)] mod tests { - use arrow::array::{Array, LargeStringArray, StringArray}; - use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use arrow::array::{Array, StringArray}; + use arrow::datatypes::DataType::Utf8; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -123,24 +123,6 @@ mod tests { Utf8, StringArray ); - - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))], - $EXPECTED, - &str, - LargeUtf8, - LargeStringArray - ); - - test_function!( - ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))], - $EXPECTED, - &str, - Utf8, - StringArray - ); }; } diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 9fd8c75eab23..e756d4b1af7d 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -181,71 +181,6 @@ mod tests { #[test] fn test_functions() -> Result<()> { - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(None)), - ColumnarValue::Scalar(ScalarValue::from(1i64)), - ], - Ok(None), - &str, - Utf8, - StringArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "alphabet" - )))), - ColumnarValue::Scalar(ScalarValue::from(0i64)), - ], - Ok(Some("alphabet")), - &str, - Utf8, - StringArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "joséésoj" - )))), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ], - Ok(Some("ésoj")), - &str, - Utf8, - StringArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "alphabet" - )))), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), - ], - Ok(Some("ph")), - &str, - Utf8, - StringArray - ); - test_function!( - SubstrFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "alphabet" - )))), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::from(20i64)), - ], - Ok(Some("phabet")), - &str, - Utf8, - StringArray - ); test_function!( SubstrFunc::new(), &[ From b55318c3c7f054f8ce968a7f0b43cdb5d6b3adb3 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Mon, 19 Aug 2024 23:42:21 +0200 Subject: [PATCH 07/12] Phyisical Scalar --- datafusion-examples/examples/advanced_udf.rs | 62 ++--- .../examples/optimizer_rule.rs | 2 +- datafusion/common/src/scalar/mod.rs | 7 +- .../core/src/physical_optimizer/pruning.rs | 18 +- .../user_defined_scalar_functions.rs | 20 +- datafusion/expr-common/src/columnar_value.rs | 134 +++++++--- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/lib.rs | 2 +- .../src/approx_percentile_cont.rs | 2 +- datafusion/functions-nested/benches/map.rs | 4 +- datafusion/functions-nested/src/map.rs | 6 +- datafusion/functions-nested/src/range.rs | 4 +- datafusion/functions-nested/src/utils.rs | 2 +- datafusion/functions/benches/concat.rs | 2 +- datafusion/functions/benches/date_bin.rs | 2 +- datafusion/functions/benches/ltrim.rs | 2 +- datafusion/functions/benches/make_date.rs | 12 +- datafusion/functions/benches/nullif.rs | 2 +- datafusion/functions/benches/to_char.rs | 9 +- datafusion/functions/src/core/arrowtypeof.rs | 2 +- datafusion/functions/src/core/coalesce.rs | 8 +- datafusion/functions/src/core/getfield.rs | 6 +- datafusion/functions/src/core/named_struct.rs | 11 +- datafusion/functions/src/core/nullif.rs | 26 +- datafusion/functions/src/core/nvl.rs | 20 +- datafusion/functions/src/core/nvl2.rs | 2 +- datafusion/functions/src/core/struct.rs | 6 +- datafusion/functions/src/crypto/basic.rs | 18 +- datafusion/functions/src/datetime/common.rs | 26 +- datafusion/functions/src/datetime/date_bin.rs | 238 +++++++++--------- .../functions/src/datetime/date_part.rs | 22 +- .../functions/src/datetime/date_trunc.rs | 62 +++-- .../functions/src/datetime/make_date.rs | 63 +++-- datafusion/functions/src/datetime/to_char.rs | 77 +++--- .../functions/src/datetime/to_local_time.rs | 96 ++++--- .../functions/src/datetime/to_timestamp.rs | 26 +- datafusion/functions/src/encoding/inner.rs | 14 +- datafusion/functions/src/math/log.rs | 42 ++-- datafusion/functions/src/math/pi.rs | 2 +- datafusion/functions/src/math/round.rs | 90 ++++--- datafusion/functions/src/math/trunc.rs | 28 ++- datafusion/functions/src/regex/regexplike.rs | 5 +- datafusion/functions/src/regex/regexpmatch.rs | 5 +- .../functions/src/regex/regexpreplace.rs | 4 +- datafusion/functions/src/string/ascii.rs | 2 +- datafusion/functions/src/string/bit_length.rs | 4 +- datafusion/functions/src/string/common.rs | 4 +- datafusion/functions/src/string/concat.rs | 38 +-- datafusion/functions/src/string/concat_ws.rs | 103 ++++---- datafusion/functions/src/string/contains.rs | 12 +- datafusion/functions/src/string/ends_with.rs | 16 +- datafusion/functions/src/string/initcap.rs | 8 +- .../functions/src/string/octet_length.rs | 30 +-- datafusion/functions/src/string/repeat.rs | 12 +- datafusion/functions/src/string/split_part.rs | 24 +- .../functions/src/string/starts_with.rs | 4 +- .../functions/src/unicode/character_length.rs | 2 +- datafusion/functions/src/unicode/left.rs | 40 +-- datafusion/functions/src/unicode/lpad.rs | 10 +- datafusion/functions/src/unicode/reverse.rs | 2 +- datafusion/functions/src/unicode/right.rs | 40 +-- datafusion/functions/src/unicode/rpad.rs | 78 +++--- datafusion/functions/src/unicode/substr.rs | 100 ++++---- .../functions/src/unicode/substrindex.rs | 42 ++-- datafusion/functions/src/unicode/translate.rs | 36 +-- datafusion/functions/src/utils.rs | 8 +- .../optimizer/src/analyzer/type_coercion.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../simplify_expressions/expr_simplifier.rs | 35 +-- datafusion/physical-expr-common/src/datum.rs | 9 +- .../physical-expr/src/expressions/binary.rs | 11 +- .../physical-expr/src/expressions/case.rs | 24 +- .../physical-expr/src/expressions/in_list.rs | 10 +- .../src/expressions/is_not_null.rs | 4 +- .../physical-expr/src/expressions/is_null.rs | 4 +- .../physical-expr/src/expressions/literal.rs | 2 +- .../physical-expr/src/expressions/negative.rs | 6 +- .../physical-expr/src/expressions/not.rs | 10 +- .../physical-expr/src/expressions/try_cast.rs | 17 +- datafusion/physical-expr/src/functions.rs | 6 +- datafusion/physical-plan/src/projection.rs | 8 +- datafusion/physical-plan/src/values.rs | 5 +- .../sqllogictest/test_files/string_view.slt | 25 +- 83 files changed, 1069 insertions(+), 919 deletions(-) diff --git a/datafusion-examples/examples/advanced_udf.rs b/datafusion-examples/examples/advanced_udf.rs index 9a3ee9c8ebcd..22d37043e473 100644 --- a/datafusion-examples/examples/advanced_udf.rs +++ b/datafusion-examples/examples/advanced_udf.rs @@ -96,8 +96,8 @@ impl ScalarUDFImpl for PowUdf { // function, but we check again to make sure assert_eq!(args.len(), 2); let (base, exp) = (&args[0], &args[1]); - assert_eq!(base.data_type(), DataType::Float64); - assert_eq!(exp.data_type(), DataType::Float64); + assert_eq!(base.data_type(), &DataType::Float64); + assert_eq!(exp.data_type(), &DataType::Float64); match (base, exp) { // For demonstration purposes we also implement the scalar / scalar @@ -108,28 +108,31 @@ impl ScalarUDFImpl for PowUdf { // the DataFusion expression simplification logic will often invoke // this path once during planning, and simply use the result during // execution. - ( - ColumnarValue::Scalar(ScalarValue::Float64(base)), - ColumnarValue::Scalar(ScalarValue::Float64(exp)), - ) => { - // compute the output. Note DataFusion treats `None` as NULL. - let res = match (base, exp) { - (Some(base), Some(exp)) => Some(base.powf(*exp)), - // one or both arguments were NULL - _ => None, - }; - Ok(ColumnarValue::Scalar(ScalarValue::from(res))) + (ColumnarValue::Scalar(base), ColumnarValue::Scalar(exp)) => { + match (base.value(), exp.value()) { + (ScalarValue::Float64(base), ScalarValue::Float64(exp)) => { + // compute the output. Note DataFusion treats `None` as NULL. + let res = match (base, exp) { + (Some(base), Some(exp)) => Some(base.powf(*exp)), + // one or both arguments were NULL + _ => None, + }; + Ok(ColumnarValue::from(ScalarValue::from(res))) + } + _ => { + internal_err!("Invalid argument types to pow function") + } + } } // special case if the exponent is a constant - ( - ColumnarValue::Array(base_array), - ColumnarValue::Scalar(ScalarValue::Float64(exp)), - ) => { - let result_array = match exp { + (ColumnarValue::Array(base_array), ColumnarValue::Scalar(exp)) => { + let result_array = match exp.value() { // a ^ null = null - None => new_null_array(base_array.data_type(), base_array.len()), + ScalarValue::Float64(None) => { + new_null_array(base_array.data_type(), base_array.len()) + } // a ^ exp - Some(exp) => { + ScalarValue::Float64(Some(exp)) => { // DataFusion has ensured both arguments are Float64: let base_array = base_array.as_primitive::(); // calculate the result for every row. The `unary` @@ -139,24 +142,25 @@ impl ScalarUDFImpl for PowUdf { compute::unary(base_array, |base| base.powf(*exp)); Arc::new(res) } + _ => return internal_err!("Invalid argument types to pow function"), }; Ok(ColumnarValue::Array(result_array)) } // special case if the base is a constant (note this code is quite // similar to the previous case, so we omit comments) - ( - ColumnarValue::Scalar(ScalarValue::Float64(base)), - ColumnarValue::Array(exp_array), - ) => { - let res = match base { - None => new_null_array(exp_array.data_type(), exp_array.len()), - Some(base) => { + (ColumnarValue::Scalar(base), ColumnarValue::Array(exp_array)) => { + let res = match base.value() { + ScalarValue::Float64(None) => { + new_null_array(exp_array.data_type(), exp_array.len()) + } + ScalarValue::Float64(Some(base)) => { let exp_array = exp_array.as_primitive::(); let res: Float64Array = compute::unary(exp_array, |exp| base.powf(exp)); Arc::new(res) } + _ => return internal_err!("Invalid argument types to pow function"), }; Ok(ColumnarValue::Array(res)) } @@ -169,10 +173,6 @@ impl ScalarUDFImpl for PowUdf { )?; Ok(ColumnarValue::Array(Arc::new(res))) } - // if the types were not float, it is a bug in DataFusion - _ => { - internal_err!("Invalid argument types to pow function") - } } } diff --git a/datafusion-examples/examples/optimizer_rule.rs b/datafusion-examples/examples/optimizer_rule.rs index b4663b345f64..cf24a4b23eb5 100644 --- a/datafusion-examples/examples/optimizer_rule.rs +++ b/datafusion-examples/examples/optimizer_rule.rs @@ -207,7 +207,7 @@ impl ScalarUDFImpl for MyEq { fn invoke(&self, _args: &[ColumnarValue]) -> Result { // this example simply returns "true" which is not what a real // implementation would do. - Ok(ColumnarValue::Scalar(ScalarValue::from(true))) + Ok(ColumnarValue::from(ScalarValue::from(true))) } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 8d2b25021c7e..3d9c11a98b12 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -1813,8 +1813,11 @@ impl ScalarValue { if let Some(DataType::FixedSizeList(f, l)) = first_non_null_data_type { for array in arrays.iter_mut() { if array.is_null(0) { - *array = - Arc::new(FixedSizeListArray::new_null(f.clone(), l, 1)); + *array = Arc::new(FixedSizeListArray::new_null( + Arc::clone(&f), + l, + 1, + )); } } } diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 0ef390fff45c..fb0e658cb8b9 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -685,14 +685,16 @@ impl BoolVecBuilder { ColumnarValue::Array(array) => { self.combine_array(array.as_boolean()); } - ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) => { - // False means all containers can not pass the predicate - self.inner = vec![false; self.inner.len()]; - } - _ => { - // Null or true means the rows in container may pass this - // conjunct so we can't prune any containers based on that - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Boolean(Some(false)) => { + // False means all containers can not pass the predicate + self.inner = vec![false; self.inner.len()]; + } + _ => { + // Null or true means the rows in container may pass this + // conjunct so we can't prune any containers based on that + } + }, } } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 259cce74f2e5..1dcaf09784e4 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -212,7 +212,7 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF { } fn invoke_no_args(&self, _number_rows: usize) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100)))) + Ok(ColumnarValue::from(ScalarValue::Int32(Some(100)))) } } @@ -323,7 +323,7 @@ async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { vec![DataType::Int32], Arc::new(DataType::Int32), Volatility::Immutable, - Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))), + Arc::new(move |_| Ok(ColumnarValue::from(ScalarValue::Int32(Some(1))))), )); // Make sure that the UDF is used instead of the built-in function @@ -669,7 +669,10 @@ impl ScalarUDFImpl for TakeUDF { // The actual implementation fn invoke(&self, args: &[ColumnarValue]) -> Result { let take_idx = match &args[2] { - ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Int64(Some(v)) if v < &2 => *v as usize, + _ => unreachable!(), + }, _ => unreachable!(), }; match &args[take_idx] { @@ -1070,11 +1073,12 @@ impl ScalarUDFImpl for MyRegexUdf { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args { - [ColumnarValue::Scalar(ScalarValue::Utf8(value))] => { - Ok(ColumnarValue::Scalar(ScalarValue::Boolean( - self.matches(value.as_deref()), - ))) - } + [ColumnarValue::Scalar(scalar)] => match scalar.value() { + ScalarValue::Utf8(value) => Ok(ColumnarValue::from( + ScalarValue::Boolean(self.matches(value.as_deref())), + )), + _ => exec_err!("regex_udf only accepts a Utf8 arguments"), + }, [ColumnarValue::Array(values)] => { let mut builder = BooleanBuilder::with_capacity(values.len()); for value in values.as_string::() { diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 78b039f8f6fa..3191f653e6db 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -22,7 +22,7 @@ use arrow::array::{Array, ArrayRef}; use arrow::compute::{kernels, CastOptions}; use arrow::datatypes::{DataType, TimeUnit}; use datafusion_common::format::DEFAULT_CAST_OPTIONS; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use std::sync::Arc; /// The result of evaluating an expression. @@ -89,7 +89,94 @@ pub enum ColumnarValue { /// Array of values Array(ArrayRef), /// A single value - Scalar(ScalarValue), + Scalar(Scalar), +} + +#[derive(Clone, Debug)] +pub struct Scalar { + value: ScalarValue, + data_type: DataType, +} + +impl From for Scalar { + fn from(value: ScalarValue) -> Self { + Self { + data_type: value.data_type(), + value, + } + } +} + +impl TryFrom for Scalar { + type Error = DataFusionError; + fn try_from(value: DataType) -> Result { + Ok(Self { + value: ScalarValue::try_from(&value)?, + data_type: value, + }) + } +} + +impl PartialEq for Scalar { + fn eq(&self, other: &Self) -> bool { + self.value.eq(&other.value) + } +} + +impl Scalar { + pub fn new(value: ScalarValue, data_type: DataType) -> Self { + Self { value, data_type } + } + + pub fn try_from_array(array: &dyn Array, index: usize) -> Result { + let value = ScalarValue::try_from_array(array, index)?; + Ok(Self::new(value, array.data_type().clone())) + } + + #[inline] + pub fn value(&self) -> &ScalarValue { + &self.value + } + + #[inline] + pub fn into_value(self) -> ScalarValue { + self.value + } + + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + pub fn with_data_type(mut self, data_type: DataType) -> Self { + self.data_type = data_type; + self + } + + pub fn to_array_of_size(&self, size: usize) -> Result { + self.value.to_array_of_size_and_type(size, &self.data_type) + } + + pub fn to_array(&self) -> Result { + self.to_array_of_size(1) + } + + pub fn to_scalar(&self) -> Result> { + Ok(arrow::array::Scalar::new(self.to_array()?)) + } + + pub fn iter_to_array(scalars: impl IntoIterator) -> Result { + let mut scalars = scalars.into_iter().peekable(); + + // figure out the type based on the first element + let data_type = match scalars.peek() { + None => { + return exec_err!("Empty iterator passed to Scalar::iter_to_array"); + } + Some(sv) => sv.data_type().clone(), + }; + + ScalarValue::iter_to_array_of_type(scalars.map(|scalar| scalar.value), &data_type) + } } impl From for ColumnarValue { @@ -100,15 +187,15 @@ impl From for ColumnarValue { impl From for ColumnarValue { fn from(value: ScalarValue) -> Self { - ColumnarValue::Scalar(value) + ColumnarValue::Scalar(Scalar::from(value)) } } impl ColumnarValue { - pub fn data_type(&self) -> DataType { + pub fn data_type(&self) -> &DataType { match self { - ColumnarValue::Array(array_value) => array_value.data_type().clone(), - ColumnarValue::Scalar(scalar_value) => scalar_value.data_type(), + ColumnarValue::Array(array_value) => array_value.data_type(), + ColumnarValue::Scalar(scalar) => scalar.data_type(), } } @@ -130,24 +217,6 @@ impl ColumnarValue { }) } - pub fn into_array_of_type( - self, - num_rows: usize, - data_type: &DataType, - ) -> Result { - let array = self.into_array(num_rows)?; - if array.data_type() == data_type { - Ok(array) - } else { - let cast_array = kernels::cast::cast_with_options( - &array, - data_type, - &DEFAULT_CAST_OPTIONS, - )?; - Ok(cast_array) - } - } - /// null columnar values are implemented as a null array in order to pass batch /// num_rows pub fn create_null_array(num_rows: usize) -> Self { @@ -215,7 +284,7 @@ impl ColumnarValue { ColumnarValue::Scalar(scalar) => { let scalar_array = if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) { - if let ScalarValue::Float64(Some(float_ts)) = scalar { + if let ScalarValue::Float64(Some(float_ts)) = scalar.value() { ScalarValue::Int64(Some( (float_ts * 1_000_000_000_f64).trunc() as i64, )) @@ -232,7 +301,10 @@ impl ColumnarValue { &cast_options, )?; let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) + Ok(ColumnarValue::Scalar(Scalar::new( + cast_scalar, + cast_type.clone(), + ))) } } } @@ -268,7 +340,7 @@ mod tests { TestCase { input: vec![ ColumnarValue::Array(make_array(1, 3)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), + ColumnarValue::from(ScalarValue::Int32(Some(100))), ], expected: vec![ make_array(1, 3), @@ -278,7 +350,7 @@ mod tests { // scalar and array TestCase { input: vec![ - ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), + ColumnarValue::from(ScalarValue::Int32(Some(100))), ColumnarValue::Array(make_array(1, 3)), ], expected: vec![ @@ -289,9 +361,9 @@ mod tests { // multiple scalars and array TestCase { input: vec![ - ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), + ColumnarValue::from(ScalarValue::Int32(Some(100))), ColumnarValue::Array(make_array(1, 3)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(200))), + ColumnarValue::from(ScalarValue::Int32(Some(200))), ], expected: vec![ make_array(100, 3), // scalar is expanded @@ -324,7 +396,7 @@ mod tests { fn values_to_arrays_mixed_length_and_scalar() { ColumnarValue::values_to_arrays(&[ ColumnarValue::Array(make_array(1, 3)), - ColumnarValue::Scalar(ScalarValue::Int32(Some(100))), + ColumnarValue::from(ScalarValue::Int32(Some(100))), ColumnarValue::Array(make_array(2, 7)), ]) .unwrap(); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index b4d489cc7c1e..26392634cc8b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2784,7 +2784,7 @@ mod test { } fn invoke(&self, _args: &[ColumnarValue]) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + Ok(ColumnarValue::from(ScalarValue::from("a"))) } } let udf = Arc::new(ScalarUDF::from(TestScalarUDF { diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 260065f69af9..2b3faf2b8d21 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -68,7 +68,7 @@ pub mod window_state; pub use built_in_window_function::BuiltInWindowFunction; pub use datafusion_expr_common::accumulator::Accumulator; -pub use datafusion_expr_common::columnar_value::ColumnarValue; +pub use datafusion_expr_common::columnar_value::{ColumnarValue, Scalar}; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub use datafusion_expr_common::operator::Operator; pub use datafusion_expr_common::signature::{ diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs index 89d827e86859..7f1ffb6d7bed 100644 --- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -147,7 +147,7 @@ fn get_scalar_value(expr: &Arc) -> Result { let empty_schema = Arc::new(Schema::empty()); let batch = RecordBatch::new_empty(Arc::clone(&empty_schema)); if let ColumnarValue::Scalar(s) = expr.evaluate(&batch)? { - Ok(s) + Ok(s.into_value()) } else { internal_err!("Didn't expect ColumnarValue::Array") } diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index c9a12eefa4fa..c6dfa4a50ec8 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -87,8 +87,8 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(Int32Array::from(values(&mut rng))), None, ); - let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); - let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); + let keys = ColumnarValue::from(ScalarValue::List(Arc::new(key_list))); + let values = ColumnarValue::from(ScalarValue::List(Arc::new(value_list))); b.iter(|| { black_box( diff --git a/datafusion/functions-nested/src/map.rs b/datafusion/functions-nested/src/map.rs index b6068fdff0d5..8efdf1bd5b5b 100644 --- a/datafusion/functions-nested/src/map.rs +++ b/datafusion/functions-nested/src/map.rs @@ -59,7 +59,7 @@ fn make_map_batch(args: &[ColumnarValue]) -> datafusion_common::Result datafusion_common::Result { match columnar_value { - ColumnarValue::Scalar(value) => match value { + ColumnarValue::Scalar(value) => match value.value() { ScalarValue::List(array) => Ok(array.value(0)), ScalarValue::LargeList(array) => Ok(array.value(0)), ScalarValue::FixedSizeList(array) => Ok(array.value(0)), @@ -135,7 +135,7 @@ fn make_map_batch_internal( let map_array = Arc::new(MapArray::from(map_data)); Ok(if can_evaluate_to_const { - ColumnarValue::Scalar(ScalarValue::try_from_array(map_array.as_ref(), 0)?) + ColumnarValue::from(ScalarValue::try_from_array(map_array.as_ref(), 0)?) } else { ColumnarValue::Array(map_array) }) diff --git a/datafusion/functions-nested/src/range.rs b/datafusion/functions-nested/src/range.rs index 5b7315719631..489660c3b340 100644 --- a/datafusion/functions-nested/src/range.rs +++ b/datafusion/functions-nested/src/range.rs @@ -88,7 +88,7 @@ impl ScalarUDFImpl for Range { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - if args.iter().any(|arg| arg.data_type() == DataType::Null) { + if args.iter().any(|arg| arg.data_type() == &DataType::Null) { return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); } match args[0].data_type() { @@ -159,7 +159,7 @@ impl ScalarUDFImpl for GenSeries { } fn invoke(&self, args: &[ColumnarValue]) -> Result { - if args.iter().any(|arg| arg.data_type() == DataType::Null) { + if args.iter().any(|arg| arg.data_type() == &DataType::Null) { return Ok(ColumnarValue::Array(Arc::new(NullArray::new(1)))); } match args[0].data_type() { diff --git a/datafusion/functions-nested/src/utils.rs b/datafusion/functions-nested/src/utils.rs index 688e1633e5cf..65d96269d779 100644 --- a/datafusion/functions-nested/src/utils.rs +++ b/datafusion/functions-nested/src/utils.rs @@ -83,7 +83,7 @@ where if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(ColumnarValue::Scalar) + result.map(ColumnarValue::from) } else { result.map(ColumnarValue::Array) } diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 91c46ac775a8..bd3bc31b0c65 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -28,7 +28,7 @@ fn create_args(size: usize, str_len: usize) -> Vec { let scalar = ScalarValue::Utf8(Some(", ".to_string())); vec![ ColumnarValue::Array(Arc::clone(&array) as ArrayRef), - ColumnarValue::Scalar(scalar), + ColumnarValue::from(scalar), ColumnarValue::Array(array), ] } diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index c881947354fd..7a92037ccc5d 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -40,7 +40,7 @@ fn timestamps(rng: &mut ThreadRng) -> TimestampSecondArray { fn criterion_benchmark(c: &mut Criterion) { c.bench_function("date_bin_1000", |b| { let mut rng = rand::thread_rng(); - let interval = ColumnarValue::Scalar(ScalarValue::new_interval_dt(0, 1_000_000)); + let interval = ColumnarValue::from(ScalarValue::new_interval_dt(0, 1_000_000)); let timestamps = ColumnarValue::Array(Arc::new(timestamps(&mut rng)) as ArrayRef); let udf = date_bin(); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 01acb9de3381..525430e21ea1 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -30,7 +30,7 @@ fn create_args(size: usize, characters: &str) -> Vec { let array = Arc::new(StringArray::from_iter_values(iter)) as ArrayRef; vec![ ColumnarValue::Array(array), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(characters.to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some(characters.to_string()))), ] } diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index cb8f1abe6d5d..a865953897eb 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -72,7 +72,7 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_scalar_col_col_1000", |b| { let mut rng = rand::thread_rng(); - let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); + let year = ColumnarValue::from(ScalarValue::Int32(Some(2025))); let months = ColumnarValue::Array(Arc::new(months(&mut rng)) as ArrayRef); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); @@ -87,8 +87,8 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function("make_date_scalar_scalar_col_1000", |b| { let mut rng = rand::thread_rng(); - let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); - let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); + let year = ColumnarValue::from(ScalarValue::Int32(Some(2025))); + let month = ColumnarValue::from(ScalarValue::Int32(Some(11))); let days = ColumnarValue::Array(Arc::new(days(&mut rng)) as ArrayRef); b.iter(|| { @@ -101,9 +101,9 @@ fn criterion_benchmark(c: &mut Criterion) { }); c.bench_function("make_date_scalar_scalar_scalar", |b| { - let year = ColumnarValue::Scalar(ScalarValue::Int32(Some(2025))); - let month = ColumnarValue::Scalar(ScalarValue::Int32(Some(11))); - let day = ColumnarValue::Scalar(ScalarValue::Int32(Some(26))); + let year = ColumnarValue::from(ScalarValue::Int32(Some(2025))); + let month = ColumnarValue::from(ScalarValue::Int32(Some(11))); + let day = ColumnarValue::from(ScalarValue::Int32(Some(26))); b.iter(|| { black_box( diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index dfabad335835..31192c1a749f 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -29,7 +29,7 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let array = Arc::new(create_string_array_with_len::(size, 0.2, 32)); let args = vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("abcd".to_string()))), ColumnarValue::Array(array), ]; c.bench_function(&format!("nullif scalar array: {}", size), |b| { diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index d9a153e64abc..14819c90a7a5 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -98,7 +98,7 @@ fn criterion_benchmark(c: &mut Criterion) { let mut rng = rand::thread_rng(); let data = ColumnarValue::Array(Arc::new(data(&mut rng)) as ArrayRef); let patterns = - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%Y-%m-%d".to_string()))); + ColumnarValue::from(ScalarValue::Utf8(Some("%Y-%m-%d".to_string()))); b.iter(|| { black_box( @@ -118,10 +118,9 @@ fn criterion_benchmark(c: &mut Criterion) { .and_utc() .timestamp_nanos_opt() .unwrap(); - let data = ColumnarValue::Scalar(TimestampNanosecond(Some(timestamp), None)); - let pattern = ColumnarValue::Scalar(ScalarValue::Utf8(Some( - "%d-%m-%Y %H:%M:%S".to_string(), - ))); + let data = ColumnarValue::from(TimestampNanosecond(Some(timestamp), None)); + let pattern = + ColumnarValue::from(ScalarValue::Utf8(Some("%d-%m-%Y %H:%M:%S".to_string()))); b.iter(|| { black_box( diff --git a/datafusion/functions/src/core/arrowtypeof.rs b/datafusion/functions/src/core/arrowtypeof.rs index cc5e7e619bd8..dd502fb2686d 100644 --- a/datafusion/functions/src/core/arrowtypeof.rs +++ b/datafusion/functions/src/core/arrowtypeof.rs @@ -65,7 +65,7 @@ impl ScalarUDFImpl for ArrowTypeOfFunc { } let input_data_type = args[0].data_type(); - Ok(ColumnarValue::Scalar(ScalarValue::from(format!( + Ok(ColumnarValue::from(ScalarValue::from(format!( "{input_data_type}" )))) } diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 15a3ddd9d6e9..c790aecbdb4f 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -91,11 +91,11 @@ impl ScalarUDFImpl for CoalesceFunc { current_value = zip(&to_apply, array, ¤t_value)?; remainder = and(&remainder, &is_null(array)?)?; } - ColumnarValue::Scalar(value) => { - if value.is_null() { + ColumnarValue::Scalar(scalar) => { + if scalar.value().is_null() { continue; } else { - let last_value = value.to_scalar()?; + let last_value = scalar.to_scalar()?; current_value = zip(&remainder, &last_value, ¤t_value)?; break; } @@ -110,7 +110,7 @@ impl ScalarUDFImpl for CoalesceFunc { let result = args .iter() .filter_map(|x| match x { - ColumnarValue::Scalar(s) if !s.is_null() => Some(x.clone()), + ColumnarValue::Scalar(s) if !s.value().is_null() => Some(x.clone()), _ => None, }) .next() diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index a51f895c5084..756109edc4e0 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -168,7 +168,7 @@ impl ScalarUDFImpl for GetFieldFunc { } if args[0].data_type().is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Null)); + return Ok(ColumnarValue::from(ScalarValue::Null)); } let arrays = ColumnarValue::values_to_arrays(args)?; @@ -183,7 +183,7 @@ impl ScalarUDFImpl for GetFieldFunc { } }; - match (array.data_type(), name) { + match (array.data_type(), name.value()) { (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { let map_array = as_map_array(array.as_ref())?; let key_scalar: Scalar>> = Scalar::new(StringArray::from(vec![k.clone()])); @@ -227,7 +227,7 @@ impl ScalarUDFImpl for GetFieldFunc { "get indexed field is only possible on struct with utf8 indexes. \ Tried with {name:?} index" ), - (DataType::Null, _) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + (DataType::Null, _) => Ok(ColumnarValue::from(ScalarValue::Null)), (dt, name) => exec_err!( "get indexed field is only possible on lists with int64 indexes or struct \ with utf8 indexes. Tried {dt:?} with {name:?} index" diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index f71b1b00f0fe..ca4f2819f403 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -47,8 +47,15 @@ fn named_struct_expr(args: &[ColumnarValue]) -> Result { let name_column = &chunk[0]; let name = match name_column { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar, - _ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2) + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(name_scalar)) => name_scalar, + _ => return exec_err!( + "named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2 + ) + }, + _ => return exec_err!( + "named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2 + ) }; Ok((name, chunk[1].clone())) diff --git a/datafusion/functions/src/core/nullif.rs b/datafusion/functions/src/core/nullif.rs index 6fcfbd36416e..de1099f7b0ed 100644 --- a/datafusion/functions/src/core/nullif.rs +++ b/datafusion/functions/src/core/nullif.rs @@ -17,11 +17,10 @@ use arrow::datatypes::DataType; use datafusion_common::{exec_err, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Scalar}; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::nullif::nullif; -use datafusion_common::ScalarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -131,8 +130,8 @@ fn nullif_func(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array)) } (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { - let val: ScalarValue = match lhs.eq(rhs) { - true => lhs.data_type().try_into()?, + let val = match lhs.eq(rhs) { + true => Scalar::try_from(lhs.data_type().clone())?, false => lhs.clone(), }; @@ -146,6 +145,7 @@ mod tests { use std::sync::Arc; use arrow::array::*; + use datafusion_common::ScalarValue; use super::*; @@ -164,7 +164,7 @@ mod tests { ]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -190,7 +190,7 @@ mod tests { let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(1i32))); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -215,7 +215,7 @@ mod tests { let a = BooleanArray::from(vec![Some(true), Some(false), None]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + let lit_array = ColumnarValue::from(ScalarValue::Boolean(Some(false))); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -232,7 +232,7 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::from("bar")); + let lit_array = ColumnarValue::from(ScalarValue::from("bar")); let result = nullif_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -253,7 +253,7 @@ mod tests { let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result = nullif_func(&[lit_array, a])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -272,8 +272,8 @@ mod tests { #[test] fn nullif_scalar() -> Result<()> { - let a_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let b_eq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let a_eq = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); + let b_eq = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result_eq = nullif_func(&[a_eq, b_eq])?; let result_eq = result_eq.into_array(1).expect("Failed to convert to array"); @@ -282,8 +282,8 @@ mod tests { assert_eq!(expected_eq.as_ref(), result_eq.as_ref()); - let a_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let b_neq = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); + let a_neq = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); + let b_neq = ColumnarValue::from(ScalarValue::Int32(Some(1i32))); let result_neq = nullif_func(&[a_neq, b_neq])?; let result_neq = result_neq diff --git a/datafusion/functions/src/core/nvl.rs b/datafusion/functions/src/core/nvl.rs index a09224acefcd..09d86ddfc0d9 100644 --- a/datafusion/functions/src/core/nvl.rs +++ b/datafusion/functions/src/core/nvl.rs @@ -112,7 +112,7 @@ fn nvl_func(args: &[ColumnarValue]) -> Result { } (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => { let mut current_value = lhs; - if lhs.is_null() { + if lhs.value().is_null() { current_value = rhs; } return Ok(ColumnarValue::Scalar(current_value.clone())); @@ -147,7 +147,7 @@ mod tests { ]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(6i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(6i32))); let result = nvl_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -173,7 +173,7 @@ mod tests { let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(20i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(20i32))); let result = nvl_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -198,7 +198,7 @@ mod tests { let a = BooleanArray::from(vec![Some(true), Some(false), None]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + let lit_array = ColumnarValue::from(ScalarValue::Boolean(Some(false))); let result = nvl_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -218,7 +218,7 @@ mod tests { let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::from("bax")); + let lit_array = ColumnarValue::from(ScalarValue::from("bax")); let result = nvl_func(&[a, lit_array])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -239,7 +239,7 @@ mod tests { let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]); let a = ColumnarValue::Array(Arc::new(a)); - let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let lit_array = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result = nvl_func(&[lit_array, a])?; let result = result.into_array(0).expect("Failed to convert to array"); @@ -258,8 +258,8 @@ mod tests { #[test] fn nvl_scalar() -> Result<()> { - let a_null = ColumnarValue::Scalar(ScalarValue::Int32(None)); - let b_null = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); + let a_null = ColumnarValue::from(ScalarValue::Int32(None)); + let b_null = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); let result_null = nvl_func(&[a_null, b_null])?; let result_null = result_null @@ -270,8 +270,8 @@ mod tests { assert_eq!(expected_null.as_ref(), result_null.as_ref()); - let a_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32))); - let b_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32))); + let a_nnull = ColumnarValue::from(ScalarValue::Int32(Some(2i32))); + let b_nnull = ColumnarValue::from(ScalarValue::Int32(Some(1i32))); let result_nnull = nvl_func(&[a_nnull, b_nnull])?; let result_nnull = result_nnull diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index 1144dc0fb7c5..f3027925b26a 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -126,7 +126,7 @@ fn nvl2_func(args: &[ColumnarValue]) -> Result { internal_err!("except Scalar value, but got Array") } ColumnarValue::Scalar(scalar) => { - if scalar.is_null() { + if scalar.value().is_null() { current_value = &args[2]; } Ok(current_value.clone()) diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index c3dee8b1ccb4..5633beea3915 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -109,9 +109,9 @@ mod tests { fn test_struct() { // struct(1, 2, 3) = {"c0": 1, "c1": 2, "c2": 3} let args = [ - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(3))), + ColumnarValue::from(ScalarValue::Int64(Some(1))), + ColumnarValue::from(ScalarValue::Int64(Some(2))), + ColumnarValue::from(ScalarValue::Int64(Some(3))), ]; let struc = struct_expr(&args) .expect("failed to initialize function struct") diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index f3015c24b3fa..85b0b4a0bc6d 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -94,6 +94,7 @@ macro_rules! digest_to_scalar { digest.update(v); digest.finalize().as_slice().to_vec() })) + .into() }}; } @@ -120,7 +121,7 @@ pub fn digest(args: &[ColumnarValue]) -> Result { ); } let digest_algorithm = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(Some(method)) => method.parse::(), other => exec_err!("Unsupported data type {other:?} for function digest"), }, @@ -189,10 +190,12 @@ pub fn md5(args: &[ColumnarValue]) -> Result { .collect(); ColumnarValue::Array(Arc::new(string_array)) } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => { - ColumnarValue::Scalar(ScalarValue::Utf8(opt.map(hex_encode::<_>))) - } - _ => return exec_err!("Impossibly got invalid results from digest"), + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Binary(opt) => { + ColumnarValue::from(ScalarValue::Utf8(opt.map(hex_encode::<_>))) + } + _ => return exec_err!("Impossibly got invalid results from digest"), + }, }) } @@ -254,7 +257,8 @@ impl DigestAlgorithm { let mut digest = Blake3::default(); digest.update(v); Blake3::finalize(&digest).as_bytes().to_vec() - })), + })) + .into(), }) } @@ -337,7 +341,7 @@ pub fn digest_process( ), }, ColumnarValue::Scalar(scalar) => { - match scalar { + match scalar.value() { ScalarValue::Utf8(a) => Ok(digest_algorithm .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), ScalarValue::Binary(a) => Ok(digest_algorithm diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index e0d775e602d6..d1ff06f52712 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -170,10 +170,10 @@ where ))), other => exec_err!("Unsupported data type {other:?} for function {name}"), }, - ColumnarValue::Scalar(scalar) => match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(a) => { let result = a.as_ref().map(|x| (op)(x)).transpose()?; - Ok(ColumnarValue::Scalar(S::scalar(result))) + Ok(ColumnarValue::from(S::scalar(result))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, @@ -227,7 +227,7 @@ where } }, // if the first argument is a scalar utf8 all arguments are expected to be scalar utf8 - ColumnarValue::Scalar(scalar) => match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(a) => { let a = a.as_ref(); // ASK: Why do we trust `a` to be non-null at this point? @@ -236,14 +236,26 @@ where let mut ret = None; for (pos, v) in args.iter().enumerate().skip(1) { - let ColumnarValue::Scalar(ScalarValue::Utf8(x)) = v else { - return exec_err!("Unsupported data type {v:?} for function {name}, arg # {pos}"); + let x = match v { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(v) => v, + _ => { + return exec_err!( + "Unsupported data type {v:?} for function {name}, arg # {pos}" + ) + } + }, + _ => { + return exec_err!( + "Unsupported data type {v:?} for function {name}, arg # {pos}" + ) + } }; if let Some(s) = x { match op(a.as_str(), s.as_str()) { Ok(r) => { - ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some( + ret = Some(Ok(ColumnarValue::from(S::scalar(Some( op2(r), ))))); break; @@ -300,7 +312,7 @@ where ColumnarValue::Array(a) => { Ok(Either::Left(as_generic_string_array::(a.as_ref())?)) } - ColumnarValue::Scalar(s) => match s { + ColumnarValue::Scalar(s) => match s.value() { ScalarValue::Utf8(a) => Ok(Either::Right(a)), other => exec_err!( "Unexpected scalar type encountered '{other}' for function '{name}'" diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 997f1a36ad04..36b97d2ef3c3 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -135,7 +135,7 @@ impl ScalarUDFImpl for DateBinFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { if args.len() == 2 { // Default to unix EPOCH - let origin = ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + let origin = ColumnarValue::from(ScalarValue::TimestampNanosecond( Some(0), Some("+00:00".into()), )); @@ -260,67 +260,67 @@ fn date_bin_impl( array: &ColumnarValue, origin: &ColumnarValue, ) -> Result { - let stride = match stride { - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(v))) => { - let (days, ms) = IntervalDayTimeType::to_parts(*v); - let nanos = (TimeDelta::try_days(days as i64).unwrap() - + TimeDelta::try_milliseconds(ms as i64).unwrap()) - .num_nanoseconds(); - - match nanos { - Some(v) => Interval::Nanoseconds(v), - _ => return exec_err!("DATE_BIN stride argument is too large"), - } - } - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some(v))) => { - let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); - - // If interval is months, its origin must be midnight of first date of the month - if months != 0 { - // Return error if days or nanos is not zero - if days != 0 || nanos != 0 { - return not_impl_err!( - "DATE_BIN stride does not support combination of month, day and nanosecond intervals" - ); - } else { - Interval::Months(months as i64) - } - } else { + let stride = if let ColumnarValue::Scalar(scalar) = stride { + match scalar.value() { + ScalarValue::IntervalDayTime(Some(v)) => { + let (days, ms) = IntervalDayTimeType::to_parts(*v); let nanos = (TimeDelta::try_days(days as i64).unwrap() - + Duration::nanoseconds(nanos)) + + TimeDelta::try_milliseconds(ms as i64).unwrap()) .num_nanoseconds(); + match nanos { Some(v) => Interval::Nanoseconds(v), _ => return exec_err!("DATE_BIN stride argument is too large"), } } + ScalarValue::IntervalMonthDayNano(Some(v)) => { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v); + + // If interval is months, its origin must be midnight of first date of the month + if months != 0 { + // Return error if days or nanos is not zero + if days != 0 || nanos != 0 { + return not_impl_err!( + "DATE_BIN stride does not support combination of month, day and nanosecond intervals" + ); + } else { + Interval::Months(months as i64) + } + } else { + let nanos = (TimeDelta::try_days(days as i64).unwrap() + + Duration::nanoseconds(nanos)) + .num_nanoseconds(); + match nanos { + Some(v) => Interval::Nanoseconds(v), + _ => return exec_err!("DATE_BIN stride argument is too large"), + } + } + } + _ => { + return exec_err!( + "DATE_BIN expects stride argument to be an INTERVAL but got {}", + scalar.data_type() + ); + } } - ColumnarValue::Scalar(v) => { - return exec_err!( - "DATE_BIN expects stride argument to be an INTERVAL but got {}", - v.data_type() - ); - } - ColumnarValue::Array(_) => { - return not_impl_err!( + } else { + return not_impl_err!( "DATE_BIN only supports literal values for the stride argument, not arrays" ); - } }; - let origin = match origin { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(v), _)) => *v, - ColumnarValue::Scalar(v) => { - return exec_err!( + let origin = if let ColumnarValue::Scalar(scalar) = origin { + match scalar.value() { + ScalarValue::TimestampNanosecond(Some(v), _) => *v, + _ => return exec_err!( "DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got {}", - v.data_type() - ); + scalar.data_type() + ) } - ColumnarValue::Array(_) => { - return not_impl_err!( + } else { + return not_impl_err!( "DATE_BIN only supports literal values for the origin argument, not arrays" ); - } }; let (stride, stride_fn) = stride.bin_fn(); @@ -345,38 +345,37 @@ fn date_bin_impl( } Ok(match array { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - v.map(apply_stride_fn), - tz_opt.clone(), - )) - } - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - v.map(apply_stride_fn), - tz_opt.clone(), - )) - } - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - v.map(apply_stride_fn), - tz_opt.clone(), - )) - } - ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - let apply_stride_fn = - stride_map_fn::(origin, stride, stride_fn); - ColumnarValue::Scalar(ScalarValue::TimestampSecond( - v.map(apply_stride_fn), - tz_opt.clone(), - )) - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::TimestampNanosecond(v, tz_opt) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); + ScalarValue::TimestampNanosecond(v.map(apply_stride_fn), tz_opt.clone()) + .into() + } + ScalarValue::TimestampMicrosecond(v, tz_opt) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); + ScalarValue::TimestampMicrosecond(v.map(apply_stride_fn), tz_opt.clone()) + .into() + } + ScalarValue::TimestampMillisecond(v, tz_opt) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); + ScalarValue::TimestampMillisecond(v.map(apply_stride_fn), tz_opt.clone()) + .into() + } + ScalarValue::TimestampSecond(v, tz_opt) => { + let apply_stride_fn = + stride_map_fn::(origin, stride, stride_fn); + ScalarValue::TimestampSecond(v.map(apply_stride_fn), tz_opt.clone()) + .into() + } + value => { + return exec_err!( + "DATE_BIN expects source argument to be a TIMESTAMP scalar but got {}", value + ); + } + }, ColumnarValue::Array(array) => { fn transform_array_with_stride( @@ -427,11 +426,6 @@ fn date_bin_impl( } } } - _ => { - return exec_err!( - "DATE_BIN expects source argument to be a TIMESTAMP scalar or array" - ); - } }) } @@ -454,46 +448,46 @@ mod tests { #[test] fn test_date_bin() { let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); let timestamps = Arc::new((1..6).map(Some).collect::()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), ColumnarValue::Array(timestamps), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); // stride supports month-day-nano let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( + ColumnarValue::from(ScalarValue::IntervalMonthDayNano(Some( IntervalMonthDayNano { months: 0, days: 0, nanoseconds: 1, }, ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); @@ -502,7 +496,7 @@ mod tests { // // invalid number of arguments - let res = DateBinFunc::new().invoke(&[ColumnarValue::Scalar( + let res = DateBinFunc::new().invoke(&[ColumnarValue::from( ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, @@ -515,9 +509,9 @@ mod tests { // stride: invalid type let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -526,12 +520,12 @@ mod tests { // stride: invalid value let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 0, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -540,11 +534,9 @@ mod tests { // stride: overflow of day-time interval let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime::MAX, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime::MAX))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -553,9 +545,9 @@ mod tests { // stride: overflow of month-day-nano interval let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -564,9 +556,9 @@ mod tests { // stride: month intervals let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::new_interval_mdn(1, 1, 1)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -575,12 +567,12 @@ mod tests { // origin: invalid type let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampMicrosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -588,12 +580,12 @@ mod tests { ); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert!(res.is_ok()); @@ -610,8 +602,8 @@ mod tests { ); let res = DateBinFunc::new().invoke(&[ ColumnarValue::Array(intervals), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -621,11 +613,11 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); let res = DateBinFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + ColumnarValue::from(ScalarValue::IntervalDayTime(Some(IntervalDayTime { days: 0, milliseconds: 1, }))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Array(timestamps), ]); assert_eq!( @@ -744,9 +736,9 @@ mod tests { .with_timezone_opt(tz_opt.clone()); let result = DateBinFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::from(ScalarValue::new_interval_dt(1, 0)), ColumnarValue::Array(Arc::new(input)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + ColumnarValue::from(ScalarValue::TimestampNanosecond( Some(string_to_timestamp_nanos(origin).unwrap()), tz_opt.clone(), )), diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index f4ea165e174e..310ed1b6d425 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -138,12 +138,20 @@ impl ScalarUDFImpl for DatePartFunc { } let (part, array) = (&args[0], &args[1]); - let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part { - v - } else { - return exec_err!( - "First argument of `DATE_PART` must be non-null scalar Utf8" - ); + let part = match part { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(v)) => v, + _ => { + return exec_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ) + } + }, + _ => { + return exec_err!( + "First argument of `DATE_PART` must be non-null scalar Utf8" + ) + } }; let is_scalar = matches!(array, ColumnarValue::Scalar(_)); @@ -178,7 +186,7 @@ impl ScalarUDFImpl for DatePartFunc { }; Ok(if is_scalar { - ColumnarValue::Scalar(ScalarValue::try_from_array(arr.as_ref(), 0)?) + ColumnarValue::from(ScalarValue::try_from_array(arr.as_ref(), 0)?) } else { ColumnarValue::Array(arr) }) diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 3bb22ce7913a..bbe337d24598 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -137,12 +137,20 @@ impl ScalarUDFImpl for DateTruncFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { let (granularity, array) = (&args[0], &args[1]); - let granularity = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = - granularity - { - v.to_lowercase() - } else { - return exec_err!("Granularity of `date_trunc` must be non-null scalar Utf8"); + let granularity = match granularity { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(v)) => v.to_lowercase(), + _ => { + return exec_err!( + "Granularity of `date_trunc` must be non-null scalar Utf8" + ) + } + }, + _ => { + return exec_err!( + "Granularity of `date_trunc` must be non-null scalar Utf8" + ) + } }; fn process_array( @@ -168,22 +176,29 @@ impl ScalarUDFImpl for DateTruncFunc { let parsed_tz = parse_tz(tz_opt)?; let value = general_date_trunc(T::UNIT, v, parsed_tz, granularity.as_str())?; let value = ScalarValue::new_timestamp::(value, tz_opt.clone()); - Ok(ColumnarValue::Scalar(value)) + Ok(ColumnarValue::from(value)) } Ok(match array { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(v, tz_opt)) => { - process_scalar::(v, granularity, tz_opt)? - } - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(v, tz_opt)) => { - process_scalar::(v, granularity, tz_opt)? - } - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(v, tz_opt)) => { - process_scalar::(v, granularity, tz_opt)? - } - ColumnarValue::Scalar(ScalarValue::TimestampSecond(v, tz_opt)) => { - process_scalar::(v, granularity, tz_opt)? - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::TimestampNanosecond(v, tz_opt) => { + process_scalar::(v, granularity, tz_opt)? + } + ScalarValue::TimestampMicrosecond(v, tz_opt) => { + process_scalar::(v, granularity, tz_opt)? + } + ScalarValue::TimestampMillisecond(v, tz_opt) => { + process_scalar::(v, granularity, tz_opt)? + } + ScalarValue::TimestampSecond(v, tz_opt) => { + process_scalar::(v, granularity, tz_opt)? + } + _ => { + return exec_err!( + "second argument of `date_trunc` must be nanosecond timestamp scalar or array" + ); + } + }, ColumnarValue::Array(array) => { let array_type = array.data_type(); match array_type { @@ -212,11 +227,6 @@ impl ScalarUDFImpl for DateTruncFunc { )?, } } - _ => { - return exec_err!( - "second argument of `date_trunc` must be nanosecond timestamp scalar or array" - ); - } }) } @@ -685,7 +695,7 @@ mod tests { .with_timezone_opt(tz_opt.clone()); let result = DateTruncFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::from("day")), + ColumnarValue::from(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ]) .unwrap(); @@ -843,7 +853,7 @@ mod tests { .with_timezone_opt(tz_opt.clone()); let result = DateTruncFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::from("hour")), + ColumnarValue::from(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ]) .unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 5f59ef0710a1..75dfa305351a 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -94,7 +94,7 @@ impl ScalarUDFImpl for MakeDateFunc { let ColumnarValue::Scalar(s) = col else { return exec_err!("Expected scalar value"); }; - let ScalarValue::Int32(Some(i)) = s else { + let ScalarValue::Int32(Some(i)) = s.value() else { return exec_err!("Unable to parse date from null/empty value"); }; Ok(*i) @@ -143,7 +143,7 @@ impl ScalarUDFImpl for MakeDateFunc { |days: i32| value = days, )?; - ColumnarValue::Scalar(ScalarValue::Date32(Some(value))) + ColumnarValue::from(ScalarValue::Date32(Some(value))) }; Ok(value) @@ -192,42 +192,51 @@ mod tests { fn test_make_date() { let res = MakeDateFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), + ColumnarValue::from(ScalarValue::Int32(Some(2024))), + ColumnarValue::from(ScalarValue::Int64(Some(1))), + ColumnarValue::from(ScalarValue::UInt32(Some(14))), ]) .expect("that make_date parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); + if let ColumnarValue::Scalar(scalar) = res { + match scalar.value() { + ScalarValue::Date32(date) => assert_eq!(19736, date.unwrap()), + _ => panic!("Expected a Date32"), + } } else { panic!("Expected a scalar value") } let res = MakeDateFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), - ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), + ColumnarValue::from(ScalarValue::Int64(Some(2024))), + ColumnarValue::from(ScalarValue::UInt64(Some(1))), + ColumnarValue::from(ScalarValue::UInt32(Some(14))), ]) .expect("that make_date parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); + if let ColumnarValue::Scalar(scalar) = res { + match scalar.value() { + ScalarValue::Date32(date) => assert_eq!(19736, date.unwrap()), + _ => panic!("Expected a Date32"), + } } else { panic!("Expected a scalar value") } let res = MakeDateFunc::new() .invoke(&[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("1".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("2024".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("1".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("14".to_string()))), ]) .expect("that make_date parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { - assert_eq!(19736, date.unwrap()); + if let ColumnarValue::Scalar(scalar) = res { + match scalar.value() { + ScalarValue::Date32(date) => assert_eq!(19736, date.unwrap()), + _ => panic!("Expected a Date32"), + } } else { panic!("Expected a scalar value") } @@ -261,7 +270,7 @@ mod tests { // invalid number of arguments let res = MakeDateFunc::new() - .invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]); + .invoke(&[ColumnarValue::from(ScalarValue::Int32(Some(1)))]); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: make_date function requires 3 arguments, got 1" @@ -269,9 +278,9 @@ mod tests { // invalid type let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -280,9 +289,9 @@ mod tests { // overflow of month let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), - ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), + ColumnarValue::from(ScalarValue::Int32(Some(2023))), + ColumnarValue::from(ScalarValue::UInt64(Some(u64::MAX))), + ColumnarValue::from(ScalarValue::Int32(Some(22))), ]); assert_eq!( res.err().unwrap().strip_backtrace(), @@ -291,9 +300,9 @@ mod tests { // overflow of day let res = MakeDateFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), - ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), + ColumnarValue::from(ScalarValue::Int32(Some(2023))), + ColumnarValue::from(ScalarValue::Int32(Some(22))), + ColumnarValue::from(ScalarValue::UInt32(Some(u32::MAX))), ]); assert_eq!( res.err().unwrap().strip_backtrace(), diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index f2e5af978ca0..5c0f57ca09e2 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -115,22 +115,23 @@ impl ScalarUDFImpl for ToCharFunc { } match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Null) => { - _to_char_scalar(args[0].clone(), None) - } - // constant format - ColumnarValue::Scalar(ScalarValue::Utf8(Some(format))) => { - // invoke to_char_scalar with the known string, without converting to array - _to_char_scalar(args[0].clone(), Some(format)) - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(None) | ScalarValue::Null => { + _to_char_scalar(args[0].clone(), None) + } + // constant format + ScalarValue::Utf8(Some(format)) => { + // invoke to_char_scalar with the known string, without converting to array + _to_char_scalar(args[0].clone(), Some(format)) + } + _ => { + exec_err!( + "Format for `to_char` must be non-null Utf8, received {:?}", + args[1].data_type() + ) + } + }, ColumnarValue::Array(_) => _to_char_array(args), - _ => { - exec_err!( - "Format for `to_char` must be non-null Utf8, received {:?}", - args[1].data_type() - ) - } } } @@ -177,13 +178,13 @@ fn _to_char_scalar( ) -> Result { // it's possible that the expression is a scalar however because // of the implementation in arrow-rs we need to convert it to an array - let data_type = &expression.data_type(); + let data_type = expression.data_type().clone(); let is_scalar_expression = matches!(&expression, ColumnarValue::Scalar(_)); let array = expression.into_array(1)?; if format.is_none() { if is_scalar_expression { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + return Ok(ColumnarValue::from(ScalarValue::Utf8(None))); } else { return Ok(ColumnarValue::Array(new_null_array( &DataType::Utf8, @@ -192,7 +193,7 @@ fn _to_char_scalar( } } - let format_options = match _build_format_options(data_type, format) { + let format_options = match _build_format_options(&data_type, format) { Ok(value) => value, Err(value) => return value, }; @@ -204,7 +205,7 @@ fn _to_char_scalar( if let Ok(formatted) = formatted { if is_scalar_expression { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + Ok(ColumnarValue::from(ScalarValue::Utf8(Some( formatted.first().unwrap().to_string(), )))) } else { @@ -252,10 +253,10 @@ fn _to_char_array(args: &[ColumnarValue]) -> Result { results, )) as ArrayRef)), ColumnarValue::Scalar(_) => match results.first().unwrap() { - Some(value) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + Some(value) => Ok(ColumnarValue::from(ScalarValue::Utf8(Some( value.to_string(), )))), - None => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))), + None => Ok(ColumnarValue::from(ScalarValue::Utf8(None))), }, } } @@ -351,13 +352,15 @@ mod tests { for (value, format, expected) in scalar_data { let result = ToCharFunc::new() - .invoke(&[ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)]) + .invoke(&[ColumnarValue::from(value), ColumnarValue::from(format)]) .expect("that to_char parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Utf8(date)) = result { - assert_eq!(expected, date.unwrap()); - } else { - panic!("Expected a scalar value") + match result { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(date)) => assert_eq!(&expected, date), + _ => panic!("Expected a scalar value"), + }, + _ => panic!("Expected a scalar value"), } } @@ -426,15 +429,17 @@ mod tests { for (value, format, expected) in scalar_array_data { let result = ToCharFunc::new() .invoke(&[ - ColumnarValue::Scalar(value), + ColumnarValue::from(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ]) .expect("that to_char parsed values without error"); - if let ColumnarValue::Scalar(ScalarValue::Utf8(date)) = result { - assert_eq!(expected, date.unwrap()); - } else { - panic!("Expected a scalar value") + match result { + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(date)) => assert_eq!(&expected, date), + _ => panic!("Expected a scalar value"), + }, + _ => panic!("Expected a scalar value"), } } @@ -552,7 +557,7 @@ mod tests { let result = ToCharFunc::new() .invoke(&[ ColumnarValue::Array(value as ArrayRef), - ColumnarValue::Scalar(format), + ColumnarValue::from(format), ]) .expect("that to_char parsed values without error"); @@ -585,8 +590,8 @@ mod tests { // // invalid number of arguments - let result = ToCharFunc::new() - .invoke(&[ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))]); + let result = + ToCharFunc::new().invoke(&[ColumnarValue::from(ScalarValue::Int32(Some(1)))]); assert_eq!( result.err().unwrap().strip_backtrace(), "Execution error: to_char function requires 2 arguments, got 1" @@ -594,8 +599,8 @@ mod tests { // invalid type let result = ToCharFunc::new().invoke(&[ - ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::from(ScalarValue::Int32(Some(1))), + ColumnarValue::from(ScalarValue::TimestampNanosecond(Some(1), None)), ]); assert_eq!( result.err().unwrap().strip_backtrace(), diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 634e28e6f393..3fa36223b918 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -97,50 +97,48 @@ impl ToLocalTimeFunc { let tz: Tz = timezone.parse()?; match time_value { - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampMillisecond( - Some(adjusted_ts), - None, - ))) - } - ColumnarValue::Scalar(ScalarValue::TimestampSecond( - Some(ts), - Some(_), - )) => { - let adjusted_ts = - adjust_to_local_time::(*ts, tz)?; - Ok(ColumnarValue::Scalar(ScalarValue::TimestampSecond( - Some(adjusted_ts), - None, - ))) - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::TimestampNanosecond(Some(ts), Some(_)) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::from(ScalarValue::TimestampNanosecond( + Some(adjusted_ts), + None, + ))) + } + ScalarValue::TimestampMicrosecond(Some(ts), Some(_)) => { + let adjusted_ts = adjust_to_local_time::< + TimestampMicrosecondType, + >(*ts, tz)?; + Ok(ColumnarValue::from(ScalarValue::TimestampMicrosecond( + Some(adjusted_ts), + None, + ))) + } + ScalarValue::TimestampMillisecond(Some(ts), Some(_)) => { + let adjusted_ts = adjust_to_local_time::< + TimestampMillisecondType, + >(*ts, tz)?; + Ok(ColumnarValue::from(ScalarValue::TimestampMillisecond( + Some(adjusted_ts), + None, + ))) + } + ScalarValue::TimestampSecond(Some(ts), Some(_)) => { + let adjusted_ts = + adjust_to_local_time::(*ts, tz)?; + Ok(ColumnarValue::from(ScalarValue::TimestampSecond( + Some(adjusted_ts), + None, + ))) + } + _ => { + exec_err!( + "to_local_time function requires timestamp argument, got {:?}", + time_value.data_type() + ) + } + }, ColumnarValue::Array(array) => { fn transform_array( array: &ArrayRef, @@ -185,12 +183,6 @@ impl ToLocalTimeFunc { } } } - _ => { - exec_err!( - "to_local_time function requires timestamp argument, got {:?}", - time_value.data_type() - ) - } } } _ => { @@ -486,11 +478,11 @@ mod tests { fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { let res = ToLocalTimeFunc::new() - .invoke(&[ColumnarValue::Scalar(input)]) + .invoke(&[ColumnarValue::from(input)]) .unwrap(); match res { ColumnarValue::Scalar(res) => { - assert_eq!(res, expected); + assert_eq!(res.into_value(), expected); } _ => panic!("unexpected return type"), } diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index b767fd0720db..4a6f62e39c70 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -161,7 +161,7 @@ impl ScalarUDFImpl for ToTimestampFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Int32 | DataType::Int64 => args[0] .cast_to(&Timestamp(Second, None), None)? .cast_to(&Timestamp(Nanosecond, None), None), @@ -214,7 +214,7 @@ impl ScalarUDFImpl for ToTimestampSecondsFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Second, None), None) } @@ -264,7 +264,7 @@ impl ScalarUDFImpl for ToTimestampMillisFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Millisecond, None), None) } @@ -314,7 +314,7 @@ impl ScalarUDFImpl for ToTimestampMicrosFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Microsecond, None), None) } @@ -364,7 +364,7 @@ impl ScalarUDFImpl for ToTimestampNanosFunc { validate_data_types(args, "to_timestamp")?; } - match args[0].data_type() { + match args[0].data_type().clone() { DataType::Null | DataType::Int32 | DataType::Int64 | Timestamp(_, None) => { args[0].cast_to(&Timestamp(Nanosecond, None), None) } @@ -803,7 +803,7 @@ mod tests { for udf in &udfs { for array in arrays { - let rt = udf.return_type(&[array.data_type()]).unwrap(); + let rt = udf.return_type(&[array.data_type().clone()]).unwrap(); assert!(matches!(rt, DataType::Timestamp(_, Some(_)))); let res = udf @@ -846,7 +846,7 @@ mod tests { for udf in &udfs { for array in arrays { - let rt = udf.return_type(&[array.data_type()]).unwrap(); + let rt = udf.return_type(&[array.data_type().clone()]).unwrap(); assert!(matches!(rt, DataType::Timestamp(_, None))); let res = udf @@ -896,9 +896,9 @@ mod tests { // test UTF8 let string_array = [ ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%s".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%c".to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some("%+".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("%s".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("%c".to_string()))), + ColumnarValue::from(ScalarValue::Utf8(Some("%+".to_string()))), ]; let parsed_timestamps = func(&string_array) .expect("that to_timestamp with format args parsed values without error"); @@ -925,9 +925,9 @@ mod tests { // test other types let string_array = [ ColumnarValue::Array(Arc::new(data.clone()) as ArrayRef), - ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), - ColumnarValue::Scalar(ScalarValue::Int32(Some(3))), + ColumnarValue::from(ScalarValue::Int32(Some(1))), + ColumnarValue::from(ScalarValue::Int32(Some(2))), + ColumnarValue::from(ScalarValue::Int32(Some(3))), ]; let expected = "Unsupported data type Int32 for function".to_string(); diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index aedbe18ec18c..5a1e85af5ccc 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -173,7 +173,7 @@ fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(a) => { Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) } @@ -198,7 +198,7 @@ fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(a) => { encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) } @@ -254,7 +254,7 @@ macro_rules! decode_to_array { impl Encoding { fn encode_scalar(self, value: Option<&[u8]>) -> ColumnarValue { - ColumnarValue::Scalar(match self { + ColumnarValue::from(match self { Self::Base64 => ScalarValue::Utf8( value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), ), @@ -289,7 +289,7 @@ impl Encoding { fn decode_scalar(self, value: Option<&[u8]>) -> Result { let value = match value { Some(value) => value, - None => return Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))), + None => return Ok(ColumnarValue::from(ScalarValue::Binary(None))), }; let out = match self { @@ -311,7 +311,7 @@ impl Encoding { })?, }; - Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(out)))) + Ok(ColumnarValue::from(ScalarValue::Binary(Some(out)))) } fn decode_binary_array(self, value: &dyn Array) -> Result @@ -376,7 +376,7 @@ fn encode(args: &[ColumnarValue]) -> Result { ); } let encoding = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(Some(method)) => { method.parse::() } @@ -402,7 +402,7 @@ fn decode(args: &[ColumnarValue]) -> Result { ); } let encoding = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(Some(method)) => { method.parse::() } diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index ad7cff1f7149..7c65439478df 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -110,7 +110,7 @@ impl ScalarUDFImpl for LogFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { let args = ColumnarValue::values_to_arrays(args)?; - let mut base = ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))); + let mut base = ColumnarValue::from(ScalarValue::Float32(Some(10.0))); let mut x = &args[0]; if args.len() == 2 { @@ -120,11 +120,18 @@ impl ScalarUDFImpl for LogFunc { // note in f64::log params order is different than in sql. e.g in sql log(base, x) == f64::log(x, base) let arr: ArrayRef = match args[0].data_type() { DataType::Float64 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { - |value: f64| f64::log(value, base as f64) - })) - } + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Float32(Some(base)) => { + Arc::new(make_function_scalar_inputs!(x, "x", Float64Array, { + |value: f64| f64::log(value, base as f64) + })) + } + _ => { + return exec_err!( + "log function requires a scalar or array for base" + ) + } + }, ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( x, base, @@ -133,17 +140,21 @@ impl ScalarUDFImpl for LogFunc { Float64Array, { f64::log } )), - _ => { - return exec_err!("log function requires a scalar or array for base") - } }, DataType::Float32 => match base { - ColumnarValue::Scalar(ScalarValue::Float32(Some(base))) => { - Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { - |value: f32| f32::log(value, base) - })) - } + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Float32(Some(base)) => { + Arc::new(make_function_scalar_inputs!(x, "x", Float32Array, { + |value: f32| f32::log(value, base) + })) + } + _ => { + return exec_err!( + "log function requires a scalar or array for base" + ) + } + }, ColumnarValue::Array(base) => Arc::new(make_function_inputs2!( x, base, @@ -152,9 +163,6 @@ impl ScalarUDFImpl for LogFunc { Float32Array, { f32::log } )), - _ => { - return exec_err!("log function requires a scalar or array for base") - } }, other => { return exec_err!("Unsupported data type {other:?} for function log") diff --git a/datafusion/functions/src/math/pi.rs b/datafusion/functions/src/math/pi.rs index c2fe4efb1139..b4adfc190e69 100644 --- a/datafusion/functions/src/math/pi.rs +++ b/datafusion/functions/src/math/pi.rs @@ -64,7 +64,7 @@ impl ScalarUDFImpl for PiFunc { } fn invoke_no_args(&self, _number_rows: usize) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + Ok(ColumnarValue::from(ScalarValue::Float64(Some( std::f64::consts::PI, )))) } diff --git a/datafusion/functions/src/math/round.rs b/datafusion/functions/src/math/round.rs index 89554a76febb..05d3d055c2b8 100644 --- a/datafusion/functions/src/math/round.rs +++ b/datafusion/functions/src/math/round.rs @@ -108,7 +108,7 @@ pub fn round(args: &[ArrayRef]) -> Result { ); } - let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0))); + let mut decimal_places = ColumnarValue::from(ScalarValue::Int64(Some(0))); if args.len() == 2 { decimal_places = ColumnarValue::Array(Arc::clone(&args[1])); @@ -116,25 +116,32 @@ pub fn round(args: &[ArrayRef]) -> Result { match args[0].data_type() { DataType::Float64 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places: i32 = decimal_places.try_into().map_err(|e| { - exec_datafusion_err!( - "Invalid value for decimal places: {decimal_places}: {e}" - ) - })?; + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Int64(Some(decimal_places)) => { + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float64Array, - { - |value: f64| { - (value * 10.0_f64.powi(decimal_places)).round() - / 10.0_f64.powi(decimal_places) + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float64Array, + { + |value: f64| { + (value * 10.0_f64.powi(decimal_places)).round() + / 10.0_f64.powi(decimal_places) + } } - } - )) as ArrayRef) - } + )) as ArrayRef) + } + _ => { + exec_err!( + "round function requires a scalar or array for decimal_places" + ) + } + }, ColumnarValue::Array(decimal_places) => { let options = CastOptions { safe: false, // raise error if the cast is not possible @@ -159,31 +166,35 @@ pub fn round(args: &[ArrayRef]) -> Result { } )) as ArrayRef) } - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } }, DataType::Float32 => match decimal_places { - ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => { - let decimal_places: i32 = decimal_places.try_into().map_err(|e| { - exec_datafusion_err!( - "Invalid value for decimal places: {decimal_places}: {e}" - ) - })?; + ColumnarValue::Scalar(scalar) => match scalar.into_value() { + ScalarValue::Int64(Some(decimal_places)) => { + let decimal_places: i32 = decimal_places.try_into().map_err(|e| { + exec_datafusion_err!( + "Invalid value for decimal places: {decimal_places}: {e}" + ) + })?; - Ok(Arc::new(make_function_scalar_inputs!( - &args[0], - "value", - Float32Array, - { - |value: f32| { - (value * 10.0_f32.powi(decimal_places)).round() - / 10.0_f32.powi(decimal_places) + Ok(Arc::new(make_function_scalar_inputs!( + &args[0], + "value", + Float32Array, + { + |value: f32| { + (value * 10.0_f32.powi(decimal_places)).round() + / 10.0_f32.powi(decimal_places) + } } - } - )) as ArrayRef) - } + )) as ArrayRef) + } + _ => { + exec_err!( + "round function requires a scalar or array for decimal_places" + ) + } + }, ColumnarValue::Array(_) => { let ColumnarValue::Array(decimal_places) = decimal_places.cast_to(&Int32, None).map_err(|e| { @@ -208,9 +219,6 @@ pub fn round(args: &[ArrayRef]) -> Result { } )) as ArrayRef) } - _ => { - exec_err!("round function requires a scalar or array for decimal_places") - } }, other => exec_err!("Unsupported data type {other:?} for function round"), diff --git a/datafusion/functions/src/math/trunc.rs b/datafusion/functions/src/math/trunc.rs index 3344438454c4..e6020f18d168 100644 --- a/datafusion/functions/src/math/trunc.rs +++ b/datafusion/functions/src/math/trunc.rs @@ -115,16 +115,22 @@ fn trunc(args: &[ArrayRef]) -> Result { //or then invoke the compute_truncate method to process precision let num = &args[0]; let precision = if args.len() == 1 { - ColumnarValue::Scalar(Int64(Some(0))) + ColumnarValue::from(Int64(Some(0))) } else { ColumnarValue::Array(Arc::clone(&args[1])) }; match args[0].data_type() { Float64 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float64Array, { f64::trunc }), - ) as ArrayRef), + ColumnarValue::Scalar(scalar) => match scalar.value() { + Int64(Some(0)) => Ok(Arc::new(make_function_scalar_inputs!( + num, + "num", + Float64Array, + { f64::trunc } + )) as ArrayRef), + _ => exec_err!("trunc function requires a scalar or array for precision"), + }, ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( num, precision, @@ -134,12 +140,17 @@ fn trunc(args: &[ArrayRef]) -> Result { Int64Array, { compute_truncate64 } )) as ArrayRef), - _ => exec_err!("trunc function requires a scalar or array for precision"), }, Float32 => match precision { - ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new( - make_function_scalar_inputs!(num, "num", Float32Array, { f32::trunc }), - ) as ArrayRef), + ColumnarValue::Scalar(scalar) => match scalar.value() { + Int64(Some(0)) => Ok(Arc::new(make_function_scalar_inputs!( + num, + "num", + Float32Array, + { f32::trunc } + )) as ArrayRef), + _ => exec_err!("trunc function requires a scalar or array for precision"), + }, ColumnarValue::Array(precision) => Ok(Arc::new(make_function_inputs2!( num, precision, @@ -149,7 +160,6 @@ fn trunc(args: &[ArrayRef]) -> Result { Int64Array, { compute_truncate32 } )) as ArrayRef), - _ => exec_err!("trunc function requires a scalar or array for precision"), }, other => exec_err!("Unsupported data type {other:?} for function trunc"), } diff --git a/datafusion/functions/src/regex/regexplike.rs b/datafusion/functions/src/regex/regexplike.rs index 20029ba005c4..c67bc4157298 100644 --- a/datafusion/functions/src/regex/regexplike.rs +++ b/datafusion/functions/src/regex/regexplike.rs @@ -20,13 +20,12 @@ use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use datafusion_common::exec_err; -use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Scalar}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -99,7 +98,7 @@ impl ScalarUDFImpl for RegexpLikeFunc { let result = regexp_like_func(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) diff --git a/datafusion/functions/src/regex/regexpmatch.rs b/datafusion/functions/src/regex/regexpmatch.rs index 764acd7de757..2f7f1d05b2ae 100644 --- a/datafusion/functions/src/regex/regexpmatch.rs +++ b/datafusion/functions/src/regex/regexpmatch.rs @@ -21,13 +21,12 @@ use arrow::compute::kernels::regexp; use arrow::datatypes::DataType; use arrow::datatypes::Field; use datafusion_common::exec_err; -use datafusion_common::ScalarValue; use datafusion_common::{arrow_datafusion_err, plan_err}; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; -use datafusion_expr::ColumnarValue; use datafusion_expr::TypeSignature::*; +use datafusion_expr::{ColumnarValue, Scalar}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; use std::sync::Arc; @@ -97,7 +96,7 @@ impl ScalarUDFImpl for RegexpMatchFunc { let result = regexp_match_func(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index d28c6cd36d65..021e3105624f 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -27,12 +27,12 @@ use arrow::datatypes::DataType; use datafusion_common::cast::as_string_view_array; use datafusion_common::exec_err; use datafusion_common::plan_err; -use datafusion_common::ScalarValue; use datafusion_common::{ cast::as_generic_string_array, internal_err, DataFusionError, Result, }; use datafusion_expr::function::Hint; use datafusion_expr::ColumnarValue; +use datafusion_expr::Scalar; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use regex::Regex; @@ -116,7 +116,7 @@ impl ScalarUDFImpl for RegexpReplaceFunc { let result = regexp_replace_func(args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) diff --git a/datafusion/functions/src/string/ascii.rs b/datafusion/functions/src/string/ascii.rs index 016e5f11893a..196a5a607639 100644 --- a/datafusion/functions/src/string/ascii.rs +++ b/datafusion/functions/src/string/ascii.rs @@ -122,7 +122,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( AsciiFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/string/bit_length.rs b/datafusion/functions/src/string/bit_length.rs index 6f1926f5a96d..0dd07fa87f69 100644 --- a/datafusion/functions/src/string/bit_length.rs +++ b/datafusion/functions/src/string/bit_length.rs @@ -77,8 +77,8 @@ impl ScalarUDFImpl for BitLengthFunc { match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(bit_length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + ColumnarValue::Scalar(v) => match v.value() { + ScalarValue::Utf8(v) => Ok(ColumnarValue::from(ScalarValue::Int32( v.as_ref().map(|x| (x.len() * 8) as i32), ))), _ => unreachable!(), diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 393b1f4b3317..3a59c8a115f0 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -214,10 +214,10 @@ where >(array, op)?)), other => exec_err!("Unsupported data type {other:?} for function {name}"), }, - ColumnarValue::Scalar(scalar) => match scalar { + ColumnarValue::Scalar(scalar) => match scalar.value() { ScalarValue::Utf8(a) => { let result = a.as_ref().map(|x| op(x)); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + Ok(ColumnarValue::from(ScalarValue::Utf8(result))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 827a2dfef222..770dd0c30979 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -83,11 +83,13 @@ impl ScalarUDFImpl for ConcatFunc { if array_len.is_none() { let mut result = String::new(); for arg in args { - if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { - result.push_str(v); + if let ColumnarValue::Scalar(scalar) = arg { + if let ScalarValue::Utf8(Some(v)) = scalar.value() { + result.push_str(v); + } } } - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + return Ok(ColumnarValue::from(ScalarValue::Utf8(Some(result)))); } // Array @@ -97,12 +99,15 @@ impl ScalarUDFImpl for ConcatFunc { for arg in args { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(maybe_value) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } } - } + _ => unreachable!(), + }, ColumnarValue::Array(array) => { let string_array = as_string_array(array)?; data_size += string_array.values().len(); @@ -113,7 +118,6 @@ impl ScalarUDFImpl for ConcatFunc { }; columns.push(column); } - _ => unreachable!(), } } @@ -203,9 +207,9 @@ mod tests { test_function!( ConcatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::from("bb")), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::from("bb")), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(Some("aabbcc")), &str, @@ -215,9 +219,9 @@ mod tests { test_function!( ConcatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(Some("aacc")), &str, @@ -226,7 +230,7 @@ mod tests { ); test_function!( ConcatFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + &[ColumnarValue::from(ScalarValue::Utf8(None))], Ok(Some("")), &str, Utf8, @@ -240,7 +244,7 @@ mod tests { fn concat() -> Result<()> { let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); - let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c1 = ColumnarValue::from(ScalarValue::Utf8(Some(",".to_string()))); let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ Some("x"), None, diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index bdf153eaccb6..5b9dd5f75ab6 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -92,10 +92,13 @@ impl ScalarUDFImpl for ConcatWsFunc { // Scalar if array_len.is_none() { let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(s)) => s, + ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::from(ScalarValue::Utf8(None))); + } + _ => unreachable!(), + }, _ => unreachable!(), }; @@ -104,27 +107,33 @@ impl ScalarUDFImpl for ConcatWsFunc { for arg in iter.by_ref() { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - result.push_str(s); - break; - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(s)) => { + result.push_str(s); + break; + } + ScalarValue::Utf8(None) => {} + _ => unreachable!(), + }, _ => unreachable!(), } } for arg in iter.by_ref() { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - result.push_str(sep); - result.push_str(s); - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(s)) => { + result.push_str(sep); + result.push_str(s); + } + ScalarValue::Utf8(None) => {} + _ => unreachable!(), + }, _ => unreachable!(), } } - return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + return Ok(ColumnarValue::from(ScalarValue::Utf8(Some(result)))); } // Array @@ -133,13 +142,18 @@ impl ScalarUDFImpl for ConcatWsFunc { // parse sep let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { - data_size += s.len() * len * (args.len() - 2); // estimate - ColumnarValueRef::Scalar(s.as_bytes()) - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { - return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(Some(s)) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) + } + ScalarValue::Utf8(None) => { + return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null( + len, + )))); + } + _ => unreachable!(), + }, ColumnarValue::Array(array) => { let string_array = as_string_array(array)?; data_size += string_array.values().len() * (args.len() - 2); // estimate @@ -149,18 +163,20 @@ impl ScalarUDFImpl for ConcatWsFunc { ColumnarValueRef::NonNullableArray(string_array) } } - _ => unreachable!(), }; let mut columns = Vec::with_capacity(args.len() - 1); for arg in &args[1..] { match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(s) = maybe_value { - data_size += s.len() * len; - columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Utf8(maybe_value) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } } - } + _ => unreachable!(), + }, ColumnarValue::Array(array) => { let string_array = as_string_array(array)?; data_size += string_array.values().len(); @@ -171,7 +187,6 @@ impl ScalarUDFImpl for ConcatWsFunc { }; columns.push(column); } - _ => unreachable!(), } } @@ -316,10 +331,10 @@ mod tests { test_function!( ConcatWsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("|")), - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::from("bb")), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::from("|")), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::from("bb")), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(Some("aa|bb|cc")), &str, @@ -329,8 +344,8 @@ mod tests { test_function!( ConcatWsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("|")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("|")), + ColumnarValue::from(ScalarValue::Utf8(None)), ], Ok(Some("")), &str, @@ -340,10 +355,10 @@ mod tests { test_function!( ConcatWsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::from("bb")), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::from("bb")), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(None), &str, @@ -353,10 +368,10 @@ mod tests { test_function!( ConcatWsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("|")), - ColumnarValue::Scalar(ScalarValue::from("aa")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("cc")), + ColumnarValue::from(ScalarValue::from("|")), + ColumnarValue::from(ScalarValue::from("aa")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("cc")), ], Ok(Some("aa|cc")), &str, @@ -370,7 +385,7 @@ mod tests { #[test] fn concat_ws() -> Result<()> { // sep is scalar - let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c0 = ColumnarValue::from(ScalarValue::Utf8(Some(",".to_string()))); let c1 = ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index faf979f80614..32a13e2cda72 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -108,8 +108,8 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("alph")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("alph")), ], Ok(Some(true)), bool, @@ -119,8 +119,8 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("dddddd")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("dddddd")), ], Ok(Some(false)), bool, @@ -130,8 +130,8 @@ mod tests { test_function!( ContainsFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("pha")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("pha")), ], Ok(Some(true)), bool, diff --git a/datafusion/functions/src/string/ends_with.rs b/datafusion/functions/src/string/ends_with.rs index 03a1795954d0..82fde772c282 100644 --- a/datafusion/functions/src/string/ends_with.rs +++ b/datafusion/functions/src/string/ends_with.rs @@ -111,8 +111,8 @@ mod tests { test_function!( EndsWithFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("alph")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("alph")), ], Ok(Some(false)), bool, @@ -122,8 +122,8 @@ mod tests { test_function!( EndsWithFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from("bet")), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from("bet")), ], Ok(Some(true)), bool, @@ -133,8 +133,8 @@ mod tests { test_function!( EndsWithFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("alph")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("alph")), ], Ok(None), bool, @@ -144,8 +144,8 @@ mod tests { test_function!( EndsWithFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::Utf8(None)), ], Ok(None), bool, diff --git a/datafusion/functions/src/string/initcap.rs b/datafusion/functions/src/string/initcap.rs index 15861c39e807..66e095b84e0f 100644 --- a/datafusion/functions/src/string/initcap.rs +++ b/datafusion/functions/src/string/initcap.rs @@ -137,7 +137,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))], + &[ColumnarValue::from(ScalarValue::from("hi THOMAS"))], Ok(Some("Hi Thomas")), &str, Utf8, @@ -145,7 +145,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + &[ColumnarValue::from(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -153,7 +153,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::from(""))], + &[ColumnarValue::from(ScalarValue::from(""))], Ok(Some("")), &str, Utf8, @@ -161,7 +161,7 @@ mod tests { ); test_function!( InitcapFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + &[ColumnarValue::from(ScalarValue::Utf8(None))], Ok(None), &str, Utf8, diff --git a/datafusion/functions/src/string/octet_length.rs b/datafusion/functions/src/string/octet_length.rs index d0a533333247..f228291a7925 100644 --- a/datafusion/functions/src/string/octet_length.rs +++ b/datafusion/functions/src/string/octet_length.rs @@ -77,8 +77,8 @@ impl ScalarUDFImpl for OctetLengthFunc { match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), - ColumnarValue::Scalar(v) => match v { - ScalarValue::Utf8(v) => Ok(ColumnarValue::Scalar(ScalarValue::Int32( + ColumnarValue::Scalar(v) => match v.value() { + ScalarValue::Utf8(v) => Ok(ColumnarValue::from(ScalarValue::Int32( v.as_ref().map(|x| x.len() as i32), ))), _ => unreachable!(), @@ -105,7 +105,7 @@ mod tests { fn test_functions() -> Result<()> { test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Int32(Some(12)))], + &[ColumnarValue::from(ScalarValue::Int32(Some(12)))], exec_err!( "The OCTET_LENGTH function can only accept strings, but got Int32." ), @@ -127,8 +127,8 @@ mod tests { test_function!( OctetLengthFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("chars")))) + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("chars")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("chars")))) ], exec_err!("octet_length function requires 1 argument, got 2"), i32, @@ -137,9 +137,9 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("chars") - )))], + &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + "chars" + ))))], Ok(Some(5)), i32, Int32, @@ -147,9 +147,9 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("josé") - )))], + &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + "josé" + ))))], Ok(Some(5)), i32, Int32, @@ -157,9 +157,9 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(Some( - String::from("") - )))], + &[ColumnarValue::from(ScalarValue::Utf8(Some(String::from( + "" + ))))], Ok(Some(0)), i32, Int32, @@ -167,7 +167,7 @@ mod tests { ); test_function!( OctetLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8(None))], + &[ColumnarValue::from(ScalarValue::Utf8(None))], Ok(None), i32, Int32, diff --git a/datafusion/functions/src/string/repeat.rs b/datafusion/functions/src/string/repeat.rs index f04e4ce87546..64ba7e5f65c8 100644 --- a/datafusion/functions/src/string/repeat.rs +++ b/datafusion/functions/src/string/repeat.rs @@ -147,8 +147,8 @@ mod tests { test_function!( RepeatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::from(ScalarValue::Int64(Some(4))), ], Ok(Some("PgPgPgPg")), &str, @@ -158,8 +158,8 @@ mod tests { test_function!( RepeatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::Int64(Some(4))), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::Int64(Some(4))), ], Ok(None), &str, @@ -169,8 +169,8 @@ mod tests { test_function!( RepeatFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("Pg")))), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("Pg")))), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 19721f0fad28..6c0690526803 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -196,11 +196,11 @@ mod tests { test_function!( SplitPartFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(2))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::from(ScalarValue::Int64(Some(2))), ], Ok(Some("def")), &str, @@ -210,11 +210,11 @@ mod tests { test_function!( SplitPartFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(20))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::from(ScalarValue::Int64(Some(20))), ], Ok(Some("")), &str, @@ -224,11 +224,11 @@ mod tests { test_function!( SplitPartFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::from(ScalarValue::Int64(Some(-1))), ], Ok(Some("ghi")), &str, @@ -238,11 +238,11 @@ mod tests { test_function!( SplitPartFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from( + ColumnarValue::from(ScalarValue::Utf8(Some(String::from( "abc~@~def~@~ghi" )))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))), - ColumnarValue::Scalar(ScalarValue::Int64(Some(0))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("~@~")))), + ColumnarValue::from(ScalarValue::Int64(Some(0))), ], exec_err!("field position must not be zero"), &str, diff --git a/datafusion/functions/src/string/starts_with.rs b/datafusion/functions/src/string/starts_with.rs index 78dd49ba4509..07302aa06465 100644 --- a/datafusion/functions/src/string/starts_with.rs +++ b/datafusion/functions/src/string/starts_with.rs @@ -117,8 +117,8 @@ mod tests { .into_iter() .flat_map(|(a, b, c)| { let utf_8_args = vec![ - ColumnarValue::Scalar(ScalarValue::Utf8(a.map(|s| s.to_string()))), - ColumnarValue::Scalar(ScalarValue::Utf8(b.map(|s| s.to_string()))), + ColumnarValue::from(ScalarValue::Utf8(a.map(|s| s.to_string()))), + ColumnarValue::from(ScalarValue::Utf8(b.map(|s| s.to_string()))), ]; vec![(utf_8_args, c)] diff --git a/datafusion/functions/src/unicode/character_length.rs b/datafusion/functions/src/unicode/character_length.rs index 9e8de0a8405f..02a6ed100949 100644 --- a/datafusion/functions/src/unicode/character_length.rs +++ b/datafusion/functions/src/unicode/character_length.rs @@ -131,7 +131,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( CharacterLengthFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, i32, Int32, diff --git a/datafusion/functions/src/unicode/left.rs b/datafusion/functions/src/unicode/left.rs index c49784948dd0..f1f84c98ef5e 100644 --- a/datafusion/functions/src/unicode/left.rs +++ b/datafusion/functions/src/unicode/left.rs @@ -153,8 +153,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("ab")), &str, @@ -164,8 +164,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(200i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(200i64)), ], Ok(Some("abcde")), &str, @@ -175,8 +175,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(-2i64)), ], Ok(Some("abc")), &str, @@ -186,8 +186,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(-200i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(-200i64)), ], Ok(Some("")), &str, @@ -197,8 +197,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("")), &str, @@ -208,8 +208,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(None), &str, @@ -219,8 +219,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -230,8 +230,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("joséé")), &str, @@ -241,8 +241,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(-3i64)), ], Ok(Some("joséé")), &str, @@ -253,8 +253,8 @@ mod tests { test_function!( LeftFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(2i64)), ], internal_err!( "function left requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/unicode/lpad.rs b/datafusion/functions/src/unicode/lpad.rs index 0664da4101e7..d45b2639cb6c 100644 --- a/datafusion/functions/src/unicode/lpad.rs +++ b/datafusion/functions/src/unicode/lpad.rs @@ -265,8 +265,8 @@ mod tests { test_function!( LPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), - ColumnarValue::Scalar($LENGTH) + ColumnarValue::from(ScalarValue::Utf8($INPUT)), + ColumnarValue::from($LENGTH) ], $EXPECTED, &str, @@ -280,9 +280,9 @@ mod tests { test_function!( LPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8($INPUT)), - ColumnarValue::Scalar($LENGTH), - ColumnarValue::Scalar(ScalarValue::Utf8($REPLACE)) + ColumnarValue::from(ScalarValue::Utf8($INPUT)), + ColumnarValue::from($LENGTH), + ColumnarValue::from(ScalarValue::Utf8($REPLACE)) ], $EXPECTED, &str, diff --git a/datafusion/functions/src/unicode/reverse.rs b/datafusion/functions/src/unicode/reverse.rs index ef290a9b3970..7b72570a83e4 100644 --- a/datafusion/functions/src/unicode/reverse.rs +++ b/datafusion/functions/src/unicode/reverse.rs @@ -117,7 +117,7 @@ mod tests { ($INPUT:expr, $EXPECTED:expr) => { test_function!( ReverseFunc::new(), - &[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))], + &[ColumnarValue::from(ScalarValue::Utf8($INPUT))], $EXPECTED, &str, Utf8, diff --git a/datafusion/functions/src/unicode/right.rs b/datafusion/functions/src/unicode/right.rs index 9d542bb2c006..7fadb058c19b 100644 --- a/datafusion/functions/src/unicode/right.rs +++ b/datafusion/functions/src/unicode/right.rs @@ -156,8 +156,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("de")), &str, @@ -167,8 +167,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(200i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(200i64)), ], Ok(Some("abcde")), &str, @@ -178,8 +178,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(-2i64)), ], Ok(Some("cde")), &str, @@ -189,8 +189,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(-200i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(-200i64)), ], Ok(Some("")), &str, @@ -200,8 +200,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("")), &str, @@ -211,8 +211,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(None), &str, @@ -222,8 +222,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -233,8 +233,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("éésoj")), &str, @@ -244,8 +244,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(-3i64)), ], Ok(Some("éésoj")), &str, @@ -256,8 +256,8 @@ mod tests { test_function!( RightFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("abcde")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("abcde")), + ColumnarValue::from(ScalarValue::from(2i64)), ], internal_err!( "function right requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/unicode/rpad.rs b/datafusion/functions/src/unicode/rpad.rs index 4bcf102c8793..657e595f94fa 100644 --- a/datafusion/functions/src/unicode/rpad.rs +++ b/datafusion/functions/src/unicode/rpad.rs @@ -266,8 +266,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("josé")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("josé ")), &str, @@ -277,8 +277,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("hi ")), &str, @@ -288,8 +288,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("")), &str, @@ -299,8 +299,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -310,8 +310,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(None), &str, @@ -321,9 +321,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("xy")), ], Ok(Some("hixyx")), &str, @@ -333,9 +333,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(21i64)), - ColumnarValue::Scalar(ScalarValue::from("abcdef")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(21i64)), + ColumnarValue::from(ScalarValue::from("abcdef")), ], Ok(Some("hiabcdefabcdefabcdefa")), &str, @@ -345,9 +345,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from(" ")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from(" ")), ], Ok(Some("hi ")), &str, @@ -357,9 +357,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("")), ], Ok(Some("hi")), &str, @@ -369,9 +369,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("xy")), ], Ok(None), &str, @@ -381,9 +381,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ColumnarValue::Scalar(ScalarValue::from("xy")), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("xy")), ], Ok(None), &str, @@ -393,9 +393,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("hi")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("hi")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::Utf8(None)), ], Ok(None), &str, @@ -405,9 +405,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("xy")), + ColumnarValue::from(ScalarValue::from("josé")), + ColumnarValue::from(ScalarValue::from(10i64)), + ColumnarValue::from(ScalarValue::from("xy")), ], Ok(Some("joséxyxyxy")), &str, @@ -417,9 +417,9 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(10i64)), - ColumnarValue::Scalar(ScalarValue::from("éñ")), + ColumnarValue::from(ScalarValue::from("josé")), + ColumnarValue::from(ScalarValue::from(10i64)), + ColumnarValue::from(ScalarValue::from("éñ")), ], Ok(Some("josééñéñéñ")), &str, @@ -430,8 +430,8 @@ mod tests { test_function!( RPadFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("josé")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("josé")), + ColumnarValue::from(ScalarValue::from(5i64)), ], internal_err!( "function rpad requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index e756d4b1af7d..7ac2e2da8a29 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -184,8 +184,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("alphabet")), &str, @@ -195,8 +195,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("ésoj")), &str, @@ -206,8 +206,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(-5i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(-5i64)), ], Ok(Some("joséésoj")), &str, @@ -217,8 +217,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(1i64)), ], Ok(Some("alphabet")), &str, @@ -228,8 +228,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("lphabet")), &str, @@ -239,8 +239,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(3i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(3i64)), ], Ok(Some("phabet")), &str, @@ -250,8 +250,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(-3i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(-3i64)), ], Ok(Some("alphabet")), &str, @@ -261,8 +261,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(30i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(30i64)), ], Ok(Some("")), &str, @@ -272,8 +272,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -283,9 +283,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(3i64)), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("ph")), &str, @@ -295,9 +295,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::from(20i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(3i64)), + ColumnarValue::from(ScalarValue::from(20i64)), ], Ok(Some("phabet")), &str, @@ -307,9 +307,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("alph")), &str, @@ -320,9 +320,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(-5i64)), - ColumnarValue::Scalar(ScalarValue::from(10i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(-5i64)), + ColumnarValue::from(ScalarValue::from(10i64)), ], Ok(Some("alph")), &str, @@ -333,9 +333,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(-5i64)), - ColumnarValue::Scalar(ScalarValue::from(4i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(-5i64)), + ColumnarValue::from(ScalarValue::from(4i64)), ], Ok(Some("")), &str, @@ -346,9 +346,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(-5i64)), - ColumnarValue::Scalar(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(-5i64)), + ColumnarValue::from(ScalarValue::from(5i64)), ], Ok(Some("")), &str, @@ -358,9 +358,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::Int64(None)), - ColumnarValue::Scalar(ScalarValue::from(20i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from(20i64)), ], Ok(None), &str, @@ -370,9 +370,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(3i64)), - ColumnarValue::Scalar(ScalarValue::Int64(None)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(3i64)), + ColumnarValue::from(ScalarValue::Int64(None)), ], Ok(None), &str, @@ -382,9 +382,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), - ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from(-1i64)), ], exec_err!("negative substring length not allowed: substr(, 1, -1)"), &str, @@ -394,9 +394,9 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("joséésoj")), - ColumnarValue::Scalar(ScalarValue::from(5i64)), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("joséésoj")), + ColumnarValue::from(ScalarValue::from(5i64)), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("és")), &str, @@ -407,8 +407,8 @@ mod tests { test_function!( SubstrFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("alphabet")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("alphabet")), + ColumnarValue::from(ScalarValue::from(0i64)), ], internal_err!( "function substr requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/unicode/substrindex.rs b/datafusion/functions/src/unicode/substrindex.rs index 6591ee26403a..9ca3d018d884 100644 --- a/datafusion/functions/src/unicode/substrindex.rs +++ b/datafusion/functions/src/unicode/substrindex.rs @@ -213,9 +213,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(1i64)), ], Ok(Some("www")), &str, @@ -225,9 +225,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(2i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(2i64)), ], Ok(Some("www.apache")), &str, @@ -237,9 +237,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(-2i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(-2i64)), ], Ok(Some("apache.org")), &str, @@ -249,9 +249,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(-1i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(-1i64)), ], Ok(Some("org")), &str, @@ -261,9 +261,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(0i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(0i64)), ], Ok(Some("")), &str, @@ -273,9 +273,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("")), - ColumnarValue::Scalar(ScalarValue::from(".")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from("")), + ColumnarValue::from(ScalarValue::from(".")), + ColumnarValue::from(ScalarValue::from(1i64)), ], Ok(Some("")), &str, @@ -285,9 +285,9 @@ mod tests { test_function!( SubstrIndexFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("www.apache.org")), - ColumnarValue::Scalar(ScalarValue::from("")), - ColumnarValue::Scalar(ScalarValue::from(1i64)), + ColumnarValue::from(ScalarValue::from("www.apache.org")), + ColumnarValue::from(ScalarValue::from("")), + ColumnarValue::from(ScalarValue::from(1i64)), ], Ok(Some("")), &str, diff --git a/datafusion/functions/src/unicode/translate.rs b/datafusion/functions/src/unicode/translate.rs index a42b9c6cb857..d49559d452c8 100644 --- a/datafusion/functions/src/unicode/translate.rs +++ b/datafusion/functions/src/unicode/translate.rs @@ -171,9 +171,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("12345")), - ColumnarValue::Scalar(ScalarValue::from("143")), - ColumnarValue::Scalar(ScalarValue::from("ax")) + ColumnarValue::from(ScalarValue::from("12345")), + ColumnarValue::from(ScalarValue::from("143")), + ColumnarValue::from(ScalarValue::from("ax")) ], Ok(Some("a2x5")), &str, @@ -183,9 +183,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("143")), - ColumnarValue::Scalar(ScalarValue::from("ax")) + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("143")), + ColumnarValue::from(ScalarValue::from("ax")) ], Ok(None), &str, @@ -195,9 +195,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("12345")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)), - ColumnarValue::Scalar(ScalarValue::from("ax")) + ColumnarValue::from(ScalarValue::from("12345")), + ColumnarValue::from(ScalarValue::Utf8(None)), + ColumnarValue::from(ScalarValue::from("ax")) ], Ok(None), &str, @@ -207,9 +207,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("12345")), - ColumnarValue::Scalar(ScalarValue::from("143")), - ColumnarValue::Scalar(ScalarValue::Utf8(None)) + ColumnarValue::from(ScalarValue::from("12345")), + ColumnarValue::from(ScalarValue::from("143")), + ColumnarValue::from(ScalarValue::Utf8(None)) ], Ok(None), &str, @@ -219,9 +219,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("é2íñ5")), - ColumnarValue::Scalar(ScalarValue::from("éñí")), - ColumnarValue::Scalar(ScalarValue::from("óü")), + ColumnarValue::from(ScalarValue::from("é2íñ5")), + ColumnarValue::from(ScalarValue::from("éñí")), + ColumnarValue::from(ScalarValue::from("óü")), ], Ok(Some("ó2ü5")), &str, @@ -232,9 +232,9 @@ mod tests { test_function!( TranslateFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::from("12345")), - ColumnarValue::Scalar(ScalarValue::from("143")), - ColumnarValue::Scalar(ScalarValue::from("ax")), + ColumnarValue::from(ScalarValue::from("12345")), + ColumnarValue::from(ScalarValue::from("143")), + ColumnarValue::from(ScalarValue::from("ax")), ], internal_err!( "function translate requires compilation with feature flag: unicode_expressions." diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 7b367174006d..f9918f2d9aae 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -20,9 +20,9 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; +use datafusion_common::Result; use datafusion_expr::function::Hint; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_expr::{ColumnarValue, Scalar, ScalarFunctionImplementation}; /// Creates a function to identify the optimal return type of a string function given /// the type of its first argument. @@ -114,7 +114,7 @@ where let result = (inner)(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) @@ -135,7 +135,7 @@ pub mod test { let expected: Result> = $EXPECTED; let func = $FUNC; - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let type_array = $ARGS.iter().map(|arg| arg.data_type().clone()).collect::>(); let return_type = func.return_type(&type_array); match expected { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 40efbba6de7a..d5c76ed7c833 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -986,7 +986,7 @@ mod test { } fn invoke(&self, _args: &[ColumnarValue]) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::from("a"))) + Ok(ColumnarValue::from(ScalarValue::from("a"))) } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 8455919c35a8..a0aafe2f8644 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -3047,7 +3047,7 @@ Projection: a, b } fn invoke(&self, _args: &[ColumnarValue]) -> Result { - Ok(ColumnarValue::Scalar(ScalarValue::from(1))) + Ok(ColumnarValue::from(ScalarValue::from(1))) } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 559659fdbf24..a707ae80695b 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -27,13 +27,14 @@ use arrow::{ record_batch::RecordBatch, }; -use datafusion_common::logical::eq::LogicallyEq; +// use datafusion_common::logical::eq::LogicallyEq; use datafusion_common::{cast::as_large_list_array, exec_datafusion_err}; use datafusion_common::{ cast::as_list_array, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_common::logical::eq::LogicallyEq; use datafusion_expr::expr::{InList, InSubquery, WindowFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ @@ -629,7 +630,7 @@ impl<'a> ConstEvaluator<'a> { return ConstSimplifyResult::NotSimplified(s); } - let start_type = match expr.get_type(&self.input_schema) { + let expected_type = match expr.get_type(&self.input_schema) { Ok(t) => t, Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; @@ -645,18 +646,6 @@ impl<'a> ConstEvaluator<'a> { Err(err) => return ConstSimplifyResult::SimplifyRuntimeError(err, expr), }; - // TODO(@notfilippo): a fix for the select_arrow_cast error - let end_type = col_val.data_type(); - if end_type.logically_eq(&start_type) && start_type != end_type { - return ConstSimplifyResult::SimplifyRuntimeError( - exec_datafusion_err!( - "Skipping, end_type {} is logically equal to start_type {} but not strictly equal", - end_type, start_type - ), - expr, - ); - } - match col_val { ColumnarValue::Array(a) => { if a.len() != 1 { @@ -691,8 +680,22 @@ impl<'a> ConstEvaluator<'a> { } } ColumnarValue::Scalar(s) => { + // TODO(@notfilippo): a fix for the select_arrow_cast error + let actual_type = s.value().data_type(); + if expected_type.logically_eq(&actual_type) + && expected_type.ne(&actual_type) + { + return ConstSimplifyResult::SimplifyRuntimeError( + exec_datafusion_err!( + "Skipping, actual_type {} is logically equal to expected_type {} but not strictly equal", + actual_type, expected_type + ), + expr, + ); + } + // TODO: support the optimization for `Map` type after support impl hash for it - if matches!(&s, ScalarValue::Map(_)) { + if matches!(s.value(), ScalarValue::Map(_)) { ConstSimplifyResult::SimplifyRuntimeError( DataFusionError::NotImplemented( "Const evaluate for Map type is still not supported" @@ -701,7 +704,7 @@ impl<'a> ConstEvaluator<'a> { expr, ) } else { - ConstSimplifyResult::Simplified(s) + ConstSimplifyResult::Simplified(s.into_value()) } } } diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 96c08d0d3a5b..e2baa373bd41 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -21,8 +21,8 @@ use arrow::buffer::NullBuffer; use arrow::compute::SortOptions; use arrow::error::ArrowError; use datafusion_common::internal_err; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr_common::columnar_value::ColumnarValue; +use datafusion_common::Result; +use datafusion_expr_common::columnar_value::{ColumnarValue, Scalar}; use datafusion_expr_common::operator::Operator; use std::sync::Arc; @@ -39,15 +39,14 @@ pub fn apply( Ok(ColumnarValue::Array(f(&left.as_ref(), &right.as_ref())?)) } (ColumnarValue::Scalar(left), ColumnarValue::Array(right)) => Ok( - ColumnarValue::Array(f(&left.to_scalar()?, &right.as_ref())?), + ColumnarValue::Array(f(&left.value().to_scalar()?, &right.as_ref())?), ), (ColumnarValue::Array(left), ColumnarValue::Scalar(right)) => Ok( ColumnarValue::Array(f(&left.as_ref(), &right.to_scalar()?)?), ), (ColumnarValue::Scalar(left), ColumnarValue::Scalar(right)) => { let array = f(&left.to_scalar()?, &right.to_scalar()?)?; - let scalar = ScalarValue::try_from_array(array.as_ref(), 0)?; - Ok(ColumnarValue::Scalar(scalar)) + Ok(ColumnarValue::Scalar(Scalar::try_from_array(&array, 0)?)) } } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 291db27e6bd7..452a57a08ae8 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -276,8 +276,8 @@ impl PhysicalExpr for BinaryExpr { let lhs = self.left.evaluate(batch)?; let rhs = self.right.evaluate(batch)?; - let left_data_type = lhs.data_type(); - let right_data_type = rhs.data_type(); + let left_data_type = lhs.data_type().clone(); + let right_data_type = rhs.data_type().clone(); let schema = batch.schema(); let input_schema = schema.as_ref(); @@ -319,9 +319,10 @@ impl PhysicalExpr for BinaryExpr { let scalar_result = match (&lhs, &rhs) { (ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => { // if left is array and right is literal - use scalar operations - self.evaluate_array_scalar(array, scalar.clone())?.map(|r| { - r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) - }) + self.evaluate_array_scalar(array, scalar.clone().into_value())? + .map(|r| { + r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) + }) } (_, _) => None, // default to array implementation }; diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index c6afb5c05985..10e63b289653 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -223,12 +223,12 @@ impl CaseExpr { .evaluate_selection(batch, &when_match)?; current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_match)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_match, &then_value.to_scalar()?, ¤t_value)? - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Null => nullif(current_value.as_ref(), &when_match)?, + then_value => { + zip(&when_match, &then_value.to_scalar()?, ¤t_value)? + } + }, ColumnarValue::Array(then_value) => { zip(&when_match, &then_value, ¤t_value)? } @@ -294,12 +294,12 @@ impl CaseExpr { .evaluate_selection(batch, &when_value)?; current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? - } + ColumnarValue::Scalar(scalar) => match scalar.value() { + ScalarValue::Null => nullif(current_value.as_ref(), &when_value)?, + then_value => { + zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + } + }, ColumnarValue::Array(then_value) => { zip(&when_value, &then_value, ¤t_value)? } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 33ca659987a3..b2dadccb4a17 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -37,10 +37,8 @@ use datafusion_common::cast::{ as_boolean_array, as_generic_binary_array, as_string_array, }; use datafusion_common::hash_utils::HashValue; -use datafusion_common::{ - exec_err, internal_err, not_impl_err, DFSchema, Result, ScalarValue, -}; -use datafusion_expr::{ColumnarValue, Operator}; +use datafusion_common::{exec_err, internal_err, not_impl_err, DFSchema, Result}; +use datafusion_expr::{ColumnarValue, Operator, Scalar}; use datafusion_physical_expr_common::datum::compare_op_for_nested; use ahash::RandomState; @@ -227,7 +225,7 @@ fn evaluate_list( }) .collect::>>()?; - ScalarValue::iter_to_array(scalars) + Scalar::iter_to_array(scalars) } fn try_cast_static_filter_to_set( @@ -453,7 +451,7 @@ mod tests { use super::*; use crate::expressions; use crate::expressions::{col, lit, try_cast}; - use datafusion_common::plan_err; + use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::type_coercion::binary::comparison_coercion; type InListCastResult = (Arc, Vec>); diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 58559352d44c..50c3cbab9baf 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -76,8 +76,8 @@ impl PhysicalExpr for IsNotNullExpr { let is_not_null = super::is_null::compute_is_not_null(array)?; Ok(ColumnarValue::Array(Arc::new(is_not_null))) } - ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( - ScalarValue::Boolean(Some(!scalar.is_null())), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::from( + ScalarValue::Boolean(Some(!scalar.value().is_null())), )), } } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 3cdb49bcab42..cdc5f101002e 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -80,8 +80,8 @@ impl PhysicalExpr for IsNullExpr { ColumnarValue::Array(array) => { Ok(ColumnarValue::Array(Arc::new(compute_is_null(array)?))) } - ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( - ScalarValue::Boolean(Some(scalar.is_null())), + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::from( + ScalarValue::Boolean(Some(scalar.value().is_null())), )), } } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index ed24e9028153..e064abbca35c 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -72,7 +72,7 @@ impl PhysicalExpr for Literal { } fn evaluate(&self, _batch: &RecordBatch) -> Result { - Ok(ColumnarValue::Scalar(self.value.clone())) + Ok(ColumnarValue::from(self.value.clone())) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index b5ebc250cb89..01429614d552 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -83,9 +83,9 @@ impl PhysicalExpr for NegativeExpr { let result = neg_wrapping(array.as_ref())?; Ok(ColumnarValue::Array(result)) } - ColumnarValue::Scalar(scalar) => { - Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?)) - } + ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::from( + (scalar.into_value().arithmetic_negate())?, + )), } } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index b69954e00bba..7a0afaa1a637 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -78,13 +78,11 @@ impl PhysicalExpr for NotExpr { ))) } ColumnarValue::Scalar(scalar) => { - if scalar.is_null() { - return Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))); + if scalar.value().is_null() { + return Ok(ColumnarValue::from(ScalarValue::Boolean(None))); } - let bool_value: bool = scalar.try_into()?; - Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( - !bool_value, - )))) + let bool_value: bool = scalar.into_value().try_into()?; + Ok(ColumnarValue::from(ScalarValue::Boolean(Some(!bool_value)))) } } } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 43b6c993d2b2..8da0717165e3 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -23,12 +23,12 @@ use std::sync::Arc; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; use arrow::compute; -use arrow::compute::{cast_with_options, CastOptions}; +use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_common::{not_impl_err, Result}; use datafusion_expr::ColumnarValue; /// TRY_CAST expression casts an expression to a specific data type and returns NULL on invalid cast @@ -83,18 +83,7 @@ impl PhysicalExpr for TryCastExpr { safe: true, format_options: DEFAULT_FORMAT_OPTIONS, }; - match value { - ColumnarValue::Array(array) => { - let cast = cast_with_options(&array, &self.cast_type, &options)?; - Ok(ColumnarValue::Array(cast)) - } - ColumnarValue::Scalar(scalar) => { - let array = scalar.to_array()?; - let cast_array = cast_with_options(&array, &self.cast_type, &options)?; - let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?; - Ok(ColumnarValue::Scalar(cast_scalar)) - } - } + value.cast_to(&self.cast_type, Some(&options)) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index e33c28df1988..6b7c64003408 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -34,8 +34,8 @@ use std::sync::Arc; use arrow::array::{Array, ArrayRef}; -use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use datafusion_common::Result; +use datafusion_expr::{ColumnarValue, Scalar, ScalarFunctionImplementation}; pub use crate::scalar_function::create_physical_expr; // For backward compatibility @@ -113,7 +113,7 @@ where let result = (inner)(&args); if is_scalar { // If all inputs are scalar, keeps output as scalar - let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); + let result = result.and_then(|arr| Scalar::try_from_array(&arr, 0)); result.map(ColumnarValue::Scalar) } else { result.map(ColumnarValue::Array) diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index c99a55259306..d2bb8f2b0ead 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -306,11 +306,9 @@ impl ProjectionStream { let arrays = self .expr .iter() - .zip(&self.schema.fields) - .map(|(expr, field)| { - expr.evaluate(batch).and_then(|v| { - v.into_array_of_type(batch.num_rows(), field.data_type()) - }) + .map(|expr| { + expr.evaluate(batch) + .and_then(|v| v.into_array(batch.num_rows())) }) .collect::>>()?; diff --git a/datafusion/physical-plan/src/values.rs b/datafusion/physical-plan/src/values.rs index 3ea27d62d80b..d0b62a89851b 100644 --- a/datafusion/physical-plan/src/values.rs +++ b/datafusion/physical-plan/src/values.rs @@ -33,6 +33,7 @@ use arrow::datatypes::{Schema, SchemaRef}; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::{internal_err, plan_err, Result, ScalarValue}; use datafusion_execution::TaskContext; +use datafusion_expr::Scalar; use datafusion_physical_expr::EquivalenceProperties; /// Execution plan for values list based relation (produces constant rows) @@ -74,7 +75,7 @@ impl ValuesExec { match r { Ok(ColumnarValue::Scalar(scalar)) => Ok(scalar), Ok(ColumnarValue::Array(a)) if a.len() == 1 => { - ScalarValue::try_from_array(&a, 0) + Ok(Scalar::from(ScalarValue::try_from_array(&a, 0)?)) } Ok(ColumnarValue::Array(a)) => { plan_err!( @@ -85,7 +86,7 @@ impl ValuesExec { } }) .collect::>>() - .and_then(ScalarValue::iter_to_array) + .and_then(Scalar::iter_to_array) }) .collect::>>()?; let batch = RecordBatch::try_new(Arc::clone(&schema), arr)?; diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 5d0315e2d26f..21954b3453c6 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -503,7 +503,7 @@ SELECT INITCAP(column1_large_utf8_lower) as c3 FROM test_lowercase; ---- -Andrew Andrew Andrew +Andrew Andrew Andrew Xiangpeng Xiangpeng Xiangpeng Raphael Raphael Raphael NULL NULL NULL @@ -688,7 +688,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: ltrim(test.column1_utf8view, Utf8View("foo")) AS l +01)Projection: ltrim(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test LTRIM with Utf8View bytes longer than 12 @@ -698,7 +698,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: ltrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +01)Projection: ltrim(test.column1_utf8view, CAST(Utf8("this is longer than 12") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test LTRIM outputs @@ -734,7 +734,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: rtrim(test.column1_utf8view, Utf8View("foo")) AS l +01)Projection: rtrim(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test RTRIM with Utf8View bytes longer than 12 @@ -744,7 +744,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: rtrim(test.column1_utf8view, Utf8View("this is longer than 12")) AS l +01)Projection: rtrim(test.column1_utf8view, CAST(Utf8("this is longer than 12") AS Utf8View)) AS l 02)--TableScan: test projection=[column1_utf8view] # Test RTRIM outputs @@ -816,7 +816,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: ends_with(test.column1_utf8view, Utf8View("foo")) AS c1, ends_with(test.column2_utf8view, test.column2_utf8view) AS c2 +01)Projection: ends_with(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS c1, ends_with(test.column2_utf8view, test.column2_utf8view) AS c2 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for LEVENSHTEIN @@ -827,7 +827,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: levenshtein(test.column1_utf8view, Utf8View("foo")) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2 +01)Projection: levenshtein(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS c1, levenshtein(test.column1_utf8view, test.column2_utf8view) AS c2 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for LOWER @@ -887,7 +887,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: overlay(test.column1_utf8view, Utf8View("foo"), Int64(2)) AS c1 +01)Projection: overlay(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View), Int64(2)) AS c1 02)--TableScan: test projection=[column1_utf8view] query T @@ -1084,8 +1084,9 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: substr_index(test.column1_utf8view, Utf8View("a"), Int64(1)) AS c, substr_index(test.column1_utf8view, Utf8View("a"), Int64(2)) AS c2 -02)--TableScan: test projection=[column1_utf8view] +01)Projection: substr_index(test.column1_utf8view, __common_expr_1, Int64(1)) AS c, substr_index(test.column1_utf8view, __common_expr_1, Int64(2)) AS c2 +02)--Projection: CAST(Utf8("a") AS Utf8View) AS __common_expr_1, test.column1_utf8view +03)----TableScan: test projection=[column1_utf8view] query TT SELECT @@ -1106,7 +1107,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: starts_with(test.column1_utf8view, Utf8View("foo")) AS c, starts_with(test.column1_utf8view, test.column2_utf8view) AS c2 +01)Projection: starts_with(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View)) AS c, starts_with(test.column1_utf8view, test.column2_utf8view) AS c2 02)--TableScan: test projection=[column1_utf8view, column2_utf8view] ## Ensure no casts for TRANSLATE @@ -1126,7 +1127,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: find_in_set(test.column1_utf8view, Utf8View("a,b,c,d")) AS c +01)Projection: find_in_set(test.column1_utf8view, CAST(Utf8("a,b,c,d") AS Utf8View)) AS c 02)--TableScan: test projection=[column1_utf8view] query I From 2578bf9062281e0b21d1de9e1c6eefea85b564fb Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Tue, 20 Aug 2024 10:58:19 +0200 Subject: [PATCH 08/12] Make expr.slt pass --- datafusion/functions/src/encoding/inner.rs | 78 +++++++++++++------ .../sqllogictest/test_files/dictionary.slt | 12 +-- 2 files changed, 62 insertions(+), 28 deletions(-) diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 5a1e85af5ccc..e2dbf5164fde 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -28,7 +28,7 @@ use datafusion_common::{ }; use datafusion_common::{exec_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, Scalar}; use std::sync::Arc; use std::{fmt, str::FromStr}; @@ -173,13 +173,24 @@ fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result match scalar.value() { - ScalarValue::Utf8(a) => { - Ok(encoding.encode_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) - } - ScalarValue::Binary(a) => { - Ok(encoding.encode_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))) - } + ColumnarValue::Scalar(scalar) => match (scalar.value(), scalar.data_type()) { + (ScalarValue::Utf8(a), DataType::Utf8) => Ok(encoding.encode_scalar( + a.as_ref().map(|s: &String| s.as_bytes()), + DataType::Utf8, + )), + (ScalarValue::Utf8(a), DataType::LargeUtf8) => Ok(encoding.encode_scalar( + a.as_ref().map(|s: &String| s.as_bytes()), + DataType::LargeUtf8, + )), + (ScalarValue::Binary(a), DataType::Binary) => Ok(encoding.encode_scalar( + a.as_ref().map(|v: &Vec| v.as_slice()), + DataType::Utf8, + )), + (ScalarValue::Binary(a), DataType::LargeBinary) => Ok(encoding + .encode_scalar( + a.as_ref().map(|v: &Vec| v.as_slice()), + DataType::LargeUtf8, + )), other => exec_err!( "Unsupported data type {other:?} for function encode({encoding})" ), @@ -198,13 +209,23 @@ fn decode_process(value: &ColumnarValue, encoding: Encoding) -> Result match scalar.value() { - ScalarValue::Utf8(a) => { - encoding.decode_scalar(a.as_ref().map(|s: &String| s.as_bytes())) - } - ScalarValue::Binary(a) => { - encoding.decode_scalar(a.as_ref().map(|v: &Vec| v.as_slice())) - } + ColumnarValue::Scalar(scalar) => match (scalar.value(), scalar.data_type()) { + (ScalarValue::Utf8(a), DataType::Utf8) => encoding.decode_scalar( + a.as_ref().map(|s: &String| s.as_bytes()), + DataType::Binary, + ), + (ScalarValue::Utf8(a), DataType::LargeUtf8) => encoding.decode_scalar( + a.as_ref().map(|s: &String| s.as_bytes()), + DataType::LargeBinary, + ), + (ScalarValue::Binary(a), DataType::Binary) => encoding.decode_scalar( + a.as_ref().map(|v: &Vec| v.as_slice()), + DataType::Binary, + ), + (ScalarValue::Binary(a), DataType::LargeBinary) => encoding.decode_scalar( + a.as_ref().map(|v: &Vec| v.as_slice()), + DataType::LargeBinary, + ), other => exec_err!( "Unsupported data type {other:?} for function decode({encoding})" ), @@ -253,13 +274,14 @@ macro_rules! decode_to_array { } impl Encoding { - fn encode_scalar(self, value: Option<&[u8]>) -> ColumnarValue { - ColumnarValue::from(match self { + fn encode_scalar(self, value: Option<&[u8]>, data_type: DataType) -> ColumnarValue { + let value = match self { Self::Base64 => ScalarValue::Utf8( value.map(|v| general_purpose::STANDARD_NO_PAD.encode(v)), ), Self::Hex => ScalarValue::Utf8(value.map(hex::encode)), - }) + }; + ColumnarValue::Scalar(Scalar::new(value, data_type)) } fn encode_binary_array(self, value: &dyn Array) -> Result @@ -286,10 +308,19 @@ impl Encoding { Ok(ColumnarValue::Array(array)) } - fn decode_scalar(self, value: Option<&[u8]>) -> Result { + fn decode_scalar( + self, + value: Option<&[u8]>, + data_type: DataType, + ) -> Result { let value = match value { Some(value) => value, - None => return Ok(ColumnarValue::from(ScalarValue::Binary(None))), + None => { + return Ok(ColumnarValue::Scalar(Scalar::new( + ScalarValue::Binary(None), + data_type, + ))) + } }; let out = match self { @@ -311,7 +342,10 @@ impl Encoding { })?, }; - Ok(ColumnarValue::from(ScalarValue::Binary(Some(out)))) + Ok(ColumnarValue::Scalar(Scalar::new( + ScalarValue::Binary(Some(out)), + data_type, + ))) } fn decode_binary_array(self, value: &dyn Array) -> Result @@ -403,7 +437,7 @@ fn decode(args: &[ColumnarValue]) -> Result { } let encoding = match &args[1] { ColumnarValue::Scalar(scalar) => match scalar.value() { - ScalarValue::Utf8(Some(method)) => { + ScalarValue::Utf8(Some(method)) => { method.parse::() } _ => not_impl_err!( diff --git a/datafusion/sqllogictest/test_files/dictionary.slt b/datafusion/sqllogictest/test_files/dictionary.slt index ec8a51488564..48cbd36d2132 100644 --- a/datafusion/sqllogictest/test_files/dictionary.slt +++ b/datafusion/sqllogictest/test_files/dictionary.slt @@ -407,11 +407,11 @@ query TT explain SELECT * from test where column2 = '1'; ---- logical_plan -01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +01)Filter: test.column2 = CAST(Utf8("1") AS Dictionary(Int32, Utf8)) 02)--TableScan: test projection=[column1, column2] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column2@1 = 1 +02)--FilterExec: column2@1 = CAST(1 AS Dictionary(Int32, Utf8)) 03)----MemoryExec: partitions=1, partition_sizes=[1] # try literal = col to verify order doesn't matter @@ -420,11 +420,11 @@ query TT explain SELECT * from test where '1' = column2 ---- logical_plan -01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +01)Filter: CAST(Utf8("1") AS Dictionary(Int32, Utf8)) = test.column2 02)--TableScan: test projection=[column1, column2] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column2@1 = 1 +02)--FilterExec: CAST(1 AS Dictionary(Int32, Utf8)) = column2@1 03)----MemoryExec: partitions=1, partition_sizes=[1] @@ -438,9 +438,9 @@ query TT explain SELECT * from test where column2 = 1; ---- logical_plan -01)Filter: test.column2 = Dictionary(Int32, Utf8("1")) +01)Filter: CAST(test.column2 AS Utf8) = Utf8("1") 02)--TableScan: test projection=[column1, column2] physical_plan 01)CoalesceBatchesExec: target_batch_size=8192 -02)--FilterExec: column2@1 = 1 +02)--FilterExec: CAST(column2@1 AS Utf8) = 1 03)----MemoryExec: partitions=1, partition_sizes=[1] From 058553e8d06bb716c3f4b91c59e8a4e75eb5d685 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Tue, 20 Aug 2024 11:08:22 +0200 Subject: [PATCH 09/12] Apply suggestions and lint --- datafusion/common/src/logical/{eq.rs => equality.rs} | 4 ++++ datafusion/common/src/logical/mod.rs | 2 +- datafusion/functions/src/core/coalesce.rs | 2 +- .../optimizer/src/simplify_expressions/expr_simplifier.rs | 2 +- datafusion/physical-expr/src/functions.rs | 4 ++-- 5 files changed, 9 insertions(+), 5 deletions(-) rename datafusion/common/src/logical/{eq.rs => equality.rs} (87%) diff --git a/datafusion/common/src/logical/eq.rs b/datafusion/common/src/logical/equality.rs similarity index 87% rename from datafusion/common/src/logical/eq.rs rename to datafusion/common/src/logical/equality.rs index c8d60736fe95..239cebf1338f 100644 --- a/datafusion/common/src/logical/eq.rs +++ b/datafusion/common/src/logical/equality.rs @@ -31,9 +31,13 @@ impl LogicallyEq for DataType { | (Binary | LargeBinary | BinaryView, Binary | LargeBinary | BinaryView) => { true } + (Dictionary(_, left), Dictionary(_, right)) => left.logically_eq(right), (Dictionary(_, inner), other) | (other, Dictionary(_, inner)) => { other.logically_eq(inner) } + (RunEndEncoded(_, left), RunEndEncoded(_, right)) => { + left.data_type().logically_eq(right.data_type()) + } (RunEndEncoded(_, inner), other) | (other, RunEndEncoded(_, inner)) => { other.logically_eq(inner.data_type()) } diff --git a/datafusion/common/src/logical/mod.rs b/datafusion/common/src/logical/mod.rs index ff72c3dd28d2..ff4f478bc1a8 100644 --- a/datafusion/common/src/logical/mod.rs +++ b/datafusion/common/src/logical/mod.rs @@ -15,4 +15,4 @@ // specific language governing permissions and limitations // under the License. -pub mod eq; +pub mod equality; diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index c790aecbdb4f..ad193c932b8f 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -81,7 +81,7 @@ impl ScalarUDFImpl for CoalesceFunc { if let Some(size) = return_array.next() { // start with nulls as default output - let mut current_value = new_null_array(&return_type, size); + let mut current_value = new_null_array(return_type, size); let mut remainder = BooleanArray::from(vec![true; size]); for arg in args { diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index a707ae80695b..5b592042a59a 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -28,13 +28,13 @@ use arrow::{ }; // use datafusion_common::logical::eq::LogicallyEq; +use datafusion_common::logical::equality::LogicallyEq; use datafusion_common::{cast::as_large_list_array, exec_datafusion_err}; use datafusion_common::{ cast::as_list_array, tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_common::logical::eq::LogicallyEq; use datafusion_expr::expr::{InList, InSubquery, WindowFunction}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 6b7c64003408..92e64ddaca7a 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -69,9 +69,9 @@ where make_scalar_function_with_hints(inner, vec![]) } -/// Just like [`make_scalar_function`], decorates the given function to handle both [`ScalarValue`]s and arrays. +/// Just like [`make_scalar_function`], decorates the given function to handle both [`Scalar`]s and arrays. /// Additionally can receive a `hints` vector which can be used to control the output arrays when generating them -/// from [`ScalarValue`]s. +/// from [`Scalar`]s. /// /// Each element of the `hints` vector gets mapped to the corresponding argument of the function. The number of hints /// can be less or greater than the number of arguments (for functions with variable number of arguments). Each unmapped From 4ad496553df74ce5d11c176451a19612a1b2a23c Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Tue, 20 Aug 2024 11:45:13 +0200 Subject: [PATCH 10/12] Updates due to merge --- datafusion/functions/src/string/replace.rs | 39 ++----------------- .../physical-expr/src/expressions/binary.rs | 8 ++-- datafusion/physical-expr/src/functions.rs | 2 +- .../sqllogictest/test_files/string_view.slt | 6 ++- 4 files changed, 14 insertions(+), 41 deletions(-) diff --git a/datafusion/functions/src/string/replace.rs b/datafusion/functions/src/string/replace.rs index 13fa3d55672d..0a58b9e55727 100644 --- a/datafusion/functions/src/string/replace.rs +++ b/datafusion/functions/src/string/replace.rs @@ -127,18 +127,17 @@ mod tests { use super::*; use crate::utils::test::test_function; use arrow::array::Array; - use arrow::array::LargeStringArray; use arrow::array::StringArray; - use arrow::datatypes::DataType::{LargeUtf8, Utf8}; + use arrow::datatypes::DataType::Utf8; use datafusion_common::ScalarValue; #[test] fn test_functions() -> Result<()> { test_function!( ReplaceFunc::new(), &[ - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))), - ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("bb")))), + ColumnarValue::from(ScalarValue::Utf8(Some(String::from("ccc")))), ], Ok(Some("aacccdqcccc")), &str, @@ -146,36 +145,6 @@ mod tests { StringArray ); - test_function!( - ReplaceFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from( - "aabbb" - )))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))), - ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))), - ], - Ok(Some("aacc")), - &str, - LargeUtf8, - LargeStringArray - ); - - test_function!( - ReplaceFunc::new(), - &[ - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( - "aabbbcw" - )))), - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))), - ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))), - ], - Ok(Some("aaccbcw")), - &str, - Utf8, - StringArray - ); - Ok(()) } } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 074a5ce43858..e9b66ca53972 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -323,9 +323,11 @@ impl PhysicalExpr for BinaryExpr { None } else { self.evaluate_array_scalar(array, scalar.clone().into_value())? - .map(|r| { - r.and_then(|a| to_result_type_array(&self.op, a, &result_type)) - }) + .map(|r| { + r.and_then(|a| { + to_result_type_array(&self.op, a, &result_type) + }) + }) } } (_, _) => None, // default to array implementation diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 92e64ddaca7a..7935839d1c98 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -46,7 +46,7 @@ pub fn columnar_values_to_array(args: &[ColumnarValue]) -> Result> ColumnarValue::values_to_arrays(args) } -/// Decorates a function to handle [`ScalarValue`]s by converting them to arrays before calling the function +/// Decorates a function to handle [`Scalar`]s by converting them to arrays before calling the function /// and vice-versa after evaluation. /// Note that this function makes a scalar function with no arguments or all scalar inputs return a scalar. /// That's said its output will be same for all input rows in a batch. diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index f5eafd73016d..8b62b68bebcc 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -946,8 +946,10 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: replace(test.column1_utf8view, Utf8View("foo"), Utf8View("bar")) AS c1, replace(test.column1_utf8view, test.column2_utf8view, Utf8View("bar")) AS c2 -02)--TableScan: test projection=[column1_utf8view, column2_utf8view] +01)Projection: replace(test.column1_utf8view, CAST(Utf8("foo") AS Utf8View), __common_expr_1) AS c1, replace(test.column1_utf8view, test.column2_utf8view, __common_expr_1) AS c2 +02)--Projection: CAST(Utf8("bar") AS Utf8View) AS __common_expr_1, test.column1_utf8view, test.column2_utf8view +03)----TableScan: test projection=[column1_utf8view, column2_utf8view] + query TT SELECT From 9308d0b5b17b0fb06a75cf997ff99788f69e8ab6 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Fri, 23 Aug 2024 11:36:15 +0200 Subject: [PATCH 11/12] Fix aggregates --- datafusion/expr-common/src/accumulator.rs | 10 +++++ .../src/aggregate/groups_accumulator.rs | 14 +++---- datafusion/functions-aggregate/src/min_max.rs | 39 ++++++++++++++++++- .../functions-aggregate/src/string_agg.rs | 13 ++++++- .../physical-plan/src/aggregates/mod.rs | 6 ++- 5 files changed, 71 insertions(+), 11 deletions(-) diff --git a/datafusion/expr-common/src/accumulator.rs b/datafusion/expr-common/src/accumulator.rs index 75335209451e..43f1e86a1ff7 100644 --- a/datafusion/expr-common/src/accumulator.rs +++ b/datafusion/expr-common/src/accumulator.rs @@ -17,6 +17,7 @@ //! Accumulator module contains the trait definition for aggregation function's accumulators. +use crate::columnar_value::Scalar; use arrow::array::ArrayRef; use datafusion_common::{internal_err, Result, ScalarValue}; use std::fmt::Debug; @@ -72,6 +73,10 @@ pub trait Accumulator: Send + Sync + Debug { /// when possible (for example distinct strings) fn evaluate(&mut self) -> Result; + fn evaluate_as_scalar(&mut self) -> Result { + self.evaluate().map(Scalar::from) + } + /// Returns the allocated size required for this accumulator, in /// bytes, including `Self`. /// @@ -250,6 +255,11 @@ pub trait Accumulator: Send + Sync + Debug { /// ``` fn state(&mut self) -> Result>; + fn state_as_scalars(&mut self) -> Result> { + self.state() + .map(|scalars| scalars.into_iter().map(Scalar::from).collect()) + } + /// Updates the accumulator's state from an `Array` containing one /// or more intermediate values. /// diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index 3984b02c5fbb..8166fd20a571 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -30,9 +30,9 @@ use arrow::{ }; use datafusion_common::{ arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result, - ScalarValue, }; use datafusion_expr_common::accumulator::Accumulator; +use datafusion_expr_common::columnar_value::Scalar; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] @@ -275,15 +275,15 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { let states = emit_to.take_needed(&mut self.states); - let results: Vec = states + let results: Vec = states .into_iter() .map(|mut state| { self.free_allocation(state.size()); - state.accumulator.evaluate() + state.accumulator.evaluate_as_scalar() }) .collect::>()?; - let result = ScalarValue::iter_to_array(results); + let result = Scalar::iter_to_array(results); self.adjust_allocation(vec_size_pre, self.states.allocated_size()); @@ -296,11 +296,11 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // each accumulator produces a potential vector of values // which we need to form into columns - let mut results: Vec> = vec![]; + let mut results: Vec> = vec![]; for mut state in states { self.free_allocation(state.size()); - let accumulator_state = state.accumulator.state()?; + let accumulator_state = state.accumulator.state_as_scalars()?; results.resize_with(accumulator_state.len(), Vec::new); for (idx, state_val) in accumulator_state.into_iter().enumerate() { results[idx].push(state_val); @@ -310,7 +310,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter { // create an array for each intermediate column let arrays = results .into_iter() - .map(ScalarValue::iter_to_array) + .map(Scalar::iter_to_array) .collect::>>()?; // double check each array has the same length (aka the diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 2db5861ae141..fe9d188c563c 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -63,10 +63,10 @@ use arrow::datatypes::{ }; use datafusion_common::ScalarValue; -use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; +use datafusion_expr::{GroupsAccumulator, Scalar}; use half::f16; use std::ops::Deref; @@ -769,6 +769,7 @@ macro_rules! min_max { #[derive(Debug)] pub struct MaxAccumulator { max: ScalarValue, + return_type: DataType, } impl MaxAccumulator { @@ -776,6 +777,7 @@ impl MaxAccumulator { pub fn try_new(datatype: &DataType) -> Result { Ok(Self { max: ScalarValue::try_from(datatype)?, + return_type: datatype.clone(), }) } } @@ -797,9 +799,15 @@ impl Accumulator for MaxAccumulator { fn state(&mut self) -> Result> { Ok(vec![self.evaluate()?]) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } fn evaluate(&mut self) -> Result { Ok(self.max.clone()) } + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new(self.evaluate()?, self.return_type.clone())) + } fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.max) + self.max.size() @@ -810,6 +818,7 @@ impl Accumulator for MaxAccumulator { pub struct SlidingMaxAccumulator { max: ScalarValue, moving_max: MovingMax, + return_type: DataType, } impl SlidingMaxAccumulator { @@ -818,6 +827,7 @@ impl SlidingMaxAccumulator { Ok(Self { max: ScalarValue::try_from(datatype)?, moving_max: MovingMax::::new(), + return_type: datatype.clone(), }) } } @@ -856,6 +866,13 @@ impl Accumulator for SlidingMaxAccumulator { Ok(self.max.clone()) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } + + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new(self.evaluate()?, self.return_type.clone())) + } fn supports_retract_batch(&self) -> bool { true } @@ -1026,6 +1043,7 @@ impl AggregateUDFImpl for Min { #[derive(Debug)] pub struct MinAccumulator { min: ScalarValue, + return_type: DataType, } impl MinAccumulator { @@ -1033,6 +1051,7 @@ impl MinAccumulator { pub fn try_new(datatype: &DataType) -> Result { Ok(Self { min: ScalarValue::try_from(datatype)?, + return_type: datatype.clone(), }) } } @@ -1042,6 +1061,10 @@ impl Accumulator for MinAccumulator { Ok(vec![self.evaluate()?]) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &min_batch(values)?; @@ -1059,6 +1082,10 @@ impl Accumulator for MinAccumulator { Ok(self.min.clone()) } + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new(self.evaluate()?, self.return_type.clone())) + } + fn size(&self) -> usize { std::mem::size_of_val(self) - std::mem::size_of_val(&self.min) + self.min.size() } @@ -1068,6 +1095,7 @@ impl Accumulator for MinAccumulator { pub struct SlidingMinAccumulator { min: ScalarValue, moving_min: MovingMin, + return_type: DataType, } impl SlidingMinAccumulator { @@ -1075,6 +1103,7 @@ impl SlidingMinAccumulator { Ok(Self { min: ScalarValue::try_from(datatype)?, moving_min: MovingMin::::new(), + return_type: datatype.clone(), }) } } @@ -1084,6 +1113,14 @@ impl Accumulator for SlidingMinAccumulator { Ok(vec![self.min.clone()]) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } + + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new(self.evaluate()?, self.return_type.clone())) + } + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { for idx in 0..values[0].len() { let val = ScalarValue::try_from_array(&values[0], idx)?; diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 510b3b535dc9..5269cff8b9e2 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -24,7 +24,7 @@ use datafusion_common::Result; use datafusion_common::{not_impl_err, ScalarValue}; use datafusion_expr::function::AccumulatorArgs; use datafusion_expr::{ - Accumulator, AggregateUDFImpl, Signature, TypeSignature, Volatility, + Accumulator, AggregateUDFImpl, Scalar, Signature, TypeSignature, Volatility, }; use datafusion_physical_expr::expressions::Literal; use std::any::Any; @@ -140,10 +140,21 @@ impl Accumulator for StringAggAccumulator { Ok(vec![self.evaluate()?]) } + fn state_as_scalars(&mut self) -> Result> { + Ok(vec![self.evaluate_as_scalar()?]) + } + fn evaluate(&mut self) -> Result { Ok(ScalarValue::Utf8(self.values.clone())) } + fn evaluate_as_scalar(&mut self) -> Result { + Ok(Scalar::new( + ScalarValue::Utf8(self.values.clone()), + DataType::LargeUtf8, + )) + } + fn size(&self) -> usize { std::mem::size_of_val(self) + self.values.as_ref().map(|v| v.capacity()).unwrap_or(0) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 89d4c452cca6..bebfd64097a1 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1074,7 +1074,7 @@ pub fn finalize_aggregation( accumulators .iter_mut() .map(|accumulator| { - accumulator.state().and_then(|e| { + accumulator.state_as_scalars().and_then(|e| { e.iter() .map(|v| v.to_array()) .collect::>>() @@ -1090,7 +1090,9 @@ pub fn finalize_aggregation( // Merge the state to the final value accumulators .iter_mut() - .map(|accumulator| accumulator.evaluate().and_then(|v| v.to_array())) + .map(|accumulator| { + accumulator.evaluate_as_scalar().and_then(|v| v.to_array()) + }) .collect() } } From f158e17b55fd88db3abce5dcd2b45922b9506b84 Mon Sep 17 00:00:00 2001 From: Filippo Rossi Date: Mon, 16 Sep 2024 18:11:40 +0200 Subject: [PATCH 12/12] Resolve conflicts --- datafusion/functions-nested/src/array_has.rs | 2 +- datafusion/functions/src/string/split_part.rs | 2 +- datafusion/physical-expr-common/src/datum.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 595b4830ec01..d9a780cd2b6f 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -131,7 +131,7 @@ impl ScalarUDFImpl for ArrayHas { if is_scalar { // If all inputs are scalar, keeps output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(|x| ColumnarValue::from(x)) + result.map(ColumnarValue::from) } else { result.map(ColumnarValue::Array) } diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 7b0506a7495c..438e2e611359 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -173,7 +173,7 @@ impl ScalarUDFImpl for SplitPartFunc { if is_scalar { // If all inputs are scalar, keep the output as scalar let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0)); - result.map(|x| ColumnarValue::from(x)) + result.map(ColumnarValue::from) } else { result.map(ColumnarValue::Array) } diff --git a/datafusion/physical-expr-common/src/datum.rs b/datafusion/physical-expr-common/src/datum.rs index 57f710ecbad1..b6f64245ae64 100644 --- a/datafusion/physical-expr-common/src/datum.rs +++ b/datafusion/physical-expr-common/src/datum.rs @@ -21,8 +21,8 @@ use arrow::buffer::NullBuffer; use arrow::compute::SortOptions; use arrow::error::ArrowError; use datafusion_common::DataFusionError; -use datafusion_common::{arrow_datafusion_err, internal_err}; use datafusion_common::Result; +use datafusion_common::{arrow_datafusion_err, internal_err}; use datafusion_expr_common::columnar_value::{ColumnarValue, Scalar}; use datafusion_expr_common::operator::Operator; use std::sync::Arc;