Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add approx_median() aggregate function #1729

Merged
merged 15 commits into from
Feb 9, 2022
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ This library currently supports many SQL constructs, including
- `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)`
- Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`.
- `WHERE` to filter
- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population)
- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `APPROX_PERCENTILE_CONT`, `APPROX_MEDIAN`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population)
- `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST`

## Supported Functions
Expand Down
1 change: 1 addition & 0 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ enum AggregateFunction {
STDDEV_POP=12;
CORRELATION=13;
APPROX_PERCENTILE_CONT = 14;
APPROX_MEDIAN=15;
}

message AggregateExprNode {
Expand Down
4 changes: 4 additions & 0 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::Correlation => {
protobuf::AggregateFunction::Correlation
}
AggregateFunction::ApproxMedian => {
protobuf::AggregateFunction::ApproxMedian
}
};

let aggregate_expr = protobuf::AggregateExprNode {
Expand Down Expand Up @@ -1340,6 +1343,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::StddevPop => Self::StddevPop,
AggregateFunction::Correlation => Self::Correlation,
AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont,
AggregateFunction::ApproxMedian => Self::ApproxMedian,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
protobuf::AggregateFunction::ApproxPercentileCont => {
AggregateFunction::ApproxPercentileCont
}
protobuf::AggregateFunction::ApproxMedian => AggregateFunction::ApproxMedian,
}
}
}
Expand Down
5 changes: 4 additions & 1 deletion datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ use crate::optimizer::limit_push_down::LimitPushDown;
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::projection_push_down::ProjectionPushDown;
use crate::optimizer::simplify_expressions::SimplifyExpressions;
use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use crate::optimizer::to_approx_perc::ToApproxPerc;

use crate::physical_optimizer::coalesce_batches::CoalesceBatches;
use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec;
use crate::physical_optimizer::repartition::Repartition;

use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use crate::logical_plan::plan::Explain;
use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::physical_plan::udf::ScalarUDF;
use crate::physical_plan::ExecutionPlan;
Expand Down Expand Up @@ -925,6 +927,7 @@ impl Default for ExecutionConfig {
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateLimit::new()),
Arc::new(ProjectionPushDown::new()),
Arc::new(ToApproxPerc::new()),
Arc::new(FilterPushDown::new()),
Arc::new(LimitPushDown::new()),
Arc::new(SingleDistinctToGroupBy::new()),
Expand Down
1 change: 1 addition & 0 deletions datafusion/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ pub mod optimizer;
pub mod projection_push_down;
pub mod simplify_expressions;
pub mod single_distinct_to_groupby;
pub mod to_approx_perc;
pub mod utils;
161 changes: 161 additions & 0 deletions datafusion/src/optimizer/to_approx_perc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! espression/function to approx_percentile optimizer rule

use crate::error::Result;
use crate::execution::context::ExecutionProps;
use crate::logical_plan::plan::Aggregate;
use crate::logical_plan::{Expr, LogicalPlan};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use crate::physical_plan::aggregates;
use crate::scalar::ScalarValue;

/// espression/function to approx_percentile optimizer rule
/// ```text
/// SELECT F1(s)
/// ...
///
/// Into
///
/// SELECT APPROX_PERCENTILE_CONT(s, lit(n)) as "F1(s)"
/// ...
/// ```
pub struct ToApproxPerc {}

impl ToApproxPerc {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}

impl Default for ToApproxPerc {
fn default() -> Self {
Self::new()
}
}

fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input,
aggr_expr,
schema,
group_expr,
}) => {
let new_aggr_expr = aggr_expr
.iter()
.map(|agg_expr| replace_with_percentile(agg_expr).unwrap())
.collect::<Vec<_>>();

Ok(LogicalPlan::Aggregate(Aggregate {
input: input.clone(),
aggr_expr: new_aggr_expr,
schema: schema.clone(),
group_expr: group_expr.clone(),
}))
}
_ => optimize_children(plan),
}
}

fn optimize_children(plan: &LogicalPlan) -> Result<LogicalPlan> {
let expr = plan.expressions();
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
.map(|plan| optimize(plan))
.collect::<Result<Vec<_>>>()?;
utils::from_plan(plan, &expr, &new_inputs)
}

fn replace_with_percentile(expr: &Expr) -> Result<Expr> {
match expr {
Expr::AggregateFunction {
fun,
args,
distinct,
} => {
let mut new_args = args.clone();
let mut new_func = fun.clone();
if fun == &aggregates::AggregateFunction::ApproxMedian {
new_args.push(Expr::Literal(ScalarValue::Float64(Some(0.5_f64))));
new_func = aggregates::AggregateFunction::ApproxPercentileCont;
}

Ok(Expr::AggregateFunction {
fun: new_func,
args: new_args,
distinct: *distinct,
})
}
_ => Ok(expr.clone()),
}
}

