Skip to content

Commit

Permalink
fix(agg): fix first_value and last_value to not ignore NULLs (#19332
Browse files Browse the repository at this point in the history
)

Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored Nov 13, 2024
1 parent c93b92b commit c21a771
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 24 deletions.
88 changes: 88 additions & 0 deletions src/expr/impl/src/aggregate/first_last_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2024 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use risingwave_common::types::{Datum, ScalarRefImpl};
use risingwave_common_estimate_size::EstimateSize;
use risingwave_expr::aggregate;
use risingwave_expr::aggregate::AggStateDyn;

/// Note that different from `min` and `max`, `first_value` doesn't ignore `NULL` values.
///
/// ```slt
/// statement ok
/// create table t(v1 int, ts int);
///
/// statement ok
/// insert into t values (null, 1), (2, 2), (null, 3);
///
/// query I
/// select first_value(v1 order by ts) from t;
/// ----
/// NULL
///
/// statement ok
/// drop table t;
/// ```
#[aggregate("first_value(any) -> any")]
fn first_value(state: &mut FirstValueState, input: Option<ScalarRefImpl<'_>>) {
if state.0.is_none() {
state.0 = Some(input.map(|x| x.into_scalar_impl()));
}
}

#[derive(Debug, Clone, Default, EstimateSize)]
struct FirstValueState(Option<Datum>);

impl AggStateDyn for FirstValueState {}

impl From<&FirstValueState> for Datum {
fn from(state: &FirstValueState) -> Self {
if let Some(state) = &state.0 {
state.clone()
} else {
None
}
}
}

/// Note that different from `min` and `max`, `last_value` doesn't ignore `NULL` values.
///
/// ```slt
/// statement ok
/// create table t(v1 int, ts int);
///
/// statement ok
/// insert into t values (null, 1), (2, 2), (null, 3);
///
/// query I
/// select last_value(v1 order by ts) from t;
/// ----
/// NULL
///
/// statement ok
/// drop table t;
/// ```
#[aggregate("last_value(*) -> auto", state = "ref")] // TODO(rc): `last_value(any) -> any`
fn last_value<T>(_: Option<T>, input: Option<T>) -> Option<T> {
input
}

#[aggregate("internal_last_seen_value(*) -> auto", state = "ref", internal)]
fn internal_last_seen_value<T>(state: T, input: T, retract: bool) -> T {
if retract {
state
} else {
input
}
}
19 changes: 0 additions & 19 deletions src/expr/impl/src/aggregate/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,6 @@ fn max<T: Ord>(state: T, input: T) -> T {
state.max(input)
}

#[aggregate("first_value(*) -> auto", state = "ref")]
fn first_value<T>(state: T, _: T) -> T {
state
}

#[aggregate("last_value(*) -> auto", state = "ref")]
fn last_value<T>(_: T, input: T) -> T {
input
}

#[aggregate("internal_last_seen_value(*) -> auto", state = "ref", internal)]
fn internal_last_seen_value<T>(state: T, input: T, retract: bool) -> T {
if retract {
state
} else {
input
}
}

/// Note the following corner cases:
///
/// ```slt
Expand Down
1 change: 1 addition & 0 deletions src/expr/impl/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod bit_or;
mod bit_xor;
mod bool_and;
mod bool_or;
mod first_last_value;
mod general;
mod jsonb_agg;
mod mode;
Expand Down
14 changes: 9 additions & 5 deletions src/tests/sqlsmith/src/sql_gen/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,15 @@ impl<R: Rng> SqlGenerator<'_, R> {
data_type: AstDataType::SmallInt,
value: self.gen_int(i16::MIN as isize, i16::MAX as isize),
})),
T::Varchar => Expr::Value(Value::SingleQuotedString(
(0..10)
.map(|_| self.rng.sample(Alphanumeric) as char)
.collect(),
)),
T::Varchar => Expr::Cast {
// since we are generating random scalar literal, we should cast it to avoid unknown type
expr: Box::new(Expr::Value(Value::SingleQuotedString(
(0..10)
.map(|_| self.rng.sample(Alphanumeric) as char)
.collect(),
))),
data_type: AstDataType::Varchar,
},
T::Decimal => Expr::Nested(Box::new(Expr::Value(Value::Number(self.gen_float())))),
T::Float64 => Expr::Nested(Box::new(Expr::TypedString {
data_type: AstDataType::Float(None),
Expand Down

0 comments on commit c21a771

Please sign in to comment.