diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index dc54b9978ec33..f07c8b987faab 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation; use crate::execution::physical_plan::common; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::expressions::{ - BinaryExpr, CastExpr, Column, Literal, Sum, + BinaryExpr, CastExpr, Column, Count, Literal, Sum, }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::merge::MergeExec; @@ -333,6 +333,9 @@ impl ExecutionContext { "sum" => Ok(Arc::new(Sum::new( self.create_physical_expr(&args[0], input_schema)?, ))), + "count" => Ok(Arc::new(Count::new( + self.create_physical_expr(&args[0], input_schema)?, + ))), other => Err(ExecutionError::NotImplemented(format!( "Unsupported aggregate function '{}'", other @@ -641,6 +644,45 @@ mod tests { Ok(()) } + #[test] + fn count_basic() -> Result<()> { + let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["10,10"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + Ok(()) + } + + #[test] + fn count_partitioned() -> Result<()> { + let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["40,40"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + Ok(()) + } + + #[test] + fn count_aggregated() -> Result<()> { + let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected = vec!["0,10", "1,10", "2,10", "3,10"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + Ok(()) + } + /// Execute SQL and return results fn execute(sql: &str, partition_count: usize) -> Result> { let tmp_dir = TempDir::new("execute")?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index f63b40cb688ff..8b9f79171fd29 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -147,6 +147,7 @@ macro_rules! sum_accumulate { } }}; } + struct SumAccumulator { expr: Arc, sum: Option, @@ -207,6 +208,70 @@ pub fn sum(expr: Arc) -> Arc { Arc::new(Sum::new(expr)) } +/// COUNT aggregate expression +/// Returns the amount of non-null values of the given expression. +pub struct Count { + expr: Arc, +} + +impl Count { + /// Create a new COUNT aggregate function. + pub fn new(expr: Arc) -> Self { + Self { expr: expr } + } +} + +impl AggregateExpr for Count { + fn name(&self) -> String { + "COUNT".to_string() + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::UInt64) + } + + fn evaluate_input(&self, batch: &RecordBatch) -> Result { + self.expr.evaluate(batch) + } + + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(CountAccumulator { + count: 0, + })) + } + + fn create_combiner(&self, column_index: usize) -> Arc { + Arc::new(Sum::new(Arc::new(Column::new(column_index)))) + } +} + +struct CountAccumulator { + count: u64, +} + +impl Accumulator for CountAccumulator { + fn accumulate( + &mut self, + _batch: &RecordBatch, + array: &ArrayRef, + row_index: usize, + ) -> Result<()> { + if array.is_valid(row_index) { + self.count += 1; + } + Ok(()) + } + + fn get_value(&self) -> Result> { + Ok(Some(ScalarValue::UInt64(self.count))) + } +} + +/// Create a count expression +pub fn count(expr: Arc) -> Arc { + Arc::new(Count::new(expr)) +} + /// Invoke a compute kernel on a pair of arrays macro_rules! compute_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -702,6 +767,42 @@ mod tests { Ok(()) } + #[test] + fn count_elements() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(5))); + Ok(()) + } + + #[test] + fn count_with_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(3))); + Ok(()) + } + + #[test] + fn count_all_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let a = BooleanArray::from(vec![None, None, None, None, None, None, None, None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0))); + Ok(()) + } + + #[test] + fn count_empty() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let a = BooleanArray::from(Vec::::new()); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0))); + Ok(()) + } + fn do_sum(batch: &RecordBatch) -> Result> { let sum = sum(col(0)); let accum = sum.create_accumulator(); @@ -712,4 +813,15 @@ mod tests { } accum.get_value() } + + fn do_count(batch: &RecordBatch) -> Result> { + let count = count(col(0)); + let accum = count.create_accumulator(); + let input = count.evaluate_input(batch)?; + let mut accum = accum.borrow_mut(); + for i in 0..batch.num_rows() { + accum.accumulate(&batch, &input, i)?; + } + accum.get_value() + } }