Skip to content

Commit

Permalink
Add approx-median operator (#1729)
Browse files Browse the repository at this point in the history
* add median operator

* update doc

* rename median to approx_median

* rename median to approx_median

* add doc

* test optimizer

* try rewriting logical plan

* move rewrite rule to earlier stages

* fix lint

* move the rule after projection push down

* get ready to merge

* remove unused function

* remove commented out code

* Update datafusion/src/optimizer/to_approx_perc.rs

Co-authored-by: Andrew Lamb <[email protected]>

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
realno and alamb authored Feb 9, 2022
1 parent 59ecf2b commit 6e02d2d
Show file tree
Hide file tree
Showing 12 changed files with 371 additions and 4 deletions.
1 change: 1 addition & 0 deletions ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ enum AggregateFunction {
STDDEV_POP=12;
CORRELATION=13;
APPROX_PERCENTILE_CONT = 14;
APPROX_MEDIAN=15;
}

message AggregateExprNode {
Expand Down
4 changes: 4 additions & 0 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::Correlation => {
protobuf::AggregateFunction::Correlation
}
AggregateFunction::ApproxMedian => {
protobuf::AggregateFunction::ApproxMedian
}
};

let aggregate_expr = protobuf::AggregateExprNode {
Expand Down Expand Up @@ -1340,6 +1343,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::StddevPop => Self::StddevPop,
AggregateFunction::Correlation => Self::Correlation,
AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont,
AggregateFunction::ApproxMedian => Self::ApproxMedian,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
protobuf::AggregateFunction::ApproxPercentileCont => {
AggregateFunction::ApproxPercentileCont
}
protobuf::AggregateFunction::ApproxMedian => AggregateFunction::ApproxMedian,
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions datafusion-expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub enum AggregateFunction {
Correlation,
/// Approximate continuous percentile function
ApproxPercentileCont,
/// ApproxMedian
ApproxMedian,
}

impl fmt::Display for AggregateFunction {
Expand Down Expand Up @@ -82,6 +84,7 @@ impl FromStr for AggregateFunction {
"covar_pop" => AggregateFunction::CovariancePop,
"corr" => AggregateFunction::Correlation,
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
"approx_median" => AggregateFunction::ApproxMedian,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down
8 changes: 7 additions & 1 deletion datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ use crate::optimizer::limit_push_down::LimitPushDown;
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::projection_push_down::ProjectionPushDown;
use crate::optimizer::simplify_expressions::SimplifyExpressions;
use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use crate::optimizer::to_approx_perc::ToApproxPerc;

use crate::physical_optimizer::coalesce_batches::CoalesceBatches;
use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec;
use crate::physical_optimizer::repartition::Repartition;

use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use crate::logical_plan::plan::Explain;
use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::physical_plan::udf::ScalarUDF;
use crate::physical_plan::ExecutionPlan;
Expand Down Expand Up @@ -913,6 +915,10 @@ impl Default for ExecutionConfig {
Arc::new(FilterPushDown::new()),
Arc::new(LimitPushDown::new()),
Arc::new(SingleDistinctToGroupBy::new()),
// ToApproxPerc must be applied last because
// it rewrites only the function and may interfere with
// other rules
Arc::new(ToApproxPerc::new()),
],
physical_optimizers: vec![
Arc::new(AggregateStatistics::new()),
Expand Down
1 change: 1 addition & 0 deletions datafusion/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ pub mod optimizer;
pub mod projection_push_down;
pub mod simplify_expressions;
pub mod single_distinct_to_groupby;
pub mod to_approx_perc;
pub mod utils;
161 changes: 161 additions & 0 deletions datafusion/src/optimizer/to_approx_perc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! espression/function to approx_percentile optimizer rule

use crate::error::Result;
use crate::execution::context::ExecutionProps;
use crate::logical_plan::plan::Aggregate;
use crate::logical_plan::{Expr, LogicalPlan};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use crate::physical_plan::aggregates;
use crate::scalar::ScalarValue;

/// espression/function to approx_percentile optimizer rule
/// ```text
/// SELECT F1(s)
/// ...
///
/// Into
///
/// SELECT APPROX_PERCENTILE_CONT(s, lit(n)) as "F1(s)"
/// ...
/// ```
pub struct ToApproxPerc {}

impl ToApproxPerc {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}

impl Default for ToApproxPerc {
fn default() -> Self {
Self::new()
}
}

fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
match plan {
LogicalPlan::Aggregate(Aggregate {
input,
aggr_expr,
schema,
group_expr,
}) => {
let new_aggr_expr = aggr_expr
.iter()
.map(|agg_expr| replace_with_percentile(agg_expr).unwrap())
.collect::<Vec<_>>();

Ok(LogicalPlan::Aggregate(Aggregate {
input: input.clone(),
aggr_expr: new_aggr_expr,
schema: schema.clone(),
group_expr: group_expr.clone(),
}))
}
_ => optimize_children(plan),
}
}

fn optimize_children(plan: &LogicalPlan) -> Result<LogicalPlan> {
let expr = plan.expressions();
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
.map(|plan| optimize(plan))
.collect::<Result<Vec<_>>>()?;
utils::from_plan(plan, &expr, &new_inputs)
}

fn replace_with_percentile(expr: &Expr) -> Result<Expr> {
match expr {
Expr::AggregateFunction {
fun,
args,
distinct,
} => {
let mut new_args = args.clone();
let mut new_func = fun.clone();
if fun == &aggregates::AggregateFunction::ApproxMedian {
new_args.push(Expr::Literal(ScalarValue::Float64(Some(0.5_f64))));
new_func = aggregates::AggregateFunction::ApproxPercentileCont;
}

Ok(Expr::AggregateFunction {
fun: new_func,
args: new_args,
distinct: *distinct,
})
}
_ => Ok(expr.clone()),
}
}

impl OptimizerRule for ToApproxPerc {
fn optimize(
&self,
plan: &LogicalPlan,
_execution_props: &ExecutionProps,
) -> Result<LogicalPlan> {
optimize(plan)
}
fn name(&self) -> &str {
"ToApproxPerc"
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::{col, LogicalPlanBuilder};
use crate::physical_plan::aggregates;
use crate::test::*;

fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
let rule = ToApproxPerc::new();
let optimized_plan = rule
.optimize(plan, &ExecutionProps::new())
.expect("failed to optimize plan");
let formatted_plan = format!("{}", optimized_plan.display_indent_schema());
assert_eq!(formatted_plan, expected);
}

#[test]
fn median_1() -> Result<()> {
let table_scan = test_table_scan()?;
let expr = Expr::AggregateFunction {
fun: aggregates::AggregateFunction::ApproxMedian,
distinct: false,
args: vec![col("b")],
};

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(Vec::<Expr>::new(), vec![expr])?
.build()?;

// Rewrite to use approx_percentile
let expected = "Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#test.b, Float64(0.5))]] [APPROXMEDIAN(test.b):UInt32;N]\
\n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_eq(&plan, expected);
Ok(())
}
}
77 changes: 74 additions & 3 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ pub fn return_type(
true,
)))),
AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()),
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
}
}

