diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index 4caff5ca68448..b8aa09754756b 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -242,6 +242,10 @@ impl AggregateExpr for Max { } } + fn evaluate_input(&self, batch: &RecordBatch) -> Result { + self.expr.evaluate(batch) + } + fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(MaxAccumulator { expr: self.expr.clone(), @@ -289,8 +293,12 @@ struct MaxAccumulator { } impl Accumulator for MaxAccumulator { - fn accumulate(&mut self, batch: &RecordBatch, row_index: usize) -> Result<()> { - let array = self.expr.evaluate(batch)?; + fn accumulate( + &mut self, + batch: &RecordBatch, + array: &ArrayRef, + row_index: usize, + ) -> Result<()> { match self.expr.data_type(batch.schema())? { DataType::Int8 => { max_accumulate!(self, array, row_index, Int8Array, Int64, i64) @@ -373,6 +381,10 @@ impl AggregateExpr for Min { } } + fn evaluate_input(&self, batch: &RecordBatch) -> Result { + self.expr.evaluate(batch) + } + fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(MinAccumulator { expr: self.expr.clone(), @@ -420,8 +432,12 @@ struct MinAccumulator { } impl Accumulator for MinAccumulator { - fn accumulate(&mut self, batch: &RecordBatch, row_index: usize) -> Result<()> { - let array = self.expr.evaluate(batch)?; + fn accumulate( + &mut self, + batch: &RecordBatch, + array: &ArrayRef, + row_index: usize, + ) -> Result<()> { match self.expr.data_type(batch.schema())? { DataType::Int8 => { min_accumulate!(self, array, row_index, Int8Array, Int64, i64) @@ -1251,9 +1267,10 @@ mod tests { fn do_max(batch: &RecordBatch) -> Result> { let max = max(col(0)); let accum = max.create_accumulator(); + let input = max.evaluate_input(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { - accum.accumulate(&batch, i)?; + accum.accumulate(&batch, &input, i)?; } accum.get_value() } @@ -1261,9 +1278,10 @@ mod tests { fn do_min(batch: &RecordBatch) -> Result> { let min = min(col(0)); let accum = min.create_accumulator(); + let input = min.evaluate_input(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { - accum.accumulate(&batch, i)?; + accum.accumulate(&batch, &input, i)?; } accum.get_value() }