diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs index 4a5a2c34a393d..e2939595611c5 100644 --- a/datafusion/src/scalar.rs +++ b/datafusion/src/scalar.rs @@ -43,6 +43,8 @@ pub enum ScalarValue { Float32(Option), /// 64bit float Float64(Option), + /// 128bit decimal, using the i128 to represent the decimal + Decimal128(Option, Option, Option), /// signed 8bit int Int8(Option), /// signed 16bit int @@ -100,6 +102,12 @@ impl PartialEq for ScalarValue { // any newly added enum variant will require editing this list // or else face a compile error match (self, other) { + (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { + v1.eq(v2) && p1.eq(p2) && s1.eq(s2) + // TODO how to handle this case: decimal(123,10,1) with decimal(1230,10,2) + // these values present the same value, but has the diff data type + } + (Decimal128(_, _, _), _) => false, (Boolean(v1), Boolean(v2)) => v1.eq(v2), (Boolean(_), _) => false, (Float32(v1), Float32(v2)) => { @@ -171,6 +179,17 @@ impl PartialOrd for ScalarValue { // any newly added enum variant will require editing this list // or else face a compile error match (self, other) { + // TODO decimal type, we just compare the values which have the same precision and scale. + // need to support decimal value with diff precision and scale + // TODO how to compare Null decimal with other decimal + (Decimal128(v1, p1, s1), Decimal128(v2, p2, s2)) => { + if p1.eq(p2) && s1.eq(s2) { + v1.partial_cmp(v2) + } else { + None + } + } + (Decimal128(_, _, _), _) => None, (Boolean(v1), Boolean(v2)) => v1.partial_cmp(v2), (Boolean(_), _) => None, (Float32(v1), Float32(v2)) => { @@ -253,6 +272,11 @@ impl std::hash::Hash for ScalarValue { fn hash(&self, state: &mut H) { use ScalarValue::*; match self { + Decimal128(v, p, s) => { + v.hash(state); + p.hash(state); + s.hash(state) + } Boolean(v) => v.hash(state), Float32(v) => { let v = v.map(OrderedFloat); @@ -453,6 +477,26 @@ macro_rules! eq_array_primitive { } impl ScalarValue { + /// Create a decimal Scalar from value/precision and scale. + pub fn try_new_decimal128( + value: i128, + precision: usize, + scale: usize, + ) -> Result { + // make sure the precision and scale is valid + // TODO const the max precision and min scale + if precision <= 38 && scale <= precision { + return Ok(ScalarValue::Decimal128( + Some(value), + Some(precision), + Some(scale), + )); + } + return Err(DataFusionError::Internal(format!( + "Can not new a decimal type ScalarValue for precision {} and scale {}", + precision, scale + ))); + } /// Getter for the `DataType` of the value pub fn get_datatype(&self) -> DataType { match self { @@ -465,6 +509,14 @@ impl ScalarValue { ScalarValue::Int16(_) => DataType::Int16, ScalarValue::Int32(_) => DataType::Int32, ScalarValue::Int64(_) => DataType::Int64, + ScalarValue::Decimal128(_, Some(precision), Some(scale)) => { + DataType::Decimal(*precision, *scale) + } + ScalarValue::Decimal128(_, _, _) => { + // TODO add the default precision and scale for this case + // DataType::Decimal(38, 0) + panic!("The Decimal Scalar value with invalid precision or scale."); + } ScalarValue::TimestampSecond(_) => { DataType::Timestamp(TimeUnit::Second, None) } @@ -513,6 +565,9 @@ impl ScalarValue { ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)), ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)), ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)), + ScalarValue::Decimal128(Some(v), Some(precision), Some(scale)) => { + ScalarValue::Decimal128(Some(-v), Some(*precision), Some(*scale)) + } _ => panic!("Cannot run arithmetic negate on scalar value: {:?}", self), } } @@ -541,6 +596,7 @@ impl ScalarValue { | ScalarValue::TimestampMicrosecond(None) | ScalarValue::TimestampNanosecond(None) | ScalarValue::Struct(None, _) + | ScalarValue::Decimal128(None, _, _) // For decimal type, the value is null means ScalarValue::Decimal128 is null. ) } @@ -590,7 +646,7 @@ impl ScalarValue { None => { return Err(DataFusionError::Internal( "Empty iterator passed to ScalarValue::iter_to_array".to_string(), - )) + )); } Some(sv) => sv.get_datatype(), }; @@ -706,6 +762,11 @@ impl ScalarValue { } let array: ArrayRef = match &data_type { + DataType::Decimal(precision, scale) => { + let decimal_array = + ScalarValue::iter_to_decimal_array(scalars, precision, scale)?; + Arc::new(decimal_array) + } DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), @@ -831,13 +892,40 @@ impl ScalarValue { "Unsupported creation of {:?} array from ScalarValue {:?}", data_type, scalars.peek() - ))) + ))); } }; Ok(array) } + fn iter_to_decimal_array( + scalars: impl IntoIterator, + precision: &usize, + scale: &usize, + ) -> Result { + // collect the value as Option + let array = scalars + .into_iter() + .map(|element: ScalarValue| match element { + ScalarValue::Decimal128(v1, _, _) => v1, + _ => unreachable!(), + }) + .collect::>>(); + + // build the decimal array using the Decimal Builder + let mut builder = DecimalBuilder::new(array.len(), *precision, *scale); + array.iter().for_each(|element| match element { + None => { + builder.append_null().unwrap(); + } + Some(v) => { + builder.append_value(*v).unwrap(); + } + }); + Ok(builder.finish()) + } + fn iter_to_array_list( scalars: impl IntoIterator, data_type: &DataType, @@ -905,9 +993,36 @@ impl ScalarValue { Ok(list_array) } + fn build_decimal_array( + value: &Option, + precision: &Option, + scale: &Option, + size: usize, + ) -> DecimalArray { + // TODO check the precision and scale + let mut builder = DecimalBuilder::new(size, precision.unwrap(), scale.unwrap()); + match value { + None => { + for _i in 0..size { + builder.append_null().unwrap(); + } + } + Some(v) => { + let v = *v; + for _i in 0..size { + builder.append_value(v).unwrap(); + } + } + }; + builder.finish() + } + /// Converts a scalar value into an array of `size` rows. pub fn to_array_of_size(&self, size: usize) -> ArrayRef { match self { + ScalarValue::Decimal128(e, precision, scale) => { + Arc::new(ScalarValue::build_decimal_array(e, precision, scale, size)) + } ScalarValue::Boolean(e) => { Arc::new(BooleanArray::from(vec![*e; size])) as ArrayRef } @@ -1061,12 +1176,15 @@ impl ScalarValue { Arc::new(StructArray::from(field_values)) } None => { - let field_values: Vec<_> = fields.iter().map(|field| { + let field_values: Vec<_> = fields + .iter() + .map(|field| { let none_field = Self::try_from(field.data_type()).expect( "Failed to construct null ScalarValue from Struct field type" ); (field.clone(), none_field.to_array_of_size(size)) - }).collect(); + }) + .collect(); Arc::new(StructArray::from(field_values)) } @@ -1074,6 +1192,25 @@ impl ScalarValue { } } + fn get_decimal_value_from_array( + array: &ArrayRef, + index: usize, + precision: &usize, + scale: &usize, + ) -> ScalarValue { + let array = array.as_any().downcast_ref::().unwrap(); + // TODO add checker: the precision and scale are same with array + if array.is_null(index) { + ScalarValue::Decimal128(None, Some(*precision), Some(*scale)) + } else { + ScalarValue::Decimal128( + Some(array.value(index)), + Some(*precision), + Some(*scale), + ) + } + } + /// Converts a value in `array` at `index` into a ScalarValue pub fn try_from_array(array: &ArrayRef, index: usize) -> Result { // handle NULL value @@ -1082,6 +1219,9 @@ impl ScalarValue { } Ok(match array.data_type() { + DataType::Decimal(precision, scale) => { + ScalarValue::get_decimal_value_from_array(array, index, precision, scale) + } DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean), DataType::Float64 => typed_cast!(array, index, Float64Array, Float64), DataType::Float32 => typed_cast!(array, index, Float32Array, Float32), @@ -1162,7 +1302,7 @@ impl ScalarValue { return Err(DataFusionError::Internal(format!( "Index type not supported while creating scalar from dictionary: {}", array.data_type(), - ))) + ))); } }; @@ -1194,11 +1334,28 @@ impl ScalarValue { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar from array of type \"{:?}\"", other - ))) + ))); } }) } + fn eq_array_decimal( + array: &ArrayRef, + index: usize, + value: &Option, + precision: usize, + scale: usize, + ) -> bool { + let array = array.as_any().downcast_ref::().unwrap(); + if array.precision() != precision || array.scale() != scale { + return false; + } + match value { + None => array.is_null(index), + Some(v) => !array.is_null(index) && array.value(index) == *v, + } + } + /// Compares a single row of array @ index for equality with self, /// in an optimized fashion. /// @@ -1222,6 +1379,11 @@ impl ScalarValue { } match self { + ScalarValue::Decimal128(v, Some(precision), Some(scale)) => { + ScalarValue::eq_array_decimal(array, index, v, *precision, *scale) + } + // TODO the precision or scale is none, we can't handle this case. + ScalarValue::Decimal128(_, _, _) => unimplemented!(), ScalarValue::Boolean(val) => { eq_array_primitive!(array, index, BooleanArray, val) } @@ -1458,6 +1620,9 @@ impl TryFrom<&DataType> for ScalarValue { DataType::UInt16 => ScalarValue::UInt16(None), DataType::UInt32 => ScalarValue::UInt32(None), DataType::UInt64 => ScalarValue::UInt64(None), + DataType::Decimal(precision, scale) => { + ScalarValue::Decimal128(None, Some(*precision), Some(*scale)) + } DataType::Utf8 => ScalarValue::Utf8(None), DataType::LargeUtf8 => ScalarValue::LargeUtf8(None), DataType::Date32 => ScalarValue::Date32(None), @@ -1487,7 +1652,7 @@ impl TryFrom<&DataType> for ScalarValue { return Err(DataFusionError::NotImplemented(format!( "Can't create a scalar from data_type \"{:?}\"", datatype - ))) + ))); } }) } @@ -1505,6 +1670,9 @@ macro_rules! format_option { impl fmt::Display for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { + ScalarValue::Decimal128(v, p, s) => { + write!(f, "{}", format!("{:?},{:?},{:?}", v, p, s))?; + } ScalarValue::Boolean(e) => format_option!(f, e)?, ScalarValue::Float32(e) => format_option!(f, e)?, ScalarValue::Float64(e) => format_option!(f, e)?, @@ -1579,6 +1747,7 @@ impl fmt::Display for ScalarValue { impl fmt::Debug for ScalarValue { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { + ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({})", self), ScalarValue::Boolean(_) => write!(f, "Boolean({})", self), ScalarValue::Float32(_) => write!(f, "Float32({})", self), ScalarValue::Float64(_) => write!(f, "Float64({})", self), @@ -1676,6 +1845,102 @@ impl ScalarType for TimestampNanosecondType { mod tests { use super::*; + #[test] + fn scalar_decimal_test() { + let decimal_value = ScalarValue::Decimal128(Some(123), Some(10), Some(1)); + assert_eq!(DataType::Decimal(10, 1), decimal_value.get_datatype()); + assert!(!decimal_value.is_null()); + let neg_decimal_value = decimal_value.arithmetic_negate(); + match neg_decimal_value { + ScalarValue::Decimal128(v, _, _) => { + assert_eq!(-123, v.unwrap()); + } + _ => { + unreachable!(); + } + } + + // decimal scalar to array + let array = decimal_value.to_array(); + let array = array.as_any().downcast_ref::().unwrap(); + assert_eq!(1, array.len()); + assert_eq!(DataType::Decimal(10, 1), array.data_type().clone()); + assert_eq!(123i128, array.value(0)); + + // decimal scalar to array with size + let array = decimal_value.to_array_of_size(10); + let array_decimal = array.as_any().downcast_ref::().unwrap(); + assert_eq!(10, array.len()); + assert_eq!(DataType::Decimal(10, 1), array.data_type().clone()); + assert_eq!(123i128, array_decimal.value(0)); + assert_eq!(123i128, array_decimal.value(9)); + // test eq array + assert!(decimal_value.eq_array(&array, 1)); + assert!(decimal_value.eq_array(&array, 5)); + // test try from array + assert_eq!( + decimal_value, + ScalarValue::try_from_array(&array, 5).unwrap() + ); + + assert_eq!( + decimal_value, + ScalarValue::try_new_decimal128(123, 10, 1).unwrap() + ); + + // test compare + let left = ScalarValue::Decimal128(Some(123), Some(10), Some(2)); + let right = ScalarValue::Decimal128(Some(124), Some(10), Some(2)); + assert!(!left.eq(&right)); + let result = left < right; + assert!(result); + let result = left <= right; + assert!(result); + let right = ScalarValue::Decimal128(Some(124), Some(10), Some(3)); + // make sure that two decimals with diff datatype can't be compared. + let result = left.partial_cmp(&right); + assert_eq!(None, result); + + let decimal_vec = vec![ + ScalarValue::Decimal128(Some(1), Some(10), Some(2)), + ScalarValue::Decimal128(Some(2), Some(10), Some(2)), + ScalarValue::Decimal128(Some(3), Some(10), Some(2)), + ]; + // convert the vec to decimal array and check the result + let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); + assert_eq!(3, array.len()); + assert_eq!(DataType::Decimal(10, 2), array.data_type().clone()); + + let decimal_vec = vec![ + ScalarValue::Decimal128(Some(1), Some(10), Some(2)), + ScalarValue::Decimal128(Some(2), Some(10), Some(2)), + ScalarValue::Decimal128(Some(3), Some(10), Some(2)), + ScalarValue::Decimal128(None, Some(10), Some(2)), + ]; + let array = ScalarValue::iter_to_array(decimal_vec.into_iter()).unwrap(); + assert_eq!(4, array.len()); + assert_eq!(DataType::Decimal(10, 2), array.data_type().clone()); + + assert!(ScalarValue::try_new_decimal128(1, 10, 2) + .unwrap() + .eq_array(&array, 0)); + assert!(ScalarValue::try_new_decimal128(2, 10, 2) + .unwrap() + .eq_array(&array, 1)); + assert!(ScalarValue::try_new_decimal128(3, 10, 2) + .unwrap() + .eq_array(&array, 2)); + assert_eq!( + ScalarValue::Decimal128(None, Some(10), Some(2)), + ScalarValue::try_from_array(&array, 3).unwrap() + ); + // TODO why the index has no data, but we can get the null result + assert_eq!( + ScalarValue::Decimal128(None, Some(10), Some(2)), + ScalarValue::try_from_array(&array, 4).unwrap() + ); + } + #[test] fn scalar_value_to_array_u64() { let value = ScalarValue::UInt64(Some(13u64)); @@ -1909,7 +2174,7 @@ mod tests { // Since ScalarValues are used in a non trivial number of places, // making it larger means significant more memory consumption // per distinct value. - assert_eq!(std::mem::size_of::(), 32); + assert_eq!(std::mem::size_of::(), 64); } #[test] @@ -2088,11 +2353,11 @@ mod tests { assert_eq!( List( Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32) + Box::new(DataType::Int32), ) .partial_cmp(&List( Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32) + Box::new(DataType::Int32), )), Some(Ordering::Equal) ); @@ -2100,11 +2365,11 @@ mod tests { assert_eq!( List( Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), - Box::new(DataType::Int32) + Box::new(DataType::Int32), ) .partial_cmp(&List( Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32) + Box::new(DataType::Int32), )), Some(Ordering::Greater) ); @@ -2112,11 +2377,11 @@ mod tests { assert_eq!( List( Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32) + Box::new(DataType::Int32), ) .partial_cmp(&List( Some(Box::new(vec![Int32(Some(10)), Int32(Some(5))])), - Box::new(DataType::Int32) + Box::new(DataType::Int32), )), Some(Ordering::Less) ); @@ -2125,11 +2390,11 @@ mod tests { assert_eq!( List( Some(Box::new(vec![Int64(Some(1)), Int64(Some(5))])), - Box::new(DataType::Int64) + Box::new(DataType::Int64), ) .partial_cmp(&List( Some(Box::new(vec![Int32(Some(1)), Int32(Some(5))])), - Box::new(DataType::Int32) + Box::new(DataType::Int32), )), None ); @@ -2137,11 +2402,11 @@ mod tests { assert_eq!( ScalarValue::from(vec![ ("A", ScalarValue::from(1.0)), - ("B", ScalarValue::from("Z")) + ("B", ScalarValue::from("Z")), ]) .partial_cmp(&ScalarValue::from(vec![ ("A", ScalarValue::from(2.0)), - ("B", ScalarValue::from("A")) + ("B", ScalarValue::from("A")), ])), Some(Ordering::Less) ); @@ -2150,11 +2415,11 @@ mod tests { assert_eq!( ScalarValue::from(vec![ ("A", ScalarValue::from(1.0)), - ("B", ScalarValue::from("Z")) + ("B", ScalarValue::from("Z")), ]) .partial_cmp(&ScalarValue::from(vec![ ("a", ScalarValue::from(2.0)), - ("b", ScalarValue::from("A")) + ("b", ScalarValue::from("A")), ])), None );