From 8ec511e51cdfdef8e3f79116076fbb962f53f887 Mon Sep 17 00:00:00 2001 From: Mustafa akur <106137913+mustafasrepo@users.noreply.github.com> Date: Tue, 27 Dec 2022 00:17:32 +0300 Subject: [PATCH] Unnecessary SortExec removal rule from Physical Plan (#4691) * Sort Removal rule initial commit * move ordering satisfy to the util * update test and change repartition maintain_input_order impl * simplifications * partition by refactor (#28) * partition by refactor * minor changes * Unnecessary tuple to Range conversion is removed * move transpose under common * Add naive sort removal rule * Add todo for finer Sort removal handling * Refactors to improve readability and reduce nesting * reverse expr returns Option (no need for support check) * fix tests * partition by and order by no longer ends up at the same window group * Refactor to simplify code * Better comments, change method names * Resolve errors introduced by syncing * address reviews * address reviews * Rename to less confusing OptimizeSorts Co-authored-by: Mehmet Ozan Kabak --- datafusion/common/src/lib.rs | 10 + datafusion/core/src/execution/context.rs | 7 + .../src/physical_optimizer/enforcement.rs | 76 +- datafusion/core/src/physical_optimizer/mod.rs | 1 + .../src/physical_optimizer/optimize_sorts.rs | 887 ++++++++++++++++++ .../core/src/physical_optimizer/utils.rs | 75 ++ datafusion/core/src/physical_plan/common.rs | 25 + datafusion/core/src/physical_plan/planner.rs | 2 +- .../core/src/physical_plan/repartition.rs | 11 +- .../physical_plan/windows/window_agg_exec.rs | 81 +- datafusion/core/tests/sql/explain_analyze.rs | 5 - datafusion/core/tests/sql/window.rs | 582 +++++++++++- datafusion/expr/src/logical_plan/builder.rs | 2 +- datafusion/expr/src/utils.rs | 62 +- datafusion/expr/src/window_frame.rs | 29 + .../physical-expr/src/aggregate/count.rs | 6 +- datafusion/physical-expr/src/aggregate/mod.rs | 8 + datafusion/physical-expr/src/aggregate/sum.rs | 6 +- .../physical-expr/src/window/aggregate.rs | 94 +- .../physical-expr/src/window/built_in.rs | 67 +- .../window/built_in_window_function_expr.rs | 6 + .../physical-expr/src/window/cume_dist.rs | 25 +- .../physical-expr/src/window/lead_lag.rs | 27 +- .../physical-expr/src/window/nth_value.rs | 14 + datafusion/physical-expr/src/window/ntile.rs | 9 +- .../src/window/partition_evaluator.rs | 55 +- datafusion/physical-expr/src/window/rank.rs | 28 +- .../physical-expr/src/window/row_number.rs | 18 +- .../src/window/sliding_aggregate.rs | 111 ++- .../physical-expr/src/window/window_expr.rs | 51 +- .../src/window/window_frame_state.rs | 15 +- 31 files changed, 2017 insertions(+), 378 deletions(-) create mode 100644 datafusion/core/src/physical_optimizer/optimize_sorts.rs diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 60d69324913b..392fa3f25a67 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -30,6 +30,7 @@ pub mod stats; mod table_reference; pub mod test_util; +use arrow::compute::SortOptions; pub use column::Column; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema}; pub use error::{field_not_found, DataFusionError, Result, SchemaError}; @@ -63,3 +64,12 @@ macro_rules! downcast_value { })? }}; } + +/// Computes the "reverse" of given `SortOptions`. +// TODO: If/when arrow supports `!` for `SortOptions`, we can remove this. +pub fn reverse_sort_options(options: SortOptions) -> SortOptions { + SortOptions { + descending: !options.descending, + nulls_first: !options.nulls_first, + } +} diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 098dafdc0280..978bde2a2ed8 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -100,6 +100,7 @@ use url::Url; use crate::catalog::listing_schema::ListingSchemaProvider; use crate::datasource::object_store::ObjectStoreUrl; use crate::execution::memory_pool::MemoryPool; +use crate::physical_optimizer::optimize_sorts::OptimizeSorts; use uuid::Uuid; use super::options::{ @@ -1580,6 +1581,12 @@ impl SessionState { // To make sure the SinglePartition is satisfied, run the BasicEnforcement again, originally it was the AddCoalescePartitionsExec here. physical_optimizers.push(Arc::new(BasicEnforcement::new())); + // `BasicEnforcement` stage conservatively inserts `SortExec`s to satisfy ordering requirements. + // However, a deeper analysis may sometimes reveal that such a `SortExec` is actually unnecessary. + // These cases typically arise when we have reversible `WindowAggExec`s or deep subqueries. The + // rule below performs this analysis and removes unnecessary `SortExec`s. + physical_optimizers.push(Arc::new(OptimizeSorts::new())); + let mut this = SessionState { session_id, optimizer: Optimizer::new(), diff --git a/datafusion/core/src/physical_optimizer/enforcement.rs b/datafusion/core/src/physical_optimizer/enforcement.rs index 06832ac2498a..4a496a3ef9f1 100644 --- a/datafusion/core/src/physical_optimizer/enforcement.rs +++ b/datafusion/core/src/physical_optimizer/enforcement.rs @@ -20,6 +20,7 @@ //! use crate::config::OPT_TOP_DOWN_JOIN_KEY_REORDERING; use crate::error::Result; +use crate::physical_optimizer::utils::{add_sort_above_child, ordering_satisfy}; use crate::physical_optimizer::PhysicalOptimizerRule; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -29,8 +30,7 @@ use crate::physical_plan::joins::{ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::rewrite::TreeNodeRewritable; -use crate::physical_plan::sorts::sort::SortExec; -use crate::physical_plan::sorts::sort::SortOptions; +use crate::physical_plan::sorts::sort::{SortExec, SortOptions}; use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::Partitioning; @@ -42,9 +42,8 @@ use datafusion_physical_expr::equivalence::EquivalenceProperties; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::expressions::NoOp; use datafusion_physical_expr::{ - expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, - normalize_sort_expr_with_equivalence_properties, AggregateExpr, PhysicalExpr, - PhysicalSortExpr, + expr_list_eq_strict_order, normalize_expr_with_equivalence_properties, AggregateExpr, + PhysicalExpr, }; use std::collections::HashMap; use std::sync::Arc; @@ -919,9 +918,7 @@ fn ensure_distribution_and_ordering( Ok(child) } else { let sort_expr = required.unwrap().to_vec(); - Ok(Arc::new(SortExec::new_with_partitioning( - sort_expr, child, true, None, - )) as Arc) + add_sort_above_child(&child, sort_expr) } }) .collect(); @@ -929,61 +926,6 @@ fn ensure_distribution_and_ordering( with_new_children_if_necessary(plan, new_children?) } -/// Check the required ordering requirements are satisfied by the provided PhysicalSortExprs. -fn ordering_satisfy EquivalenceProperties>( - provided: Option<&[PhysicalSortExpr]>, - required: Option<&[PhysicalSortExpr]>, - equal_properties: F, -) -> bool { - match (provided, required) { - (_, None) => true, - (None, Some(_)) => false, - (Some(provided), Some(required)) => { - if required.len() > provided.len() { - false - } else { - let fast_match = required - .iter() - .zip(provided.iter()) - .all(|(order1, order2)| order1.eq(order2)); - - if !fast_match { - let eq_properties = equal_properties(); - let eq_classes = eq_properties.classes(); - if !eq_classes.is_empty() { - let normalized_required_exprs = required - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - let normalized_provided_exprs = provided - .iter() - .map(|e| { - normalize_sort_expr_with_equivalence_properties( - e.clone(), - eq_classes, - ) - }) - .collect::>(); - normalized_required_exprs - .iter() - .zip(normalized_provided_exprs.iter()) - .all(|(order1, order2)| order1.eq(order2)) - } else { - fast_match - } - } else { - fast_match - } - } - } - } -} - #[derive(Debug, Clone)] struct JoinKeyPairs { left_keys: Vec>, @@ -1063,10 +1005,10 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_expr::logical_plan::JoinType; use datafusion_expr::Operator; - use datafusion_physical_expr::expressions::binary; - use datafusion_physical_expr::expressions::lit; - use datafusion_physical_expr::expressions::Column; - use datafusion_physical_expr::{expressions, PhysicalExpr}; + use datafusion_physical_expr::{ + expressions, expressions::binary, expressions::lit, expressions::Column, + PhysicalExpr, PhysicalSortExpr, + }; use std::ops::Deref; use super::*; diff --git a/datafusion/core/src/physical_optimizer/mod.rs b/datafusion/core/src/physical_optimizer/mod.rs index 36b00a0e01bc..0fd0600fbe67 100644 --- a/datafusion/core/src/physical_optimizer/mod.rs +++ b/datafusion/core/src/physical_optimizer/mod.rs @@ -22,6 +22,7 @@ pub mod aggregate_statistics; pub mod coalesce_batches; pub mod enforcement; pub mod join_selection; +pub mod optimize_sorts; pub mod optimizer; pub mod pruning; pub mod repartition; diff --git a/datafusion/core/src/physical_optimizer/optimize_sorts.rs b/datafusion/core/src/physical_optimizer/optimize_sorts.rs new file mode 100644 index 000000000000..cb421b7b82fd --- /dev/null +++ b/datafusion/core/src/physical_optimizer/optimize_sorts.rs @@ -0,0 +1,887 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! OptimizeSorts optimizer rule inspects [SortExec]s in the given physical +//! plan and removes the ones it can prove unnecessary. The rule can work on +//! valid *and* invalid physical plans with respect to sorting requirements, +//! but always produces a valid physical plan in this sense. +//! +//! A non-realistic but easy to follow example: Assume that we somehow get the fragment +//! "SortExec: [nullable_col@0 ASC]", +//! " SortExec: [non_nullable_col@1 ASC]", +//! in the physical plan. The first sort is unnecessary since its result is overwritten +//! by another SortExec. Therefore, this rule removes it from the physical plan. +use crate::error::Result; +use crate::physical_optimizer::utils::{ + add_sort_above_child, ordering_satisfy, ordering_satisfy_concrete, +}; +use crate::physical_optimizer::PhysicalOptimizerRule; +use crate::physical_plan::rewrite::TreeNodeRewritable; +use crate::physical_plan::sorts::sort::SortExec; +use crate::physical_plan::windows::WindowAggExec; +use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use crate::prelude::SessionConfig; +use arrow::datatypes::SchemaRef; +use datafusion_common::{reverse_sort_options, DataFusionError}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use itertools::izip; +use std::iter::zip; +use std::sync::Arc; + +/// This rule inspects SortExec's in the given physical plan and removes the +/// ones it can prove unnecessary. +#[derive(Default)] +pub struct OptimizeSorts {} + +impl OptimizeSorts { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +/// This is a "data class" we use within the [OptimizeSorts] rule that +/// tracks the closest `SortExec` descendant for every child of a plan. +#[derive(Debug, Clone)] +struct PlanWithCorrespondingSort { + plan: Arc, + // For every child, keep a vector of `ExecutionPlan`s starting from the + // closest `SortExec` till the current plan. The first index of the tuple is + // the child index of the plan -- we need this information as we make updates. + sort_onwards: Vec)>>, +} + +impl PlanWithCorrespondingSort { + pub fn new(plan: Arc) -> Self { + let length = plan.children().len(); + PlanWithCorrespondingSort { + plan, + sort_onwards: vec![vec![]; length], + } + } + + pub fn children(&self) -> Vec { + self.plan + .children() + .into_iter() + .map(|child| PlanWithCorrespondingSort::new(child)) + .collect() + } +} + +impl TreeNodeRewritable for PlanWithCorrespondingSort { + fn map_children(self, transform: F) -> Result + where + F: FnMut(Self) -> Result, + { + let children = self.children(); + if children.is_empty() { + Ok(self) + } else { + let children_requirements = children + .into_iter() + .map(transform) + .collect::>>()?; + let children_plans = children_requirements + .iter() + .map(|elem| elem.plan.clone()) + .collect::>(); + let sort_onwards = children_requirements + .iter() + .map(|item| { + if item.sort_onwards.is_empty() { + vec![] + } else { + // TODO: When `maintains_input_order` returns Vec, + // pass the order-enforcing sort upwards. + item.sort_onwards[0].clone() + } + }) + .collect::>(); + let plan = with_new_children_if_necessary(self.plan, children_plans)?; + Ok(PlanWithCorrespondingSort { plan, sort_onwards }) + } + } +} + +impl PhysicalOptimizerRule for OptimizeSorts { + fn optimize( + &self, + plan: Arc, + _config: &SessionConfig, + ) -> Result> { + // Execute a post-order traversal to adjust input key ordering: + let plan_requirements = PlanWithCorrespondingSort::new(plan); + let adjusted = plan_requirements.transform_up(&optimize_sorts)?; + Ok(adjusted.plan) + } + + fn name(&self) -> &str { + "OptimizeSorts" + } + + fn schema_check(&self) -> bool { + true + } +} + +fn optimize_sorts( + requirements: PlanWithCorrespondingSort, +) -> Result> { + // Perform naive analysis at the beginning -- remove already-satisfied sorts: + if let Some(result) = analyze_immediate_sort_removal(&requirements)? { + return Ok(Some(result)); + } + let plan = &requirements.plan; + let mut new_children = plan.children().clone(); + let mut new_onwards = requirements.sort_onwards.clone(); + for (idx, (child, sort_onwards, required_ordering)) in izip!( + new_children.iter_mut(), + new_onwards.iter_mut(), + plan.required_input_ordering() + ) + .enumerate() + { + let physical_ordering = child.output_ordering(); + match (required_ordering, physical_ordering) { + (Some(required_ordering), Some(physical_ordering)) => { + let is_ordering_satisfied = ordering_satisfy_concrete( + physical_ordering, + required_ordering, + || child.equivalence_properties(), + ); + if !is_ordering_satisfied { + // Make sure we preserve the ordering requirements: + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; + let sort_expr = required_ordering.to_vec(); + *child = add_sort_above_child(child, sort_expr)?; + sort_onwards.push((idx, child.clone())) + } else if let [first, ..] = sort_onwards.as_slice() { + // The ordering requirement is met, we can analyze if there is an unnecessary sort: + let sort_any = first.1.clone(); + let sort_exec = convert_to_sort_exec(&sort_any)?; + let sort_output_ordering = sort_exec.output_ordering(); + let sort_input_ordering = sort_exec.input().output_ordering(); + // Simple analysis: Does the input of the sort in question already satisfy the ordering requirements? + if ordering_satisfy(sort_input_ordering, sort_output_ordering, || { + sort_exec.input().equivalence_properties() + }) { + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; + } else if let Some(window_agg_exec) = + requirements.plan.as_any().downcast_ref::() + { + // For window expressions, we can remove some sorts when we can + // calculate the result in reverse: + if let Some(res) = analyze_window_sort_removal( + window_agg_exec, + sort_exec, + sort_onwards, + )? { + return Ok(Some(res)); + } + } + // TODO: Once we can ensure that required ordering information propagates with + // necessary lineage information, compare `sort_input_ordering` and `required_ordering`. + // This will enable us to handle cases such as (a,b) -> Sort -> (a,b,c) -> Required(a,b). + // Currently, we can not remove such sorts. + } + } + (Some(required), None) => { + // Ordering requirement is not met, we should add a SortExec to the plan. + let sort_expr = required.to_vec(); + *child = add_sort_above_child(child, sort_expr)?; + *sort_onwards = vec![(idx, child.clone())]; + } + (None, Some(_)) => { + // We have a SortExec whose effect may be neutralized by a order-imposing + // operator. In this case, remove this sort: + if !requirements.plan.maintains_input_order() { + update_child_to_remove_unnecessary_sort(child, sort_onwards)?; + } + } + (None, None) => {} + } + } + if plan.children().is_empty() { + Ok(Some(requirements)) + } else { + let new_plan = requirements.plan.with_new_children(new_children)?; + for (idx, (trace, required_ordering)) in new_onwards + .iter_mut() + .zip(new_plan.required_input_ordering()) + .enumerate() + .take(new_plan.children().len()) + { + // TODO: When `maintains_input_order` returns a `Vec`, use corresponding index. + if new_plan.maintains_input_order() + && required_ordering.is_none() + && !trace.is_empty() + { + trace.push((idx, new_plan.clone())); + } else { + trace.clear(); + if new_plan.as_any().is::() { + trace.push((idx, new_plan.clone())); + } + } + } + Ok(Some(PlanWithCorrespondingSort { + plan: new_plan, + sort_onwards: new_onwards, + })) + } +} + +/// Analyzes a given `SortExec` to determine whether its input already has +/// a finer ordering than this `SortExec` enforces. +fn analyze_immediate_sort_removal( + requirements: &PlanWithCorrespondingSort, +) -> Result> { + if let Some(sort_exec) = requirements.plan.as_any().downcast_ref::() { + // If this sort is unnecessary, we should remove it: + if ordering_satisfy( + sort_exec.input().output_ordering(), + sort_exec.output_ordering(), + || sort_exec.input().equivalence_properties(), + ) { + // Since we know that a `SortExec` has exactly one child, + // we can use the zero index safely: + let mut new_onwards = requirements.sort_onwards[0].to_vec(); + if !new_onwards.is_empty() { + new_onwards.pop(); + } + return Ok(Some(PlanWithCorrespondingSort { + plan: sort_exec.input().clone(), + sort_onwards: vec![new_onwards], + })); + } + } + Ok(None) +} + +/// Analyzes a `WindowAggExec` to determine whether it may allow removing a sort. +fn analyze_window_sort_removal( + window_agg_exec: &WindowAggExec, + sort_exec: &SortExec, + sort_onward: &mut Vec<(usize, Arc)>, +) -> Result> { + let required_ordering = sort_exec.output_ordering().ok_or_else(|| { + DataFusionError::Plan("A SortExec should have output ordering".to_string()) + })?; + let physical_ordering = sort_exec.input().output_ordering(); + let physical_ordering = if let Some(physical_ordering) = physical_ordering { + physical_ordering + } else { + // If there is no physical ordering, there is no way to remove a sort -- immediately return: + return Ok(None); + }; + let window_expr = window_agg_exec.window_expr(); + let (can_skip_sorting, should_reverse) = can_skip_sort( + window_expr[0].partition_by(), + required_ordering, + &sort_exec.input().schema(), + physical_ordering, + )?; + if can_skip_sorting { + let new_window_expr = if should_reverse { + window_expr + .iter() + .map(|e| e.get_reverse_expr()) + .collect::>>() + } else { + Some(window_expr.to_vec()) + }; + if let Some(window_expr) = new_window_expr { + let new_child = remove_corresponding_sort_from_sub_plan(sort_onward)?; + let new_schema = new_child.schema(); + let new_plan = Arc::new(WindowAggExec::try_new( + window_expr, + new_child, + new_schema, + window_agg_exec.partition_keys.clone(), + Some(physical_ordering.to_vec()), + )?); + return Ok(Some(PlanWithCorrespondingSort::new(new_plan))); + } + } + Ok(None) +} + +/// Updates child to remove the unnecessary sorting below it. +fn update_child_to_remove_unnecessary_sort( + child: &mut Arc, + sort_onwards: &mut Vec<(usize, Arc)>, +) -> Result<()> { + if !sort_onwards.is_empty() { + *child = remove_corresponding_sort_from_sub_plan(sort_onwards)?; + } + Ok(()) +} + +/// Converts an [ExecutionPlan] trait object to a [SortExec] when possible. +fn convert_to_sort_exec(sort_any: &Arc) -> Result<&SortExec> { + sort_any.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Plan("Given ExecutionPlan is not a SortExec".to_string()) + }) +} + +/// Removes the sort from the plan in `sort_onwards`. +fn remove_corresponding_sort_from_sub_plan( + sort_onwards: &mut Vec<(usize, Arc)>, +) -> Result> { + let (sort_child_idx, sort_any) = sort_onwards[0].clone(); + let sort_exec = convert_to_sort_exec(&sort_any)?; + let mut prev_layer = sort_exec.input().clone(); + let mut prev_child_idx = sort_child_idx; + // In the loop below, se start from 1 as the first one is a SortExec + // and we are removing it from the plan. + for (child_idx, layer) in sort_onwards.iter().skip(1) { + let mut children = layer.children(); + children[prev_child_idx] = prev_layer; + prev_layer = layer.clone().with_new_children(children)?; + prev_child_idx = *child_idx; + } + // We have removed the sort, hence empty the sort_onwards: + sort_onwards.clear(); + Ok(prev_layer) +} + +#[derive(Debug)] +/// This structure stores extra column information required to remove unnecessary sorts. +pub struct ColumnInfo { + is_aligned: bool, + reverse: bool, + is_partition: bool, +} + +/// Compares physical ordering and required ordering of all `PhysicalSortExpr`s and returns a tuple. +/// The first element indicates whether these `PhysicalSortExpr`s can be removed from the physical plan. +/// The second element is a flag indicating whether we should reverse the sort direction in order to +/// remove physical sort expressions from the plan. +pub fn can_skip_sort( + partition_keys: &[Arc], + required: &[PhysicalSortExpr], + input_schema: &SchemaRef, + physical_ordering: &[PhysicalSortExpr], +) -> Result<(bool, bool)> { + if required.len() > physical_ordering.len() { + return Ok((false, false)); + } + let mut col_infos = vec![]; + for (sort_expr, physical_expr) in zip(required, physical_ordering) { + let column = sort_expr.expr.clone(); + let is_partition = partition_keys.iter().any(|e| e.eq(&column)); + let (is_aligned, reverse) = + check_alignment(input_schema, physical_expr, sort_expr); + col_infos.push(ColumnInfo { + is_aligned, + reverse, + is_partition, + }); + } + let partition_by_sections = col_infos + .iter() + .filter(|elem| elem.is_partition) + .collect::>(); + let can_skip_partition_bys = if partition_by_sections.is_empty() { + true + } else { + let first_reverse = partition_by_sections[0].reverse; + let can_skip_partition_bys = partition_by_sections + .iter() + .all(|c| c.is_aligned && c.reverse == first_reverse); + can_skip_partition_bys + }; + let order_by_sections = col_infos + .iter() + .filter(|elem| !elem.is_partition) + .collect::>(); + let (can_skip_order_bys, should_reverse_order_bys) = if order_by_sections.is_empty() { + (true, false) + } else { + let first_reverse = order_by_sections[0].reverse; + let can_skip_order_bys = order_by_sections + .iter() + .all(|c| c.is_aligned && c.reverse == first_reverse); + (can_skip_order_bys, first_reverse) + }; + let can_skip = can_skip_order_bys && can_skip_partition_bys; + Ok((can_skip, should_reverse_order_bys)) +} + +/// Compares `physical_ordering` and `required` ordering, returns a tuple +/// indicating (1) whether this column requires sorting, and (2) whether we +/// should reverse the window expression in order to avoid sorting. +fn check_alignment( + input_schema: &SchemaRef, + physical_ordering: &PhysicalSortExpr, + required: &PhysicalSortExpr, +) -> (bool, bool) { + if required.expr.eq(&physical_ordering.expr) { + let nullable = required.expr.nullable(input_schema).unwrap(); + let physical_opts = physical_ordering.options; + let required_opts = required.options; + let is_reversed = if nullable { + physical_opts == reverse_sort_options(required_opts) + } else { + // If the column is not nullable, NULLS FIRST/LAST is not important. + physical_opts.descending != required_opts.descending + }; + let can_skip = !nullable || is_reversed || (physical_opts == required_opts); + (can_skip, is_reversed) + } else { + (false, false) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::physical_plan::displayable; + use crate::physical_plan::filter::FilterExec; + use crate::physical_plan::memory::MemoryExec; + use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; + use crate::physical_plan::windows::create_window_expr; + use crate::prelude::SessionContext; + use arrow::compute::SortOptions; + use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::Result; + use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunction}; + use datafusion_physical_expr::expressions::{col, NotExpr}; + use datafusion_physical_expr::PhysicalSortExpr; + use std::sync::Arc; + + fn create_test_schema() -> Result { + let nullable_column = Field::new("nullable_col", DataType::Int32, true); + let non_nullable_column = Field::new("non_nullable_col", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![nullable_column, non_nullable_column])); + + Ok(schema) + } + + #[tokio::test] + async fn test_is_column_aligned_nullable() -> Result<()> { + let schema = create_test_schema()?; + let params = vec![ + ((true, true), (false, false), (true, true)), + ((true, true), (false, true), (false, false)), + ((true, true), (true, false), (false, false)), + ((true, false), (false, true), (true, true)), + ((true, false), (false, false), (false, false)), + ((true, false), (true, true), (false, false)), + ]; + for ( + (physical_desc, physical_nulls_first), + (req_desc, req_nulls_first), + (is_aligned_expected, reverse_expected), + ) in params + { + let physical_ordering = PhysicalSortExpr { + expr: col("nullable_col", &schema)?, + options: SortOptions { + descending: physical_desc, + nulls_first: physical_nulls_first, + }, + }; + let required_ordering = PhysicalSortExpr { + expr: col("nullable_col", &schema)?, + options: SortOptions { + descending: req_desc, + nulls_first: req_nulls_first, + }, + }; + let (is_aligned, reverse) = + check_alignment(&schema, &physical_ordering, &required_ordering); + assert_eq!(is_aligned, is_aligned_expected); + assert_eq!(reverse, reverse_expected); + } + + Ok(()) + } + + #[tokio::test] + async fn test_is_column_aligned_non_nullable() -> Result<()> { + let schema = create_test_schema()?; + + let params = vec![ + ((true, true), (false, false), (true, true)), + ((true, true), (false, true), (true, true)), + ((true, true), (true, false), (true, false)), + ((true, false), (false, true), (true, true)), + ((true, false), (false, false), (true, true)), + ((true, false), (true, true), (true, false)), + ]; + for ( + (physical_desc, physical_nulls_first), + (req_desc, req_nulls_first), + (is_aligned_expected, reverse_expected), + ) in params + { + let physical_ordering = PhysicalSortExpr { + expr: col("non_nullable_col", &schema)?, + options: SortOptions { + descending: physical_desc, + nulls_first: physical_nulls_first, + }, + }; + let required_ordering = PhysicalSortExpr { + expr: col("non_nullable_col", &schema)?, + options: SortOptions { + descending: req_desc, + nulls_first: req_nulls_first, + }, + }; + let (is_aligned, reverse) = + check_alignment(&schema, &physical_ordering, &required_ordering); + assert_eq!(is_aligned, is_aligned_expected); + assert_eq!(reverse, reverse_expected); + } + + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs, source, None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let physical_plan = Arc::new(SortExec::try_new(sort_exprs, sort_exec, None)?) + as Arc; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortExec: [nullable_col@0 ASC]", + " SortExec: [non_nullable_col@1 ASC]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + OptimizeSorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { vec!["SortExec: [nullable_col@0 ASC]"] }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", source.schema().as_ref()).unwrap(), + options: SortOptions { + descending: true, + nulls_first: true, + }, + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) + as Arc; + let window_agg_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col("non_nullable_col", &schema)?], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + )?], + sort_exec.clone(), + sort_exec.schema(), + vec![], + Some(sort_exprs), + )?) as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("non_nullable_col", window_agg_exec.schema().as_ref()).unwrap(), + options: SortOptions { + descending: false, + nulls_first: false, + }, + }]; + let sort_exec = Arc::new(SortExec::try_new( + sort_exprs.clone(), + window_agg_exec, + None, + )?) as Arc; + // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before + let filter_exec = Arc::new(FilterExec::try_new( + Arc::new(NotExpr::new( + col("non_nullable_col", schema.as_ref()).unwrap(), + )), + sort_exec, + )?) as Arc; + // let filter_exec = sort_exec; + let window_agg_exec = Arc::new(WindowAggExec::try_new( + vec![create_window_expr( + &WindowFunction::AggregateFunction(AggregateFunction::Count), + "count".to_owned(), + &[col("non_nullable_col", &schema)?], + &[], + &sort_exprs, + Arc::new(WindowFrame::new(true)), + schema.as_ref(), + )?], + filter_exec.clone(), + filter_exec.schema(), + vec![], + Some(sort_exprs), + )?) as Arc; + let physical_plan = window_agg_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " FilterExec: NOT non_nullable_col@1", + " SortExec: [non_nullable_col@2 ASC NULLS LAST]", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " SortExec: [non_nullable_col@1 DESC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + OptimizeSorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(NULL) }]", + " FilterExec: NOT non_nullable_col@1", + " WindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow }]", + " SortExec: [non_nullable_col@1 DESC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_add_required_sort() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let physical_plan = Arc::new(SortPreservingMergeExec::new(sort_exprs, source)) + as Arc; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { vec!["SortPreservingMergeExec: [nullable_col@0 ASC]"] }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + OptimizeSorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_remove_unnecessary_sort1() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new(sort_exprs.clone(), source, None)?) + as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let sort_exprs = vec![PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }]; + let sort_exec = Arc::new(SortExec::try_new( + sort_exprs.clone(), + sort_preserving_merge_exec, + None, + )?) as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let physical_plan = sort_preserving_merge_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + OptimizeSorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortPreservingMergeExec: [nullable_col@0 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } + + #[tokio::test] + async fn test_change_wrong_sorting() -> Result<()> { + let session_ctx = SessionContext::new(); + let conf = session_ctx.copied_config(); + let schema = create_test_schema()?; + let source = Arc::new(MemoryExec::try_new(&[], schema.clone(), None)?) + as Arc; + let sort_exprs = vec![ + PhysicalSortExpr { + expr: col("nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }, + PhysicalSortExpr { + expr: col("non_nullable_col", schema.as_ref()).unwrap(), + options: SortOptions::default(), + }, + ]; + let sort_exec = Arc::new(SortExec::try_new( + vec![sort_exprs[0].clone()], + source, + None, + )?) as Arc; + let sort_preserving_merge_exec = + Arc::new(SortPreservingMergeExec::new(sort_exprs, sort_exec)) + as Arc; + let physical_plan = sort_preserving_merge_exec; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: [nullable_col@0 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + let optimized_physical_plan = + OptimizeSorts::new().optimize(physical_plan, &conf)?; + let formatted = displayable(optimized_physical_plan.as_ref()) + .indent() + .to_string(); + let expected = { + vec![ + "SortPreservingMergeExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " SortExec: [nullable_col@0 ASC,non_nullable_col@1 ASC]", + " MemoryExec: partitions=0, partition_sizes=[]", + ] + }; + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + Ok(()) + } +} diff --git a/datafusion/core/src/physical_optimizer/utils.rs b/datafusion/core/src/physical_optimizer/utils.rs index 4aceb776d7d5..8f1fe2d08213 100644 --- a/datafusion/core/src/physical_optimizer/utils.rs +++ b/datafusion/core/src/physical_optimizer/utils.rs @@ -21,7 +21,12 @@ use super::optimizer::PhysicalOptimizerRule; use crate::execution::context::SessionConfig; use crate::error::Result; +use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{with_new_children_if_necessary, ExecutionPlan}; +use datafusion_physical_expr::{ + normalize_sort_expr_with_equivalence_properties, EquivalenceProperties, + PhysicalSortExpr, +}; use std::sync::Arc; /// Convenience rule for writing optimizers: recursively invoke @@ -45,3 +50,73 @@ pub fn optimize_children( with_new_children_if_necessary(plan, children) } } + +/// Checks whether given ordering requirements are satisfied by provided [PhysicalSortExpr]s. +pub fn ordering_satisfy EquivalenceProperties>( + provided: Option<&[PhysicalSortExpr]>, + required: Option<&[PhysicalSortExpr]>, + equal_properties: F, +) -> bool { + match (provided, required) { + (_, None) => true, + (None, Some(_)) => false, + (Some(provided), Some(required)) => { + ordering_satisfy_concrete(provided, required, equal_properties) + } + } +} + +pub fn ordering_satisfy_concrete EquivalenceProperties>( + provided: &[PhysicalSortExpr], + required: &[PhysicalSortExpr], + equal_properties: F, +) -> bool { + if required.len() > provided.len() { + false + } else if required + .iter() + .zip(provided.iter()) + .all(|(order1, order2)| order1.eq(order2)) + { + true + } else if let eq_classes @ [_, ..] = equal_properties().classes() { + let normalized_required_exprs = required + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) + }) + .collect::>(); + let normalized_provided_exprs = provided + .iter() + .map(|e| { + normalize_sort_expr_with_equivalence_properties(e.clone(), eq_classes) + }) + .collect::>(); + normalized_required_exprs + .iter() + .zip(normalized_provided_exprs.iter()) + .all(|(order1, order2)| order1.eq(order2)) + } else { + false + } +} + +/// Util function to add SortExec above child +/// preserving the original partitioning +pub fn add_sort_above_child( + child: &Arc, + sort_expr: Vec, +) -> Result> { + let new_child = if child.output_partitioning().partition_count() > 1 { + Arc::new(SortExec::new_with_partitioning( + sort_expr, + child.clone(), + true, + None, + )) as Arc + } else { + Arc::new(SortExec::try_new(sort_expr, child.clone(), None)?) + as Arc + }; + Ok(new_child) +} diff --git a/datafusion/core/src/physical_plan/common.rs b/datafusion/core/src/physical_plan/common.rs index b29dc0cb8c11..1c36014f2012 100644 --- a/datafusion/core/src/physical_plan/common.rs +++ b/datafusion/core/src/physical_plan/common.rs @@ -266,6 +266,22 @@ impl Drop for AbortOnDropMany { } } +/// Transposes the given vector of vectors. +pub fn transpose(original: Vec>) -> Vec> { + match original.as_slice() { + [] => vec![], + [first, ..] => { + let mut result = (0..first.len()).map(|_| vec![]).collect::>(); + for row in original { + for (item, transposed_row) in row.into_iter().zip(&mut result) { + transposed_row.push(item); + } + } + result + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -332,6 +348,15 @@ mod tests { assert_eq!(actual, expected); Ok(()) } + + #[test] + fn test_transpose() -> Result<()> { + let in_data = vec![vec![1, 2, 3], vec![4, 5, 6]]; + let transposed = transpose(in_data); + let expected = vec![vec![1, 4], vec![2, 5], vec![3, 6]]; + assert_eq!(expected, transposed); + Ok(()) + } } /// Write in Arrow IPC format. diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 0a598be873cd..e16c518b62d9 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -583,7 +583,7 @@ impl DefaultPhysicalPlanner { let physical_input_schema = input_exec.schema(); let sort_keys = sort_keys .iter() - .map(|e| match e { + .map(|(e, _)| match e { Expr::Sort(expr::Sort { expr, asc, diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 9492fb7497a6..3dc0c6d337cc 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -289,7 +289,16 @@ impl ExecutionPlan for RepartitionExec { } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None + if self.maintains_input_order() { + self.input().output_ordering() + } else { + None + } + } + + fn maintains_input_order(&self) -> bool { + // We preserve ordering when input partitioning is 1 + self.input().output_partitioning().partition_count() <= 1 } fn equivalence_properties(&self) -> EquivalenceProperties { diff --git a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs index 914e3e71dbad..d1ea0af69ad1 100644 --- a/datafusion/core/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/core/src/physical_plan/windows/window_agg_exec.rs @@ -19,6 +19,7 @@ use crate::error::Result; use crate::execution::context::TaskContext; +use crate::physical_plan::common::transpose; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, @@ -28,19 +29,23 @@ use crate::physical_plan::{ ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream, Statistics, WindowExpr, }; -use arrow::compute::concat_batches; +use arrow::compute::{ + concat, concat_batches, lexicographical_partition_ranges, SortColumn, +}; use arrow::{ array::ArrayRef, datatypes::{Schema, SchemaRef}, error::{ArrowError, Result as ArrowResult}, record_batch::RecordBatch, }; +use datafusion_common::DataFusionError; use datafusion_physical_expr::rewrite::TreeNodeRewritable; use datafusion_physical_expr::EquivalentClass; use futures::stream::Stream; use futures::{ready, StreamExt}; use log::debug; use std::any::Any; +use std::ops::Range; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -131,6 +136,28 @@ impl WindowAggExec { pub fn input_schema(&self) -> SchemaRef { self.input_schema.clone() } + + /// Return the output sort order of partition keys: For example + /// OVER(PARTITION BY a, ORDER BY b) -> would give sorting of the column a + // We are sure that partition by columns are always at the beginning of sort_keys + // Hence returned `PhysicalSortExpr` corresponding to `PARTITION BY` columns can be used safely + // to calculate partition separation points + pub fn partition_by_sort_keys(&self) -> Result> { + let mut result = vec![]; + // All window exprs have the same partition by, so we just use the first one: + let partition_by = self.window_expr()[0].partition_by(); + let sort_keys = self.sort_keys.as_deref().unwrap_or(&[]); + for item in partition_by { + if let Some(a) = sort_keys.iter().find(|&e| e.expr.eq(item)) { + result.push(a.clone()); + } else { + return Err(DataFusionError::Execution( + "Partition key not found in sort keys".to_string(), + )); + } + } + Ok(result) + } } impl ExecutionPlan for WindowAggExec { @@ -253,6 +280,7 @@ impl ExecutionPlan for WindowAggExec { self.window_expr.clone(), input, BaselineMetrics::new(&self.metrics, partition), + self.partition_by_sort_keys()?, )); Ok(stream) } @@ -337,6 +365,7 @@ pub struct WindowAggStream { batches: Vec, finished: bool, window_expr: Vec>, + partition_by_sort_keys: Vec, baseline_metrics: BaselineMetrics, } @@ -347,6 +376,7 @@ impl WindowAggStream { window_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, + partition_by_sort_keys: Vec, ) -> Self { Self { schema, @@ -355,6 +385,7 @@ impl WindowAggStream { finished: false, window_expr, baseline_metrics, + partition_by_sort_keys, } } @@ -368,9 +399,32 @@ impl WindowAggStream { let batch = concat_batches(&self.input.schema(), &self.batches)?; - // calculate window cols - let mut columns = compute_window_aggregates(&self.window_expr, &batch) - .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let partition_by_sort_keys = self + .partition_by_sort_keys + .iter() + .map(|elem| elem.evaluate_to_sort_column(&batch)) + .collect::>>()?; + let partition_points = + self.evaluate_partition_points(batch.num_rows(), &partition_by_sort_keys)?; + + let mut partition_results = vec![]; + // Calculate window cols + for partition_point in partition_points { + let length = partition_point.end - partition_point.start; + partition_results.push( + compute_window_aggregates( + &self.window_expr, + &batch.slice(partition_point.start, length), + ) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?, + ) + } + let mut columns = transpose(partition_results) + .iter() + .map(|elems| concat(&elems.iter().map(|x| x.as_ref()).collect::>())) + .collect::>() + .into_iter() + .collect::>>()?; // combine with the original cols // note the setup of window aggregates is that they newly calculated window @@ -378,6 +432,25 @@ impl WindowAggStream { columns.extend_from_slice(batch.columns()); RecordBatch::try_new(self.schema.clone(), columns) } + + /// Evaluates the partition points given the sort columns. If the sort columns are + /// empty, then the result will be a single element vector spanning the entire batch. + fn evaluate_partition_points( + &self, + num_rows: usize, + partition_columns: &[SortColumn], + ) -> Result>> { + Ok(if partition_columns.is_empty() { + vec![Range { + start: 0, + end: num_rows, + }] + } else { + lexicographical_partition_ranges(partition_columns) + .map_err(DataFusionError::ArrowError)? + .collect::>() + }) + } } impl Stream for WindowAggStream { diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index 62aeb02557a7..90fd91164fe1 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -61,11 +61,6 @@ async fn explain_analyze_baseline_metrics() { "AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", "metrics=[output_rows=5, elapsed_compute=" ); - assert_metrics!( - &formatted, - "SortExec: [c1@0 ASC NULLS LAST]", - "metrics=[output_rows=5, elapsed_compute=" - ); assert_metrics!( &formatted, "FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index c9ef64212a6b..41278e1208b7 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -1748,17 +1748,20 @@ async fn test_window_partition_by_order_by() -> Result<()> { let msg = format!("Creating logical plan for '{}'", sql); let dataframe = ctx.sql(sql).await.expect(&msg); - let physical_plan = dataframe.create_physical_plan().await.unwrap(); + let physical_plan = dataframe.create_physical_plan().await?; let formatted = displayable(physical_plan.as_ref()).indent().to_string(); - // Only 1 SortExec was added let expected = { vec![ - "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as COUNT(UInt8(1))]", - " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }, COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", - " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + "ProjectionExec: expr=[SUM(aggregate_test_100.c4) PARTITION BY [aggregate_test_100.c1, aggregate_test_100.c2] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@1 as SUM(aggregate_test_100.c4), COUNT(UInt8(1)) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING@0 as COUNT(UInt8(1))]", + " WindowAggExec: wdw=[COUNT(UInt8(1)): Ok(Field { name: \"COUNT(UInt8(1))\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c1@1 ASC NULLS LAST,c2@2 ASC NULLS LAST]", " CoalesceBatchesExec: target_batch_size=4096", - " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 2)", - " RepartitionExec: partitioning=RoundRobinBatch(2)", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 1 }], 2)", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c4): Ok(Field { name: \"SUM(aggregate_test_100.c4)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c2@1 ASC NULLS LAST]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }, Column { name: \"c2\", index: 1 }], 2)", + " RepartitionExec: partitioning=RoundRobinBatch(2)", ] }; @@ -1772,3 +1775,568 @@ async fn test_window_partition_by_order_by() -> Result<()> { ); Ok(()) } + +#[tokio::test] +async fn test_window_agg_sort_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@2 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4268716378 | 8498370520 | 24997484146 |", + "| 4229654142 | 12714811027 | 29012926487 |", + "| 4216440507 | 16858984380 | 28743001064 |", + "| 4144173353 | 20935849039 | 28472563256 |", + "| 4076864659 | 24997484146 | 28118515915 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_reversed_plan_builtin() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + FIRST_VALUE(c9) OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as fv1, + FIRST_VALUE(c9) OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as fv2, + LAG(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lag1, + LAG(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lag2, + LEAD(c9, 2, 10101) OVER(ORDER BY c9 ASC) as lead1, + LEAD(c9, 2, 10101) OVER(ORDER BY c9 DESC ROWS BETWEEN 10 PRECEDING and 1 FOLLOWING) as lead2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@6 as c9, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as fv1, FIRST_VALUE(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@3 as fv2, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as lag1, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@4 as lag2, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@2 as lead1, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 10 PRECEDING AND 1 FOLLOWING@5 as lead2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(UInt32(NULL)) }]", + " WindowAggExec: wdw=[FIRST_VALUE(aggregate_test_100.c9): Ok(Field { name: \"FIRST_VALUE(aggregate_test_100.c9)\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }, LAG(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LAG(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }, LEAD(aggregate_test_100.c9,Int64(2),Int64(10101)): Ok(Field { name: \"LEAD(aggregate_test_100.c9,Int64(2),Int64(10101))\", data_type: UInt32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(10)), end_bound: Following(UInt64(1)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+------------+------------+------------+------------+------------+------------+", + "| c9 | fv1 | fv2 | lag1 | lag2 | lead1 | lead2 |", + "+------------+------------+------------+------------+------------+------------+------------+", + "| 4268716378 | 4229654142 | 4268716378 | 4216440507 | 10101 | 10101 | 4216440507 |", + "| 4229654142 | 4216440507 | 4268716378 | 4144173353 | 10101 | 10101 | 4144173353 |", + "| 4216440507 | 4144173353 | 4229654142 | 4076864659 | 4268716378 | 4268716378 | 4076864659 |", + "| 4144173353 | 4076864659 | 4216440507 | 4061635107 | 4229654142 | 4229654142 | 4061635107 |", + "| 4076864659 | 4061635107 | 4144173353 | 4015442341 | 4216440507 | 4216440507 | 4015442341 |", + "+------------+------------+------------+------------+------------+------------+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_non_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + ROW_NUMBER() OVER(ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn1, + ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // We cannot reverse each window function (ROW_NUMBER is not reversible) + let expected = { + vec![ + "ProjectionExec: expr=[c9@2 as c9, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as rn1, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@1 ASC NULLS LAST]", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+-----+-----+", + "| c9 | rn1 | rn2 |", + "+-----------+-----+-----+", + "| 28774375 | 1 | 100 |", + "| 63044568 | 2 | 99 |", + "| 141047417 | 3 | 98 |", + "| 141680161 | 4 | 97 |", + "| 145294611 | 5 | 96 |", + "+-----------+-----+-----+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_multi_layer_non_reversed_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c9 ASC, c1 ASC, c2 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(ORDER BY c9 DESC, c1 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2, + ROW_NUMBER() OVER(ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as rn2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // We cannot reverse each window function (ROW_NUMBER is not reversible) + let expected = { + vec![ + "ProjectionExec: expr=[c9@5 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 ASC NULLS LAST, aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c2 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c1 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@2 as sum2, ROW_NUMBER() ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as rn2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@4 ASC NULLS LAST,c1@2 ASC NULLS LAST,c2@3 ASC NULLS LAST]", + " WindowAggExec: wdw=[ROW_NUMBER(): Ok(Field { name: \"ROW_NUMBER()\", data_type: UInt64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c9@2 DESC,c1@0 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------+------------+-----------+-----+", + "| c9 | sum1 | sum2 | rn2 |", + "+-----------+------------+-----------+-----+", + "| 28774375 | 745354217 | 91818943 | 100 |", + "| 63044568 | 988558066 | 232866360 | 99 |", + "| 141047417 | 1285934966 | 374546521 | 98 |", + "| 141680161 | 1654839259 | 519841132 | 97 |", + "| 145294611 | 1980231675 | 745354217 | 96 |", + "+-----------+------------+-----------+-----+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_complex_plan() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_null_cases_csv(&ctx).await?; + let sql = "SELECT + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as a, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as b, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as c, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as d, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as e, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as f, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING) as g, + SUM(c1) OVER (ORDER BY c3) as h, + SUM(c1) OVER (ORDER BY c3 DESC) as i, + SUM(c1) OVER (ORDER BY c3 NULLS first) as j, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first) as k, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last) as l, + SUM(c1) OVER (ORDER BY c3, c2) as m, + SUM(c1) OVER (ORDER BY c3, c1 DESC) as n, + SUM(c1) OVER (ORDER BY c3 DESC, c1) as o, + SUM(c1) OVER (ORDER BY c3, c1 NULLs first) as p, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as a1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as b1, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as c1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as d1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as e1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as f1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING) as g1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as h1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as j1, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as k1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as l1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as m1, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as n1, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN UNBOUNDED PRECEDING AND current row) as o1, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as h11, + SUM(c1) OVER (ORDER BY c3 RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as j11, + SUM(c1) OVER (ORDER BY c3 DESC RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as k11, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as l11, + SUM(c1) OVER (ORDER BY c3 DESC NULLS last RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as m11, + SUM(c1) OVER (ORDER BY c3 DESC NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as n11, + SUM(c1) OVER (ORDER BY c3 NULLS first RANGE BETWEEN current row AND UNBOUNDED FOLLOWING) as o11 + FROM null_cases + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Unnecessary SortExecs are removed + let expected = { + vec![ + "ProjectionExec: expr=[SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as a, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@0 as b, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@15 as c, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as d, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@11 as e, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@15 as f, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN 10 PRECEDING AND 11 FOLLOWING@7 as g, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as i, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as j, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as k, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as l, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@4 as m, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@6 as n, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST, null_cases.c1 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@19 as o, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST, null_cases.c1 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@5 as p, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as a1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@2 as b1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@17 as c1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as d1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@13 as e1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@17 as f1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND 11 FOLLOWING@9 as g1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as h1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as j1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as k1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as l1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@12 as m1, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@16 as n1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@8 as o1, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as h11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@3 as j11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@18 as k11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as l11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS LAST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@14 as m11, SUM(null_cases.c1) ORDER BY [null_cases.c3 DESC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@18 as n11, SUM(null_cases.c1) ORDER BY [null_cases.c3 ASC NULLS FIRST] RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING@10 as o11]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@17 ASC NULLS LAST,c2@16 ASC NULLS LAST]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@16 ASC NULLS LAST,c1@14 ASC]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(10)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(11)), end_bound: Following(Int64(NULL)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(10)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: Following(Int64(11)) }, SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int64(NULL)) }]", + " WindowAggExec: wdw=[SUM(null_cases.c1): Ok(Field { name: \"SUM(null_cases.c1)\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int64(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@2 DESC,c1@0 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_partitionby_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(ORDER BY c1, c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c1 ASC NULLS LAST, aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4015442341 | 21907044499 | 21907044499 |", + "| 3998790955 | 24576419362 | 24576419362 |", + "| 3959216334 | 23063303501 | 23063303501 |", + "| 3717551163 | 21560567246 | 21560567246 |", + "| 3276123488 | 19815386638 | 19815386638 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_partitionby_reversed_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT + c9, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 ASC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum1, + SUM(c9) OVER(PARTITION BY c1 ORDER BY c9 DESC ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c9@3 as c9, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 ASC NULLS LAST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@0 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c1] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] ROWS BETWEEN 1 PRECEDING AND 5 FOLLOWING@1 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(5)), end_bound: Following(UInt64(1)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(1)), end_bound: Following(UInt64(5)) }]", + " SortExec: [c1@0 ASC NULLS LAST,c9@1 DESC]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+------------+-------------+-------------+", + "| c9 | sum1 | sum2 |", + "+------------+-------------+-------------+", + "| 4015442341 | 8014233296 | 21907044499 |", + "| 3998790955 | 11973449630 | 24576419362 |", + "| 3959216334 | 15691000793 | 23063303501 |", + "| 3717551163 | 18967124281 | 21560567246 |", + "| 3276123488 | 21907044499 | 19815386638 |", + "+------------+-------------+-------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_binary_expr() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c3, + SUM(c9) OVER(ORDER BY c3+c4 DESC, c9 DESC, c2 ASC) as sum1, + SUM(c9) OVER(ORDER BY c3+c4 ASC, c9 ASC ) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 ASC NULLS LAST, aggregate_test_100.c9 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: CurrentRow, end_bound: Following(Int16(NULL)) }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW: Ok(Field { name: \"SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 + aggregate_test_100.c4 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int16(NULL)), end_bound: CurrentRow }]", + " SortExec: [CAST(c3@1 AS Int16) + c4@2 DESC,c9@3 DESC,c2@0 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-------------+--------------+", + "| c3 | sum1 | sum2 |", + "+-----+-------------+--------------+", + "| -86 | 2861911482 | 222089770060 |", + "| 13 | 5075947208 | 219227858578 |", + "| 125 | 8701233618 | 217013822852 |", + "| 123 | 11293564174 | 213388536442 |", + "| 97 | 14767488750 | 210796205886 |", + "+-----+-------------+--------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> { + let config = SessionConfig::new() + .with_target_partitions(8) + .with_repartition_windows(true); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT count(*) as global_count FROM + (SELECT count(*), c1 + FROM aggregate_test_100 + WHERE c13 != 'C2GT5KVyOPZpgKVl110TyZO0NcJ434' + GROUP BY c1 + ORDER BY c1 ) AS a "; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Unnecessary Sort in the sub query is removed + let expected = { + vec![ + "ProjectionExec: expr=[COUNT(UInt8(1))@0 as global_count]", + " AggregateExec: mode=Final, gby=[], aggr=[COUNT(UInt8(1))]", + " CoalescePartitionsExec", + " AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]", + " RepartitionExec: partitioning=RoundRobinBatch(8)", + " CoalescePartitionsExec", + " AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]", + " CoalesceBatchesExec: target_batch_size=4096", + " RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8)", + " AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]", + " CoalesceBatchesExec: target_batch_size=4096", + " FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434", + " RepartitionExec: partitioning=RoundRobinBatch(8)", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+--------------+", + "| global_count |", + "+--------------+", + "| 5 |", + "+--------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn test_window_agg_sort_orderby_reversed_partitionby_reversed_plan() -> Result<()> { + let config = SessionConfig::new().with_repartition_windows(false); + let ctx = SessionContext::with_config(config); + register_aggregate_csv(&ctx).await?; + let sql = "SELECT c3, + SUM(c9) OVER(ORDER BY c3 DESC, c9 DESC, c2 ASC) as sum1, + SUM(c9) OVER(PARTITION BY c3 ORDER BY c9 DESC ) as sum2 + FROM aggregate_test_100 + LIMIT 5"; + + let msg = format!("Creating logical plan for '{}'", sql); + let dataframe = ctx.sql(sql).await.expect(&msg); + let physical_plan = dataframe.create_physical_plan().await?; + let formatted = displayable(physical_plan.as_ref()).indent().to_string(); + // Only 1 SortExec was added + let expected = { + vec![ + "ProjectionExec: expr=[c3@3 as c3, SUM(aggregate_test_100.c9) ORDER BY [aggregate_test_100.c3 DESC NULLS FIRST, aggregate_test_100.c9 DESC NULLS FIRST, aggregate_test_100.c2 ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@1 as sum1, SUM(aggregate_test_100.c9) PARTITION BY [aggregate_test_100.c3] ORDER BY [aggregate_test_100.c9 DESC NULLS FIRST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW@0 as sum2]", + " GlobalLimitExec: skip=0, fetch=5", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(NULL)), end_bound: CurrentRow }]", + " WindowAggExec: wdw=[SUM(aggregate_test_100.c9): Ok(Field { name: \"SUM(aggregate_test_100.c9)\", data_type: UInt64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(Int8(NULL)), end_bound: CurrentRow }]", + " SortExec: [c3@1 DESC,c9@2 DESC,c2@0 ASC NULLS LAST]", + ] + }; + + let actual: Vec<&str> = formatted.trim().lines().collect(); + let actual_len = actual.len(); + let actual_trim_last = &actual[..actual_len - 1]; + assert_eq!( + expected, actual_trim_last, + "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", + expected, actual + ); + + let actual = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----+-------------+------------+", + "| c3 | sum1 | sum2 |", + "+-----+-------------+------------+", + "| 125 | 3625286410 | 3625286410 |", + "| 123 | 7192027599 | 3566741189 |", + "| 123 | 9784358155 | 6159071745 |", + "| 122 | 13845993262 | 4061635107 |", + "| 120 | 16676974334 | 2830981072 |", + "+-----+-------------+------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index bf2a1d001867..eeb3215c4b6f 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -257,7 +257,7 @@ impl LogicalPlanBuilder { // The sort_by() implementation here is a stable sort. // Note that by this rule if there's an empty over, it'll be at the top level groups.sort_by(|(key_a, _), (key_b, _)| { - for (first, second) in key_a.iter().zip(key_b.iter()) { + for ((first, _), (second, _)) in key_a.iter().zip(key_b.iter()) { let key_ordering = compare_sort_expr(first, second, plan.schema()); match key_ordering { Ordering::Less => { diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 3ee36de17622..ca06dfdb4aaf 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -204,7 +204,9 @@ pub fn expand_qualified_wildcard( expand_wildcard(&qualifier_schema, plan) } -type WindowSortKey = Vec; +/// (expr, "is the SortExpr for window (either comes from PARTITION BY or ORDER BY columns)") +/// if bool is true SortExpr comes from `PARTITION BY` column, if false comes from `ORDER BY` column +type WindowSortKey = Vec<(Expr, bool)>; /// Generate a sort key for a given window expr's partition_by and order_bu expr pub fn generate_sort_key( @@ -224,6 +226,7 @@ pub fn generate_sort_key( .collect::>>()?; let mut final_sort_keys = vec![]; + let mut is_partition_flag = vec![]; partition_by.iter().for_each(|e| { // By default, create sort key with ASC is true and NULLS LAST to be consistent with // PostgreSQL's rule: https://www.postgresql.org/docs/current/queries-order.html @@ -232,18 +235,26 @@ pub fn generate_sort_key( let order_by_key = &order_by[pos]; if !final_sort_keys.contains(order_by_key) { final_sort_keys.push(order_by_key.clone()); + is_partition_flag.push(true); } } else if !final_sort_keys.contains(&e) { final_sort_keys.push(e); + is_partition_flag.push(true); } }); order_by.iter().for_each(|e| { if !final_sort_keys.contains(e) { final_sort_keys.push(e.clone()); + is_partition_flag.push(false); } }); - Ok(final_sort_keys) + let res = final_sort_keys + .into_iter() + .zip(is_partition_flag) + .map(|(lhs, rhs)| (lhs, rhs)) + .collect::>(); + Ok(res) } /// Compare the sort expr as PostgreSQL's common_prefix_cmp(): @@ -1043,9 +1054,13 @@ mod tests { let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; let result = group_window_expr_by_sort_keys(exprs)?; - let key1 = vec![age_asc.clone(), name_desc.clone()]; + let key1 = vec![(age_asc.clone(), false), (name_desc.clone(), false)]; let key2 = vec![]; - let key3 = vec![name_desc, age_asc, created_at_desc]; + let key3 = vec![ + (name_desc, false), + (age_asc, false), + (created_at_desc, false), + ]; let expected: Vec<(WindowSortKey, Vec<&Expr>)> = vec![ (key1, vec![&max1, &min3]), @@ -1112,21 +1127,30 @@ mod tests { ]; let expected = vec![ - Expr::Sort(Sort { - expr: Box::new(col("age")), - asc: asc_, - nulls_first: nulls_first_, - }), - Expr::Sort(Sort { - expr: Box::new(col("name")), - asc: asc_, - nulls_first: nulls_first_, - }), - Expr::Sort(Sort { - expr: Box::new(col("created_at")), - asc: true, - nulls_first: false, - }), + ( + Expr::Sort(Sort { + expr: Box::new(col("age")), + asc: asc_, + nulls_first: nulls_first_, + }), + true, + ), + ( + Expr::Sort(Sort { + expr: Box::new(col("name")), + asc: asc_, + nulls_first: nulls_first_, + }), + true, + ), + ( + Expr::Sort(Sort { + expr: Box::new(col("created_at")), + asc: true, + nulls_first: false, + }), + true, + ), ]; let result = generate_sort_key(partition_by, order_by)?; assert_eq!(expected, result); diff --git a/datafusion/expr/src/window_frame.rs b/datafusion/expr/src/window_frame.rs index 62c7c57d47ba..100ea8e1ded1 100644 --- a/datafusion/expr/src/window_frame.rs +++ b/datafusion/expr/src/window_frame.rs @@ -113,6 +113,35 @@ impl WindowFrame { } } } + + /// Get reversed window frame. For example + /// `3 ROWS PRECEDING AND 2 ROWS FOLLOWING` --> + /// `2 ROWS PRECEDING AND 3 ROWS FOLLOWING` + pub fn reverse(&self) -> Self { + let start_bound = match &self.end_bound { + WindowFrameBound::Preceding(elem) => { + WindowFrameBound::Following(elem.clone()) + } + WindowFrameBound::Following(elem) => { + WindowFrameBound::Preceding(elem.clone()) + } + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + }; + let end_bound = match &self.start_bound { + WindowFrameBound::Preceding(elem) => { + WindowFrameBound::Following(elem.clone()) + } + WindowFrameBound::Following(elem) => { + WindowFrameBound::Preceding(elem.clone()) + } + WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow, + }; + WindowFrame { + units: self.units, + start_bound, + end_bound, + } + } } /// There are five ways to describe starting and ending frame boundaries: diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 6c43344db97a..813952117af1 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -36,7 +36,7 @@ use crate::expressions::format_state_name; /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Count { name: String, data_type: DataType, @@ -105,6 +105,10 @@ impl AggregateExpr for Count { Ok(Box::new(CountRowAccumulator::new(start_index))) } + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) + } + fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(CountAccumulator::new())) } diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 436a2339663f..947336596292 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -103,6 +103,14 @@ pub trait AggregateExpr: Send + Sync + Debug { ))) } + /// Construct an expression that calculates the aggregate in reverse. + /// Typically the "reverse" expression is itself (e.g. SUM, COUNT). + /// For aggregates that do not support calculation in reverse, + /// returns None (which is the default value). + fn reverse_expr(&self) -> Option> { + None + } + /// Creates accumulator implementation that supports retract fn create_sliding_accumulator(&self) -> Result> { Err(DataFusionError::NotImplemented(format!( diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 8d2620296c2e..c2d54c40ed7c 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -44,7 +44,7 @@ use arrow::compute::cast; use datafusion_row::accessor::RowAccessor; /// SUM aggregate expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Sum { name: String, data_type: DataType, @@ -123,6 +123,10 @@ impl AggregateExpr for Sum { ))) } + fn reverse_expr(&self) -> Option> { + Some(Arc::new(self.clone())) + } + fn create_sliding_accumulator(&self) -> Result> { Ok(Box::new(SumAccumulator::try_new(&self.data_type)?)) } diff --git a/datafusion/physical-expr/src/window/aggregate.rs b/datafusion/physical-expr/src/window/aggregate.rs index c42f7ff55a36..5c46f38f220f 100644 --- a/datafusion/physical-expr/src/window/aggregate.rs +++ b/datafusion/physical-expr/src/window/aggregate.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::iter::IntoIterator; +use std::ops::Range; use std::sync::Arc; use arrow::array::Array; @@ -30,6 +31,8 @@ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::WindowFrame; +use crate::window::window_expr::reverse_order_bys; +use crate::window::SlidingAggregateWindowExpr; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -89,49 +92,41 @@ impl WindowExpr for AggregateWindowExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let partition_columns = self.partition_columns(batch)?; - let partition_points = - self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results: Vec = vec![]; - for partition_range in &partition_points { - let mut accumulator = self.aggregate.create_accumulator()?; - let length = partition_range.end - partition_range.start; - let (values, order_bys) = - self.get_values_orderbys(&batch.slice(partition_range.start, length))?; - - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - let mut last_range: (usize, usize) = (0, 0); - - // We iterate on each row to perform a running calculation. - // First, cur_range is calculated, then it is compared with last_range. - for i in 0..length { - let cur_range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - length, - i, - )?; - let value = if cur_range.0 == cur_range.1 { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type())? - } else { - // Accumulate any new rows that have entered the window: - let update_bound = cur_range.1 - last_range.1; - if update_bound > 0 { - let update: Vec = values - .iter() - .map(|v| v.slice(last_range.1, update_bound)) - .collect(); - accumulator.update_batch(&update)? - } - accumulator.evaluate()? - }; - row_wise_results.push(value); - last_range = cur_range; - } + + let mut accumulator = self.aggregate.create_accumulator()?; + let length = batch.num_rows(); + let (values, order_bys) = self.get_values_orderbys(batch)?; + + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let mut last_range = Range { start: 0, end: 0 }; + + // We iterate on each row to perform a running calculation. + // First, cur_range is calculated, then it is compared with last_range. + for i in 0..length { + let cur_range = + window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?; + let value = if cur_range.end == cur_range.start { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? + } else { + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.end - last_range.end; + if update_bound > 0 { + let update: Vec = values + .iter() + .map(|v| v.slice(last_range.end, update_bound)) + .collect(); + accumulator.update_batch(&update)? + } + accumulator.evaluate()? + }; + row_wise_results.push(value); + last_range = cur_range; } + ScalarValue::iter_to_array(row_wise_results.into_iter()) } @@ -146,4 +141,25 @@ impl WindowExpr for AggregateWindowExpr { fn get_window_frame(&self) -> &Arc { &self.window_frame } + + fn get_reverse_expr(&self) -> Option> { + self.aggregate.reverse_expr().map(|reverse_expr| { + let reverse_window_frame = self.window_frame.reverse(); + if reverse_window_frame.start_bound.is_unbounded() { + Arc::new(AggregateWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + )) as _ + } else { + Arc::new(SlidingAggregateWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + )) as _ + } + }) + } } diff --git a/datafusion/physical-expr/src/window/built_in.rs b/datafusion/physical-expr/src/window/built_in.rs index 95bf01608b82..9804432b2056 100644 --- a/datafusion/physical-expr/src/window/built_in.rs +++ b/datafusion/physical-expr/src/window/built_in.rs @@ -20,15 +20,15 @@ use super::window_frame_state::WindowFrameContext; use super::BuiltInWindowFunctionExpr; use super::WindowExpr; +use crate::window::window_expr::reverse_order_bys; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; -use arrow::compute::{concat, SortOptions}; +use arrow::compute::SortOptions; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::DataFusionError; use datafusion_common::Result; +use datafusion_common::ScalarValue; use datafusion_expr::WindowFrame; use std::any::Any; -use std::ops::Range; use std::sync::Arc; /// A window expr that takes the form of a built in window function @@ -91,50 +91,49 @@ impl WindowExpr for BuiltInWindowExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let evaluator = self.expr.create_evaluator()?; let num_rows = batch.num_rows(); - let partition_columns = self.partition_columns(batch)?; - let partition_points = - self.evaluate_partition_points(num_rows, &partition_columns)?; - - let results = if evaluator.uses_window_frame() { + if evaluator.uses_window_frame() { let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results = vec![]; - for partition_range in &partition_points { - let length = partition_range.end - partition_range.start; - let (values, order_bys) = self - .get_values_orderbys(&batch.slice(partition_range.start, length))?; - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - // We iterate on each row to calculate window frame range and and window function result - for idx in 0..length { - let range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - num_rows, - idx, - )?; - let range = Range { - start: range.0, - end: range.1, - }; - let value = evaluator.evaluate_inside_range(&values, range)?; - row_wise_results.push(value.to_array()); - } + + let length = batch.num_rows(); + let (values, order_bys) = self.get_values_orderbys(batch)?; + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + // We iterate on each row to calculate window frame range and and window function result + for idx in 0..length { + let range = window_frame_ctx.calculate_range( + &order_bys, + &sort_options, + num_rows, + idx, + )?; + let value = evaluator.evaluate_inside_range(&values, range)?; + row_wise_results.push(value); } - row_wise_results + ScalarValue::iter_to_array(row_wise_results.into_iter()) } else if evaluator.include_rank() { let columns = self.sort_columns(batch)?; let sort_partition_points = self.evaluate_partition_points(num_rows, &columns)?; - evaluator.evaluate_with_rank(partition_points, sort_partition_points)? + evaluator.evaluate_with_rank(num_rows, &sort_partition_points) } else { let (values, _) = self.get_values_orderbys(batch)?; - evaluator.evaluate(&values, partition_points)? - }; - let results = results.iter().map(|i| i.as_ref()).collect::>(); - concat(&results).map_err(DataFusionError::ArrowError) + evaluator.evaluate(&values, num_rows) + } } fn get_window_frame(&self) -> &Arc { &self.window_frame } + + fn get_reverse_expr(&self) -> Option> { + self.expr.reverse_expr().map(|reverse_expr| { + Arc::new(BuiltInWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + )) as _ + }) + } } diff --git a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs index 7f7a27435c39..c358403fefda 100644 --- a/datafusion/physical-expr/src/window/built_in_window_function_expr.rs +++ b/datafusion/physical-expr/src/window/built_in_window_function_expr.rs @@ -58,4 +58,10 @@ pub trait BuiltInWindowFunctionExpr: Send + Sync + std::fmt::Debug { /// Create built-in window evaluator with a batch fn create_evaluator(&self) -> Result>; + + /// Construct Reverse Expression that produces the same result + /// on a reversed window. For example `lead(10)` --> `lag(10)` + fn reverse_expr(&self) -> Option> { + None + } } diff --git a/datafusion/physical-expr/src/window/cume_dist.rs b/datafusion/physical-expr/src/window/cume_dist.rs index 4202058a3c5a..45fe51178afc 100644 --- a/datafusion/physical-expr/src/window/cume_dist.rs +++ b/datafusion/physical-expr/src/window/cume_dist.rs @@ -73,19 +73,19 @@ impl PartitionEvaluator for CumeDistEvaluator { true } - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - partition: Range, + num_rows: usize, ranks_in_partition: &[Range], ) -> Result { - let scaler = (partition.end - partition.start) as f64; + let scalar = num_rows as f64; let result = Float64Array::from_iter_values( ranks_in_partition .iter() .scan(0_u64, |acc, range| { let len = range.end - range.start; *acc += len as u64; - let value: f64 = (*acc as f64) / scaler; + let value: f64 = (*acc as f64) / scalar; let result = iter::repeat(value).take(len); Some(result) }) @@ -102,15 +102,14 @@ mod tests { fn test_i32_result( expr: &CumeDist, - partition: Range, + num_rows: usize, ranks: Vec>, expected: Vec, ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank(vec![partition], ranks)?; - assert_eq!(1, result.len()); - let result = as_float64_array(&result[0])?; + .evaluate_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -121,19 +120,19 @@ mod tests { let r = cume_dist("arr".into()); let expected = vec![0.0; 0]; - test_i32_result(&r, 0..0, vec![], expected)?; + test_i32_result(&r, 0, vec![], expected)?; let expected = vec![1.0; 1]; - test_i32_result(&r, 0..1, vec![0..1], expected)?; + test_i32_result(&r, 1, vec![0..1], expected)?; let expected = vec![1.0; 2]; - test_i32_result(&r, 0..2, vec![0..2], expected)?; + test_i32_result(&r, 2, vec![0..2], expected)?; let expected = vec![0.5, 0.5, 1.0, 1.0]; - test_i32_result(&r, 0..4, vec![0..2, 2..4], expected)?; + test_i32_result(&r, 4, vec![0..2, 2..4], expected)?; let expected = vec![0.25, 0.5, 0.75, 1.0]; - test_i32_result(&r, 0..4, vec![0..1, 1..2, 2..3, 3..4], expected)?; + test_i32_result(&r, 4, vec![0..1, 1..2, 2..3, 3..4], expected)?; Ok(()) } diff --git a/datafusion/physical-expr/src/window/lead_lag.rs b/datafusion/physical-expr/src/window/lead_lag.rs index c7fc73b9f1c1..e18815c4c3a6 100644 --- a/datafusion/physical-expr/src/window/lead_lag.rs +++ b/datafusion/physical-expr/src/window/lead_lag.rs @@ -28,7 +28,6 @@ use datafusion_common::ScalarValue; use datafusion_common::{DataFusionError, Result}; use std::any::Any; use std::ops::Neg; -use std::ops::Range; use std::sync::Arc; /// window shift expression @@ -107,6 +106,16 @@ impl BuiltInWindowFunctionExpr for WindowShift { default_value: self.default_value.clone(), })) } + + fn reverse_expr(&self) -> Option> { + Some(Arc::new(Self { + name: self.name.clone(), + data_type: self.data_type.clone(), + shift_offset: -self.shift_offset, + expr: self.expr.clone(), + default_value: self.default_value.clone(), + })) + } } pub(crate) struct WindowShiftEvaluator { @@ -164,15 +173,10 @@ fn shift_with_default_value( } impl PartitionEvaluator for WindowShiftEvaluator { - fn evaluate_partition( - &self, - values: &[ArrayRef], - partition: Range, - ) -> Result { + fn evaluate(&self, values: &[ArrayRef], _num_rows: usize) -> Result { // LEAD, LAG window functions take single column, values will have size 1 let value = &values[0]; - let value = value.slice(partition.start, partition.end - partition.start); - shift_with_default_value(&value, self.shift_offset, self.default_value.as_ref()) + shift_with_default_value(value, self.shift_offset, self.default_value.as_ref()) } } @@ -191,9 +195,10 @@ mod tests { let schema = Schema::new(vec![Field::new("arr", DataType::Int32, false)]); let batch = RecordBatch::try_new(Arc::new(schema), values.clone())?; let values = expr.evaluate_args(&batch)?; - let result = expr.create_evaluator()?.evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_int32_array(&result[0])?; + let result = expr + .create_evaluator()? + .evaluate(&values, batch.num_rows())?; + let result = as_int32_array(&result)?; assert_eq!(expected, *result); Ok(()) } diff --git a/datafusion/physical-expr/src/window/nth_value.rs b/datafusion/physical-expr/src/window/nth_value.rs index 63a2354c9e4c..e998b47018a5 100644 --- a/datafusion/physical-expr/src/window/nth_value.rs +++ b/datafusion/physical-expr/src/window/nth_value.rs @@ -123,6 +123,20 @@ impl BuiltInWindowFunctionExpr for NthValue { fn create_evaluator(&self) -> Result> { Ok(Box::new(NthValueEvaluator { kind: self.kind })) } + + fn reverse_expr(&self) -> Option> { + let reversed_kind = match self.kind { + NthValueKind::First => NthValueKind::Last, + NthValueKind::Last => NthValueKind::First, + NthValueKind::Nth(_) => return None, + }; + Some(Arc::new(Self { + name: self.name.clone(), + expr: self.expr.clone(), + data_type: self.data_type.clone(), + kind: reversed_kind, + })) + } } /// Value evaluator for nth_value functions diff --git a/datafusion/physical-expr/src/window/ntile.rs b/datafusion/physical-expr/src/window/ntile.rs index ed00c3c86955..f5844eccc63a 100644 --- a/datafusion/physical-expr/src/window/ntile.rs +++ b/datafusion/physical-expr/src/window/ntile.rs @@ -26,7 +26,6 @@ use arrow::datatypes::Field; use arrow_schema::DataType; use datafusion_common::Result; use std::any::Any; -use std::ops::Range; use std::sync::Arc; #[derive(Debug)] @@ -70,12 +69,8 @@ pub(crate) struct NtileEvaluator { } impl PartitionEvaluator for NtileEvaluator { - fn evaluate_partition( - &self, - _values: &[ArrayRef], - partition: Range, - ) -> Result { - let num_rows = (partition.end - partition.start) as u64; + fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { + let num_rows = num_rows as u64; let mut vec: Vec = Vec::new(); for i in 0..num_rows { let res = i * self.n / num_rows; diff --git a/datafusion/physical-expr/src/window/partition_evaluator.rs b/datafusion/physical-expr/src/window/partition_evaluator.rs index 1608758d61b3..86500441df5b 100644 --- a/datafusion/physical-expr/src/window/partition_evaluator.rs +++ b/datafusion/physical-expr/src/window/partition_evaluator.rs @@ -22,23 +22,6 @@ use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use std::ops::Range; -/// Given a partition range, and the full list of sort partition points, given that the sort -/// partition points are sorted using [partition columns..., order columns...], the split -/// boundaries would align (what's sorted on [partition columns...] would definitely be sorted -/// on finer columns), so this will use binary search to find ranges that are within the -/// partition range and return the valid slice. -pub(crate) fn find_ranges_in_range<'a>( - partition_range: &Range, - sort_partition_points: &'a [Range], -) -> &'a [Range] { - let start_idx = sort_partition_points - .partition_point(|sort_range| sort_range.start < partition_range.start); - let end_idx = start_idx - + sort_partition_points[start_idx..] - .partition_point(|sort_range| sort_range.end <= partition_range.end); - &sort_partition_points[start_idx..end_idx] -} - /// Partition evaluator pub trait PartitionEvaluator { /// Whether the evaluator should be evaluated with rank @@ -50,49 +33,17 @@ pub trait PartitionEvaluator { false } - /// evaluate the partition evaluator against the partitions - fn evaluate( - &self, - values: &[ArrayRef], - partition_points: Vec>, - ) -> Result> { - partition_points - .into_iter() - .map(|partition| self.evaluate_partition(values, partition)) - .collect() - } - - /// evaluate the partition evaluator against the partitions with rank information - fn evaluate_with_rank( - &self, - partition_points: Vec>, - sort_partition_points: Vec>, - ) -> Result> { - partition_points - .into_iter() - .map(|partition| { - let ranks_in_partition = - find_ranges_in_range(&partition, &sort_partition_points); - self.evaluate_partition_with_rank(partition, ranks_in_partition) - }) - .collect() - } - /// evaluate the partition evaluator against the partition - fn evaluate_partition( - &self, - _values: &[ArrayRef], - _partition: Range, - ) -> Result { + fn evaluate(&self, _values: &[ArrayRef], _num_rows: usize) -> Result { Err(DataFusionError::NotImplemented( "evaluate_partition is not implemented by default".into(), )) } /// evaluate the partition evaluator against the partition but with rank - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - _partition: Range, + _num_rows: usize, _ranks_in_partition: &[Range], ) -> Result { Err(DataFusionError::NotImplemented( diff --git a/datafusion/physical-expr/src/window/rank.rs b/datafusion/physical-expr/src/window/rank.rs index 8ed0319a10b0..87e01528de5a 100644 --- a/datafusion/physical-expr/src/window/rank.rs +++ b/datafusion/physical-expr/src/window/rank.rs @@ -114,9 +114,9 @@ impl PartitionEvaluator for RankEvaluator { true } - fn evaluate_partition_with_rank( + fn evaluate_with_rank( &self, - partition: Range, + num_rows: usize, ranks_in_partition: &[Range], ) -> Result { // see https://www.postgresql.org/docs/current/functions-window.html @@ -132,7 +132,7 @@ impl PartitionEvaluator for RankEvaluator { )), RankType::Percent => { // Returns the relative rank of the current row, that is (rank - 1) / (total partition rows - 1). The value thus ranges from 0 to 1 inclusive. - let denominator = (partition.end - partition.start) as f64; + let denominator = num_rows as f64; Arc::new(Float64Array::from_iter_values( ranks_in_partition .iter() @@ -177,15 +177,14 @@ mod tests { fn test_f64_result( expr: &Rank, - range: Range, + num_rows: usize, ranks: Vec>, expected: Vec, ) -> Result<()> { let result = expr .create_evaluator()? - .evaluate_with_rank(vec![range], ranks)?; - assert_eq!(1, result.len()); - let result = as_float64_array(&result[0])?; + .evaluate_with_rank(num_rows, &ranks)?; + let result = as_float64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -196,11 +195,8 @@ mod tests { ranks: Vec>, expected: Vec, ) -> Result<()> { - let result = expr - .create_evaluator()? - .evaluate_with_rank(vec![0..8], ranks)?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + let result = expr.create_evaluator()?.evaluate_with_rank(8, &ranks)?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(expected, result); Ok(()) @@ -228,19 +224,19 @@ mod tests { // empty case let expected = vec![0.0; 0]; - test_f64_result(&r, 0..0, vec![0..0; 0], expected)?; + test_f64_result(&r, 0, vec![0..0; 0], expected)?; // singleton case let expected = vec![0.0]; - test_f64_result(&r, 0..1, vec![0..1], expected)?; + test_f64_result(&r, 1, vec![0..1], expected)?; // uniform case let expected = vec![0.0; 7]; - test_f64_result(&r, 0..7, vec![0..7], expected)?; + test_f64_result(&r, 7, vec![0..7], expected)?; // non-trivial case let expected = vec![0.0, 0.0, 0.0, 0.5, 0.5, 0.5, 0.5]; - test_f64_result(&r, 0..7, vec![0..3, 3..7], expected)?; + test_f64_result(&r, 7, vec![0..3, 3..7], expected)?; Ok(()) } diff --git a/datafusion/physical-expr/src/window/row_number.rs b/datafusion/physical-expr/src/window/row_number.rs index f70d9ea379dd..b27ac29d2764 100644 --- a/datafusion/physical-expr/src/window/row_number.rs +++ b/datafusion/physical-expr/src/window/row_number.rs @@ -24,7 +24,6 @@ use arrow::array::{ArrayRef, UInt64Array}; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; use std::any::Any; -use std::ops::Range; use std::sync::Arc; /// row_number expression @@ -69,12 +68,7 @@ impl BuiltInWindowFunctionExpr for RowNumber { pub(crate) struct NumRowsEvaluator {} impl PartitionEvaluator for NumRowsEvaluator { - fn evaluate_partition( - &self, - _values: &[ArrayRef], - partition: Range, - ) -> Result { - let num_rows = partition.end - partition.start; + fn evaluate(&self, _values: &[ArrayRef], num_rows: usize) -> Result { Ok(Arc::new(UInt64Array::from_iter_values( 1..(num_rows as u64) + 1, ))) @@ -99,9 +93,8 @@ mod tests { let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? - .evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + .evaluate(&values, batch.num_rows())?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); Ok(()) @@ -118,9 +111,8 @@ mod tests { let values = row_number.evaluate_args(&batch)?; let result = row_number .create_evaluator()? - .evaluate(&values, vec![0..8])?; - assert_eq!(1, result.len()); - let result = as_uint64_array(&result[0])?; + .evaluate(&values, batch.num_rows())?; + let result = as_uint64_array(&result)?; let result = result.values(); assert_eq!(vec![1, 2, 3, 4, 5, 6, 7, 8], result); Ok(()) diff --git a/datafusion/physical-expr/src/window/sliding_aggregate.rs b/datafusion/physical-expr/src/window/sliding_aggregate.rs index 9dbaca76e689..2a0fa86b7fe3 100644 --- a/datafusion/physical-expr/src/window/sliding_aggregate.rs +++ b/datafusion/physical-expr/src/window/sliding_aggregate.rs @@ -19,6 +19,7 @@ use std::any::Any; use std::iter::IntoIterator; +use std::ops::Range; use std::sync::Arc; use arrow::array::Array; @@ -30,6 +31,8 @@ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::WindowFrame; +use crate::window::window_expr::reverse_order_bys; +use crate::window::AggregateWindowExpr; use crate::{expressions::PhysicalSortExpr, PhysicalExpr}; use crate::{window::WindowExpr, AggregateExpr}; @@ -89,57 +92,48 @@ impl WindowExpr for SlidingAggregateWindowExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let partition_columns = self.partition_columns(batch)?; - let partition_points = - self.evaluate_partition_points(batch.num_rows(), &partition_columns)?; let sort_options: Vec = self.order_by.iter().map(|o| o.options).collect(); let mut row_wise_results: Vec = vec![]; - for partition_range in &partition_points { - let mut accumulator = self.aggregate.create_sliding_accumulator()?; - let length = partition_range.end - partition_range.start; - let (values, order_bys) = - self.get_values_orderbys(&batch.slice(partition_range.start, length))?; - - let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); - let mut last_range: (usize, usize) = (0, 0); - - // We iterate on each row to perform a running calculation. - // First, cur_range is calculated, then it is compared with last_range. - for i in 0..length { - let cur_range = window_frame_ctx.calculate_range( - &order_bys, - &sort_options, - length, - i, - )?; - let value = if cur_range.0 == cur_range.1 { - // We produce None if the window is empty. - ScalarValue::try_from(self.aggregate.field()?.data_type())? - } else { - // Accumulate any new rows that have entered the window: - let update_bound = cur_range.1 - last_range.1; - if update_bound > 0 { - let update: Vec = values - .iter() - .map(|v| v.slice(last_range.1, update_bound)) - .collect(); - accumulator.update_batch(&update)? - } - // Remove rows that have now left the window: - let retract_bound = cur_range.0 - last_range.0; - if retract_bound > 0 { - let retract: Vec = values - .iter() - .map(|v| v.slice(last_range.0, retract_bound)) - .collect(); - accumulator.retract_batch(&retract)? - } - accumulator.evaluate()? - }; - row_wise_results.push(value); - last_range = cur_range; - } + + let mut accumulator = self.aggregate.create_sliding_accumulator()?; + let length = batch.num_rows(); + let (values, order_bys) = self.get_values_orderbys(batch)?; + + let mut window_frame_ctx = WindowFrameContext::new(&self.window_frame); + let mut last_range = Range { start: 0, end: 0 }; + + // We iterate on each row to perform a running calculation. + // First, cur_range is calculated, then it is compared with last_range. + for i in 0..length { + let cur_range = + window_frame_ctx.calculate_range(&order_bys, &sort_options, length, i)?; + let value = if cur_range.start == cur_range.end { + // We produce None if the window is empty. + ScalarValue::try_from(self.aggregate.field()?.data_type())? + } else { + // Accumulate any new rows that have entered the window: + let update_bound = cur_range.end - last_range.end; + if update_bound > 0 { + let update: Vec = values + .iter() + .map(|v| v.slice(last_range.end, update_bound)) + .collect(); + accumulator.update_batch(&update)? + } + // Remove rows that have now left the window: + let retract_bound = cur_range.start - last_range.start; + if retract_bound > 0 { + let retract: Vec = values + .iter() + .map(|v| v.slice(last_range.start, retract_bound)) + .collect(); + accumulator.retract_batch(&retract)? + } + accumulator.evaluate()? + }; + row_wise_results.push(value); + last_range = cur_range; } ScalarValue::iter_to_array(row_wise_results.into_iter()) } @@ -155,4 +149,25 @@ impl WindowExpr for SlidingAggregateWindowExpr { fn get_window_frame(&self) -> &Arc { &self.window_frame } + + fn get_reverse_expr(&self) -> Option> { + self.aggregate.reverse_expr().map(|reverse_expr| { + let reverse_window_frame = self.window_frame.reverse(); + if reverse_window_frame.start_bound.is_unbounded() { + Arc::new(AggregateWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + )) as _ + } else { + Arc::new(SlidingAggregateWindowExpr::new( + reverse_expr, + &self.partition_by.clone(), + &reverse_order_bys(&self.order_by), + Arc::new(self.window_frame.reverse()), + )) as _ + } + }) + } } diff --git a/datafusion/physical-expr/src/window/window_expr.rs b/datafusion/physical-expr/src/window/window_expr.rs index 2fbc6e2c4c8e..a718fa4cd3b3 100644 --- a/datafusion/physical-expr/src/window/window_expr.rs +++ b/datafusion/physical-expr/src/window/window_expr.rs @@ -17,10 +17,10 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::compute::kernels::partition::lexicographical_partition_ranges; -use arrow::compute::kernels::sort::{SortColumn, SortOptions}; +use arrow::compute::kernels::sort::SortColumn; use arrow::record_batch::RecordBatch; use arrow::{array::ArrayRef, datatypes::Field}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{reverse_sort_options, DataFusionError, Result}; use datafusion_expr::WindowFrame; use std::any::Any; use std::fmt::Debug; @@ -86,31 +86,6 @@ pub trait WindowExpr: Send + Sync + Debug { /// expressions that's from the window function's order by clause, empty if absent fn order_by(&self) -> &[PhysicalSortExpr]; - /// get partition columns that can be used for partitioning, empty if absent - fn partition_columns(&self, batch: &RecordBatch) -> Result> { - self.partition_by() - .iter() - .map(|expr| { - if let Some(idx) = - self.order_by().iter().position(|key| key.expr.eq(expr)) - { - self.order_by()[idx].clone() - } else { - // When ASC is true, by default NULLS LAST to be consistent with PostgreSQL's rule: - // https://www.postgresql.org/docs/current/queries-order.html - PhysicalSortExpr { - expr: expr.clone(), - options: SortOptions { - descending: false, - nulls_first: false, - }, - } - } - .evaluate_to_sort_column(batch) - }) - .collect() - } - /// get order by columns, empty if absent fn order_by_columns(&self, batch: &RecordBatch) -> Result> { self.order_by() @@ -121,10 +96,8 @@ pub trait WindowExpr: Send + Sync + Debug { /// get sort columns that can be used for peer evaluation, empty if absent fn sort_columns(&self, batch: &RecordBatch) -> Result> { - let mut sort_columns = self.partition_columns(batch)?; let order_by_columns = self.order_by_columns(batch)?; - sort_columns.extend(order_by_columns); - Ok(sort_columns) + Ok(order_by_columns) } /// Get values columns(argument of Window Function) @@ -140,6 +113,22 @@ pub trait WindowExpr: Send + Sync + Debug { Ok((values, order_bys)) } - // Get window frame of this WindowExpr, None if absent + /// Get the window frame of this [WindowExpr]. fn get_window_frame(&self) -> &Arc; + + /// Get the reverse expression of this [WindowExpr]. + fn get_reverse_expr(&self) -> Option>; +} + +/// Reverses the ORDER BY expression, which is useful during equivalent window +/// expression construction. For instance, 'ORDER BY a ASC, NULLS LAST' turns into +/// 'ORDER BY a DESC, NULLS FIRST'. +pub fn reverse_order_bys(order_bys: &[PhysicalSortExpr]) -> Vec { + order_bys + .iter() + .map(|e| PhysicalSortExpr { + expr: e.expr.clone(), + options: reverse_sort_options(e.options), + }) + .collect() } diff --git a/datafusion/physical-expr/src/window/window_frame_state.rs b/datafusion/physical-expr/src/window/window_frame_state.rs index 307ea91440df..b49bd3a22a78 100644 --- a/datafusion/physical-expr/src/window/window_frame_state.rs +++ b/datafusion/physical-expr/src/window/window_frame_state.rs @@ -26,6 +26,7 @@ use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits}; use std::cmp::min; use std::collections::VecDeque; use std::fmt::Debug; +use std::ops::Range; use std::sync::Arc; /// This object stores the window frame state for use in incremental calculations. @@ -68,7 +69,7 @@ impl<'a> WindowFrameContext<'a> { sort_options: &[SortOptions], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { match *self { WindowFrameContext::Rows(window_frame) => { Self::calculate_range_rows(window_frame, length, idx) @@ -99,7 +100,7 @@ impl<'a> WindowFrameContext<'a> { window_frame: &Arc, length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { let start = match window_frame.start_bound { // UNBOUNDED PRECEDING WindowFrameBound::Preceding(ScalarValue::UInt64(None)) => 0, @@ -152,7 +153,7 @@ impl<'a> WindowFrameContext<'a> { return Err(DataFusionError::Internal("Rows should be Uint".to_string())) } }; - Ok((start, end)) + Ok(Range { start, end }) } } @@ -171,7 +172,7 @@ impl WindowFrameStateRange { sort_options: &[SortOptions], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { let start = match window_frame.start_bound { WindowFrameBound::Preceding(ref n) => { if n.is_null() { @@ -240,7 +241,7 @@ impl WindowFrameStateRange { } } }; - Ok((start, end)) + Ok(Range { start, end }) } /// This function does the heavy lifting when finding range boundaries. It is meant to be @@ -333,7 +334,7 @@ impl WindowFrameStateGroups { range_columns: &[ArrayRef], length: usize, idx: usize, - ) -> Result<(usize, usize)> { + ) -> Result> { if range_columns.is_empty() { return Err(DataFusionError::Execution( "GROUPS mode requires an ORDER BY clause".to_string(), @@ -399,7 +400,7 @@ impl WindowFrameStateGroups { )) } }; - Ok((start, end)) + Ok(Range { start, end }) } /// This function does the heavy lifting when finding group boundaries. It is meant to be