Skip to content

Commit

Permalink
physical-plan: Cast nested group values back to dictionary if necessa…
Browse files Browse the repository at this point in the history
…ry (apache#12586)
  • Loading branch information
brancz authored and bgjackma committed Sep 25, 2024
1 parent ced28fd commit 350e3f8
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 13 deletions.
60 changes: 49 additions & 11 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ use ahash::RandomState;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::{Array, ArrayRef};
use arrow_array::{Array, ArrayRef, ListArray, StructArray};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use hashbrown::raw::RawTable;
use std::sync::Arc;

/// A [`GroupValues`] making use of [`Rows`]
pub struct GroupValuesRows {
Expand Down Expand Up @@ -221,15 +222,10 @@ impl GroupValues for GroupValuesRows {
// TODO: Materialize dictionaries in group keys (#7647)
for (field, array) in self.schema.fields.iter().zip(&mut output) {
let expected = field.data_type();
if let DataType::Dictionary(_, v) = expected {
let actual = array.data_type();
if v.as_ref() != actual {
return Err(DataFusionError::Internal(format!(
"Converted group rows expected dictionary of {v} got {actual}"
)));
}
*array = cast(array.as_ref(), expected)?;
}
*array = dictionary_encode_if_necessary(
Arc::<dyn arrow_array::Array>::clone(array),
expected,
)?;
}

self.group_values = Some(group_values);
Expand All @@ -249,3 +245,45 @@ impl GroupValues for GroupValuesRows {
self.hashes_buffer.shrink_to(count);
}
}

fn dictionary_encode_if_necessary(
array: ArrayRef,
expected: &DataType,
) -> Result<ArrayRef> {
match (expected, array.data_type()) {
(DataType::Struct(expected_fields), _) => {
let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
let arrays = expected_fields
.iter()
.zip(struct_array.columns())
.map(|(expected_field, column)| {
dictionary_encode_if_necessary(
Arc::<dyn arrow_array::Array>::clone(column),
expected_field.data_type(),
)
})
.collect::<Result<Vec<_>>>()?;

Ok(Arc::new(StructArray::try_new(
expected_fields.clone(),
arrays,
struct_array.nulls().cloned(),
)?))
}
(DataType::List(expected_field), &DataType::List(_)) => {
let list = array.as_any().downcast_ref::<ListArray>().unwrap();

Ok(Arc::new(ListArray::try_new(
Arc::<arrow_schema::Field>::clone(expected_field),
list.offsets().clone(),
dictionary_encode_if_necessary(
Arc::<dyn arrow_array::Array>::clone(list.values()),
expected_field.data_type(),
)?,
list.nulls().cloned(),
)?))
}
(DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?),
(_, _) => Ok(Arc::<dyn arrow_array::Array>::clone(&array)),
}
}
128 changes: 126 additions & 2 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1213,8 +1213,10 @@ mod tests {

use arrow::array::{Float64Array, UInt32Array};
use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::DataType;
use arrow_array::{Float32Array, Int32Array};
use arrow::datatypes::{DataType, Int32Type};
use arrow_array::{
DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array,
};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
ScalarValue,
Expand All @@ -1227,6 +1229,7 @@ mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
use datafusion_functions_aggregate::median::median_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::lit;
use datafusion_physical_expr::PhysicalSortExpr;

Expand Down Expand Up @@ -2329,6 +2332,127 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_agg_exec_struct_of_dicts() -> Result<()> {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new(
"labels".to_string(),
DataType::Struct(
vec![
Field::new_dict(
"a".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
),
Field::new_dict(
"b".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
),
]
.into(),
),
false,
),
Field::new("value", DataType::UInt64, false),
])),
vec![
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new_dict(
"a".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
)),
Arc::new(
vec![Some("a"), None, Some("a")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
),
(
Arc::new(Field::new_dict(
"b".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
)),
Arc::new(
vec![Some("b"), Some("c"), Some("b")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
),
])),
Arc::new(UInt64Array::from(vec![1, 1, 1])),
],
)
.expect("Failed to create RecordBatch");

let group_by = PhysicalGroupBy::new_single(vec![(
col("labels", &batch.schema())?,
"labels".to_string(),
)]);

let aggr_expr = vec![AggregateExprBuilder::new(
sum_udaf(),
vec![col("value", &batch.schema())?],
)
.schema(Arc::clone(&batch.schema()))
.alias(String::from("SUM(value)"))
.build()?];

let input = Arc::new(MemoryExec::try_new(
&[vec![batch.clone()]],
Arc::<arrow_schema::Schema>::clone(&batch.schema()),
None,
)?);
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::FinalPartitioned,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
batch.schema(),
)?);

let session_config = SessionConfig::default();
let ctx = TaskContext::default().with_session_config(session_config);
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;

let expected = [
"+--------------+------------+",
"| labels | SUM(value) |",
"+--------------+------------+",
"| {a: a, b: b} | 2 |",
"| {a: , b: c} | 1 |",
"+--------------+------------+",
];
assert_batches_eq!(expected, &output);

Ok(())
}

#[tokio::test]
async fn test_skip_aggregation_after_first_batch() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Expand Down

0 comments on commit 350e3f8

Please sign in to comment.