Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sum statistics #1

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 73 additions & 12 deletions datafusion/common/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::fmt::{self, Debug, Display};

use crate::{Result, ScalarValue};

use arrow_schema::{Schema, SchemaRef};
use arrow_schema::{DataType, Schema, SchemaRef};

/// Represents a value with a degree of certainty. `Precision` is used to
/// propagate information the precision of statistical values.
Expand Down Expand Up @@ -170,24 +170,63 @@ impl Precision<ScalarValue> {
pub fn add(&self, other: &Precision<ScalarValue>) -> Precision<ScalarValue> {
match (self, other) {
(Precision::Exact(a), Precision::Exact(b)) => {
if let Ok(result) = a.add(b) {
Precision::Exact(result)
} else {
Precision::Absent
}
a.add(b).map(Precision::Exact).unwrap_or(Precision::Absent)
}
(Precision::Inexact(a), Precision::Exact(b))
| (Precision::Exact(a), Precision::Inexact(b))
| (Precision::Inexact(a), Precision::Inexact(b)) => {
if let Ok(result) = a.add(b) {
Precision::Inexact(result)
} else {
Precision::Absent
}
| (Precision::Inexact(a), Precision::Inexact(b)) => a
.add(b)
.map(Precision::Inexact)
.unwrap_or(Precision::Absent),
(_, _) => Precision::Absent,
}
}

/// Calculates the difference of two (possibly inexact) [`ScalarValue`] values,
/// conservatively propagating exactness information. If one of the input
/// values is [`Precision::Absent`], the result is `Absent` too.
pub fn sub(&self, other: &Precision<ScalarValue>) -> Precision<ScalarValue> {
match (self, other) {
(Precision::Exact(a), Precision::Exact(b)) => {
a.add(b).map(Precision::Exact).unwrap_or(Precision::Absent)
}
(Precision::Inexact(a), Precision::Exact(b))
| (Precision::Exact(a), Precision::Inexact(b))
| (Precision::Inexact(a), Precision::Inexact(b)) => a
.add(b)
.map(Precision::Inexact)
.unwrap_or(Precision::Absent),
(_, _) => Precision::Absent,
}
}

/// Calculates the multiplication of two (possibly inexact) [`ScalarValue`] values,
/// conservatively propagating exactness information. If one of the input
/// values is [`Precision::Absent`], the result is `Absent` too.
pub fn multiply(&self, other: &Precision<ScalarValue>) -> Precision<ScalarValue> {
match (self, other) {
(Precision::Exact(a), Precision::Exact(b)) => a
.mul_checked(b)
.map(Precision::Exact)
.unwrap_or(Precision::Absent),
(Precision::Inexact(a), Precision::Exact(b))
| (Precision::Exact(a), Precision::Inexact(b))
| (Precision::Inexact(a), Precision::Inexact(b)) => a
.mul_checked(b)
.map(Precision::Inexact)
.unwrap_or(Precision::Absent),
(_, _) => Precision::Absent,
}
}

/// Casts the value to the given data type, propagating exactness information.
pub fn cast_to(&self, data_type: &DataType) -> Result<Precision<ScalarValue>> {
match self {
Precision::Exact(value) => value.cast_to(data_type).map(Precision::Exact),
Precision::Inexact(value) => value.cast_to(data_type).map(Precision::Inexact),
Precision::Absent => Ok(Precision::Absent),
}
}
}

impl<T: Debug + Clone + PartialEq + Eq + PartialOrd> Debug for Precision<T> {
Expand All @@ -210,6 +249,18 @@ impl<T: Debug + Clone + PartialEq + Eq + PartialOrd> Display for Precision<T> {
}
}

impl From<Precision<usize>> for Precision<ScalarValue> {
fn from(value: Precision<usize>) -> Self {
match value {
Precision::Exact(v) => Precision::Exact(ScalarValue::UInt64(Some(v as u64))),
Precision::Inexact(v) => {
Precision::Inexact(ScalarValue::UInt64(Some(v as u64)))
}
Precision::Absent => Precision::Absent,
}
}
}

/// Statistics for a relation
/// Fields are optional and can be inexact because the sources
/// sometimes provide approximate estimates for performance reasons
Expand Down Expand Up @@ -401,6 +452,11 @@ impl Display for Statistics {
} else {
s
};
let s = if cs.sum_value != Precision::Absent {
format!("{} Sum={}", s, cs.sum_value)
} else {
s
};
let s = if cs.null_count != Precision::Absent {
format!("{} Null={}", s, cs.null_count)
} else {
Expand Down Expand Up @@ -436,6 +492,8 @@ pub struct ColumnStatistics {
pub max_value: Precision<ScalarValue>,
/// Minimum value of column
pub min_value: Precision<ScalarValue>,
/// Sum value of a column
pub sum_value: Precision<ScalarValue>,
/// Number of distinct values
pub distinct_count: Precision<usize>,
}
Expand All @@ -458,6 +516,7 @@ impl ColumnStatistics {
null_count: Precision::Absent,
max_value: Precision::Absent,
min_value: Precision::Absent,
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
}
}
Expand All @@ -469,6 +528,7 @@ impl ColumnStatistics {
self.null_count = self.null_count.to_inexact();
self.max_value = self.max_value.to_inexact();
self.min_value = self.min_value.to_inexact();
self.sum_value = self.sum_value.to_inexact();
self.distinct_count = self.distinct_count.to_inexact();
self
}
Expand Down Expand Up @@ -646,6 +706,7 @@ mod tests {
null_count: Precision::Exact(null_count),
max_value: Precision::Exact(ScalarValue::Int64(Some(42))),
min_value: Precision::Exact(ScalarValue::Int64(Some(64))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(4600))),
distinct_count: Precision::Exact(100),
}
}
Expand Down
9 changes: 5 additions & 4 deletions datafusion/core/src/datasource/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ pub async fn get_statistics_with_limit(
for (index, file_column) in
file_stats.column_statistics.clone().into_iter().enumerate()
{
col_stats_set[index].null_count = file_column.null_count;
col_stats_set[index].max_value = file_column.max_value;
col_stats_set[index].min_value = file_column.min_value;
col_stats_set[index] = file_column;
}

// If the number of rows exceeds the limit, we can stop processing
Expand Down Expand Up @@ -113,12 +111,14 @@ pub async fn get_statistics_with_limit(
null_count: file_nc,
max_value: file_max,
min_value: file_min,
sum_value: file_sum,
distinct_count: _,
} = file_col_stats;

col_stats.null_count = add_row_stats(*file_nc, col_stats.null_count);
set_max_if_greater(file_max, &mut col_stats.max_value);
set_min_if_lesser(file_min, &mut col_stats.min_value)
set_min_if_lesser(file_min, &mut col_stats.min_value);
col_stats.sum_value = file_sum.add(&col_stats.sum_value);
}

