Skip to content

Commit

Permalink
Use better initial guesses for Roots
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Dec 5, 2018
1 parent 1d8857a commit aab142f
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 62 deletions.
147 changes: 104 additions & 43 deletions src/biguint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1267,6 +1267,35 @@ impl Integer for BigUint {
}
}

#[inline]
fn fixpoint<F>(mut x: BigUint, max_bits: usize, f: F) -> BigUint
where
F: Fn(&BigUint) -> BigUint,
{
let mut xn = f(&x);

// If the value increased, then the initial guess must have been low.
// Repeat until we reverse course.
while x < xn {
// Sometimes an increase will go way too far, especially with large
// powers, and then take a long time to walk back. We know an upper
// bound based on bit size, so saturate on that.
x = if xn.bits() > max_bits {
BigUint::one() << max_bits
} else {
xn
};
xn = f(&x);
}

// Now keep repeating while the estimate is decreasing.
while x > xn {
x = xn;
xn = f(&x);
}
x
}

impl Roots for BigUint {
// nth_root, sqrt and cbrt use Newton's method to compute
// principal root of a given degree for a given integer.
Expand All @@ -1288,27 +1317,42 @@ impl Roots for BigUint {
_ => (),
}

let n = n as usize;
let n_min_1 = n - 1;

let guess = BigUint::one() << (self.bits() / n + 1);

let mut u = guess;
let mut s: BigUint;
// The root of non-zero values less than 2ⁿ can only be 1.
let bits = self.bits();
if bits <= n as usize {
return BigUint::one()
}

loop {
s = u;
let q = self / s.pow(n_min_1);
let t: BigUint = n_min_1 * &s + q;
// If we fit in `u64`, compute the root that way.
if let Some(x) = self.to_u64() {
return x.nth_root(n).into();
}

u = t / n;
let max_bits = bits / n as usize + 1;

if u >= s {
break;
let guess = if let Some(f) = self.to_f64() {
// We fit in `f64` (lossy), so get a better initial guess from that.
BigUint::from_f64((f.ln() / f64::from(n)).exp()).unwrap()
} else {
// Try to guess by scaling down such that it does fit in `f64`.
// With some (x * 2ⁿᵏ), its nth root ≈ (ⁿ√x * 2ᵏ)
let nsz = n as usize;
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
let root_scale = (extra_bits + (nsz - 1)) / nsz;
let scale = root_scale * nsz;
if scale < bits && bits - scale > nsz {
(self >> scale).nth_root(n) << root_scale
} else {
BigUint::one() << max_bits
}
}
};

s
let n_min_1 = n - 1;
fixpoint(guess, max_bits, move |s| {
let q = self / s.pow(n_min_1);
let t = n_min_1 * s + q;
t / n
})
}

// Reference:
Expand All @@ -1318,47 +1362,64 @@ impl Roots for BigUint {
return self.clone();
}

let guess = BigUint::one() << (self.bits() / 2 + 1);

let mut u = guess;
let mut s: BigUint;
// If we fit in `u64`, compute the root that way.
if let Some(x) = self.to_u64() {
return x.sqrt().into();
}

loop {
s = u;
let q = self / &s;
let t: BigUint = &s + q;
u = t >> 1;
let bits = self.bits();
let max_bits = bits / 2 as usize + 1;

if u >= s {
break;
}
}
let guess = if let Some(f) = self.to_f64() {
// We fit in `f64` (lossy), so get a better initial guess from that.
BigUint::from_f64(f.sqrt()).unwrap()
} else {
// Try to guess by scaling down such that it does fit in `f64`.
// With some (x * 2²ᵏ), its sqrt ≈ (√x * 2ᵏ)
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
let root_scale = (extra_bits + 1) / 2;
let scale = root_scale * 2;
(self >> scale).sqrt() << root_scale
};

s
fixpoint(guess, max_bits, move |s| {
let q = self / s;
let t = s + q;
t >> 1
})
}

fn cbrt(&self) -> Self {
if self.is_zero() || self.is_one() {
return self.clone();
}

let guess = BigUint::one() << (self.bits() / 3 + 1);
// If we fit in `u64`, compute the root that way.
if let Some(x) = self.to_u64() {
return x.cbrt().into();
}

let mut u = guess;
let mut s: BigUint;
let bits = self.bits();
let max_bits = bits / 3 as usize + 1;

loop {
s = u;
let q = self / (&s * &s);
let t: BigUint = (&s << 1) + q;
u = t / 3u32;
let guess = if let Some(f) = self.to_f64() {
// We fit in `f64` (lossy), so get a better initial guess from that.
BigUint::from_f64(f.cbrt()).unwrap()
} else {
// Try to guess by scaling down such that it does fit in `f64`.
// With some (x * 2³ᵏ), its cbrt ≈ (∛x * 2ᵏ)
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
let root_scale = (extra_bits + 2) / 3;
let scale = root_scale * 3;
(self >> scale).cbrt() << root_scale
};

if u >= s {
break;
}
}

s
fixpoint(guess, max_bits, move |s| {
let q = self / (s * s);
let t = (s << 1) + q;
t / 3u32
})
}
}

