Skip to content

Commit

Permalink
ARROW-6656: [Rust][Datafusion] Add MAX expression
Browse files Browse the repository at this point in the history
  • Loading branch information
alippai authored and andygrove committed Oct 4, 2019
1 parent 368562b commit 9f98de7
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 1 deletion.
33 changes: 32 additions & 1 deletion rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation;
use crate::execution::physical_plan::common;
use crate::execution::physical_plan::datasource::DatasourceExec;
use crate::execution::physical_plan::expressions::{
BinaryExpr, CastExpr, Column, Count, Literal, Sum,
BinaryExpr, CastExpr, Column, Count, Literal, Max, Sum,
};
use crate::execution::physical_plan::hash_aggregate::HashAggregateExec;
use crate::execution::physical_plan::merge::MergeExec;
Expand Down Expand Up @@ -333,6 +333,9 @@ impl ExecutionContext {
"sum" => Ok(Arc::new(Sum::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
"max" => Ok(Arc::new(Max::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
"count" => Ok(Arc::new(Count::new(
self.create_physical_expr(&args[0], input_schema)?,
))),
Expand Down Expand Up @@ -630,6 +633,20 @@ mod tests {
Ok(())
}

#[test]
fn aggregate_max() -> Result<()> {
let results = execute("SELECT MAX(c1), MAX(c2) FROM test", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["3,10"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

#[test]
fn aggregate_grouped() -> Result<()> {
let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4)?;
Expand All @@ -644,6 +661,20 @@ mod tests {
Ok(())
}

#[test]
fn aggregate_grouped_max() -> Result<()> {
let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4)?;
assert_eq!(results.len(), 1);

let batch = &results[0];
let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"];
let mut rows = test::format_batch(&batch);
rows.sort();
assert_eq!(rows, expected);

Ok(())
}

#[test]
fn count_basic() -> Result<()> {
let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1)?;
Expand Down
228 changes: 228 additions & 0 deletions rust/datafusion/src/execution/physical_plan/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,137 @@ pub fn sum(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Sum::new(expr))
}

/// MAX aggregate expression
pub struct Max {
expr: Arc<dyn PhysicalExpr>,
}

impl Max {
/// Create a new MAX aggregate function
pub fn new(expr: Arc<dyn PhysicalExpr>) -> Self {
Self { expr }
}
}

impl AggregateExpr for Max {
fn name(&self) -> String {
"MAX".to_string()
}

fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
match self.expr.data_type(input_schema)? {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {
Ok(DataType::Int64)
}
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
Ok(DataType::UInt64)
}
DataType::Float32 => Ok(DataType::Float32),
DataType::Float64 => Ok(DataType::Float64),
other => Err(ExecutionError::General(format!(
"MAX does not support {:?}",
other
))),
}
}

fn create_accumulator(&self) -> Rc<RefCell<dyn Accumulator>> {
Rc::new(RefCell::new(MaxAccumulator {
expr: self.expr.clone(),
max: None,
}))
}

fn create_combiner(&self, column_index: usize) -> Arc<dyn AggregateExpr> {
Arc::new(Max::new(Arc::new(Column::new(column_index))))
}
}

macro_rules! max_accumulate {
($SELF:ident, $ARRAY:ident, $ROW_INDEX:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{
if let Some(array) = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>() {
if $ARRAY.is_valid($ROW_INDEX) {
let value = array.value($ROW_INDEX);
$SELF.max = match $SELF.max {
Some(ScalarValue::$SCALAR_VARIANT(n)) => {
if n > (value as $TY) {
Some(ScalarValue::$SCALAR_VARIANT(n))
} else {
Some(ScalarValue::$SCALAR_VARIANT(value as $TY))
}
}
Some(_) => {
return Err(ExecutionError::InternalError(
"Unexpected ScalarValue variant".to_string(),
))
}
None => Some(ScalarValue::$SCALAR_VARIANT(value as $TY)),
};
}
Ok(())
} else {
Err(ExecutionError::General(
"Failed to downcast array".to_string(),
))
}
}};
}
struct MaxAccumulator {
expr: Arc<dyn PhysicalExpr>,
max: Option<ScalarValue>,
}

impl Accumulator for MaxAccumulator {
fn accumulate(&mut self, batch: &RecordBatch, row_index: usize) -> Result<()> {
let array = self.expr.evaluate(batch)?;
match self.expr.data_type(batch.schema())? {
DataType::Int8 => {
max_accumulate!(self, array, row_index, Int8Array, Int64, i64)
}
DataType::Int16 => {
max_accumulate!(self, array, row_index, Int16Array, Int64, i64)
}
DataType::Int32 => {
max_accumulate!(self, array, row_index, Int32Array, Int64, i64)
}
DataType::Int64 => {
max_accumulate!(self, array, row_index, Int64Array, Int64, i64)
}
DataType::UInt8 => {
max_accumulate!(self, array, row_index, UInt8Array, UInt64, u64)
}
DataType::UInt16 => {
max_accumulate!(self, array, row_index, UInt16Array, UInt64, u64)
}
DataType::UInt32 => {
max_accumulate!(self, array, row_index, UInt32Array, UInt64, u64)
}
DataType::UInt64 => {
max_accumulate!(self, array, row_index, UInt64Array, UInt64, u64)
}
DataType::Float32 => {
max_accumulate!(self, array, row_index, Float32Array, Float32, f32)
}
DataType::Float64 => {
max_accumulate!(self, array, row_index, Float64Array, Float64, f64)
}
other => Err(ExecutionError::General(format!(
"MAX does not support {:?}",
other
))),
}
}

fn get_value(&self) -> Result<Option<ScalarValue>> {
Ok(self.max.clone())
}
}

/// Create a max expression
pub fn max(expr: Arc<dyn PhysicalExpr>) -> Arc<dyn AggregateExpr> {
Arc::new(Max::new(expr))
}

/// COUNT aggregate expression
/// Returns the amount of non-null values of the given expression.
pub struct Count {
Expand Down Expand Up @@ -693,6 +824,21 @@ mod tests {
Ok(())
}

#[test]
fn max_contract() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let max = max(col(0));
assert_eq!("MAX".to_string(), max.name());
assert_eq!(DataType::Int64, max.data_type(&schema)?);

let combiner = max.create_combiner(0);
assert_eq!("MAX".to_string(), combiner.name());
assert_eq!(DataType::Int64, combiner.data_type(&schema)?);

Ok(())
}

#[test]
fn sum_i32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand All @@ -705,6 +851,18 @@ mod tests {
Ok(())
}

#[test]
fn max_i32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5)));

Ok(())
}

