From 933fec845e02bb8983a0b932cfa12ebb27054748 Mon Sep 17 00:00:00 2001 From: Takahiro Ebato Date: Fri, 27 Dec 2024 15:32:56 +0900 Subject: [PATCH] Consolidate example: simplify_udaf_expression.rs into advanced_udaf.rs (#13905) --- datafusion-examples/examples/advanced_udaf.rs | 185 +++++++++++++----- .../examples/simplify_udaf_expression.rs | 176 ----------------- 2 files changed, 132 insertions(+), 229 deletions(-) delete mode 100644 datafusion-examples/examples/simplify_udaf_expression.rs diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 414596bdc678..a914cea4a928 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -31,7 +31,9 @@ use datafusion::error::Result; use datafusion::prelude::*; use datafusion_common::{cast::as_float64_array, ScalarValue}; use datafusion_expr::{ - function::{AccumulatorArgs, StateFieldsArgs}, + expr::AggregateFunction, + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + simplify::SimplifyInfo, Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, }; @@ -197,40 +199,6 @@ impl Accumulator for GeometricMean { } } -// create local session context with an in-memory table -fn create_context() -> Result { - use datafusion::datasource::MemTable; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float32, false), - ])); - - // define data in two partitions - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), - Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), - ], - )?; - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![64.0])), - Arc::new(Float32Array::from(vec![2.0])), - ], - )?; - - // declare a new context. In spark API, this corresponds to a new spark SQLsession - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - Ok(ctx) -} - // Define a `GroupsAccumulator` for GeometricMean /// which handles accumulator state for multiple groups at once. /// This API is significantly more complicated than `Accumulator`, which manages @@ -399,35 +367,146 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator { } } +/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user +/// defined aggregate function with a different expression which is defined in the `simplify` method. +#[derive(Debug, Clone)] +struct SimplifiedGeoMeanUdaf { + signature: Signature, +} + +impl SimplifiedGeoMeanUdaf { + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "simplified_geo_mean" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + unimplemented!("should not get here"); + } + + /// Optionally replaces a UDAF with another expression during query optimization. + fn simplify(&self) -> Option { + let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { + // Replaces the UDAF with `GeoMeanUdaf` as a placeholder example to demonstrate the `simplify` method. + // In real-world scenarios, you might create UDFs from built-in expressions. + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + Arc::new(AggregateUDF::from(GeoMeanUdaf::new())), + aggregate_function.args, + aggregate_function.distinct, + aggregate_function.filter, + aggregate_function.order_by, + aggregate_function.null_treatment, + ))) + }; + Some(Box::new(simplify)) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), + Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![64.0])), + Arc::new(Float32Array::from(vec![2.0])), + ], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + #[tokio::main] async fn main() -> Result<()> { let ctx = create_context()?; - // create the AggregateUDF - let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new()); - ctx.register_udaf(geometric_mean.clone()); + let geo_mean_udf = AggregateUDF::from(GeoMeanUdaf::new()); + let simplified_geo_mean_udf = AggregateUDF::from(SimplifiedGeoMeanUdaf::new()); + + for (udf, udf_name) in [ + (geo_mean_udf, "geo_mean"), + (simplified_geo_mean_udf, "simplified_geo_mean"), + ] { + ctx.register_udaf(udf.clone()); - let sql_df = ctx.sql("SELECT geo_mean(a) FROM t group by b").await?; - sql_df.show().await?; + let sql_df = ctx + .sql(&format!("SELECT {}(a) FROM t GROUP BY b", udf_name)) + .await?; + sql_df.show().await?; - // get a DataFrame from the context - // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. - let df = ctx.table("t").await?; + // get a DataFrame from the context + // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. + let df = ctx.table("t").await?; - // perform the aggregation - let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?; + // perform the aggregation + let df = df.aggregate(vec![], vec![udf.call(vec![col("a")])])?; - // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. + // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. - // execute the query - let results = df.collect().await?; + // execute the query + let results = df.collect().await?; - // downcast the array to the expected type - let result = as_float64_array(results[0].column(0))?; + // downcast the array to the expected type + let result = as_float64_array(results[0].column(0))?; - // verify that the calculation is correct - assert!((result.value(0) - 8.0).abs() < f64::EPSILON); - println!("The geometric mean of [2,4,8,64] is {}", result.value(0)); + // verify that the calculation is correct + assert!((result.value(0) - 8.0).abs() < f64::EPSILON); + println!("The geometric mean of [2,4,8,64] is {}", result.value(0)); + } Ok(()) } diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs deleted file mode 100644 index 52a27317e3c3..000000000000 --- a/datafusion-examples/examples/simplify_udaf_expression.rs +++ /dev/null @@ -1,176 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::{any::Any, sync::Arc}; - -use arrow_schema::{Field, Schema}; - -use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch}; -use datafusion::error::Result; -use datafusion::functions_aggregate::average::avg_udaf; -use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; -use datafusion::{assert_batches_eq, prelude::*}; -use datafusion_common::cast::as_float64_array; -use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs}; -use datafusion_expr::simplify::SimplifyInfo; -use datafusion_expr::{ - expr::AggregateFunction, function::AccumulatorArgs, Accumulator, AggregateUDF, - AggregateUDFImpl, GroupsAccumulator, Signature, -}; - -/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user -/// defined aggregate function with a different expression which is defined in the `simplify` method. - -#[derive(Debug, Clone)] -struct BetterAvgUdaf { - signature: Signature, -} - -impl BetterAvgUdaf { - /// Create a new instance of the GeoMeanUdaf struct - fn new() -> Self { - Self { - signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), - } - } -} - -impl AggregateUDFImpl for BetterAvgUdaf { - fn as_any(&self) -> &dyn Any { - self - } - - fn name(&self) -> &str { - "better_avg" - } - - fn signature(&self) -> &Signature { - &self.signature - } - - fn return_type(&self, _arg_types: &[DataType]) -> Result { - Ok(DataType::Float64) - } - - fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - unimplemented!("should not be invoked") - } - - fn state_fields(&self, _args: StateFieldsArgs) -> Result> { - unimplemented!("should not be invoked") - } - - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { - true - } - - fn create_groups_accumulator( - &self, - _args: AccumulatorArgs, - ) -> Result> { - unimplemented!("should not get here"); - } - - // we override method, to return new expression which would substitute - // user defined function call - fn simplify(&self) -> Option { - // as an example for this functionality we replace UDF function - // with build-in aggregate function to illustrate the use - let simplify = |aggregate_function: AggregateFunction, _: &dyn SimplifyInfo| { - Ok(Expr::AggregateFunction(AggregateFunction::new_udf( - avg_udaf(), - // yes it is the same Avg, `BetterAvgUdaf` was just a - // marketing pitch :) - aggregate_function.args, - aggregate_function.distinct, - aggregate_function.filter, - aggregate_function.order_by, - aggregate_function.null_treatment, - ))) - }; - - Some(Box::new(simplify)) - } -} - -// create local session context with an in-memory table -fn create_context() -> Result { - use datafusion::datasource::MemTable; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Float32, false), - Field::new("b", DataType::Float32, false), - ])); - - // define data in two partitions - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), - Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), - ], - )?; - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Float32Array::from(vec![16.0])), - Arc::new(Float32Array::from(vec![2.0])), - ], - )?; - - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("t", Arc::new(provider))?; - Ok(ctx) -} - -#[tokio::main] -async fn main() -> Result<()> { - let ctx = create_context()?; - - let better_avg = AggregateUDF::from(BetterAvgUdaf::new()); - ctx.register_udaf(better_avg.clone()); - - let result = ctx - .sql("SELECT better_avg(a) FROM t group by b") - .await? - .collect() - .await?; - - let expected = [ - "+-----------------+", - "| better_avg(t.a) |", - "+-----------------+", - "| 7.5 |", - "+-----------------+", - ]; - - assert_batches_eq!(expected, &result); - - let df = ctx.table("t").await?; - let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?; - - let results = df.collect().await?; - let result = as_float64_array(results[0].column(0))?; - - assert!((result.value(0) - 7.5).abs() < f64::EPSILON); - println!("The average of [2,4,8,16] is {}", result.value(0)); - - Ok(()) -}