diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index a86c76b9b6dd..9812789740f7 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -22,12 +22,16 @@ use arrow_schema::{DataType, Field, Schema}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::as_float64_array; -use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue}; +use datafusion_common::{ + assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err, + plan_err, DataFusionError, ExprSchema, Result, ScalarValue, +}; use datafusion_expr::{ - create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable, + LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use rand::{thread_rng, Rng}; +use std::any::Any; use std::iter; use std::sync::Arc; @@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> { Ok(()) } +#[derive(Debug)] +struct TakeUDF { + signature: Signature, +} + +impl TakeUDF { + fn new() -> Self { + Self { + signature: Signature::any(3, Volatility::Immutable), + } + } +} + +/// Implement a ScalarUDFImpl whose return type is a function of the input values +impl ScalarUDFImpl for TakeUDF { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "take" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, _args: &[DataType]) -> Result { + not_impl_err!("Not called because the return_type_from_exprs is implemented") + } + + /// This function returns the type of the first or second argument based on + /// the third argument: + /// + /// 1. If the third argument is '0', return the type of the first argument + /// 2. If the third argument is '1', return the type of the second argument + fn return_type_from_exprs( + &self, + arg_exprs: &[Expr], + schema: &dyn ExprSchema, + ) -> Result { + if arg_exprs.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len()); + } + + let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) = + arg_exprs.get(2) + { + if *idx == 0 || *idx == 1 { + *idx as usize + } else { + return plan_err!("The third argument must be 0 or 1, got: {idx}"); + } + } else { + return plan_err!( + "The third argument must be a literal of type int64, but got {:?}", + arg_exprs.get(2) + ); + }; + + arg_exprs.get(take_idx).unwrap().get_type(schema) + } + + // The actual implementation + fn invoke(&self, args: &[ColumnarValue]) -> Result { + let take_idx = match &args[2] { + ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize, + _ => unreachable!(), + }; + match &args[take_idx] { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())), + ColumnarValue::Scalar(_) => unimplemented!(), + } + } +} + +#[tokio::test] +async fn verify_udf_return_type() -> Result<()> { + // Create a new ScalarUDF from the implementation + let take = ScalarUDF::from(TakeUDF::new()); + + // SELECT + // take(smallint_col, double_col, 0) as take0, + // take(smallint_col, double_col, 1) as take1 + // FROM alltypes_plain; + let exprs = vec![ + take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)]) + .alias("take0"), + take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)]) + .alias("take1"), + ]; + + let ctx = SessionContext::new(); + register_alltypes_parquet(&ctx).await?; + + let df = ctx.table("alltypes_plain").await?.select(exprs)?; + + let schema = df.schema(); + + // The output schema should be + // * type of column smallint_col (int32) + // * type of column double_col (float64) + assert_eq!(schema.field(0).data_type(), &DataType::Int32); + assert_eq!(schema.field(1).data_type(), &DataType::Float64); + + let expected = [ + "+-------+-------+", + "| take0 | take1 |", + "+-------+-------+", + "| 0 | 0.0 |", + "| 0 | 0.0 |", + "| 0 | 0.0 |", + "| 0 | 0.0 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "| 1 | 10.1 |", + "+-------+-------+", + ]; + assert_batches_sorted_eq!(&expected, &df.collect().await?); + + Ok(()) +} + fn create_udf_context() -> SessionContext { let ctx = SessionContext::new(); // register a custom UDF @@ -531,6 +656,17 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> { Ok(()) } +async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { + let testdata = datafusion::test_util::parquet_test_data(); + ctx.register_parquet( + "alltypes_plain", + &format!("{testdata}/alltypes_plain.parquet"), + ParquetReadOptions::default(), + ) + .await?; + Ok(()) +} + /// Execute SQL and return results as a RecordBatch async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 517d7a35f70a..491b4a852261 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -28,8 +28,8 @@ use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; use datafusion_common::{ - internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, - DataFusionError, ExprSchema, Result, + internal_err, plan_datafusion_err, plan_err, Column, DFField, DataFusionError, + ExprSchema, Result, }; use std::collections::HashMap; use std::sync::Arc; @@ -37,26 +37,28 @@ use std::sync::Arc; /// trait to allow expr to typable with respect to a schema pub trait ExprSchemable { /// given a schema, return the type of the expr - fn get_type(&self, schema: &S) -> Result; + fn get_type(&self, schema: &dyn ExprSchema) -> Result; /// given a schema, return the nullability of the expr - fn nullable(&self, input_schema: &S) -> Result; + fn nullable(&self, input_schema: &dyn ExprSchema) -> Result; /// given a schema, return the expr's optional metadata - fn metadata(&self, schema: &S) -> Result>; + fn metadata(&self, schema: &dyn ExprSchema) -> Result>; /// convert to a field with respect to a schema - fn to_field(&self, input_schema: &DFSchema) -> Result; + fn to_field(&self, input_schema: &dyn ExprSchema) -> Result; /// cast to a type with respect to a schema - fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result; + fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result; } impl ExprSchemable for Expr { /// Returns the [arrow::datatypes::DataType] of the expression /// based on [ExprSchema] /// - /// Note: [DFSchema] implements [ExprSchema]. + /// Note: [`DFSchema`] implements [ExprSchema]. + /// + /// [`DFSchema`]: datafusion_common::DFSchema /// /// # Examples /// @@ -90,7 +92,7 @@ impl ExprSchemable for Expr { /// expression refers to a column that does not exist in the /// schema, or when the expression is incorrectly typed /// (e.g. `[utf8] + [bool]`). - fn get_type(&self, schema: &S) -> Result { + fn get_type(&self, schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, name, .. }) => match &**expr { Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { @@ -136,7 +138,7 @@ impl ExprSchemable for Expr { fun.return_type(&arg_data_types) } ScalarFunctionDefinition::UDF(fun) => { - Ok(fun.return_type(&arg_data_types)?) + Ok(fun.return_type_from_exprs(args, schema)?) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -213,14 +215,16 @@ impl ExprSchemable for Expr { /// Returns the nullability of the expression based on [ExprSchema]. /// - /// Note: [DFSchema] implements [ExprSchema]. + /// Note: [`DFSchema`] implements [ExprSchema]. + /// + /// [`DFSchema`]: datafusion_common::DFSchema /// /// # Errors /// /// This function errors when it is not possible to compute its /// nullability. This happens when the expression refers to a /// column that does not exist in the schema. - fn nullable(&self, input_schema: &S) -> Result { + fn nullable(&self, input_schema: &dyn ExprSchema) -> Result { match self { Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) @@ -327,7 +331,7 @@ impl ExprSchemable for Expr { } } - fn metadata(&self, schema: &S) -> Result> { + fn metadata(&self, schema: &dyn ExprSchema) -> Result> { match self { Expr::Column(c) => Ok(schema.metadata(c)?.clone()), Expr::Alias(Alias { expr, .. }) => expr.metadata(schema), @@ -339,7 +343,7 @@ impl ExprSchemable for Expr { /// /// So for example, a projected expression `col(c1) + col(c2)` is /// placed in an output field **named** col("c1 + c2") - fn to_field(&self, input_schema: &DFSchema) -> Result { + fn to_field(&self, input_schema: &dyn ExprSchema) -> Result { match self { Expr::Column(c) => Ok(DFField::new( c.relation.clone(), @@ -370,7 +374,7 @@ impl ExprSchemable for Expr { /// /// This function errors when it is impossible to cast the /// expression to the target [arrow::datatypes::DataType]. - fn cast_to(self, cast_to_type: &DataType, schema: &S) -> Result { + fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result { let this_type = self.get_type(schema)?; if this_type == *cast_to_type { return Ok(self); @@ -394,10 +398,10 @@ impl ExprSchemable for Expr { } /// return the schema [`Field`] for the type referenced by `get_indexed_field` -fn field_for_index( +fn field_for_index( expr: &Expr, field: &GetFieldAccess, - schema: &S, + schema: &dyn ExprSchema, ) -> Result { let expr_dt = expr.get_type(schema)?; match field { @@ -457,7 +461,7 @@ mod tests { use super::*; use crate::{col, lit}; use arrow::datatypes::{DataType, Fields}; - use datafusion_common::{Column, ScalarValue, TableReference}; + use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; macro_rules! test_is_expr_nullable { ($EXPR_TYPE:ident) => {{ diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3017e1ec0271..5b5d92a628c2 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,12 +17,13 @@ //! [`ScalarUDF`]: Scalar User Defined Functions +use crate::ExprSchemable; use crate::{ ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction, ScalarFunctionImplementation, Signature, }; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{ExprSchema, Result}; use std::any::Any; use std::fmt; use std::fmt::Debug; @@ -110,7 +111,7 @@ impl ScalarUDF { /// /// If you implement [`ScalarUDFImpl`] directly you should return aliases directly. pub fn with_aliases(self, aliases: impl IntoIterator) -> Self { - Self::new_from_impl(AliasedScalarUDFImpl::new(self, aliases)) + Self::new_from_impl(AliasedScalarUDFImpl::new(self.inner.clone(), aliases)) } /// Returns a [`Expr`] logical expression to call this UDF with specified @@ -146,10 +147,17 @@ impl ScalarUDF { } /// The datatype this function returns given the input argument input types. + /// This function is used when the input arguments are [`Expr`]s. /// - /// See [`ScalarUDFImpl::return_type`] for more details. - pub fn return_type(&self, args: &[DataType]) -> Result { - self.inner.return_type(args) + /// + /// See [`ScalarUDFImpl::return_type_from_exprs`] for more details. + pub fn return_type_from_exprs( + &self, + args: &[Expr], + schema: &dyn ExprSchema, + ) -> Result { + // If the implementation provides a return_type_from_exprs, use it + self.inner.return_type_from_exprs(args, schema) } /// Invoke the function on `args`, returning the appropriate result. @@ -246,9 +254,54 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn signature(&self) -> &Signature; /// What [`DataType`] will be returned by this function, given the types of - /// the arguments + /// the arguments. + /// + /// # Notes + /// + /// If you provide an implementation for [`Self::return_type_from_exprs`], + /// DataFusion will not call `return_type` (this function). In this case it + /// is recommended to return [`DataFusionError::Internal`]. + /// + /// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal fn return_type(&self, arg_types: &[DataType]) -> Result; + /// What [`DataType`] will be returned by this function, given the + /// arguments? + /// + /// Note most UDFs should implement [`Self::return_type`] and not this + /// function. The output type for most functions only depends on the types + /// of their inputs (e.g. `sqrt(f32)` is always `f32`). + /// + /// By default, this function calls [`Self::return_type`] with the + /// types of each argument. + /// + /// This method can be overridden for functions that return different + /// *types* based on the *values* of their arguments. + /// + /// For example, the following two function calls get the same argument + /// types (something and a `Utf8` string) but return different types based + /// on the value of the second argument: + /// + /// * `arrow_cast(x, 'Int16')` --> `Int16` + /// * `arrow_cast(x, 'Float32')` --> `Float32` + /// + /// # Notes: + /// + /// This function must consistently return the same type for the same + /// logical input even if the input is simplified (e.g. it must return the same + /// value for `('foo' | 'bar')` as it does for ('foobar'). + fn return_type_from_exprs( + &self, + args: &[Expr], + schema: &dyn ExprSchema, + ) -> Result { + let arg_types = args + .iter() + .map(|arg| arg.get_type(schema)) + .collect::>>()?; + self.return_type(&arg_types) + } + /// Invoke the function on `args`, returning the appropriate result /// /// The function will be invoked passed with the slice of [`ColumnarValue`] @@ -290,13 +343,13 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// implement [`ScalarUDFImpl`], which supports aliases, directly if possible. #[derive(Debug)] struct AliasedScalarUDFImpl { - inner: ScalarUDF, + inner: Arc, aliases: Vec, } impl AliasedScalarUDFImpl { pub fn new( - inner: ScalarUDF, + inner: Arc, new_aliases: impl IntoIterator, ) -> Self { let mut aliases = inner.aliases().to_vec(); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 662e0fc7c258..fba77047dd74 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -681,17 +681,17 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { let case_type = case .expr .as_ref() - .map(|expr| expr.get_type(&schema)) + .map(|expr| expr.get_type(schema)) .transpose()?; let then_types = case .when_then_expr .iter() - .map(|(_when, then)| then.get_type(&schema)) + .map(|(_when, then)| then.get_type(schema)) .collect::>>()?; let else_type = case .else_expr .as_ref() - .map(|expr| expr.get_type(&schema)) + .map(|expr| expr.get_type(schema)) .transpose()?; // find common coercible types @@ -701,7 +701,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { let when_types = case .when_then_expr .iter() - .map(|(when, _then)| when.get_type(&schema)) + .map(|(when, _then)| when.get_type(schema)) .collect::>>()?; let coerced_type = get_coerce_type_for_case_expression(&when_types, Some(case_type)); @@ -727,7 +727,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { let case_expr = case .expr .zip(case_when_coerce_type.as_ref()) - .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, &schema)) + .map(|(case_expr, coercible_type)| case_expr.cast_to(coercible_type, schema)) .transpose()? .map(Box::new); let when_then = case @@ -735,7 +735,7 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { .into_iter() .map(|(when, then)| { let when_type = case_when_coerce_type.as_ref().unwrap_or(&DataType::Boolean); - let when = when.cast_to(when_type, &schema).map_err(|e| { + let when = when.cast_to(when_type, schema).map_err(|e| { DataFusionError::Context( format!( "WHEN expressions in CASE couldn't be \ @@ -744,13 +744,13 @@ fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result { Box::new(e), ) })?; - let then = then.cast_to(&then_else_coerce_type, &schema)?; + let then = then.cast_to(&then_else_coerce_type, schema)?; Ok((Box::new(when), Box::new(then))) }) .collect::>>()?; let else_expr = case .else_expr - .map(|expr| expr.cast_to(&then_else_coerce_type, &schema)) + .map(|expr| expr.cast_to(&then_else_coerce_type, schema)) .transpose()? .map(Box::new); diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 6408af5cda99..b8491aea2d6f 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -272,11 +272,15 @@ pub fn create_physical_expr( execution_props, ) } - ScalarFunctionDefinition::UDF(fun) => udf::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - ), + ScalarFunctionDefinition::UDF(fun) => { + let return_type = fun.return_type_from_exprs(args, input_dfschema)?; + + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + return_type, + ) + } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") } diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index e0117fecb4e8..d9c7c9e5c2a6 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -17,28 +17,24 @@ //! UDF support use crate::{PhysicalExpr, ScalarFunctionExpr}; -use arrow::datatypes::Schema; +use arrow_schema::DataType; use datafusion_common::Result; pub use datafusion_expr::ScalarUDF; use std::sync::Arc; /// Create a physical expression of the UDF. -/// This function errors when `args`' can't be coerced to a valid argument type of the UDF. +/// +/// Arguments: pub fn create_physical_expr( fun: &ScalarUDF, input_phy_exprs: &[Arc], - input_schema: &Schema, + return_type: DataType, ) -> Result> { - let input_exprs_types = input_phy_exprs - .iter() - .map(|e| e.data_type(input_schema)) - .collect::>>()?; - Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), fun.fun(), input_phy_exprs.to_vec(), - fun.return_type(&input_exprs_types)?, + return_type, fun.monotonicity()?, fun.signature().type_signature.supports_zero_argument(), ))) @@ -46,7 +42,6 @@ pub fn create_physical_expr( #[cfg(test)] mod tests { - use arrow::datatypes::Schema; use arrow_schema::DataType; use datafusion_common::Result; use datafusion_expr::{ @@ -102,7 +97,7 @@ mod tests { // create and register the udf let udf = ScalarUDF::from(TestScalarUDF::new()); - let p_expr = create_physical_expr(&udf, &[], &Schema::empty())?; + let p_expr = create_physical_expr(&udf, &[], DataType::Float64)?; assert_eq!( p_expr