#[test]
fn sum_i32_with_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand All @@ -717,6 +875,18 @@ mod tests {
Ok(())
}

#[test]
fn max_i32_with_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5)));

Ok(())
}

#[test]
fn sum_i32_all_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand All @@ -729,6 +899,18 @@ mod tests {
Ok(())
}

#[test]
fn max_i32_all_nulls() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let a = Int32Array::from(vec![None, None]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_max(&batch)?, None);

Ok(())
}

#[test]
fn sum_u32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
Expand All @@ -741,6 +923,18 @@ mod tests {
Ok(())
}

#[test]
fn max_u32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);

let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_max(&batch)?, Some(ScalarValue::UInt64(5_u64)));

Ok(())
}

#[test]
fn sum_f32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);
Expand All @@ -753,6 +947,18 @@ mod tests {
Ok(())
}

#[test]
fn max_f32() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]);

let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_max(&batch)?, Some(ScalarValue::Float32(5_f32)));

Ok(())
}

#[test]
fn sum_f64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);
Expand All @@ -765,6 +971,18 @@ mod tests {
Ok(())
}

#[test]
fn max_f64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]);

let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;

assert_eq!(do_max(&batch)?, Some(ScalarValue::Float64(5_f64)));

Ok(())
}

#[test]
fn count_elements() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down Expand Up @@ -812,6 +1030,16 @@ mod tests {
accum.get_value()
}

fn do_max(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
let max = max(col(0));
let accum = max.create_accumulator();
let mut accum = accum.borrow_mut();
for i in 0..batch.num_rows() {
accum.accumulate(&batch, i)?;
}
accum.get_value()
}

fn do_count(batch: &RecordBatch) -> Result<Option<ScalarValue>> {
let count = count(col(0));
let accum = count.create_accumulator();
Expand Down

0 comments on commit 9f98de7

Please sign in to comment.