diff --git a/src/expr/core/src/window_function/states.rs b/src/expr/core/src/window_function/states.rs index 63e9d0f8a6e4b..8b9e9a7d25d19 100644 --- a/src/expr/core/src/window_function/states.rs +++ b/src/expr/core/src/window_function/states.rs @@ -86,6 +86,17 @@ impl WindowStates { } Ok(()) } + + /// Slide all windows forward, until the current key is `curr_key`, ignoring the output and evict hints. + /// After this method, `self.curr_key() == Some(curr_key)`. + /// `curr_key` must exist in the `WindowStates`. + pub fn just_slide_to(&mut self, curr_key: &StateKey) -> Result<()> { + // TODO(rc): with the knowledge of the old output, we can "jump" to the `curr_key` directly for some window function kind + while self.curr_key() != Some(curr_key) { + self.just_slide()?; + } + Ok(()) + } } impl Deref for WindowStates { diff --git a/src/stream/src/executor/over_window/general.rs b/src/stream/src/executor/over_window/general.rs index 23623be6e0f2c..96e6d87c19977 100644 --- a/src/stream/src/executor/over_window/general.rs +++ b/src/stream/src/executor/over_window/general.rs @@ -16,19 +16,15 @@ use std::collections::{btree_map, BTreeMap, HashSet}; use std::marker::PhantomData; use std::ops::RangeInclusive; -use delta_btree_map::{Change, PositionType}; -use itertools::Itertools; +use delta_btree_map::Change; use risingwave_common::array::stream_record::Record; use risingwave_common::array::Op; use risingwave_common::row::RowExt; use risingwave_common::session_config::OverWindowCachePolicy as CachePolicy; use risingwave_common::types::DefaultOrdered; -use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common::util::memcmp_encoding::{self, MemcmpEncoded}; use risingwave_common::util::sort_util::OrderType; -use risingwave_expr::window_function::{ - create_window_state, StateKey, WindowFuncCall, WindowStates, -}; +use risingwave_expr::window_function::{StateKey, WindowFuncCall}; use super::over_partition::{ new_empty_partition_cache, shrink_partition_cache, CacheKey, OverPartition, PartitionCache, @@ -38,7 +34,6 @@ use crate::cache::ManagedLruCache; use crate::common::metrics::MetricsInfo; use crate::consistency::consistency_panic; use crate::executor::monitor::OverWindowMetrics; -use crate::executor::over_window::over_partition::AffectedRange; use crate::executor::prelude::*; /// [`OverWindowExecutor`] consumes retractable input stream and produces window function outputs. @@ -61,7 +56,6 @@ struct ExecutorInner { order_key_data_types: Vec, order_key_order_types: Vec, input_pk_indices: Vec, - input_schema_len: usize, state_key_to_table_sub_pk_proj: Vec, state_table: StateTable, @@ -179,7 +173,6 @@ impl OverWindowExecutor { order_key_data_types, order_key_order_types: args.order_key_order_types, input_pk_indices: input_info.pk_indices, - input_schema_len: input_schema.len(), state_key_to_table_sub_pk_proj, state_table: args.state_table, watermark_sequence: args.watermark_epoch, @@ -350,8 +343,9 @@ impl OverWindowExecutor { ); // Build changes for current partition. - let (part_changes, accessed_range) = - Self::build_changes_for_partition(this, &mut partition, delta).await?; + let (part_changes, accessed_range) = partition + .build_changes(&this.state_table, &this.calls, delta) + .await?; for (key, record) in part_changes { // Build chunk and yield if needed. @@ -452,155 +446,6 @@ impl OverWindowExecutor { } } - async fn build_changes_for_partition( - this: &ExecutorInner, - partition: &mut OverPartition<'_, S>, - mut delta: PartitionDelta, - ) -> StreamExecutorResult<( - BTreeMap>, - Option>, - )> { - let mut part_changes = BTreeMap::new(); - - // Find affected ranges, this also ensures that all rows in the affected ranges are loaded - // into the cache. - let (part_with_delta, affected_ranges) = partition - .find_affected_ranges(&this.state_table, &mut delta) - .await?; - - let snapshot = part_with_delta.snapshot(); - let delta = part_with_delta.delta(); - - // Generate delete changes first, because deletes are skipped during iteration over - // `part_with_delta` in the next step. - for (key, change) in delta { - if change.is_delete() { - part_changes.insert( - key.as_normal_expect().clone(), - Record::Delete { - old_row: snapshot.get(key).unwrap().clone(), - }, - ); - } - } - - let mut accessed_range: Option> = None; - - for AffectedRange { - first_frame_start, - first_curr_key, - last_curr_key, - last_frame_end, - } in affected_ranges - { - assert!(first_frame_start <= first_curr_key); - assert!(first_curr_key <= last_curr_key); - assert!(last_curr_key <= last_frame_end); - assert!(first_frame_start.is_normal()); - assert!(first_curr_key.is_normal()); - assert!(last_curr_key.is_normal()); - assert!(last_frame_end.is_normal()); - - if let Some(accessed_range) = accessed_range.as_mut() { - let min_start = first_frame_start - .as_normal_expect() - .min(accessed_range.start()) - .clone(); - let max_end = last_frame_end - .as_normal_expect() - .max(accessed_range.end()) - .clone(); - *accessed_range = min_start..=max_end; - } else { - accessed_range = Some( - first_frame_start.as_normal_expect().clone() - ..=last_frame_end.as_normal_expect().clone(), - ); - } - - let mut states = - WindowStates::new(this.calls.iter().map(create_window_state).try_collect()?); - - // Populate window states with the affected range of rows. - { - let mut cursor = part_with_delta - .find(first_frame_start) - .expect("first frame start key must exist"); - while { - let (key, row) = cursor - .key_value() - .expect("cursor must be valid until `last_frame_end`"); - - for (call, state) in this.calls.iter().zip_eq_fast(states.iter_mut()) { - // TODO(rc): batch appending - state.append( - key.as_normal_expect().clone(), - row.project(call.args.val_indices()) - .into_owned_row() - .as_inner() - .into(), - ); - } - cursor.move_next(); - - key != last_frame_end - } {} - } - - // Slide to the first affected key. We can safely compare to `Some(first_curr_key)` here - // because it must exist in the states, by the definition of affected range. - while states.curr_key() != Some(first_curr_key.as_normal_expect()) { - states.just_slide()?; - } - let mut curr_key_cursor = part_with_delta.find(first_curr_key).unwrap(); - assert_eq!( - states.curr_key(), - curr_key_cursor.key().map(CacheKey::as_normal_expect) - ); - - // Slide and generate changes. - while { - let (key, row) = curr_key_cursor - .key_value() - .expect("cursor must be valid until `last_curr_key`"); - let output = states.slide_no_evict_hint()?; - let new_row = OwnedRow::new( - row.as_inner() - .iter() - .take(this.input_schema_len) - .cloned() - .chain(output) - .collect(), - ); - - match curr_key_cursor.position() { - PositionType::Ghost => unreachable!(), - PositionType::Snapshot | PositionType::DeltaUpdate => { - // update - let old_row = snapshot.get(key).unwrap().clone(); - if old_row != new_row { - part_changes.insert( - key.as_normal_expect().clone(), - Record::Update { old_row, new_row }, - ); - } - } - PositionType::DeltaInsert => { - // insert - part_changes - .insert(key.as_normal_expect().clone(), Record::Insert { new_row }); - } - } - - curr_key_cursor.move_next(); - - key != last_curr_key - } {} - } - - Ok((part_changes, accessed_range)) - } - #[try_stream(ok = Message, error = StreamExecutorError)] async fn executor_inner(self) { let OverWindowExecutor { diff --git a/src/stream/src/executor/over_window/over_partition.rs b/src/stream/src/executor/over_window/over_partition.rs index f81ba9d89a50e..fcc481acda2ee 100644 --- a/src/stream/src/executor/over_window/over_partition.rs +++ b/src/stream/src/executor/over_window/over_partition.rs @@ -19,16 +19,17 @@ use std::collections::BTreeMap; use std::marker::PhantomData; use std::ops::{Bound, RangeInclusive}; -use delta_btree_map::{Change, DeltaBTreeMap}; +use delta_btree_map::{Change, DeltaBTreeMap, PositionType}; use educe::Educe; use futures_async_stream::for_await; use risingwave_common::array::stream_record::Record; -use risingwave_common::row::{OwnedRow, Row}; +use risingwave_common::row::{OwnedRow, Row, RowExt}; use risingwave_common::session_config::OverWindowCachePolicy as CachePolicy; use risingwave_common::types::{Datum, Sentinelled}; +use risingwave_common::util::iter_util::ZipEqFast; use risingwave_common_estimate_size::collections::EstimatedBTreeMap; use risingwave_expr::window_function::{ - RangeFrameBounds, RowsFrameBounds, StateKey, WindowFuncCall, + create_window_state, RangeFrameBounds, RowsFrameBounds, StateKey, WindowFuncCall, WindowStates, }; use risingwave_storage::store::PrefetchOptions; use risingwave_storage::StateStore; @@ -406,6 +407,157 @@ impl<'a, S: StateStore> OverPartition<'a, S> { .unwrap_or(false) } + /// Build changes for the partition, with the given `delta`. Necessary maintenance of the range + /// cache will be done during this process, like loading rows from the `table` into the cache. + pub async fn build_changes( + &mut self, + table: &StateTable, + calls: &[WindowFuncCall], + mut delta: PartitionDelta, + ) -> StreamExecutorResult<( + BTreeMap>, + Option>, + )> { + let input_schema_len = table.get_data_types().len() - calls.len(); + let mut part_changes = BTreeMap::new(); + + // Find affected ranges, this also ensures that all rows in the affected ranges are loaded + // into the cache. + let (part_with_delta, affected_ranges) = + self.find_affected_ranges(table, &mut delta).await?; + + let snapshot = part_with_delta.snapshot(); + let delta = part_with_delta.delta(); + + // Generate delete changes first, because deletes are skipped during iteration over + // `part_with_delta` in the next step. + for (key, change) in delta { + if change.is_delete() { + part_changes.insert( + key.as_normal_expect().clone(), + Record::Delete { + old_row: snapshot.get(key).unwrap().clone(), + }, + ); + } + } + + let mut accessed_range: Option> = None; + + for AffectedRange { + first_frame_start, + first_curr_key, + last_curr_key, + last_frame_end, + } in affected_ranges + { + assert!(first_frame_start <= first_curr_key); + assert!(first_curr_key <= last_curr_key); + assert!(last_curr_key <= last_frame_end); + assert!(first_frame_start.is_normal()); + assert!(first_curr_key.is_normal()); + assert!(last_curr_key.is_normal()); + assert!(last_frame_end.is_normal()); + + if let Some(accessed_range) = accessed_range.as_mut() { + let min_start = first_frame_start + .as_normal_expect() + .min(accessed_range.start()) + .clone(); + let max_end = last_frame_end + .as_normal_expect() + .max(accessed_range.end()) + .clone(); + *accessed_range = min_start..=max_end; + } else { + accessed_range = Some( + first_frame_start.as_normal_expect().clone() + ..=last_frame_end.as_normal_expect().clone(), + ); + } + + let mut states = + WindowStates::new(calls.iter().map(create_window_state).try_collect()?); + + // Populate window states with the affected range of rows. + { + let mut cursor = part_with_delta + .find(first_frame_start) + .expect("first frame start key must exist"); + while { + let (key, row) = cursor + .key_value() + .expect("cursor must be valid until `last_frame_end`"); + + for (call, state) in calls.iter().zip_eq_fast(states.iter_mut()) { + // TODO(rc): batch appending + // TODO(rc): append not only the arguments but also the old output for optimization + state.append( + key.as_normal_expect().clone(), + row.project(call.args.val_indices()) + .into_owned_row() + .as_inner() + .into(), + ); + } + cursor.move_next(); + + key != last_frame_end + } {} + } + + // Slide to the first affected key. We can safely pass in `first_curr_key` here + // because it definitely exists in the states by the definition of affected range. + states.just_slide_to(first_curr_key.as_normal_expect())?; + let mut curr_key_cursor = part_with_delta.find(first_curr_key).unwrap(); + assert_eq!( + states.curr_key(), + curr_key_cursor.key().map(CacheKey::as_normal_expect) + ); + + // Slide and generate changes. + while { + let (key, row) = curr_key_cursor + .key_value() + .expect("cursor must be valid until `last_curr_key`"); + let output = states.slide_no_evict_hint()?; + let new_row = OwnedRow::new( + row.as_inner() + .iter() + .take(input_schema_len) + .cloned() + .chain(output) + .collect(), + ); + + match curr_key_cursor.position() { + PositionType::Ghost => unreachable!(), + PositionType::Snapshot | PositionType::DeltaUpdate => { + // update + let old_row = snapshot.get(key).unwrap().clone(); + if old_row != new_row { + part_changes.insert( + key.as_normal_expect().clone(), + Record::Update { old_row, new_row }, + ); + } + } + PositionType::DeltaInsert => { + // insert + part_changes + .insert(key.as_normal_expect().clone(), Record::Insert { new_row }); + } + } + + curr_key_cursor.move_next(); + + key != last_curr_key + } {} + } + + Ok((part_changes, accessed_range)) + } + /// Write a change record to state table and cache. /// This function must be called after finding affected ranges, which means the change records /// should never exceed the cached range. @@ -437,7 +589,7 @@ impl<'a, S: StateStore> OverPartition<'a, S> { /// Find all ranges in the partition that are affected by the given delta. /// The returned ranges are guaranteed to be sorted and non-overlapping. All keys in the ranges /// are guaranteed to be cached, which means they should be [`Sentinelled::Normal`]s. - pub async fn find_affected_ranges<'s, 'delta>( + async fn find_affected_ranges<'s, 'delta>( &'s mut self, table: &StateTable, delta: &'delta mut PartitionDelta,