Skip to content

Commit

Permalink
ScalarUDF with zero arguments should be provided with one null array …
Browse files Browse the repository at this point in the history
…as parameter (#9031)

* Fix ScalaUDF with zero arguments

* Fix test

* Fix clippy

* Fix

* Exclude built-in scalar functions

* For review
  • Loading branch information
viirya authored Jan 30, 2024
1 parent efd2fd2 commit 85ceb9d
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 32 deletions.
4 changes: 4 additions & 0 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,7 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1336,6 +1337,7 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 3))),
Expand Down Expand Up @@ -1405,6 +1407,7 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1471,6 +1474,7 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d_new", 3))),
Expand Down
114 changes: 108 additions & 6 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{ArrayRef, Float64Array, Int32Array, RecordBatch};
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, Volatility,
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
};
use rand::{thread_rng, Rng};
use std::iter;
use std::sync::Arc;

/// test that casting happens on udfs.
Expand Down Expand Up @@ -166,10 +170,7 @@ async fn scalar_udf_zero_params() -> Result<()> {

ctx.register_batch("t", batch)?;
// create function just returns 100 regardless of inp
let myfunc = Arc::new(|args: &[ColumnarValue]| {
let ColumnarValue::Scalar(_) = &args[0] else {
panic!("expect scalar")
};
let myfunc = Arc::new(|_args: &[ColumnarValue]| {
Ok(ColumnarValue::Array(
Arc::new((0..1).map(|_| 100).collect::<Int32Array>()) as ArrayRef,
))
Expand Down Expand Up @@ -392,6 +393,107 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
Ok(())
}

#[derive(Debug)]
pub struct RandomUDF {
signature: Signature,
}

impl RandomUDF {
pub fn new() -> Self {
Self {
signature: Signature::any(0, Volatility::Volatile),
}
}
}

impl ScalarUDFImpl for RandomUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"random_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
// This udf is always invoked with zero argument so its argument
// is a null array indicating the batch size.
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(),
_ => {
return Err(datafusion::error::DataFusionError::Internal(
"Invalid argument type".to_string(),
))
}
};
let mut rng = thread_rng();
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len);
let array = Float64Array::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}
}

/// Ensure that a user defined function with zero argument will be invoked
/// with a null array indicating the batch size.
#[tokio::test]
async fn test_user_defined_functions_zero_argument() -> Result<()> {
let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![Field::new(
"index",
DataType::UInt8,
false,
)]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))],
)?;

ctx.register_batch("data_table", batch)?;

let random_normal_udf = ScalarUDF::from(RandomUDF::new());
ctx.register_udf(random_normal_udf);

let result = plan_and_collect(
&ctx,
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table",
)
.await?;

assert_eq!(result.len(), 1);
let batch = &result[0];
let random_udf = batch
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let native_random = batch
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();

assert_eq!(random_udf.len(), native_random.len());

let mut previous = -1.0;
for i in 0..random_udf.len() {
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0);
assert!(random_udf.value(i) != previous);
previous = random_udf.value(i);
}

Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ pub fn create_physical_expr(
input_phy_exprs.to_vec(),
data_type,
monotonicity,
fun.signature().type_signature.supports_zero_argument(),
)))
}

Expand Down
18 changes: 6 additions & 12 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ pub fn create_physical_expr(
}

Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
let mut physical_args = args
let physical_args = args
.iter()
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
.collect::<Result<Vec<_>>>()?;
Expand All @@ -272,17 +272,11 @@ pub fn create_physical_expr(
execution_props,
)
}
ScalarFunctionDefinition::UDF(fun) => {
// udfs with zero params expect null array as input
if args.is_empty() {
physical_args.push(Arc::new(Literal::new(ScalarValue::Null)));
}
udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
input_schema,
)
}
ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
input_schema,
),
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
Expand Down
17 changes: 15 additions & 2 deletions datafusion/physical-expr/src/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub struct ScalarFunctionExpr {
// and it specifies the effect of an increase or decrease in
// the corresponding `arg` to the function value.
monotonicity: Option<FuncMonotonicity>,
// Whether this function can be invoked with zero arguments
supports_zero_argument: bool,
}

impl Debug for ScalarFunctionExpr {
Expand All @@ -79,13 +81,15 @@ impl ScalarFunctionExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
monotonicity: Option<FuncMonotonicity>,
supports_zero_argument: bool,
) -> Self {
Self {
fun,
name: name.to_owned(),
args,
return_type,
monotonicity,
supports_zero_argument,
}
}

Expand Down Expand Up @@ -138,9 +142,12 @@ impl PhysicalExpr for ScalarFunctionExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
// evaluate the arguments, if there are no arguments we'll instead pass in a null array
// indicating the batch size (as a convention)
let inputs = match (self.args.len(), self.name.parse::<BuiltinScalarFunction>()) {
let inputs = match (
self.args.is_empty(),
self.name.parse::<BuiltinScalarFunction>(),
) {
// MakeArray support zero argument but has the different behavior from the array with one null.
(0, Ok(scalar_fun))
(true, Ok(scalar_fun))
if scalar_fun
.signature()
.type_signature
Expand All @@ -149,6 +156,11 @@ impl PhysicalExpr for ScalarFunctionExpr {
{
vec![ColumnarValue::create_null_array(batch.num_rows())]
}
// If the function supports zero argument, we pass in a null array indicating the batch size.
// This is for user-defined functions.
(true, Err(_)) if self.supports_zero_argument => {
vec![ColumnarValue::create_null_array(batch.num_rows())]
}
_ => self
.args
.iter()
Expand All @@ -175,6 +187,7 @@ impl PhysicalExpr for ScalarFunctionExpr {
children,
self.return_type().clone(),
self.monotonicity.clone(),
self.supports_zero_argument,
)))
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub fn create_physical_expr(
input_phy_exprs.to_vec(),
fun.return_type(&input_exprs_types)?,
fun.monotonicity()?,
fun.signature().type_signature.supports_zero_argument(),
)))
}

Expand Down
19 changes: 8 additions & 11 deletions datafusion/proto/src/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,21 +340,17 @@ pub fn parse_physical_expr(
// TODO Do not create new the ExecutionProps
let execution_props = ExecutionProps::new();

let fun_expr = functions::create_physical_fun(
functions::create_physical_expr(
&(&scalar_function).into(),
&args,
input_schema,
&execution_props,
)?;

Arc::new(ScalarFunctionExpr::new(
&e.name,
fun_expr,
args,
convert_required!(e.return_type)?,
None,
))
)?
}
ExprType::ScalarUdf(e) => {
let scalar_fun = registry.udf(e.name.as_str())?.fun().clone();
let udf = registry.udf(e.name.as_str())?;
let signature = udf.signature();
let scalar_fun = udf.fun().clone();

let args = e
.args
Expand All @@ -368,6 +364,7 @@ pub fn parse_physical_expr(
args,
convert_required!(e.return_type)?,
None,
signature.type_signature.supports_zero_argument(),
))
}
ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new(
Expand Down
4 changes: 3 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,9 @@ fn roundtrip_builtin_scalar_function() -> Result<()> {
"acos",
fun_expr,
vec![col("a", &schema)?],
DataType::Int64,
DataType::Float64,
None,
false,
);

let project =
Expand Down Expand Up @@ -617,6 +618,7 @@ fn roundtrip_scalar_udf() -> Result<()> {
vec![col("a", &schema)?],
DataType::Int64,
None,
false,
);

let project =
Expand Down

0 comments on commit 85ceb9d

Please sign in to comment.