Skip to content

Commit

Permalink
Supports nulls in group by
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-biryukov committed Aug 12, 2021
1 parent 82bda79 commit e524173
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 33 deletions.
4 changes: 3 additions & 1 deletion datafusion/src/physical_plan/distinct_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ impl Accumulator for DistinctCountAccumulator {
self.values.iter().unique().for_each(|distinct_values| {
distinct_values.0.iter().enumerate().for_each(
|(col_index, distinct_value)| {
cols_vec[col_index].push(ScalarValue::from(distinct_value));
cols_vec[col_index].push(
distinct_value.to_scalar(&self.state_data_types[col_index]),
);
},
)
});
Expand Down
41 changes: 17 additions & 24 deletions datafusion/src/physical_plan/group_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ use std::convert::{From, TryFrom};

use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;
use arrow::datatypes::DataType;

/// Enumeration of types that can be used in a GROUP BY expression
#[allow(missing_docs)]
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
pub enum GroupByScalar {
Null,
Float32(OrderedFloat<f32>),
Float64(OrderedFloat<f64>),
UInt8(u8),
Expand Down Expand Up @@ -95,18 +97,10 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
| ScalarValue::UInt32(None)
| ScalarValue::UInt64(None)
| ScalarValue::Utf8(None)
| ScalarValue::Int64Decimal(None, 0)
| ScalarValue::Int64Decimal(None, 1)
| ScalarValue::Int64Decimal(None, 2)
| ScalarValue::Int64Decimal(None, 3)
| ScalarValue::Int64Decimal(None, 4)
| ScalarValue::Int64Decimal(None, 5)
| ScalarValue::Int64Decimal(None, 10) => {
return Err(DataFusionError::Internal(format!(
"Cannot convert a ScalarValue holding NULL ({:?})",
scalar_value
)));
}
| ScalarValue::Int64Decimal(None, _)
| ScalarValue::TimestampMillisecond(None)
| ScalarValue::TimestampMicrosecond(None)
| ScalarValue::TimestampNanosecond(None) => GroupByScalar::Null,
v => {
return Err(DataFusionError::Internal(format!(
"Cannot convert a ScalarValue with associated DataType {:?}",
Expand All @@ -117,9 +111,13 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
}
}

impl From<&GroupByScalar> for ScalarValue {
fn from(group_by_scalar: &GroupByScalar) -> Self {
match group_by_scalar {
impl GroupByScalar {
/// Convert to ScalarValue.
pub fn to_scalar(&self, ty: &DataType) -> ScalarValue {
let r = match self {
GroupByScalar::Null => {
ScalarValue::try_from(ty).expect("could not create null")
}
GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())),
GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())),
GroupByScalar::Boolean(v) => ScalarValue::Boolean(Some(*v)),
Expand Down Expand Up @@ -147,7 +145,9 @@ impl From<&GroupByScalar> for ScalarValue {
ScalarValue::TimestampNanosecond(Some(*v))
}
GroupByScalar::Date32(v) => ScalarValue::Date32(Some(*v)),
}
};
debug_assert_eq!(&r.get_datatype(), ty);
r
}
}

Expand Down Expand Up @@ -199,14 +199,7 @@ mod tests {
fn from_scalar_holding_none() {
let scalar_value = ScalarValue::Int8(None);
let result = GroupByScalar::try_from(&scalar_value);

match result {
Err(DataFusionError::Internal(error_message)) => assert_eq!(
error_message,
String::from("Cannot convert a ScalarValue holding NULL (Int8(NULL))")
),
_ => panic!("Unexpected result"),
}
assert_eq!(result.unwrap(), GroupByScalar::Null);
}

#[test]
Expand Down
11 changes: 8 additions & 3 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,7 @@ pub(crate) fn create_batch_from_map(
*mode,
group_by_values,
accumulator_set,
&output_schema.fields()[0..num_group_expr],
&mut key_columns,
&mut value_columns,
)
Expand Down Expand Up @@ -1161,6 +1162,7 @@ pub fn write_group_result_row(
mode: AggregateMode,
group_by_values: &[GroupByScalar],
accumulator_set: &AccumulatorSet,
key_fields: &[Field],
key_columns: &mut Vec<Box<dyn ArrayBuilder>>,
value_columns: &mut Vec<Box<dyn ArrayBuilder>>,
) -> Result<()> {
Expand All @@ -1179,7 +1181,7 @@ pub fn write_group_result_row(
.append_value(str)?;
}
v => {
let scalar = &ScalarValue::from(v);
let scalar = v.to_scalar(key_fields[i].data_type());
if add_key_columns {
key_columns.push(create_builder(&scalar));
}
Expand Down Expand Up @@ -1386,7 +1388,10 @@ fn dictionary_create_group_by_value<K: ArrowDictionaryKeyType>(
}

/// Extract the value in `col[row]` as a GroupByScalar
fn create_group_by_value(col: &ArrayRef, row: usize) -> Result<GroupByScalar> {
pub(crate) fn create_group_by_value(col: &ArrayRef, row: usize) -> Result<GroupByScalar> {
if col.is_null(row) {
return Ok(GroupByScalar::Null);
}
match col.data_type() {
DataType::Float32 => {
let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
Expand Down Expand Up @@ -1549,7 +1554,7 @@ async fn compute_grouped_sorted_aggregate(
.map_err(DataFusionError::into_arrow_external_error)?;

state
.add_batch(mode, &aggr_expr, &group_values, &aggr_input_values)
.add_batch(mode, &aggr_expr, &group_values, &aggr_input_values, &schema)
.map_err(DataFusionError::into_arrow_external_error)?;
}
state.finish(mode, schema)
Expand Down
12 changes: 7 additions & 5 deletions datafusion/src/physical_plan/sorted_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
use crate::error::{DataFusionError, Result};
use crate::physical_plan::group_scalar::GroupByScalar;
use crate::physical_plan::hash_aggregate::{
create_accumulators, create_group_by_values, write_group_result_row, AccumulatorSet,
AggregateMode,
create_accumulators, create_group_by_value, create_group_by_values,
write_group_result_row, AccumulatorSet, AggregateMode,
};
use crate::physical_plan::AggregateExpr;
use crate::scalar::ScalarValue;
use arrow::array::{ArrayBuilder, ArrayRef, LargeStringArray, StringArray};
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use itertools::Itertools;
use smallvec::smallvec;
Expand Down Expand Up @@ -67,6 +67,7 @@ impl SortedAggState {
mode,
&agg.key,
&agg.accumulators,
&schema.fields()[0..agg.key.len()],
&mut self.processed_keys,
&mut self.processed_values,
)
Expand All @@ -91,6 +92,7 @@ impl SortedAggState {
agg_exprs: &Vec<Arc<dyn AggregateExpr>>,
key_columns: &[ArrayRef],
aggr_input_values: &[Vec<ArrayRef>],
out_schema: &Schema,
) -> Result<()> {
assert_ne!(key_columns.len(), 0);
assert_eq!(aggr_input_values.len(), agg_exprs.len());
Expand Down Expand Up @@ -134,6 +136,7 @@ impl SortedAggState {
mode,
&current_agg.key,
&current_agg.accumulators,
&out_schema.fields()[0..current_agg.key.len()],
&mut self.processed_keys,
&mut self.processed_values,
)?;
Expand Down Expand Up @@ -216,8 +219,7 @@ fn agg_key_equals(
}
}
l => {
let r = ScalarValue::try_from_array(&key_columns[i], row)?;
if ScalarValue::from(l) != r {
if l != &create_group_by_value(&key_columns[i], row)? {
return Ok(false);
}
}
Expand Down

0 comments on commit e524173

Please sign in to comment.