Skip to content

Commit

Permalink
fix(agg): fix embedded UDAF as window function (#18632)
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored Sep 25, 2024
1 parent 9360139 commit ce70a51
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 27 deletions.
66 changes: 66 additions & 0 deletions e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,69 @@ select t.value, sum00(weight) OVER (PARTITION BY value) from (values (1, 1), (nu
----
1 1
3 3

statement ok
drop aggregate sum00;

# https://github.com/risingwavelabs/risingwave/issues/18436

statement ok
CREATE TABLE exam_scores (
score_id int,
exam_id int,
student_id int,
score real,
exam_date timestamp
);

statement ok
INSERT INTO exam_scores (score_id, exam_id, student_id, score, exam_date)
VALUES
(1, 101, 1001, 85.5, '2022-01-10'),
(2, 101, 1002, 92.0, '2022-01-10'),
(3, 101, 1003, 78.5, '2022-01-10'),
(4, 102, 1001, 91.2, '2022-02-15'),
(5, 102, 1003, 88.9, '2022-02-15');

statement ok
create aggregate weighted_avg(value float, weight float) returns float language python as $$
def create_state():
return (0, 0)
def accumulate(state, value, weight):
if value is None or weight is None:
return state
(s, w) = state
s += value * weight
w += weight
return (s, w)
def retract(state, value, weight):
if value is None or weight is None:
return state
(s, w) = state
s -= value * weight
w -= weight
return (s, w)
def finish(state):
(sum, weight) = state
if weight == 0:
return None
else:
return sum / weight
$$;

query
SELECT
*,
weighted_avg(score, 1) OVER (
PARTITION BY "student_id"
ORDER BY "exam_date"
ROWS 2 PRECEDING
) AS "weighted_avg"
FROM exam_scores
ORDER BY "student_id", "exam_date";
----
1 101 1001 85.5 2022-01-10 00:00:00 85.5
4 102 1001 91.2 2022-02-15 00:00:00 88.3499984741211
2 101 1002 92 2022-01-10 00:00:00 92
3 101 1003 78.5 2022-01-10 00:00:00 78.5
5 102 1003 88.9 2022-02-15 00:00:00 83.70000076293945
17 changes: 16 additions & 1 deletion proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,20 @@ message AggCall {
ExprNode scalar = 9;
}

// The aggregation type.
//
// Ideally this should be used to encode the Rust `AggCall::agg_type` field, but historically we
// flattened it into multiple fields in proto `AggCall` - `kind` + `udf` + `scalar`. So this
// `AggType` proto type is only used by `WindowFunction` currently.
message AggType {
AggCall.Kind kind = 1;

// UDF metadata. Only present when the kind is `USER_DEFINED`.
optional UserDefinedFunctionMetadata udf_meta = 8;
// Wrapped scalar expression. Only present when the kind is `WRAP_SCALAR`.
optional ExprNode scalar_expr = 9;
}

message WindowFrame {
enum Type {
TYPE_UNSPECIFIED = 0;
Expand Down Expand Up @@ -562,7 +576,8 @@ message WindowFunction {

oneof type {
GeneralType general = 1;
AggCall.Kind aggregate = 2;
AggCall.Kind aggregate = 2 [deprecated = true]; // Deprecated since we have a new `aggregate2` variant.
AggType aggregate2 = 103;
}
repeated InputRef args = 3;
data.DataType return_type = 4;
Expand Down
39 changes: 35 additions & 4 deletions src/expr/core/src/aggregate/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use risingwave_common::types::{DataType, Datum};
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::util::value_encoding::DatumFromProtoExt;
pub use risingwave_pb::expr::agg_call::PbKind as PbAggKind;
use risingwave_pb::expr::{PbAggCall, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata};
use risingwave_pb::expr::{
PbAggCall, PbAggType, PbExprNode, PbInputRef, PbUserDefinedFunctionMetadata,
};

use crate::expr::{
build_from_prost, BoxedExpression, ExpectExt, Expression, LiteralExpression, Token,
Expand Down Expand Up @@ -65,7 +67,7 @@ pub struct AggCall {

impl AggCall {
pub fn from_protobuf(agg_call: &PbAggCall) -> Result<Self> {
let agg_type = AggType::from_protobuf(
let agg_type = AggType::from_protobuf_flatten(
agg_call.get_kind()?,
agg_call.udf.as_ref(),
agg_call.scalar.as_ref(),
Expand Down Expand Up @@ -160,7 +162,7 @@ impl<Iter: Iterator<Item = Token>> Parser<Iter> {
self.tokens.next(); // Consume the RParen

AggCall {
agg_type: AggType::from_protobuf(func, None, None).unwrap(),
agg_type: AggType::from_protobuf_flatten(func, None, None).unwrap(),
args: AggArgs {
data_types: children.iter().map(|(_, ty)| ty.clone()).collect(),
val_indices: children.iter().map(|(idx, _)| *idx).collect(),
Expand Down Expand Up @@ -260,7 +262,7 @@ impl From<PbAggKind> for AggType {
}

impl AggType {
pub fn from_protobuf(
pub fn from_protobuf_flatten(
pb_kind: PbAggKind,
user_defined: Option<&PbUserDefinedFunctionMetadata>,
scalar: Option<&PbExprNode>,
Expand All @@ -286,6 +288,35 @@ impl AggType {
Self::WrapScalar(_) => PbAggKind::WrapScalar,
}
}

pub fn from_protobuf(pb_type: &PbAggType) -> Result<Self> {
match PbAggKind::try_from(pb_type.kind).context("no such aggregate function type")? {
PbAggKind::Unspecified => bail!("Unrecognized agg."),
PbAggKind::UserDefined => Ok(AggType::UserDefined(pb_type.get_udf_meta()?.clone())),
PbAggKind::WrapScalar => Ok(AggType::WrapScalar(pb_type.get_scalar_expr()?.clone())),
kind => Ok(AggType::Builtin(kind)),
}
}

pub fn to_protobuf(&self) -> PbAggType {
match self {
Self::Builtin(kind) => PbAggType {
kind: *kind as _,
udf_meta: None,
scalar_expr: None,
},
Self::UserDefined(udf_meta) => PbAggType {
kind: PbAggKind::UserDefined as _,
udf_meta: Some(udf_meta.clone()),
scalar_expr: None,
},
Self::WrapScalar(scalar_expr) => PbAggType {
kind: PbAggKind::WrapScalar as _,
udf_meta: None,
scalar_expr: Some(scalar_expr.clone()),
},
}
}
}

/// Macros to generate match arms for `AggType`.
Expand Down
12 changes: 7 additions & 5 deletions src/expr/core/src/window_function/kind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use anyhow::Context;
use parse_display::{Display, FromStr};
use risingwave_common::bail;

Expand Down Expand Up @@ -51,11 +52,12 @@ impl WindowFuncKind {
Ok(PbGeneralType::Lead) => Self::Lead,
Err(_) => bail!("no such window function type"),
},
PbType::Aggregate(agg_type) => match PbAggKind::try_from(*agg_type) {
// TODO(runji): support UDAF and wrapped scalar functions
Ok(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type, None, None)?),
Err(_) => bail!("no such aggregate function type"),
},
PbType::Aggregate(kind) => Self::Aggregate(AggType::from_protobuf_flatten(
PbAggKind::try_from(*kind).context("no such aggregate function type")?,
None,
None,
)?),
PbType::Aggregate2(agg_type) => Self::Aggregate(AggType::from_protobuf(agg_type)?),
};
Ok(kind)
}
Expand Down
44 changes: 30 additions & 14 deletions src/expr/impl/src/window_function/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::{bail, must_match};
use risingwave_common_estimate_size::{EstimateSize, KvSize};
use risingwave_expr::aggregate::{
AggCall, AggType, AggregateFunction, AggregateState as AggImplState, BoxedAggregateFunction,
build_append_only, AggCall, AggType, AggregateFunction, AggregateState as AggImplState,
BoxedAggregateFunction,
};
use risingwave_expr::sig::FUNCTION_REGISTRY;
use risingwave_expr::window_function::{
Expand Down Expand Up @@ -63,19 +64,34 @@ pub(super) fn new(call: &WindowFuncCall) -> Result<BoxedWindowState> {
distinct: false,
direct_args: vec![],
};
// TODO(runji): support UDAF and wrapped scalar function
let agg_kind = must_match!(agg_type, AggType::Builtin(agg_kind) => agg_kind);
let agg_func_sig = FUNCTION_REGISTRY
.get(*agg_kind, &arg_data_types, &call.return_type)
.expect("the agg func must exist");
let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state()?;
(AggImpl::Incremental(init_state), true)
} else {
(AggImpl::Full, false)
};

let (agg_func, agg_impl, enable_delta) = match agg_type {
AggType::Builtin(kind) => {
let agg_func_sig = FUNCTION_REGISTRY
.get(*kind, &arg_data_types, &call.return_type)
.expect("the agg func must exist");
let agg_func = agg_func_sig.build_aggregate(&agg_call)?;
let (agg_impl, enable_delta) =
if agg_func_sig.is_retractable() && call.frame.exclusion.is_no_others() {
let init_state = agg_func.create_state()?;
(AggImpl::Incremental(init_state), true)
} else {
(AggImpl::Full, false)
};
(agg_func, agg_impl, enable_delta)
}
AggType::UserDefined(_) => {
// TODO(rc): utilize `retract` method of embedded UDAF to do incremental aggregation
let agg_func = build_append_only(&agg_call)?;
(agg_func, AggImpl::Full, false)
}
AggType::WrapScalar(_) => {
// we have to feed the wrapped scalar function with all the rows in the window,
// instead of doing incremental aggregation
let agg_func = build_append_only(&agg_call)?;
(agg_func, AggImpl::Full, false)
}
};

let this = match &call.frame.bounds {
FrameBounds::Rows(frame_bounds) => Box::new(AggregateState {
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/binder/expr/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ impl Binder {
None
};

let agg_type = if let Some(wrapped_agg_type) = wrapped_agg_type {
Some(wrapped_agg_type)
let agg_type = if wrapped_agg_type.is_some() {
wrapped_agg_type
} else if let Some(ref udf) = udf
&& udf.kind.is_aggregate()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl PlanWindowFunction {
DenseRank => PbType::General(PbGeneralType::DenseRank as _),
Lag => PbType::General(PbGeneralType::Lag as _),
Lead => PbType::General(PbGeneralType::Lead as _),
Aggregate(agg_type) => PbType::Aggregate(agg_type.to_protobuf_simple() as _),
Aggregate(agg_type) => PbType::Aggregate2(agg_type.to_protobuf()),
};

PbWindowFunction {
Expand Down

0 comments on commit ce70a51

Please sign in to comment.