Skip to content

Commit

Permalink
feat(expr, agg): support PERCENTILE_CONT, PERCENTILE_DISC and `MO…
Browse files Browse the repository at this point in the history
…DE` aggregation (#10252)

Signed-off-by: Richard Chien <[email protected]>
Co-authored-by: Richard Chien <[email protected]>
Co-authored-by: Noel Kwan <[email protected]>
  • Loading branch information
3 people authored Jun 14, 2023
1 parent e3fe51b commit 90ee868
Show file tree
Hide file tree
Showing 11 changed files with 563 additions and 18 deletions.
33 changes: 33 additions & 0 deletions e2e_test/batch/aggregate/ordered_set_agg.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
statement error
select p, percentile_cont(p) within group (order by x::float8)
from generate_series(1,5) x,
(values (0::float8),(0.1),(0.25),(0.4),(0.5),(0.6),(0.75),(0.9),(1)) v(p)
group by p order by p;

statement error
select percentile_cont(array[0,1,0.25,0.75,0.5,1,0.3,0.32,0.35,0.38,0.4]) within group (order by x)
from generate_series(1,6) x;

statement error
select percentile_disc(array[0.25,0.5,0.75]) within group (order by x)
from unnest('{fred,jim,fred,jack,jill,fred,jill,jim,jim,sheila,jim,sheila}'::text[]) u(x);

statement error
select pg_collation_for(percentile_disc(1) within group (order by x collate "POSIX"))
from (values ('fred'),('jim')) v(x);

query RR
select
percentile_cont(0.5) within group (order by a),
percentile_disc(0.5) within group (order by a)
from (values(1::float8),(3),(5),(7)) t(a);
----
4 3

query RR
select
percentile_cont(0.25) within group (order by a),
percentile_disc(0.5) within group (order by a)
from (values(1::float8),(3),(5),(7)) t(a);
----
2.5 3
3 changes: 3 additions & 0 deletions src/expr/src/agg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ mod array_agg;
mod count_star;
mod general;
mod jsonb_agg;
mod mode;
mod percentile_cont;
mod percentile_disc;
mod string_agg;

// wrappers
Expand Down
126 changes: 126 additions & 0 deletions src/expr/src/agg/mode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::array::*;
use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::types::*;
use risingwave_expr_macro::build_aggregate;

use super::Aggregator;
use crate::agg::AggCall;
use crate::Result;

#[build_aggregate("mode(*) -> *")]
fn build(agg: AggCall) -> Result<Box<dyn Aggregator>> {
Ok(Box::new(Mode::new(agg.return_type)))
}

/// Computes the mode, the most frequent value of the aggregated argument (arbitrarily choosing the
/// first one if there are multiple equally-frequent values). The aggregated argument must be of a
/// sortable type.
///
/// ```slt
/// query I
/// select mode() within group (order by unnest) from unnest(array[1]);
/// ----
/// 1
///
/// query I
/// select mode() within group (order by unnest) from unnest(array[1,2,2,3,3,4,4,4]);
/// ----
/// 4
///
/// query R
/// select mode() within group (order by unnest) from unnest(array[0.1,0.2,0.2,0.4,0.4,0.3,0.3,0.4]);
/// ----
/// 0.4
///
/// query R
/// select mode() within group (order by unnest) from unnest(array[1,2,2,3,3,4,4,4,3]);
/// ----
/// 3
///
/// query T
/// select mode() within group (order by unnest) from unnest(array['1','2','2','3','3','4','4','4','3']);
/// ----
/// 3
///
/// query I
/// select mode() within group (order by unnest) from unnest(array[]::int[]);
/// ----
/// NULL
/// ```
#[derive(Clone, EstimateSize)]
pub struct Mode {
return_type: DataType,
cur_mode: Datum,
cur_mode_freq: usize,
cur_item: Datum,
cur_item_freq: usize,
}

impl Mode {
pub fn new(return_type: DataType) -> Self {
Self {
return_type,
cur_mode: None,
cur_mode_freq: 0,
cur_item: None,
cur_item_freq: 0,
}
}

fn add_datum(&mut self, datum_ref: DatumRef<'_>) {
let datum = datum_ref.to_owned_datum();
if datum.is_some() && self.cur_item == datum {
self.cur_item_freq += 1;
} else if datum.is_some() {
self.cur_item = datum;
self.cur_item_freq = 1;
}
if self.cur_item_freq > self.cur_mode_freq {
self.cur_mode = self.cur_item.clone();
self.cur_mode_freq = self.cur_item_freq;
}
}
}

#[async_trait::async_trait]
impl Aggregator for Mode {
fn return_type(&self) -> DataType {
self.return_type.clone()
}

async fn update_multi(
&mut self,
input: &DataChunk,
start_row_id: usize,
end_row_id: usize,
) -> Result<()> {
let array = input.column_at(0);
for row_id in start_row_id..end_row_id {
self.add_datum(array.value_at(row_id));
}
Ok(())
}

fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> {
builder.append(self.cur_mode.clone());
Ok(())
}

fn estimated_size(&self) -> usize {
EstimateSize::estimated_size(self)
}
}
132 changes: 132 additions & 0 deletions src/expr/src/agg/percentile_cont.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::array::*;
use risingwave_common::estimate_size::EstimateSize;
use risingwave_common::types::*;
use risingwave_expr_macro::build_aggregate;

use super::Aggregator;
use crate::agg::AggCall;
use crate::Result;

/// Computes the continuous percentile, a value corresponding to the specified fraction within the
/// ordered set of aggregated argument values. This will interpolate between adjacent input items if
/// needed.
///
/// ```slt
/// statement ok
/// create table t(x int, y bigint, z real, w double, v varchar);
///
/// statement ok
/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000');
///
/// query R
/// select percentile_cont(0.45) within group (order by x desc) from t;
/// ----
/// 2.1
///
/// query R
/// select percentile_cont(0.45) within group (order by y desc) from t;
/// ----
/// 21
///
/// query R
/// select percentile_cont(0.45) within group (order by z desc) from t;
/// ----
/// 210
///
/// query R
/// select percentile_cont(0.45) within group (order by w desc) from t;
/// ----
/// 2100
///
/// query R
/// select percentile_cont(NULL) within group (order by w desc) from t;
/// ----
/// NULL
///
/// statement ok
/// drop table t;
/// ```
#[build_aggregate("percentile_cont(float64) -> float64")]
fn build(agg: AggCall) -> Result<Box<dyn Aggregator>> {
let fraction: Option<f64> = agg.direct_args[0]
.literal()
.map(|x| (*x.as_float64()).into());
Ok(Box::new(PercentileCont::new(fraction)))
}

#[derive(Clone, EstimateSize)]
pub struct PercentileCont {
fractions: Option<f64>,
data: Vec<f64>,
}

impl PercentileCont {
pub fn new(fractions: Option<f64>) -> Self {
Self {
fractions,
data: vec![],
}
}

fn add_datum(&mut self, datum_ref: DatumRef<'_>) {
if let Some(datum) = datum_ref.to_owned_datum() {
self.data.push((*datum.as_float64()).into());
}
}
}

#[async_trait::async_trait]
impl Aggregator for PercentileCont {
fn return_type(&self) -> DataType {
DataType::Float64
}

async fn update_multi(
&mut self,
input: &DataChunk,
start_row_id: usize,
end_row_id: usize,
) -> Result<()> {
let array = input.column_at(0);
for row_id in start_row_id..end_row_id {
self.add_datum(array.value_at(row_id));
}
Ok(())
}

fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> {
if let Some(fractions) = self.fractions && !self.data.is_empty() {
let rn = fractions * (self.data.len() - 1) as f64;
let crn = f64::ceil(rn);
let frn = f64::floor(rn);
let result = if crn == frn {
self.data[crn as usize]
} else {
(crn - rn) * self.data[frn as usize]
+ (rn - frn) * self.data[crn as usize]
};
builder.append(Some(ScalarImpl::Float64(result.into())));
} else {
builder.append(Datum::None);
}
Ok(())
}

fn estimated_size(&self) -> usize {
EstimateSize::estimated_size(self)
}
}
Loading

0 comments on commit 90ee868

Please sign in to comment.