Skip to content

Commit

Permalink
Support compute return types from argument values (not just their Dat…
Browse files Browse the repository at this point in the history
…aTypes) (apache#8985)

* ScalarValue return types from argument values

* change file name

* try using ?Sized

* use Ok

* move method default impl outside trait

* Use type trait for ExprSchemable

* fix nit

* Proposed Return Type from Expr suggestions (#1)

* Improve return_type_from_args

* Rework example

* Update datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

---------

Co-authored-by: Junhao Liu <[email protected]>

* Apply suggestions from code review

Co-authored-by: Alex Huang <[email protected]>

* Fix tests + clippy

* rework types to use dyn trait

* fmt

* docs

* Apply suggestions from code review

Co-authored-by: Jeffrey Vo <[email protected]>

* Add docs explaining what happens when both `return_type` and `return_type_from_exprs` are called

* clippy

* fix doc -- comedy of errors

---------

Co-authored-by: Andrew Lamb <[email protected]>
Co-authored-by: Alex Huang <[email protected]>
Co-authored-by: Jeffrey Vo <[email protected]>
  • Loading branch information
4 people authored Feb 15, 2024
1 parent 85be1bc commit 92d9274
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 53 deletions.
142 changes: 139 additions & 3 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::{assert_batches_eq, cast::as_int32_array, Result, ScalarValue};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
plan_err, DataFusionError, ExprSchema, Result, ScalarValue,
};
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, LogicalPlanBuilder, ScalarUDF,
ScalarUDFImpl, Signature, Volatility,
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
use std::sync::Arc;

Expand Down Expand Up @@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
}

impl TakeUDF {
fn new() -> Self {
Self {
signature: Signature::any(3, Volatility::Immutable),
}
}
}

/// Implement a ScalarUDFImpl whose return type is a function of the input values
impl ScalarUDFImpl for TakeUDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"take"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
not_impl_err!("Not called because the return_type_from_exprs is implemented")
}

/// This function returns the type of the first or second argument based on
/// the third argument:
///
/// 1. If the third argument is '0', return the type of the first argument
/// 2. If the third argument is '1', return the type of the second argument
fn return_type_from_exprs(
&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
) -> Result<DataType> {
if arg_exprs.len() != 3 {
return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
}

let take_idx = if let Some(Expr::Literal(ScalarValue::Int64(Some(idx)))) =
arg_exprs.get(2)
{
if *idx == 0 || *idx == 1 {
*idx as usize
} else {
return plan_err!("The third argument must be 0 or 1, got: {idx}");
}
} else {
return plan_err!(
"The third argument must be a literal of type int64, but got {:?}",
arg_exprs.get(2)
);
};

arg_exprs.get(take_idx).unwrap().get_type(schema)
}

// The actual implementation
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let take_idx = match &args[2] {
ColumnarValue::Scalar(ScalarValue::Int64(Some(v))) if v < &2 => *v as usize,
_ => unreachable!(),
};
match &args[take_idx] {
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(array.clone())),
ColumnarValue::Scalar(_) => unimplemented!(),
}
}
}

