Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support mathematics operation for decimal data type #1554

Merged
merged 9 commits into from
Jan 18, 2022
205 changes: 200 additions & 5 deletions datafusion/src/physical_plan/coercion_rule/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use crate::arrow::datatypes::DataType;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::Operator;
use crate::physical_plan::expressions::coercion::{
dictionary_coercion, eq_coercion, is_numeric, like_coercion, numerical_coercion,
string_coercion, temporal_coercion,
dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion,
temporal_coercion,
};
use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128};

/// Coercion rules for all binary operators. Returns the output type
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
Expand All @@ -49,12 +50,11 @@ pub(crate) fn coerce_types(
Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type),
// for math expressions, the final value of the coercion is also the return type
// because coercion favours higher information types
// TODO: support decimal data type
Operator::Plus
| Operator::Minus
| Operator::Modulo
| Operator::Divide
| Operator::Multiply => numerical_coercion(lhs_type, rhs_type),
| Operator::Multiply => mathematics_numerical_coercion(op, lhs_type, rhs_type),
Operator::RegexMatch
| Operator::RegexIMatch
| Operator::RegexNotMatch
Expand Down Expand Up @@ -162,12 +162,141 @@ fn get_comparison_common_decimal_type(
}
}

// Convert the numeric data type to the decimal data type.
// Now, we just support the signed integer type and floating-point type.
fn convert_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this would be better called coerce_numeric_type_to_decimal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

match numeric_type {
DataType::Int8 => Some(DataType::Decimal(3, 0)),
DataType::Int16 => Some(DataType::Decimal(5, 0)),
DataType::Int32 => Some(DataType::Decimal(10, 0)),
DataType::Int64 => Some(DataType::Decimal(20, 0)),
// TODO if we convert the floating-point data to the decimal type, it maybe overflow.
DataType::Float32 => Some(DataType::Decimal(14, 7)),
DataType::Float64 => Some(DataType::Decimal(30, 15)),
_ => None,
}
}

fn mathematics_numerical_coercion(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function looks almost the same as numerical_coercion -- which is used for equality

What is the reason for making a new function rather than extending numerical_coercion with Decimal support?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
The numerical_coercion is used by other coercion functions:

  1. numerical_coercion->eq_coercion -> Operator::IsDistinctFrom | Operator::IsNotDistinctFrom =>
  2. numerical_coercion -> dictionary_value_coercion
    If I extend the numerical_coercion, it may take a side effect to other coercion logical.
    I will refactor and merge these two together and make it clear if there is no conflict.
    @alamb

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you try extending numerical_coercion (and if so, did it have any side effects)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb I will extend it in the follow-up pull request and support dict and distinct with decimal.
In the follow-up pull request, I think we can refactor out a common part.

mathematics_op: &Operator,
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;

// error on any non-numeric type
if !is_numeric(lhs_type) || !is_numeric(rhs_type) {
return None;
};

// same type => all good
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}

// these are ordered from most informative to least informative so
// that the coercion removes the least amount of information
match (lhs_type, rhs_type) {
(Decimal(_, _), Decimal(_, _)) => {
coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type)
}
(Decimal(_, _), _) => {
let converted_decimal_type = convert_numeric_type_to_decimal(rhs_type);
match converted_decimal_type {
None => None,
Some(right_decimal_type) => coercion_decimal_mathematics_type(
mathematics_op,
lhs_type,
&right_decimal_type,
),
}
}
(_, Decimal(_, _)) => {
let converted_decimal_type = convert_numeric_type_to_decimal(lhs_type);
match converted_decimal_type {
None => None,
Some(left_decimal_type) => coercion_decimal_mathematics_type(
mathematics_op,
&left_decimal_type,
rhs_type,
),
}
}
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
(Int64, _) | (_, Int64) => Some(Int64),
(Int32, _) | (_, Int32) => Some(Int32),
(Int16, _) | (_, Int16) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
}

fn create_decimal_type(precision: usize, scale: usize) -> DataType {
DataType::Decimal(
MAX_PRECISION_FOR_DECIMAL128.min(precision),
MAX_SCALE_FOR_DECIMAL128.min(scale),
)
}

fn coercion_decimal_mathematics_type(
mathematics_op: &Operator,
left_decimal_type: &DataType,
right_decimal_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (left_decimal_type, right_decimal_type) {
// The coercion rule from spark
// https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35
(Decimal(p1, s1), Decimal(p2, s2)) => {
match mathematics_op {
Operator::Plus | Operator::Minus => {
// max(s1, s2)
let result_scale = *s1.max(s2);
// max(s1, s2) + max(p1-s1, p2-s2) + 1
let result_precision = result_scale + (*p1 - *s1).max(*p2 - *s2) + 1;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Multiply => {
// s1 + s2
let result_scale = *s1 + *s2;
// p1 + p2 + 1
let result_precision = *p1 + *p2 + 1;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Divide => {
// max(6, s1 + p2 + 1)
let result_scale = 6.max(*s1 + *p2 + 1);
// p1 - s1 + s2 + max(6, s1 + p2 + 1)
let result_precision = result_scale + *p1 - *s1 + *s2;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Modulo => {
// max(s1, s2)
let result_scale = *s1.max(s2);
// min(p1-s1, p2-s2) + max(s1, s2)
let result_precision = result_scale + (*p1 - *s1).min(*p2 - *s2);
Some(create_decimal_type(result_precision, result_scale))
}
_ => unreachable!(),
}
}
_ => unreachable!(),
}
}

#[cfg(test)]
mod tests {
use crate::arrow::datatypes::DataType;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::Operator;
use crate::physical_plan::coercion_rule::binary_rule::coerce_types;
use crate::physical_plan::coercion_rule::binary_rule::{
coerce_types, coercion_decimal_mathematics_type, convert_numeric_type_to_decimal,
};

#[test]

Expand Down Expand Up @@ -226,4 +355,70 @@ mod tests {
assert!(result_type.is_err());
Ok(())
}

#[test]
fn test_decimal_mathematics_op_type() {
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Int8).unwrap(),
DataType::Decimal(3, 0)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Int16).unwrap(),
DataType::Decimal(5, 0)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Int32).unwrap(),
DataType::Decimal(10, 0)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Int64).unwrap(),
DataType::Decimal(20, 0)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Float32).unwrap(),
DataType::Decimal(14, 7)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Float64).unwrap(),
DataType::Decimal(30, 15)
);

let op = Operator::Plus;
let left_decimal_type = DataType::Decimal(10, 3);
let right_decimal_type = DataType::Decimal(20, 4);
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(21, 4), result.unwrap());
let op = Operator::Minus;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(21, 4), result.unwrap());
let op = Operator::Multiply;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(31, 7), result.unwrap());
let op = Operator::Divide;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(35, 24), result.unwrap());
let op = Operator::Modulo;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(11, 4), result.unwrap());
}
}
Loading