diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 272e75acba6f..cfd3b7194429 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1268,6 +1268,35 @@ mod tests { Ok(()) } + #[tokio::test] + async fn window() -> Result<()> { + let results = execute( + "SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5", + 4, + ) + .await?; + // result in one batch, although e.g. having 2 batches do not change + // result semantics, having a len=1 assertion upfront keeps surprises + // at bay + assert_eq!(results.len(), 1); + + let expected = vec![ + "+----+----+---------+-----------+---------+---------+---------+", + "| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |", + "+----+----+---------+-----------+---------+---------+---------+", + "| 0 | 1 | 220 | 40 | 10 | 1 | 5.5 |", + "| 0 | 2 | 220 | 40 | 10 | 1 | 5.5 |", + "| 0 | 3 | 220 | 40 | 10 | 1 | 5.5 |", + "| 0 | 4 | 220 | 40 | 10 | 1 | 5.5 |", + "| 0 | 5 | 220 | 40 | 10 | 1 | 5.5 |", + "+----+----+---------+-----------+---------+---------+---------+", + ]; + + // window function shall respect ordering + assert_batches_eq!(expected, &results); + Ok(()) + } + #[tokio::test] async fn aggregate() -> Result<()> { let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?; diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 4d57c39bb31c..803870f3f784 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -41,6 +41,7 @@ mod min_max; mod negative; mod not; mod nullif; +mod row_number; mod sum; mod try_cast; @@ -58,6 +59,7 @@ pub use min_max::{Max, Min}; pub use negative::{negative, NegativeExpr}; pub use not::{not, NotExpr}; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; +pub use row_number::RowNumber; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; /// returns the name of the state diff --git a/datafusion/src/physical_plan/expressions/row_number.rs b/datafusion/src/physical_plan/expressions/row_number.rs new file mode 100644 index 000000000000..f399995461f7 --- /dev/null +++ b/datafusion/src/physical_plan/expressions/row_number.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expression for `row_number` that can evaluated at runtime during query execution + +use crate::error::Result; +use crate::physical_plan::{ + window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator, +}; +use crate::scalar::ScalarValue; +use arrow::array::{ArrayRef, UInt64Array}; +use arrow::datatypes::{DataType, Field}; +use std::any::Any; +use std::sync::Arc; + +/// row_number expression +#[derive(Debug)] +pub struct RowNumber { + name: String, +} + +impl RowNumber { + /// Create a new ROW_NUMBER function + pub fn new(name: String) -> Self { + Self { name } + } +} + +impl BuiltInWindowFunctionExpr for RowNumber { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + let nullable = false; + let data_type = DataType::UInt64; + Ok(Field::new(&self.name(), data_type, nullable)) + } + + fn expressions(&self) -> Vec> { + vec![] + } + + fn name(&self) -> &str { + self.name.as_str() + } + + fn create_accumulator(&self) -> Result> { + Ok(Box::new(RowNumberAccumulator::new())) + } +} + +#[derive(Debug)] +struct RowNumberAccumulator { + row_number: u64, +} + +impl RowNumberAccumulator { + /// new row_number accumulator + pub fn new() -> Self { + // row number is 1 based + Self { row_number: 1 } + } +} + +impl WindowAccumulator for RowNumberAccumulator { + fn scan(&mut self, _values: &[ScalarValue]) -> Result> { + let result = Some(ScalarValue::UInt64(Some(self.row_number))); + self.row_number += 1; + Ok(result) + } + + fn scan_batch( + &mut self, + num_rows: usize, + _values: &[ArrayRef], + ) -> Result> { + let new_row_number = self.row_number + (num_rows as u64); + // TODO: probably would be nice to have a (optimized) kernel for this at some point to + // generate an array like this. + let result = UInt64Array::from_iter_values(self.row_number..new_row_number); + self.row_number = new_row_number; + Ok(Some(Arc::new(result))) + } + + fn evaluate(&self) -> Result> { + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use arrow::record_batch::RecordBatch; + use arrow::{array::*, datatypes::*}; + + #[test] + fn row_number_all_null() -> Result<()> { + let arr: ArrayRef = Arc::new(BooleanArray::from(vec![ + None, None, None, None, None, None, None, None, + ])); + let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; + + let row_number = Arc::new(RowNumber::new("row_number".to_owned())); + + let mut acc = row_number.create_accumulator()?; + let expr = row_number.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(&batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + + let result = acc.scan_batch(batch.num_rows(), &values)?; + assert_eq!(true, result.is_some()); + + let result = result.unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + let result = result.values(); + assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); + + let result = acc.evaluate()?; + assert_eq!(false, result.is_some()); + Ok(()) + } + + #[test] + fn row_number_all_values() -> Result<()> { + let arr: ArrayRef = Arc::new(BooleanArray::from(vec![ + true, false, true, false, false, true, false, true, + ])); + let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?; + + let row_number = Arc::new(RowNumber::new("row_number".to_owned())); + + let mut acc = row_number.create_accumulator()?; + let expr = row_number.expressions(); + let values = expr + .iter() + .map(|e| e.evaluate(&batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + + let result = acc.scan_batch(batch.num_rows(), &values)?; + assert_eq!(true, result.is_some()); + + let result = result.unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + let result = result.values(); + assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); + + let result = acc.evaluate()?; + assert_eq!(false, result.is_some()); + Ok(()) + } +} diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index c9d268619cad..5008f49250b0 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -712,7 +712,7 @@ impl GroupedHashAggregateStream { tx.send(result) }); - GroupedHashAggregateStream { + Self { schema, output: rx, finished: false, @@ -825,7 +825,8 @@ fn aggregate_expressions( } pin_project! { - struct HashAggregateStream { + /// stream struct for hash aggregation + pub struct HashAggregateStream { schema: SchemaRef, #[pin] output: futures::channel::oneshot::Receiver>, @@ -878,7 +879,7 @@ impl HashAggregateStream { tx.send(result) }); - HashAggregateStream { + Self { schema, output: rx, finished: false, diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index c053229bc000..4f90a8cf7d6e 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -17,22 +17,23 @@ //! Traits for physical query plan, supporting parallel execution for partitioned relations. -use std::fmt::{self, Debug, Display}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::{any::Any, pin::Pin}; - use crate::execution::context::ExecutionContextState; use crate::logical_plan::LogicalPlan; -use crate::{error::Result, scalar::ScalarValue}; +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; - use async_trait::async_trait; pub use display::DisplayFormatType; use futures::stream::Stream; +use std::fmt::{self, Debug, Display}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::{any::Any, pin::Pin}; use self::{display::DisplayableExecutionPlan, merge::MergeExec}; use hashbrown::HashMap; @@ -457,10 +458,22 @@ pub trait WindowExpr: Send + Sync + Debug { fn name(&self) -> &str { "WindowExpr: default name" } + + /// the accumulator used to accumulate values from the expressions. + /// the accumulator expects the same number of arguments as `expressions` and must + /// return states with the same description as `state_fields` + fn create_accumulator(&self) -> Result>; + + /// expressions that are passed to the WindowAccumulator. + /// Functions which take a single input argument, such as `sum`, return a single [`Expr`], + /// others (e.g. `cov`) return many. + fn expressions(&self) -> Vec>; } /// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and -/// generically accumulates values. An accumulator knows how to: +/// generically accumulates values. +/// +/// An accumulator knows how to: /// * update its state from inputs via `update` /// * convert its internal state to a vector of scalar values /// * update its state from multiple accumulators' states via `merge` @@ -509,6 +522,58 @@ pub trait Accumulator: Send + Sync + Debug { fn evaluate(&self) -> Result; } +/// A window accumulator represents a stateful object that lives throughout the evaluation of multiple +/// rows and generically accumulates values. +/// +/// An accumulator knows how to: +/// * update its state from inputs via `update` +/// * convert its internal state to a vector of scalar values +/// * update its state from multiple accumulators' states via `merge` +/// * compute the final value from its internal state via `evaluate` +pub trait WindowAccumulator: Send + Sync + Debug { + /// scans the accumulator's state from a vector of scalars, similar to Accumulator it also + /// optionally generates values. + fn scan(&mut self, values: &[ScalarValue]) -> Result>; + + /// scans the accumulator's state from a vector of arrays. + fn scan_batch( + &mut self, + num_rows: usize, + values: &[ArrayRef], + ) -> Result> { + if values.is_empty() { + return Ok(None); + }; + // transpose columnar to row based so that we can apply window + let result = (0..num_rows) + .map(|index| { + let v = values + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + self.scan(&v) + }) + .collect::>>>()? + .into_iter() + .collect::>>(); + + Ok(match result { + Some(arr) if num_rows == arr.len() => Some(ScalarValue::iter_to_array(&arr)?), + None => None, + Some(arr) => { + return Err(DataFusionError::Internal(format!( + "expect scan batch to return {:?} rows, but got {:?}", + num_rows, + arr.len() + ))) + } + }) + } + + /// returns its value based on its current state. + fn evaluate(&self) -> Result>; +} + pub mod aggregates; pub mod array_expressions; pub mod coalesce_batches; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 018925d0e535..7ddfaf8f6897 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -147,8 +147,10 @@ impl DefaultPhysicalPlanner { // Initially need to perform the aggregate and then merge the partitions let input_exec = self.create_initial_plan(input, ctx_state)?; let input_schema = input_exec.schema(); - let physical_input_schema = input_exec.as_ref().schema(); + let logical_input_schema = input.as_ref().schema(); + let physical_input_schema = input_exec.as_ref().schema(); + let window_expr = window_expr .iter() .map(|e| { diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index 7cd4d9df7875..c5b838c6e84b 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -250,6 +250,7 @@ fn sort_batches( } pin_project! { + /// stream for sort plan struct SortStream { #[pin] output: futures::channel::oneshot::Receiver>>, diff --git a/datafusion/src/physical_plan/window_functions.rs b/datafusion/src/physical_plan/window_functions.rs index 65d5373d54f4..e6afcaad8ad6 100644 --- a/datafusion/src/physical_plan/window_functions.rs +++ b/datafusion/src/physical_plan/window_functions.rs @@ -20,12 +20,15 @@ //! //! see also https://www.postgresql.org/docs/current/functions-window.html +use crate::arrow::datatypes::Field; use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ aggregates, aggregates::AggregateFunction, functions::Signature, - type_coercion::data_types, + type_coercion::data_types, PhysicalExpr, WindowAccumulator, }; use arrow::datatypes::DataType; +use std::any::Any; +use std::sync::Arc; use std::{fmt, str::FromStr}; /// WindowFunction @@ -143,52 +146,92 @@ impl FromStr for BuiltInWindowFunction { /// Returns the datatype of the window function pub fn return_type(fun: &WindowFunction, arg_types: &[DataType]) -> Result { + match fun { + WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types), + WindowFunction::BuiltInWindowFunction(fun) => { + return_type_for_built_in(fun, arg_types) + } + } +} + +/// Returns the datatype of the built-in window function +pub(super) fn return_type_for_built_in( + fun: &BuiltInWindowFunction, + arg_types: &[DataType], +) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. // verify that this is a valid set of data types for this function - data_types(arg_types, &signature(fun))?; + data_types(arg_types, &signature_for_built_in(fun))?; match fun { - WindowFunction::AggregateFunction(fun) => aggregates::return_type(fun, arg_types), - WindowFunction::BuiltInWindowFunction(fun) => match fun { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), - BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { - Ok(DataType::Float64) - } - BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue - | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()), - }, + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank => Ok(DataType::UInt64), + BuiltInWindowFunction::PercentRank | BuiltInWindowFunction::CumeDist => { + Ok(DataType::Float64) + } + BuiltInWindowFunction::Ntile => Ok(DataType::UInt32), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue + | BuiltInWindowFunction::NthValue => Ok(arg_types[0].clone()), } } /// the signatures supported by the function `fun`. -fn signature(fun: &WindowFunction) -> Signature { - // note: the physical expression must accept the type returned by this function or the execution panics. +pub fn signature(fun: &WindowFunction) -> Signature { match fun { WindowFunction::AggregateFunction(fun) => aggregates::signature(fun), - WindowFunction::BuiltInWindowFunction(fun) => match fun { - BuiltInWindowFunction::RowNumber - | BuiltInWindowFunction::Rank - | BuiltInWindowFunction::DenseRank - | BuiltInWindowFunction::PercentRank - | BuiltInWindowFunction::CumeDist => Signature::Any(0), - BuiltInWindowFunction::Lag - | BuiltInWindowFunction::Lead - | BuiltInWindowFunction::FirstValue - | BuiltInWindowFunction::LastValue => Signature::Any(1), - BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]), - BuiltInWindowFunction::NthValue => Signature::Any(2), - }, + WindowFunction::BuiltInWindowFunction(fun) => signature_for_built_in(fun), + } +} + +/// the signatures supported by the built-in window function `fun`. +pub(super) fn signature_for_built_in(fun: &BuiltInWindowFunction) -> Signature { + // note: the physical expression must accept the type returned by this function or the execution panics. + match fun { + BuiltInWindowFunction::RowNumber + | BuiltInWindowFunction::Rank + | BuiltInWindowFunction::DenseRank + | BuiltInWindowFunction::PercentRank + | BuiltInWindowFunction::CumeDist => Signature::Any(0), + BuiltInWindowFunction::Lag + | BuiltInWindowFunction::Lead + | BuiltInWindowFunction::FirstValue + | BuiltInWindowFunction::LastValue => Signature::Any(1), + BuiltInWindowFunction::Ntile => Signature::Exact(vec![DataType::UInt64]), + BuiltInWindowFunction::NthValue => Signature::Any(2), } } +/// A window expression that is a built-in window function +pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { + /// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be + /// downcast to a specific implementation. + fn as_any(&self) -> &dyn Any; + + /// the field of the final result of this aggregation. + fn field(&self) -> Result; + + /// expressions that are passed to the Accumulator. + /// Single-column aggregations such as `sum` return a single value, others (e.g. `cov`) return many. + fn expressions(&self) -> Vec>; + + /// Human readable name such as `"MIN(c2)"` or `"RANK()"`. The default + /// implementation returns placeholder text. + fn name(&self) -> &str { + "BuiltInWindowFunctionExpr: default name" + } + + /// the accumulator used to accumulate values from the expressions. + /// the accumulator expects the same number of arguments as `expressions` and must + /// return states with the same description as `state_fields` + fn create_accumulator(&self) -> Result>; +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs index bdd25d69fd55..8ced3aec8ec1 100644 --- a/datafusion/src/physical_plan/windows.rs +++ b/datafusion/src/physical_plan/windows.rs @@ -19,13 +19,30 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{ - aggregates, window_functions::WindowFunction, AggregateExpr, Distribution, - ExecutionPlan, Partitioning, PhysicalExpr, SendableRecordBatchStream, WindowExpr, + aggregates, + expressions::RowNumber, + window_functions::BuiltInWindowFunctionExpr, + window_functions::{BuiltInWindowFunction, WindowFunction}, + Accumulator, AggregateExpr, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, + RecordBatchStream, SendableRecordBatchStream, WindowAccumulator, WindowExpr, +}; +use crate::scalar::ScalarValue; +use arrow::compute::concat; +use arrow::{ + array::{Array, ArrayRef}, + datatypes::{Field, Schema, SchemaRef}, + error::{ArrowError, Result as ArrowResult}, + record_batch::RecordBatch, }; -use arrow::datatypes::{Field, Schema, SchemaRef}; use async_trait::async_trait; +use futures::stream::{Stream, StreamExt}; +use futures::Future; +use pin_project_lite::pin_project; use std::any::Any; +use std::iter; +use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; /// Window execution plan #[derive(Debug)] @@ -57,18 +74,55 @@ pub fn create_window_expr( name, )?, })), - WindowFunction::BuiltInWindowFunction(fun) => { - Err(DataFusionError::NotImplemented(format!( - "window function with {:?} not implemented", - fun - ))) - } + WindowFunction::BuiltInWindowFunction(fun) => Ok(Arc::new(BuiltInWindowExpr { + window: create_built_in_window_expr(fun, args, input_schema, name)?, + })), + } +} + +fn create_built_in_window_expr( + fun: &BuiltInWindowFunction, + _args: &[Arc], + _input_schema: &Schema, + name: String, +) -> Result> { + match fun { + BuiltInWindowFunction::RowNumber => Ok(Arc::new(RowNumber::new(name))), + _ => Err(DataFusionError::NotImplemented(format!( + "Window function with {:?} not yet implemented", + fun + ))), } } /// A window expr that takes the form of a built in window function #[derive(Debug)] -pub struct BuiltInWindowExpr {} +pub struct BuiltInWindowExpr { + window: Arc, +} + +impl WindowExpr for BuiltInWindowExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.window.name() + } + + fn field(&self) -> Result { + self.window.field() + } + + fn expressions(&self) -> Vec> { + self.window.expressions() + } + + fn create_accumulator(&self) -> Result> { + self.window.create_accumulator() + } +} /// A window expr that takes the form of an aggregate function #[derive(Debug)] @@ -76,6 +130,23 @@ pub struct AggregateWindowExpr { aggregate: Arc, } +#[derive(Debug)] +struct AggregateWindowAccumulator { + accumulator: Box, +} + +impl WindowAccumulator for AggregateWindowAccumulator { + fn scan(&mut self, values: &[ScalarValue]) -> Result> { + self.accumulator.update(values)?; + Ok(None) + } + + /// returns its value based on its current state. + fn evaluate(&self) -> Result> { + Ok(Some(self.accumulator.evaluate()?)) + } +} + impl WindowExpr for AggregateWindowExpr { /// Return a reference to Any that can be used for downcasting fn as_any(&self) -> &dyn Any { @@ -89,6 +160,15 @@ impl WindowExpr for AggregateWindowExpr { fn field(&self) -> Result { self.aggregate.field() } + + fn expressions(&self) -> Vec> { + self.aggregate.expressions() + } + + fn create_accumulator(&self) -> Result> { + let accumulator = self.aggregate.create_accumulator()?; + Ok(Box::new(AggregateWindowAccumulator { accumulator })) + } } fn create_schema( @@ -120,12 +200,17 @@ impl WindowAggExec { }) } + /// Window expressions + pub fn window_expr(&self) -> &[Arc] { + &self.window_expr + } + /// Input plan pub fn input(&self) -> &Arc { &self.input } - /// Get the input schema before any aggregates are applied + /// Get the input schema before any window functions are applied pub fn input_schema(&self) -> SchemaRef { self.input_schema.clone() } @@ -163,7 +248,7 @@ impl ExecutionPlan for WindowAggExec { 1 => Ok(Arc::new(WindowAggExec::try_new( self.window_expr.clone(), children[0].clone(), - children[0].schema(), + self.input_schema.clone(), )?)), _ => Err(DataFusionError::Internal( "WindowAggExec wrong number of children".to_owned(), @@ -186,10 +271,258 @@ impl ExecutionPlan for WindowAggExec { )); } - // let input = self.input.execute(0).await?; + let input = self.input.execute(partition).await?; + + let stream = Box::pin(WindowAggStream::new( + self.schema.clone(), + self.window_expr.clone(), + input, + )); + Ok(stream) + } +} + +pin_project! { + /// stream for window aggregation plan + pub struct WindowAggStream { + schema: SchemaRef, + #[pin] + output: futures::channel::oneshot::Receiver>, + finished: bool, + } +} + +type WindowAccumulatorItem = Box; + +fn window_expressions( + window_expr: &[Arc], +) -> Result>>> { + Ok(window_expr + .iter() + .map(|expr| expr.expressions()) + .collect::>()) +} + +fn window_aggregate_batch( + batch: &RecordBatch, + window_accumulators: &mut [WindowAccumulatorItem], + expressions: &[Vec>], +) -> Result>> { + // 1.1 iterate accumulators and respective expressions together + // 1.2 evaluate expressions + // 1.3 update / merge window accumulators with the expressions' values + + // 1.1 + window_accumulators + .iter_mut() + .zip(expressions) + .map(|(window_acc, expr)| { + // 1.2 + let values = &expr + .iter() + .map(|e| e.evaluate(batch)) + .map(|r| r.map(|v| v.into_array(batch.num_rows()))) + .collect::>>()?; + + window_acc.scan_batch(batch.num_rows(), values) + }) + .into_iter() + .collect::>>() +} + +/// returns a vector of ArrayRefs, where each entry corresponds to one window expr +fn finalize_window_aggregation( + window_accumulators: &[WindowAccumulatorItem], +) -> Result>> { + window_accumulators + .iter() + .map(|window_accumulator| window_accumulator.evaluate()) + .collect::>>() +} + +fn create_window_accumulators( + window_expr: &[Arc], +) -> Result> { + window_expr + .iter() + .map(|expr| expr.create_accumulator()) + .collect::>>() +} + +async fn compute_window_aggregate( + schema: SchemaRef, + window_expr: Vec>, + mut input: SendableRecordBatchStream, +) -> ArrowResult { + let mut window_accumulators = create_window_accumulators(&window_expr) + .map_err(DataFusionError::into_arrow_external_error)?; + + let expressions = window_expressions(&window_expr) + .map_err(DataFusionError::into_arrow_external_error)?; + + let expressions = Arc::new(expressions); + + // TODO each element shall have some size hint + let mut accumulator: Vec> = + iter::repeat(vec![]).take(window_expr.len()).collect(); + + let mut original_batches: Vec = vec![]; + + let mut total_num_rows = 0; + + while let Some(batch) = input.next().await { + let batch = batch?; + total_num_rows += batch.num_rows(); + original_batches.push(batch.clone()); + + let batch_aggregated = + window_aggregate_batch(&batch, &mut window_accumulators, &expressions) + .map_err(DataFusionError::into_arrow_external_error)?; + accumulator.iter_mut().zip(batch_aggregated).for_each( + |(acc_for_window, window_batch)| { + if let Some(data) = window_batch { + acc_for_window.push(data); + } + }, + ); + } + + let aggregated_mapped = finalize_window_aggregation(&window_accumulators) + .map_err(DataFusionError::into_arrow_external_error)?; + + let mut columns: Vec = accumulator + .iter() + .zip(aggregated_mapped) + .map(|(acc, agg)| { + Ok(match (acc, agg) { + (acc, Some(scalar_value)) if acc.is_empty() => { + scalar_value.to_array_of_size(total_num_rows) + } + (acc, None) if !acc.is_empty() => { + let vec_array: Vec<&dyn Array> = + acc.iter().map(|arc| arc.as_ref()).collect(); + concat(&vec_array)? + } + _ => { + return Err(DataFusionError::Execution( + "Invalid window function behavior".to_owned(), + )) + } + }) + }) + .collect::>>() + .map_err(DataFusionError::into_arrow_external_error)?; + + for i in 0..(schema.fields().len() - window_expr.len()) { + let col = concat( + &original_batches + .iter() + .map(|batch| batch.column(i).as_ref()) + .collect::>(), + )?; + columns.push(col); + } + + RecordBatch::try_new(schema.clone(), columns) +} + +impl WindowAggStream { + /// Create a new WindowAggStream + pub fn new( + schema: SchemaRef, + window_expr: Vec>, + input: SendableRecordBatchStream, + ) -> Self { + let (tx, rx) = futures::channel::oneshot::channel(); + let schema_clone = schema.clone(); + tokio::spawn(async move { + let result = compute_window_aggregate(schema_clone, window_expr, input).await; + tx.send(result) + }); + + Self { + output: rx, + finished: false, + schema, + } + } +} + +impl Stream for WindowAggStream { + type Item = ArrowResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.finished { + return Poll::Ready(None); + } - Err(DataFusionError::NotImplemented( - "WindowAggExec::execute".to_owned(), - )) + // is the output ready? + let this = self.project(); + let output_poll = this.output.poll(cx); + + match output_poll { + Poll::Ready(result) => { + *this.finished = true; + // check for error in receiving channel and unwrap actual result + let result = match result { + Err(e) => Some(Err(ArrowError::ExternalError(Box::new(e)))), // error receiving + Ok(result) => Some(result), + }; + Poll::Ready(result) + } + Poll::Pending => Poll::Pending, + } + } +} + +impl RecordBatchStream for WindowAggStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.schema.clone() } } + +#[cfg(test)] +mod tests { + // use super::*; + + // /// some mock data to test windows + // fn some_data() -> (Arc, Vec) { + // // define a schema. + // let schema = Arc::new(Schema::new(vec![ + // Field::new("a", DataType::UInt32, false), + // Field::new("b", DataType::Float64, false), + // ])); + + // // define data. + // ( + // schema.clone(), + // vec![ + // RecordBatch::try_new( + // schema.clone(), + // vec![ + // Arc::new(UInt32Array::from(vec![2, 3, 4, 4])), + // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + // ], + // ) + // .unwrap(), + // RecordBatch::try_new( + // schema, + // vec![ + // Arc::new(UInt32Array::from(vec![2, 3, 3, 4])), + // Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + // ], + // ) + // .unwrap(), + // ], + // ) + // } + + // #[tokio::test] + // async fn window_function() -> Result<()> { + // let input: Arc = unimplemented!(); + // let input_schema = input.schema(); + // let window_expr = vec![]; + // WindowAggExec::try_new(window_expr, input, input_schema); + // } +} diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index e68c53b251e6..55bc88eedf9a 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -797,20 +797,31 @@ async fn csv_query_count() -> Result<()> { Ok(()) } -// FIXME uncomment this when exec is done -// #[tokio::test] -// async fn csv_query_window_with_empty_over() -> Result<()> { -// let mut ctx = ExecutionContext::new(); -// register_aggregate_csv(&mut ctx)?; -// let sql = "SELECT count(c12) over () FROM aggregate_test_100"; -// // FIXME: so far the WindowAggExec is not implemented -// // and the current behavior is to throw not implemented exception - -// let result = execute(&mut ctx, sql).await; -// let expected: Vec> = vec![]; -// assert_eq!(result, expected); -// Ok(()) -// } +#[tokio::test] +async fn csv_query_window_with_empty_over() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_aggregate_csv(&mut ctx)?; + let sql = "select \ + c2, \ + sum(c3) over (), \ + avg(c3) over (), \ + count(c3) over (), \ + max(c3) over (), \ + min(c3) over () \ + from aggregate_test_100 \ + order by c2 \ + limit 5"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![ + vec!["1", "781", "7.81", "100", "125", "-117"], + vec!["1", "781", "7.81", "100", "125", "-117"], + vec!["1", "781", "7.81", "100", "125", "-117"], + vec!["1", "781", "7.81", "100", "125", "-117"], + vec!["1", "781", "7.81", "100", "125", "-117"], + ]; + assert_eq!(expected, actual); + Ok(()) +} #[tokio::test] async fn csv_query_group_by_int_count() -> Result<()> { diff --git a/parquet-testing b/parquet-testing index 8e7badc6a381..ddd898958803 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 8e7badc6a3817a02e06d17b5d8ab6b6dc356e890 +Subproject commit ddd898958803cb89b7156c6350584d1cda0fe8de