Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Row accumulator support update Scalar values #6003

Merged
merged 6 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1454,7 +1454,7 @@ impl std::hash::Hash for ScalarValue {
/// return a reference to the values array and the index into it for a
/// dictionary array
#[inline]
fn get_dict_value<K: ArrowDictionaryKeyType>(
pub fn get_dict_value<K: ArrowDictionaryKeyType>(
array: &dyn Array,
index: usize,
) -> (&ArrayRef, Option<usize>) {
Expand Down
369 changes: 335 additions & 34 deletions datafusion/core/src/physical_plan/aggregates/row_hash.rs

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,57 @@ async fn count_multi_expr() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn count_multi_expr_group_by() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("c1", DataType::Int32, true),
Field::new("c2", DataType::Int32, true),
Field::new("c3", DataType::Int32, true),
]));

let data = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![
Some(0),
None,
Some(1),
Some(2),
None,
])),
Arc::new(Int32Array::from(vec![
Some(1),
Some(1),
Some(0),
None,
None,
])),
Arc::new(Int32Array::from(vec![
Some(10),
Some(10),
Some(10),
Some(10),
Some(10),
])),
],
)?;

let ctx = SessionContext::new();
ctx.register_batch("test", data)?;
let sql = "SELECT c3, count(c1, c2) FROM test group by c3";
let actual = execute_to_batches(&ctx, sql).await;

let expected = vec![
"+----+------------------------+",
"| c3 | COUNT(test.c1,test.c2) |",
"+----+------------------------+",
"| 10 | 2 |",
"+----+------------------------+",
];
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}

#[tokio::test]
async fn simple_avg() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down
23 changes: 19 additions & 4 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,24 @@ impl RowAccumulator for AvgRowAccumulator {
self.state_index() + 1,
accessor,
&sum::sum_batch(values, &self.sum_datatype)?,
)?;
Ok(())
)
}

fn update_scalar_values(
&mut self,
values: &[ScalarValue],
accessor: &mut RowAccessor,
) -> Result<()> {
let value = &values[0];
sum::update_avg_to_row(self.state_index(), accessor, value)
}

fn update_scalar(
&mut self,
value: &ScalarValue,
accessor: &mut RowAccessor,
) -> Result<()> {
sum::update_avg_to_row(self.state_index(), accessor, value)
}

