Skip to content

Commit

Permalink
Adding safe support to to_date and to_timestamp functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 committed Jul 9, 2024
1 parent e65c3e9 commit 0a5570c
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 77 deletions.
46 changes: 38 additions & 8 deletions datafusion/functions/src/datetime/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ pub(crate) fn handle<'a, O, F, S>(
args: &'a [ColumnarValue],
op: F,
name: &str,
safe: bool,
) -> Result<ColumnarValue>
where
O: ArrowPrimitiveType,
Expand All @@ -164,14 +165,25 @@ where
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 | DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new(
unary_string_to_primitive_function::<i32, O, _>(&[a.as_ref()], op, name)?,
unary_string_to_primitive_function::<i32, O, _>(
&[a.as_ref()],
op,
name,
safe,
)?,
))),
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| (op)(x)).transpose()?;
Ok(ColumnarValue::Scalar(S::scalar(result)))
let result = a.as_ref().map(|x| op(x)).transpose();
if let Ok(v) = result {
Ok(ColumnarValue::Scalar(S::scalar(v)))
} else if safe {
Ok(ColumnarValue::Scalar(S::scalar(None)))
} else {
Err(result.err().unwrap())
}
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
Expand All @@ -186,6 +198,7 @@ pub(crate) fn handle_multiple<'a, O, F, S, M>(
op: F,
op2: M,
name: &str,
safe: bool,
) -> Result<ColumnarValue>
where
O: ArrowPrimitiveType,
Expand Down Expand Up @@ -217,7 +230,9 @@ where
}

Ok(ColumnarValue::Array(Arc::new(
strings_to_primitive_function::<i32, O, _, _>(args, op, op2, name)?,
strings_to_primitive_function::<i32, O, _, _>(
args, op, op2, name, safe,
)?,
)))
}
other => {
Expand Down Expand Up @@ -264,8 +279,13 @@ where

if let Some(v) = val {
v
} else if safe {
Ok(ColumnarValue::Scalar(S::scalar(None)))
} else {
Err(err.unwrap())
match err {
Some(e) => Err(e),
None => Ok(ColumnarValue::Scalar(S::scalar(None))),
}
}
}
other => {
Expand All @@ -285,12 +305,13 @@ where
/// This function errors iff:
/// * the number of arguments is not > 1 or
/// * the array arguments are not castable to a `GenericStringArray` or
/// * the function `op` errors for all input
/// * the function `op` errors for all input and safe is false
pub(crate) fn strings_to_primitive_function<'a, T, O, F, F2>(
args: &'a [ColumnarValue],
op: F,
op2: F2,
name: &str,
safe: bool,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
Expand Down Expand Up @@ -360,6 +381,7 @@ where
};

val.transpose()
.or_else(|e| if safe { Ok(None) } else { Err(e) })
})
.collect()
}
Expand All @@ -371,11 +393,12 @@ where
/// This function errors iff:
/// * the number of arguments is not 1 or
/// * the first argument is not castable to a `GenericStringArray` or
/// * the function `op` errors
/// * the function `op` errors and safe is false
fn unary_string_to_primitive_function<'a, T, O, F>(
args: &[&'a dyn Array],
op: F,
name: &str,
safe: bool,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
Expand All @@ -393,5 +416,12 @@ where
let array = as_generic_string_array::<T>(args[0])?;

// first map is the iterator, second is for the `Option<_>`
array.iter().map(|x| x.map(&op).transpose()).collect()
array
.iter()
.map(|x| {
x.map(&op)
.transpose()
.or_else(|e| if safe { Ok(None) } else { Err(e) })
})
.collect()
}
18 changes: 17 additions & 1 deletion datafusion/functions/src/datetime/to_date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use std::any::Any;

use arrow::array::types::Date32Type;
use arrow::compute::CastOptions;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Date32;

Expand All @@ -28,6 +29,8 @@ use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
#[derive(Debug)]
pub struct ToDateFunc {
signature: Signature,
/// how to handle cast or parsing failures, either return NULL (safe=true) or return ERR (safe=false)
safe: bool,
}

impl Default for ToDateFunc {
Expand All @@ -40,6 +43,14 @@ impl ToDateFunc {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
safe: false,
}
}

pub fn new_with_safe(safe: bool) -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
safe,
}
}

Expand All @@ -57,6 +68,7 @@ impl ToDateFunc {
})
},
"to_date",
self.safe,
),
n if n >= 2 => handle_multiple::<Date32Type, _, Date32Type, _>(
args,
Expand All @@ -71,6 +83,7 @@ impl ToDateFunc {
},
|n| n,
"to_date",
self.safe,
),
_ => exec_err!("Unsupported 0 argument count for function to_date"),
}
Expand Down Expand Up @@ -110,7 +123,10 @@ impl ScalarUDFImpl for ToDateFunc {
| DataType::Null
| DataType::Float64
| DataType::Date32
| DataType::Date64 => args[0].cast_to(&DataType::Date32, None),
| DataType::Date64 => match self.safe {
true => args[0].cast_to(&DataType::Date32, Some(&CastOptions::default())),
false => args[0].cast_to(&DataType::Date32, None),
},
DataType::Utf8 => self.to_date(args),
other => {
exec_err!("Unsupported data type {:?} for function to_date", other)
Expand Down
Loading

0 comments on commit 0a5570c

Please sign in to comment.