Skip to content

Commit

Permalink
fix: correct decoding of rust_decimal::Decimal for high-precision v…
Browse files Browse the repository at this point in the history
…alues

also fixes handling of feature flags
  • Loading branch information
abonander committed Oct 16, 2023
1 parent 540baf7 commit e3cf1e1
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 38 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion sqlx-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ migrate = ["sqlx-core/migrate"]
offline = ["sqlx-core/offline"]

# Type integration features which require additional dependencies
rust_decimal = ["dep:rust_decimal", "dep:num-bigint"]
rust_decimal = ["dep:rust_decimal", "rust_decimal/maths"]
bigdecimal = ["dep:bigdecimal", "dep:num-bigint"]

[dependencies]
Expand Down
57 changes: 26 additions & 31 deletions sqlx-postgres/src/types/rust_decimal.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use num_bigint::{BigInt, Sign};
use rust_decimal::{
prelude::{ToPrimitive, Zero},
Decimal,
};
use rust_decimal::{prelude::Zero, Decimal};

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand All @@ -11,6 +7,8 @@ use crate::types::numeric::{PgNumeric, PgNumericSign};
use crate::types::Type;
use crate::{PgArgumentBuffer, PgHasArrayType, PgTypeInfo, PgValueFormat, PgValueRef, Postgres};

use rust_decimal::MathematicalOps;

impl Type<Postgres> for Decimal {
fn type_info() -> PgTypeInfo {
PgTypeInfo::NUMERIC
Expand All @@ -27,7 +25,7 @@ impl TryFrom<PgNumeric> for Decimal {
type Error = BoxDynError;

fn try_from(numeric: PgNumeric) -> Result<Self, BoxDynError> {
let (digits, sign, weight) = match numeric {
let (digits, sign, mut weight) = match numeric {
PgNumeric::Number {
digits,
sign,
Expand All @@ -41,39 +39,33 @@ impl TryFrom<PgNumeric> for Decimal {
};

if digits.is_empty() {
// Postgres returns an empty digit array for 0 but BigInt expects at least one zero
// Postgres returns an empty digit array for 0
return Ok(0u64.into());
}

let sign = match sign {
PgNumericSign::Positive => Sign::Plus,
PgNumericSign::Negative => Sign::Minus,
};
let mut value = Decimal::ZERO;

// Sum over `digits`, multiply each by its weight and add it to `value`.
for digit in digits {
let mul = Decimal::from(10_000i16)
.checked_powi(weight as i64)
.ok_or("value not representable as rust_decimal::Decimal")?;

let part = Decimal::from(digit) * mul;

// weight is 0 if the decimal point falls after the first base-10000 digit
let scale = (digits.len() as i64 - weight as i64 - 1) * 4;
value = value
.checked_add(part)
.ok_or("value not representable as rust_decimal::Decimal")?;

// no optimized algorithm for base-10 so use base-100 for faster processing
let mut cents = Vec::with_capacity(digits.len() * 2);
for digit in &digits {
cents.push((digit / 100) as u8);
cents.push((digit % 100) as u8);
weight = weight.checked_sub(1).ok_or("weight underflowed")?;
}

let bigint = BigInt::from_radix_be(sign, &cents, 100)
.ok_or("PgNumeric contained an out-of-range digit")?;

match (bigint.to_i128(), scale) {
// A negative scale, meaning we have nothing on the right and must
// add zeroes to the left.
(Some(num), scale) if scale < 0 => Ok(Decimal::from_i128_with_scale(
num * 10i128.pow(scale.abs() as u32),
0,
)),
// A positive scale, so we have decimals on the right.
(Some(num), _) => Ok(Decimal::from_i128_with_scale(num, scale as u32)),
(None, _) => Err("Decimal's integer part out of range.".into()),
match sign {
PgNumericSign::Positive => value.set_sign_positive(true),
PgNumericSign::Negative => value.set_sign_negative(true),
}

Ok(value)
}
}

Expand Down Expand Up @@ -403,4 +395,7 @@ mod decimal_to_pgnumeric {
}
);
}

#[test]
fn issue_666_trailing_zeroes_at_max_precision() {}
}
4 changes: 2 additions & 2 deletions tests/mysql/types.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
extern crate time_ as time;

#[cfg(feature = "decimal")]
#[cfg(feature = "rust_decimal")]
use std::str::FromStr;

use sqlx::mysql::MySql;
Expand Down Expand Up @@ -223,7 +223,7 @@ test_type!(bigdecimal<sqlx::types::BigDecimal>(
"CAST(12345.6789 AS DECIMAL(9, 4))" == "12345.6789".parse::<sqlx::types::BigDecimal>().unwrap(),
));

#[cfg(feature = "decimal")]
#[cfg(feature = "rust_decimal")]
test_type!(decimal<sqlx::types::Decimal>(MySql,
"CAST(0 as DECIMAL(0, 0))" == sqlx::types::Decimal::from_str("0").unwrap(),
"CAST(1 AS DECIMAL(1, 0))" == sqlx::types::Decimal::from_str("1").unwrap(),
Expand Down
8 changes: 5 additions & 3 deletions tests/postgres/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use sqlx::postgres::types::{Oid, PgCiText, PgInterval, PgMoney, PgRange};
use sqlx::postgres::Postgres;
use sqlx_test::{test_decode_type, test_prepared_type, test_type};

#[cfg(any(postgres_14, postgres_15))]
use std::str::FromStr;

test_type!(null<Option<i16>>(Postgres,
Expand Down Expand Up @@ -475,7 +474,7 @@ test_type!(numrange_bigdecimal<PgRange<sqlx::types::BigDecimal>>(Postgres,
Bound::Excluded("2.4".parse::<sqlx::types::BigDecimal>().unwrap())))
));

#[cfg(feature = "decimal")]
#[cfg(feature = "rust_decimal")]
test_type!(decimal<sqlx::types::Decimal>(Postgres,
"0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(),
"1::numeric" == sqlx::types::Decimal::from_str("1").unwrap(),
Expand All @@ -484,9 +483,12 @@ test_type!(decimal<sqlx::types::Decimal>(Postgres,
"0.01234::numeric" == sqlx::types::Decimal::from_str("0.01234").unwrap(),
"12.34::numeric" == sqlx::types::Decimal::from_str("12.34").unwrap(),
"12345.6789::numeric" == sqlx::types::Decimal::from_str("12345.6789").unwrap(),
// https://github.com/launchbadge/sqlx/issues/666#issuecomment-683872154
"17.905625985174584660842500258::numeric" == sqlx::types::Decimal::from_str("17.905625985174584660842500258").unwrap(),
"-17.905625985174584660842500258::numeric" == sqlx::types::Decimal::from_str("-17.905625985174584660842500258").unwrap(),
));

#[cfg(feature = "decimal")]
#[cfg(feature = "rust_decimal")]
test_type!(numrange_decimal<PgRange<sqlx::types::Decimal>>(Postgres,
"'(1.3,2.4)'::numrange" == PgRange::from(
(Bound::Excluded(sqlx::types::Decimal::from_str("1.3").unwrap()),
Expand Down

0 comments on commit e3cf1e1

Please sign in to comment.