Skip to content

Commit

Permalink
Change example simplify to always simplify its argument
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Mar 4, 2024
1 parent 550cbc4 commit fea82cb
Showing 1 changed file with 23 additions and 18 deletions.
41 changes: 23 additions & 18 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,34 +545,39 @@ impl ScalarUDFImpl for CastToI64UDF {
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}
// Wrap with Expr::Cast() to Int64

// Demonstrate simplifying a UDF
fn simplify(
&self,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
// DataFusion should have ensured the function is called with just a
// single argument
assert_eq!(args.len(), 1);
let arg = args.pop().unwrap();

// Note that Expr::cast_to requires an ExprSchema but simplify gets a
// SimplifyInfo so we have to replicate some of the casting logic here.
let source_type = info.get_data_type(&args[0])?;
if source_type == DataType::Int64 {
Ok(ExprSimplifyResult::Original(args))

let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == DataType::Int64 {
// the argument's data type is already the correct type
arg
} else {
// DataFusion should have ensured the function is called with just a
// single argument
assert_eq!(args.len(), 1);
let e = args.pop().unwrap();
Ok(ExprSimplifyResult::Simplified(Expr::Cast(
datafusion_expr::Cast {
expr: Box::new(e),
data_type: DataType::Int64,
},
)))
}
// need to use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: DataType::Int64,
})
};
// return the newly written argument to DataFusion
Ok(ExprSimplifyResult::Simplified(new_expr))
}

// Casting should be done in `simplify`, so we just return the first argument
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
assert_eq!(args.len(), 1);
Ok(args.first().unwrap().clone())
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!("Function should not be evaluated")
}
}

Expand Down

0 comments on commit fea82cb

Please sign in to comment.