Skip to content

Commit

Permalink
AggregateExec: Take grouping sets into account for InputOrderMode (#1…
Browse files Browse the repository at this point in the history
…1301)

* AggregateExec: Take grouping sets into account for InputOrderMode

* pr comments
  • Loading branch information
thinkharderdev authored Jul 7, 2024
1 parent 08c5345 commit e693ed7
Showing 1 changed file with 113 additions and 8 deletions.
121 changes: 113 additions & 8 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,26 @@ impl AggregateExec {
new_requirement.extend(req);
new_requirement = collapse_lex_req(new_requirement);

let input_order_mode =
if indices.len() == groupby_exprs.len() && !indices.is_empty() {
InputOrderMode::Sorted
} else if !indices.is_empty() {
InputOrderMode::PartiallySorted(indices)
} else {
InputOrderMode::Linear
};
// If our aggregation has grouping sets then our base grouping exprs will
// be expanded based on the flags in `group_by.groups` where for each
// group we swap the grouping expr for `null` if the flag is `true`
// That means that each index in `indices` is valid if and only if
// it is not null in every group
let indices: Vec<usize> = indices
.into_iter()
.filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
.collect();

let input_order_mode = if indices.len() == groupby_exprs.len()
&& !indices.is_empty()
&& group_by.groups.len() == 1
{
InputOrderMode::Sorted
} else if !indices.is_empty() {
InputOrderMode::PartiallySorted(indices)
} else {
InputOrderMode::Linear
};

// construct a map from the input expression to the output expression of the Aggregation group by
let projection_mapping =
Expand Down Expand Up @@ -1180,6 +1192,7 @@ mod tests {
use arrow::array::{Float64Array, UInt32Array};
use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::DataType;
use arrow_array::{Float32Array, Int32Array};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
ScalarValue,
Expand All @@ -1195,7 +1208,9 @@ mod tests {
use datafusion_physical_expr::expressions::{lit, OrderSensitiveArrayAgg};
use datafusion_physical_expr::PhysicalSortExpr;

use crate::common::collect;
use datafusion_physical_expr_common::aggregate::create_aggregate_expr;
use datafusion_physical_expr_common::expressions::Literal;
use futures::{FutureExt, Stream};

// Generate a schema which consists of 5 columns (a, b, c, d, e)
Expand Down Expand Up @@ -2267,4 +2282,94 @@ mod tests {
assert_eq!(new_agg.schema(), aggregate_exec.schema());
Ok(())
}

#[tokio::test]
async fn test_agg_exec_group_by_const() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
Field::new("const", DataType::Int32, false),
]));

let col_a = col("a", &schema)?;
let col_b = col("b", &schema)?;
let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));

let groups = PhysicalGroupBy::new(
vec![
(col_a, "a".to_string()),
(col_b, "b".to_string()),
(const_expr, "const".to_string()),
],
vec![
(
Arc::new(Literal::new(ScalarValue::Float32(None))),
"a".to_string(),
),
(
Arc::new(Literal::new(ScalarValue::Float32(None))),
"b".to_string(),
),
(
Arc::new(Literal::new(ScalarValue::Int32(None))),
"const".to_string(),
),
],
vec![
vec![false, true, true],
vec![true, false, true],
vec![true, true, false],
],
);

let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
count_udaf().as_ref(),
&[lit(1)],
&[datafusion_expr::lit(1)],
&[],
&[],
schema.as_ref(),
"1",
false,
false,
)?];

let input_batches = (0..4)
.map(|_| {
let a = Arc::new(Float32Array::from(vec![0.; 8192]));
let b = Arc::new(Float32Array::from(vec![0.; 8192]));
let c = Arc::new(Int32Array::from(vec![1; 8192]));

RecordBatch::try_new(schema.clone(), vec![a, b, c]).unwrap()
})
.collect();

let input =
Arc::new(MemoryExec::try_new(&[input_batches], schema.clone(), None)?);

let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates.clone(),
vec![None],
input,
schema,
)?);

let output =
collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;

let expected = [
"+-----+-----+-------+----------+",
"| a | b | const | 1[count] |",
"+-----+-----+-------+----------+",
"| | 0.0 | | 32768 |",
"| 0.0 | | | 32768 |",
"| | | 1 | 32768 |",
"+-----+-----+-------+----------+",
];
assert_batches_sorted_eq!(expected, &output);

Ok(())
}
}

0 comments on commit e693ed7

Please sign in to comment.