// If the number of rows exceeds the limit, we can stop processing
Expand Down Expand Up @@ -204,6 +204,7 @@ pub(crate) fn get_col_stats(
null_count: null_counts[i],
max_value: max_value.map(Precision::Exact).unwrap_or(Precision::Absent),
min_value: min_value.map(Precision::Exact).unwrap_or(Precision::Absent),
sum_value: Precision::Absent,
distinct_count: Precision::Absent,
}
})
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/tests/custom_sources_cases/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,14 @@ fn fully_defined() -> (Statistics, Schema) {
distinct_count: Precision::Exact(2),
max_value: Precision::Exact(ScalarValue::Int32(Some(1023))),
min_value: Precision::Exact(ScalarValue::Int32(Some(-24))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(10))),
null_count: Precision::Exact(0),
},
ColumnStatistics {
distinct_count: Precision::Exact(13),
max_value: Precision::Exact(ScalarValue::Int64(Some(5486))),
min_value: Precision::Exact(ScalarValue::Int64(Some(-6783))),
sum_value: Precision::Exact(ScalarValue::Int64(Some(10))),
null_count: Precision::Exact(5),
},
],
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl fmt::Display for AggregateUDF {
}

/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
#[derive(Debug)]
pub struct StatisticsArgs<'a> {
/// The statistics of the aggregate input
pub statistics: &'a Statistics,
Expand Down
71 changes: 70 additions & 1 deletion datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ use datafusion_expr::utils::format_state_name;
use datafusion_expr::Volatility::Immutable;
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator,
ReversedUDAF, Signature,
ReversedUDAF, Signature, StatisticsArgs,
};

use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
filtered_null_mask, set_nulls,
};

use datafusion_common::stats::Precision;
use datafusion_functions_aggregate_common::utils::DecimalAverager;
use datafusion_macros::user_doc;
use log::debug;
Expand Down Expand Up @@ -253,6 +254,34 @@ impl AggregateUDFImpl for Avg {
coerce_avg_type(self.name(), arg_types)
}

fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
if statistics_args.is_distinct {
return None;
}

if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
match *num_rows {
0 => return ScalarValue::new_zero(statistics_args.return_type).ok(),
num_rows => {
if statistics_args.exprs.len() == 1 {
if let Precision::Exact(sum) = statistics_args.exprs[0]
.column_statistics(statistics_args.statistics)
.ok()?
.sum_value
{
let sum = sum.cast_to(statistics_args.return_type).ok()?;
let num_rows = ScalarValue::from(num_rows as u64)
.cast_to(statistics_args.return_type)
.ok()?;
return sum.div(&num_rows).ok();
}
}
}
}
}
None
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
Expand Down Expand Up @@ -606,3 +635,43 @@ where
self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
}
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::{ColumnStatistics, Statistics};
use datafusion_physical_expr::expressions::Column;
use std::sync::Arc;

#[test]
fn sum() {
let agg = Box::new(Avg::new());
let statistics = Statistics {
num_rows: Precision::Exact(5),
total_byte_size: Precision::Absent,
column_statistics: vec![ColumnStatistics {
null_count: Precision::Absent,
max_value: Precision::Absent,
min_value: Precision::Absent,
sum_value: Precision::Exact(ScalarValue::Int8(Some(10))),
distinct_count: Default::default(),
}],
};
let mut statistics_args = StatisticsArgs {
statistics: &statistics,
return_type: &DataType::Float64,
is_distinct: false,
exprs: &[Arc::new(Column::new("a", 0))],
};

// Ensure that the sum statistic is used and cast to the return type.
assert_eq!(
agg.value_from_stats(&statistics_args),
Some(ScalarValue::Float64(Some(2.0)))
);

// With a distinct aggregate, the sum statistic isn't helpful
statistics_args.is_distinct = true;
assert_eq!(agg.value_from_stats(&statistics_args), None);
}
}
20 changes: 8 additions & 12 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,25 +324,21 @@ impl AggregateUDFImpl for Count {
}
if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
if statistics_args.exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
let current_val = &statistics_args.statistics.column_statistics
[col_expr.index()]
.null_count;
if let &Precision::Exact(val) = current_val {
return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
}
} else if let Some(lit_expr) = statistics_args.exprs[0]
if let Some(lit_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some(ScalarValue::Int64(Some(num_rows as i64)));
}
}

let col_stats = statistics_args.exprs[0]
.column_statistics(statistics_args.statistics)
.ok()?;
if let Precision::Exact(val) = col_stats.null_count {
return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
}
}
}
None
Expand Down
Loading