-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support mathematics operation for decimal data type #1554
Changes from 8 commits
6687f29
e2df9f5
ed8ef8f
a2551db
bf63119
3d15e8f
047827f
f1e7081
a61f7f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,9 +21,10 @@ use crate::arrow::datatypes::DataType; | |
use crate::error::{DataFusionError, Result}; | ||
use crate::logical_plan::Operator; | ||
use crate::physical_plan::expressions::coercion::{ | ||
dictionary_coercion, eq_coercion, is_numeric, like_coercion, numerical_coercion, | ||
string_coercion, temporal_coercion, | ||
dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion, | ||
temporal_coercion, | ||
}; | ||
use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128}; | ||
|
||
/// Coercion rules for all binary operators. Returns the output type | ||
/// of applying `op` to an argument of `lhs_type` and `rhs_type`. | ||
|
@@ -49,12 +50,11 @@ pub(crate) fn coerce_types( | |
Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type), | ||
// for math expressions, the final value of the coercion is also the return type | ||
// because coercion favours higher information types | ||
// TODO: support decimal data type | ||
Operator::Plus | ||
| Operator::Minus | ||
| Operator::Modulo | ||
| Operator::Divide | ||
| Operator::Multiply => numerical_coercion(lhs_type, rhs_type), | ||
| Operator::Multiply => mathematics_numerical_coercion(op, lhs_type, rhs_type), | ||
Operator::RegexMatch | ||
| Operator::RegexIMatch | ||
| Operator::RegexNotMatch | ||
|
@@ -162,12 +162,141 @@ fn get_comparison_common_decimal_type( | |
} | ||
} | ||
|
||
// Convert the numeric data type to the decimal data type. | ||
// Now, we just support the signed integer type and floating-point type. | ||
fn convert_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> { | ||
match numeric_type { | ||
DataType::Int8 => Some(DataType::Decimal(3, 0)), | ||
DataType::Int16 => Some(DataType::Decimal(5, 0)), | ||
DataType::Int32 => Some(DataType::Decimal(10, 0)), | ||
DataType::Int64 => Some(DataType::Decimal(20, 0)), | ||
// TODO if we convert the floating-point data to the decimal type, it maybe overflow. | ||
DataType::Float32 => Some(DataType::Decimal(14, 7)), | ||
DataType::Float64 => Some(DataType::Decimal(30, 15)), | ||
_ => None, | ||
} | ||
} | ||
|
||
fn mathematics_numerical_coercion( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function looks almost the same as What is the reason for making a new function rather than extending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thank you There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you try extending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @alamb I will extend it in the follow-up pull request and support dict and |
||
mathematics_op: &Operator, | ||
lhs_type: &DataType, | ||
rhs_type: &DataType, | ||
) -> Option<DataType> { | ||
use arrow::datatypes::DataType::*; | ||
|
||
// error on any non-numeric type | ||
if !is_numeric(lhs_type) || !is_numeric(rhs_type) { | ||
return None; | ||
}; | ||
|
||
// same type => all good | ||
if lhs_type == rhs_type { | ||
return Some(lhs_type.clone()); | ||
} | ||
|
||
// these are ordered from most informative to least informative so | ||
// that the coercion removes the least amount of information | ||
match (lhs_type, rhs_type) { | ||
(Decimal(_, _), Decimal(_, _)) => { | ||
coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type) | ||
} | ||
(Decimal(_, _), _) => { | ||
let converted_decimal_type = convert_numeric_type_to_decimal(rhs_type); | ||
match converted_decimal_type { | ||
None => None, | ||
Some(right_decimal_type) => coercion_decimal_mathematics_type( | ||
mathematics_op, | ||
lhs_type, | ||
&right_decimal_type, | ||
), | ||
} | ||
} | ||
(_, Decimal(_, _)) => { | ||
let converted_decimal_type = convert_numeric_type_to_decimal(lhs_type); | ||
match converted_decimal_type { | ||
None => None, | ||
Some(left_decimal_type) => coercion_decimal_mathematics_type( | ||
mathematics_op, | ||
&left_decimal_type, | ||
rhs_type, | ||
), | ||
} | ||
} | ||
(Float64, _) | (_, Float64) => Some(Float64), | ||
(_, Float32) | (Float32, _) => Some(Float32), | ||
(Int64, _) | (_, Int64) => Some(Int64), | ||
(Int32, _) | (_, Int32) => Some(Int32), | ||
(Int16, _) | (_, Int16) => Some(Int16), | ||
(Int8, _) | (_, Int8) => Some(Int8), | ||
(UInt64, _) | (_, UInt64) => Some(UInt64), | ||
(UInt32, _) | (_, UInt32) => Some(UInt32), | ||
(UInt16, _) | (_, UInt16) => Some(UInt16), | ||
(UInt8, _) | (_, UInt8) => Some(UInt8), | ||
_ => None, | ||
} | ||
} | ||
|
||
fn create_decimal_type(precision: usize, scale: usize) -> DataType { | ||
DataType::Decimal( | ||
MAX_PRECISION_FOR_DECIMAL128.min(precision), | ||
MAX_SCALE_FOR_DECIMAL128.min(scale), | ||
) | ||
} | ||
|
||
fn coercion_decimal_mathematics_type( | ||
mathematics_op: &Operator, | ||
left_decimal_type: &DataType, | ||
right_decimal_type: &DataType, | ||
) -> Option<DataType> { | ||
use arrow::datatypes::DataType::*; | ||
match (left_decimal_type, right_decimal_type) { | ||
// The coercion rule from spark | ||
// https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35 | ||
(Decimal(p1, s1), Decimal(p2, s2)) => { | ||
match mathematics_op { | ||
Operator::Plus | Operator::Minus => { | ||
// max(s1, s2) | ||
let result_scale = *s1.max(s2); | ||
// max(s1, s2) + max(p1-s1, p2-s2) + 1 | ||
let result_precision = result_scale + (*p1 - *s1).max(*p2 - *s2) + 1; | ||
Some(create_decimal_type(result_precision, result_scale)) | ||
} | ||
Operator::Multiply => { | ||
// s1 + s2 | ||
let result_scale = *s1 + *s2; | ||
// p1 + p2 + 1 | ||
let result_precision = *p1 + *p2 + 1; | ||
Some(create_decimal_type(result_precision, result_scale)) | ||
} | ||
Operator::Divide => { | ||
// max(6, s1 + p2 + 1) | ||
let result_scale = 6.max(*s1 + *p2 + 1); | ||
// p1 - s1 + s2 + max(6, s1 + p2 + 1) | ||
let result_precision = result_scale + *p1 - *s1 + *s2; | ||
Some(create_decimal_type(result_precision, result_scale)) | ||
} | ||
Operator::Modulo => { | ||
// max(s1, s2) | ||
let result_scale = *s1.max(s2); | ||
// min(p1-s1, p2-s2) + max(s1, s2) | ||
let result_precision = result_scale + (*p1 - *s1).min(*p2 - *s2); | ||
Some(create_decimal_type(result_precision, result_scale)) | ||
} | ||
_ => unreachable!(), | ||
} | ||
} | ||
_ => unreachable!(), | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::arrow::datatypes::DataType; | ||
use crate::error::{DataFusionError, Result}; | ||
use crate::logical_plan::Operator; | ||
use crate::physical_plan::coercion_rule::binary_rule::coerce_types; | ||
use crate::physical_plan::coercion_rule::binary_rule::{ | ||
coerce_types, coercion_decimal_mathematics_type, convert_numeric_type_to_decimal, | ||
}; | ||
|
||
#[test] | ||
|
||
|
@@ -226,4 +355,70 @@ mod tests { | |
assert!(result_type.is_err()); | ||
Ok(()) | ||
} | ||
|
||
#[test] | ||
fn test_decimal_mathematics_op_type() { | ||
assert_eq!( | ||
convert_numeric_type_to_decimal(&DataType::Int8).unwrap(), | ||
DataType::Decimal(3, 0) | ||
); | ||
assert_eq!( | ||
convert_numeric_type_to_decimal(&DataType::Int16).unwrap(), | ||
DataType::Decimal(5, 0) | ||
); | ||
assert_eq!( | ||
convert_numeric_type_to_decimal(&DataType::Int32).unwrap(), | ||
DataType::Decimal(10, 0) | ||
); | ||
assert_eq!( | ||
convert_numeric_type_to_decimal(&DataType::Int64).unwrap(), | ||
DataType::Decimal(20, 0) | ||
); | ||
assert_eq!( | ||
convert_numeric_type_to_decimal(&DataType::Float32).unwrap(), | ||
DataType::Decimal(14, 7) | ||
); | ||
assert_eq!( | ||
convert_numeric_type_to_decimal(&DataType::Float64).unwrap(), | ||
DataType::Decimal(30, 15) | ||
); | ||
|
||
let op = Operator::Plus; | ||
let left_decimal_type = DataType::Decimal(10, 3); | ||
let right_decimal_type = DataType::Decimal(20, 4); | ||
let result = coercion_decimal_mathematics_type( | ||
&op, | ||
&left_decimal_type, | ||
&right_decimal_type, | ||
); | ||
assert_eq!(DataType::Decimal(21, 4), result.unwrap()); | ||
let op = Operator::Minus; | ||
let result = coercion_decimal_mathematics_type( | ||
&op, | ||
&left_decimal_type, | ||
&right_decimal_type, | ||
); | ||
assert_eq!(DataType::Decimal(21, 4), result.unwrap()); | ||
let op = Operator::Multiply; | ||
let result = coercion_decimal_mathematics_type( | ||
&op, | ||
&left_decimal_type, | ||
&right_decimal_type, | ||
); | ||
assert_eq!(DataType::Decimal(31, 7), result.unwrap()); | ||
let op = Operator::Divide; | ||
let result = coercion_decimal_mathematics_type( | ||
&op, | ||
&left_decimal_type, | ||
&right_decimal_type, | ||
); | ||
assert_eq!(DataType::Decimal(35, 24), result.unwrap()); | ||
let op = Operator::Modulo; | ||
let result = coercion_decimal_mathematics_type( | ||
&op, | ||
&left_decimal_type, | ||
&right_decimal_type, | ||
); | ||
assert_eq!(DataType::Decimal(11, 4), result.unwrap()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this would be better called
coerce_numeric_type_to_decimal
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done