Skip to content

Commit

Permalink
support decimal to arithmetic operation
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Jan 17, 2022
1 parent a2551db commit bf63119
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 20 deletions.
3 changes: 1 addition & 2 deletions ballista/rust/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ tonic = "0.5"
uuid = { version = "0.8", features = ["v4"] }
chrono = { version = "0.4", default-features = false }

#arrow-flight = { version = "6.4.0" }
arrow-flight = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow-flight" }
arrow-flight = { version = "6.4.0" }
datafusion = { path = "../../../datafusion", version = "6.0.0" }

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl protobuf::IntervalUnit {
match interval_unit {
IntervalUnit::YearMonth => protobuf::IntervalUnit::YearMonth,
IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime,
IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano,
IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano,
}
}

Expand Down
6 changes: 2 additions & 4 deletions ballista/rust/executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ edition = "2018"
snmalloc = ["snmalloc-rs"]

[dependencies]
#arrow = { version = "6.4.0" }
#arrow-flight = { version = "6.4.0" }
arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow" }
arrow-flight = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow-flight" }
arrow = { version = "6.4.0" }
arrow-flight = { version = "6.4.0" }
anyhow = "1"
async-trait = "0.1.36"
ballista-core = { path = "../core", version = "0.6.0" }
Expand Down
3 changes: 1 addition & 2 deletions datafusion-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,5 @@ clap = "2.33"
rustyline = "9.0"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
datafusion = { path = "../datafusion", version = "6.0.0" }
#arrow = { version = "6.4.0" }
arrow = { version = "6.4.0" }
ballista = { path = "../ballista/rust/client", version = "0.6.0" }
arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow" }
3 changes: 1 addition & 2 deletions datafusion-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ path = "examples/avro_sql.rs"
required-features = ["datafusion/avro"]

[dev-dependencies]
#arrow-flight = { version = "6.4.0" }
arrow-flight = { version = "6.4.0" }
datafusion = { path = "../datafusion" }
arrow-flight = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow-flight" }
prost = "0.8"
tonic = "0.5"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
Expand Down
6 changes: 2 additions & 4 deletions datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ avro = ["avro-rs", "num-traits"]
[dependencies]
ahash = { version = "0.7", default-features = false }
hashbrown = { version = "0.11", features = ["raw"] }
#arrow = { version = "6.4.0", features = ["prettyprint"] }
#parquet = { version = "6.4.0", features = ["arrow"] }
arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow", features = ["prettyprint"] }
parquet = { path = "/Users/kliu3/Documents/github/arrow-rs/parquet", features = ["arrow"] }
arrow = { version = "6.4.0", features = ["prettyprint"] }
parquet = { version = "6.4.0", features = ["arrow"] }
sqlparser = "0.13"
paste = "^1.0"
num_cpus = "1.13.0"
Expand Down
205 changes: 200 additions & 5 deletions datafusion/src/physical_plan/coercion_rule/binary_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use crate::arrow::datatypes::DataType;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::Operator;
use crate::physical_plan::expressions::coercion::{
dictionary_coercion, eq_coercion, is_numeric, like_coercion, numerical_coercion,
string_coercion, temporal_coercion,
dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion,
temporal_coercion,
};
use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128};

/// Coercion rules for all binary operators. Returns the output type
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
Expand All @@ -30,12 +31,11 @@ pub(crate) fn coerce_types(
Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type),
// for math expressions, the final value of the coercion is also the return type
// because coercion favours higher information types
// TODO: support decimal data type
Operator::Plus
| Operator::Minus
| Operator::Modulo
| Operator::Divide
| Operator::Multiply => numerical_coercion(lhs_type, rhs_type),
| Operator::Multiply => mathematics_numerical_coercion(op, lhs_type, rhs_type),
Operator::RegexMatch
| Operator::RegexIMatch
| Operator::RegexNotMatch
Expand Down Expand Up @@ -143,12 +143,141 @@ fn get_comparison_common_decimal_type(
}
}

// Convert the numeric data type to the decimal data type.
// Now, we just support the signed integer type and floating-point type.
fn convert_numeric_type_to_decimal(numeric_type: &DataType) -> Option<DataType> {
match numeric_type {
DataType::Int8 => Some(DataType::Decimal(3, 0)),
DataType::Int16 => Some(DataType::Decimal(5, 0)),
DataType::Int32 => Some(DataType::Decimal(10, 0)),
DataType::Int64 => Some(DataType::Decimal(20, 0)),
// TODO if we convert the floating-point data to the decimal type, it maybe overflow.
DataType::Float32 => Some(DataType::Decimal(14, 7)),
DataType::Float64 => Some(DataType::Decimal(30, 15)),
_ => None,
}
}

