Skip to content

Commit

Permalink
first draft
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Feb 21, 2024
1 parent 453a45a commit bba4940
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch,
UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
Expand All @@ -26,12 +29,15 @@ use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
plan_err, DataFusionError, ExprSchema, Result, ScalarValue,
};
use datafusion_common::{DFField, DFSchema};
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Simplified, Volatility,
};

use rand::{thread_rng, Rng};
use std::any::Any;
use std::collections::HashMap;
use std::iter;
use std::sync::Arc;

Expand Down Expand Up @@ -498,6 +504,81 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
Ok(())
}

#[derive(Debug)]
struct CastToI64UDF {
signature: Signature,
}

impl CastToI64UDF {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for CastToI64UDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"cast_to_i64"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}
// Wrap with Expr::Cast() to Int64
fn simplify(&self, args: Vec<Expr>) -> Result<Simplified> {
let dfs = DFSchema::new_with_metadata(
vec![DFField::new(Some("t"), "x", DataType::Float32, true)],
HashMap::default(),
)?;
let e = args[0].clone();
let casted_expr = e.cast_to(&DataType::Int64, &dfs)?;
Ok(Simplified::Rewritten(casted_expr))
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(args.get(0).unwrap().clone())
}
}

#[tokio::test]
async fn test_user_defined_functions_cast_to_i64() -> Result<()> {
let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, false)]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]))],
)?;

ctx.register_batch("t", batch)?;

let cast_to_i64_udf = ScalarUDF::from(CastToI64UDF::new());
ctx.register_udf(cast_to_i64_udf);

let result = plan_and_collect(&ctx, "SELECT cast_to_i64(x) FROM t").await?;

assert_batches_eq!(
&[
"+------------------+",
"| cast_to_i64(t.x) |",
"+------------------+",
"| 1 |",
"| 2 |",
"| 3 |",
"+------------------+"
],
&result
);

Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub use signature::{
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udf::{ScalarUDF, ScalarUDFImpl, Simplified};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

Expand Down
17 changes: 17 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ use std::fmt::Debug;
use std::fmt::Formatter;
use std::sync::Arc;

/// Was the expression simplified?
pub enum Simplified {
/// The function call was simplified to an entirely new Expr
Rewritten(Expr),
/// the function call could not be simplified, and the arguments
/// are return unmodified
Original(Vec<Expr>),
}

/// Logical representation of a Scalar User Defined Function.
///
/// A scalar function produces a single row output for each row of input. This
Expand Down Expand Up @@ -160,6 +169,10 @@ impl ScalarUDF {
self.inner.return_type_from_exprs(args, schema)
}

pub fn simplify(&self, args: Vec<Expr>) -> Result<Simplified> {
self.inner.simplify(args)
}

/// Invoke the function on `args`, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke`] for more details.
Expand Down Expand Up @@ -337,6 +350,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
fn monotonicity(&self) -> Result<Option<FuncMonotonicity>> {
Ok(None)
}

fn simplify(&self, args: Vec<Expr>) -> Result<Simplified> {
Ok(Simplified::Original(args))
}
}

/// ScalarUDF that adds an alias to the underlying function. It is better to
Expand Down
62 changes: 36 additions & 26 deletions datafusion/physical-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,34 +258,44 @@ pub fn create_physical_expr(
)))
}

Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
let physical_args = args
.iter()
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
.collect::<Result<Vec<_>>>()?;
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
functions::create_physical_expr(
fun,
&physical_args,
input_schema,
execution_props,
)
}
ScalarFunctionDefinition::UDF(fun) => {
let return_type = fun.return_type_from_exprs(args, input_dfschema)?;
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
let physical_args = args
.iter()
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
.collect::<Result<Vec<_>>>()?;

udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
return_type,
)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
functions::create_physical_expr(
fun,
&physical_args,
input_schema,
execution_props,
)
}
}
ScalarFunctionDefinition::UDF(fun) => {
let args = match fun.simplify(args.to_owned())? {
datafusion_expr::Simplified::Original(args) => args,
datafusion_expr::Simplified::Rewritten(expr) => vec![expr],
};

let physical_args = args
.iter()
.map(|e| create_physical_expr(e, input_dfschema, execution_props))
.collect::<Result<Vec<_>>>()?;

let return_type =
fun.return_type_from_exprs(args.as_slice(), input_dfschema)?;

udf::create_physical_expr(
fun.clone().as_ref(),
&physical_args,
return_type,
)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
},
Expr::Between(Between {
expr,
negated,
Expand Down

0 comments on commit bba4940

Please sign in to comment.