Skip to content

Commit

Permalink
refactor: bastilla approx_quantile support
Browse files Browse the repository at this point in the history
Adds bastilla wire encoding for approx_quantile.

Adding support for this required modifying the AggregateExprNode proto
message to support propigating multiple LogicalExprNode aggregate
arguments - all the existing aggregations take a single argument, so
this wasn't needed before.

This commit adds "repeated" to the expr field, which I believe is
backwards compatible as described here:

	https://developers.google.com/protocol-buffers/docs/proto3#updating

Specifically, adding "repeated" to an existing message field:

	"For ... message fields, optional is compatible with repeated"

No existing tests needed fixing, and a new roundtrip test is included
that covers the change to allow multiple expr.
  • Loading branch information
domodwyer committed Jan 11, 2022
1 parent 8714293 commit b4ff5b3
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 7 deletions.
3 changes: 2 additions & 1 deletion ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,12 @@ enum AggregateFunction {
VARIANCE_POP=8;
STDDEV=9;
STDDEV_POP=10;
APPROX_QUANTILE = 11;
}

message AggregateExprNode {
AggregateFunction aggr_function = 1;
LogicalExprNode expr = 2;
repeated LogicalExprNode expr = 2;
}

enum BuiltInWindowFunction {
Expand Down
6 changes: 5 additions & 1 deletion ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,11 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {

Ok(Expr::AggregateFunction {
fun,
args: vec![parse_required_expr(&expr.expr)?],
args: expr
.expr
.iter()
.map(|e| e.try_into())
.collect::<Result<Vec<_>, _>>()?,
distinct: false, //TODO
})
}
Expand Down
15 changes: 14 additions & 1 deletion ballista/rust/core/src/serde/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ mod roundtrip_tests {
use super::super::{super::error::Result, protobuf};
use crate::error::BallistaError;
use core::panic;
use datafusion::logical_plan::Repartition;
use datafusion::{
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit},
datasource::object_store::local::LocalFileSystem,
Expand All @@ -37,6 +36,7 @@ mod roundtrip_tests {
scalar::ScalarValue,
sql::parser::FileType,
};
use datafusion::{logical_plan::Repartition, physical_plan::aggregates};
use protobuf::arrow_type;
use std::{convert::TryInto, sync::Arc};

Expand Down Expand Up @@ -988,4 +988,17 @@ mod roundtrip_tests {

Ok(())
}

#[test]
fn roundtrip_approx_quantile() -> Result<()> {
let test_expr = Expr::AggregateFunction {
fun: aggregates::AggregateFunction::ApproxQuantile,
args: vec![col("bananas"), lit(0.42)],
distinct: false,
};

roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr);

Ok(())
}
}
14 changes: 10 additions & 4 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::ApproxDistinct => {
protobuf::AggregateFunction::ApproxDistinct
}
AggregateFunction::ApproxQuantile => {
protobuf::AggregateFunction::ApproxQuantile
}
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
Expand All @@ -1036,11 +1039,13 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
}
};

let arg = &args[0];
let aggregate_expr = Box::new(protobuf::AggregateExprNode {
let aggregate_expr = protobuf::AggregateExprNode {
aggr_function: aggr_function.into(),
expr: Some(Box::new(arg.try_into()?)),
});
expr: args
.iter()
.map(|v| v.try_into())
.collect::<Result<Vec<_>, _>>()?,
};
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
})
Expand Down Expand Up @@ -1268,6 +1273,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::VariancePop => Self::VariancePop,
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
AggregateFunction::ApproxQuantile => Self::ApproxQuantile,
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop,
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
protobuf::AggregateFunction::ApproxQuantile => {
AggregateFunction::ApproxQuantile
}
}
}
}
Expand Down

0 comments on commit b4ff5b3

Please sign in to comment.