Skip to content

Commit

Permalink
ARROW-6658: [Rust][Datafusion] Implement AVG expression
Browse files Browse the repository at this point in the history
I wasn't sure about the datatypes of sum&count (picked the broadest f64 and i64) also it may or may not be better to implement this as SUM()/COUNT().

Either way the changes from mod.rs are needed to test SQL with f64

Closes apache#5558 from alippai/ARROW-6658 and squashes the following commits:

20cddef <Andy Grove> fix typo
62372f6 <Andy Grove> Remove unwrap
cc32f22 <Andy Grove> rebase
ebc3acb <Adam Lippai> ARROW-6658:  Implement AVG expression

Lead-authored-by: Adam Lippai <[email protected]>
Co-authored-by: Andy Grove <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
alippai and andygrove committed Oct 5, 2019
1 parent bf68fa8 commit fc93312
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 1 deletion.
33 changes: 32 additions & 1 deletion rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation;
use crate::execution::physical_plan::common;
use crate::execution::physical_plan::datasource::DatasourceExec;
use crate::execution::physical_plan::expressions::{
BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, Sum,
Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, Sum,
};
use crate::execution::physical_plan::hash_aggregate::HashAggregateExec;
use crate::execution::physical_plan::merge::MergeExec;
Expand Down Expand Up @@ -333,6 +333,9 @@ impl ExecutionContext {
"sum" => Ok(Arc::new(Sum::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
"avg" => Ok(Arc::new(Avg::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
"max" => Ok(Arc::new(Max::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
Expand Down Expand Up @@ -636,6 +639,20 @@ mod tests {
Ok(())
}

#[test]
fn aggregate_avg() -> Result<()> {
let results = execute("SELECT AVG(c1), AVG(c2) FROM test", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["1.5,5.5"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

#[test]
fn aggregate_max() -> Result<()> {
let results = execute("SELECT MAX(c1), MAX(c2) FROM test", 4)?;
Expand Down Expand Up @@ -678,6 +695,20 @@ mod tests {
Ok(())
}

#[test]
fn aggregate_grouped_avg() -> Result<()> {
let results = execute("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

#[test]
fn aggregate_grouped_max() -> Result<()> {
let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4)?;
Expand Down
220 changes: 220 additions & 0 deletions rust/datafusion/src/execution/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,129 @@ pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Sum::new(expr))
}

/// AVG aggregate expression
pub struct Avg {
expr: Arc<dyn PhysicalExpr>,
}

impl Avg {
/// Create a new AVG aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
Self { expr }
}
}

impl AggregateExpr for Avg {
fn name(&self) -> String {
"AVG".to_string()
}

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
match self.expr.data_type(input_schema)? {
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Float32
| DataType::Float64 => Ok(DataType::Float64),
other => Err(ExecutionError::General(format!(
"AVG does not support {:?}",
other
))),
}
}

fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
self.expr.evaluate(batch)
}

fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
Rc::new(RefCell::new(AvgAccumulator {
expr: self.expr.clone(),
sum: None,
count: None,
}))
}

fn create_combiner(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
Arc::new(Avg::new(Arc::new(Column::new(column_index))))
}
}

macro_rules! avg_accumulate {
($SELF:ident, $ARRAY:ident, $ROW_INDEX:expr, $ARRAY_TYPE:ident) => {{
if let Some(array) = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>() {
if $ARRAY.is_valid($ROW_INDEX) {
let value = array.value($ROW_INDEX);
match ($SELF.sum, $SELF.count) {
(Some(sum), Some(count)) => {
$SELF.sum = Some(sum + value as f64);
$SELF.count = Some(count + 1);
}
_ => {
$SELF.sum = Some(value as f64);
$SELF.count = Some(1);
}
};
}
Ok(())
} else {
Err(ExecutionError::General(
"Failed to downcast array".to_string(),
))
}
}};
}
struct AvgAccumulator {
expr: Arc<dyn PhysicalExpr>,
sum: Option<f64>,
count: Option<i64>,
}

impl Accumulator for AvgAccumulator {
fn accumulate(
&mut self,
batch: &RecordBatch,
array: &ArrayRef,
row_index: usize,
) -> Result<()> {
match self.expr.data_type(batch.schema())? {
DataType::Int8 => avg_accumulate!(self, array, row_index, Int8Array),
DataType::Int16 => avg_accumulate!(self, array, row_index, Int16Array),
DataType::Int32 => avg_accumulate!(self, array, row_index, Int32Array),
DataType::Int64 => avg_accumulate!(self, array, row_index, Int64Array),
DataType::UInt8 => avg_accumulate!(self, array, row_index, UInt8Array),
DataType::UInt16 => avg_accumulate!(self, array, row_index, UInt16Array),
DataType::UInt32 => avg_accumulate!(self, array, row_index, UInt32Array),
DataType::UInt64 => avg_accumulate!(self, array, row_index, UInt64Array),
DataType::Float32 => avg_accumulate!(self, array, row_index, Float32Array),
DataType::Float64 => avg_accumulate!(self, array, row_index, Float64Array),
other => Err(ExecutionError::General(format!(
"AVG does not support {:?}",
other
))),
}
}

fn get_value(&self) -> Result<Option<ScalarValue>> {
match (self.sum, self.count) {
(Some(sum), Some(count)) => {
Ok(Some(ScalarValue::Float64(sum / count as f64)))
}
_ => Ok(None),
}
}
}

/// Create a avg expression
pub fn avg(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Avg::new(expr))
}

/// MAX aggregate expression
pub struct Max {
expr: Arc<dyn PhysicalExpr>,
Expand Down Expand Up @@ -1000,6 +1123,20 @@ mod tests {

Ok(())
}
#[test]
fn avg_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let avg = avg(col(0));
assert_eq!("AVG".to_string(), avg.name());
assert_eq!(DataType::Float64, avg.data_type(&schema)?);

let combiner = avg.create_combiner(0);
assert_eq!("AVG".to_string(), combiner.name());
assert_eq!(DataType::Float64, combiner.data_type(&schema)?);

Ok(())
}

#[test]
fn sum_i32() -> Result<()> {
Expand All @@ -1013,6 +1150,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_i32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));

Ok(())
}

#[test]
fn max_i32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down Expand Up @@ -1049,6 +1198,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_i32_with_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3.25)));

Ok(())
}

