Skip to content

Commit

Permalink
physical-plan: Cast nested group values back to dictionary if necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
brancz committed Sep 23, 2024
1 parent 3b93cc9 commit 4adf217
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 3 deletions.
98 changes: 97 additions & 1 deletion datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ use ahash::RandomState;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::{Array, ArrayRef};
use arrow_array::{Array, ArrayRef, ListArray, StructArray};
use arrow_schema::{DataType, SchemaRef};
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use hashbrown::raw::RawTable;
use std::sync::Arc;

/// A [`GroupValues`] making use of [`Rows`]
pub struct GroupValuesRows {
Expand Down Expand Up @@ -230,6 +231,11 @@ impl GroupValues for GroupValuesRows {
}
*array = cast(array.as_ref(), expected)?;
}

if expected.is_nested() && needs_nested_dictionary_encoding(expected, array)?
{
*array = dictionary_encode_nested(array.clone(), expected)?;
}
}

self.group_values = Some(group_values);
Expand All @@ -249,3 +255,93 @@ impl GroupValues for GroupValuesRows {
self.hashes_buffer.shrink_to(count);
}
}

fn needs_nested_dictionary_encoding(
expected: &DataType,
actual: &ArrayRef,
) -> Result<bool> {
match (expected, actual.data_type()) {
(
&DataType::Struct(ref expected_fields),
&DataType::Struct(ref actual_fields),
) => {
if expected_fields.len() != actual_fields.len() {
return Err(DataFusionError::Internal(format!(
"Converted group rows expected struct of {} fields got {}",
expected_fields.len(),
actual_fields.len(),
)));
}

let actual_struct = actual.as_any().downcast_ref::<StructArray>().unwrap();
Ok(expected_fields
.iter()
.zip(actual_struct.columns().iter())
.map(|(expected_field, actual_column)| {
// Propagate the result of needs_nested_dictionary_encoding
needs_nested_dictionary_encoding(
expected_field.data_type(),
actual_column,
)
})
.try_fold(false, |acc, needs_nested| {
Ok::<bool, DataFusionError>(acc || needs_nested?)
})?)
}
(&DataType::List(ref expected_field), &DataType::List(_)) => {
let actual_list = actual.as_any().downcast_ref::<ListArray>().unwrap();
needs_nested_dictionary_encoding(
expected_field.data_type(),
actual_list.values(),
)
}
(&DataType::Dictionary(_, ref value), _) => {
let actual_data_type = actual.data_type();
let value_data_type = value.as_ref();
if value_data_type != actual_data_type {
return Err(DataFusionError::Internal(format!(
"Converted group rows expected dictionary of {value_data_type} got {actual_data_type}"
)));
}

Ok(true)
}
(_, _) => Ok(false),
}
}

fn dictionary_encode_nested(array: ArrayRef, expected: &DataType) -> Result<ArrayRef> {
match (expected, array.data_type()) {
(&DataType::Struct(ref expected_fields), _) => {
let struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
let arrays = expected_fields
.iter()
.zip(struct_array.columns())
.map(|(expected_field, column)| {
dictionary_encode_nested(column.clone(), expected_field.data_type())
})
.collect::<Result<Vec<_>>>()?;

Ok(Arc::new(StructArray::try_new(
expected_fields.clone(),
arrays,
struct_array.nulls().cloned(),
)?))
}
(&DataType::List(ref expected_field), &DataType::List(_)) => {
let list = array.as_any().downcast_ref::<ListArray>().unwrap();

Ok(Arc::new(ListArray::try_new(
expected_field.clone(),
list.offsets().clone(),
dictionary_encode_nested(
list.values().clone(),
expected_field.data_type(),
)?,
list.nulls().cloned(),
)?))
}
(&DataType::Dictionary(_, _), _) => Ok(cast(array.as_ref(), expected)?),
(_, _) => Ok(array.clone()),
}
}
128 changes: 126 additions & 2 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1200,8 +1200,10 @@ mod tests {

use arrow::array::{Float64Array, UInt32Array};
use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::DataType;
use arrow_array::{Float32Array, Int32Array};
use arrow::datatypes::{DataType, Int32Type};
use arrow_array::{
DictionaryArray, Float32Array, Int32Array, StructArray, UInt64Array,
};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
ScalarValue,
Expand All @@ -1214,6 +1216,7 @@ mod tests {
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::first_last::{first_value_udaf, last_value_udaf};
use datafusion_functions_aggregate::median::median_udaf;
use datafusion_functions_aggregate::sum::sum_udaf;
use datafusion_physical_expr::expressions::lit;
use datafusion_physical_expr::PhysicalSortExpr;

Expand Down Expand Up @@ -2316,6 +2319,127 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_agg_exec_struct_of_dicts() -> Result<()> {
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![
Field::new(
"labels".to_string(),
DataType::Struct(
vec![
Field::new_dict(
"a".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
),
Field::new_dict(
"b".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
),
]
.into(),
),
false,
),
Field::new("value", DataType::UInt64, false),
])),
vec![
Arc::new(StructArray::from(vec![
(
Arc::new(Field::new_dict(
"a".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
)),
Arc::new(
vec![Some("a"), None, Some("a")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
),
(
Arc::new(Field::new_dict(
"b".to_string(),
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Utf8),
),
true,
0,
false,
)),
Arc::new(
vec![Some("b"), Some("c"), Some("b")]
.into_iter()
.collect::<DictionaryArray<Int32Type>>(),
) as ArrayRef,
),
])),
Arc::new(UInt64Array::from(vec![1, 1, 1])),
],
)
.expect("Failed to create RecordBatch");

let group_by = PhysicalGroupBy::new_single(vec![(
col("labels", &batch.schema())?,
"labels".to_string(),
)]);

let aggr_expr = vec![AggregateExprBuilder::new(
sum_udaf(),
vec![col("value", &batch.schema())?],
)
.schema(Arc::clone(&batch.schema()))
.alias(String::from("SUM(value)"))
.build()?];

let input = Arc::new(MemoryExec::try_new(
&[vec![batch.clone()]],
Arc::clone(&batch.schema().clone()),
None,
)?);
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::FinalPartitioned,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
batch.schema(),
)?);

let session_config = SessionConfig::default();
let ctx = TaskContext::default().with_session_config(session_config);
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;

let expected = [
"+--------------+------------+",
"| labels | SUM(value) |",
"+--------------+------------+",
"| {a: a, b: b} | 2 |",
"| {a: , b: c} | 1 |",
"+--------------+------------+",
];
assert_batches_eq!(expected, &output);

Ok(())
}

#[tokio::test]
async fn test_skip_aggregation_after_first_batch() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Expand Down

0 comments on commit 4adf217

Please sign in to comment.