Skip to content

Commit

Permalink
ARROW-6658: [Rust][Datafusion] Implement AVG expression
Browse files Browse the repository at this point in the history
  • Loading branch information
alippai committed Oct 1, 2019
1 parent 8231fcb commit 5174a2a
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 1 deletion.
33 changes: 32 additions & 1 deletion rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation;
use crate::execution::physical_plan::common;
use crate::execution::physical_plan::datasource::DatasourceExec;
use crate::execution::physical_plan::expressions::{
BinaryExpr, CastExpr, Column, Literal, Sum,
Avg, BinaryExpr, CastExpr, Column, Literal, Sum,
};
use crate::execution::physical_plan::hash_aggregate::HashAggregateExec;
use crate::execution::physical_plan::merge::MergeExec;
Expand Down Expand Up @@ -333,6 +333,9 @@ impl ExecutionContext {
"sum" => Ok(Arc::new(Sum::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
"avg" => Ok(Arc::new(Avg::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
other => Err(ExecutionError::NotImplemented(format!(
"Unsupported aggregate function '{}'",
other
Expand Down Expand Up @@ -627,6 +630,20 @@ mod tests {
Ok(())
}

#[test]
fn aggregate_avg() -> Result<()> {
let results = execute("SELECT AVG(c1), AVG(c2) FROM test", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["1.5,5.5"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

#[test]
fn aggregate_grouped() -> Result<()> {
let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4)?;
Expand All @@ -641,6 +658,20 @@ mod tests {
Ok(())
}

#[test]
fn aggregate_grouped_avg() -> Result<()> {
let results = execute("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

/// Execute SQL and return results
fn execute(sql: &str, partition_count: usize) -> Result<Vec<RecordBatch>> {
let tmp_dir = TempDir::new("execute")?;
Expand Down
211 changes: 211 additions & 0 deletions rust/datafusion/src/execution/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,121 @@ pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Sum::new(expr))
}

/// AVG aggregate expression
pub struct Avg {
expr: Arc<dyn PhysicalExpr>,
}

impl Avg {
/// Create a new AVG aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
Self { expr }
}
}

impl AggregateExpr for Avg {
fn name(&self) -> String {
"AVG".to_string()
}

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
match self.expr.data_type(input_schema)? {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => Ok(DataType::Float64),
other => Err(ExecutionError::General(format!(
"AVG does not support {:?}",
other
))),
}
}

fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
Rc::new(RefCell::new(AvgAccumulator {
expr: self.expr.clone(),
sum: None,
count: None,
}))
}

fn create_combiner(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
Arc::new(Avg::new(Arc::new(Column::new(column_index))))
}
}

macro_rules! avg_accumulate {
($SELF:ident, $ARRAY:ident, $ROW_INDEX:expr, $ARRAY_TYPE:ident) => {{
if let Some(array) = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>() {
if $ARRAY.is_valid($ROW_INDEX) {
let value = array.value($ROW_INDEX);
match $SELF.sum {
Some(n) => {
$SELF.sum = Some(n + value as f64);
$SELF.count = Some($SELF.count.unwrap() + 1);
}
None => {
$SELF.sum = Some(value as f64);
$SELF.count = Some(1);
}
};
}
Ok(())
} else {
Err(ExecutionError::General(
"Failed to downcast array".to_string(),
))
}
}};
}
struct AvgAccumulator {
expr: Arc<dyn PhysicalExpr>,
sum: Option<f64>,
count: Option<i64>,
}

impl Accumulator for AvgAccumulator {
fn accumulate(&mut self, batch: &RecordBatch, row_index: usize) -> Result<()> {
let array = self.expr.evaluate(batch)?;
match self.expr.data_type(batch.schema())? {
DataType::Int8 => avg_accumulate!(self, array, row_index, Int8Array),
DataType::Int16 => avg_accumulate!(self, array, row_index, Int16Array),
DataType::Int32 => avg_accumulate!(self, array, row_index, Int32Array),
DataType::Int64 => avg_accumulate!(self, array, row_index, Int64Array),
DataType::UInt8 => avg_accumulate!(self, array, row_index, UInt8Array),
DataType::UInt16 => avg_accumulate!(self, array, row_index, UInt16Array),
DataType::UInt32 => avg_accumulate!(self, array, row_index, UInt32Array),
DataType::UInt64 => avg_accumulate!(self, array, row_index, UInt64Array),
DataType::Float32 => avg_accumulate!(self, array, row_index, Float32Array),
DataType::Float64 => avg_accumulate!(self, array, row_index, Float64Array),
other => Err(ExecutionError::General(format!(
"SUM does not support {:?}",
other
))),
}
}

fn get_value(&self) -> Result<Option<ScalarValue>> {
match self.sum {
Some(_) => Ok(Some(ScalarValue::Float64(
self.sum.unwrap() / self.count.unwrap() as f64,
))),
None => Ok(None),
}
}
}

/// Create a avg expression
pub fn avg(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Avg::new(expr))
}

/// Invoke a compute kernel on a pair of arrays
macro_rules! compute_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
Expand Down Expand Up @@ -621,6 +736,20 @@ mod tests {

Ok(())
}
#[test]
fn avg_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let avg = avg(col(0));
assert_eq!("AVG".to_string(), avg.name());
assert_eq!(DataType::Float64, avg.data_type(&schema)?);

let combiner = avg.create_combiner(0);
assert_eq!("AVG".to_string(), combiner.name());
assert_eq!(DataType::Float64, combiner.data_type(&schema)?);

Ok(())
}

#[test]
fn sum_i32() -> Result<()> {
Expand All @@ -634,6 +763,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_i32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));

Ok(())
}

#[test]
fn sum_i32_with_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand All @@ -646,6 +787,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_i32_with_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3.25)));

Ok(())
}

#[test]
fn sum_i32_all_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand All @@ -658,6 +811,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_i32_all_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![None, None]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, None);

Ok(())
}

#[test]
fn sum_u32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
Expand All @@ -670,6 +835,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_u32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);

let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));

Ok(())
}

#[test]
fn sum_f32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
Expand All @@ -682,6 +859,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_f32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);

let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));

Ok(())
}

#[test]
fn sum_f64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
Expand All @@ -694,6 +883,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_f64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);

let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));

Ok(())
}

fn do_sum(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
let sum = sum(col(0));
let accum = sum.create_accumulator();
Expand All @@ -703,4 +904,14 @@ mod tests {
}
accum.get_value()
}

fn do_avg(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
let avg = avg(col(0));
let accum = avg.create_accumulator();
let mut accum = accum.borrow_mut();
for i in 0..batch.num_rows() {
accum.accumulate(&batch, i)?;
}
accum.get_value()
}
}
16 changes: 16 additions & 0 deletions rust/datafusion/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ pub fn format_batch(batch: &RecordBatch) -> Vec<String> {
.unwrap()
.value(row_index)
)),
DataType::Float32 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::Float32Array>()
.unwrap()
.value(row_index)
)),
DataType::Float64 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::Float64Array>()
.unwrap()
.value(row_index)
)),
_ => s.push('?'),
}
}
Expand Down

0 comments on commit 5174a2a

Please sign in to comment.