From 90ee868d9a168f7d38b190796d89cbd40159a988 Mon Sep 17 00:00:00 2001 From: Xinjing Hu Date: Wed, 14 Jun 2023 19:00:27 +0800 Subject: [PATCH] feat(expr, agg): support `PERCENTILE_CONT`, `PERCENTILE_DISC` and `MODE` aggregation (#10252) Signed-off-by: Richard Chien Co-authored-by: Richard Chien Co-authored-by: Noel Kwan <47273164+kwannoel@users.noreply.github.com> --- .../batch/aggregate/ordered_set_agg.slt.part | 33 ++++ src/expr/src/agg/mod.rs | 3 + src/expr/src/agg/mode.rs | 126 +++++++++++++++ src/expr/src/agg/percentile_cont.rs | 132 ++++++++++++++++ src/expr/src/agg/percentile_disc.rs | 143 ++++++++++++++++++ .../tests/testdata/input/agg.yaml | 25 +++ .../tests/testdata/output/agg.yaml | 43 +++++- src/frontend/src/binder/expr/function.rs | 39 ++++- src/frontend/src/expr/agg_call.rs | 16 +- .../src/optimizer/plan_node/logical_agg.rs | 18 +++ src/tests/sqlsmith/src/sql_gen/types.rs | 3 + 11 files changed, 563 insertions(+), 18 deletions(-) create mode 100644 e2e_test/batch/aggregate/ordered_set_agg.slt.part create mode 100644 src/expr/src/agg/mode.rs create mode 100644 src/expr/src/agg/percentile_cont.rs create mode 100644 src/expr/src/agg/percentile_disc.rs diff --git a/e2e_test/batch/aggregate/ordered_set_agg.slt.part b/e2e_test/batch/aggregate/ordered_set_agg.slt.part new file mode 100644 index 0000000000000..6cf42db843136 --- /dev/null +++ b/e2e_test/batch/aggregate/ordered_set_agg.slt.part @@ -0,0 +1,33 @@ +statement error +select p, percentile_cont(p) within group (order by x::float8) +from generate_series(1,5) x, + (values (0::float8),(0.1),(0.25),(0.4),(0.5),(0.6),(0.75),(0.9),(1)) v(p) +group by p order by p; + +statement error +select percentile_cont(array[0,1,0.25,0.75,0.5,1,0.3,0.32,0.35,0.38,0.4]) within group (order by x) +from generate_series(1,6) x; + +statement error +select percentile_disc(array[0.25,0.5,0.75]) within group (order by x) +from unnest('{fred,jim,fred,jack,jill,fred,jill,jim,jim,sheila,jim,sheila}'::text[]) u(x); + +statement error +select pg_collation_for(percentile_disc(1) within group (order by x collate "POSIX")) + from (values ('fred'),('jim')) v(x); + +query RR +select + percentile_cont(0.5) within group (order by a), + percentile_disc(0.5) within group (order by a) +from (values(1::float8),(3),(5),(7)) t(a); +---- +4 3 + +query RR +select + percentile_cont(0.25) within group (order by a), + percentile_disc(0.5) within group (order by a) +from (values(1::float8),(3),(5),(7)) t(a); +---- +2.5 3 diff --git a/src/expr/src/agg/mod.rs b/src/expr/src/agg/mod.rs index c03aa597a9067..b0d7722d240e8 100644 --- a/src/expr/src/agg/mod.rs +++ b/src/expr/src/agg/mod.rs @@ -28,6 +28,9 @@ mod array_agg; mod count_star; mod general; mod jsonb_agg; +mod mode; +mod percentile_cont; +mod percentile_disc; mod string_agg; // wrappers diff --git a/src/expr/src/agg/mode.rs b/src/expr/src/agg/mode.rs new file mode 100644 index 0000000000000..297e169dc93c2 --- /dev/null +++ b/src/expr/src/agg/mode.rs @@ -0,0 +1,126 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::*; +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::types::*; +use risingwave_expr_macro::build_aggregate; + +use super::Aggregator; +use crate::agg::AggCall; +use crate::Result; + +#[build_aggregate("mode(*) -> *")] +fn build(agg: AggCall) -> Result> { + Ok(Box::new(Mode::new(agg.return_type))) +} + +/// Computes the mode, the most frequent value of the aggregated argument (arbitrarily choosing the +/// first one if there are multiple equally-frequent values). The aggregated argument must be of a +/// sortable type. +/// +/// ```slt +/// query I +/// select mode() within group (order by unnest) from unnest(array[1]); +/// ---- +/// 1 +/// +/// query I +/// select mode() within group (order by unnest) from unnest(array[1,2,2,3,3,4,4,4]); +/// ---- +/// 4 +/// +/// query R +/// select mode() within group (order by unnest) from unnest(array[0.1,0.2,0.2,0.4,0.4,0.3,0.3,0.4]); +/// ---- +/// 0.4 +/// +/// query R +/// select mode() within group (order by unnest) from unnest(array[1,2,2,3,3,4,4,4,3]); +/// ---- +/// 3 +/// +/// query T +/// select mode() within group (order by unnest) from unnest(array['1','2','2','3','3','4','4','4','3']); +/// ---- +/// 3 +/// +/// query I +/// select mode() within group (order by unnest) from unnest(array[]::int[]); +/// ---- +/// NULL +/// ``` +#[derive(Clone, EstimateSize)] +pub struct Mode { + return_type: DataType, + cur_mode: Datum, + cur_mode_freq: usize, + cur_item: Datum, + cur_item_freq: usize, +} + +impl Mode { + pub fn new(return_type: DataType) -> Self { + Self { + return_type, + cur_mode: None, + cur_mode_freq: 0, + cur_item: None, + cur_item_freq: 0, + } + } + + fn add_datum(&mut self, datum_ref: DatumRef<'_>) { + let datum = datum_ref.to_owned_datum(); + if datum.is_some() && self.cur_item == datum { + self.cur_item_freq += 1; + } else if datum.is_some() { + self.cur_item = datum; + self.cur_item_freq = 1; + } + if self.cur_item_freq > self.cur_mode_freq { + self.cur_mode = self.cur_item.clone(); + self.cur_mode_freq = self.cur_item_freq; + } + } +} + +#[async_trait::async_trait] +impl Aggregator for Mode { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + async fn update_multi( + &mut self, + input: &DataChunk, + start_row_id: usize, + end_row_id: usize, + ) -> Result<()> { + let array = input.column_at(0); + for row_id in start_row_id..end_row_id { + self.add_datum(array.value_at(row_id)); + } + Ok(()) + } + + fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> { + builder.append(self.cur_mode.clone()); + Ok(()) + } + + fn estimated_size(&self) -> usize { + EstimateSize::estimated_size(self) + } +} diff --git a/src/expr/src/agg/percentile_cont.rs b/src/expr/src/agg/percentile_cont.rs new file mode 100644 index 0000000000000..1e88557712c5d --- /dev/null +++ b/src/expr/src/agg/percentile_cont.rs @@ -0,0 +1,132 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::*; +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::types::*; +use risingwave_expr_macro::build_aggregate; + +use super::Aggregator; +use crate::agg::AggCall; +use crate::Result; + +/// Computes the continuous percentile, a value corresponding to the specified fraction within the +/// ordered set of aggregated argument values. This will interpolate between adjacent input items if +/// needed. +/// +/// ```slt +/// statement ok +/// create table t(x int, y bigint, z real, w double, v varchar); +/// +/// statement ok +/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000'); +/// +/// query R +/// select percentile_cont(0.45) within group (order by x desc) from t; +/// ---- +/// 2.1 +/// +/// query R +/// select percentile_cont(0.45) within group (order by y desc) from t; +/// ---- +/// 21 +/// +/// query R +/// select percentile_cont(0.45) within group (order by z desc) from t; +/// ---- +/// 210 +/// +/// query R +/// select percentile_cont(0.45) within group (order by w desc) from t; +/// ---- +/// 2100 +/// +/// query R +/// select percentile_cont(NULL) within group (order by w desc) from t; +/// ---- +/// NULL +/// +/// statement ok +/// drop table t; +/// ``` +#[build_aggregate("percentile_cont(float64) -> float64")] +fn build(agg: AggCall) -> Result> { + let fraction: Option = agg.direct_args[0] + .literal() + .map(|x| (*x.as_float64()).into()); + Ok(Box::new(PercentileCont::new(fraction))) +} + +#[derive(Clone, EstimateSize)] +pub struct PercentileCont { + fractions: Option, + data: Vec, +} + +impl PercentileCont { + pub fn new(fractions: Option) -> Self { + Self { + fractions, + data: vec![], + } + } + + fn add_datum(&mut self, datum_ref: DatumRef<'_>) { + if let Some(datum) = datum_ref.to_owned_datum() { + self.data.push((*datum.as_float64()).into()); + } + } +} + +#[async_trait::async_trait] +impl Aggregator for PercentileCont { + fn return_type(&self) -> DataType { + DataType::Float64 + } + + async fn update_multi( + &mut self, + input: &DataChunk, + start_row_id: usize, + end_row_id: usize, + ) -> Result<()> { + let array = input.column_at(0); + for row_id in start_row_id..end_row_id { + self.add_datum(array.value_at(row_id)); + } + Ok(()) + } + + fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> { + if let Some(fractions) = self.fractions && !self.data.is_empty() { + let rn = fractions * (self.data.len() - 1) as f64; + let crn = f64::ceil(rn); + let frn = f64::floor(rn); + let result = if crn == frn { + self.data[crn as usize] + } else { + (crn - rn) * self.data[frn as usize] + + (rn - frn) * self.data[crn as usize] + }; + builder.append(Some(ScalarImpl::Float64(result.into()))); + } else { + builder.append(Datum::None); + } + Ok(()) + } + + fn estimated_size(&self) -> usize { + EstimateSize::estimated_size(self) + } +} diff --git a/src/expr/src/agg/percentile_disc.rs b/src/expr/src/agg/percentile_disc.rs new file mode 100644 index 0000000000000..a8ab7ccb0fe68 --- /dev/null +++ b/src/expr/src/agg/percentile_disc.rs @@ -0,0 +1,143 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::*; +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::types::*; +use risingwave_expr_macro::build_aggregate; + +use super::Aggregator; +use crate::agg::AggCall; +use crate::Result; + +/// Computes the discrete percentile, the first value within the ordered set of aggregated argument +/// values whose position in the ordering equals or exceeds the specified fraction. The aggregated +/// argument must be of a sortable type. +/// +/// ```slt +/// statement ok +/// create table t(x int, y bigint, z real, w double, v varchar); +/// +/// statement ok +/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000'); +/// +/// query R +/// select percentile_disc(0) within group (order by x) from t; +/// ---- +/// 1 +/// +/// query R +/// select percentile_disc(0.33) within group (order by y) from t; +/// ---- +/// 10 +/// +/// query R +/// select percentile_disc(0.34) within group (order by z) from t; +/// ---- +/// 200 +/// +/// query R +/// select percentile_disc(0.67) within group (order by w) from t +/// ---- +/// 3000 +/// +/// query R +/// select percentile_disc(1) within group (order by v) from t; +/// ---- +/// 30000 +/// +/// query R +/// select percentile_disc(NULL) within group (order by w) from t; +/// ---- +/// NULL +/// +/// statement ok +/// drop table t; +/// ``` +#[build_aggregate("percentile_disc(*) -> *")] +fn build(agg: AggCall) -> Result> { + let fraction: Option = agg.direct_args[0] + .literal() + .map(|x| (*x.as_float64()).into()); + Ok(Box::new(PercentileDisc::new(fraction, agg.return_type))) +} + +#[derive(Clone)] +pub struct PercentileDisc { + fractions: Option, + return_type: DataType, + data: Vec, +} + +impl EstimateSize for PercentileDisc { + fn estimated_heap_size(&self) -> usize { + self.data + .iter() + .fold(0, |acc, x| acc + x.estimated_heap_size()) + } +} + +impl PercentileDisc { + pub fn new(fractions: Option, return_type: DataType) -> Self { + Self { + fractions, + return_type, + data: vec![], + } + } + + fn add_datum(&mut self, datum_ref: DatumRef<'_>) { + if let Some(datum) = datum_ref.to_owned_datum() { + self.data.push(datum); + } + } +} + +#[async_trait::async_trait] +impl Aggregator for PercentileDisc { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + async fn update_multi( + &mut self, + input: &DataChunk, + start_row_id: usize, + end_row_id: usize, + ) -> Result<()> { + let array = input.column_at(0); + for row_id in start_row_id..end_row_id { + self.add_datum(array.value_at(row_id)); + } + Ok(()) + } + + fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> { + if let Some(fractions) = self.fractions && !self.data.is_empty() { + let rn = fractions * self.data.len() as f64; + if fractions == 0.0 { + builder.append(Some(self.data[0].clone())); + } else { + builder.append(Some(self.data[f64::ceil(rn) as usize - 1].clone())); + } + } else { + builder.append(Datum::None); + } + Ok(()) + } + + fn estimated_size(&self) -> usize { + EstimateSize::estimated_size(self) + } +} diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index 6836385aa9ad9..71f50323179ca 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -815,13 +815,38 @@ select percentile_cont('abc') within group (order by y) from t; expected_outputs: - binder_error +- sql: | + create table t (x int, y int); + select percentile_cont(1.3) within group (order by y) from t; + expected_outputs: + - binder_error - sql: | create table t (x int, y int); select percentile_cont(0, 0) within group (order by y) from t; expected_outputs: - binder_error +- sql: | + create table t (x int, y varchar); + select percentile_cont(0) within group (order by y) from t; + expected_outputs: + - binder_error - sql: | create table t (x int, y int); select percentile_cont(0) within group (order by y desc) from t; expected_outputs: - batch_plan +- sql: | + create table t (x int, y varchar); + select percentile_disc(1) within group (order by y desc) from t; + expected_outputs: + - batch_plan +- sql: | + create table t (x int, y varchar); + select mode() within group (order by y desc) from t; + expected_outputs: + - batch_plan +- sql: | + create table t (x int, y varchar); + select mode(1) within group (order by y desc) from t; + expected_outputs: + - binder_error \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index 0f7c52f181a37..a0777d5f906b9 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1428,7 +1428,15 @@ Bind error: failed to bind expression: percentile_cont('abc') Caused by: - Invalid input syntax: arg in percentile_cont must be double precision + Invalid input syntax: arg in percentile_cont must be float64 +- sql: | + create table t (x int, y int); + select percentile_cont(1.3) within group (order by y) from t; + binder_error: |- + Bind error: failed to bind expression: percentile_cont(1.3) + + Caused by: + Invalid input syntax: arg in percentile_cont must between 0 and 1 - sql: | create table t (x int, y int); select percentile_cont(0, 0) within group (order by y) from t; @@ -1437,10 +1445,41 @@ Caused by: Invalid input syntax: only one arg is expected in percentile_cont +- sql: | + create table t (x int, y varchar); + select percentile_cont(0) within group (order by y) from t; + binder_error: |- + Bind error: failed to bind expression: percentile_cont(0) + + Caused by: + Bind error: cannot cast type "varchar" to "double precision" in Implicit context - sql: | create table t (x int, y int); select percentile_cont(0) within group (order by y desc) from t; batch_plan: | - BatchSimpleAgg { aggs: [percentile_cont(t.y order_by(t.y DESC))] } + BatchSimpleAgg { aggs: [percentile_cont($expr1 order_by(t.y DESC))] } + └─BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [t.y::Float64 as $expr1, t.y] } + └─BatchScan { table: t, columns: [t.y], distribution: SomeShard } +- sql: | + create table t (x int, y varchar); + select percentile_disc(1) within group (order by y desc) from t; + batch_plan: | + BatchSimpleAgg { aggs: [percentile_disc(t.y order_by(t.y DESC))] } + └─BatchExchange { order: [], dist: Single } + └─BatchScan { table: t, columns: [t.y], distribution: SomeShard } +- sql: | + create table t (x int, y varchar); + select mode() within group (order by y desc) from t; + batch_plan: | + BatchSimpleAgg { aggs: [mode(t.y order_by(t.y DESC))] } └─BatchExchange { order: [], dist: Single } └─BatchScan { table: t, columns: [t.y], distribution: SomeShard } +- sql: | + create table t (x int, y varchar); + select mode(1) within group (order by y desc) from t; + binder_error: |- + Bind error: failed to bind expression: mode(1) + + Caused by: + Invalid input syntax: no arguments are expected in mode agg diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 121f2f99fbc4d..c1b089560bef4 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -159,8 +159,14 @@ impl Binder { )) .into()); } + if kind == AggKind::Mode && !f.args.is_empty() { + return Err(ErrorCode::InvalidInputSyntax( + "no arguments are expected in mode agg".to_string(), + ) + .into()); + } self.ensure_aggregate_allowed()?; - let inputs: Vec = if f.within_group.is_some() { + let mut inputs: Vec = if f.within_group.is_some() { f.within_group .iter() .map(|x| self.bind_function_expr_arg(FunctionArgExpr::Expr(x.expr.clone()))) @@ -173,6 +179,15 @@ impl Binder { .flatten_ok() .try_collect()? }; + if kind == AggKind::PercentileCont { + inputs[0] = inputs + .iter() + .exactly_one() + .unwrap() + .clone() + .cast_implicit(DataType::Float64)?; + } + if f.distinct { match &kind { AggKind::Count if inputs.is_empty() => { @@ -280,13 +295,23 @@ impl Binder { .cast_implicit(DataType::Float64)? .fold_const() { - Ok::<_, RwError>(vec![Literal::new(casted, DataType::Float64)]) + if casted + .clone() + .is_some_and(|x| !(0.0..=1.0).contains(&Into::::into(*x.as_float64()))) + { + Err(ErrorCode::InvalidInputSyntax(format!( + "arg in {} must between 0 and 1", + kind + )) + .into()) + } else { + Ok::<_, RwError>(vec![Literal::new(casted, DataType::Float64)]) + } } else { - Err(ErrorCode::InvalidInputSyntax(format!( - "arg in {} must be double precision", - kind - )) - .into()) + Err( + ErrorCode::InvalidInputSyntax(format!("arg in {} must be float64", kind)) + .into(), + ) } } else { Ok(vec![]) diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index daf8d30859e8b..6760376de1d81 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -69,14 +69,7 @@ impl AggCall { // XXX: some special cases that can not be handled by signature map. // may return list or struct type - ( - AggKind::Min - | AggKind::Max - | AggKind::FirstValue - | AggKind::PercentileDisc - | AggKind::Mode, - [input], - ) => input.clone(), + (AggKind::Min | AggKind::Max | AggKind::FirstValue, [input]) => input.clone(), (AggKind::ArrayAgg, [input]) => List(Box::new(input.clone())), // functions that are rewritten in the frontend and don't exist in the expr crate (AggKind::Avg, [input]) => match input { @@ -93,7 +86,12 @@ impl AggCall { Float32 | Float64 | Int256 => Float64, _ => return Err(err()), }, - (AggKind::PercentileCont, _) => Float64, + // Ordered-Set Aggregation + (AggKind::PercentileCont, [input]) => match input { + Float64 => Float64, + _ => return Err(err()), + }, + (AggKind::PercentileDisc | AggKind::Mode, [input]) => input.clone(), // other functions are handled by signature map _ => { diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 6cbbe5dccd3c1..c1a563310af33 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -1062,6 +1062,24 @@ fn new_stream_hash_agg(logical: Agg, vnode_col_idx: Option) -> S impl ToStream for LogicalAgg { fn to_stream(&self, ctx: &mut ToStreamContext) -> Result { + for agg_call in self.agg_calls() { + if matches!( + agg_call.agg_kind, + AggKind::BitAnd + | AggKind::BitOr + | AggKind::BoolAnd + | AggKind::BoolOr + | AggKind::PercentileCont + | AggKind::PercentileDisc + | AggKind::Mode + ) { + return Err(ErrorCode::NotImplemented( + format!("{} aggregation in materialized view", agg_call.agg_kind), + None.into(), + ) + .into()); + } + } let eowc = ctx.emit_on_window_close(); let stream_input = self.input().to_stream(ctx)?; diff --git a/src/tests/sqlsmith/src/sql_gen/types.rs b/src/tests/sqlsmith/src/sql_gen/types.rs index b53209e6e2432..9f0af59765693 100644 --- a/src/tests/sqlsmith/src/sql_gen/types.rs +++ b/src/tests/sqlsmith/src/sql_gen/types.rs @@ -219,6 +219,9 @@ pub(crate) static AGG_FUNC_TABLE: LazyLock>> = AggKind::BitOr, AggKind::BoolAnd, AggKind::BoolOr, + AggKind::PercentileCont, + AggKind::PercentileDisc, + AggKind::Mode, ] .contains(&func.func) })