Skip to content

Commit

Permalink
feature: Add a WindowUDFImpl::simplify() API (#9906)
Browse files Browse the repository at this point in the history
* feature: Add a WindowUDFImpl::simplfy() API

Signed-off-by: guojidan <[email protected]>

* fix doc

Signed-off-by: guojidan <[email protected]>

* fix fmt

Signed-off-by: guojidan <[email protected]>

---------

Signed-off-by: guojidan <[email protected]>
  • Loading branch information
guojidan authored May 30, 2024
1 parent 77352b2 commit 3d00760
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 4 deletions.
142 changes: 142 additions & 0 deletions datafusion-examples/examples/simplify_udwf_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::any::Any;

use arrow_schema::DataType;
use datafusion::execution::context::SessionContext;
use datafusion::{error::Result, execution::options::CsvReadOptions};
use datafusion_expr::function::WindowFunctionSimplification;
use datafusion_expr::{
expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr,
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
};

/// This UDWF will show how to use the WindowUDFImpl::simplify() API
#[derive(Debug, Clone)]
struct SimplifySmoothItUdf {
signature: Signature,
}

impl SimplifySmoothItUdf {
fn new() -> Self {
Self {
signature: Signature::exact(
// this function will always take one arguments of type f64
vec![DataType::Float64],
// this function is deterministic and will always return the same
// result for the same input
Volatility::Immutable,
),
}
}
}
impl WindowUDFImpl for SimplifySmoothItUdf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"simplify_smooth_it"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
todo!()
}

/// this function will simplify `SimplifySmoothItUdf` to `SmoothItUdf`.
fn simplify(&self) -> Option<WindowFunctionSimplification> {
// Ok(ExprSimplifyResult::Simplified(Expr::WindowFunction(
// WindowFunction {
// fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction(
// AggregateFunction::Avg,
// ),
// args,
// partition_by: partition_by.to_vec(),
// order_by: order_by.to_vec(),
// window_frame: window_frame.clone(),
// null_treatment: *null_treatment,
// },
// )))
let simplify = |window_function: datafusion_expr::expr::WindowFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::WindowFunction(WindowFunction {
fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction(
AggregateFunction::Avg,
),
args: window_function.args,
partition_by: window_function.partition_by,
order_by: window_function.order_by,
window_frame: window_function.window_frame,
null_treatment: window_function.null_treatment,
}))
};

Some(Box::new(simplify))
}
}

// create local execution context with `cars.csv` registered as a table named `cars`
async fn create_context() -> Result<SessionContext> {
// declare a new context. In spark API, this corresponds to a new spark SQL session
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
println!("pwd: {}", std::env::current_dir().unwrap().display());
let csv_path = "../../datafusion/core/tests/data/cars.csv".to_string();
let read_options = CsvReadOptions::default().has_header(true);

ctx.register_csv("cars", &csv_path, read_options).await?;
Ok(ctx)
}

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context().await?;
let simplify_smooth_it = WindowUDF::from(SimplifySmoothItUdf::new());
ctx.register_udwf(simplify_smooth_it.clone());

// Use SQL to run the new window function
let df = ctx.sql("SELECT * from cars").await?;
// print the results
df.show().await?;

let df = ctx
.sql(
"SELECT \
car, \
speed, \
simplify_smooth_it(speed) OVER (PARTITION BY car ORDER BY time) AS smooth_speed,\
time \
from cars \
ORDER BY \
car",
)
.await?;
// print the results
df.show().await?;

Ok(())
}
13 changes: 13 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,16 @@ pub type AggregateFunctionSimplification = Box<
&dyn crate::simplify::SimplifyInfo,
) -> Result<Expr>,
>;

/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure
/// A closure with two arguments:
/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked
/// * 'info': [crate::simplify::SimplifyInfo]
///
/// closure returns simplified [Expr] or an error.
pub type WindowFunctionSimplification = Box<
dyn Fn(
crate::expr::WindowFunction,
&dyn crate::simplify::SimplifyInfo,
) -> Result<Expr>,
>;
34 changes: 32 additions & 2 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! [`WindowUDF`]: User Defined Window Functions

