Skip to content

Commit

Permalink
partition by refactor (apache#28)
Browse files Browse the repository at this point in the history
* partition by refactor

* minor changes

* Unnecessary tuple to Range conversion is removed

* move transpose under common
  • Loading branch information
mustafasrepo authored Dec 15, 2022
1 parent 0a42315 commit c2a1593
Show file tree
Hide file tree
Showing 13 changed files with 262 additions and 230 deletions.
12 changes: 12 additions & 0 deletions datafusion/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,15 @@ pub fn reverse_sort_options(options: SortOptions) -> SortOptions {
nulls_first: !options.nulls_first,
}
}

/// Transposes 2d vector
pub fn transpose<T>(original: Vec<Vec<T>>) -> Vec<Vec<T>> {
assert!(!original.is_empty());
let mut transposed = (0..original[0].len()).map(|_| vec![]).collect::<Vec<_>>();
for original_row in original {
for (item, transposed_row) in original_row.into_iter().zip(&mut transposed) {
transposed_row.push(item);
}
}
transposed
}
25 changes: 10 additions & 15 deletions datafusion/core/src/physical_optimizer/remove_unnecessary_sorts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,16 +364,15 @@ pub fn can_skip_sort(
.iter()
.filter(|elem| elem.is_partition)
.collect::<Vec<_>>();
let (can_skip_partition_bys, should_reverse_partition_bys) =
if partition_by_sections.is_empty() {
(true, false)
} else {
let first_reverse = partition_by_sections[0].reverse;
let can_skip_partition_bys = partition_by_sections
.iter()
.all(|c| c.is_aligned && c.reverse == first_reverse);
(can_skip_partition_bys, first_reverse)
};
let can_skip_partition_bys = if partition_by_sections.is_empty() {
true
} else {
let first_reverse = partition_by_sections[0].reverse;
let can_skip_partition_bys = partition_by_sections
.iter()
.all(|c| c.is_aligned && c.reverse == first_reverse);
can_skip_partition_bys
};
let order_by_sections = col_infos
.iter()
.filter(|elem| !elem.is_partition)
Expand All @@ -387,11 +386,7 @@ pub fn can_skip_sort(
.all(|c| c.is_aligned && c.reverse == first_reverse);
(can_skip_order_bys, first_reverse)
};
// TODO: We cannot skip partition by keys when sort direction is reversed,
// by propogating partition by sort direction to `WindowAggExec` we can skip
// these columns also. Add support for that (Use direction during partition range calculation).
let can_skip =
can_skip_order_bys && can_skip_partition_bys && !should_reverse_partition_bys;
let can_skip = can_skip_order_bys && can_skip_partition_bys;
Ok((can_skip, should_reverse_order_bys))
}

Expand Down
79 changes: 76 additions & 3 deletions datafusion/core/src/physical_plan/windows/window_agg_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,23 @@ use crate::physical_plan::{
ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream,
SendableRecordBatchStream, Statistics, WindowExpr,
};
use arrow::compute::concat_batches;
use arrow::compute::{
concat, concat_batches, lexicographical_partition_ranges, SortColumn,
};
use arrow::{
array::ArrayRef,
datatypes::{Schema, SchemaRef},
error::{ArrowError, Result as ArrowResult},
record_batch::RecordBatch,
};
use datafusion_common::{transpose, DataFusionError};
use datafusion_physical_expr::rewrite::TreeNodeRewritable;
use datafusion_physical_expr::EquivalentClass;
use futures::stream::Stream;
use futures::{ready, StreamExt};
use log::debug;
use std::any::Any;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -131,6 +135,25 @@ impl WindowAggExec {
pub fn input_schema(&self) -> SchemaRef {
self.input_schema.clone()
}

/// Get Partition Columns
pub fn partition_by_sort_keys(&self) -> Result<Vec<PhysicalSortExpr>> {
// All window exprs have same partition by hance we just use first one
let partition_by = self.window_expr()[0].partition_by();
let mut partition_columns = vec![];
for elem in partition_by {
if let Some(sort_keys) = &self.sort_keys {
for a in sort_keys {
if a.expr.eq(elem) {
partition_columns.push(a.clone());
break;
}
}
}
}
assert_eq!(partition_by.len(), partition_columns.len());
Ok(partition_columns)
}
}

