Skip to content

Commit

Permalink
add window expression stream, delegated window aggregation to aggrega…
Browse files Browse the repository at this point in the history
…te functions, and implement `row_number` (#375)

* Squashed commit of the following:

commit 7fb3640
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 16:38:25 2021 +0800

    row number done

commit 1723926
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 16:05:50 2021 +0800

    add row number

commit bf5b8a5
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 15:04:49 2021 +0800

    save

commit d2ce852
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 14:53:05 2021 +0800

    add streams

commit 0a861a7
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 22:28:34 2021 +0800

    save stream

commit a9121af
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 22:01:51 2021 +0800

    update unit test

commit 2af2a27
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 14:25:12 2021 +0800

    fix unit test

commit bb57c76
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 14:23:34 2021 +0800

    use upper case

commit 5d96e52
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 14:16:16 2021 +0800

    fix unit test

commit 1ecae8f
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 12:27:26 2021 +0800

    fix unit test

commit bc2271d
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 10:04:29 2021 +0800

    fix error

commit 880b94f
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 08:24:00 2021 +0800

    fix unit test

commit 4e792e1
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 08:05:17 2021 +0800

    fix test

commit c36c04a
Author: Jiayu Liu <[email protected]>
Date:   Fri May 21 00:07:54 2021 +0800

    add more tests

commit f5e64de
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 23:41:36 2021 +0800

    update

commit a1eae86
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 23:36:15 2021 +0800

    enrich unit test

commit 0d2a214
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 23:25:43 2021 +0800

    adding filter by todo

commit 8b486d5
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 23:17:22 2021 +0800

    adding more built-in functions

commit abf08cd
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 22:36:27 2021 +0800

    Update datafusion/src/physical_plan/window_functions.rs

    Co-authored-by: Andrew Lamb <[email protected]>

commit 0cbca53
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 22:34:57 2021 +0800

    Update datafusion/src/physical_plan/window_functions.rs

    Co-authored-by: Andrew Lamb <[email protected]>

commit 831c069
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 22:34:04 2021 +0800

    Update datafusion/src/logical_plan/builder.rs

    Co-authored-by: Andrew Lamb <[email protected]>

commit f70c739
Author: Jiayu Liu <[email protected]>
Date:   Thu May 20 22:33:04 2021 +0800

    Update datafusion/src/logical_plan/builder.rs

    Co-authored-by: Andrew Lamb <[email protected]>

commit 3ee87aa
Author: Jiayu Liu <[email protected]>
Date:   Wed May 19 22:55:08 2021 +0800

    fix unit test

commit 5c4d92d
Author: Jiayu Liu <[email protected]>
Date:   Wed May 19 22:48:26 2021 +0800

    fix clippy

commit a0b7526
Author: Jiayu Liu <[email protected]>
Date:   Wed May 19 22:46:38 2021 +0800

    fix unused imports

commit 1d3b076
Author: Jiayu Liu <[email protected]>
Date:   Thu May 13 18:51:14 2021 +0800

    add window expr

* fix unit test
  • Loading branch information
jimexist authored May 26, 2021
1 parent 3593d1f commit 4b1e9e6
Show file tree
Hide file tree
Showing 11 changed files with 736 additions and 75 deletions.
29 changes: 29 additions & 0 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,35 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn window() -> Result<()> {
let results = execute(
"SELECT c1, c2, SUM(c2) OVER (), COUNT(c2) OVER (), MAX(c2) OVER (), MIN(c2) OVER (), AVG(c2) OVER () FROM test ORDER BY c1, c2 LIMIT 5",
4,
)
.await?;
// result in one batch, although e.g. having 2 batches do not change
// result semantics, having a len=1 assertion upfront keeps surprises
// at bay
assert_eq!(results.len(), 1);

let expected = vec![
"+----+----+---------+-----------+---------+---------+---------+",
"| c1 | c2 | SUM(c2) | COUNT(c2) | MAX(c2) | MIN(c2) | AVG(c2) |",
"+----+----+---------+-----------+---------+---------+---------+",
"| 0 | 1 | 220 | 40 | 10 | 1 | 5.5 |",
"| 0 | 2 | 220 | 40 | 10 | 1 | 5.5 |",
"| 0 | 3 | 220 | 40 | 10 | 1 | 5.5 |",
"| 0 | 4 | 220 | 40 | 10 | 1 | 5.5 |",
"| 0 | 5 | 220 | 40 | 10 | 1 | 5.5 |",
"+----+----+---------+-----------+---------+---------+---------+",
];

// window function shall respect ordering
assert_batches_eq!(expected, &results);
Ok(())
}

#[tokio::test]
async fn aggregate() -> Result<()> {
let results = execute("SELECT SUM(c1), SUM(c2) FROM test", 4).await?;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ mod min_max;
mod negative;
mod not;
mod nullif;
mod row_number;
mod sum;
mod try_cast;

Expand All @@ -58,6 +59,7 @@ pub use min_max::{Max, Min};
pub use negative::{negative, NegativeExpr};
pub use not::{not, NotExpr};
pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
pub use row_number::RowNumber;
pub use sum::{sum_return_type, Sum};
pub use try_cast::{try_cast, TryCastExpr};
/// returns the name of the state
Expand Down
174 changes: 174 additions & 0 deletions datafusion/src/physical_plan/expressions/row_number.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Defines physical expression for `row_number` that can evaluated at runtime during query execution
use crate::error::Result;
use crate::physical_plan::{
window_functions::BuiltInWindowFunctionExpr, PhysicalExpr, WindowAccumulator,
};
use crate::scalar::ScalarValue;
use arrow::array::{ArrayRef, UInt64Array};
use arrow::datatypes::{DataType, Field};
use std::any::Any;
use std::sync::Arc;

/// row_number expression
#[derive(Debug)]
pub struct RowNumber {
name: String,
}

impl RowNumber {
/// Create a new ROW_NUMBER function
pub fn new(name: String) -> Self {
Self { name }
}
}

impl BuiltInWindowFunctionExpr for RowNumber {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}

fn field(&self) -> Result<Field> {
let nullable = false;
let data_type = DataType::UInt64;
Ok(Field::new(&self.name(), data_type, nullable))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![]
}

fn name(&self) -> &str {
self.name.as_str()
}

fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>> {
Ok(Box::new(RowNumberAccumulator::new()))
}
}

#[derive(Debug)]
struct RowNumberAccumulator {
row_number: u64,
}

impl RowNumberAccumulator {
/// new row_number accumulator
pub fn new() -> Self {
// row number is 1 based
Self { row_number: 1 }
}
}

impl WindowAccumulator for RowNumberAccumulator {
fn scan(&mut self, _values: &[ScalarValue]) -> Result<Option<ScalarValue>> {
let result = Some(ScalarValue::UInt64(Some(self.row_number)));
self.row_number += 1;
Ok(result)
}

fn scan_batch(
&mut self,
num_rows: usize,
_values: &[ArrayRef],
) -> Result<Option<ArrayRef>> {
let new_row_number = self.row_number + (num_rows as u64);
// TODO: probably would be nice to have a (optimized) kernel for this at some point to
// generate an array like this.
let result = UInt64Array::from_iter_values(self.row_number..new_row_number);
self.row_number = new_row_number;
Ok(Some(Arc::new(result)))
}

fn evaluate(&self) -> Result<Option<ScalarValue>> {
Ok(None)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};

#[test]
fn row_number_all_null() -> Result<()> {
let arr: ArrayRef = Arc::new(BooleanArray::from(vec![
None, None, None, None, None, None, None, None,
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let row_number = Arc::new(RowNumber::new("row_number".to_owned()));

let mut acc = row_number.create_accumulator()?;
let expr = row_number.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;

let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(true, result.is_some());

let result = result.unwrap();
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);

let result = acc.evaluate()?;
assert_eq!(false, result.is_some());
Ok(())
}

#[test]
fn row_number_all_values() -> Result<()> {
let arr: ArrayRef = Arc::new(BooleanArray::from(vec![
true, false, true, false, false, true, false, true,
]));
let schema = Schema::new(vec![Field::new("arr", DataType::Boolean, false)]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![arr])?;

let row_number = Arc::new(RowNumber::new("row_number".to_owned()));

let mut acc = row_number.create_accumulator()?;
let expr = row_number.expressions();
let values = expr
.iter()
.map(|e| e.evaluate(&batch))
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;

let result = acc.scan_batch(batch.num_rows(), &values)?;
assert_eq!(true, result.is_some());

let result = result.unwrap();
let result = result.as_any().downcast_ref::<UInt64Array>().unwrap();
let result = result.values();
assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result);

let result = acc.evaluate()?;
assert_eq!(false, result.is_some());
Ok(())
}
}
7 changes: 4 additions & 3 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ impl GroupedHashAggregateStream {
tx.send(result)
});

GroupedHashAggregateStream {
Self {
schema,
output: rx,
finished: false,
Expand Down Expand Up @@ -825,7 +825,8 @@ fn aggregate_expressions(
}

pin_project! {
struct HashAggregateStream {
/// stream struct for hash aggregation
pub struct HashAggregateStream {
schema: SchemaRef,
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
Expand Down Expand Up @@ -878,7 +879,7 @@ impl HashAggregateStream {
tx.send(result)
});

HashAggregateStream {
Self {
schema,
output: rx,
finished: false,
Expand Down
81 changes: 73 additions & 8 deletions datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,23 @@

//! Traits for physical query plan, supporting parallel execution for partitioned relations.
use std::fmt::{self, Debug, Display};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::{any::Any, pin::Pin};

use crate::execution::context::ExecutionContextState;
use crate::logical_plan::LogicalPlan;
use crate::{error::Result, scalar::ScalarValue};
use crate::{
error::{DataFusionError, Result},
scalar::ScalarValue,
};
use arrow::datatypes::{DataType, Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use arrow::{array::ArrayRef, datatypes::Field};

use async_trait::async_trait;
pub use display::DisplayFormatType;
use futures::stream::Stream;
use std::fmt::{self, Debug, Display};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::{any::Any, pin::Pin};

use self::{display::DisplayableExecutionPlan, merge::MergeExec};
use hashbrown::HashMap;
Expand Down Expand Up @@ -457,10 +458,22 @@ pub trait WindowExpr: Send + Sync + Debug {
fn name(&self) -> &str {
"WindowExpr: default name"
}

/// the accumulator used to accumulate values from the expressions.
/// the accumulator expects the same number of arguments as `expressions` and must
/// return states with the same description as `state_fields`
fn create_accumulator(&self) -> Result<Box<dyn WindowAccumulator>>;

/// expressions that are passed to the WindowAccumulator.
/// Functions which take a single input argument, such as `sum`, return a single [`Expr`],
/// others (e.g. `cov`) return many.
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>>;
}

/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
/// generically accumulates values. An accumulator knows how to:
/// generically accumulates values.
///
/// An accumulator knows how to:
/// * update its state from inputs via `update`
/// * convert its internal state to a vector of scalar values
/// * update its state from multiple accumulators' states via `merge`
Expand Down Expand Up @@ -509,6 +522,58 @@ pub trait Accumulator: Send + Sync + Debug {
fn evaluate(&self) -> Result<ScalarValue>;
}

/// A window accumulator represents a stateful object that lives throughout the evaluation of multiple
/// rows and generically accumulates values.
///
/// An accumulator knows how to:
/// * update its state from inputs via `update`
/// * convert its internal state to a vector of scalar values
/// * update its state from multiple accumulators' states via `merge`
/// * compute the final value from its internal state via `evaluate`
pub trait WindowAccumulator: Send + Sync + Debug {
/// scans the accumulator's state from a vector of scalars, similar to Accumulator it also
/// optionally generates values.
fn scan(&mut self, values: &[ScalarValue]) -> Result<Option<ScalarValue>>;

/// scans the accumulator's state from a vector of arrays.
fn scan_batch(
&mut self,
num_rows: usize,
values: &[ArrayRef],
) -> Result<Option<ArrayRef>> {
if values.is_empty() {
return Ok(None);
};
// transpose columnar to row based so that we can apply window
let result = (0..num_rows)
.map(|index| {
let v = values
.iter()
.map(|array| ScalarValue::try_from_array(array, index))
.collect::<Result<Vec<_>>>()?;
self.scan(&v)
})
.collect::<Result<Vec<Option<ScalarValue>>>>()?
.into_iter()
.collect::<Option<Vec<ScalarValue>>>();

Ok(match result {
Some(arr) if num_rows == arr.len() => Some(ScalarValue::iter_to_array(&arr)?),
None => None,
Some(arr) => {
return Err(DataFusionError::Internal(format!(
"expect scan batch to return {:?} rows, but got {:?}",
num_rows,
arr.len()
)))
}
})
}

/// returns its value based on its current state.
fn evaluate(&self) -> Result<Option<ScalarValue>>;
}

pub mod aggregates;
pub mod array_expressions;
pub mod coalesce_batches;
Expand Down
4 changes: 3 additions & 1 deletion datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,10 @@ impl DefaultPhysicalPlanner {
// Initially need to perform the aggregate and then merge the partitions
let input_exec = self.create_initial_plan(input, ctx_state)?;
let input_schema = input_exec.schema();
let physical_input_schema = input_exec.as_ref().schema();

let logical_input_schema = input.as_ref().schema();
let physical_input_schema = input_exec.as_ref().schema();

let window_expr = window_expr
.iter()
.map(|e| {
Expand Down
1 change: 1 addition & 0 deletions datafusion/src/physical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ fn sort_batches(
}

pin_project! {
/// stream for sort plan
struct SortStream {
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,
Expand Down
Loading

0 comments on commit 4b1e9e6

Please sign in to comment.