impl OptimizerRule for ToApproxPerc {
fn optimize(
&self,
plan: &LogicalPlan,
_execution_props: &ExecutionProps,
) -> Result<LogicalPlan> {
optimize(plan)
}
fn name(&self) -> &str {
"ToApproxPerc"
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::{col, LogicalPlanBuilder};
use crate::physical_plan::aggregates;
use crate::test::*;

fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let rule = ToApproxPerc::new();
let optimized_plan = rule
.optimize(plan, &ExecutionProps::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{}", optimized_plan.display_indent_schema());
assert_eq!(formatted_plan, expected);
}

#[test]
fn median_1() -> Result<()> {
let table_scan = test_table_scan()?;
let expr = Expr::AggregateFunction {
fun: aggregates::AggregateFunction::ApproxMedian,
distinct: false,
args: vec![col("b")],
};

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(Vec::<Expr>::new(), vec![expr])?
.build()?;

// Do nothing
realno marked this conversation as resolved.
Show resolved Hide resolved
let expected = "Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#test.b, Float64(0.5))]] [APPROXMEDIAN(test.b):UInt32;N]\
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_eq(&plan, expected);
Ok(())
}
}
80 changes: 77 additions & 3 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ pub enum AggregateFunction {
Correlation,
/// Approximate continuous percentile function
ApproxPercentileCont,
/// ApproxMedian
ApproxMedian,
}

impl fmt::Display for AggregateFunction {
Expand Down Expand Up @@ -113,6 +115,7 @@ impl FromStr for AggregateFunction {
"covar_pop" => AggregateFunction::CovariancePop,
"corr" => AggregateFunction::Correlation,
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
"approx_median" => AggregateFunction::ApproxMedian,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -161,6 +164,7 @@ pub fn return_type(
true,
)))),
AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()),
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
}
}

Expand Down Expand Up @@ -349,6 +353,18 @@ pub fn create_aggregate_expr(
.to_string(),
));
}
(AggregateFunction::ApproxMedian, false) => {
Arc::new(expressions::ApproxMedian::new(
coerced_phy_exprs[0].clone(),
name,
return_type,
))
}
(AggregateFunction::ApproxMedian, true) => {
return Err(DataFusionError::NotImplemented(
"MEDIAN(DISTINCT) aggregations are not available".to_string(),
));
}
})
}

Expand Down Expand Up @@ -398,7 +414,8 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
| AggregateFunction::Variance
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop => {
| AggregateFunction::StddevPop
| AggregateFunction::ApproxMedian => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
Expand All @@ -422,8 +439,9 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
mod tests {
use super::*;
use crate::physical_plan::expressions::{
ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count,
Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance,
ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, Correlation,
Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum,
Variance,
};
use crate::{error::Result, scalar::ScalarValue};

Expand Down Expand Up @@ -996,6 +1014,62 @@ mod tests {
Ok(())
}

#[test]
fn test_median_expr() -> Result<()> {
let funcs = vec![AggregateFunction::ApproxMedian];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;

if fun == AggregateFunction::ApproxMedian {
assert!(result_agg_phy_exprs.as_any().is::<ApproxMedian>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", data_type.clone(), true),
result_agg_phy_exprs.field().unwrap()
);
}
}
}
Ok(())
}

#[test]
fn test_median() -> Result<()> {
let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Utf8]);
assert!(observed.is_err());
realno marked this conversation as resolved.
Show resolved Hide resolved

let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Int32])?;
assert_eq!(DataType::Int32, observed);

let observed = return_type(
&AggregateFunction::ApproxMedian,
&[DataType::Decimal(10, 6)],
);
assert!(observed.is_err());

Ok(())
}

#[test]
fn test_min_max() -> Result<()> {
let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
Expand Down
16 changes: 13 additions & 3 deletions datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::aggregates::AggregateFunction;
use crate::physical_plan::expressions::{
is_avg_support_arg_type, is_correlation_support_arg_type,
is_covariance_support_arg_type, is_stddev_support_arg_type, is_sum_support_arg_type,
is_variance_support_arg_type, try_cast,
is_approx_median_support_arg_type, is_avg_support_arg_type,
is_correlation_support_arg_type, is_covariance_support_arg_type,
is_stddev_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type,
try_cast,
};
use crate::physical_plan::functions::{Signature, TypeSignature};
use crate::physical_plan::PhysicalExpr;
Expand Down Expand Up @@ -154,6 +155,15 @@ pub(crate) fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxMedian => {
if !is_approx_median_support_arg_type(&input_types[0]) {
realno marked this conversation as resolved.
Show resolved Hide resolved
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
}
}

Expand Down
Loading