Skip to content

Commit

Permalink
Move with_precision_and_scale to trait (#2292)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored Aug 3, 2022
1 parent f40403f commit e835853
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 66 deletions.
138 changes: 84 additions & 54 deletions arrow/src/array/array_decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ use super::{BooleanBufferBuilder, FixedSizeBinaryArray};
pub use crate::array::DecimalIter;
use crate::buffer::{Buffer, MutableBuffer};
use crate::datatypes::{
validate_decimal_precision, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE,
validate_decimal256_precision, validate_decimal_precision, DECIMAL256_MAX_PRECISION,
DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE,
};
use crate::datatypes::{DataType, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE};
use crate::error::{ArrowError, Result};
Expand Down Expand Up @@ -95,6 +96,8 @@ pub trait BasicDecimalArray<T: BasicDecimal, U: From<ArrayData>>:
{
const VALUE_LENGTH: i32;
const DEFAULT_TYPE: DataType;
const MAX_PRECISION: usize;
const MAX_SCALE: usize;

fn data(&self) -> &ArrayData;

Expand Down Expand Up @@ -246,12 +249,72 @@ pub trait BasicDecimalArray<T: BasicDecimal, U: From<ArrayData>>:
fn default_type() -> DataType {
Self::DEFAULT_TYPE
}

/// Returns a Decimal array with the same data as self, with the
/// specified precision.
///
/// Returns an Error if:
/// 1. `precision` is larger than [`Self::MAX_PRECISION`]
/// 2. `scale` is larger than [`Self::MAX_SCALE`];
/// 3. `scale` is > `precision`
fn with_precision_and_scale(self, precision: usize, scale: usize) -> Result<U>
where
Self: Sized,
{
if precision > Self::MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
precision,
Self::MAX_PRECISION
)));
}
if scale > Self::MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than max {}",
scale,
Self::MAX_SCALE
)));
}
if scale > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than precision {}",
scale, precision
)));
}

// Ensure that all values are within the requested
// precision. For performance, only check if the precision is
// decreased
self.validate_decimal_precision(precision)?;

let data_type = if Self::VALUE_LENGTH == 16 {
DataType::Decimal128(self.precision(), self.scale())
} else {
DataType::Decimal256(self.precision(), self.scale())
};
assert_eq!(self.data().data_type(), &data_type);

// safety: self.data is valid DataType::Decimal as checked above
let new_data_type = if Self::VALUE_LENGTH == 16 {
DataType::Decimal128(precision, scale)
} else {
DataType::Decimal256(precision, scale)
};

Ok(self.data().clone().with_data_type(new_data_type).into())
}

/// Validates decimal values in this array can be properly interpreted
/// with the specified precision.
fn validate_decimal_precision(&self, precision: usize) -> Result<()>;
}

impl BasicDecimalArray<Decimal128, Decimal128Array> for Decimal128Array {
const VALUE_LENGTH: i32 = 16;
const DEFAULT_TYPE: DataType =
DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const MAX_PRECISION: usize = DECIMAL128_MAX_PRECISION;
const MAX_SCALE: usize = DECIMAL128_MAX_SCALE;

fn data(&self) -> &ArrayData {
&self.data
Expand All @@ -264,12 +327,23 @@ impl BasicDecimalArray<Decimal128, Decimal128Array> for Decimal128Array {
fn scale(&self) -> usize {
self.scale
}

fn validate_decimal_precision(&self, precision: usize) -> Result<()> {
if precision < self.precision {
for v in self.iter().flatten() {
validate_decimal_precision(v.as_i128(), precision)?;
}
}
Ok(())
}
}

impl BasicDecimalArray<Decimal256, Decimal256Array> for Decimal256Array {
const VALUE_LENGTH: i32 = 32;
const DEFAULT_TYPE: DataType =
DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE);
const MAX_PRECISION: usize = DECIMAL256_MAX_PRECISION;
const MAX_SCALE: usize = DECIMAL256_MAX_SCALE;

fn data(&self) -> &ArrayData {
&self.data
Expand All @@ -282,6 +356,15 @@ impl BasicDecimalArray<Decimal256, Decimal256Array> for Decimal256Array {
fn scale(&self) -> usize {
self.scale
}

fn validate_decimal_precision(&self, precision: usize) -> Result<()> {
if precision < self.precision {
for v in self.iter().flatten() {
validate_decimal256_precision(&v.to_string(), precision)?;
}
}
Ok(())
}
}

impl Decimal128Array {
Expand All @@ -302,59 +385,6 @@ impl Decimal128Array {
};
Decimal128Array::from(data)
}

/// Returns a Decimal128Array with the same data as self, with the
/// specified precision.
///
/// Returns an Error if:
/// 1. `precision` is larger than [`DECIMAL128_MAX_PRECISION`]
/// 2. `scale` is larger than [`DECIMAL128_MAX_SCALE`];
/// 3. `scale` is > `precision`
pub fn with_precision_and_scale(
mut self,
precision: usize,
scale: usize,
) -> Result<Self> {
if precision > DECIMAL128_MAX_PRECISION {
return Err(ArrowError::InvalidArgumentError(format!(
"precision {} is greater than max {}",
precision, DECIMAL128_MAX_PRECISION
)));
}
if scale > DECIMAL128_MAX_SCALE {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than max {}",
scale, DECIMAL128_MAX_SCALE
)));
}
if scale > precision {
return Err(ArrowError::InvalidArgumentError(format!(
"scale {} is greater than precision {}",
scale, precision
)));
}

// Ensure that all values are within the requested
// precision. For performance, only check if the precision is
// decreased
if precision < self.precision {
for v in self.iter().flatten() {
validate_decimal_precision(v.as_i128(), precision)?;
}
}

assert_eq!(
self.data.data_type(),
&DataType::Decimal128(self.precision, self.scale)
);

// safety: self.data is valid DataType::Decimal as checked above
let new_data_type = DataType::Decimal128(precision, scale);
self.precision = precision;
self.scale = scale;
self.data = self.data.with_data_type(new_data_type);
Ok(self)
}
}

