Skip to content

Commit

Permalink
Add FromIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jul 30, 2022
1 parent d727618 commit c81dc82
Showing 1 changed file with 103 additions and 21 deletions.
124 changes: 103 additions & 21 deletions arrow/src/array/array_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::array::{ArrayAccessor, Decimal128Iter, Decimal256Iter};
use num::BigInt;
use std::borrow::Borrow;
use std::convert::From;
use std::fmt;
Expand All @@ -27,8 +28,10 @@ use super::{
use super::{BooleanBufferBuilder, FixedSizeBinaryArray};
#[allow(deprecated)]
pub use crate::array::DecimalIter;
use crate::buffer::Buffer;
use crate::datatypes::{validate_decimal_precision, DECIMAL_DEFAULT_SCALE};
use crate::buffer::{Buffer, MutableBuffer};
use crate::datatypes::{
validate_decimal_precision, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
};
use crate::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE};
use crate::error::{ArrowError, Result};
use crate::util::decimal::{BasicDecimal, Decimal128, Decimal256};
Expand Down Expand Up @@ -219,6 +222,16 @@ pub trait BasicDecimalArray<T: BasicDecimal, U: From<ArrayData>>:
let array_data = unsafe { builder.build_unchecked() };
U::from(array_data)
}

/// The default precision and scale used when not specified.
fn default_type() -> DataType {
// Keep maximum precision
if Self::VALUE_LENGTH == 16 {
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE)
} else {
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE)
}
}
}

impl BasicDecimalArray<Decimal128, Decimal128Array> for Decimal128Array {
Expand Down Expand Up @@ -324,12 +337,6 @@ impl Decimal128Array {
self.data = self.data.with_data_type(new_data_type);
Ok(self)
}

/// The default precision and scale used when not specified.
pub fn default_type() -> DataType {
// Keep maximum precision
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE)
}
}

impl From<ArrayData> for Decimal128Array {
Expand Down Expand Up @@ -384,6 +391,59 @@ impl<'a> Decimal128Array {
}
}

impl From<BigInt> for Decimal256 {
fn from(bigint: BigInt) -> Self {
Decimal256::from_big_int(&bigint, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE)
.unwrap()
}
}

fn build_decimal_array_from<U: BasicDecimalArray<T, U>, T>(
null_buf: BooleanBufferBuilder,
buffer: Buffer,
) -> U
where
T: BasicDecimal,
U: From<ArrayData>,
{
let data = unsafe {
ArrayData::new_unchecked(
U::default_type(),
null_buf.len(),
None,
Some(null_buf.into()),
0,
vec![buffer],
vec![],
)
};
U::from(data)
}

impl<Ptr: Into<Decimal256>> FromIterator<Option<Ptr>> for Decimal256Array {
fn from_iter<I: IntoIterator<Item = Option<Ptr>>>(iter: I) -> Self {
let iter = iter.into_iter();
let (lower, upper) = iter.size_hint();
let size_hint = upper.unwrap_or(lower);

let mut null_buf = BooleanBufferBuilder::new(size_hint);

let mut buffer = MutableBuffer::from_len_zeroed(0);

iter.for_each(|item| {
if let Some(a) = item {
null_buf.append(true);
buffer.extend_from_slice(Into::into(a).raw_value());
} else {
null_buf.append(false);
buffer.extend_zeros(32);
}
});

build_decimal_array_from::<Decimal256Array, _>(null_buf, buffer.into())
}
}

impl<Ptr: Borrow<Option<i128>>> FromIterator<Ptr> for Decimal128Array {
fn from_iter<I: IntoIterator<Item = Ptr>>(iter: I) -> Self {
let iter = iter.into_iter();
Expand All @@ -405,18 +465,7 @@ impl<Ptr: Borrow<Option<i128>>> FromIterator<Ptr> for Decimal128Array {
})
.collect();

let data = unsafe {
ArrayData::new_unchecked(
Self::default_type(),
null_buf.len(),
None,
Some(null_buf.into()),
0,
vec![buffer],
vec![],
)
};
Decimal128Array::from(data)
build_decimal_array_from::<Decimal128Array, _>(null_buf, buffer.into())
}
}

Expand Down Expand Up @@ -794,7 +843,6 @@ mod tests {

#[test]
fn test_decimal256_iter() {
// TODO: Impl FromIterator for Decimal256Array
let mut builder = Decimal256Builder::new(30, 76, 6);
let value = BigInt::from_str_radix("12345", 10).unwrap();
let decimal1 = Decimal256::from_big_int(&value, 76, 6).unwrap();
Expand All @@ -811,4 +859,38 @@ mod tests {
let collected: Vec<_> = array.iter().collect();
assert_eq!(vec![Some(decimal1), None, Some(decimal2)], collected);
}

#[test]
fn test_from_iter_decimal256array() {
let value1 = BigInt::from_str_radix("12345", 10).unwrap();
let value2 = BigInt::from_str_radix("56789", 10).unwrap();

let array: Decimal256Array =
vec![Some(value1.clone()), None, Some(value2.clone())]
.into_iter()
.collect();
assert_eq!(array.len(), 3);
assert_eq!(array.data_type(), &DataType::Decimal256(76, 10));
assert_eq!(
Decimal256::from_big_int(
&value1,
DECIMAL256_MAX_PRECISION,
DECIMAL_DEFAULT_SCALE
)
.unwrap(),
array.value(0)
);
assert!(!array.is_null(0));
assert!(array.is_null(1));
assert_eq!(
Decimal256::from_big_int(
&value2,
DECIMAL256_MAX_PRECISION,
DECIMAL_DEFAULT_SCALE
)
.unwrap(),
array.value(2)
);
assert!(!array.is_null(2));
}
}

0 comments on commit c81dc82

Please sign in to comment.