Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support sum/avg agg for decimal, change sum(float32) --> float64 #1408

Merged
merged 7 commits into from
Dec 17, 2021

Conversation

liukun4515
Copy link
Contributor

@liukun4515 liukun4515 commented Dec 7, 2021

Which issue does this PR close?

part of #122

The result data type for decimal case (closes #1418)

Rationale for this change

What changes are included in this PR?

Are there any user-facing changes?

@liukun4515 liukun4515 changed the title support sum/avg agg for decimal [WIP]support sum/avg agg for decimal Dec 7, 2021
@github-actions github-actions bot added the datafusion Changes in the datafusion crate label Dec 7, 2021
@liukun4515 liukun4515 marked this pull request as ready for review December 14, 2021 05:29
@liukun4515 liukun4515 changed the title [WIP]support sum/avg agg for decimal support sum/avg agg for decimal Dec 14, 2021
@liukun4515
Copy link
Contributor Author

PTAL @alamb @houqp

@alamb alamb changed the title support sum/avg agg for decimal support sum/avg agg for decimal, change sum(float32) --> float64 Dec 14, 2021
Copy link
Contributor

@alamb alamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution; Very nicely tested @liukun4515 🏅

I left some comments, but overall I think this is looking quite good

"+-----------------+",
];
assert_eq!(
&DataType::Decimal(20, 3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}

/// function return type of an average
pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal(precision, scale) => {
// the new precision and scale for return type of avg function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please document document the rationale for the 4 and 38 constants below (or even better pull them into named constants somewhere)?

I also don't understand where the additional 4 came from. I tried to see if it was what postgres did, but when I checked the output schema for avg(numeric(10,3)) appears to be numeric without the precision or scale specified 🤔

(arrow_dev) alamb@MacBook-Pro-2:~/Downloads$ psql
psql (14.1)
Type "help" for help.

alamb=# create table test(x decimal(10, 3));
CREATE TABLE
alamb=# insert into test values (1.02);
INSERT 0 1
alamb=# create table test2 as select avg(x) from test;
SELECT 1

alamb=# select table_name, column_name, numeric_precision, numeric_scale, data_type from information_schema.columns where table_name='test2';
 table_name | column_name | numeric_precision | numeric_scale | data_type 
------------+-------------+-------------------+---------------+-----------
 test2      | avg         |                   |               | numeric
(1 row)

Copy link
Contributor Author

@liukun4515 liukun4515 Dec 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the intention for #1418.
In the PG, we can create bigger precision for decimal than that in datafusion.
The decimal in datafusion, whose behavior may be different from that in PG.
So we should discuss them.
Now for the promotion precision or scale, I just follow the spark behavior.
@alamb

Copy link
Contributor

@alamb alamb Dec 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think following the spark behavior is reasonable, but it should be documented (aka that the constant 4 came from spark). Otherwise in 3 months we'll be 🤔 why were the particular constants picked

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comments to the code and issue #1461 to track this rule.
We can add the rule to our document later in the follow-up pull request.

Self {
name: name.into(),
expr,
match data_type {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment above looks out of date -- I think it should simply be removed.

And perhaps we can change this code so it doesn't use unreachable as I think it would be fairly easy to reach this code by calling Avg::new(..) with some incorrect paramters

How about something like

assert!(matches!(data_type, DataType::Float64 | DataType::Decimal(_, _)));

Which I think might be easier to diagnose if anyone hits it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grate comments

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(AvgAccumulator::try_new(
// avg is f64 or decimal
&self.data_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a sum of decimal(10,2) can be decimal(20,2) shouldn't the accumulator state also be decimal(20,2) to avoid overflow?

I think handling overflow is probably fine to for a later date / PR, but it is strange to me that there is a discrepancy between the type for sum and the accumulator type for computing avg

Copy link
Contributor Author

@liukun4515 liukun4515 Dec 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result type of phy expr(sum/avg) is same with each Accumulator, and it was decided by sum_return_type and avg_return_type.

If the column is decimal(8,2), the avg of this column must be less than 10^8-1, but we need more digits to represent the decimal part. For example, The avg of 3,4,6 is 4.333333....., we should increase the scale part.

For the sum agg, we just should increase the precision part, and the rule of adding 10 to precision is spark coercion rule for sum decimal. We can have our rules for decimal if we want.
@alamb
We can just follow the spark now, and change the rules if we want to define own rules.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add issue to track the overflow.
#1460

ScalarValue::Decimal128(value, precision, scale) => {
Ok(match value {
None => ScalarValue::Decimal128(None, precision, scale),
// TODO add the checker for overflow the precision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// the lhs_scale must be greater or equal rhs_scale.
match (lhs, rhs) {
(None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale),
(None, rhs) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(None, rhs) => {
(None, Some(rhs)) => {

I think you could avoid the unwrap below (which does a redundant check) by using a pattern match here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Done.

ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale)
}
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale),
(lhs, rhs) => {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(lhs, rhs) => {
(Some(lhs), rhs) => {
Suggested change
(lhs, rhs) => {
(lhs, rhs) => {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

sum_decimal(v1, v2, p1, s1)
} else if s1.gt(s2) && p1.ge(p2) {
// For avg aggravate function.
// In the avg function, the scale of result data type is different with the scale of the input data type.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment also applies to sum, not just avg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed the logic to handle diff scale values to sum, so the comments has been removed.

if s1.eq(s2) {
sum_decimal(v1, v2, p1, s1)
} else if s1.gt(s2) && p1.ge(p2) {
// For avg aggravate function.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// For avg aggravate function.
// For avg aggregate function.
Suggested change
// For avg aggravate function.
// For avg aggravate function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

(ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
if s1.eq(s2) {
sum_decimal(v1, v2, p1, s1)
} else if s1.gt(s2) && p1.ge(p2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the need for this clause. It means, among other things, it would seem to make sum for decimal is not commutative which is confusing

I would expect that sum(lhs, rhs) == sum(rhs, lhs) for any specific lhs and rhs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also ugly and confusing to me.

I have refined this part.
For the input args, the left decimal value and right decimal value, there is two cases:

  1. scale is same: just find the max precision as the result precision
  2. scale is diff: select the max(scale1,scale2) as the result scale and max precision as the result precision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let max_precision = p1.max(p2);
            if s1.eq(s2) {
                // s1 = s2
                sum_decimal(v1, v2, max_precision, s1)
            } else if s1.gt(s2) {
                // s1 > s2
                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
            } else if s1.lt(s2) {
                // s1 < s2
                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
            } else {
                return Err(DataFusionError::Internal(format!(
                    "Sum is not expected to receive lhs {:?}, rhs {:?}",
                    lhs, rhs
                )));
            }

@@ -54,8 +56,15 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
Ok(DataType::UInt64)
}
DataType::Float32 => Ok(DataType::Float32),
DataType::Float64 => Ok(DataType::Float64),
// In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc,
// In the https://www.postgresql.org/docs/current/functions-aggregate.html doc,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 237 to 238
(lhs, rhs) => {
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(lhs, rhs) => {
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
(Some(lhs), Some(rhs)) => {
ScalarValue::Decimal128(Some(lhs + rhs), *precision, *scale)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @jimexist updated from the @alamb comments.

Comment on lines 233 to 240
match (lhs, rhs) {
(None, None) => ScalarValue::Decimal128(None, *precision, *scale),
(None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
(lhs, rhs) => {
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
match (lhs, rhs) {
(None, None) => ScalarValue::Decimal128(None, *precision, *scale),
(None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
(lhs, rhs) => {
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale)
}
}
ScalarValue::Decimal128(Some(lhs.unwrap_or(0) + rhs.unwrap_or(0)), *precision, *scale)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this changes meet the None for left or right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to

match (lhs, rhs) {
        (None, None) => ScalarValue::Decimal128(None, *precision, *scale),
        (None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
        (lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
        (Some(lhs_value), Some(rhs_value)) => {
            ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale)
        }
    }

@liukun4515
Copy link
Contributor Author

In the comments, @alamb gives some comments about the precision or scale for the result data type.
I want to follow the guidance document https://www.postgresql.org/docs/8.2/functions-aggregate.html at first but can't find some useful suggestions.
So I follow the behavior in the spark sun and avg.

@alamb
Copy link
Contributor

alamb commented Dec 15, 2021

I want to follow the guidance document https://www.postgresql.org/docs/8.2/functions-aggregate.html at first but can't find some useful suggestions.
So I follow the behavior in the spark sun and avg.

As I mentioned above, i think following the behavior of spark is fine, but I think we should use symbolic constants (largely as a way to document why those particular constants were picked)

@liukun4515
Copy link
Contributor Author

As I mentioned above, i think following the behavior of spark is fine, but I think we should use symbolic constants (largely as a way to document why those particular constants were picked)

Ok, I will refine this pull request.

@liukun4515 liukun4515 marked this pull request as draft December 17, 2021 01:54
@github-actions github-actions bot added the sql SQL Planner label Dec 17, 2021
@liukun4515 liukun4515 marked this pull request as ready for review December 17, 2021 09:08
@liukun4515 liukun4515 requested review from alamb and removed request for alamb December 17, 2021 09:29
@liukun4515 liukun4515 requested review from jimexist and alamb December 17, 2021 09:29
@alamb
Copy link
Contributor

alamb commented Dec 17, 2021

👁️

// the new precision and scale for return type of avg function
let new_precision = 38.min(*precision + 4);
let new_scale = 38.min(*scale + 4);
// in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@@ -33,6 +33,11 @@ use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};

// TODO may need to be moved to arrow-rs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving to arrow-rs would be a good idea I think

Copy link
Contributor

@alamb alamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is looking good -- thank you @liukun4515 -- nice forward progress

// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4);
let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4);
Ok(DataType::Decimal(new_precision, new_scale))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb alamb merged commit 9d31866 into apache:master Dec 17, 2021
@alamb alamb added enhancement New feature or request and removed enhancement New feature or request labels Feb 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datafusion Changes in the datafusion crate sql SQL Planner
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Talk about the result type for coerced type
3 participants