Skip to content

Commit

Permalink
Add a ScalarUDFImpl::simplfy() API, move SimplifyInfo et al to da…
Browse files Browse the repository at this point in the history
…tafusion_expr (#9304)

* first draft

Signed-off-by: jayzhan211 <[email protected]>

* clippy

Signed-off-by: jayzhan211 <[email protected]>

* add comments

Signed-off-by: jayzhan211 <[email protected]>

* move to optimize rule

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* fix explain test

Signed-off-by: jayzhan211 <[email protected]>

* move to simplifier

Signed-off-by: jayzhan211 <[email protected]>

* pass with schema

Signed-off-by: jayzhan211 <[email protected]>

* fix explain

Signed-off-by: jayzhan211 <[email protected]>

* fix doc

Signed-off-by: jayzhan211 <[email protected]>

* move to expr

Signed-off-by: jayzhan211 <[email protected]>

* change simplify signature

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* cleanup

Signed-off-by: jayzhan211 <[email protected]>

* fix doc

Signed-off-by: jayzhan211 <[email protected]>

* fix doc

Signed-off-by: jayzhan211 <[email protected]>

* Update datafusion/expr/src/udf.rs

* Add backwards compatibile uses, inline FunctionSimplifier, rename to ExprSimplifyResult

* Remove DFSchema from SimplifyInfo

* Avoid requiring argument copies

* Improve docs

* fix link

* fix doc test

* Update datafusion/physical-expr/src/lib.rs

* Change example simplify to always simplify its argument

* Clarify comment

---------

Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
jayzhan211 and alamb authored Mar 5, 2024
1 parent 3854419 commit 2873fd0
Show file tree
Hide file tree
Showing 35 changed files with 287 additions and 99 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ use arrow::record_batch::RecordBatch;
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::{DFField, DFSchema};
use datafusion::error::Result;
use datafusion::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use datafusion::physical_expr::execution_props::ExecutionProps;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_expr::{
analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr,
};
use datafusion::prelude::*;
use datafusion_common::{ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{ColumnarValue, ExprSchemable, Operator};

/// This example demonstrates the DataFusion [`Expr`] API.
Expand Down
3 changes: 2 additions & 1 deletion datafusion-examples/examples/simple_udtf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::{plan_err, ScalarValue};
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::{Expr, TableType};
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use std::fs::File;
use std::io::Seek;
use std::path::Path;
Expand Down
6 changes: 2 additions & 4 deletions datafusion/core/src/datasource/listing/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ use arrow::{
use arrow_schema::Fields;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;

use futures::stream::{BoxStream, FuturesUnordered};
use futures::{StreamExt, TryStreamExt};
use futures::stream::{BoxStream, FuturesUnordered, StreamExt, TryStreamExt};
use log::{debug, trace};
use object_store::path::Path;
use object_store::{ObjectMeta, ObjectStore};
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/src/datasource/physical_plan/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,13 +800,14 @@ mod tests {
ArrayRef, Date64Array, Int32Array, Int64Array, Int8Array, StringArray,
StructArray,
};

use arrow::datatypes::{DataType, Field, Schema, SchemaBuilder};
use arrow::record_batch::RecordBatch;
use arrow_schema::Fields;
use datafusion_common::{assert_contains, FileType, GetExt, ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{col, lit, when, Expr};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;

use chrono::{TimeZone, Utc};
use futures::StreamExt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,9 @@ mod test {
use super::*;
use arrow::datatypes::Field;
use datafusion_common::ToDFSchema;
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{cast, col, lit, Expr};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use parquet::arrow::parquet_to_arrow_schema;
use parquet::file::reader::{FileReader, SerializedFileReader};
use rand::prelude::*;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ mod tests {
use arrow::datatypes::Schema;
use arrow::datatypes::{DataType, Field};
use datafusion_common::{Result, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{cast, col, lit, Expr};
use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::{create_physical_expr, PhysicalExpr};
use parquet::arrow::arrow_to_parquet_schema;
use parquet::arrow::async_reader::ParquetObjectReader;
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ use datafusion_common::{
tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor},
};
use datafusion_execution::registry::SerializerRegistry;
pub use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::var_provider::is_system_variables;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
pub use datafusion_physical_expr::execution_props::ExecutionProps;
use datafusion_physical_expr::var_provider::is_system_variables;
use parking_lot::RwLock;
use std::collections::hash_map::Entry;
use std::string::String;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1341,10 +1341,10 @@ mod tests {
datatypes::{DataType, TimeUnit},
};
use datafusion_common::{ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;
use std::collections::HashMap;
use std::ops::{Not, Rem};

Expand Down
5 changes: 3 additions & 2 deletions datafusion/core/src/test_util/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ use crate::datasource::listing::{ListingTableUrl, PartitionedFile};
use crate::datasource::object_store::ObjectStoreUrl;
use crate::datasource::physical_plan::{FileScanConfig, ParquetExec};
use crate::error::Result;
use crate::optimizer::simplify_expressions::{ExprSimplifier, SimplifyContext};
use crate::logical_expr::execution_props::ExecutionProps;
use crate::logical_expr::simplify::SimplifyContext;
use crate::optimizer::simplify_expressions::ExprSimplifier;
use crate::physical_expr::create_physical_expr;
use crate::physical_expr::execution_props::ExecutionProps;
use crate::physical_plan::filter::FilterExec;
use crate::physical_plan::metrics::MetricsSet;
use crate::physical_plan::ExecutionPlan;
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/variable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@

//! Variable provider for `@name` and `@@name` style runtime values.

pub use datafusion_physical_expr::var_provider::{VarProvider, VarType};
pub use datafusion_expr::var_provider::{VarProvider, VarType};
2 changes: 1 addition & 1 deletion datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ use datafusion_common::{assert_contains, DataFusionError, ScalarValue, UnnestOpt
use datafusion_execution::config::SessionConfig;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, count, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, sum, when, wildcard, AggregateFunction, Expr,
ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
use datafusion_physical_expr::var_provider::{VarProvider, VarType};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/parquet/page_pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ use datafusion::physical_plan::metrics::MetricValue;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::prelude::SessionContext;
use datafusion_common::{ScalarValue, Statistics, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::{col, lit, Expr};
use datafusion_physical_expr::create_physical_expr;
use datafusion_physical_expr::execution_props::ExecutionProps;

use futures::StreamExt;
use object_store::path::Path;
Expand Down
26 changes: 14 additions & 12 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@
use arrow::datatypes::{DataType, Field, Schema};
use arrow_array::{ArrayRef, Int32Array};
use chrono::{DateTime, TimeZone, Utc};
use datafusion::common::DFSchema;
use datafusion::{error::Result, execution::context::ExecutionProps, prelude::*};
use datafusion_common::cast::as_int32_array;
use datafusion_common::ScalarValue;
use datafusion_common::{DFSchemaRef, ToDFSchema};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr, table_scan, BuiltinScalarFunction, Cast, ColumnarValue, Expr, ExprSchemable,
LogicalPlan, LogicalPlanBuilder, ScalarUDF, Volatility,
};
use datafusion_optimizer::simplify_expressions::{
ExprSimplifier, SimplifyExpressions, SimplifyInfo,
};
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
use std::sync::Arc;

Expand All @@ -42,7 +41,7 @@ use std::sync::Arc;
/// objects or from some other implementation
struct MyInfo {
/// The input schema
schema: DFSchema,
schema: DFSchemaRef,

/// Execution specific details needed for constant evaluation such
/// as the current time for `now()` and [VariableProviders]
Expand All @@ -51,24 +50,27 @@ struct MyInfo {

impl SimplifyInfo for MyInfo {
fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
Ok(matches!(expr.get_type(&self.schema)?, DataType::Boolean))
Ok(matches!(
expr.get_type(self.schema.as_ref())?,
DataType::Boolean
))
}

fn nullable(&self, expr: &Expr) -> Result<bool> {
expr.nullable(&self.schema)
expr.nullable(self.schema.as_ref())
}

fn execution_props(&self) -> &ExecutionProps {
&self.execution_props
}

fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
expr.get_type(&self.schema)
expr.get_type(self.schema.as_ref())
}
}

impl From<DFSchema> for MyInfo {
fn from(schema: DFSchema) -> Self {
impl From<DFSchemaRef> for MyInfo {
fn from(schema: DFSchemaRef) -> Self {
Self {
schema,
execution_props: ExecutionProps::new(),
Expand All @@ -81,13 +83,13 @@ impl From<DFSchema> for MyInfo {
/// a: Int32 (possibly with nulls)
/// b: Int32
/// s: Utf8
fn schema() -> DFSchema {
fn schema() -> DFSchemaRef {
Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, false),
Field::new("s", DataType::Utf8, false),
])
.try_into()
.to_dfschema_ref()
.unwrap()
}

Expand Down
102 changes: 101 additions & 1 deletion datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{Array, ArrayRef, Float64Array, Int32Array, RecordBatch, UInt8Array};
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_schema::{DataType, Field, Schema};
use datafusion::prelude::*;
Expand All @@ -26,10 +28,13 @@ use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, cast::as_int32_array, not_impl_err,
plan_err, ExprSchema, Result, ScalarValue,
};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};

use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
Expand Down Expand Up @@ -514,6 +519,101 @@ async fn deregister_udf() -> Result<()> {
Ok(())
}

#[derive(Debug)]
struct CastToI64UDF {
signature: Signature,
}

impl CastToI64UDF {
fn new() -> Self {
Self {
signature: Signature::any(1, Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for CastToI64UDF {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"cast_to_i64"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, _args: &[DataType]) -> Result<DataType> {
Ok(DataType::Int64)
}

// Demonstrate simplifying a UDF
fn simplify(
&self,
mut args: Vec<Expr>,
info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
// DataFusion should have ensured the function is called with just a
// single argument
assert_eq!(args.len(), 1);
let arg = args.pop().unwrap();

// Note that Expr::cast_to requires an ExprSchema but simplify gets a
// SimplifyInfo so we have to replicate some of the casting logic here.

let source_type = info.get_data_type(&arg)?;
let new_expr = if source_type == DataType::Int64 {
// the argument's data type is already the correct type
arg
} else {
// need to use an actual cast to get the correct type
Expr::Cast(datafusion_expr::Cast {
expr: Box::new(arg),
data_type: DataType::Int64,
})
};
// return the newly written argument to DataFusion
Ok(ExprSimplifyResult::Simplified(new_expr))
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!("Function should have been simplified prior to evaluation")
}
}

#[tokio::test]
async fn test_user_defined_functions_cast_to_i64() -> Result<()> {
let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![Field::new("x", DataType::Float32, false)]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(Float32Array::from(vec![1.0, 2.0, 3.0]))],
)?;

ctx.register_batch("t", batch)?;

let cast_to_i64_udf = ScalarUDF::from(CastToI64UDF::new());
ctx.register_udf(cast_to_i64_udf);

let result = plan_and_collect(&ctx, "SELECT cast_to_i64(x) FROM t").await?;

assert_batches_eq!(
&[
"+------------------+",
"| cast_to_i64(t.x) |",
"+------------------+",
"| 1 |",
"| 2 |",
"| 3 |",
"+------------------+"
],
&result
);

Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ ahash = { version = "0.8", default-features = false, features = [
] }
arrow = { workspace = true }
arrow-array = { workspace = true }
chrono = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
paste = "^1.0"
sqlparser = { workspace = true }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,12 @@ use std::sync::Arc;
/// Holds per-query execution properties and data (such as statement
/// starting timestamps).
///
/// An [`ExecutionProps`] is created each time a [`LogicalPlan`] is
/// An [`ExecutionProps`] is created each time a `LogicalPlan` is
/// prepared for execution (optimized). If the same plan is optimized
/// multiple times, a new `ExecutionProps` is created each time.
///
/// It is important that this structure be cheap to create as it is
/// done so during predicate pruning and expression simplification
///
/// [`LogicalPlan`]: datafusion_expr::LogicalPlan
#[derive(Clone, Debug)]
pub struct ExecutionProps {
pub query_execution_start_time: DateTime<Utc>,
Expand Down
Loading

0 comments on commit 2873fd0

Please sign in to comment.