From c91efc27658e58264c4f346a5cfdec8810179e90 Mon Sep 17 00:00:00 2001 From: Yang Jiang <37145547+Ted-Jiang@users.noreply.github.com> Date: Mon, 18 Apr 2022 10:19:26 +0800 Subject: [PATCH] [Ballista] Enable ApproxPercentileWithWeight in Ballista and fill UT (#2192) * enable ApproxPercentileWithWeight in Ballista * add ApproxPercentileWithWeight in Ballista proto --- ballista/rust/client/src/context.rs | 223 +++++++++++++++++- ballista/rust/core/proto/datafusion.proto | 1 + .../core/src/serde/physical_plan/to_proto.rs | 6 + .../approx_percentile_cont_with_weight.rs | 13 +- 4 files changed, 226 insertions(+), 17 deletions(-) diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 5899598ba2de..7dc7ec63b956 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -642,7 +642,6 @@ mod tests { BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA, }; use datafusion::arrow::util::pretty::pretty_format_batches; - use datafusion::assert_batches_eq; let config = BallistaConfigBuilder::default() .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true") .build() @@ -696,13 +695,15 @@ mod tests { #[tokio::test] #[cfg(feature = "standalone")] - async fn test_percentile_func() { + async fn test_aggregate_func() { use crate::context::BallistaContext; use ballista_core::config::{ BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA, }; + use datafusion::arrow; use datafusion::arrow::util::pretty::pretty_format_batches; use datafusion::prelude::ParquetReadOptions; + let config = BallistaConfigBuilder::default() .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true") .build() @@ -718,6 +719,199 @@ mod tests { ) .await .unwrap(); + + let df = context.sql("select min(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------+", + "| MIN(test.id) |", + "+--------------+", + "| 0 |", + "+--------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context.sql("select max(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------+", + "| MAX(test.id) |", + "+--------------+", + "| 7 |", + "+--------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context.sql("select SUM(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------+", + "| SUM(test.id) |", + "+--------------+", + "| 28 |", + "+--------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context.sql("select AVG(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------+", + "| AVG(test.id) |", + "+--------------+", + "| 3.5 |", + "+--------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context.sql("select COUNT(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+----------------+", + "| COUNT(test.id) |", + "+----------------+", + "| 8 |", + "+----------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select approx_distinct(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+-------------------------+", + "| APPROXDISTINCT(test.id) |", + "+-------------------------+", + "| 8 |", + "+-------------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select ARRAY_AGG(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------------+", + "| ARRAYAGG(test.id) |", + "+--------------------------+", + "| [4, 5, 6, 7, 2, 3, 0, 1] |", + "+--------------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context.sql("select VAR(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+-------------------+", + "| VARIANCE(test.id) |", + "+-------------------+", + "| 6.000000000000001 |", + "+-------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select VAR_POP(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+----------------------+", + "| VARIANCEPOP(test.id) |", + "+----------------------+", + "| 5.250000000000001 |", + "+----------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select VAR_SAMP(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+-------------------+", + "| VARIANCE(test.id) |", + "+-------------------+", + "| 6.000000000000001 |", + "+-------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select STDDEV(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------+", + "| STDDEV(test.id) |", + "+--------------------+", + "| 2.4494897427831783 |", + "+--------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select STDDEV_SAMP(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------+", + "| STDDEV(test.id) |", + "+--------------------+", + "| 2.4494897427831783 |", + "+--------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select COVAR(id, tinyint_col) from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------------------------+", + "| COVARIANCE(test.id,test.tinyint_col) |", + "+--------------------------------------+", + "| 0.28571428571428586 |", + "+--------------------------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select CORR(id, tinyint_col) from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+---------------------------------------+", + "| CORRELATION(test.id,test.tinyint_col) |", + "+---------------------------------------+", + "| 0.21821789023599245 |", + "+---------------------------------------+", + ]; + assert_result_eq(expected, &*res); + + let df = context + .sql("select approx_percentile_cont_with_weight(\"id\", 2, 0.5) from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+---------------------------------------------------------------+", + "| APPROXPERCENTILECONTWITHWEIGHT(test.id,Int64(2),Float64(0.5)) |", + "+---------------------------------------------------------------+", + "| 1 |", + "+---------------------------------------------------------------+", + ]; + assert_result_eq(expected, &*res); + let df = context .sql("select approx_percentile_cont(\"double_col\", 0.5) from test") .await @@ -731,14 +925,21 @@ mod tests { "+----------------------------------------------------+", ]; - assert_eq!( - expected, - pretty_format_batches(&*res) - .unwrap() - .to_string() - .trim() - .lines() - .collect::>() - ); + assert_result_eq(expected, &*res); + + fn assert_result_eq( + expected: Vec<&str>, + results: &[arrow::record_batch::RecordBatch], + ) { + assert_eq!( + expected, + pretty_format_batches(results) + .unwrap() + .to_string() + .trim() + .lines() + .collect::>() + ); + } } } diff --git a/ballista/rust/core/proto/datafusion.proto b/ballista/rust/core/proto/datafusion.proto index 1dc9b34f7dd5..9999abbf2cf0 100644 --- a/ballista/rust/core/proto/datafusion.proto +++ b/ballista/rust/core/proto/datafusion.proto @@ -201,6 +201,7 @@ enum AggregateFunction { CORRELATION=13; APPROX_PERCENTILE_CONT = 14; APPROX_MEDIAN=15; + APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; } message AggregateExprNode { diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 3a1f24d0f6d1..d022766d85af 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -123,6 +123,12 @@ impl TryInto for Arc { .is_some() { Ok(AggregateFunction::ApproxPercentileCont.into()) + } else if self + .as_any() + .downcast_ref::() + .is_some() + { + Ok(AggregateFunction::ApproxPercentileContWithWeight.into()) } else if self .as_any() .downcast_ref::() diff --git a/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs b/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs index 33b2ee7a67c4..1beb7a86cfca 100644 --- a/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs +++ b/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs @@ -38,6 +38,7 @@ pub struct ApproxPercentileContWithWeight { approx_percentile_cont: ApproxPercentileCont, column_expr: Arc, weight_expr: Arc, + percentile_expr: Arc, } impl ApproxPercentileContWithWeight { @@ -58,6 +59,7 @@ impl ApproxPercentileContWithWeight { approx_percentile_cont, column_expr: expr[0].clone(), weight_expr: expr[1].clone(), + percentile_expr: expr[2].clone(), }) } } @@ -79,7 +81,11 @@ impl AggregateExpr for ApproxPercentileContWithWeight { } fn expressions(&self) -> Vec> { - vec![self.column_expr.clone(), self.weight_expr.clone()] + vec![ + self.column_expr.clone(), + self.weight_expr.clone(), + self.percentile_expr.clone(), + ] } fn create_accumulator(&self) -> Result> { @@ -115,11 +121,6 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator { } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - debug_assert_eq!( - values.len(), - 2, - "invalid number of values in batch percentile update" - ); let means = &values[0]; let weights = &values[1]; debug_assert_eq!(