Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Materialize dictionaries in group keys #8291

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions datafusion/core/tests/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> {
assert_eq!(min_limit, resulting_limit);

let s = ScalarValue::try_from_array(results[0].column(1), 0)?;
let month = match extract_as_utf(&s) {
Some(month) => month,
s => panic!("Expected month as Dict(_, Utf8) found {s:?}"),
let month = match s {
ScalarValue::Utf8(Some(month)) => month,
s => panic!("Expected month as Utf8 found {s:?}"),
};

let sql_on_partition_boundary = format!(
Expand All @@ -191,15 +191,6 @@ async fn parquet_distinct_partition_col() -> Result<()> {
Ok(())
}

fn extract_as_utf(v: &ScalarValue) -> Option<String> {
if let ScalarValue::Dictionary(_, v) = v {
if let ScalarValue::Utf8(v) = v.as_ref() {
return v.clone();
}
}
None
}

#[tokio::test]
async fn csv_filter_with_file_col() -> Result<()> {
let ctx = SessionContext::new();
Expand Down
27 changes: 4 additions & 23 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,18 @@

use crate::aggregates::group_values::GroupValues;
use ahash::RandomState;
use arrow::compute::cast;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::{Array, ArrayRef};
use arrow_schema::{DataType, SchemaRef};
use arrow_array::ArrayRef;
use arrow_schema::SchemaRef;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_physical_expr::EmitTo;
use hashbrown::raw::RawTable;

/// A [`GroupValues`] making use of [`Rows`]
pub struct GroupValuesRows {
/// The output schema
schema: SchemaRef,

/// Converter for the group values
row_converter: RowConverter,

Expand Down Expand Up @@ -79,7 +75,6 @@ impl GroupValuesRows {
let map = RawTable::with_capacity(0);

Ok(Self {
schema,
row_converter,
map,
map_size: 0,
Expand Down Expand Up @@ -170,7 +165,7 @@ impl GroupValues for GroupValuesRows {
.take()
.expect("Can not emit from empty rows");

let mut output = match emit_to {
let output = match emit_to {
EmitTo::All => {
let output = self.row_converter.convert_rows(&group_values)?;
group_values.clear();
Expand Down Expand Up @@ -203,20 +198,6 @@ impl GroupValues for GroupValuesRows {
}
};

// TODO: Materialize dictionaries in group keys (#7647)
for (field, array) in self.schema.fields.iter().zip(&mut output) {
let expected = field.data_type();
if let DataType::Dictionary(_, v) = expected {
let actual = array.data_type();
if v.as_ref() != actual {
return Err(DataFusionError::Internal(format!(
"Converted group rows expected dictionary of {v} got {actual}"
)));
}
*array = cast(array.as_ref(), expected)?;
}
}

self.group_values = Some(group_values);
Ok(output)
}
Expand Down
31 changes: 29 additions & 2 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use crate::{
use arrow::array::ArrayRef;
use arrow::datatypes::{Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_schema::DataType;
use datafusion_common::stats::Precision;
use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result};
use datafusion_execution::TaskContext;
Expand Down Expand Up @@ -286,6 +287,9 @@ pub struct AggregateExec {
limit: Option<usize>,
/// Input plan, could be a partial aggregate or the input to the aggregate
pub input: Arc<dyn ExecutionPlan>,
/// Original aggregation schema, could be different from `schema` before dictionary group
/// keys get materialized
original_schema: SchemaRef,
/// Schema after the aggregate is applied
schema: SchemaRef,
/// Input schema before any aggregation is applied. For partial aggregate this will be the
Expand Down Expand Up @@ -469,15 +473,19 @@ impl AggregateExec {
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
let schema = create_schema(
let original_schema = create_schema(
&input.schema(),
&group_by.expr,
&aggr_expr,
group_by.contains_null(),
mode,
)?;

let schema = Arc::new(schema);
let schema = Arc::new(materialize_dict_group_keys(
&original_schema,
group_by.expr.len(),
));
let original_schema = Arc::new(original_schema);
// Reset ordering requirement to `None` if aggregator is not order-sensitive
order_by_expr = aggr_expr
.iter()
Expand Down Expand Up @@ -552,6 +560,7 @@ impl AggregateExec {
filter_expr,
order_by_expr,
input,
original_schema,
schema,
input_schema,
projection_mapping,
Expand Down Expand Up @@ -973,6 +982,24 @@ fn create_schema(
Ok(Schema::new(fields))
}

/// returns schema with dictionary group keys materialized as their value types
/// The actual convertion happens in `RowConverter` and we don't do unnecessary
/// conversion back into dictionaries
fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GroupValues is a trait and there are two implementations of it -- GroupValueRows and GroupValuesPrimitive

This conversion only applies to GroupValueRows (though GroupValuesPrimitive won't be used for input dictionary columns anyways)

I wonder if there is some way to move the conversion into GroupValueRows so it is clearer from the context when this is needed or not

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb I do understand your point and it saw this when creating my PR but I don't quite see a good way around this: the problem is that GroupValues get's created only during execution in AggregateExec::execute_typed but the correct ("materialized") schema looks to be needed in AggregateExec itself, i.e. already at the planning stage (as a part of impl ExecutionPlan for AggregateExec).
Maybe I miss some option here?

let fields = schema
.fields
.iter()
.enumerate()
.map(|(i, field)| match field.data_type() {
DataType::Dictionary(_, value_data_type) if i < group_count => {
Field::new(field.name(), *value_data_type.clone(), field.is_nullable())
}
_ => Field::clone(field),
})
.collect::<Vec<_>>();
Schema::new(fields)
}

fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
let group_fields = schema.fields()[0..group_count].to_vec();
Arc::new(Schema::new(group_fields))
Expand Down
4 changes: 3 additions & 1 deletion datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ impl GroupedHashAggregateStream {
.map(create_group_accumulator)
.collect::<Result<_>>()?;

let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
// we need to use original schema so RowConverter in group_values below
// will do the proper coversion of dictionaries into value types
let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len());
let spill_expr = group_schema
.fields
.into_iter()
Expand Down
9 changes: 9 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,15 @@ select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict);
4
5

query T
select arrow_typeof(x_dict) from value_dict group by x_dict;
----
Int32
Int32
Int32
Int32
Int32

statement ok
drop table value

Expand Down