diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index fb006e532ff3..0b0d364ca7c7 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -177,6 +177,7 @@ enum AggregateFunction { STDDEV_POP=12; CORRELATION=13; APPROX_PERCENTILE_CONT = 14; + APPROX_MEDIAN=15; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 4b13ce577cfb..84910b2c31fa 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1100,6 +1100,9 @@ impl TryInto for &Expr { AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } + AggregateFunction::ApproxMedian => { + protobuf::AggregateFunction::ApproxMedian + } }; let aggregate_expr = protobuf::AggregateExprNode { @@ -1340,6 +1343,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::StddevPop => Self::StddevPop, AggregateFunction::Correlation => Self::Correlation, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, + AggregateFunction::ApproxMedian => Self::ApproxMedian, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index 64a60dc4da5d..f7b0b9436c4c 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -132,6 +132,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxPercentileCont => { AggregateFunction::ApproxPercentileCont } + protobuf::AggregateFunction::ApproxMedian => AggregateFunction::ApproxMedian, } } } diff --git a/datafusion-expr/src/aggregate_function.rs b/datafusion-expr/src/aggregate_function.rs index 8f12e88bf1a2..4e03445f7209 100644 --- a/datafusion-expr/src/aggregate_function.rs +++ b/datafusion-expr/src/aggregate_function.rs @@ -51,6 +51,8 @@ pub enum AggregateFunction { Correlation, /// Approximate continuous percentile function ApproxPercentileCont, + /// ApproxMedian + ApproxMedian, } impl fmt::Display for AggregateFunction { @@ -82,6 +84,7 @@ impl FromStr for AggregateFunction { "covar_pop" => AggregateFunction::CovariancePop, "corr" => AggregateFunction::Correlation, "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, + "approx_median" => AggregateFunction::ApproxMedian, _ => { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 2f8663ecae01..e3d0ba93c3f0 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -70,13 +70,15 @@ use crate::optimizer::limit_push_down::LimitPushDown; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; use crate::optimizer::simplify_expressions::SimplifyExpressions; +use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; +use crate::optimizer::to_approx_perc::ToApproxPerc; + use crate::physical_optimizer::coalesce_batches::CoalesceBatches; use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec; use crate::physical_optimizer::repartition::Repartition; use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::logical_plan::plan::Explain; -use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use crate::physical_plan::planner::DefaultPhysicalPlanner; use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; @@ -913,6 +915,10 @@ impl Default for ExecutionConfig { Arc::new(FilterPushDown::new()), Arc::new(LimitPushDown::new()), Arc::new(SingleDistinctToGroupBy::new()), + // ToApproxPerc must be applied last because + // it rewrites only the function and may interfere with + // other rules + Arc::new(ToApproxPerc::new()), ], physical_optimizers: vec![ Arc::new(AggregateStatistics::new()), diff --git a/datafusion/src/optimizer/mod.rs b/datafusion/src/optimizer/mod.rs index 984cbee90947..418eaad4bc5c 100644 --- a/datafusion/src/optimizer/mod.rs +++ b/datafusion/src/optimizer/mod.rs @@ -27,4 +27,5 @@ pub mod optimizer; pub mod projection_push_down; pub mod simplify_expressions; pub mod single_distinct_to_groupby; +pub mod to_approx_perc; pub mod utils; diff --git a/datafusion/src/optimizer/to_approx_perc.rs b/datafusion/src/optimizer/to_approx_perc.rs new file mode 100644 index 000000000000..c33c3f67602a --- /dev/null +++ b/datafusion/src/optimizer/to_approx_perc.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! espression/function to approx_percentile optimizer rule + +use crate::error::Result; +use crate::execution::context::ExecutionProps; +use crate::logical_plan::plan::Aggregate; +use crate::logical_plan::{Expr, LogicalPlan}; +use crate::optimizer::optimizer::OptimizerRule; +use crate::optimizer::utils; +use crate::physical_plan::aggregates; +use crate::scalar::ScalarValue; + +/// espression/function to approx_percentile optimizer rule +/// ```text +/// SELECT F1(s) +/// ... +/// +/// Into +/// +/// SELECT APPROX_PERCENTILE_CONT(s, lit(n)) as "F1(s)" +/// ... +/// ``` +pub struct ToApproxPerc {} + +impl ToApproxPerc { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl Default for ToApproxPerc { + fn default() -> Self { + Self::new() + } +} + +fn optimize(plan: &LogicalPlan) -> Result { + match plan { + LogicalPlan::Aggregate(Aggregate { + input, + aggr_expr, + schema, + group_expr, + }) => { + let new_aggr_expr = aggr_expr + .iter() + .map(|agg_expr| replace_with_percentile(agg_expr).unwrap()) + .collect::>(); + + Ok(LogicalPlan::Aggregate(Aggregate { + input: input.clone(), + aggr_expr: new_aggr_expr, + schema: schema.clone(), + group_expr: group_expr.clone(), + })) + } + _ => optimize_children(plan), + } +} + +fn optimize_children(plan: &LogicalPlan) -> Result { + let expr = plan.expressions(); + let inputs = plan.inputs(); + let new_inputs = inputs + .iter() + .map(|plan| optimize(plan)) + .collect::>>()?; + utils::from_plan(plan, &expr, &new_inputs) +} + +fn replace_with_percentile(expr: &Expr) -> Result { + match expr { + Expr::AggregateFunction { + fun, + args, + distinct, + } => { + let mut new_args = args.clone(); + let mut new_func = fun.clone(); + if fun == &aggregates::AggregateFunction::ApproxMedian { + new_args.push(Expr::Literal(ScalarValue::Float64(Some(0.5_f64)))); + new_func = aggregates::AggregateFunction::ApproxPercentileCont; + } + + Ok(Expr::AggregateFunction { + fun: new_func, + args: new_args, + distinct: *distinct, + }) + } + _ => Ok(expr.clone()), + } +} + +impl OptimizerRule for ToApproxPerc { + fn optimize( + &self, + plan: &LogicalPlan, + _execution_props: &ExecutionProps, + ) -> Result { + optimize(plan) + } + fn name(&self) -> &str { + "ToApproxPerc" + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::{col, LogicalPlanBuilder}; + use crate::physical_plan::aggregates; + use crate::test::*; + + fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { + let rule = ToApproxPerc::new(); + let optimized_plan = rule + .optimize(plan, &ExecutionProps::new()) + .expect("failed to optimize plan"); + let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); + assert_eq!(formatted_plan, expected); + } + + #[test] + fn median_1() -> Result<()> { + let table_scan = test_table_scan()?; + let expr = Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxMedian, + distinct: false, + args: vec![col("b")], + }; + + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![expr])? + .build()?; + + // Rewrite to use approx_percentile + let expected = "Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#test.b, Float64(0.5))]] [APPROXMEDIAN(test.b):UInt32;N]\ + \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq(&plan, expected); + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index a1531d4a7b83..5de8a9da71ad 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -89,6 +89,7 @@ pub fn return_type( true, )))), AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), + AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), } } @@ -277,6 +278,18 @@ pub fn create_aggregate_expr( .to_string(), )); } + (AggregateFunction::ApproxMedian, false) => { + Arc::new(expressions::ApproxMedian::new( + coerced_phy_exprs[0].clone(), + name, + return_type, + )) + } + (AggregateFunction::ApproxMedian, true) => { + return Err(DataFusionError::NotImplemented( + "MEDIAN(DISTINCT) aggregations are not available".to_string(), + )); + } }) } @@ -326,7 +339,8 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature { | AggregateFunction::Variance | AggregateFunction::VariancePop | AggregateFunction::Stddev - | AggregateFunction::StddevPop => { + | AggregateFunction::StddevPop + | AggregateFunction::ApproxMedian => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::Covariance | AggregateFunction::CovariancePop => { @@ -350,8 +364,9 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::physical_plan::expressions::{ - ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count, - Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, + ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, Correlation, + Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, + Variance, }; use crate::{error::Result, scalar::ScalarValue}; @@ -924,6 +939,62 @@ mod tests { Ok(()) } + #[test] + fn test_median_expr() -> Result<()> { + let funcs = vec![AggregateFunction::ApproxMedian]; + let data_types = vec![ + DataType::UInt32, + DataType::UInt64, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + + if fun == AggregateFunction::ApproxMedian { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + } + } + Ok(()) + } + + #[test] + fn test_median() -> Result<()> { + let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Utf8]); + assert!(observed.is_err()); + + let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Int32])?; + assert_eq!(DataType::Int32, observed); + + let observed = return_type( + &AggregateFunction::ApproxMedian, + &[DataType::Decimal(10, 6)], + ); + assert!(observed.is_err()); + + Ok(()) + } + #[test] fn test_min_max() -> Result<()> { let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?; diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs index bae2de74c7b7..47d406579241 100644 --- a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -154,6 +154,15 @@ pub(crate) fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::ApproxMedian => { + if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } } } diff --git a/datafusion/src/physical_plan/expressions/approx_median.rs b/datafusion/src/physical_plan/expressions/approx_median.rs new file mode 100644 index 000000000000..2ca585759c6b --- /dev/null +++ b/datafusion/src/physical_plan/expressions/approx_median.rs @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution + +use std::any::Any; +use std::sync::Arc; + +use crate::error::Result; +use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use arrow::{datatypes::DataType, datatypes::Field}; + +/// MEDIAN aggregate expression +#[derive(Debug)] +pub struct ApproxMedian { + name: String, + expr: Arc, + data_type: DataType, +} + +impl ApproxMedian { + /// Create a new APPROX_MEDIAN aggregate function + pub fn new( + expr: Arc, + name: impl Into, + data_type: DataType, + ) -> Self { + Self { + name: name.into(), + expr, + data_type, + } + } +} + +impl AggregateExpr for ApproxMedian { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result> { + unimplemented!() + } + + fn state_fields(&self) -> Result> { + unimplemented!() + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn name(&self) -> &str { + &self.name + } +} diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 9344fbd6b1bc..06afe004ff34 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -44,6 +44,7 @@ mod lead_lag; mod literal; #[macro_use] mod min_max; +mod approx_median; mod correlation; mod covariance; mod distinct_expressions; @@ -65,6 +66,7 @@ pub mod helpers { } pub use approx_distinct::ApproxDistinct; +pub(crate) use approx_median::ApproxMedian; pub use approx_percentile_cont::{ is_approx_percentile_cont_supported_arg_type, ApproxPercentileCont, }; diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index fd1d15cc0ca7..528386d0ecba 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -219,6 +219,39 @@ async fn csv_query_stddev_6() -> Result<()> { Ok(()) } +#[tokio::test] +async fn csv_query_median_1() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT approx_median(c2) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["3"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_median_2() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT approx_median(c6) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["1146409980542786560"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_median_3() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx).await?; + let sql = "SELECT approx_median(c12) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["0.5550065410522981"]]; + assert_float_eq(&expected, &actual); + Ok(()) +} + #[tokio::test] async fn csv_query_external_table_count() { let mut ctx = ExecutionContext::new();