From 6459945fb3455ced6d23f23612cbbcd74ce58e07 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 13 Jul 2023 07:54:22 -0400 Subject: [PATCH] Fix incorrect results in `BitAnd` GroupsAccumulator Fix accumulator --- .../sqllogictests/test_files/aggregate.slt | 184 +++++++++++------- .../src/aggregate/bit_and_or_xor.rs | 83 +++----- .../aggregate/groups_accumulator/prim_op.rs | 12 +- 3 files changed, 160 insertions(+), 119 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt index 95cf51d57187..72b9e8400b61 100644 --- a/datafusion/core/tests/sqllogictests/test_files/aggregate.slt +++ b/datafusion/core/tests/sqllogictests/test_files/aggregate.slt @@ -1420,65 +1420,95 @@ select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.c 2 1 1.414213562373 1 -# sum / count for all nulls -statement ok -create table the_nulls as values (null::bigint, 1), (null::bigint, 1), (null::bigint, 2); -# counts should be zeros (even for nulls) -query II -SELECT count(column1), column2 from the_nulls group by column2 order by column2; ----- -0 1 -0 2 - -# sums should be null -query II -SELECT sum(column1), column2 from the_nulls group by column2 order by column2; +# aggregates on empty tables +statement ok +CREATE TABLE empty (column1 bigint, column2 int); + +# no group by column +query IIRIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1) +FROM empty +---- +0 NULL NULL NULL NULL NULL NULL NULL + +# Same query but with grouping (no groups, so no output) +query IIRIIIIII +SELECT + count(column1), + sum(column1), + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1), + column2 +FROM empty +GROUP BY column2 +ORDER BY column2; ---- -NULL 1 -NULL 2 -# avg should be null -query RI -SELECT avg(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 -# bit_and should be null -query II -SELECT bit_and(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 +statement ok +drop table empty -# bit_or should be null -query II -SELECT bit_or(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 +# aggregates on all nulls +statement ok +CREATE TABLE the_nulls +AS VALUES + (null::bigint, 1), + (null::bigint, 1), + (null::bigint, 2); -# bit_xor should be null query II -SELECT bit_xor(column1), column2 from the_nulls group by column2 order by column2; +select * from the_nulls ---- NULL 1 -NULL 2 - -# min should be null -query II -SELECT min(column1), column2 from the_nulls group by column2 order by column2; ----- NULL 1 NULL 2 -# max should be null -query II -SELECT max(column1), column2 from the_nulls group by column2 order by column2; ----- -NULL 1 -NULL 2 +# no group by column +query IIRIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1) +FROM the_nulls +---- +0 NULL NULL NULL NULL NULL NULL NULL + +# Same query but with grouping +query IIRIIIIII +SELECT + count(column1), -- counts should be zero, even for nulls + sum(column1), -- other aggregates should be null + avg(column1), + min(column1), + max(column1), + bit_and(column1), + bit_or(column1), + bit_xor(column1), + column2 +FROM the_nulls +GROUP BY column2 +ORDER BY column2; +---- +0 NULL NULL NULL NULL NULL NULL NULL 1 +0 NULL NULL NULL NULL NULL NULL NULL 2 statement ok @@ -1489,29 +1519,49 @@ create table bit_aggregate_functions ( c1 SMALLINT NOT NULL, c2 SMALLINT NOT NULL, c3 SMALLINT, + tag varchar ) as values - (5, 10, 11), - (33, 11, null), - (9, 12, null); - -# query_bit_and -query III -SELECT bit_and(c1), bit_and(c2), bit_and(c3) FROM bit_aggregate_functions ----- -1 8 11 - -# query_bit_or -query III -SELECT bit_or(c1), bit_or(c2), bit_or(c3) FROM bit_aggregate_functions ----- -45 15 11 + (5, 10, 11, 'A'), + (33, 11, null, 'B'), + (9, 12, null, 'A'); + +# query_bit_and, query_bit_or, query_bit_xor +query IIIIIIIII +SELECT + bit_and(c1), + bit_and(c2), + bit_and(c3), + bit_or(c1), + bit_or(c2), + bit_or(c3), + bit_xor(c1), + bit_xor(c2), + bit_xor(c3) +FROM bit_aggregate_functions +---- +1 8 11 45 15 11 45 13 11 + +# query_bit_and, query_bit_or, query_bit_xor, with group +query IIIIIIIIIT +SELECT + bit_and(c1), + bit_and(c2), + bit_and(c3), + bit_or(c1), + bit_or(c2), + bit_or(c3), + bit_xor(c1), + bit_xor(c2), + bit_xor(c3), + tag +FROM bit_aggregate_functions +GROUP BY tag +ORDER BY tag +---- +1 8 11 13 14 11 12 6 11 A +33 11 NULL 33 11 NULL 33 11 NULL B -# query_bit_xor -query III -SELECT bit_xor(c1), bit_xor(c2), bit_xor(c3) FROM bit_aggregate_functions ----- -45 13 11 statement ok create table bool_aggregate_functions ( diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs index ab37e5891e3f..6a2d50938944 100644 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs @@ -49,15 +49,16 @@ use arrow::compute::{bit_and, bit_or, bit_xor}; use datafusion_row::accessor::RowAccessor; /// Creates a [`PrimitiveGroupsAccumulator`] with the specified -/// [`ArrowPrimitiveType`] which applies `$FN` to each element +/// [`ArrowPrimitiveType`] that initailizes each accumulator to $START +/// and applies `$FN` to each element /// /// [`ArrowPrimitiveType`]: arrow::datatypes::ArrowPrimitiveType -macro_rules! instantiate_primitive_accumulator { - ($SELF:expr, $PRIMTYPE:ident, $FN:expr) => {{ - Ok(Box::new(PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new( - &$SELF.data_type, - $FN, - ))) +macro_rules! instantiate_accumulator { + ($SELF:expr, $START:expr, $PRIMTYPE:ident, $FN:expr) => {{ + Ok(Box::new( + PrimitiveGroupsAccumulator::<$PRIMTYPE, _>::new(&$SELF.data_type, $FN) + .with_starting_value($START), + )) }}; } @@ -279,35 +280,31 @@ impl AggregateExpr for BitAnd { use std::ops::BitAndAssign; match self.data_type { DataType::Int8 => { - instantiate_primitive_accumulator!(self, Int8Type, |x, y| x - .bitand_assign(y)) + instantiate_accumulator!(self, -1, Int8Type, |x, y| x.bitand_assign(y)) } DataType::Int16 => { - instantiate_primitive_accumulator!(self, Int16Type, |x, y| x - .bitand_assign(y)) + instantiate_accumulator!(self, -1, Int16Type, |x, y| x.bitand_assign(y)) } DataType::Int32 => { - instantiate_primitive_accumulator!(self, Int32Type, |x, y| x - .bitand_assign(y)) + instantiate_accumulator!(self, -1, Int32Type, |x, y| x.bitand_assign(y)) } DataType::Int64 => { - instantiate_primitive_accumulator!(self, Int64Type, |x, y| x - .bitand_assign(y)) + instantiate_accumulator!(self, -1, Int64Type, |x, y| x.bitand_assign(y)) } DataType::UInt8 => { - instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x + instantiate_accumulator!(self, u8::MAX, UInt8Type, |x, y| x .bitand_assign(y)) } DataType::UInt16 => { - instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x + instantiate_accumulator!(self, u16::MAX, UInt16Type, |x, y| x .bitand_assign(y)) } DataType::UInt32 => { - instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x + instantiate_accumulator!(self, u32::MAX, UInt32Type, |x, y| x .bitand_assign(y)) } DataType::UInt64 => { - instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x + instantiate_accumulator!(self, u64::MAX, UInt64Type, |x, y| x .bitand_assign(y)) } @@ -517,36 +514,28 @@ impl AggregateExpr for BitOr { use std::ops::BitOrAssign; match self.data_type { DataType::Int8 => { - instantiate_primitive_accumulator!(self, Int8Type, |x, y| x - .bitor_assign(y)) + instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitor_assign(y)) } DataType::Int16 => { - instantiate_primitive_accumulator!(self, Int16Type, |x, y| x - .bitor_assign(y)) + instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitor_assign(y)) } DataType::Int32 => { - instantiate_primitive_accumulator!(self, Int32Type, |x, y| x - .bitor_assign(y)) + instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitor_assign(y)) } DataType::Int64 => { - instantiate_primitive_accumulator!(self, Int64Type, |x, y| x - .bitor_assign(y)) + instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitor_assign(y)) } DataType::UInt8 => { - instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x - .bitor_assign(y)) + instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitor_assign(y)) } DataType::UInt16 => { - instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x - .bitor_assign(y)) + instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitor_assign(y)) } DataType::UInt32 => { - instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x - .bitor_assign(y)) + instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitor_assign(y)) } DataType::UInt64 => { - instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x - .bitor_assign(y)) + instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitor_assign(y)) } _ => Err(DataFusionError::NotImplemented(format!( @@ -756,36 +745,28 @@ impl AggregateExpr for BitXor { use std::ops::BitXorAssign; match self.data_type { DataType::Int8 => { - instantiate_primitive_accumulator!(self, Int8Type, |x, y| x - .bitxor_assign(y)) + instantiate_accumulator!(self, 0, Int8Type, |x, y| x.bitxor_assign(y)) } DataType::Int16 => { - instantiate_primitive_accumulator!(self, Int16Type, |x, y| x - .bitxor_assign(y)) + instantiate_accumulator!(self, 0, Int16Type, |x, y| x.bitxor_assign(y)) } DataType::Int32 => { - instantiate_primitive_accumulator!(self, Int32Type, |x, y| x - .bitxor_assign(y)) + instantiate_accumulator!(self, 0, Int32Type, |x, y| x.bitxor_assign(y)) } DataType::Int64 => { - instantiate_primitive_accumulator!(self, Int64Type, |x, y| x - .bitxor_assign(y)) + instantiate_accumulator!(self, 0, Int64Type, |x, y| x.bitxor_assign(y)) } DataType::UInt8 => { - instantiate_primitive_accumulator!(self, UInt8Type, |x, y| x - .bitxor_assign(y)) + instantiate_accumulator!(self, 0, UInt8Type, |x, y| x.bitxor_assign(y)) } DataType::UInt16 => { - instantiate_primitive_accumulator!(self, UInt16Type, |x, y| x - .bitxor_assign(y)) + instantiate_accumulator!(self, 0, UInt16Type, |x, y| x.bitxor_assign(y)) } DataType::UInt32 => { - instantiate_primitive_accumulator!(self, UInt32Type, |x, y| x - .bitxor_assign(y)) + instantiate_accumulator!(self, 0, UInt32Type, |x, y| x.bitxor_assign(y)) } DataType::UInt64 => { - instantiate_primitive_accumulator!(self, UInt64Type, |x, y| x - .bitxor_assign(y)) + instantiate_accumulator!(self, 0, UInt64Type, |x, y| x.bitxor_assign(y)) } _ => Err(DataFusionError::NotImplemented(format!( diff --git a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs index 860301078909..a49651a5e3fa 100644 --- a/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs +++ b/datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs @@ -47,6 +47,9 @@ where /// The output type (needed for Decimal precision and scale) data_type: DataType, + /// The starting value for new groups + starting_value: T::Native, + /// Track nulls in the input / filters null_state: NullState, @@ -64,9 +67,16 @@ where values: vec![], data_type: data_type.clone(), null_state: NullState::new(), + starting_value: T::default_value(), prim_fn, } } + + /// Set the starting values for new groups + pub fn with_starting_value(mut self, starting_value: T::Native) -> Self { + self.starting_value = starting_value; + self + } } impl GroupsAccumulator for PrimitiveGroupsAccumulator @@ -85,7 +95,7 @@ where let values = values[0].as_primitive::(); // update values - self.values.resize(total_num_groups, T::default_value()); + self.values.resize(total_num_groups, self.starting_value); // NullState dispatches / handles tracking nulls and groups that saw no values self.null_state.accumulate(