diff --git a/datafusion-examples/examples/simplify_udwf_expression.rs b/datafusion-examples/examples/simplify_udwf_expression.rs new file mode 100644 index 000000000000..2824d03761ab --- /dev/null +++ b/datafusion-examples/examples/simplify_udwf_expression.rs @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow_schema::DataType; +use datafusion::execution::context::SessionContext; +use datafusion::{error::Result, execution::options::CsvReadOptions}; +use datafusion_expr::function::WindowFunctionSimplification; +use datafusion_expr::{ + expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, +}; + +/// This UDWF will show how to use the WindowUDFImpl::simplify() API +#[derive(Debug, Clone)] +struct SimplifySmoothItUdf { + signature: Signature, +} + +impl SimplifySmoothItUdf { + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} +impl WindowUDFImpl for SimplifySmoothItUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "simplify_smooth_it" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn partition_evaluator(&self) -> Result> { + todo!() + } + + /// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`. + fn simplify(&self) -> Option { + // Ok(ExprSimplifyResult::Simplified(Expr::WindowFunction( + // WindowFunction { + // fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction( + // AggregateFunction::Avg, + // ), + // args, + // partition_by: partition_by.to_vec(), + // order_by: order_by.to_vec(), + // window_frame: window_frame.clone(), + // null_treatment: *null_treatment, + // }, + // ))) + let simplify = |window_function: datafusion_expr::expr::WindowFunction, + _: &dyn SimplifyInfo| { + Ok(Expr::WindowFunction(WindowFunction { + fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction( + AggregateFunction::Avg, + ), + args: window_function.args, + partition_by: window_function.partition_by, + order_by: window_function.order_by, + window_frame: window_function.window_frame, + null_treatment: window_function.null_treatment, + })) + }; + + Some(Box::new(simplify)) + } +} + +// create local execution context with `cars.csv` registered as a table named `cars` +async fn create_context() -> Result { + // declare a new context. In spark API, this corresponds to a new spark SQL session + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + println!("pwd: {}", std::env::current_dir().unwrap().display()); + let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string(); + let read_options = CsvReadOptions::default().has_header(true); + + ctx.register_csv("cars", &csv_path, read_options).await?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context().await?; + let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new()); + ctx.register_udwf(simplify_smooth_it.clone()); + + // Use SQL to run the new window function + let df = ctx.sql("SELECT * from cars").await?; + // print the results + df.show().await?; + + let df = ctx + .sql( + "SELECT \ + car, \ + speed, \ + simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\ + time \ + from cars \ + ORDER BY \ + car", + ) + .await?; + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index eb748ed2711a..7f49b03bb2ce 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -134,3 +134,16 @@ pub type AggregateFunctionSimplification = Box< &dyn crate::simplify::SimplifyInfo, ) -> Result, >; + +/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure +/// A closure with two arguments: +/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked +/// * 'info': [crate::simplify::SimplifyInfo] +/// +/// closure returns simplified [Expr] or an error. +pub type WindowFunctionSimplification = Box< + dyn Fn( + crate::expr::WindowFunction, + &dyn crate::simplify::SimplifyInfo, + ) -> Result, +>; diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 5a8373509a40..ce28b444adbc 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -18,8 +18,8 @@ //! [`WindowUDF`]: User Defined Window Functions use crate::{ - Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature, - WindowFrame, + function::WindowFunctionSimplification, Expr, PartitionEvaluator, + PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame, }; use arrow::datatypes::DataType; use datafusion_common::Result; @@ -170,6 +170,13 @@ impl WindowUDF { self.inner.return_type(args) } + /// Do the function rewrite + /// + /// See [`WindowUDFImpl::simplify`] for more details. + pub fn simplify(&self) -> Option { + self.inner.simplify() + } + /// Return a `PartitionEvaluator` for evaluating this window function pub fn partition_evaluator_factory(&self) -> Result> { self.inner.partition_evaluator() @@ -266,6 +273,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Optionally apply per-UDWF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization. The default implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// Example: + /// [`simplify_udwf_expression.rs`]: + /// + /// # Returns + /// [None] if simplify is not defined or, + /// + /// Or, a closure with two arguments: + /// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked + /// * 'info': [crate::simplify::SimplifyInfo] + fn simplify(&self) -> Option { + None + } } /// WindowUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 25504e5c78e7..c87654292a01 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,10 +32,13 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery}; +use datafusion_expr::expr::{ + AggregateFunctionDefinition, InList, InSubquery, WindowFunction, +}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, + WindowFunctionDefinition, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -1391,6 +1394,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { (_, expr) => Transformed::no(expr), }, + Expr::WindowFunction(WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(ref udwf), + .. + }) => match (udwf.simplify(), expr) { + (Some(simplify_function), Expr::WindowFunction(wf)) => { + Transformed::yes(simplify_function(wf, info)?) + } + (_, expr) => Transformed::no(expr), + }, + // // Rules for Between // @@ -1758,7 +1771,10 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; use datafusion_expr::{ - function::{AccumulatorArgs, AggregateFunctionSimplification}, + function::{ + AccumulatorArgs, AggregateFunctionSimplification, + WindowFunctionSimplification, + }, interval_arithmetic::Interval, *, }; @@ -3800,4 +3816,87 @@ mod tests { } } } + + #[test] + fn test_simplify_udwf() { + let udwf = WindowFunctionDefinition::WindowUDF( + WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(), + ); + let window_function_expr = + Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( + udwf, + vec![], + vec![], + vec![], + WindowFrame::new(None), + None, + )); + + let expected = col("result_column"); + assert_eq!(simplify(window_function_expr), expected); + + let udwf = WindowFunctionDefinition::WindowUDF( + WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(), + ); + let window_function_expr = + Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new( + udwf, + vec![], + vec![], + vec![], + WindowFrame::new(None), + None, + )); + + let expected = window_function_expr.clone(); + assert_eq!(simplify(window_function_expr), expected); + } + + /// A Mock UDWF which defines `simplify` to be used in tests + /// related to UDWF simplification + #[derive(Debug, Clone)] + struct SimplifyMockUdwf { + simplify: bool, + } + + impl SimplifyMockUdwf { + /// make simplify method return new expression + fn new_with_simplify() -> Self { + Self { simplify: true } + } + /// make simplify method return no change + fn new_without_simplify() -> Self { + Self { simplify: false } + } + } + + impl WindowUDFImpl for SimplifyMockUdwf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_simplify" + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("not needed for tests") + } + + fn simplify(&self) -> Option { + if self.simplify { + Some(Box::new(|_, _| Ok(col("result_column")))) + } else { + None + } + } + + fn partition_evaluator(&self) -> Result> { + unimplemented!("not needed for tests") + } + } }