Skip to content

Commit

Permalink
Add Count Aggregate Expression
Browse files Browse the repository at this point in the history
  • Loading branch information
sinistersnare committed Sep 28, 2019
1 parent 7fb6b75 commit 2e85443
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 1 deletion.
44 changes: 43 additions & 1 deletion rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation;
use crate::execution::physical_plan::common;
use crate::execution::physical_plan::datasource::DatasourceExec;
use crate::execution::physical_plan::expressions::{
BinaryExpr, CastExpr, Column, Literal, Sum,
BinaryExpr, CastExpr, Column, Count, Literal, Sum,
};
use crate::execution::physical_plan::hash_aggregate::HashAggregateExec;
use crate::execution::physical_plan::merge::MergeExec;
Expand Down Expand Up @@ -333,6 +333,9 @@ impl ExecutionContext {
"sum" => Ok(Arc::new(Sum::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
"count" => Ok(Arc::new(Count::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
other => Err(ExecutionError::NotImplemented(format!(
"Unsupported aggregate function '{}'",
other
Expand Down Expand Up @@ -641,6 +644,45 @@ mod tests {
Ok(())
}

#[test]
fn count_basic() -> Result<()> {
let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["10,10"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);
Ok(())
}

#[test]
fn count_partitioned() -> Result<()> {
let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["40,40"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);
Ok(())
}

#[test]
fn count_aggregated() -> Result<()> {
let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected = vec!["0,10", "1,10", "2,10", "3,10"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);
Ok(())
}

/// Execute SQL and return results
fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> {
let tmp_dir = TempDir::new("execute")?;
Expand Down
107 changes: 107 additions & 0 deletions rust/datafusion/src/execution/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ macro_rules! sum_accumulate {
}
}};
}

struct SumAccumulator {
expr: Arc<dyn PhysicalExpr>,
sum: Option<ScalarValue>,
Expand Down Expand Up @@ -199,6 +200,66 @@ pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Sum::new(expr))
}

/// COUNT aggregate expression
/// If given something like `SELECT COUNT(*) FROM table`, the number of rows in the table will be returned.
/// If given a specific column, `SELECT COUNT(column_name) FROM table`, the number of *non-null*
/// rows in that column will be returned.
pub struct Count {
expr: Arc<dyn PhysicalExpr>,
}

impl Count {
/// Create a new COUNT aggregate function.
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
Self { expr: expr }
}
}

impl AggregateExpr for Count {
fn name(&self) -> String {
"COUNT".to_string()
}

fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
Ok(DataType::UInt64)
}

fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
Rc::new(RefCell::new(CountAccumulator {
expr: self.expr.clone(),
count: 0,
}))
}

fn create_combiner(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
Arc::new(Sum::new(Arc::new(Column::new(column_index))))
}
}

struct CountAccumulator {
expr: Arc<dyn PhysicalExpr>,
count: u64,
}

impl Accumulator for CountAccumulator {
fn accumulate(&mut self, batch: &RecordBatch, row_index: usize) -> Result<()> {
let array = self.expr.evaluate(batch)?;
if array.is_valid(row_index) {
self.count += 1;
}
Ok(())
}

fn get_value(&self) -> Result<Option<ScalarValue>> {
Ok(Some(ScalarValue::UInt64(self.count)))
}
}

/// Create a count expression
pub fn count(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Count::new(expr))
}

/// Invoke a compute kernel on a pair of arrays
macro_rules! compute_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
Expand Down Expand Up @@ -607,6 +668,42 @@ mod tests {
}
}

#[test]
fn count_elements() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(5)));
Ok(())
}

#[test]
fn count_with_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), None]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(3)));
Ok(())
}

#[test]
fn count_all_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
let a = BooleanArray::from(vec![None, None, None, None, None, None, None, None]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0)));
Ok(())
}

#[test]
fn count_empty() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
let a = BooleanArray::from(Vec::<bool>::new());
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
assert_eq!(do_count(&batch)?, Some(ScalarValue::UInt64(0)));
Ok(())
}

#[test]
fn sum_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down Expand Up @@ -703,4 +800,14 @@ mod tests {
}
accum.get_value()
}

fn do_count(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
let count = count(col(0));
let accum = count.create_accumulator();
let mut accum = accum.borrow_mut();
for i in 0..batch.num_rows() {
accum.accumulate(&batch, i)?;
}
accum.get_value()
}
}

0 comments on commit 2e85443

Please sign in to comment.