#[test]
fn max_i32_with_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down Expand Up @@ -1109,6 +1270,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_i32_all_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![None, None]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, None);

Ok(())
}

#[test]
fn sum_u32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
Expand All @@ -1121,6 +1294,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_u32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);

let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));

Ok(())
}

#[test]
fn max_u32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
Expand Down Expand Up @@ -1157,6 +1342,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_f32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);

let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));

Ok(())
}

#[test]
fn max_f32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
Expand Down Expand Up @@ -1193,6 +1390,18 @@ mod tests {
Ok(())
}

#[test]
fn avg_f64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);

let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_avg(&batch)?, Some(ScalarValue::Float64(3_f64)));

Ok(())
}

#[test]
fn max_f64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
Expand Down Expand Up @@ -1296,4 +1505,15 @@ mod tests {
}
accum.get_value()
}

fn do_avg(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
let avg = avg(col(0));
let accum = avg.create_accumulator();
let input = avg.evaluate_input(batch)?;
let mut accum = accum.borrow_mut();
for i in 0..batch.num_rows() {
accum.accumulate(&batch, &input, i)?;
}
accum.get_value()
}
}
16 changes: 16 additions & 0 deletions rust/datafusion/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,22 @@ pub fn format_batch(batch: &RecordBatch) -> Vec<String> {
.unwrap()
.value(row_index)
)),
DataType::Float32 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::Float32Array>()
.unwrap()
.value(row_index)
)),
DataType::Float64 => s.push_str(&format!(
"{:?}",
array
.as_any()
.downcast_ref::<array::Float64Array>()
.unwrap()
.value(row_index)
)),
_ => s.push('?'),
}
}
Expand Down

0 comments on commit fc93312

Please sign in to comment.