Expand Down Expand Up @@ -268,6 +269,18 @@ pub fn create_aggregate_expr(
.to_string(),
));
}
(AggregateFunction::ApproxMedian, false) => {
Arc::new(expressions::ApproxMedian::new(
coerced_phy_exprs[0].clone(),
name,
return_type,
))
}
(AggregateFunction::ApproxMedian, true) => {
return Err(DataFusionError::NotImplemented(
"MEDIAN(DISTINCT) aggregations are not available".to_string(),
));
}
})
}

Expand Down Expand Up @@ -317,7 +330,8 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
| AggregateFunction::Variance
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop => {
| AggregateFunction::StddevPop
| AggregateFunction::ApproxMedian => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::Covariance | AggregateFunction::CovariancePop => {
Expand All @@ -341,8 +355,9 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
mod tests {
use super::*;
use crate::physical_plan::expressions::{
ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count,
Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance,
ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, Correlation,
Count, Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum,
Variance,
};
use crate::{error::Result, scalar::ScalarValue};

Expand Down Expand Up @@ -915,6 +930,62 @@ mod tests {
Ok(())
}

#[test]
fn test_median_expr() -> Result<()> {
let funcs = vec![AggregateFunction::ApproxMedian];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_aggregate_expr(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;

if fun == AggregateFunction::ApproxMedian {
assert!(result_agg_phy_exprs.as_any().is::<ApproxMedian>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", data_type.clone(), true),
result_agg_phy_exprs.field().unwrap()
);
}
}
}
Ok(())
}

#[test]
fn test_median() -> Result<()> {
let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Utf8]);
assert!(observed.is_err());

let observed = return_type(&AggregateFunction::ApproxMedian, &[DataType::Int32])?;
assert_eq!(DataType::Int32, observed);

let observed = return_type(
&AggregateFunction::ApproxMedian,
&[DataType::Decimal(10, 6)],
);
assert!(observed.is_err());

Ok(())
}

#[test]
fn test_min_max() -> Result<()> {
let observed = return_type(&AggregateFunction::Min, &[DataType::Utf8])?;
Expand Down
9 changes: 9 additions & 0 deletions datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ pub(crate) fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxMedian => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
}
}

Expand Down
Loading

0 comments on commit 6e02d2d

Please sign in to comment.