From bf63119ac36e1ef2ec229752eee892cb59845df6 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 4 Jan 2022 15:40:09 +0800 Subject: [PATCH] support decimal to arithmetic operation --- ballista/rust/core/Cargo.toml | 3 +- .../core/src/serde/logical_plan/to_proto.rs | 2 +- ballista/rust/executor/Cargo.toml | 6 +- datafusion-cli/Cargo.toml | 3 +- datafusion-examples/Cargo.toml | 3 +- datafusion/Cargo.toml | 6 +- .../coercion_rule/binary_rule.rs | 205 ++++++++++++- .../src/physical_plan/expressions/binary.rs | 271 ++++++++++++++++++ 8 files changed, 479 insertions(+), 20 deletions(-) diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 12b4e6f0a281..fc357048b37b 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -43,8 +43,7 @@ tonic = "0.5" uuid = { version = "0.8", features = ["v4"] } chrono = { version = "0.4", default-features = false } -#arrow-flight = { version = "6.4.0" } -arrow-flight = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow-flight" } +arrow-flight = { version = "6.4.0" } datafusion = { path = "../../../datafusion", version = "6.0.0" } [dev-dependencies] diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 0058bed26505..e6aa0fb1ba6b 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -60,7 +60,7 @@ impl protobuf::IntervalUnit { match interval_unit { IntervalUnit::YearMonth => protobuf::IntervalUnit::YearMonth, IntervalUnit::DayTime => protobuf::IntervalUnit::DayTime, - IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano, + IntervalUnit::MonthDayNano => protobuf::IntervalUnit::MonthDayNano, } } diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml index 8033823e26d2..00f3aab745ff 100644 --- a/ballista/rust/executor/Cargo.toml +++ b/ballista/rust/executor/Cargo.toml @@ -29,10 +29,8 @@ edition = "2018" snmalloc = ["snmalloc-rs"] [dependencies] -#arrow = { version = "6.4.0" } -#arrow-flight = { version = "6.4.0" } -arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow" } -arrow-flight = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow-flight" } +arrow = { version = "6.4.0" } +arrow-flight = { version = "6.4.0" } anyhow = "1" async-trait = "0.1.36" ballista-core = { path = "../core", version = "0.6.0" } diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index 14cefcf7acf1..394bd1e3a29b 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -31,6 +31,5 @@ clap = "2.33" rustyline = "9.0" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } datafusion = { path = "../datafusion", version = "6.0.0" } -#arrow = { version = "6.4.0" } +arrow = { version = "6.4.0" } ballista = { path = "../ballista/rust/client", version = "0.6.0" } -arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow" } diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index e24d5a5d9820..f7ef66d99bde 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -34,9 +34,8 @@ path = "examples/avro_sql.rs" required-features = ["datafusion/avro"] [dev-dependencies] -#arrow-flight = { version = "6.4.0" } +arrow-flight = { version = "6.4.0" } datafusion = { path = "../datafusion" } -arrow-flight = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow-flight" } prost = "0.8" tonic = "0.5" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] } diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 2e54389afe2b..b9192826120e 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -52,10 +52,8 @@ avro = ["avro-rs", "num-traits"] [dependencies] ahash = { version = "0.7", default-features = false } hashbrown = { version = "0.11", features = ["raw"] } -#arrow = { version = "6.4.0", features = ["prettyprint"] } -#parquet = { version = "6.4.0", features = ["arrow"] } -arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow", features = ["prettyprint"] } -parquet = { path = "/Users/kliu3/Documents/github/arrow-rs/parquet", features = ["arrow"] } +arrow = { version = "6.4.0", features = ["prettyprint"] } +parquet = { version = "6.4.0", features = ["arrow"] } sqlparser = "0.13" paste = "^1.0" num_cpus = "1.13.0" diff --git a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs index bf836f9e7547..1ac295723fc6 100644 --- a/datafusion/src/physical_plan/coercion_rule/binary_rule.rs +++ b/datafusion/src/physical_plan/coercion_rule/binary_rule.rs @@ -2,9 +2,10 @@ use crate::arrow::datatypes::DataType; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Operator; use crate::physical_plan::expressions::coercion::{ - dictionary_coercion, eq_coercion, is_numeric, like_coercion, numerical_coercion, - string_coercion, temporal_coercion, + dictionary_coercion, eq_coercion, is_numeric, like_coercion, string_coercion, + temporal_coercion, }; +use crate::scalar::{MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128}; /// Coercion rules for all binary operators. Returns the output type /// of applying `op` to an argument of `lhs_type` and `rhs_type`. @@ -30,12 +31,11 @@ pub(crate) fn coerce_types( Operator::Like | Operator::NotLike => like_coercion(lhs_type, rhs_type), // for math expressions, the final value of the coercion is also the return type // because coercion favours higher information types - // TODO: support decimal data type Operator::Plus | Operator::Minus | Operator::Modulo | Operator::Divide - | Operator::Multiply => numerical_coercion(lhs_type, rhs_type), + | Operator::Multiply => mathematics_numerical_coercion(op, lhs_type, rhs_type), Operator::RegexMatch | Operator::RegexIMatch | Operator::RegexNotMatch @@ -143,12 +143,141 @@ fn get_comparison_common_decimal_type( } } +// Convert the numeric data type to the decimal data type. +// Now, we just support the signed integer type and floating-point type. +fn convert_numeric_type_to_decimal(numeric_type: &DataType) -> Option { + match numeric_type { + DataType::Int8 => Some(DataType::Decimal(3, 0)), + DataType::Int16 => Some(DataType::Decimal(5, 0)), + DataType::Int32 => Some(DataType::Decimal(10, 0)), + DataType::Int64 => Some(DataType::Decimal(20, 0)), + // TODO if we convert the floating-point data to the decimal type, it maybe overflow. + DataType::Float32 => Some(DataType::Decimal(14, 7)), + DataType::Float64 => Some(DataType::Decimal(30, 15)), + _ => None, + } +} + +fn mathematics_numerical_coercion( + mathematics_op: &Operator, + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + + // error on any non-numeric type + if !is_numeric(lhs_type) || !is_numeric(rhs_type) { + return None; + }; + + // same type => all good + if lhs_type == rhs_type { + return Some(lhs_type.clone()); + } + + // these are ordered from most informative to least informative so + // that the coercion removes the least amount of information + match (lhs_type, rhs_type) { + (Decimal(_, _), Decimal(_, _)) => { + coercion_decimal_mathematics_type(mathematics_op, lhs_type, rhs_type) + } + (Decimal(_, _), _) => { + let converted_decimal_type = convert_numeric_type_to_decimal(rhs_type); + if converted_decimal_type.is_none() { + None + } else { + coercion_decimal_mathematics_type( + mathematics_op, + lhs_type, + &converted_decimal_type.unwrap(), + ) + } + } + (_, Decimal(_, _)) => { + let converted_decimal_type = convert_numeric_type_to_decimal(lhs_type); + if converted_decimal_type.is_none() { + None + } else { + coercion_decimal_mathematics_type( + mathematics_op, + &converted_decimal_type.unwrap(), + rhs_type, + ) + } + } + (Float64, _) | (_, Float64) => Some(Float64), + (_, Float32) | (Float32, _) => Some(Float32), + (Int64, _) | (_, Int64) => Some(Int64), + (Int32, _) | (_, Int32) => Some(Int32), + (Int16, _) | (_, Int16) => Some(Int16), + (Int8, _) | (_, Int8) => Some(Int8), + (UInt64, _) | (_, UInt64) => Some(UInt64), + (UInt32, _) | (_, UInt32) => Some(UInt32), + (UInt16, _) | (_, UInt16) => Some(UInt16), + (UInt8, _) | (_, UInt8) => Some(UInt8), + _ => None, + } +} + +fn create_decimal_type(precision: usize, scale: usize) -> DataType { + DataType::Decimal( + MAX_PRECISION_FOR_DECIMAL128.min(precision), + MAX_SCALE_FOR_DECIMAL128.min(scale), + ) +} + +fn coercion_decimal_mathematics_type( + mathematics_op: &Operator, + left_decimal_type: &DataType, + right_decimal_type: &DataType, +) -> Option { + use arrow::datatypes::DataType::*; + match (left_decimal_type, right_decimal_type) { + (Decimal(p1, s1), Decimal(p2, s2)) => { + match mathematics_op { + Operator::Plus | Operator::Minus => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // max(s1, s2) + max(p1-s1, p2-s2) + 1 + let result_precision = result_scale + (*p1 - *s1).max(*p2 - *s2) + 1; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Multiply => { + // s1 + s2 + let result_scale = *s1 + *s2; + // p1 + p2 + 1 + let result_precision = *p1 + *p2 + 1; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Divide => { + // max(6, s1 + p2 + 1) + let result_scale = 6.max(*s1 + *p2 + 1); + // p1 - s1 + s2 + max(6, s1 + p2 + 1) + let result_precision = result_scale + *p1 - *s1 + *s2; + Some(create_decimal_type(result_precision, result_scale)) + } + Operator::Modulo => { + // max(s1, s2) + let result_scale = *s1.max(s2); + // min(p1-s1, p2-s2) + max(s1, s2) + let result_precision = result_scale + (*p1 - *s1).min(*p2 - *s2); + Some(create_decimal_type(result_precision, result_scale)) + } + _ => unreachable!(), + } + } + _ => unreachable!(), + } +} + #[cfg(test)] mod tests { use crate::arrow::datatypes::DataType; use crate::error::{DataFusionError, Result}; use crate::logical_plan::Operator; - use crate::physical_plan::coercion_rule::binary_rule::coerce_types; + use crate::physical_plan::coercion_rule::binary_rule::{ + coerce_types, coercion_decimal_mathematics_type, convert_numeric_type_to_decimal, + }; #[test] @@ -207,4 +336,70 @@ mod tests { assert!(result_type.is_err()); Ok(()) } + + #[test] + fn test_decimal_mathematics_op_type() { + assert_eq!( + convert_numeric_type_to_decimal(&DataType::Int8).unwrap(), + DataType::Decimal(3, 0) + ); + assert_eq!( + convert_numeric_type_to_decimal(&DataType::Int16).unwrap(), + DataType::Decimal(5, 0) + ); + assert_eq!( + convert_numeric_type_to_decimal(&DataType::Int32).unwrap(), + DataType::Decimal(10, 0) + ); + assert_eq!( + convert_numeric_type_to_decimal(&DataType::Int64).unwrap(), + DataType::Decimal(20, 0) + ); + assert_eq!( + convert_numeric_type_to_decimal(&DataType::Float32).unwrap(), + DataType::Decimal(14, 7) + ); + assert_eq!( + convert_numeric_type_to_decimal(&DataType::Float64).unwrap(), + DataType::Decimal(30, 15) + ); + + let op = Operator::Plus; + let left_decimal_type = DataType::Decimal(10, 3); + let right_decimal_type = DataType::Decimal(20, 4); + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(21, 4), result.unwrap()); + let op = Operator::Minus; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(21, 4), result.unwrap()); + let op = Operator::Multiply; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(31, 7), result.unwrap()); + let op = Operator::Divide; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(35, 24), result.unwrap()); + let op = Operator::Modulo; + let result = coercion_decimal_mathematics_type( + &op, + &left_decimal_type, + &right_decimal_type, + ); + assert_eq!(DataType::Decimal(11, 4), result.unwrap()); + } } diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs index 0574da763569..3a7dcbd2ba20 100644 --- a/datafusion/src/physical_plan/expressions/binary.rs +++ b/datafusion/src/physical_plan/expressions/binary.rs @@ -41,6 +41,7 @@ use arrow::compute::kernels::comparison::{ regexp_is_match_utf8_scalar, }; use arrow::datatypes::{ArrowNumericType, DataType, Schema, TimeUnit}; +use arrow::error::ArrowError::DivideByZero; use arrow::record_batch::RecordBatch; use crate::error::{DataFusionError, Result}; @@ -263,6 +264,80 @@ fn is_not_distinct_from_decimal( Ok(bool_builder.finish()) } +fn add_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(left.value(i) + right.value(i))?; + } + } + Ok(decimal_builder.finish()) +} + +fn subtract_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(left.value(i) - right.value(i))?; + } + } + Ok(decimal_builder.finish()) +} + +fn multiply_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + let divide = 10_i128.pow(left.scale() as u32); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else { + decimal_builder.append_value(left.value(i) * right.value(i) / divide)?; + } + } + Ok(decimal_builder.finish()) +} + +fn divide_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + let mul = 10_f64.powi(left.scale() as i32); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else if right.value(i) == 0 { + return Err(DataFusionError::ArrowError(DivideByZero)); + } else { + let l_value = left.value(i) as f64; + let r_value = right.value(i) as f64; + let result = ((l_value / r_value) * mul) as i128; + decimal_builder.append_value(result)?; + } + } + Ok(decimal_builder.finish()) +} + +fn modulus_decimal(left: &DecimalArray, right: &DecimalArray) -> Result { + let mut decimal_builder = + DecimalBuilder::new(left.len(), left.precision(), left.scale()); + for i in 0..left.len() { + if left.is_null(i) || right.is_null(i) { + decimal_builder.append_null()?; + } else if right.value(i) == 0 { + return Err(DataFusionError::ArrowError(DivideByZero)); + } else { + decimal_builder.append_value(left.value(i) % right.value(i))?; + } + } + Ok(decimal_builder.finish()) +} + /// Binary expression #[derive(Debug)] pub struct BinaryExpr { @@ -472,6 +547,9 @@ macro_rules! binary_string_array_op { macro_rules! binary_primitive_array_op { ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ match $LEFT.data_type() { + // TODO support decimal type + // which is not the primitive type + DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray), DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array), @@ -2570,4 +2648,197 @@ mod tests { Ok(()) } + + #[test] + fn arithmetic_decimal_op_test() -> Result<()> { + let value_i128: i128 = 123; + let left_decimal_array = create_decimal_array( + &[ + Some(value_i128), + None, + Some(value_i128 - 1), + Some(value_i128 + 1), + ], + 25, + 3, + )?; + let right_decimal_array = create_decimal_array( + &[ + Some(value_i128), + Some(value_i128), + Some(value_i128), + Some(value_i128), + ], + 25, + 3, + )?; + // add + let result = add_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = + create_decimal_array(&[Some(246), None, Some(245), Some(247)], 25, 3)?; + assert_eq!(expect, result); + // subtract + let result = subtract_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array(&[Some(0), None, Some(-1), Some(1)], 25, 3)?; + assert_eq!(expect, result); + // multiply + let result = multiply_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array(&[Some(15), None, Some(15), Some(15)], 25, 3)?; + assert_eq!(expect, result); + // divide + let left_decimal_array = create_decimal_array( + &[Some(1234567), None, Some(1234567), Some(1234567)], + 25, + 3, + )?; + let right_decimal_array = + create_decimal_array(&[Some(10), Some(100), Some(55), Some(-123)], 25, 3)?; + let result = divide_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array( + &[Some(123456700), None, Some(22446672), Some(-10037130)], + 25, + 3, + )?; + assert_eq!(expect, result); + // modulus + let result = modulus_decimal(&left_decimal_array, &right_decimal_array)?; + let expect = create_decimal_array(&[Some(7), None, Some(37), Some(16)], 25, 3)?; + assert_eq!(expect, result); + + Ok(()) + } + + fn apply_arithmetic_op( + schema: &SchemaRef, + left: &ArrayRef, + right: &ArrayRef, + op: Operator, + expected: ArrayRef, + ) -> Result<()> { + let arithmetic_op = + binary_simple(col("a", &schema)?, op, col("b", &schema)?, &schema); + let data: Vec = vec![left.clone(), right.clone()]; + let batch = RecordBatch::try_new(schema.clone(), data)?; + let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows()); + + assert_eq!(result.as_ref(), expected.as_ref()); + Ok(()) + } + + #[test] + fn arithmetic_decimal_expr_test() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let value: i128 = 123; + let decimal_array = Arc::new(create_decimal_array( + &[ + Some(value as i128), // 1.23 + None, + Some((value - 1) as i128), // 1.22 + Some((value + 1) as i128), // 1.24 + ], + 10, + 2, + )?) as ArrayRef; + let int32_array = Arc::new(Int32Array::from(vec![ + Some(123), + Some(122), + Some(123), + Some(124), + ])) as ArrayRef; + + // add: Int32array add decimal array + let expect = Arc::new(create_decimal_array( + &[Some(12423), None, Some(12422), Some(12524)], + 13, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Plus, + expect, + ) + .unwrap(); + + // subtract: decimal array subtract int32 array + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, true), + Field::new("a", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[Some(-12177), None, Some(-12178), Some(-12276)], + 13, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Minus, + expect, + ) + .unwrap(); + + // multiply: decimal array multiply int32 array + let expect = Arc::new(create_decimal_array( + &[Some(15129), None, Some(15006), Some(15376)], + 21, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Multiply, + expect, + ) + .unwrap(); + // divide: int32 array divide decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[ + Some(10000000000000), + None, + Some(10081967213114), + Some(10000000000000), + ], + 23, + 11, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Divide, + expect, + ) + .unwrap(); + // modulus: int32 array modulus decimal array + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Decimal(10, 2), true), + ])); + let expect = Arc::new(create_decimal_array( + &[Some(000), None, Some(100), Some(000)], + 10, + 2, + )?) as ArrayRef; + apply_arithmetic_op( + &schema, + &int32_array, + &decimal_array, + Operator::Modulo, + expect, + ) + .unwrap(); + + Ok(()) + } }