impl From<ArrayData> for Decimal128Array {
Expand Down
1 change: 1 addition & 0 deletions arrow/src/array/equal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ mod tests {
use std::convert::TryFrom;
use std::sync::Arc;

use crate::array::BasicDecimalArray;
use crate::array::{
array::Array, ArrayData, ArrayDataBuilder, ArrayRef, BooleanArray,
FixedSizeBinaryBuilder, FixedSizeListBuilder, GenericBinaryArray, Int32Builder,
Expand Down
3 changes: 2 additions & 1 deletion arrow/src/array/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ mod tests {

use super::*;

use crate::array::BasicDecimalArray;
use crate::array::Decimal128Array;
use crate::{
array::{
Expand Down Expand Up @@ -708,7 +709,7 @@ mod tests {
fn test_decimal() {
let decimal_array =
create_decimal_array(&[Some(1), Some(2), None, Some(3)], 10, 3);
let arrays = vec![decimal_array.data()];
let arrays = vec![Array::data(&decimal_array)];
let mut a = MutableArrayData::new(arrays, true, 3);
a.extend(0, 0, 3);
a.extend(0, 2, 3);
Expand Down
3 changes: 2 additions & 1 deletion arrow/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,7 @@ impl<'a> ArrowArrayChild<'a> {
#[cfg(test)]
mod tests {
use super::*;
use crate::array::BasicDecimalArray;
use crate::array::{
export_array_into_raw, make_array, Array, ArrayData, BooleanArray,
Decimal128Array, DictionaryArray, DurationSecondArray, FixedSizeBinaryArray,
Expand Down Expand Up @@ -953,7 +954,7 @@ mod tests {
.unwrap();

// export it
let array = ArrowArray::try_from(original_array.data().clone())?;
let array = ArrowArray::try_from(Array::data(&original_array).clone())?;

// (simulate consumer) import it
let data = ArrayData::try_from(array)?;
Expand Down
5 changes: 2 additions & 3 deletions arrow/src/util/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
//! available unless `feature = "prettyprint"` is enabled.

use crate::{array::ArrayRef, record_batch::RecordBatch};
use std::fmt::Display;

use comfy_table::{Cell, Table};
use std::fmt::Display;

use crate::error::Result;

Expand Down Expand Up @@ -108,7 +107,7 @@ fn create_column(field: &str, columns: &[ArrayRef]) -> Result<Table> {
mod tests {
use crate::{
array::{
self, new_null_array, Array, Date32Array, Date64Array,
self, new_null_array, Array, BasicDecimalArray, Date32Array, Date64Array,
FixedSizeBinaryBuilder, Float16Array, Int32Array, PrimitiveBuilder,
StringArray, StringBuilder, StringDictionaryBuilder, StructArray,
Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray,
Expand Down
4 changes: 2 additions & 2 deletions parquet/src/arrow/array_reader/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use crate::data_type::DataType;
use crate::errors::{ParquetError, Result};
use crate::schema::types::ColumnDescPtr;
use arrow::array::{
ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, Decimal128Array,
Float32Array, Float64Array, Int32Array, Int64Array,
ArrayDataBuilder, ArrayRef, BasicDecimalArray, BooleanArray, BooleanBufferBuilder,
Decimal128Array, Float32Array, Float64Array, Int32Array, Int64Array,
};
use arrow::buffer::Buffer;
use arrow::datatypes::DataType as ArrowType;
Expand Down
10 changes: 5 additions & 5 deletions parquet/src/arrow/buffer/converter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

use crate::data_type::{ByteArray, FixedLenByteArray, Int96};
use arrow::array::{
Array, ArrayRef, BinaryArray, BinaryBuilder, Decimal128Array, FixedSizeBinaryArray,
FixedSizeBinaryBuilder, IntervalDayTimeArray, IntervalDayTimeBuilder,
IntervalYearMonthArray, IntervalYearMonthBuilder, LargeBinaryArray,
LargeBinaryBuilder, LargeStringArray, LargeStringBuilder, StringArray, StringBuilder,
TimestampNanosecondArray,
Array, ArrayRef, BasicDecimalArray, BinaryArray, BinaryBuilder, Decimal128Array,
FixedSizeBinaryArray, FixedSizeBinaryBuilder, IntervalDayTimeArray,
IntervalDayTimeBuilder, IntervalYearMonthArray, IntervalYearMonthBuilder,
LargeBinaryArray, LargeBinaryBuilder, LargeStringArray, LargeStringBuilder,
StringArray, StringBuilder, TimestampNanosecondArray,
};
use std::convert::{From, TryInto};
use std::sync::Arc;
Expand Down

0 comments on commit e835853

Please sign in to comment.