Skip to content

Commit

Permalink
Add StringView support for date_part and make_date funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
a10y committed Jul 18, 2024
1 parent 2c808fb commit 7ddd7e6
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 17 deletions.
6 changes: 6 additions & 0 deletions datafusion/common/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use arrow::{
},
datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType},
};
use arrow_array::StringViewArray;

// Downcast ArrayRef to Date32Array
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array> {
Expand Down Expand Up @@ -87,6 +88,11 @@ pub fn as_string_array(array: &dyn Array) -> Result<&StringArray> {
Ok(downcast_value!(array, StringArray))
}

// Downcast ArrayRef to StringArray
pub fn as_string_view_array(array: &dyn Array) -> Result<&StringViewArray> {
Ok(downcast_value!(array, StringViewArray))
}

// Downcast ArrayRef to UInt32Array
pub fn as_uint32_array(array: &dyn Array) -> Result<&UInt32Array> {
Ok(downcast_value!(array, UInt32Array))
Expand Down
3 changes: 2 additions & 1 deletion datafusion/common/src/hash_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use arrow_buffer::IntervalMonthDayNano;
use crate::cast::{
as_boolean_array, as_fixed_size_list_array, as_generic_binary_array,
as_large_list_array, as_list_array, as_primitive_array, as_string_array,
as_struct_array,
as_string_view_array, as_struct_array,
};
use crate::error::{Result, _internal_err};

Expand Down Expand Up @@ -369,6 +369,7 @@ pub fn create_hashes<'a>(
DataType::Null => hash_null(random_state, hashes_buffer, rehash),
DataType::Boolean => hash_array(as_boolean_array(array)?, random_state, hashes_buffer, rehash),
DataType::Utf8 => hash_array(as_string_array(array)?, random_state, hashes_buffer, rehash),
DataType::Utf8View => hash_array(as_string_view_array(array)?, random_state, hashes_buffer, rehash),
DataType::LargeUtf8 => hash_array(as_largestring_array(array), random_state, hashes_buffer, rehash),
DataType::Binary => hash_array(as_generic_binary_array::<i32>(array)?, random_state, hashes_buffer, rehash),
DataType::LargeBinary => hash_array(as_generic_binary_array::<i64>(array)?, random_state, hashes_buffer, rehash),
Expand Down
26 changes: 14 additions & 12 deletions datafusion/expr/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ fn string_numeric_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<D
}

/// Coerce `lhs_type` and `rhs_type` to a common type for the purposes of a comparison operation
/// where one is temporal and one is `Utf8`/`LargeUtf8`.
/// where one is temporal and one is `Utf8View`/`Utf8`/`LargeUtf8`.
///
/// Note this cannot be performed in case of arithmetic as there is insufficient information
/// to correctly determine the type of argument. Consider
Expand All @@ -547,19 +547,21 @@ fn string_temporal_coercion(

fn match_rule(l: &DataType, r: &DataType) -> Option<DataType> {
match (l, r) {
// Coerce Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp
(Utf8, temporal) | (LargeUtf8, temporal) => match temporal {
Date32 | Date64 => Some(temporal.clone()),
Time32(_) | Time64(_) => {
if is_time_with_valid_unit(temporal.to_owned()) {
Some(temporal.to_owned())
} else {
None
// Coerce Utf8View/Utf8/LargeUtf8 to Date32/Date64/Time32/Time64/Timestamp
(Utf8, temporal) | (LargeUtf8, temporal) | (Utf8View, temporal) => {
match temporal {
Date32 | Date64 => Some(temporal.clone()),
Time32(_) | Time64(_) => {
if is_time_with_valid_unit(temporal.to_owned()) {
Some(temporal.to_owned())
} else {
None
}
}
Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())),
_ => None,
}
Timestamp(_, tz) => Some(Timestamp(TimeUnit::Nanosecond, tz.clone())),
_ => None,
},
}
_ => None,
}
}
Expand Down
30 changes: 29 additions & 1 deletion datafusion/functions/src/datetime/date_part.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;
use arrow::array::{Array, ArrayRef, Float64Array};
use arrow::compute::{binary, cast, date_part, DatePart};
use arrow::datatypes::DataType::{
Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8,
Date32, Date64, Float64, Time32, Time64, Timestamp, Utf8, Utf8View,
};
use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second};
use arrow::datatypes::{DataType, TimeUnit};
Expand Down Expand Up @@ -56,31 +56,57 @@ impl DatePartFunc {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Timestamp(Nanosecond, None)]),
Exact(vec![Utf8View, Timestamp(Nanosecond, None)]),
Exact(vec![
Utf8,
Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Nanosecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Millisecond, None)]),
Exact(vec![Utf8View, Timestamp(Millisecond, None)]),
Exact(vec![
Utf8,
Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Millisecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Microsecond, None)]),
Exact(vec![Utf8View, Timestamp(Microsecond, None)]),
Exact(vec![
Utf8,
Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Microsecond, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Timestamp(Second, None)]),
Exact(vec![Utf8View, Timestamp(Second, None)]),
Exact(vec![
Utf8,
Timestamp(Second, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![
Utf8View,
Timestamp(Second, Some(TIMEZONE_WILDCARD.into())),
]),
Exact(vec![Utf8, Date64]),
Exact(vec![Utf8View, Date64]),
Exact(vec![Utf8, Date32]),
Exact(vec![Utf8View, Date32]),
Exact(vec![Utf8, Time32(Second)]),
Exact(vec![Utf8View, Time32(Second)]),
Exact(vec![Utf8, Time32(Millisecond)]),
Exact(vec![Utf8View, Time32(Millisecond)]),
Exact(vec![Utf8, Time64(Microsecond)]),
Exact(vec![Utf8View, Time64(Microsecond)]),
Exact(vec![Utf8, Time64(Nanosecond)]),
Exact(vec![Utf8View, Time64(Nanosecond)]),
],
Volatility::Immutable,
),
Expand Down Expand Up @@ -114,6 +140,8 @@ impl ScalarUDFImpl for DatePartFunc {

let part = if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = part {
v
} else if let ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) = part {
v
} else {
return exec_err!(
"First argument of `DATE_PART` must be non-null scalar Utf8"
Expand Down
4 changes: 2 additions & 2 deletions datafusion/functions/src/datetime/make_date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow::array::cast::AsArray;
use arrow::array::types::{Date32Type, Int32Type};
use arrow::array::PrimitiveArray;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8};
use arrow::datatypes::DataType::{Date32, Int32, Int64, UInt32, UInt64, Utf8, Utf8View};
use chrono::prelude::*;

use datafusion_common::{exec_err, Result, ScalarValue};
Expand All @@ -45,7 +45,7 @@ impl MakeDateFunc {
Self {
signature: Signature::uniform(
3,
vec![Int32, Int64, UInt32, UInt64, Utf8],
vec![Int32, Int64, UInt32, UInt64, Utf8, Utf8View],
Volatility::Immutable,
),
}
Expand Down
8 changes: 7 additions & 1 deletion datafusion/functions/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::datatypes::DataType;

use datafusion_common::{Result, ScalarValue};
use datafusion_expr::function::Hint;
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
use std::sync::Arc;

/// Creates a function to identify the optimal return type of a string function given
/// the type of its first argument.
Expand All @@ -29,6 +31,8 @@ use std::sync::Arc;
/// `$largeUtf8Type`,
///
/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
///
/// If the input type is `Utf8View` the return type is `Utf8View`,
macro_rules! get_optimal_return_type {
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
Expand All @@ -37,6 +41,8 @@ macro_rules! get_optimal_return_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
// Binary inputs are automatically coerced to Utf8
DataType::Utf8 | DataType::Binary => $utf8Type,
// Utf8View inputs will yield Utf8View outputs
DataType::Utf8View => DataType::Utf8View,
DataType::Null => DataType::Null,
DataType::Dictionary(_, value_type) => match **value_type {
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
Expand Down

0 comments on commit 7ddd7e6

Please sign in to comment.