From 083adadb5b44b7cbe97ac892025e4c4ffebd3387 Mon Sep 17 00:00:00 2001 From: Davis Silverman Date: Fri, 4 Oct 2019 08:27:44 -0600 Subject: [PATCH] ARROW-6657: [Rust] [DataFusion] Add Count Aggregate Expression Hi, I added this code, and the tests pass. I still need to actually test it using a real example, so I would say its not completely ready for merge yet. Closes #5513 from sinistersnare/ARROW-6657 and squashes the following commits: 64d0c00b0 formatting 12d0c2c56 Add Count Aggregate Expression Lead-authored-by: Davis Silverman Co-authored-by: Andy Grove Signed-off-by: Andy Grove --- rust/datafusion/src/execution/context.rs | 44 ++++++- .../execution/physical_plan/expressions.rs | 110 ++++++++++++++++++ 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index dc54b9978ec33..f07c8b987faab 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation; use crate::execution::physical_plan::common; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::expressions::{ - BinaryExpr, CastExpr, Column, Literal, Sum, + BinaryExpr, CastExpr, Column, Count, Literal, Sum, }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::merge::MergeExec; @@ -333,6 +333,9 @@ impl ExecutionContext { "sum" => Ok(Arc::new(Sum::new( self.create_physical_expr(&args[0], input_schema)?, ))), + "count" => Ok(Arc::new(Count::new( + self.create_physical_expr(&args[0], input_schema)?, + ))), other => Err(ExecutionError::NotImplemented(format!( "Unsupported aggregate function '{}'", other @@ -641,6 +644,45 @@ mod tests { Ok(()) } + #[test] + fn count_basic() -> Result<()> { + let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["10,10"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + Ok(()) + } + + #[test] + fn count_partitioned() -> Result<()> { + let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["40,40"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + Ok(()) + } + + #[test] + fn count_aggregated() -> Result<()> { + let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected = vec!["0,10", "1,10", "2,10", "3,10"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + Ok(()) + } + /// Execute SQL and return results fn execute(sql: &str, partition_count: usize) -> Result> { let tmp_dir = TempDir::new("execute")?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index f63b40cb688ff..5f535367efdb5 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -147,6 +147,7 @@ macro_rules! sum_accumulate { } }}; } + struct SumAccumulator { expr: Arc, sum: Option, @@ -207,6 +208,68 @@ pub fn sum(expr: Arc) -> Arc { Arc::new(Sum::new(expr)) } +/// COUNT aggregate expression +/// Returns the amount of non-null values of the given expression. +pub struct Count { + expr: Arc, +} + +impl Count { + /// Create a new COUNT aggregate function. + pub fn new(expr: Arc) -> Self { + Self { expr: expr } + } +} + +impl AggregateExpr for Count { + fn name(&self) -> String { + "COUNT".to_string() + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::UInt64) + } + + fn evaluate_input(&self, batch: &RecordBatch) -> Result { + self.expr.evaluate(batch) + } + + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(CountAccumulator { count: 0 })) + } + + fn create_combiner(&self, column_index: usize) -> Arc { + Arc::new(Sum::new(Arc::new(Column::new(column_index)))) + } +} + +struct CountAccumulator { + count: u64, +} + +impl Accumulator for CountAccumulator { + fn accumulate( + &mut self, + _batch: &RecordBatch, + array: &ArrayRef, + row_index: usize, + ) -> Result<()> { + if array.is_valid(row_index) { + self.count += 1; + } + Ok(()) + } + + fn get_value(&self) -> Result> { + Ok(Some(ScalarValue::UInt64(self.count))) + } +} + +/// Create a count expression +pub fn count(expr: Arc) -> Arc { + Arc::new(Count::new(expr)) +} + /// Invoke a compute kernel on a pair of arrays macro_rules! compute_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -702,6 +765,42 @@ mod tests { Ok(()) } + #[test] + fn count_elements() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(5))); + Ok(()) + } + + #[test] + fn count_with_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(3))); + Ok(()) + } + + #[test] + fn count_all_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let a = BooleanArray::from(vec![None, None, None, None, None, None, None, None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0))); + Ok(()) + } + + #[test] + fn count_empty() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]); + let a = BooleanArray::from(Vec::::new()); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0))); + Ok(()) + } + fn do_sum(batch: &RecordBatch) -> Result> { let sum = sum(col(0)); let accum = sum.create_accumulator(); @@ -712,4 +811,15 @@ mod tests { } accum.get_value() } + + fn do_count(batch: &RecordBatch) -> Result> { + let count = count(col(0)); + let accum = count.create_accumulator(); + let input = count.evaluate_input(batch)?; + let mut accum = accum.borrow_mut(); + for i in 0..batch.num_rows() { + accum.accumulate(&batch, &input, i)?; + } + accum.get_value() + } }