Skip to content

Commit

Permalink
[Ballista] Enable ApproxPercentileWithWeight in Ballista and fill UT (#…
Browse files Browse the repository at this point in the history
…2192)

* enable ApproxPercentileWithWeight in Ballista

* add ApproxPercentileWithWeight in Ballista proto
  • Loading branch information
Ted-Jiang authored Apr 18, 2022
1 parent 22b70b8 commit c91efc2
Show file tree
Hide file tree
Showing 4 changed files with 226 additions and 17 deletions.
223 changes: 212 additions & 11 deletions ballista/rust/client/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,6 @@ mod tests {
BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
};
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::assert_batches_eq;
let config = BallistaConfigBuilder::default()
.set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
.build()
Expand Down Expand Up @@ -696,13 +695,15 @@ mod tests {

#[tokio::test]
#[cfg(feature = "standalone")]
async fn test_percentile_func() {
async fn test_aggregate_func() {
use crate::context::BallistaContext;
use ballista_core::config::{
BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
};
use datafusion::arrow;
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::prelude::ParquetReadOptions;

let config = BallistaConfigBuilder::default()
.set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
.build()
Expand All @@ -718,6 +719,199 @@ mod tests {
)
.await
.unwrap();

let df = context.sql("select min(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
"| MIN(test.id) |",
"+--------------+",
"| 0 |",
"+--------------+",
];
assert_result_eq(expected, &*res);

let df = context.sql("select max(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
"| MAX(test.id) |",
"+--------------+",
"| 7 |",
"+--------------+",
];
assert_result_eq(expected, &*res);

let df = context.sql("select SUM(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
"| SUM(test.id) |",
"+--------------+",
"| 28 |",
"+--------------+",
];
assert_result_eq(expected, &*res);

let df = context.sql("select AVG(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------+",
"| AVG(test.id) |",
"+--------------+",
"| 3.5 |",
"+--------------+",
];
assert_result_eq(expected, &*res);

let df = context.sql("select COUNT(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+----------------+",
"| COUNT(test.id) |",
"+----------------+",
"| 8 |",
"+----------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select approx_distinct(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------------+",
"| APPROXDISTINCT(test.id) |",
"+-------------------------+",
"| 8 |",
"+-------------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select ARRAY_AGG(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------+",
"| ARRAYAGG(test.id) |",
"+--------------------------+",
"| [4, 5, 6, 7, 2, 3, 0, 1] |",
"+--------------------------+",
];
assert_result_eq(expected, &*res);

let df = context.sql("select VAR(\"id\") from test").await.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------+",
"| VARIANCE(test.id) |",
"+-------------------+",
"| 6.000000000000001 |",
"+-------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select VAR_POP(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+----------------------+",
"| VARIANCEPOP(test.id) |",
"+----------------------+",
"| 5.250000000000001 |",
"+----------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select VAR_SAMP(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+-------------------+",
"| VARIANCE(test.id) |",
"+-------------------+",
"| 6.000000000000001 |",
"+-------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select STDDEV(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------+",
"| STDDEV(test.id) |",
"+--------------------+",
"| 2.4494897427831783 |",
"+--------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select STDDEV_SAMP(\"id\") from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------+",
"| STDDEV(test.id) |",
"+--------------------+",
"| 2.4494897427831783 |",
"+--------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select COVAR(id, tinyint_col) from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+--------------------------------------+",
"| COVARIANCE(test.id,test.tinyint_col) |",
"+--------------------------------------+",
"| 0.28571428571428586 |",
"+--------------------------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select CORR(id, tinyint_col) from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+---------------------------------------+",
"| CORRELATION(test.id,test.tinyint_col) |",
"+---------------------------------------+",
"| 0.21821789023599245 |",
"+---------------------------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select approx_percentile_cont_with_weight(\"id\", 2, 0.5) from test")
.await
.unwrap();
let res = df.collect().await.unwrap();
let expected = vec![
"+---------------------------------------------------------------+",
"| APPROXPERCENTILECONTWITHWEIGHT(test.id,Int64(2),Float64(0.5)) |",
"+---------------------------------------------------------------+",
"| 1 |",
"+---------------------------------------------------------------+",
];
assert_result_eq(expected, &*res);

let df = context
.sql("select approx_percentile_cont(\"double_col\", 0.5) from test")
.await
Expand All @@ -731,14 +925,21 @@ mod tests {
"+----------------------------------------------------+",
];

assert_eq!(
expected,
pretty_format_batches(&*res)
.unwrap()
.to_string()
.trim()
.lines()
.collect::<Vec<&str>>()
);
assert_result_eq(expected, &*res);

fn assert_result_eq(
expected: Vec<&str>,
results: &[arrow::record_batch::RecordBatch],
) {
assert_eq!(
expected,
pretty_format_batches(results)
.unwrap()
.to_string()
.trim()
.lines()
.collect::<Vec<&str>>()
);
}
}
}
1 change: 1 addition & 0 deletions ballista/rust/core/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ enum AggregateFunction {
CORRELATION=13;
APPROX_PERCENTILE_CONT = 14;
APPROX_MEDIAN=15;
APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
}

message AggregateExprNode {
Expand Down
6 changes: 6 additions & 0 deletions ballista/rust/core/src/serde/physical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ impl TryInto<protobuf::PhysicalExprNode> for Arc<dyn AggregateExpr> {
.is_some()
{
Ok(AggregateFunction::ApproxPercentileCont.into())
} else if self
.as_any()
.downcast_ref::<expressions::ApproxPercentileContWithWeight>()
.is_some()
{
Ok(AggregateFunction::ApproxPercentileContWithWeight.into())
} else if self
.as_any()
.downcast_ref::<expressions::ApproxMedian>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pub struct ApproxPercentileContWithWeight {
approx_percentile_cont: ApproxPercentileCont,
column_expr: Arc<dyn PhysicalExpr>,
weight_expr: Arc<dyn PhysicalExpr>,
percentile_expr: Arc<dyn PhysicalExpr>,
}

impl ApproxPercentileContWithWeight {
Expand All @@ -58,6 +59,7 @@ impl ApproxPercentileContWithWeight {
approx_percentile_cont,
column_expr: expr[0].clone(),
weight_expr: expr[1].clone(),
percentile_expr: expr[2].clone(),
})
}
}
Expand All @@ -79,7 +81,11 @@ impl AggregateExpr for ApproxPercentileContWithWeight {
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.column_expr.clone(), self.weight_expr.clone()]
vec![
self.column_expr.clone(),
self.weight_expr.clone(),
self.percentile_expr.clone(),
]
}

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Expand Down Expand Up @@ -115,11 +121,6 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
debug_assert_eq!(
values.len(),
2,
"invalid number of values in batch percentile update"
);
let means = &values[0];
let weights = &values[1];
debug_assert_eq!(
Expand Down

0 comments on commit c91efc2

Please sign in to comment.