Skip to content

Commit

Permalink
fix: Make input_sorted_by_group_key tolerate unrelated single value c…
Browse files Browse the repository at this point in the history
…olumns outside the group key prefix (#171)

* fix: Make input_sorted_by_group_key tolerate unrelated single value columns outside the group key prefix

* Adds basic output_hints to SortExec actually just for the sake of test code

* Adds hash_agg_aggregation_strategy_with_nongrouped_single_value_columns_in_sort_key test
  • Loading branch information
srh authored Oct 25, 2024
1 parent c29478e commit b3acc9f
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 18 deletions.
86 changes: 79 additions & 7 deletions datafusion/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1716,16 +1716,21 @@ fn input_sorted_by_group_key(
}
sort_to_group[sort_key_pos] = group_i;
}
for i in 0..sort_key.len() {
if hints.single_value_columns.contains(&sort_key[i]) {
sort_key_hit[i] = true;

// At this point all elements of the group key mapped into some column of the sort key. This
// checks the group key is mapped into a prefix of the sort key, except that it's okay if it
// skips over single value columns.
let mut pref_len: usize = 0;
for (i, hit) in sort_key_hit.iter().enumerate() {
if !hit && !hints.single_value_columns.contains(&sort_key[i]) {
break;
}
pref_len += 1;
}

// At this point all elements of the group key mapped into some column of the sort key.
// This checks the group key is mapped into a prefix of the sort key.
let pref_len = sort_key_hit.iter().take_while(|present| **present).count();
if sort_key_hit[pref_len..].iter().any(|present| *present) {
// The group key did not hit a contiguous prefix of the sort key (ignoring single value
// columns); return false.
return false;
}

Expand Down Expand Up @@ -1753,7 +1758,8 @@ fn tuple_err<T, R>(value: (Result<T>, Result<R>)) -> Result<(T, R)> {
#[cfg(test)]
mod tests {
use super::*;
use crate::logical_plan::{DFField, DFSchema, DFSchemaRef};
use crate::logical_plan::{and, DFField, DFSchema, DFSchemaRef};
use crate::physical_plan::OptimizerHints;
use crate::physical_plan::{csv::CsvReadOptions, expressions, Partitioning};
use crate::scalar::ScalarValue;
use crate::{
Expand Down Expand Up @@ -2036,6 +2042,72 @@ mod tests {
Ok(())
}

#[test]
fn hash_agg_aggregation_strategy_with_nongrouped_single_value_columns_in_sort_key(
) -> Result<()> {
let testdata = crate::test_util::arrow_test_data();
let path = format!("{}/csv/aggregate_test_100.csv", testdata);

let options = CsvReadOptions::new().schema_infer_max_records(100);

fn sort(column_name: &str) -> Expr {
col(column_name).sort(true, true)
}

// Instead of creating a mock ExecutionPlan, we have some input plan which produces the desired output_hints().
let logical_plan = LogicalPlanBuilder::scan_csv(path, options, None)?
.filter(and(
col("c4").eq(lit("value_a")),
col("c8").eq(lit("value_b")),
))?
.sort(vec![
sort("c1"),
sort("c2"),
sort("c3"),
sort("c4"),
sort("c5"),
sort("c6"),
sort("c7"),
sort("c8"),
])?
.build()?;

let execution_plan = plan(&logical_plan)?;

// Note that both single_value_columns are part of the sort key... but one will not be part of the group key.
let hints: OptimizerHints = execution_plan.output_hints();
assert_eq!(hints.sort_order, Some(vec![0, 1, 2, 3, 4, 5, 6, 7]));
assert_eq!(hints.single_value_columns, vec![3, 7]);

// Now make a group_key that overlaps one single_value_column, but the single value column 7
// has column 5 and 6 ("c6" and "c7" respectively) in between.
let group_key = vec![col("c1"), col("c2"), col("c3"), col("c4"), col("c5")];
let mut ctx_state = make_ctx_state();
ctx_state.config.concurrency = 4;
let planner = DefaultPhysicalPlanner::default();
let mut physical_group_key = Vec::new();
for expr in group_key {
let phys_expr = planner.create_physical_expr(
&expr,
&logical_plan.schema(),
&execution_plan.schema(),
&ctx_state,
)?;
physical_group_key.push((phys_expr, "".to_owned()));
}

let mut sort_order = Vec::<usize>::new();
let is_sorted: bool = input_sorted_by_group_key(
execution_plan.as_ref(),
&physical_group_key,
&mut sort_order,
);
assert!(is_sorted);
assert_eq!(sort_order, vec![0, 1, 2, 3, 4]);

Ok(())
}

#[test]
fn test_explain() {
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
Expand Down
41 changes: 30 additions & 11 deletions datafusion/src/physical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

//! Defines the SORT plan
use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::cube_ext;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::expressions::{Column, PhysicalSortExpr};
use crate::physical_plan::{
common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, SQLMetric,
};
use crate::physical_plan::{
OptimizerHints, RecordBatchStream, SendableRecordBatchStream,
};
pub use arrow::compute::SortOptions;
use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions};
use arrow::datatypes::SchemaRef;
Expand Down Expand Up @@ -186,15 +188,32 @@ impl ExecutionPlan for SortExec {
metrics
}

// TODO
// fn output_sort_order(&self) -> Result<Option<Vec<usize>>> {
// let mut order = Vec::with_capacity(self.expr.len());
// for s in &self.expr {
// let col = s.expr.as_any().downcast_ref::<Column>()?;
// order.push(self.schema().index_of(col.name())?);
// }
// Ok(Some(order))
// }
fn output_hints(&self) -> OptimizerHints {
let mut order = Vec::with_capacity(self.expr.len());
// let mut sort_order_truncated = false;
for s in &self.expr {
let column = s.expr.as_any().downcast_ref::<Column>();
if column.is_none() {
// sort_order_truncated = true;
break;
}
let column = column.unwrap();

let index: usize = match self.schema().index_of(column.name()) {
Ok(ix) => ix,
Err(_) => return OptimizerHints::default(),
};
order.push(index);
}

let input_hints = self.input.output_hints();
// TODO: If sort_order_truncated is false, we can combine input_hints.sort_order. Do this.

OptimizerHints {
sort_order: Some(order),
single_value_columns: input_hints.single_value_columns.clone(),
}
}
}

#[tracing::instrument(level = "trace", skip(batch, schema, expr))]
Expand Down

0 comments on commit b3acc9f

Please sign in to comment.