Skip to content

Commit

Permalink
ARROW-6736: [Rust] [DataFusion] Evaluate the input to the aggregate e…
Browse files Browse the repository at this point in the history
…xpression just once per batch

The current implementation of aggregate expressions in the new physical plan had a flaw where the input to the aggregate expression was repeatedly being evaluated (once per row instead of once per batch). This PR fixes this.

Closes #5542 from andygrove/ARROW-6736 and squashes the following commits:

f0fadaf <Andy Grove> Evaluate the input to the aggregate expression just once per batch

Authored-by: Andy Grove <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
andygrove committed Oct 4, 2019
1 parent a98a61d commit 399ab8f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
15 changes: 12 additions & 3 deletions rust/datafusion/src/execution/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ impl AggregateExpr for Sum {
}
}

fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
self.expr.evaluate(batch)
}

fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
Rc::new(RefCell::new(SumAccumulator {
expr: self.expr.clone(),
Expand Down Expand Up @@ -149,8 +153,12 @@ struct SumAccumulator {
}

impl Accumulator for SumAccumulator {
fn accumulate(&mut self, batch: &RecordBatch, row_index: usize) -> Result<()> {
let array = self.expr.evaluate(batch)?;
fn accumulate(
&mut self,
batch: &RecordBatch,
array: &ArrayRef,
row_index: usize,
) -> Result<()> {
match self.expr.data_type(batch.schema())? {
DataType::Int8 => {
sum_accumulate!(self, array, row_index, Int8Array, Int64, i64)
Expand Down Expand Up @@ -697,9 +705,10 @@ mod tests {
fn do_sum(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
let sum = sum(col(0));
let accum = sum.create_accumulator();
let input = sum.evaluate_input(batch)?;
let mut accum = accum.borrow_mut();
for i in 0..batch.num_rows() {
accum.accumulate(&batch, i)?;
accum.accumulate(&batch, &input, i)?;
}
accum.get_value()
}
Expand Down
29 changes: 26 additions & 3 deletions rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ impl BatchIterator for GroupedHashAggregateIterator {
.map(|expr| expr.evaluate(&batch))
.collect::<Result<Vec<_>>>()?;

// evaluate the inputs to the aggregate expressions for this batch
let aggr_input_values = self
.aggr_expr
.iter()
.map(|expr| expr.evaluate_input(&batch))
.collect::<Result<Vec<_>>>()?;

// iterate over each row in the batch
for row in 0..batch.num_rows() {
// create grouping key for this row
Expand All @@ -290,7 +297,10 @@ impl BatchIterator for GroupedHashAggregateIterator {
Some(accumulators) => {
let _ = accumulators
.iter()
.map(|accum| accum.borrow_mut().accumulate(&batch, row))
.zip(aggr_input_values.iter())
.map(|(accum, input)| {
accum.borrow_mut().accumulate(&batch, input, row)
})
.collect::<Result<Vec<_>>>()?;
Ok(true)
}
Expand All @@ -306,7 +316,10 @@ impl BatchIterator for GroupedHashAggregateIterator {

let _ = accumulators
.iter()
.map(|accum| accum.borrow_mut().accumulate(&batch, row))
.zip(aggr_input_values.iter())
.map(|(accum, input)| {
accum.borrow_mut().accumulate(&batch, input, row)
})
.collect::<Result<Vec<_>>>()?;

map.insert(key.clone(), accumulators);
Expand Down Expand Up @@ -511,11 +524,21 @@ impl BatchIterator for HashAggregateIterator {

// iterate over input and perform aggregation
while let Some(batch) = input.next()? {
// evaluate the inputs to the aggregate expressions for this batch
let aggr_input_values = self
.aggr_expr
.iter()
.map(|expr| expr.evaluate_input(&batch))
.collect::<Result<Vec<_>>>()?;

// iterate over each row in the batch
for row in 0..batch.num_rows() {
let _ = accumulators
.iter()
.map(|accum| accum.borrow_mut().accumulate(&batch, row))
.zip(aggr_input_values.iter())
.map(|(accum, input)| {
accum.borrow_mut().accumulate(&batch, input, row)
})
.collect::<Result<Vec<_>>>()?;
}
}
Expand Down
9 changes: 8 additions & 1 deletion rust/datafusion/src/execution/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub trait AggregateExpr: Send + Sync {
fn name(&self) -> String;
/// Get the data type of this expression, given the schema of the input
fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
/// Evaluate the expressioon being aggregated
fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>;
/// Create an accumulator for this aggregate expression
fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>>;
/// Create an aggregate expression for combining the results of accumulators from partitions.
Expand All @@ -76,7 +78,12 @@ pub trait AggregateExpr: Send + Sync {
/// Aggregate accumulator
pub trait Accumulator {
/// Update the accumulator based on a row in a batch
fn accumulate(&mut self, batch: &RecordBatch, row_index: usize) -> Result<()>;
fn accumulate(
&mut self,
batch: &RecordBatch,
input: &ArrayRef,
row_index: usize,
) -> Result<()>;
/// Get the final value for the accumulator
fn get_value(&self) -> Result<Option<ScalarValue>>;
}
Expand Down

0 comments on commit 399ab8f

Please sign in to comment.