diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index ad32151ab141..a29ae9d13935 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -545,34 +545,39 @@ impl ScalarUDFImpl for CastToI64UDF { fn return_type(&self, _args: &[DataType]) -> Result { Ok(DataType::Int64) } - // Wrap with Expr::Cast() to Int64 + + // Demonstrate simplifying a UDF fn simplify( &self, mut args: Vec, info: &dyn SimplifyInfo, ) -> Result { + // DataFusion should have ensured the function is called with just a + // single argument + assert_eq!(args.len(), 1); + let arg = args.pop().unwrap(); + // Note that Expr::cast_to requires an ExprSchema but simplify gets a // SimplifyInfo so we have to replicate some of the casting logic here. - let source_type = info.get_data_type(&args[0])?; - if source_type == DataType::Int64 { - Ok(ExprSimplifyResult::Original(args)) + + let source_type = info.get_data_type(&arg)?; + let new_expr = if source_type == DataType::Int64 { + // the argument's data type is already the correct type + arg } else { - // DataFusion should have ensured the function is called with just a - // single argument - assert_eq!(args.len(), 1); - let e = args.pop().unwrap(); - Ok(ExprSimplifyResult::Simplified(Expr::Cast( - datafusion_expr::Cast { - expr: Box::new(e), - data_type: DataType::Int64, - }, - ))) - } + // need to use an actual cast to get the correct type + Expr::Cast(datafusion_expr::Cast { + expr: Box::new(arg), + data_type: DataType::Int64, + }) + }; + // return the newly written argument to DataFusion + Ok(ExprSimplifyResult::Simplified(new_expr)) } + // Casting should be done in `simplify`, so we just return the first argument - fn invoke(&self, args: &[ColumnarValue]) -> Result { - assert_eq!(args.len(), 1); - Ok(args.first().unwrap().clone()) + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("Function should not be evaluated") } }