From bf68fa83f24167bba737c0902e75b9280b12a881 Mon Sep 17 00:00:00 2001 From: Adam Lippai Date: Fri, 4 Oct 2019 16:15:12 -0600 Subject: [PATCH] ARROW-6656: [Rust][Datafusion] Add MAX, MIN expressions This is a naive implementation, just copy-paste, may need some refactor :) Closes #5557 from alippai/ARROW-6656 and squashes the following commits: 3be218aba Update to use new trait definition 85eed74b7 ARROW-6656: Add MIN expression 9f98de7bf ARROW-6656: Add MAX expression Lead-authored-by: Adam Lippai Co-authored-by: Andy Grove Signed-off-by: Andy Grove --- rust/datafusion/src/execution/context.rs | 64 ++- .../execution/physical_plan/expressions.rs | 474 ++++++++++++++++++ 2 files changed, 537 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index f07c8b987faab..92bb22c950bf8 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -38,7 +38,7 @@ use crate::execution::limit::LimitRelation; use crate::execution::physical_plan::common; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::expressions::{ - BinaryExpr, CastExpr, Column, Count, Literal, Sum, + BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, Sum, }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::merge::MergeExec; @@ -333,6 +333,12 @@ impl ExecutionContext { "sum" => Ok(Arc::new(Sum::new( self.create_physical_expr(&args[0], input_schema)?, ))), + "max" => Ok(Arc::new(Max::new( + self.create_physical_expr(&args[0], input_schema)?, + ))), + "min" => Ok(Arc::new(Min::new( + self.create_physical_expr(&args[0], input_schema)?, + ))), "count" => Ok(Arc::new(Count::new( self.create_physical_expr(&args[0], input_schema)?, ))), @@ -630,6 +636,34 @@ mod tests { Ok(()) } + #[test] + fn aggregate_max() -> Result<()> { + let results = execute("SELECT MAX(c1), MAX(c2) FROM test", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["3,10"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } + + #[test] + fn aggregate_min() -> Result<()> { + let results = execute("SELECT MIN(c1), MIN(c2) FROM test", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["0,1"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } + #[test] fn aggregate_grouped() -> Result<()> { let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4)?; @@ -644,6 +678,34 @@ mod tests { Ok(()) } + #[test] + fn aggregate_grouped_max() -> Result<()> { + let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } + + #[test] + fn aggregate_grouped_min() -> Result<()> { + let results = execute("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } + #[test] fn count_basic() -> Result<()> { let results = execute("SELECT COUNT(c1), COUNT(c2) FROM test", 1)?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index 5f535367efdb5..b8aa09754756b 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -208,6 +208,284 @@ pub fn sum(expr: Arc) -> Arc { Arc::new(Sum::new(expr)) } +/// MAX aggregate expression +pub struct Max { + expr: Arc, +} + +impl Max { + /// Create a new MAX aggregate function + pub fn new(expr: Arc) -> Self { + Self { expr } + } +} + +impl AggregateExpr for Max { + fn name(&self) -> String { + "MAX".to_string() + } + + fn data_type(&self, input_schema: &Schema) -> Result { + match self.expr.data_type(input_schema)? { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + Ok(DataType::Int64) + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(DataType::UInt64) + } + DataType::Float32 => Ok(DataType::Float32), + DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "MAX does not support {:?}", + other + ))), + } + } + + fn evaluate_input(&self, batch: &RecordBatch) -> Result { + self.expr.evaluate(batch) + } + + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(MaxAccumulator { + expr: self.expr.clone(), + max: None, + })) + } + + fn create_combiner(&self, column_index: usize) -> Arc { + Arc::new(Max::new(Arc::new(Column::new(column_index)))) + } +} + +macro_rules! max_accumulate { + ($SELF:ident, $ARRAY:ident, $ROW_INDEX:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{ + if let Some(array) = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>() { + if $ARRAY.is_valid($ROW_INDEX) { + let value = array.value($ROW_INDEX); + $SELF.max = match $SELF.max { + Some(ScalarValue::$SCALAR_VARIANT(n)) => { + if n > (value as $TY) { + Some(ScalarValue::$SCALAR_VARIANT(n)) + } else { + Some(ScalarValue::$SCALAR_VARIANT(value as $TY)) + } + } + Some(_) => { + return Err(ExecutionError::InternalError( + "Unexpected ScalarValue variant".to_string(), + )) + } + None => Some(ScalarValue::$SCALAR_VARIANT(value as $TY)), + }; + } + Ok(()) + } else { + Err(ExecutionError::General( + "Failed to downcast array".to_string(), + )) + } + }}; +} +struct MaxAccumulator { + expr: Arc, + max: Option, +} + +impl Accumulator for MaxAccumulator { + fn accumulate( + &mut self, + batch: &RecordBatch, + array: &ArrayRef, + row_index: usize, + ) -> Result<()> { + match self.expr.data_type(batch.schema())? { + DataType::Int8 => { + max_accumulate!(self, array, row_index, Int8Array, Int64, i64) + } + DataType::Int16 => { + max_accumulate!(self, array, row_index, Int16Array, Int64, i64) + } + DataType::Int32 => { + max_accumulate!(self, array, row_index, Int32Array, Int64, i64) + } + DataType::Int64 => { + max_accumulate!(self, array, row_index, Int64Array, Int64, i64) + } + DataType::UInt8 => { + max_accumulate!(self, array, row_index, UInt8Array, UInt64, u64) + } + DataType::UInt16 => { + max_accumulate!(self, array, row_index, UInt16Array, UInt64, u64) + } + DataType::UInt32 => { + max_accumulate!(self, array, row_index, UInt32Array, UInt64, u64) + } + DataType::UInt64 => { + max_accumulate!(self, array, row_index, UInt64Array, UInt64, u64) + } + DataType::Float32 => { + max_accumulate!(self, array, row_index, Float32Array, Float32, f32) + } + DataType::Float64 => { + max_accumulate!(self, array, row_index, Float64Array, Float64, f64) + } + other => Err(ExecutionError::General(format!( + "MAX does not support {:?}", + other + ))), + } + } + + fn get_value(&self) -> Result> { + Ok(self.max.clone()) + } +} + +/// Create a max expression +pub fn max(expr: Arc) -> Arc { + Arc::new(Max::new(expr)) +} + +/// MIN aggregate expression +pub struct Min { + expr: Arc, +} + +impl Min { + /// Create a new MIN aggregate function + pub fn new(expr: Arc) -> Self { + Self { expr } + } +} + +impl AggregateExpr for Min { + fn name(&self) -> String { + "MIN".to_string() + } + + fn data_type(&self, input_schema: &Schema) -> Result { + match self.expr.data_type(input_schema)? { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + Ok(DataType::Int64) + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(DataType::UInt64) + } + DataType::Float32 => Ok(DataType::Float32), + DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "MIN does not support {:?}", + other + ))), + } + } + + fn evaluate_input(&self, batch: &RecordBatch) -> Result { + self.expr.evaluate(batch) + } + + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(MinAccumulator { + expr: self.expr.clone(), + min: None, + })) + } + + fn create_combiner(&self, column_index: usize) -> Arc { + Arc::new(Min::new(Arc::new(Column::new(column_index)))) + } +} + +macro_rules! min_accumulate { + ($SELF:ident, $ARRAY:ident, $ROW_INDEX:expr, $ARRAY_TYPE:ident, $SCALAR_VARIANT:ident, $TY:ty) => {{ + if let Some(array) = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>() { + if $ARRAY.is_valid($ROW_INDEX) { + let value = array.value($ROW_INDEX); + $SELF.min = match $SELF.min { + Some(ScalarValue::$SCALAR_VARIANT(n)) => { + if n < (value as $TY) { + Some(ScalarValue::$SCALAR_VARIANT(n)) + } else { + Some(ScalarValue::$SCALAR_VARIANT(value as $TY)) + } + } + Some(_) => { + return Err(ExecutionError::InternalError( + "Unexpected ScalarValue variant".to_string(), + )) + } + None => Some(ScalarValue::$SCALAR_VARIANT(value as $TY)), + }; + } + Ok(()) + } else { + Err(ExecutionError::General( + "Failed to downcast array".to_string(), + )) + } + }}; +} +struct MinAccumulator { + expr: Arc, + min: Option, +} + +impl Accumulator for MinAccumulator { + fn accumulate( + &mut self, + batch: &RecordBatch, + array: &ArrayRef, + row_index: usize, + ) -> Result<()> { + match self.expr.data_type(batch.schema())? { + DataType::Int8 => { + min_accumulate!(self, array, row_index, Int8Array, Int64, i64) + } + DataType::Int16 => { + min_accumulate!(self, array, row_index, Int16Array, Int64, i64) + } + DataType::Int32 => { + min_accumulate!(self, array, row_index, Int32Array, Int64, i64) + } + DataType::Int64 => { + min_accumulate!(self, array, row_index, Int64Array, Int64, i64) + } + DataType::UInt8 => { + min_accumulate!(self, array, row_index, UInt8Array, UInt64, u64) + } + DataType::UInt16 => { + min_accumulate!(self, array, row_index, UInt16Array, UInt64, u64) + } + DataType::UInt32 => { + min_accumulate!(self, array, row_index, UInt32Array, UInt64, u64) + } + DataType::UInt64 => { + min_accumulate!(self, array, row_index, UInt64Array, UInt64, u64) + } + DataType::Float32 => { + min_accumulate!(self, array, row_index, Float32Array, Float32, f32) + } + DataType::Float64 => { + min_accumulate!(self, array, row_index, Float64Array, Float64, f64) + } + other => Err(ExecutionError::General(format!( + "MIN does not support {:?}", + other + ))), + } + } + + fn get_value(&self) -> Result> { + Ok(self.min.clone()) + } +} + +/// Create a min expression +pub fn min(expr: Arc) -> Arc { + Arc::new(Min::new(expr)) +} + /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. pub struct Count { @@ -693,6 +971,36 @@ mod tests { Ok(()) } + #[test] + fn max_contract() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let max = max(col(0)); + assert_eq!("MAX".to_string(), max.name()); + assert_eq!(DataType::Int64, max.data_type(&schema)?); + + let combiner = max.create_combiner(0); + assert_eq!("MAX".to_string(), combiner.name()); + assert_eq!(DataType::Int64, combiner.data_type(&schema)?); + + Ok(()) + } + + #[test] + fn min_contract() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let min = min(col(0)); + assert_eq!("MIN".to_string(), min.name()); + assert_eq!(DataType::Int64, min.data_type(&schema)?); + + let combiner = min.create_combiner(0); + assert_eq!("MIN".to_string(), combiner.name()); + assert_eq!(DataType::Int64, combiner.data_type(&schema)?); + + Ok(()) + } + #[test] fn sum_i32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -705,6 +1013,30 @@ mod tests { Ok(()) } + #[test] + fn max_i32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5))); + + Ok(()) + } + + #[test] + fn min_i32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_min(&batch)?, Some(ScalarValue::Int64(1))); + + Ok(()) + } + #[test] fn sum_i32_with_nulls() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -717,6 +1049,30 @@ mod tests { Ok(()) } + #[test] + fn max_i32_with_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_max(&batch)?, Some(ScalarValue::Int64(5))); + + Ok(()) + } + + #[test] + fn min_i32_with_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![Some(1), None, Some(3), Some(4), Some(5)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_min(&batch)?, Some(ScalarValue::Int64(1))); + + Ok(()) + } + #[test] fn sum_i32_all_nulls() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -729,6 +1085,30 @@ mod tests { Ok(()) } + #[test] + fn max_i32_all_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![None, None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_max(&batch)?, None); + + Ok(()) + } + + #[test] + fn min_i32_all_nulls() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let a = Int32Array::from(vec![None, None]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_min(&batch)?, None); + + Ok(()) + } + #[test] fn sum_u32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); @@ -741,6 +1121,30 @@ mod tests { Ok(()) } + #[test] + fn max_u32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + + let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_max(&batch)?, Some(ScalarValue::UInt64(5_u64))); + + Ok(()) + } + + #[test] + fn min_u32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]); + + let a = UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_min(&batch)?, Some(ScalarValue::UInt64(1_u64))); + + Ok(()) + } + #[test] fn sum_f32() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); @@ -753,6 +1157,30 @@ mod tests { Ok(()) } + #[test] + fn max_f32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + + let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_max(&batch)?, Some(ScalarValue::Float32(5_f32))); + + Ok(()) + } + + #[test] + fn min_f32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float32, false)]); + + let a = Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_min(&batch)?, Some(ScalarValue::Float32(1_f32))); + + Ok(()) + } + #[test] fn sum_f64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); @@ -765,6 +1193,30 @@ mod tests { Ok(()) } + #[test] + fn max_f64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_max(&batch)?, Some(ScalarValue::Float64(5_f64))); + + Ok(()) + } + + #[test] + fn min_f64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Float64, false)]); + + let a = Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + assert_eq!(do_min(&batch)?, Some(ScalarValue::Float64(1_f64))); + + Ok(()) + } + #[test] fn count_elements() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); @@ -812,6 +1264,28 @@ mod tests { accum.get_value() } + fn do_max(batch: &RecordBatch) -> Result> { + let max = max(col(0)); + let accum = max.create_accumulator(); + let input = max.evaluate_input(batch)?; + let mut accum = accum.borrow_mut(); + for i in 0..batch.num_rows() { + accum.accumulate(&batch, &input, i)?; + } + accum.get_value() + } + + fn do_min(batch: &RecordBatch) -> Result> { + let min = min(col(0)); + let accum = min.create_accumulator(); + let input = min.evaluate_input(batch)?; + let mut accum = accum.borrow_mut(); + for i in 0..batch.num_rows() { + accum.accumulate(&batch, &input, i)?; + } + accum.get_value() + } + fn do_count(batch: &RecordBatch) -> Result> { let count = count(col(0)); let accum = count.create_accumulator();