From aab142fa24980400f4631789fede916ff43f123b Mon Sep 17 00:00:00 2001 From: Josh Stone Date: Wed, 5 Dec 2018 10:41:10 -0800 Subject: [PATCH] Use better initial guesses for Roots --- src/biguint.rs | 147 ++++++++++++++++++++++++++++++++++--------------- tests/roots.rs | 119 ++++++++++++++++++++++++++++++++------- 2 files changed, 204 insertions(+), 62 deletions(-) diff --git a/src/biguint.rs b/src/biguint.rs index 71242225..2b105eea 100644 --- a/src/biguint.rs +++ b/src/biguint.rs @@ -1267,6 +1267,35 @@ impl Integer for BigUint { } } +#[inline] +fn fixpoint(mut x: BigUint, max_bits: usize, f: F) -> BigUint +where + F: Fn(&BigUint) -> BigUint, +{ + let mut xn = f(&x); + + // If the value increased, then the initial guess must have been low. + // Repeat until we reverse course. + while x < xn { + // Sometimes an increase will go way too far, especially with large + // powers, and then take a long time to walk back. We know an upper + // bound based on bit size, so saturate on that. + x = if xn.bits() > max_bits { + BigUint::one() << max_bits + } else { + xn + }; + xn = f(&x); + } + + // Now keep repeating while the estimate is decreasing. + while x > xn { + x = xn; + xn = f(&x); + } + x +} + impl Roots for BigUint { // nth_root, sqrt and cbrt use Newton's method to compute // principal root of a given degree for a given integer. @@ -1288,27 +1317,42 @@ impl Roots for BigUint { _ => (), } - let n = n as usize; - let n_min_1 = n - 1; - - let guess = BigUint::one() << (self.bits() / n + 1); - - let mut u = guess; - let mut s: BigUint; + // The root of non-zero values less than 2ⁿ can only be 1. + let bits = self.bits(); + if bits <= n as usize { + return BigUint::one() + } - loop { - s = u; - let q = self / s.pow(n_min_1); - let t: BigUint = n_min_1 * &s + q; + // If we fit in `u64`, compute the root that way. + if let Some(x) = self.to_u64() { + return x.nth_root(n).into(); + } - u = t / n; + let max_bits = bits / n as usize + 1; - if u >= s { - break; + let guess = if let Some(f) = self.to_f64() { + // We fit in `f64` (lossy), so get a better initial guess from that. + BigUint::from_f64((f.ln() / f64::from(n)).exp()).unwrap() + } else { + // Try to guess by scaling down such that it does fit in `f64`. + // With some (x * 2ⁿᵏ), its nth root ≈ (ⁿ√x * 2ᵏ) + let nsz = n as usize; + let extra_bits = bits - (f64::MAX_EXP as usize - 1); + let root_scale = (extra_bits + (nsz - 1)) / nsz; + let scale = root_scale * nsz; + if scale < bits && bits - scale > nsz { + (self >> scale).nth_root(n) << root_scale + } else { + BigUint::one() << max_bits } - } + }; - s + let n_min_1 = n - 1; + fixpoint(guess, max_bits, move |s| { + let q = self / s.pow(n_min_1); + let t = n_min_1 * s + q; + t / n + }) } // Reference: @@ -1318,23 +1362,31 @@ impl Roots for BigUint { return self.clone(); } - let guess = BigUint::one() << (self.bits() / 2 + 1); - - let mut u = guess; - let mut s: BigUint; + // If we fit in `u64`, compute the root that way. + if let Some(x) = self.to_u64() { + return x.sqrt().into(); + } - loop { - s = u; - let q = self / &s; - let t: BigUint = &s + q; - u = t >> 1; + let bits = self.bits(); + let max_bits = bits / 2 as usize + 1; - if u >= s { - break; - } - } + let guess = if let Some(f) = self.to_f64() { + // We fit in `f64` (lossy), so get a better initial guess from that. + BigUint::from_f64(f.sqrt()).unwrap() + } else { + // Try to guess by scaling down such that it does fit in `f64`. + // With some (x * 2²ᵏ), its sqrt ≈ (√x * 2ᵏ) + let extra_bits = bits - (f64::MAX_EXP as usize - 1); + let root_scale = (extra_bits + 1) / 2; + let scale = root_scale * 2; + (self >> scale).sqrt() << root_scale + }; - s + fixpoint(guess, max_bits, move |s| { + let q = self / s; + let t = s + q; + t >> 1 + }) } fn cbrt(&self) -> Self { @@ -1342,23 +1394,32 @@ impl Roots for BigUint { return self.clone(); } - let guess = BigUint::one() << (self.bits() / 3 + 1); + // If we fit in `u64`, compute the root that way. + if let Some(x) = self.to_u64() { + return x.cbrt().into(); + } - let mut u = guess; - let mut s: BigUint; + let bits = self.bits(); + let max_bits = bits / 3 as usize + 1; - loop { - s = u; - let q = self / (&s * &s); - let t: BigUint = (&s << 1) + q; - u = t / 3u32; + let guess = if let Some(f) = self.to_f64() { + // We fit in `f64` (lossy), so get a better initial guess from that. + BigUint::from_f64(f.cbrt()).unwrap() + } else { + // Try to guess by scaling down such that it does fit in `f64`. + // With some (x * 2³ᵏ), its cbrt ≈ (∛x * 2ᵏ) + let extra_bits = bits - (f64::MAX_EXP as usize - 1); + let root_scale = (extra_bits + 2) / 3; + let scale = root_scale * 3; + (self >> scale).cbrt() << root_scale + }; - if u >= s { - break; - } - } - s + fixpoint(guess, max_bits, move |s| { + let q = self / (s * s); + let t = (s << 1) + q; + t / 3u32 + }) } } diff --git a/tests/roots.rs b/tests/roots.rs index 9c4ee484..d0cdf654 100644 --- a/tests/roots.rs +++ b/tests/roots.rs @@ -2,57 +2,138 @@ extern crate num_bigint; extern crate num_integer; extern crate num_traits; +#[cfg(feature = "rand")] +extern crate rand; + mod biguint { use num_bigint::BigUint; - use num_traits::Pow; - use std::str::FromStr; + use num_traits::{One, Pow, Zero}; - fn check(x: u64, n: u32) { - let big_x = BigUint::from(x); - let res = big_x.nth_root(n); + fn check>(x: T, n: u32) { + let x: BigUint = x.into(); + let root = x.nth_root(n); + println!("check {}.nth_root({}) = {}", x, n, root); if n == 2 { - assert_eq!(&res, &big_x.sqrt()) + assert_eq!(root, x.sqrt()) } else if n == 3 { - assert_eq!(&res, &big_x.cbrt()) + assert_eq!(root, x.cbrt()) + } + + let lo = root.pow(n); + assert!(lo <= x); + assert_eq!(lo.nth_root(n), root); + if !lo.is_zero() { + assert_eq!((&lo - 1u32).nth_root(n), &root - 1u32); } - assert!(res.pow(n) <= big_x); - assert!((res + 1u32).pow(n) > big_x); + let hi = (&root + 1u32).pow(n); + assert!(hi > x); + assert_eq!(hi.nth_root(n), &root + 1u32); + assert_eq!((&hi - 1u32).nth_root(n), root); } #[test] fn test_sqrt() { - check(99, 2); - check(100, 2); - check(120, 2); + check(99u32, 2); + check(100u32, 2); + check(120u32, 2); } #[test] fn test_cbrt() { - check(8, 3); - check(26, 3); + check(8u32, 3); + check(26u32, 3); } #[test] fn test_nth_root() { - check(0, 1); - check(10, 1); - check(100, 4); + check(0u32, 1); + check(10u32, 1); + check(100u32, 4); } #[test] #[should_panic] fn test_nth_root_n_is_zero() { - check(4, 0); + check(4u32, 0); } #[test] fn test_nth_root_big() { - let x = BigUint::from_str("123_456_789").unwrap(); + let x = BigUint::from(123_456_789_u32); let expected = BigUint::from(6u32); assert_eq!(x.nth_root(10), expected); + check(x, 10); + } + + #[test] + fn test_nth_root_googol() { + let googol = BigUint::from(10u32).pow(100u32); + + // perfect divisors of 100 + for &n in &[2, 4, 5, 10, 20, 25, 50, 100] { + let expected = BigUint::from(10u32).pow(100u32 / n); + assert_eq!(googol.nth_root(n), expected); + check(googol.clone(), n); + } + } + + #[test] + fn test_nth_root_twos() { + const EXP: u32 = 12; + const LOG2: usize = 1 << EXP; + let x = BigUint::one() << LOG2; + + // the perfect divisors are just powers of two + for exp in 1..EXP + 1 { + let n = 2u32.pow(exp); + let expected = BigUint::one() << (LOG2 / n as usize); + assert_eq!(x.nth_root(n), expected); + check(x.clone(), n); + } + + // degenerate cases should return quickly + assert!(x.nth_root(x.bits() as u32).is_one()); + assert!(x.nth_root(std::i32::MAX as u32).is_one()); + assert!(x.nth_root(std::u32::MAX).is_one()); + } + + #[cfg(feature = "rand")] + #[test] + fn test_roots_rand() { + use num_bigint::RandBigInt; + use rand::{thread_rng, Rng}; + use rand::distributions::Uniform; + + let mut rng = thread_rng(); + let bit_range = Uniform::new(0, 2048); + let sample_bits: Vec<_> = rng.sample_iter(&bit_range).take(100).collect(); + for bits in sample_bits { + let x = rng.gen_biguint(bits); + for n in 2..11 { + check(x.clone(), n); + } + check(x.clone(), 100); + } + } + + #[test] + fn test_roots_rand1() { + // A random input that found regressions + let s = "575981506858479247661989091587544744717244516135539456183849\ + 986593934723426343633698413178771587697273822147578889823552\ + 182702908597782734558103025298880194023243541613924361007059\ + 353344183590348785832467726433749431093350684849462759540710\ + 026019022227591412417064179299354183441181373862905039254106\ + 4781867"; + let x: BigUint = s.parse().unwrap(); + + check(x.clone(), 2); + check(x.clone(), 3); + check(x.clone(), 10); + check(x.clone(), 100); } }