From 87142933cbb19cbc992a87b9d2a6e57282bc7448 Mon Sep 17 00:00:00 2001 From: Dom Date: Sat, 8 Jan 2022 11:03:55 +0000 Subject: [PATCH] feat: approx_quantile dataframe function Adds the approx_quantile() dataframe function, and exports it in the prelude. --- datafusion/src/logical_plan/expr.rs | 9 +++++++++ datafusion/src/logical_plan/mod.rs | 14 +++++++------- datafusion/src/prelude.rs | 12 ++++++------ datafusion/tests/dataframe_functions.rs | 20 ++++++++++++++++++++ 4 files changed, 42 insertions(+), 13 deletions(-) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index dadc168530745..cb04035e92ca3 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -1635,6 +1635,15 @@ pub fn approx_distinct(expr: Expr) -> Expr { } } +/// Calculate an approximation of the specified `quantile` for `expr`. +pub fn approx_quantile(expr: Expr, quantile: Expr) -> Expr { + Expr::AggregateFunction { + fun: aggregates::AggregateFunction::ApproxQuantile, + distinct: false, + args: vec![expr, quantile], + } +} + // TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many // varying arity functions /// Create an convenience function representing a unary scalar function diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index 56fec3cf1a0c4..058f714d31168 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -36,13 +36,13 @@ pub use builder::{ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ - abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr, - bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr, - combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, - create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, - initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim, - max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random, - regexp_match, regexp_replace, repeat, replace, replace_col, reverse, + abs, acos, and, approx_distinct, approx_quantile, array, ascii, asin, atan, avg, + binary_expr, bit_length, btrim, case, ceil, character_length, chr, col, + columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, + create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, + floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, + lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length, + or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse, rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when, diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index abc75829ea17d..ab503761b2f49 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -30,10 +30,10 @@ pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; pub use crate::execution::options::AvroReadOptions; pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ - array, ascii, avg, bit_length, btrim, character_length, chr, col, concat, concat_ws, - count, create_udf, date_part, date_trunc, digest, in_list, initcap, left, length, - lit, lower, lpad, ltrim, max, md5, min, now, octet_length, random, regexp_match, - regexp_replace, repeat, replace, reverse, right, rpad, rtrim, sha224, sha256, sha384, - sha512, split_part, starts_with, strpos, substr, sum, to_hex, translate, trim, upper, - Column, JoinType, Partitioning, + approx_quantile, array, ascii, avg, bit_length, btrim, character_length, chr, col, + concat, concat_ws, count, create_udf, date_part, date_trunc, digest, in_list, + initcap, left, length, lit, lower, lpad, ltrim, max, md5, min, now, octet_length, + random, regexp_match, regexp_replace, repeat, replace, reverse, right, rpad, rtrim, + sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr, sum, to_hex, + translate, trim, upper, Column, JoinType, Partitioning, }; diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index c11aa141f003a..6cad6a8eab99e 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -153,6 +153,26 @@ async fn test_fn_btrim_with_chars() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_fn_approx_quantile() -> Result<()> { + let expr = approx_quantile(col("b"), lit(0.5)); + + let expected = vec![ + "+-------------------------------------+", + "| APPROXQUANTILE(test.b,Float64(0.5)) |", + "+-------------------------------------+", + "| 10 |", + "+-------------------------------------+", + ]; + + let df = create_test_table()?; + let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; + + assert_batches_eq!(expected, &batches); + + Ok(()) +} + #[tokio::test] async fn test_fn_character_length() -> Result<()> { let expr = character_length(col("a"));