impl ExecutionPlan for WindowAggExec {
Expand Down Expand Up @@ -253,6 +276,7 @@ impl ExecutionPlan for WindowAggExec {
self.window_expr.clone(),
input,
BaselineMetrics::new(&self.metrics, partition),
self.partition_by_sort_keys()?,
));
Ok(stream)
}
Expand Down Expand Up @@ -337,6 +361,7 @@ pub struct WindowAggStream {
batches: Vec<RecordBatch>,
finished: bool,
window_expr: Vec<Arc<dyn WindowExpr>>,
partition_by_sort_keys: Vec<PhysicalSortExpr>,
baseline_metrics: BaselineMetrics,
}

Expand All @@ -347,6 +372,7 @@ impl WindowAggStream {
window_expr: Vec<Arc<dyn WindowExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
partition_by_sort_keys: Vec<PhysicalSortExpr>,
) -> Self {
Self {
schema,
Expand All @@ -355,6 +381,7 @@ impl WindowAggStream {
finished: false,
window_expr,
baseline_metrics,
partition_by_sort_keys,
}
}

Expand All @@ -369,15 +396,61 @@ impl WindowAggStream {
let batch = concat_batches(&self.input.schema(), &self.batches)?;

// calculate window cols
let mut columns = compute_window_aggregates(&self.window_expr, &batch)
.map_err(|e| ArrowError::ExternalError(Box::new(e)))?;
let partition_columns = self.partition_columns(&batch)?;
let partition_points =
self.evaluate_partition_points(batch.num_rows(), &partition_columns)?;

let mut partition_results = vec![];
for partition_point in partition_points {
let length = partition_point.end - partition_point.start;
partition_results.push(
compute_window_aggregates(
&self.window_expr,
&batch.slice(partition_point.start, length),
)
.map_err(|e| ArrowError::ExternalError(Box::new(e)))?,
)
}
let mut columns = transpose(partition_results)
.iter()
.map(|elems| concat(&elems.iter().map(|x| x.as_ref()).collect::<Vec<_>>()))
.collect::<Vec<_>>()
.into_iter()
.collect::<ArrowResult<Vec<ArrayRef>>>()?;

// combine with the original cols
// note the setup of window aggregates is that they newly calculated window
// expressions are always prepended to the columns
columns.extend_from_slice(batch.columns());
RecordBatch::try_new(self.schema.clone(), columns)
}

/// Get Partition Columns
pub fn partition_columns(&self, batch: &RecordBatch) -> Result<Vec<SortColumn>> {
self.partition_by_sort_keys
.iter()
.map(|elem| elem.evaluate_to_sort_column(batch))
.collect::<Result<Vec<_>>>()
}

/// evaluate the partition points given the sort columns; if the sort columns are
/// empty then the result will be a single element vec of the whole column rows.
fn evaluate_partition_points(
&self,
num_rows: usize,
partition_columns: &[SortColumn],
) -> Result<Vec<Range<usize>>> {
if partition_columns.is_empty() {
Ok(vec![Range {
start: 0,
end: num_rows,
}])
} else {
Ok(lexicographical_partition_ranges(partition_columns)
.map_err(DataFusionError::ArrowError)?
.collect::<Vec<_>>())
}
}
}

impl Stream for WindowAggStream {
Expand Down
54 changes: 54 additions & 0 deletions datafusion/core/tests/sql/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2177,3 +2177,57 @@ async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> {

Ok(())
}

#[tokio::test]
async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Result<()> {
let config = SessionConfig::new().with_repartition_windows(false);
let ctx = SessionContext::with_config(config);
register_aggregate_csv(&ctx).await?;
let sql = "SELECT c3,
SUM(c9) OVER(ORDER BY c3 DESC, c9 DESC, c2 ASC) as sum1,
SUM(c9) OVER(PARTITION BY c3 ORDER BY c9 DESC ) as sum2
FROM aggregate_test_100
LIMIT 5";

let msg = format!("Creating logical plan for '{}'", sql);
let plan = ctx.create_logical_plan(sql).expect(&msg);
let state = ctx.state();
let logical_plan = state.optimize(&plan)?;
let physical_plan = state.create_physical_plan(&logical_plan).await?;
let formatted = displayable(physical_plan.as_ref()).indent().to_string();
// Only 1 SortExec was added
let expected = {
vec![
"ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]",
" GlobalLimitExec: skip=0, fetch=5",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]",
" WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }]",
" SortExec: [c3@1 DESC,c9@2 DESC,c2@0 ASC NULLS LAST]",
]
};

let actual: Vec<&str> = formatted.trim().lines().collect();
let actual_len = actual.len();
let actual_trim_last = &actual[..actual_len - 1];
assert_eq!(
expected, actual_trim_last,
"\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
expected, actual
);

let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-----+-------------+------------+",
"| c3 | sum1 | sum2 |",
"+-----+-------------+------------+",
"| 125 | 3625286410 | 3625286410 |",
"| 123 | 7192027599 | 3566741189 |",
"| 123 | 9784358155 | 6159071745 |",
"| 122 | 13845993262 | 4061635107 |",
"| 120 | 16676974334 | 2830981072 |",
"+-----+-------------+------------+",
];
assert_batches_eq!(expected, &actual);

