diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 92bb22c950bf8..c606d82db8cfc 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation; use crate::execution::physical_plan::common; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::expressions::{ - BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, Sum, + Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, Sum, }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::merge::MergeExec; @@ -333,6 +333,9 @@ impl ExecutionContext { "sum" => Ok(Arc::new(Sum::new( self.create_physical_expr(&args[0], input_schema)?, ))), + "avg" => Ok(Arc::new(Avg::new( + self.create_physical_expr(&args[0], input_schema)?, + ))), "max" => Ok(Arc::new(Max::new( self.create_physical_expr(&args[0], input_schema)?, ))), @@ -636,6 +639,20 @@ mod tests { Ok(()) } + #[test] + fn aggregate_avg() -> Result<()> { + let results = execute("SELECT AVG(c1), AVG(c2) FROM test", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["1.5,5.5"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } + #[test] fn aggregate_max() -> Result<()> { let results = execute("SELECT MAX(c1), MAX(c2) FROM test", 4)?; @@ -678,6 +695,20 @@ mod tests { Ok(()) } + #[test] + fn aggregate_grouped_avg() -> Result<()> { + let results = execute("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } + #[test] fn aggregate_grouped_max() -> Result<()> { let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4)?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index b8aa09754756b..28940fb9d36dc 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -208,6 +208,129 @@ pub fn sum(expr: Arc) -> Arc { Arc::new(Sum::new(expr)) } +/// AVG aggregate expression +pub struct Avg { + expr: Arc, +} + +impl Avg { + /// Create a new AVG aggregate function + pub fn new(expr: Arc) -> Self { + Self { expr } + } +} + +impl AggregateExpr for Avg { + fn name(&self) -> String { + "AVG".to_string() + } + + fn data_type(&self, input_schema: &Schema) -> Result { + match self.expr.data_type(input_schema)? { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "AVG does not support {:?}", + other + ))), + } + } + + fn evaluate_input(&self, batch: &RecordBatch) -> Result { + self.expr.evaluate(batch) + } + + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(AvgAccumulator { + expr: self.expr.clone(), + sum: None, + count: None, + })) + } + + fn create_combiner(&self, column_index: usize) -> Arc { + Arc::new(Avg::new(Arc::new(Column::new(column_index)))) + } +} + +macro_rules! avg_accumulate { + ($SELF:ident, $ARRAY:ident, $ROW_INDEX:expr, $ARRAY_TYPE:ident) => {{ + if let Some(array) = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>() { + if $ARRAY.is_valid($ROW_INDEX) { + let value = array.value($ROW_INDEX); + match ($SELF.sum, $SELF.count) { + (Some(sum), Some(count)) => { + $SELF.sum = Some(sum + value as f64); + $SELF.count = Some(count + 1); + } + _ => { + $SELF.sum = Some(value as f64); + $SELF.count = Some(1); + } + }; + } + Ok(()) + } else { + Err(ExecutionError::General( + "Failed to downcast array".to_string(), + )) + } + }}; +} +struct AvgAccumulator { + expr: Arc, + sum: Option, + count: Option, +} + +impl Accumulator for AvgAccumulator { + fn accumulate( + &mut self, + batch: &RecordBatch, + array: &ArrayRef, + row_index: usize, + ) -> Result<()> { + match self.expr.data_type(batch.schema())? { + DataType::Int8 => avg_accumulate!(self, array, row_index, Int8Array), + DataType::Int16 => avg_accumulate!(self, array, row_index, Int16Array), + DataType::Int32 => avg_accumulate!(self, array, row_index, Int32Array), + DataType::Int64 => avg_accumulate!(self, array, row_index, Int64Array), + DataType::UInt8 => avg_accumulate!(self, array, row_index, UInt8Array), + DataType::UInt16 => avg_accumulate!(self, array, row_index, UInt16Array), + DataType::UInt32 => avg_accumulate!(self, array, row_index, UInt32Array), + DataType::UInt64 => avg_accumulate!(self, array, row_index, UInt64Array), + DataType::Float32 => avg_accumulate!(self, array, row_index, Float32Array), + DataType::Float64 => avg_accumulate!(self, array, row_index, Float64Array), + other => Err(ExecutionError::General(format!( + "AVG does not support {:?}", + other + ))), + } + } + + fn get_value(&self) -> Result> { + match (self.sum, self.count) { + (Some(sum), Some(count)) => { + Ok(Some(ScalarValue::Float64(sum / count as f64))) + } + _ => Ok(None), + } + } +} + +/// Create a avg expression +pub fn avg(expr: Arc) -> Arc { + Arc::new(Avg::new(expr)) +} + /// MAX aggregate expression pub struct Max { expr: Arc, @@ -1000,6 +1123,20 @@ mod tests { Ok(()) } + #[test] + fn avg_contract() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let avg = avg(col(0)); + assert_eq!("AVG".to_string(), avg.name()); + assert_eq!(DataType::Float64, avg.data_type(&schema)?); + + let combiner = avg.create_combiner(0); + assert_eq!("AVG".to_string(), combiner.name()); + assert_eq!(DataType::Float64, combiner.data_type(&schema)?); + + Ok(()) + } #[test] fn sum_i32() -> Result<()> { @@ -1013,6 +1150,18 @@ mod tests { Ok(()) } + #[test] + fn avg_i32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64))); + + Ok(()) + } + #[test] fn max_i32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -1049,6 +1198,18 @@ mod tests { Ok(()) } + #[test] + fn avg_i32_with_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3.25))); + + Ok(()) + } + #[test] fn max_i32_with_nulls() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -1109,6 +1270,18 @@ mod tests { Ok(()) } + #[test] + fn avg_i32_all_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![None, None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_avg(&batch)?, None); + + Ok(()) + } + #[test] fn sum_u32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); @@ -1121,6 +1294,18 @@ mod tests { Ok(()) } + #[test] + fn avg_u32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + + let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64))); + + Ok(()) + } + #[test] fn max_u32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); @@ -1157,6 +1342,18 @@ mod tests { Ok(()) } + #[test] + fn avg_f32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + + let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64))); + + Ok(()) + } + #[test] fn max_f32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); @@ -1193,6 +1390,18 @@ mod tests { Ok(()) } + #[test] + fn avg_f64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64))); + + Ok(()) + } + #[test] fn max_f64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); @@ -1296,4 +1505,15 @@ mod tests { } accum.get_value() } + + fn do_avg(batch: &RecordBatch) -> Result> { + let avg = avg(col(0)); + let accum = avg.create_accumulator(); + let input = avg.evaluate_input(batch)?; + let mut accum = accum.borrow_mut(); + for i in 0..batch.num_rows() { + accum.accumulate(&batch, &input, i)?; + } + accum.get_value() + } } diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 34235e3f8ad2a..bc4dc1c80c7e2 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -179,6 +179,22 @@ pub fn format_batch(batch: &RecordBatch) -> Vec { .unwrap() .value(row_index) )), + DataType::Float32 => s.push_str(&format!( + "{:?}", + array + .as_any() + .downcast_ref::() + .unwrap() + .value(row_index) + )), + DataType::Float64 => s.push_str(&format!( + "{:?}", + array + .as_any() + .downcast_ref::() + .unwrap() + .value(row_index) + )), _ => s.push('?'), } }