Skip to content

Commit

Permalink
support mathematics operation for decimal data type (#1554)
Browse files Browse the repository at this point in the history
* support comparison for decimal data type

* add more test for decimal comparison; refactor binary test

* change cargo temporarily

* support decimal to arithmetic operation

* minor fix

* add doc/comment for coercion rule

* address comments
  • Loading branch information
liukun4515 authored Jan 18, 2022
1 parent 82e8003 commit c549d51
Show file tree
Hide file tree
Showing 2 changed files with 480 additions and 17 deletions.
205 changes: 200 additions & 5 deletions datafusion/src/physical_plan/coercion_rule/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ use crate::arrow::datatypes::DataType;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::Operator;
use crate::physical_plan::expressions::coercion::{
dictionary_coercion, eq_coercion, is_numeric, like_coercion, numerical_coercion,
string_coercion, temporal_coercion,
dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion,
temporal_coercion,
};
use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128};

/// Coercion rules for all binary operators. Returns the output type
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
Expand All @@ -49,12 +50,11 @@ pub(crate) fn coerce_types(
Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type),
// for math expressions, the final value of the coercion is also the return type
// because coercion favours higher information types
// TODO: support decimal data type
Operator::Plus
| Operator::Minus
| Operator::Modulo
| Operator::Divide
| Operator::Multiply => numerical_coercion(lhs_type, rhs_type),
| Operator::Multiply => mathematics_numerical_coercion(op, lhs_type, rhs_type),
Operator::RegexMatch
| Operator::RegexIMatch
| Operator::RegexNotMatch
Expand Down Expand Up @@ -162,12 +162,141 @@ fn get_comparison_common_decimal_type(
}
}

// Convert the numeric data type to the decimal data type.
// Now, we just support the signed integer type and floating-point type.
fn coerce_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
match numeric_type {
DataType::Int8 => Some(DataType::Decimal(3, 0)),
DataType::Int16 => Some(DataType::Decimal(5, 0)),
DataType::Int32 => Some(DataType::Decimal(10, 0)),
DataType::Int64 => Some(DataType::Decimal(20, 0)),
// TODO if we convert the floating-point data to the decimal type, it maybe overflow.
DataType::Float32 => Some(DataType::Decimal(14, 7)),
DataType::Float64 => Some(DataType::Decimal(30, 15)),
_ => None,
}
}

fn mathematics_numerical_coercion(
mathematics_op: &Operator,
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;

// error on any non-numeric type
if !is_numeric(lhs_type) || !is_numeric(rhs_type) {
return None;
};

// same type => all good
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}

// these are ordered from most informative to least informative so
// that the coercion removes the least amount of information
match (lhs_type, rhs_type) {
(Decimal(_, _), Decimal(_, _)) => {
coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type)
}
(Decimal(_, _), _) => {
let converted_decimal_type = coerce_numeric_type_to_decimal(rhs_type);
match converted_decimal_type {
None => None,
Some(right_decimal_type) => coercion_decimal_mathematics_type(
mathematics_op,
lhs_type,
&right_decimal_type,
),
}
}
(_, Decimal(_, _)) => {
let converted_decimal_type = coerce_numeric_type_to_decimal(lhs_type);
match converted_decimal_type {
None => None,
Some(left_decimal_type) => coercion_decimal_mathematics_type(
mathematics_op,
&left_decimal_type,
rhs_type,
),
}
}
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
(Int64, _) | (_, Int64) => Some(Int64),
(Int32, _) | (_, Int32) => Some(Int32),
(Int16, _) | (_, Int16) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
}

fn create_decimal_type(precision: usize, scale: usize) -> DataType {
DataType::Decimal(
MAX_PRECISION_FOR_DECIMAL128.min(precision),
MAX_SCALE_FOR_DECIMAL128.min(scale),
)
}

fn coercion_decimal_mathematics_type(
mathematics_op: &Operator,
left_decimal_type: &DataType,
right_decimal_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (left_decimal_type, right_decimal_type) {
// The coercion rule from spark
// https://github.com/apache/spark/blob/c20af535803a7250fef047c2bf0fe30be242369d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala#L35
(Decimal(p1, s1), Decimal(p2, s2)) => {
match mathematics_op {
Operator::Plus | Operator::Minus => {
// max(s1, s2)
let result_scale = *s1.max(s2);
// max(s1, s2) + max(p1-s1, p2-s2) + 1
let result_precision = result_scale + (*p1 - *s1).max(*p2 - *s2) + 1;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Multiply => {
// s1 + s2
let result_scale = *s1 + *s2;
// p1 + p2 + 1
let result_precision = *p1 + *p2 + 1;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Divide => {
// max(6, s1 + p2 + 1)
let result_scale = 6.max(*s1 + *p2 + 1);
// p1 - s1 + s2 + max(6, s1 + p2 + 1)
let result_precision = result_scale + *p1 - *s1 + *s2;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Modulo => {
// max(s1, s2)
let result_scale = *s1.max(s2);
// min(p1-s1, p2-s2) + max(s1, s2)
let result_precision = result_scale + (*p1 - *s1).min(*p2 - *s2);
Some(create_decimal_type(result_precision, result_scale))
}
_ => unreachable!(),
}
}
_ => unreachable!(),
}
}

#[cfg(test)]
mod tests {
use crate::arrow::datatypes::DataType;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::Operator;
use crate::physical_plan::coercion_rule::binary_rule::coerce_types;
use crate::physical_plan::coercion_rule::binary_rule::{
coerce_numeric_type_to_decimal, coerce_types, coercion_decimal_mathematics_type,
};

#[test]

Expand Down Expand Up @@ -226,4 +355,70 @@ mod tests {
assert!(result_type.is_err());
Ok(())
}

#[test]
fn test_decimal_mathematics_op_type() {
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Int8).unwrap(),
DataType::Decimal(3, 0)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Int16).unwrap(),
DataType::Decimal(5, 0)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Int32).unwrap(),
DataType::Decimal(10, 0)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Int64).unwrap(),
DataType::Decimal(20, 0)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Float32).unwrap(),
DataType::Decimal(14, 7)
);
assert_eq!(
coerce_numeric_type_to_decimal(&DataType::Float64).unwrap(),
DataType::Decimal(30, 15)
);

let op = Operator::Plus;
let left_decimal_type = DataType::Decimal(10, 3);
let right_decimal_type = DataType::Decimal(20, 4);
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(21, 4), result.unwrap());
let op = Operator::Minus;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(21, 4), result.unwrap());
let op = Operator::Multiply;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(31, 7), result.unwrap());
let op = Operator::Divide;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(35, 24), result.unwrap());
let op = Operator::Modulo;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(11, 4), result.unwrap());
}
}
Loading

0 comments on commit c549d51

Please sign in to comment.