Ok(())
}
89 changes: 41 additions & 48 deletions datafusion/physical-expr/src/window/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
use std::any::Any;
use std::iter::IntoIterator;
use std::ops::Range;
use std::sync::Arc;

use arrow::array::Array;
Expand Down Expand Up @@ -90,58 +91,50 @@ impl WindowExpr for AggregateWindowExpr {
}

fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
let partition_columns = self.partition_columns(batch)?;
let partition_points =
self.evaluate_partition_points(batch.num_rows(), &partition_columns)?;
let sort_options: Vec<SortOptions> =
self.order_by.iter().map(|o| o.options).collect();
let mut row_wise_results: Vec<ScalarValue> = vec![];
for partition_range in &partition_points {
let mut accumulator = self.aggregate.create_accumulator()?;
let length = partition_range.end - partition_range.start;
let (values, order_bys) =
self.get_values_orderbys(&batch.slice(partition_range.start, length))?;

let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
let mut last_range: (usize, usize) = (0, 0);

// We iterate on each row to perform a running calculation.
// First, cur_range is calculated, then it is compared with last_range.
for i in 0..length {
let cur_range = window_frame_ctx.calculate_range(
&order_bys,
&sort_options,
length,
i,
)?;
let value = if cur_range.0 == cur_range.1 {
// We produce None if the window is empty.
ScalarValue::try_from(self.aggregate.field()?.data_type())?
} else {
// Accumulate any new rows that have entered the window:
let update_bound = cur_range.1 - last_range.1;
if update_bound > 0 {
let update: Vec<ArrayRef> = values
.iter()
.map(|v| v.slice(last_range.1, update_bound))
.collect();
accumulator.update_batch(&update)?
}
// Remove rows that have now left the window:
let retract_bound = cur_range.0 - last_range.0;
if retract_bound > 0 {
let retract: Vec<ArrayRef> = values
.iter()
.map(|v| v.slice(last_range.0, retract_bound))
.collect();
accumulator.retract_batch(&retract)?
}
accumulator.evaluate()?
};
row_wise_results.push(value);
last_range = cur_range;
}

let mut accumulator = self.aggregate.create_accumulator()?;
let length = batch.num_rows();
let (values, order_bys) = self.get_values_orderbys(batch)?;

let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame);
let mut last_range = Range { start: 0, end: 0 };

// We iterate on each row to perform a running calculation.
// First, cur_range is calculated, then it is compared with last_range.
for i in 0..length {
let cur_range =
window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?;
let value = if cur_range.end == cur_range.start {
// We produce None if the window is empty.
ScalarValue::try_from(self.aggregate.field()?.data_type())?
} else {
// Accumulate any new rows that have entered the window:
let update_bound = cur_range.end - last_range.end;
if update_bound > 0 {
let update: Vec<ArrayRef> = values
.iter()
.map(|v| v.slice(last_range.end, update_bound))
.collect();
accumulator.update_batch(&update)?
}
// Remove rows that have now left the window:
let retract_bound = cur_range.start - last_range.start;
if retract_bound > 0 {
let retract: Vec<ArrayRef> = values
.iter()
.map(|v| v.slice(last_range.start, retract_bound))
.collect();
accumulator.retract_batch(&retract)?
}
accumulator.evaluate()?
};
row_wise_results.push(value);
last_range = cur_range;
}

ScalarValue::iter_to_array(row_wise_results.into_iter())
}

Expand Down
Loading

0 comments on commit c2a1593

Please sign in to comment.