Skip to content

Commit

Permalink
change simplify signature
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <[email protected]>
  • Loading branch information
jayzhan211 committed Mar 2, 2024
1 parent adcb2e2 commit 5d1d500
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,16 @@ use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
use datafusion::{execution::registry::FunctionRegistry, test_util};
use datafusion_common::cast::as_float64_array;
use datafusion_common::DFSchemaRef;
use datafusion_common::DFSchema;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
plan_err, DataFusionError, ExprSchema, Result, ScalarValue,
};
use datafusion_expr::simplify::Simplified;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Simplified, Volatility,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};

use rand::{thread_rng, Rng};
Expand Down Expand Up @@ -529,8 +531,9 @@ impl ScalarUDFImpl for CastToI64UDF {
Ok(DataType::Int64)
}
// Wrap with Expr::Cast() to Int64
fn simplify(&self, args: &[Expr], schema: DFSchemaRef) -> Result<Simplified> {
fn simplify(&self, args: &[Expr], info: &dyn SimplifyInfo) -> Result<Simplified> {
let e = args[0].to_owned();
let schema = info.schema().unwrap_or_else(|| DFSchema::empty().into());
let casted_expr = e.cast_to(&DataType::Int64, schema.as_ref())?;
Ok(Simplified::Rewritten(casted_expr))
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub use signature::{
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl};
pub use udf::{ScalarUDF, ScalarUDFImpl, Simplified};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};

Expand Down
9 changes: 9 additions & 0 deletions datafusion/expr/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,12 @@ impl<'a> SimplifyInfo for SimplifyContext<'a> {
self.props
}
}

/// Was the expression simplified?
pub enum Simplified {
/// The function call was simplified to an entirely new Expr
Rewritten(Expr),
/// the function call could not be simplified, and the arguments
/// are return unmodified
Original,
}
19 changes: 5 additions & 14 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,20 @@

//! [`ScalarUDF`]: Scalar User Defined Functions

use crate::simplify::{Simplified, SimplifyInfo};
use crate::ExprSchemable;
use crate::{
ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction,
ScalarFunctionImplementation, Signature,
};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchemaRef, ExprSchema, Result};
use datafusion_common::{ExprSchema, Result};
use std::any::Any;
use std::fmt;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::sync::Arc;

// TODO(In this PR): Move to simplify.rs
/// Was the expression simplified?
pub enum Simplified {
/// The function call was simplified to an entirely new Expr
Rewritten(Expr),
/// the function call could not be simplified, and the arguments
/// are return unmodified
Original,
}

/// Logical representation of a Scalar User Defined Function.
///
/// A scalar function produces a single row output for each row of input. This
Expand Down Expand Up @@ -173,8 +164,8 @@ impl ScalarUDF {
/// Do the function rewrite
///
/// See [`ScalarUDFImpl::simplify`] for more details.
pub fn simplify(&self, args: &[Expr], schema: DFSchemaRef) -> Result<Simplified> {
self.inner.simplify(args, schema)
pub fn simplify(&self, args: &[Expr], info: &dyn SimplifyInfo) -> Result<Simplified> {
self.inner.simplify(args, info)
}

/// Invoke the function on `args`, returning the appropriate result.
Expand Down Expand Up @@ -358,7 +349,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
// Do the function rewrite.
// 'args': The arguments of the function
// 'schema': The schema of the function
fn simplify(&self, _args: &[Expr], _schema: DFSchemaRef) -> Result<Simplified> {
fn simplify(&self, _args: &[Expr], _info: &dyn SimplifyInfo) -> Result<Simplified> {
Ok(Simplified::Original)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
//! This module implements a rule that do function simplification.

use datafusion_common::tree_node::TreeNodeRewriter;
use datafusion_common::{DFSchema, Result};
use datafusion_common::Result;
use datafusion_expr::simplify::Simplified;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::Simplified;
use datafusion_expr::{expr::ScalarFunction, Expr, ScalarFunctionDefinition};

pub(super) struct FunctionSimplifier<'a, S> {
Expand All @@ -42,12 +42,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for FunctionSimplifier<'a, S> {
args,
}) = &expr
{
let schema = self
.info
.schema()
.unwrap_or_else(|| DFSchema::empty().into());

let simplified_expr = udf.simplify(args, schema)?;
let simplified_expr = udf.simplify(args, self.info)?;
match simplified_expr {
Simplified::Original => Ok(expr),
Simplified::Rewritten(expr) => Ok(expr),
Expand Down

0 comments on commit 5d1d500

Please sign in to comment.