#[tokio::test]
async fn verify_udf_return_type() -> Result<()> {
// Create a new ScalarUDF from the implementation
let take = ScalarUDF::from(TakeUDF::new());

// SELECT
// take(smallint_col, double_col, 0) as take0,
// take(smallint_col, double_col, 1) as take1
// FROM alltypes_plain;
let exprs = vec![
take.call(vec![col("smallint_col"), col("double_col"), lit(0_i64)])
.alias("take0"),
take.call(vec![col("smallint_col"), col("double_col"), lit(1_i64)])
.alias("take1"),
];

let ctx = SessionContext::new();
register_alltypes_parquet(&ctx).await?;

let df = ctx.table("alltypes_plain").await?.select(exprs)?;

let schema = df.schema();

// The output schema should be
// * type of column smallint_col (int32)
// * type of column double_col (float64)
assert_eq!(schema.field(0).data_type(), &DataType::Int32);
assert_eq!(schema.field(1).data_type(), &DataType::Float64);

let expected = [
"+-------+-------+",
"| take0 | take1 |",
"+-------+-------+",
"| 0 | 0.0 |",
"| 0 | 0.0 |",
"| 0 | 0.0 |",
"| 0 | 0.0 |",
"| 1 | 10.1 |",
"| 1 | 10.1 |",
"| 1 | 10.1 |",
"| 1 | 10.1 |",
"+-------+-------+",
];
assert_batches_sorted_eq!(&expected, &df.collect().await?);

Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down Expand Up @@ -531,6 +656,17 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> {
Ok(())
}

async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> {
let testdata = datafusion::test_util::parquet_test_data();
ctx.register_parquet(
"alltypes_plain",
&format!("{testdata}/alltypes_plain.parquet"),
ParquetReadOptions::default(),
)
.await?;
Ok(())
}

/// Execute SQL and return results as a RecordBatch
async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result<Vec<RecordBatch>> {
ctx.sql(sql).await?.collect().await
Expand Down
40 changes: 22 additions & 18 deletions datafusion/expr/src/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,37 @@ use crate::{utils, LogicalPlan, Projection, Subquery};
use arrow::compute::can_cast_types;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{
internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema,
DataFusionError, ExprSchema, Result,
internal_err, plan_datafusion_err, plan_err, Column, DFField, DataFusionError,
ExprSchema, Result,
};
use std::collections::HashMap;
use std::sync::Arc;

/// trait to allow expr to typable with respect to a schema
pub trait ExprSchemable {
/// given a schema, return the type of the expr
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;

/// given a schema, return the nullability of the expr
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;

/// given a schema, return the expr's optional metadata
fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, String>>;
fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>;

/// convert to a field with respect to a schema
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField>;

/// cast to a type with respect to a schema
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
}

impl ExprSchemable for Expr {
/// Returns the [arrow::datatypes::DataType] of the expression
/// based on [ExprSchema]
///
/// Note: [DFSchema] implements [ExprSchema].
/// Note: [`DFSchema`] implements [ExprSchema].
///
/// [`DFSchema`]: datafusion_common::DFSchema
///
/// # Examples
///
Expand Down Expand Up @@ -90,7 +92,7 @@ impl ExprSchemable for Expr {
/// expression refers to a column that does not exist in the
/// schema, or when the expression is incorrectly typed
/// (e.g. `[utf8] + [bool]`).
fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
match self {
Expr::Alias(Alias { expr, name, .. }) => match &**expr {
Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
Expand Down Expand Up @@ -136,7 +138,7 @@ impl ExprSchemable for Expr {
fun.return_type(&arg_data_types)
}
ScalarFunctionDefinition::UDF(fun) => {
Ok(fun.return_type(&arg_data_types)?)
Ok(fun.return_type_from_exprs(args, schema)?)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
Expand Down Expand Up @@ -213,14 +215,16 @@ impl ExprSchemable for Expr {

/// Returns the nullability of the expression based on [ExprSchema].
///
/// Note: [DFSchema] implements [ExprSchema].
/// Note: [`DFSchema`] implements [ExprSchema].
///
/// [`DFSchema`]: datafusion_common::DFSchema
///
/// # Errors
///
/// This function errors when it is not possible to compute its
/// nullability. This happens when the expression refers to a
/// column that does not exist in the schema.
fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
match self {
Expr::Alias(Alias { expr, .. })
| Expr::Not(expr)
Expand Down Expand Up @@ -327,7 +331,7 @@ impl ExprSchemable for Expr {
}
}

fn metadata<S: ExprSchema>(&self, schema: &S) -> Result<HashMap<String, String>> {
fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> {
match self {
Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
Expand All @@ -339,7 +343,7 @@ impl ExprSchemable for Expr {
///
/// So for example, a projected expression `col(c1) + col(c2)` is
/// placed in an output field **named** col("c1 + c2")
fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
fn to_field(&self, input_schema: &dyn ExprSchema) -> Result<DFField> {
match self {
Expr::Column(c) => Ok(DFField::new(
c.relation.clone(),
Expand Down Expand Up @@ -370,7 +374,7 @@ impl ExprSchemable for Expr {
///
/// This function errors when it is impossible to cast the
/// expression to the target [arrow::datatypes::DataType].
fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr> {
fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
let this_type = self.get_type(schema)?;
if this_type == *cast_to_type {
return Ok(self);
Expand All @@ -394,10 +398,10 @@ impl ExprSchemable for Expr {
}

/// return the schema [`Field`] for the type referenced by `get_indexed_field`
fn field_for_index<S: ExprSchema>(
fn field_for_index(
expr: &Expr,
field: &GetFieldAccess,
schema: &S,
schema: &dyn ExprSchema,
) -> Result<Field> {
let expr_dt = expr.get_type(schema)?;
match field {
Expand Down Expand Up @@ -457,7 +461,7 @@ mod tests {
use super::*;
use crate::{col, lit};
use arrow::datatypes::{DataType, Fields};
use datafusion_common::{Column, ScalarValue, TableReference};
use datafusion_common::{Column, DFSchema, ScalarValue, TableReference};

macro_rules! test_is_expr_nullable {
($EXPR_TYPE:ident) => {{
Expand Down
Loading

0 comments on commit 92d9274

Please sign in to comment.