use crate::{
Expr, PartitionEvaluator, PartitionEvaluatorFactory, ReturnTypeFunction, Signature,
WindowFrame,
function::WindowFunctionSimplification, Expr, PartitionEvaluator,
PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
Expand Down Expand Up @@ -170,6 +170,13 @@ impl WindowUDF {
self.inner.return_type(args)
}

/// Do the function rewrite
///
/// See [`WindowUDFImpl::simplify`] for more details.
pub fn simplify(&self) -> Option<WindowFunctionSimplification> {
self.inner.simplify()
}

/// Return a `PartitionEvaluator` for evaluating this window function
pub fn partition_evaluator_factory(&self) -> Result<Box<dyn PartitionEvaluator>> {
self.inner.partition_evaluator()
Expand Down Expand Up @@ -266,6 +273,29 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
fn aliases(&self) -> &[String] {
&[]
}

/// Optionally apply per-UDWF simplification / rewrite rules.
///
/// This can be used to apply function specific simplification rules during
/// optimization. The default implementation does nothing.
///
/// Note that DataFusion handles simplifying arguments and "constant
/// folding" (replacing a function call with constant arguments such as
/// `my_add(1,2) --> 3` ). Thus, there is no need to implement such
/// optimizations manually for specific UDFs.
///
/// Example:
/// [`simplify_udwf_expression.rs`]: <https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simplify_udwf_expression.rs>
///
/// # Returns
/// [None] if simplify is not defined or,
///
/// Or, a closure with two arguments:
/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked
/// * 'info': [crate::simplify::SimplifyInfo]
fn simplify(&self) -> Option<WindowFunctionSimplification> {
None
}
}

/// WindowUDF that adds an alias to the underlying function. It is better to
Expand Down
103 changes: 101 additions & 2 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery};
use datafusion_expr::expr::{
AggregateFunctionDefinition, InList, InSubquery, WindowFunction,
};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
WindowFunctionDefinition,
};
use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval};
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};
Expand Down Expand Up @@ -1391,6 +1394,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
(_, expr) => Transformed::no(expr),
},

Expr::WindowFunction(WindowFunction {
fun: WindowFunctionDefinition::WindowUDF(ref udwf),
..
}) => match (udwf.simplify(), expr) {
(Some(simplify_function), Expr::WindowFunction(wf)) => {
Transformed::yes(simplify_function(wf, info)?)
}
(_, expr) => Transformed::no(expr),
},

//
// Rules for Between
//
Expand Down Expand Up @@ -1758,7 +1771,10 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{
function::{AccumulatorArgs, AggregateFunctionSimplification},
function::{
AccumulatorArgs, AggregateFunctionSimplification,
WindowFunctionSimplification,
},
interval_arithmetic::Interval,
*,
};
Expand Down Expand Up @@ -3800,4 +3816,87 @@ mod tests {
}
}
}

#[test]
fn test_simplify_udwf() {
let udwf = WindowFunctionDefinition::WindowUDF(
WindowUDF::new_from_impl(SimplifyMockUdwf::new_with_simplify()).into(),
);
let window_function_expr =
Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new(
udwf,
vec![],
vec![],
vec![],
WindowFrame::new(None),
None,
));

let expected = col("result_column");
assert_eq!(simplify(window_function_expr), expected);

let udwf = WindowFunctionDefinition::WindowUDF(
WindowUDF::new_from_impl(SimplifyMockUdwf::new_without_simplify()).into(),
);
let window_function_expr =
Expr::WindowFunction(datafusion_expr::expr::WindowFunction::new(
udwf,
vec![],
vec![],
vec![],
WindowFrame::new(None),
None,
));

let expected = window_function_expr.clone();
assert_eq!(simplify(window_function_expr), expected);
}

/// A Mock UDWF which defines `simplify` to be used in tests
/// related to UDWF simplification
#[derive(Debug, Clone)]
struct SimplifyMockUdwf {
simplify: bool,
}

impl SimplifyMockUdwf {
/// make simplify method return new expression
fn new_with_simplify() -> Self {
Self { simplify: true }
}
/// make simplify method return no change
fn new_without_simplify() -> Self {
Self { simplify: false }
}
}

impl WindowUDFImpl for SimplifyMockUdwf {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"mock_simplify"
}

fn signature(&self) -> &Signature {
unimplemented!()
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
unimplemented!("not needed for tests")
}

fn simplify(&self) -> Option<WindowFunctionSimplification> {
if self.simplify {
Some(Box::new(|_, _| Ok(col("result_column"))))
} else {
None
}
}

fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
unimplemented!("not needed for tests")
}
}
}

0 comments on commit 3d00760

Please sign in to comment.