-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support sum/avg agg for decimal, change sum(float32) --> float64 #1408
Conversation
8b91ff1
to
3c19c41
Compare
667d46a
to
9675e30
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution; Very nicely tested @liukun4515 🏅
I left some comments, but overall I think this is looking quite good
"+-----------------+", | ||
]; | ||
assert_eq!( | ||
&DataType::Decimal(20, 3), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
} | ||
|
||
/// function return type of an average | ||
pub fn avg_return_type(arg_type: &DataType) -> Result<DataType> { | ||
match arg_type { | ||
DataType::Decimal(precision, scale) => { | ||
// the new precision and scale for return type of avg function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please document document the rationale for the 4
and 38
constants below (or even better pull them into named constants somewhere)?
I also don't understand where the additional 4
came from. I tried to see if it was what postgres did, but when I checked the output schema for avg(numeric(10,3))
appears to be numeric
without the precision or scale specified 🤔
(arrow_dev) alamb@MacBook-Pro-2:~/Downloads$ psql
psql (14.1)
Type "help" for help.
alamb=# create table test(x decimal(10, 3));
CREATE TABLE
alamb=# insert into test values (1.02);
INSERT 0 1
alamb=# create table test2 as select avg(x) from test;
SELECT 1
alamb=# select table_name, column_name, numeric_precision, numeric_scale, data_type from information_schema.columns where table_name='test2';
table_name | column_name | numeric_precision | numeric_scale | data_type
------------+-------------+-------------------+---------------+-----------
test2 | avg | | | numeric
(1 row)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think following the spark behavior is reasonable, but it should be documented (aka that the constant 4
came from spark). Otherwise in 3 months we'll be 🤔 why were the particular constants picked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments to the code and issue #1461 to track this rule.
We can add the rule to our document later in the follow-up pull request.
Self { | ||
name: name.into(), | ||
expr, | ||
match data_type { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment above looks out of date -- I think it should simply be removed.
And perhaps we can change this code so it doesn't use unreachable
as I think it would be fairly easy to reach this code by calling Avg::new(..)
with some incorrect paramters
How about something like
assert!(matches!(data_type, DataType::Float64 | DataType::Decimal(_, _)));
Which I think might be easier to diagnose if anyone hits it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grate comments
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { | ||
Ok(Box::new(AvgAccumulator::try_new( | ||
// avg is f64 or decimal | ||
&self.data_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if a sum of decimal(10,2)
can be decimal(20,2)
shouldn't the accumulator state also be decimal(20,2)
to avoid overflow?
I think handling overflow is probably fine to for a later date / PR, but it is strange to me that there is a discrepancy between the type for sum
and the accumulator type for computing avg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result type of phy expr(sum/avg) is same with each Accumulator, and it was decided by sum_return_type
and avg_return_type
.
If the column is decimal(8,2), the avg of this column must be less than 10^8-1
, but we need more digits to represent the decimal part. For example, The avg of 3,4,6
is 4.333333.....
, we should increase the scale part.
For the sum agg, we just should increase the precision part, and the rule of adding 10
to precision is spark coercion rule for sum decimal. We can have our rules for decimal if we want.
@alamb
We can just follow the spark now, and change the rules if we want to define own rules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add issue to track the overflow.
#1460
ScalarValue::Decimal128(value, precision, scale) => { | ||
Ok(match value { | ||
None => ScalarValue::Decimal128(None, precision, scale), | ||
// TODO add the checker for overflow the precision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
// the lhs_scale must be greater or equal rhs_scale. | ||
match (lhs, rhs) { | ||
(None, None) => ScalarValue::Decimal128(None, *precision, *lhs_scale), | ||
(None, rhs) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(None, rhs) => { | |
(None, Some(rhs)) => { |
I think you could avoid the unwrap
below (which does a redundant check) by using a pattern match here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Done.
ScalarValue::Decimal128(Some(new_value), *precision, *lhs_scale) | ||
} | ||
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *lhs_scale), | ||
(lhs, rhs) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(lhs, rhs) => { | |
(Some(lhs), rhs) => { |
(lhs, rhs) => { | |
(lhs, rhs) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
sum_decimal(v1, v2, p1, s1) | ||
} else if s1.gt(s2) && p1.ge(p2) { | ||
// For avg aggravate function. | ||
// In the avg function, the scale of result data type is different with the scale of the input data type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment also applies to sum
, not just avg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed the logic to handle diff scale values to sum, so the comments has been removed.
if s1.eq(s2) { | ||
sum_decimal(v1, v2, p1, s1) | ||
} else if s1.gt(s2) && p1.ge(p2) { | ||
// For avg aggravate function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// For avg aggravate function. | |
// For avg aggregate function. |
// For avg aggravate function. | |
// For avg aggravate function. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
(ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => { | ||
if s1.eq(s2) { | ||
sum_decimal(v1, v2, p1, s1) | ||
} else if s1.gt(s2) && p1.ge(p2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the need for this clause. It means, among other things, it would seem to make sum
for decimal is not commutative which is confusing
I would expect that sum(lhs, rhs) == sum(rhs, lhs)
for any specific lhs
and rhs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also ugly and confusing to me.
I have refined this part.
For the input args, the left decimal value and right decimal value, there is two cases:
- scale is same: just find the max precision as the result precision
- scale is diff: select the max(scale1,scale2) as the result scale and max precision as the result precision.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let max_precision = p1.max(p2);
if s1.eq(s2) {
// s1 = s2
sum_decimal(v1, v2, max_precision, s1)
} else if s1.gt(s2) {
// s1 > s2
sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
} else if s1.lt(s2) {
// s1 < s2
sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
} else {
return Err(DataFusionError::Internal(format!(
"Sum is not expected to receive lhs {:?}, rhs {:?}",
lhs, rhs
)));
}
@@ -54,8 +56,15 @@ pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> { | |||
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { | |||
Ok(DataType::UInt64) | |||
} | |||
DataType::Float32 => Ok(DataType::Float32), | |||
DataType::Float64 => Ok(DataType::Float64), | |||
// In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// In the https://www.postgresql.org/docs/8.2/functions-aggregate.html doc, | |
// In the https://www.postgresql.org/docs/current/functions-aggregate.html doc, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
(lhs, rhs) => { | ||
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(lhs, rhs) => { | |
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale) | |
(Some(lhs), Some(rhs)) => { | |
ScalarValue::Decimal128(Some(lhs + rhs), *precision, *scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match (lhs, rhs) { | ||
(None, None) => ScalarValue::Decimal128(None, *precision, *scale), | ||
(None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale), | ||
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale), | ||
(lhs, rhs) => { | ||
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match (lhs, rhs) { | |
(None, None) => ScalarValue::Decimal128(None, *precision, *scale), | |
(None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale), | |
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale), | |
(lhs, rhs) => { | |
ScalarValue::Decimal128(Some(lhs.unwrap() + rhs.unwrap()), *precision, *scale) | |
} | |
} | |
ScalarValue::Decimal128(Some(lhs.unwrap_or(0) + rhs.unwrap_or(0)), *precision, *scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this changes meet the None for left or right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to
match (lhs, rhs) {
(None, None) => ScalarValue::Decimal128(None, *precision, *scale),
(None, rhs) => ScalarValue::Decimal128(*rhs, *precision, *scale),
(lhs, None) => ScalarValue::Decimal128(*lhs, *precision, *scale),
(Some(lhs_value), Some(rhs_value)) => {
ScalarValue::Decimal128(Some(lhs_value + rhs_value), *precision, *scale)
}
}
In the comments, @alamb gives some comments about the precision or scale for the result data type. |
As I mentioned above, i think following the behavior of spark is fine, but I think we should use symbolic constants (largely as a way to document why those particular constants were picked) |
Ok, I will refine this pull request. |
👁️ |
// the new precision and scale for return type of avg function | ||
let new_precision = 38.min(*precision + 4); | ||
let new_scale = 38.min(*scale + 4); | ||
// in the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
@@ -33,6 +33,11 @@ use std::convert::{Infallible, TryInto}; | |||
use std::str::FromStr; | |||
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; | |||
|
|||
// TODO may need to be moved to arrow-rs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving to arrow-rs would be a good idea I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is looking good -- thank you @liukun4515 -- nice forward progress
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 | ||
let new_precision = MAX_PRECISION_FOR_DECIMAL128.min(*precision + 4); | ||
let new_scale = MAX_SCALE_FOR_DECIMAL128.min(*scale + 4); | ||
Ok(DataType::Decimal(new_precision, new_scale)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could potentially use make_decimal_type
: https://github.com/apache/arrow-datafusion/blob/4b454d0363b034e3ebcdf88f6add2403cd77a23b/datafusion/src/sql/utils.rs#L510 here too
Which issue does this PR close?
part of #122
The result data type for decimal case (closes #1418)
Rationale for this change
What changes are included in this PR?
Are there any user-facing changes?