From b4ff5b33219f029d1063dd5faae9f9a6f6f866e1 Mon Sep 17 00:00:00 2001 From: Dom Date: Mon, 10 Jan 2022 14:15:01 +0000 Subject: [PATCH] refactor: bastilla approx_quantile support Adds bastilla wire encoding for approx_quantile. Adding support for this required modifying the AggregateExprNode proto message to support propigating multiple LogicalExprNode aggregate arguments - all the existing aggregations take a single argument, so this wasn't needed before. This commit adds "repeated" to the expr field, which I believe is backwards compatible as described here: https://developers.google.com/protocol-buffers/docs/proto3#updating Specifically, adding "repeated" to an existing message field: "For ... message fields, optional is compatible with repeated" No existing tests needed fixing, and a new roundtrip test is included that covers the change to allow multiple expr. --- ballista/rust/core/proto/ballista.proto | 3 ++- .../core/src/serde/logical_plan/from_proto.rs | 6 +++++- ballista/rust/core/src/serde/logical_plan/mod.rs | 15 ++++++++++++++- .../rust/core/src/serde/logical_plan/to_proto.rs | 14 ++++++++++---- ballista/rust/core/src/serde/mod.rs | 3 +++ 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index aa7b6a9f900fe..30fe595150994 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -173,11 +173,12 @@ enum AggregateFunction { VARIANCE_POP=8; STDDEV=9; STDDEV_POP=10; + APPROX_QUANTILE = 11; } message AggregateExprNode { AggregateFunction aggr_function = 1; - LogicalExprNode expr = 2; + repeated LogicalExprNode expr = 2; } enum BuiltInWindowFunction { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index dfac547d7bb35..422907d34004c 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -962,7 +962,11 @@ impl TryInto for &protobuf::LogicalExprNode { Ok(Expr::AggregateFunction { fun, - args: vec![parse_required_expr(&expr.expr)?], + args: expr + .expr + .iter() + .map(|e| e.try_into()) + .collect::, _>>()?, distinct: false, //TODO }) } diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index a0f481a803258..f5efacce9ec0c 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -24,7 +24,6 @@ mod roundtrip_tests { use super::super::{super::error::Result, protobuf}; use crate::error::BallistaError; use core::panic; - use datafusion::logical_plan::Repartition; use datafusion::{ arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}, datasource::object_store::local::LocalFileSystem, @@ -37,6 +36,7 @@ mod roundtrip_tests { scalar::ScalarValue, sql::parser::FileType, }; + use datafusion::{logical_plan::Repartition, physical_plan::aggregates}; use protobuf::arrow_type; use std::{convert::TryInto, sync::Arc}; @@ -988,4 +988,17 @@ mod roundtrip_tests { Ok(()) } + + #[test] + fn roundtrip_approx_quantile() -> Result<()> { + let test_expr = Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxQuantile, + args: vec![col("bananas"), lit(0.42)], + distinct: false, + }; + + roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr); + + Ok(()) + } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 01428d9ba7a77..d1d05fc5661c6 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1020,6 +1020,9 @@ impl TryInto for &Expr { AggregateFunction::ApproxDistinct => { protobuf::AggregateFunction::ApproxDistinct } + AggregateFunction::ApproxQuantile => { + protobuf::AggregateFunction::ApproxQuantile + } AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max, @@ -1036,11 +1039,13 @@ impl TryInto for &Expr { } }; - let arg = &args[0]; - let aggregate_expr = Box::new(protobuf::AggregateExprNode { + let aggregate_expr = protobuf::AggregateExprNode { aggr_function: aggr_function.into(), - expr: Some(Box::new(arg.try_into()?)), - }); + expr: args + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?, + }; Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::AggregateExpr(aggregate_expr)), }) @@ -1268,6 +1273,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::VariancePop => Self::VariancePop, AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, + AggregateFunction::ApproxQuantile => Self::ApproxQuantile, } } } diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index fd3b57b3deda1..ccad3a328a399 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -123,6 +123,9 @@ impl From for AggregateFunction { protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop, protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev, protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop, + protobuf::AggregateFunction::ApproxQuantile => { + AggregateFunction::ApproxQuantile + } } } }