Skip to content

Commit

Permalink
refactor(over window): move build_changes to OverPartition for be…
Browse files Browse the repository at this point in the history
…tter modularity (#18846)

Signed-off-by: Richard Chien <[email protected]>
  • Loading branch information
stdrc authored Oct 10, 2024
1 parent b12460d commit 7720bee
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 164 deletions.
11 changes: 11 additions & 0 deletions src/expr/core/src/window_function/states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ impl WindowStates {
}
Ok(())
}

/// Slide all windows forward, until the current key is `curr_key`, ignoring the output and evict hints.
/// After this method, `self.curr_key() == Some(curr_key)`.
/// `curr_key` must exist in the `WindowStates`.
pub fn just_slide_to(&mut self, curr_key: &StateKey) -> Result<()> {
// TODO(rc): with the knowledge of the old output, we can "jump" to the `curr_key` directly for some window function kind
while self.curr_key() != Some(curr_key) {
self.just_slide()?;
}
Ok(())
}
}

impl Deref for WindowStates {
Expand Down
165 changes: 5 additions & 160 deletions src/stream/src/executor/over_window/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,15 @@ use std::collections::{btree_map, BTreeMap, HashSet};
use std::marker::PhantomData;
use std::ops::RangeInclusive;

use delta_btree_map::{Change, PositionType};
use itertools::Itertools;
use delta_btree_map::Change;
use risingwave_common::array::stream_record::Record;
use risingwave_common::array::Op;
use risingwave_common::row::RowExt;
use risingwave_common::session_config::OverWindowCachePolicy as CachePolicy;
use risingwave_common::types::DefaultOrdered;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_common::util::memcmp_encoding::{self, MemcmpEncoded};
use risingwave_common::util::sort_util::OrderType;
use risingwave_expr::window_function::{
create_window_state, StateKey, WindowFuncCall, WindowStates,
};
use risingwave_expr::window_function::{StateKey, WindowFuncCall};

use super::over_partition::{
new_empty_partition_cache, shrink_partition_cache, CacheKey, OverPartition, PartitionCache,
Expand All @@ -38,7 +34,6 @@ use crate::cache::ManagedLruCache;
use crate::common::metrics::MetricsInfo;
use crate::consistency::consistency_panic;
use crate::executor::monitor::OverWindowMetrics;
use crate::executor::over_window::over_partition::AffectedRange;
use crate::executor::prelude::*;

/// [`OverWindowExecutor`] consumes retractable input stream and produces window function outputs.
Expand All @@ -61,7 +56,6 @@ struct ExecutorInner<S: StateStore> {
order_key_data_types: Vec<DataType>,
order_key_order_types: Vec<OrderType>,
input_pk_indices: Vec<usize>,
input_schema_len: usize,
state_key_to_table_sub_pk_proj: Vec<usize>,

state_table: StateTable<S>,
Expand Down Expand Up @@ -179,7 +173,6 @@ impl<S: StateStore> OverWindowExecutor<S> {
order_key_data_types,
order_key_order_types: args.order_key_order_types,
input_pk_indices: input_info.pk_indices,
input_schema_len: input_schema.len(),
state_key_to_table_sub_pk_proj,
state_table: args.state_table,
watermark_sequence: args.watermark_epoch,
Expand Down Expand Up @@ -350,8 +343,9 @@ impl<S: StateStore> OverWindowExecutor<S> {
);

// Build changes for current partition.
let (part_changes, accessed_range) =
Self::build_changes_for_partition(this, &mut partition, delta).await?;
let (part_changes, accessed_range) = partition
.build_changes(&this.state_table, &this.calls, delta)
.await?;

for (key, record) in part_changes {
// Build chunk and yield if needed.
Expand Down Expand Up @@ -452,155 +446,6 @@ impl<S: StateStore> OverWindowExecutor<S> {
}
}

async fn build_changes_for_partition(
this: &ExecutorInner<S>,
partition: &mut OverPartition<'_, S>,
mut delta: PartitionDelta,
) -> StreamExecutorResult<(
BTreeMap<StateKey, Record<OwnedRow>>,
Option<RangeInclusive<StateKey>>,
)> {
let mut part_changes = BTreeMap::new();

// Find affected ranges, this also ensures that all rows in the affected ranges are loaded
// into the cache.
let (part_with_delta, affected_ranges) = partition
.find_affected_ranges(&this.state_table, &mut delta)
.await?;

let snapshot = part_with_delta.snapshot();
let delta = part_with_delta.delta();

// Generate delete changes first, because deletes are skipped during iteration over
// `part_with_delta` in the next step.
for (key, change) in delta {
if change.is_delete() {
part_changes.insert(
key.as_normal_expect().clone(),
Record::Delete {
old_row: snapshot.get(key).unwrap().clone(),
},
);
}
}

let mut accessed_range: Option<RangeInclusive<StateKey>> = None;

for AffectedRange {
first_frame_start,
first_curr_key,
last_curr_key,
last_frame_end,
} in affected_ranges
{
assert!(first_frame_start <= first_curr_key);
assert!(first_curr_key <= last_curr_key);
assert!(last_curr_key <= last_frame_end);
assert!(first_frame_start.is_normal());
assert!(first_curr_key.is_normal());
assert!(last_curr_key.is_normal());
assert!(last_frame_end.is_normal());

if let Some(accessed_range) = accessed_range.as_mut() {
let min_start = first_frame_start
.as_normal_expect()
.min(accessed_range.start())
.clone();
let max_end = last_frame_end
.as_normal_expect()
.max(accessed_range.end())
.clone();
*accessed_range = min_start..=max_end;
} else {
accessed_range = Some(
first_frame_start.as_normal_expect().clone()
..=last_frame_end.as_normal_expect().clone(),
);
}

let mut states =
WindowStates::new(this.calls.iter().map(create_window_state).try_collect()?);

// Populate window states with the affected range of rows.
{
let mut cursor = part_with_delta
.find(first_frame_start)
.expect("first frame start key must exist");
while {
let (key, row) = cursor
.key_value()
.expect("cursor must be valid until `last_frame_end`");

for (call, state) in this.calls.iter().zip_eq_fast(states.iter_mut()) {
// TODO(rc): batch appending
state.append(
key.as_normal_expect().clone(),
row.project(call.args.val_indices())
.into_owned_row()
.as_inner()
.into(),
);
}
cursor.move_next();

key != last_frame_end
} {}
}

// Slide to the first affected key. We can safely compare to `Some(first_curr_key)` here
// because it must exist in the states, by the definition of affected range.
while states.curr_key() != Some(first_curr_key.as_normal_expect()) {
states.just_slide()?;
}
let mut curr_key_cursor = part_with_delta.find(first_curr_key).unwrap();
assert_eq!(
states.curr_key(),
curr_key_cursor.key().map(CacheKey::as_normal_expect)
);

// Slide and generate changes.
while {
let (key, row) = curr_key_cursor
.key_value()
.expect("cursor must be valid until `last_curr_key`");
let output = states.slide_no_evict_hint()?;
let new_row = OwnedRow::new(
row.as_inner()
.iter()
.take(this.input_schema_len)
.cloned()
.chain(output)
.collect(),
);

match curr_key_cursor.position() {
PositionType::Ghost => unreachable!(),
PositionType::Snapshot | PositionType::DeltaUpdate => {
// update
let old_row = snapshot.get(key).unwrap().clone();
if old_row != new_row {
part_changes.insert(
key.as_normal_expect().clone(),
Record::Update { old_row, new_row },
);
}
}
PositionType::DeltaInsert => {
// insert
part_changes
.insert(key.as_normal_expect().clone(), Record::Insert { new_row });
}
}

curr_key_cursor.move_next();

key != last_curr_key
} {}
}

Ok((part_changes, accessed_range))
}

#[try_stream(ok = Message, error = StreamExecutorError)]
async fn executor_inner(self) {
let OverWindowExecutor {
Expand Down
Loading

0 comments on commit 7720bee

Please sign in to comment.