diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index dd8eb52f67c71..abe6ab283aff4 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -168,9 +168,9 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match extract_as_utf(&s) { - Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), + let month = match s { + ScalarValue::Utf8(Some(month)) => month, + s => panic!("Expected month as Utf8 found {s:?}"), }; let sql_on_partition_boundary = format!( @@ -191,15 +191,6 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } -fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } - } - None -} - #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 10ff9edb8912f..e7c7a42cf9029 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -17,22 +17,18 @@ use crate::aggregates::group_values::GroupValues; use ahash::RandomState; -use arrow::compute::cast; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::{Array, ArrayRef}; -use arrow_schema::{DataType, SchemaRef}; +use arrow_array::ArrayRef; +use arrow_schema::SchemaRef; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_physical_expr::EmitTo; use hashbrown::raw::RawTable; /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { - /// The output schema - schema: SchemaRef, - /// Converter for the group values row_converter: RowConverter, @@ -79,7 +75,6 @@ impl GroupValuesRows { let map = RawTable::with_capacity(0); Ok(Self { - schema, row_converter, map, map_size: 0, @@ -170,7 +165,7 @@ impl GroupValues for GroupValuesRows { .take() .expect("Can not emit from empty rows"); - let mut output = match emit_to { + let output = match emit_to { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; group_values.clear(); @@ -203,20 +198,6 @@ impl GroupValues for GroupValuesRows { } }; - // TODO: Materialize dictionaries in group keys (#7647) - for (field, array) in self.schema.fields.iter().zip(&mut output) { - let expected = field.data_type(); - if let DataType::Dictionary(_, v) = expected { - let actual = array.data_type(); - if v.as_ref() != actual { - return Err(DataFusionError::Internal(format!( - "Converted group rows expected dictionary of {v} got {actual}" - ))); - } - *array = cast(array.as_ref(), expected)?; - } - } - self.group_values = Some(group_values); Ok(output) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 7d7fba6ef6c31..d594335af44f2 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -38,6 +38,7 @@ use crate::{ use arrow::array::ArrayRef; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; +use arrow_schema::DataType; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_execution::TaskContext; @@ -286,6 +287,9 @@ pub struct AggregateExec { limit: Option, /// Input plan, could be a partial aggregate or the input to the aggregate pub input: Arc, + /// Original aggregation schema, could be different from `schema` before dictionary group + /// keys get materialized + original_schema: SchemaRef, /// Schema after the aggregate is applied schema: SchemaRef, /// Input schema before any aggregation is applied. For partial aggregate this will be the @@ -469,7 +473,7 @@ impl AggregateExec { input: Arc, input_schema: SchemaRef, ) -> Result { - let schema = create_schema( + let original_schema = create_schema( &input.schema(), &group_by.expr, &aggr_expr, @@ -477,7 +481,11 @@ impl AggregateExec { mode, )?; - let schema = Arc::new(schema); + let schema = Arc::new(materialize_dict_group_keys( + &original_schema, + group_by.expr.len(), + )); + let original_schema = Arc::new(original_schema); // Reset ordering requirement to `None` if aggregator is not order-sensitive order_by_expr = aggr_expr .iter() @@ -552,6 +560,7 @@ impl AggregateExec { filter_expr, order_by_expr, input, + original_schema, schema, input_schema, projection_mapping, @@ -973,6 +982,24 @@ fn create_schema( Ok(Schema::new(fields)) } +/// returns schema with dictionary group keys materialized as their value types +/// The actual convertion happens in `RowConverter` and we don't do unnecessary +/// conversion back into dictionaries +fn materialize_dict_group_keys(schema: &Schema, group_count: usize) -> Schema { + let fields = schema + .fields + .iter() + .enumerate() + .map(|(i, field)| match field.data_type() { + DataType::Dictionary(_, value_data_type) if i < group_count => { + Field::new(field.name(), *value_data_type.clone(), field.is_nullable()) + } + _ => Field::clone(field), + }) + .collect::>(); + Schema::new(fields) +} + fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef { let group_fields = schema.fields()[0..group_count].to_vec(); Arc::new(Schema::new(group_fields)) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index f96417fc323b0..2f94c3630c33a 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -324,7 +324,9 @@ impl GroupedHashAggregateStream { .map(create_group_accumulator) .collect::>()?; - let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); + // we need to use original schema so RowConverter in group_values below + // will do the proper coversion of dictionaries into value types + let group_schema = group_schema(&agg.original_schema, agg_group_by.expr.len()); let spill_expr = group_schema .fields .into_iter()