Expand Down
119 changes: 100 additions & 19 deletions tests/roots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,138 @@ extern crate num_bigint;
extern crate num_integer;
extern crate num_traits;

#[cfg(feature = "rand")]
extern crate rand;

mod biguint {
use num_bigint::BigUint;
use num_traits::Pow;
use std::str::FromStr;
use num_traits::{One, Pow, Zero};

fn check(x: u64, n: u32) {
let big_x = BigUint::from(x);
let res = big_x.nth_root(n);
fn check<T: Into<BigUint>>(x: T, n: u32) {
let x: BigUint = x.into();
let root = x.nth_root(n);
println!("check {}.nth_root({}) = {}", x, n, root);

if n == 2 {
assert_eq!(&res, &big_x.sqrt())
assert_eq!(root, x.sqrt())
} else if n == 3 {
assert_eq!(&res, &big_x.cbrt())
assert_eq!(root, x.cbrt())
}

let lo = root.pow(n);
assert!(lo <= x);
assert_eq!(lo.nth_root(n), root);
if !lo.is_zero() {
assert_eq!((&lo - 1u32).nth_root(n), &root - 1u32);
}

assert!(res.pow(n) <= big_x);
assert!((res + 1u32).pow(n) > big_x);
let hi = (&root + 1u32).pow(n);
assert!(hi > x);
assert_eq!(hi.nth_root(n), &root + 1u32);
assert_eq!((&hi - 1u32).nth_root(n), root);
}

#[test]
fn test_sqrt() {
check(99, 2);
check(100, 2);
check(120, 2);
check(99u32, 2);
check(100u32, 2);
check(120u32, 2);
}

#[test]
fn test_cbrt() {
check(8, 3);
check(26, 3);
check(8u32, 3);
check(26u32, 3);
}

#[test]
fn test_nth_root() {
check(0, 1);
check(10, 1);
check(100, 4);
check(0u32, 1);
check(10u32, 1);
check(100u32, 4);
}

#[test]
#[should_panic]
fn test_nth_root_n_is_zero() {
check(4, 0);
check(4u32, 0);
}

#[test]
fn test_nth_root_big() {
let x = BigUint::from_str("123_456_789").unwrap();
let x = BigUint::from(123_456_789_u32);
let expected = BigUint::from(6u32);

assert_eq!(x.nth_root(10), expected);
check(x, 10);
}

#[test]
fn test_nth_root_googol() {
let googol = BigUint::from(10u32).pow(100u32);

// perfect divisors of 100
for &n in &[2, 4, 5, 10, 20, 25, 50, 100] {
let expected = BigUint::from(10u32).pow(100u32 / n);
assert_eq!(googol.nth_root(n), expected);
check(googol.clone(), n);
}
}

#[test]
fn test_nth_root_twos() {
const EXP: u32 = 12;
const LOG2: usize = 1 << EXP;
let x = BigUint::one() << LOG2;

// the perfect divisors are just powers of two
for exp in 1..EXP + 1 {
let n = 2u32.pow(exp);
let expected = BigUint::one() << (LOG2 / n as usize);
assert_eq!(x.nth_root(n), expected);
check(x.clone(), n);
}

// degenerate cases should return quickly
assert!(x.nth_root(x.bits() as u32).is_one());
assert!(x.nth_root(std::i32::MAX as u32).is_one());
assert!(x.nth_root(std::u32::MAX).is_one());
}

#[cfg(feature = "rand")]
#[test]
fn test_roots_rand() {
use num_bigint::RandBigInt;
use rand::{thread_rng, Rng};
use rand::distributions::Uniform;

let mut rng = thread_rng();
let bit_range = Uniform::new(0, 2048);
let sample_bits: Vec<_> = rng.sample_iter(&bit_range).take(100).collect();
for bits in sample_bits {
let x = rng.gen_biguint(bits);
for n in 2..11 {
check(x.clone(), n);
}
check(x.clone(), 100);
}
}

#[test]
fn test_roots_rand1() {
// A random input that found regressions
let s = "575981506858479247661989091587544744717244516135539456183849\
986593934723426343633698413178771587697273822147578889823552\
182702908597782734558103025298880194023243541613924361007059\
353344183590348785832467726433749431093350684849462759540710\
026019022227591412417064179299354183441181373862905039254106\
4781867";
let x: BigUint = s.parse().unwrap();

check(x.clone(), 2);
check(x.clone(), 3);
check(x.clone(), 10);
check(x.clone(), 100);
}
}

Expand Down

0 comments on commit aab142f

Please sign in to comment.