fn merge_batch(
Expand All @@ -315,8 +331,7 @@ impl RowAccumulator for AvgRowAccumulator {

// sum
let difference = sum::sum_batch(&states[1], &self.sum_datatype)?;
sum::add_to_row(self.state_index() + 1, accessor, &difference)?;
Ok(())
sum::add_to_row(self.state_index() + 1, accessor, &difference)
}

fn evaluate(&self, accessor: &RowAccessor) -> Result<ScalarValue> {
Expand Down
25 changes: 25 additions & 0 deletions datafusion/physical-expr/src/aggregate/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,31 @@ impl RowAccumulator for CountRowAccumulator {
Ok(())
}

fn update_scalar_values(
&mut self,
values: &[ScalarValue],
accessor: &mut RowAccessor,
) -> Result<()> {
if !values.iter().any(|s| matches!(s, ScalarValue::Null)) {
accessor.add_u64(self.state_index, 1)
}
Ok(())
}

fn update_scalar(
&mut self,
value: &ScalarValue,
accessor: &mut RowAccessor,
) -> Result<()> {
match value {
ScalarValue::Null => {
// do not update the accumulator
}
_ => accessor.add_u64(self.state_index, 1),
}
Ok(())
}

fn merge_batch(
&mut self,
states: &[ArrayRef],
Expand Down
40 changes: 38 additions & 2 deletions datafusion/physical-expr/src/aggregate/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ macro_rules! min_max_v2 {
ScalarValue::Decimal128(rhs, ..) => {
typed_min_max_v2!($INDEX, $ACC, rhs, i128, $OP)
}
ScalarValue::Null => {
// do nothing
}
e => {
return Err(DataFusionError::Internal(format!(
"MIN/MAX is not expected to receive scalars of incompatible types {:?}",
Expand Down Expand Up @@ -647,8 +650,24 @@ impl RowAccumulator for MaxRowAccumulator {
) -> Result<()> {
let values = &values[0];
let delta = &max_batch(values)?;
max_row(self.index, accessor, delta)?;
Ok(())
max_row(self.index, accessor, delta)
}

fn update_scalar_values(
&mut self,
values: &[ScalarValue],
accessor: &mut RowAccessor,
) -> Result<()> {
let value = &values[0];
max_row(self.index, accessor, value)
}

fn update_scalar(
&mut self,
value: &ScalarValue,
accessor: &mut RowAccessor,
) -> Result<()> {
max_row(self.index, accessor, value)
}

fn merge_batch(
Expand Down Expand Up @@ -894,6 +913,23 @@ impl RowAccumulator for MinRowAccumulator {
Ok(())
}

fn update_scalar_values(
&mut self,
values: &[ScalarValue],
accessor: &mut RowAccessor,
) -> Result<()> {
let value = &values[0];
min_row(self.index, accessor, value)
}

fn update_scalar(
&mut self,
value: &ScalarValue,
accessor: &mut RowAccessor,
) -> Result<()> {
min_row(self.index, accessor, value)
}

fn merge_batch(
&mut self,
states: &[ArrayRef],
Expand Down
14 changes: 14 additions & 0 deletions datafusion/physical-expr/src/aggregate/row_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ pub trait RowAccumulator: Send + Sync + Debug {
accessor: &mut RowAccessor,
) -> Result<()>;

/// updates the accumulator's state from a vector of Scalar value.
fn update_scalar_values(
&mut self,
values: &[ScalarValue],
accessor: &mut RowAccessor,
) -> Result<()>;

/// updates the accumulator's state from a Scalar value.
fn update_scalar(
&mut self,
value: &ScalarValue,
accessor: &mut RowAccessor,
) -> Result<()>;

/// updates the accumulator's state from a vector of states.
fn merge_batch(
&mut self,
Expand Down
67 changes: 65 additions & 2 deletions datafusion/physical-expr/src/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,26 @@ macro_rules! sum_row {
}};
}

macro_rules! avg_row {
($INDEX:ident, $ACC:ident, $DELTA:expr, $TYPE:ident) => {{
paste::item! {
if let Some(v) = $DELTA {
$ACC.add_u64($INDEX, 1);
$ACC.[<add_ $TYPE>]($INDEX + 1, *v)
}
}
}};
}

pub(crate) fn add_to_row(
index: usize,
accessor: &mut RowAccessor,
s: &ScalarValue,
) -> Result<()> {
match s {
ScalarValue::Null => {
// do nothing
}
ScalarValue::Float64(rhs) => {
sum_row!(index, accessor, rhs, f64)
}
Expand All @@ -270,6 +284,39 @@ pub(crate) fn add_to_row(
Ok(())
}

pub(crate) fn update_avg_to_row(
index: usize,
accessor: &mut RowAccessor,
s: &ScalarValue,
) -> Result<()> {
match s {
ScalarValue::Null => {
// do nothing
}
ScalarValue::Float64(rhs) => {
avg_row!(index, accessor, rhs, f64)
}
ScalarValue::Float32(rhs) => {
avg_row!(index, accessor, rhs, f32)
}
ScalarValue::UInt64(rhs) => {
avg_row!(index, accessor, rhs, u64)
}
ScalarValue::Int64(rhs) => {
avg_row!(index, accessor, rhs, i64)
}
ScalarValue::Decimal128(rhs, _, _) => {
avg_row!(index, accessor, rhs, i128)
}
_ => {
let msg =
format!("Row avg updater is not expected to receive a scalar {s:?}");
return Err(DataFusionError::Internal(msg));
}
}
Ok(())
}

impl Accumulator for SumAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.sum.clone(), ScalarValue::from(self.count)])
Expand Down Expand Up @@ -331,8 +378,24 @@ impl RowAccumulator for SumRowAccumulator {
) -> Result<()> {
let values = &values[0];
let delta = sum_batch(values, &self.datatype)?;
add_to_row(self.index, accessor, &delta)?;
Ok(())
add_to_row(self.index, accessor, &delta)
}

fn update_scalar_values(
&mut self,
values: &[ScalarValue],
accessor: &mut RowAccessor,
) -> Result<()> {
let value = &values[0];
add_to_row(self.index, accessor, value)
}

fn update_scalar(
&mut self,
value: &ScalarValue,
accessor: &mut RowAccessor,
) -> Result<()> {
add_to_row(self.index, accessor, value)
}

fn merge_batch(
Expand Down
3 changes: 3 additions & 0 deletions datafusion/row/src/accessor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ macro_rules! fn_add_idx {
($NATIVE: ident) => {
paste::item! {
/// add field at `idx` with `value`
#[inline(always)]
pub fn [<add_ $NATIVE>](&mut self, idx: usize, value: $NATIVE) {
if self.is_valid_at(idx) {
self.[<set_ $NATIVE>](idx, value + self.[<get_ $NATIVE>](idx));
Expand All @@ -87,6 +88,7 @@ macro_rules! fn_max_min_idx {
($NATIVE: ident, $OP: ident) => {
paste::item! {
/// check max then update
#[inline(always)]
pub fn [<$OP _ $NATIVE>](&mut self, idx: usize, value: $NATIVE) {
if self.is_valid_at(idx) {
let v = value.$OP(self.[<get_ $NATIVE>](idx));
Expand All @@ -103,6 +105,7 @@ macro_rules! fn_max_min_idx {
macro_rules! fn_get_idx_scalar {
($NATIVE: ident, $SCALAR:ident) => {
paste::item! {
#[inline(always)]
pub fn [<get_ $NATIVE _scalar>](&self, idx: usize) -> ScalarValue {
if self.is_valid_at(idx) {
ScalarValue::$SCALAR(Some(self.[<get_ $NATIVE>](idx)))
Expand Down