fn mathematics_numerical_coercion(
mathematics_op: &Operator,
lhs_type: &DataType,
rhs_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;

// error on any non-numeric type
if !is_numeric(lhs_type) || !is_numeric(rhs_type) {
return None;
};

// same type => all good
if lhs_type == rhs_type {
return Some(lhs_type.clone());
}

// these are ordered from most informative to least informative so
// that the coercion removes the least amount of information
match (lhs_type, rhs_type) {
(Decimal(_, _), Decimal(_, _)) => {
coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type)
}
(Decimal(_, _), _) => {
let converted_decimal_type = convert_numeric_type_to_decimal(rhs_type);
if converted_decimal_type.is_none() {
None
} else {
coercion_decimal_mathematics_type(
mathematics_op,
lhs_type,
&converted_decimal_type.unwrap(),
)
}
}
(_, Decimal(_, _)) => {
let converted_decimal_type = convert_numeric_type_to_decimal(lhs_type);
if converted_decimal_type.is_none() {
None
} else {
coercion_decimal_mathematics_type(
mathematics_op,
&converted_decimal_type.unwrap(),
rhs_type,
)
}
}
(Float64, _) | (_, Float64) => Some(Float64),
(_, Float32) | (Float32, _) => Some(Float32),
(Int64, _) | (_, Int64) => Some(Int64),
(Int32, _) | (_, Int32) => Some(Int32),
(Int16, _) | (_, Int16) => Some(Int16),
(Int8, _) | (_, Int8) => Some(Int8),
(UInt64, _) | (_, UInt64) => Some(UInt64),
(UInt32, _) | (_, UInt32) => Some(UInt32),
(UInt16, _) | (_, UInt16) => Some(UInt16),
(UInt8, _) | (_, UInt8) => Some(UInt8),
_ => None,
}
}

fn create_decimal_type(precision: usize, scale: usize) -> DataType {
DataType::Decimal(
MAX_PRECISION_FOR_DECIMAL128.min(precision),
MAX_SCALE_FOR_DECIMAL128.min(scale),
)
}

fn coercion_decimal_mathematics_type(
mathematics_op: &Operator,
left_decimal_type: &DataType,
right_decimal_type: &DataType,
) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (left_decimal_type, right_decimal_type) {
(Decimal(p1, s1), Decimal(p2, s2)) => {
match mathematics_op {
Operator::Plus | Operator::Minus => {
// max(s1, s2)
let result_scale = *s1.max(s2);
// max(s1, s2) + max(p1-s1, p2-s2) + 1
let result_precision = result_scale + (*p1 - *s1).max(*p2 - *s2) + 1;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Multiply => {
// s1 + s2
let result_scale = *s1 + *s2;
// p1 + p2 + 1
let result_precision = *p1 + *p2 + 1;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Divide => {
// max(6, s1 + p2 + 1)
let result_scale = 6.max(*s1 + *p2 + 1);
// p1 - s1 + s2 + max(6, s1 + p2 + 1)
let result_precision = result_scale + *p1 - *s1 + *s2;
Some(create_decimal_type(result_precision, result_scale))
}
Operator::Modulo => {
// max(s1, s2)
let result_scale = *s1.max(s2);
// min(p1-s1, p2-s2) + max(s1, s2)
let result_precision = result_scale + (*p1 - *s1).min(*p2 - *s2);
Some(create_decimal_type(result_precision, result_scale))
}
_ => unreachable!(),
}
}
_ => unreachable!(),
}
}

#[cfg(test)]
mod tests {
use crate::arrow::datatypes::DataType;
use crate::error::{DataFusionError, Result};
use crate::logical_plan::Operator;
use crate::physical_plan::coercion_rule::binary_rule::coerce_types;
use crate::physical_plan::coercion_rule::binary_rule::{
coerce_types, coercion_decimal_mathematics_type, convert_numeric_type_to_decimal,
};

#[test]

Expand Down Expand Up @@ -207,4 +336,70 @@ mod tests {
assert!(result_type.is_err());
Ok(())
}

#[test]
fn test_decimal_mathematics_op_type() {
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Int8).unwrap(),
DataType::Decimal(3, 0)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Int16).unwrap(),
DataType::Decimal(5, 0)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Int32).unwrap(),
DataType::Decimal(10, 0)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Int64).unwrap(),
DataType::Decimal(20, 0)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Float32).unwrap(),
DataType::Decimal(14, 7)
);
assert_eq!(
convert_numeric_type_to_decimal(&DataType::Float64).unwrap(),
DataType::Decimal(30, 15)
);

let op = Operator::Plus;
let left_decimal_type = DataType::Decimal(10, 3);
let right_decimal_type = DataType::Decimal(20, 4);
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(21, 4), result.unwrap());
let op = Operator::Minus;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(21, 4), result.unwrap());
let op = Operator::Multiply;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(31, 7), result.unwrap());
let op = Operator::Divide;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(35, 24), result.unwrap());
let op = Operator::Modulo;
let result = coercion_decimal_mathematics_type(
&op,
&left_decimal_type,
&right_decimal_type,
);
assert_eq!(DataType::Decimal(11, 4), result.unwrap());
}
}
Loading

0 comments on commit bf63119

Please sign in to comment.