Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify spilling merge logic in GroupedHashAggregate #12517

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 43 additions & 33 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Hash aggregation

use std::mem::take;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
Expand Down Expand Up @@ -65,6 +66,8 @@ pub(crate) enum ExecutionState {
///
/// See "partial aggregation" discussion on [`GroupedHashAggregateStream`]
SkippingAggregation,
/// Done reading input and producing output by merging spilled data
MergingSpills,
/// All input has been consumed and all groups have been emitted
Done,
}
Expand All @@ -88,9 +91,6 @@ struct SpillState {
/// aggregate_arguments for merging spilled data
merging_aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,

/// GROUP BY expressions for merging spilled data
merging_group_by: PhysicalGroupBy,

// ========================================================================
// STATES:
// Fields changes during execution. Can be buffer, or state flags that
Expand All @@ -99,9 +99,6 @@ struct SpillState {
/// If data has previously been spilled, the locations of the
/// spill files (in Arrow IPC format)
spills: Vec<RefCountedTempFile>,

/// true when streaming merge is in progress
is_stream_merging: bool,
}

/// Tracks if the aggregate should skip partial aggregations
Expand Down Expand Up @@ -514,9 +511,7 @@ impl GroupedHashAggregateStream {
spills: vec![],
spill_expr,
spill_schema: Arc::clone(&agg_schema),
is_stream_merging: false,
merging_aggregate_arguments,
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
};

// Skip aggregation is supported if:
Expand Down Expand Up @@ -667,6 +662,8 @@ impl Stream for GroupedHashAggregateStream {
}
}

ExecutionState::MergingSpills => return self.input.poll_next_unpin(cx),

ExecutionState::SkippingAggregation => {
match ready!(self.input.poll_next_unpin(cx)) {
Some(Ok(batch)) => {
Expand Down Expand Up @@ -740,26 +737,13 @@ impl GroupedHashAggregateStream {
/// Perform group-by aggregation for the given [`RecordBatch`].
fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
// Evaluate the grouping expressions
let group_by_values = if self.spill_state.is_stream_merging {
evaluate_group_by(&self.spill_state.merging_group_by, &batch)?
} else {
evaluate_group_by(&self.group_by, &batch)?
};
let group_by_values = evaluate_group_by(&self.group_by, &batch)?;

// Evaluate the aggregation expressions.
let input_values = if self.spill_state.is_stream_merging {
evaluate_many(&self.spill_state.merging_aggregate_arguments, &batch)?
} else {
evaluate_many(&self.aggregate_arguments, &batch)?
};
let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;

// Evaluate the filter expressions, if any, against the inputs
let filter_values = if self.spill_state.is_stream_merging {
let filter_expressions = vec![None; self.accumulators.len()];
evaluate_optional(&filter_expressions, &batch)?
} else {
evaluate_optional(&self.filter_expressions, &batch)?
};
let filter_values = evaluate_optional(&self.filter_expressions, &batch)?;

for group_values in &group_by_values {
// calculate the group indices for each input row
Expand Down Expand Up @@ -793,9 +777,7 @@ impl GroupedHashAggregateStream {
match self.mode {
AggregateMode::Partial
| AggregateMode::Single
| AggregateMode::SinglePartitioned
if !self.spill_state.is_stream_merging =>
{
| AggregateMode::SinglePartitioned => {
acc.update_batch(
values,
group_indices,
Expand Down Expand Up @@ -887,7 +869,6 @@ impl GroupedHashAggregateStream {
&& batch.num_rows() > 0
&& matches!(self.group_ordering, GroupOrdering::None)
&& !matches!(self.mode, AggregateMode::Partial)
&& !self.spill_state.is_stream_merging
&& self.update_memory_reservation().is_err()
{
// Use input batch (Partial mode) schema for spilling because
Expand Down Expand Up @@ -966,8 +947,7 @@ impl GroupedHashAggregateStream {
let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)?;
streams.push(stream);
}
self.spill_state.is_stream_merging = true;
self.input = streaming_merge(
let input = streaming_merge(
streams,
schema,
&self.spill_state.spill_expr,
Expand All @@ -976,8 +956,38 @@ impl GroupedHashAggregateStream {
None,
self.reservation.new_empty(),
)?;
self.input_done = false;
self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
let group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
let filter_expressions = vec![None; self.accumulators.len()];

let group_schema = group_schema(&self.schema, self.group_by.expr.len());
let aggregate_arguments = self.spill_state.merging_aggregate_arguments.clone();

self.input = Box::pin(Self {
schema: Arc::clone(&self.schema),
input,
mode: AggregateMode::Final,
accumulators: take(&mut self.accumulators),
aggregate_arguments: aggregate_arguments.clone(),
filter_expressions,
group_by: PhysicalGroupBy::new_single(self.group_by.expr.clone()),
reservation: self.reservation.new_empty(),
group_values: new_group_values(group_schema)?,
current_group_indices: Default::default(),
exec_state: ExecutionState::ReadingInput,
baseline_metrics: self.baseline_metrics.clone(),
batch_size: self.batch_size,
group_ordering,
input_done: false,
runtime: Arc::clone(&self.runtime),
spill_state: SpillState {
spill_expr: self.spill_state.spill_expr.clone(),
spill_schema: Arc::clone(&self.schema),
merging_aggregate_arguments: aggregate_arguments,
spills: vec![],
},
group_values_soft_limit: None,
skip_aggregation_probe: None,
});
Ok(())
}

Expand All @@ -1002,7 +1012,7 @@ impl GroupedHashAggregateStream {
} else {
// If spill files exist, stream-merge them.
self.update_merged_stream()?;
ExecutionState::ReadingInput
ExecutionState::MergingSpills
};
timer.done();
Ok(())
Expand Down