Skip to content

Commit

Permalink
Implement BiCGSTAB (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlogan03 committed Oct 8, 2023
1 parent f03627f commit 00d4c1e
Show file tree
Hide file tree
Showing 2 changed files with 392 additions and 0 deletions.
1 change: 1 addition & 0 deletions sprs/src/sparse/linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use crate::{DenseVector, DenseVectorMut};
use num_traits::Num;

pub mod bicgstab;
pub mod etree;
pub mod ordering;
pub mod trisolve;
Expand Down
391 changes: 391 additions & 0 deletions sprs/src/sparse/linalg/bicgstab.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,391 @@
//! Stabilized bi-conjugate gradient solver for solving Ax = b with x unknown. Suitable for non-symmetric matrices.
//! A simple, sparse-sparse, serial, un-preconditioned implementation.
//!
//! # References
//! The original paper, which is thoroughly paywalled but widely referenced:
//!
//! ```text
//! H. A. van der Vorst,
//! “Bi-CGSTAB: A Fast and Smoothly Converging Variant of Bi-CG for the Solution of Nonsymmetric Linear Systems,”
//! SIAM Journal on Scientific and Statistical Computing, Jul. 2006, doi: 10.1137/0913035.
//! ```
//!
//! A useful discussion of computational cost and convergence characteristics for the CG
//! family of algorithms can be found in the paper that introduces QMRCGSTAB, in Table 1:
//!
//! ```text
//! T. F. Chan, E. Gallopoulos, V. Simoncini, T. Szeto, and C. H. Tong,
//! “A Quasi-Minimal Residual Variant of the Bi-CGSTAB Algorithm for Nonsymmetric Systems,”
//! SIAM J. Sci. Comput., vol. 15, no. 2, pp. 338–347, Mar. 1994, doi: 10.1137/0915023.
//! ```
//!
//! A less-paywalled pseudocode variant for this solver (as well as CG aand CGS) can be found at:
//! ```text
//! https://utminers.utep.edu/xzeng/2017spring_math5330/MATH_5330_Computational_Methods_of_Linear_Algebra_files/ln07.pdf
//! ```
//!
//! # Example
//! ```rust
//! use sprs::{CsMatI, CsVecI};
//! use sprs::linalg::bicgstab::BiCGSTAB;
//!
//! let a = CsMatI::new_csc(
//! (4, 4),
//! vec![0, 2, 4, 6, 8],
//! vec![0, 3, 1, 2, 1, 2, 0, 3],
//! vec![1.0, 2., 21., 6., 6., 2., 2., 8.],
//! );
//!
//! // Solve Ax=b
//! let tol = 1e-60;
//! let max_iter = 50;
//! let b = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0; 4]);
//! let x0 = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0, 1.0, 1.0, 1.0]);
//!
//! let res = BiCGSTAB::<'_, f64, _, _>::solve(
//! a.view(),
//! x0.view(),
//! b.view(),
//! tol,
//! max_iter,
//! )
//! .unwrap();
//! let b_recovered = &a * &res.x();
//!
//! println!("Iteration count {:?}", res.iteration_count());
//! println!("Soft restart count {:?}", res.soft_restart_count());
//! println!("Hard restart count {:?}", res.hard_restart_count());
//!
//! // Make sure the solved values match expectation
//! for (input, output) in
//! b.to_dense().iter().zip(b_recovered.to_dense().iter())
//! {
//! assert!(
//! (1.0 - input / output).abs() < tol,
//! "Solved output did not match input"
//! );
//! }
//! ```
//!
//! # Commentary
//! This implementation differs slightly from the common pseudocode variations in the following ways:
//! * Both soft-restart and hard-restart logic are present
//! * Soft restart on `r` becoming perpendicular to `rhat`
//! * Hard restart to check true error before claiming convergence
//! * Soft-restart logic uses a correct metric of perpendicularity instead of a magnitude heuristic
//! * The usual method, which compares a fixed value to `rho`, does not capture the fact that the
//! magnitude of `rho` will naturally decrease as the solver approaches convergence
//! * This change eliminates the effect where the a soft restart is performed on every iteration for the last few
//! iterations of any solve with a reasonable error tolerance
//! * Hard-restart logic provides some real guarantee of correctness
//! * The usual implementations keep a cheap, but inaccurate, running estimate of the error
//! * That decreases the cost of iterations by about half by eliminating a matrix-vector multiplication,
//! but allows the estimate of error to drift numerically, which causes the solver to return claiming
//! convergence when the solved output does not, in fact, match the input system
//! * This change guarantees that the solver will not return claiming convergence unless the solution
//! actually matches the input system, and will refresh its estimate of the error and continue iterations
//! if it has reached a falsely-converged state, continuing until it either reaches true convergence or
//! reaches maximum iterations
use crate::indexing::SpIndex;
use crate::sparse::{CsMatViewI, CsVecI, CsVecViewI};
use num_traits::One;

/// Stabilized bi-conjugate gradient solver
#[derive(Debug)]
pub struct BiCGSTAB<'a, T, I: SpIndex, Iptr: SpIndex> {
// Configuration
iteration_count: usize,
soft_restart_threshold: T,
soft_restart_count: usize,
hard_restart_count: usize,
// Problem statement: err = a * x - b
err: T,
a: CsMatViewI<'a, T, I, Iptr>,
b: CsVecViewI<'a, T, I>,
x: CsVecI<T, I>,
// Intermediate vectors
r: CsVecI<T, I>,
rhat: CsVecI<T, I>, // Arbitrary w/ dot(rhat, r) != 0
p: CsVecI<T, I>,
// Intermediate scalars
rho: T,
}

macro_rules! bicgstab_impl {
($T: ty) => {
impl<'a, I: SpIndex, Iptr: SpIndex> BiCGSTAB<'a, $T, I, Iptr> {
/// Initialize the solver with a fresh error estimate
pub fn new(
a: CsMatViewI<'a, $T, I, Iptr>,
x0: CsVecViewI<'a, $T, I>,
b: CsVecViewI<'a, $T, I>,
) -> Self {
let r = &b - &(&a.view() * &x0.view()).view();
let rhat = r.to_owned();
let p = r.to_owned();
let err = (&r).l2_norm();
let rho = err * err;
let x = x0.to_owned();
Self {
iteration_count: 0,
soft_restart_threshold: 0.1 * <$T>::one(), // A sensible default
soft_restart_count: 0,
hard_restart_count: 0,
err,
a,
b,
x,
r,
rhat,
p,
rho,
}
}

/// Attempt to solve the system to the given tolerance on normed error,
/// returning an error if convergence is not achieved within the given
/// number of iterations.
pub fn solve(
a: CsMatViewI<'a, $T, I, Iptr>,
x0: CsVecViewI<'a, $T, I>,
b: CsVecViewI<'a, $T, I>,
tol: $T,
max_iter: usize,
) -> Result<
Box<BiCGSTAB<'a, $T, I, Iptr>>,
Box<BiCGSTAB<'a, $T, I, Iptr>>,
> {
let mut solver = Self::new(a, x0, b);
for _ in 0..max_iter {
solver.step();
if solver.err() < tol {
// Check true error, which may not match the running error estimate
// and either continue iterations or return depending on result.
solver.hard_restart();
if solver.err() < tol {
return Ok(Box::new(solver));
}
}
}

// If we ran past our iteration limit, error, but still return results
Err(Box::new(solver))
}

/// Reset the reference direction `rhat` to be equal to `r`
/// to prevent a singularity in `1 / rho`.
pub fn soft_restart(&mut self) {
self.soft_restart_count += 1;
self.rhat = self.r.to_owned();
self.rho = self.err * self.err; // Shortcut to (&self.r).squared_l2_norm();
self.p = self.r.to_owned();
}

/// Recalculate the error vector from scratch using `a` and `b`.
pub fn hard_restart(&mut self) {
self.hard_restart_count += 1;
// Recalculate true error
self.r = &self.b - &(&self.a.view() * &self.x.view()).view();
self.err = (&self.r).l2_norm();
// Recalculate reference directions
self.soft_restart();
self.soft_restart_count -= 1; // Don't increment soft restart count for hard restarts
}

pub fn step(&mut self) -> $T {
self.iteration_count += 1;

// Gradient descent step
let v = &self.a.view() * &self.p.view();
let alpha = self.rho / ((&self.rhat).dot(&v));
let h = &self.x + &self.p.map(|x| x * alpha); // latest estimate of `x`

// Conjugate direction step
let s = &self.r - &v.map(|x| x * alpha); // s = A*h
let t = &self.a.view() * &s.view();
let omega = t.dot(&s) / &t.squared_l2_norm();
self.x = &h.view() + &s.map(|x| omega * x);

// Check error
self.r = &s - &t.map(|x| x * omega);
self.err = (&self.r).l2_norm();

// Prep for next pass
let rho_prev = self.rho;
self.rho = (&self.rhat).dot(&self.r);

// Soft-restart if `rhat` is becoming perpendicular to `r`.
if self.rho.abs() / (self.err * self.err)
< self.soft_restart_threshold
{
self.soft_restart();
} else {
let beta = (self.rho / rho_prev) * (alpha / omega);
self.p = &self.r
+ (&self.p - &v.map(|x| x * omega)).map(|x| x * beta);
}

self.err
}

/// Set the minimum value of `rho` to trigger a soft restart
pub fn with_restart_threshold(mut self, thresh: $T) -> Self {
self.soft_restart_threshold = thresh;
self
}

/// Iteration number
pub fn iteration_count(&self) -> usize {
self.iteration_count
}

/// The minimum value of `rho` to trigger a soft restart
pub fn soft_restart_threshold(&self) -> $T {
self.soft_restart_threshold
}

/// Number of soft restarts that have been done so far
pub fn soft_restart_count(&self) -> usize {
self.soft_restart_count
}

/// Number of soft restarts that have been done so far
pub fn hard_restart_count(&self) -> usize {
self.hard_restart_count
}

/// Latest estimate of normed error
pub fn err(&self) -> $T {
self.err
}

/// `dot(rhat, r)`, a measure of how well-conditioned the
/// update to the gradient descent step direction will be.
pub fn rho(&self) -> $T {
self.rho
}

/// The problem matrix
pub fn a(&self) -> CsMatViewI<'_, $T, I, Iptr> {
self.a.view()
}

/// The latest solution vector
pub fn x(&self) -> CsVecViewI<'_, $T, I> {
self.x.view()
}

/// The objective vector
pub fn b(&self) -> CsVecViewI<'_, $T, I> {
self.b.view()
}

/// Latest residual error vector
pub fn r(&self) -> CsVecViewI<'_, $T, I> {
self.r.view()
}

/// Latest reference direction.
/// `rhat` is arbitrary w/ dot(rhat, r) != 0,
/// and is reset parallel to `r` when needed to avoid
/// `1 / rho` becoming singular.
pub fn rhat(&self) -> CsVecViewI<'_, $T, I> {
self.rhat.view()
}

/// Gradient descent step direction, unscaled
pub fn p(&self) -> CsVecViewI<'_, $T, I> {
self.p.view()
}
}
};
}

bicgstab_impl!(f64);
bicgstab_impl!(f32);

#[cfg(test)]
mod test {
use super::*;
use crate::CsMatI;

#[test]
fn test_bicgstab_f32() {
let a = CsMatI::new_csc(
(4, 4),
vec![0, 2, 4, 6, 8],
vec![0, 3, 1, 2, 1, 2, 0, 3],
vec![1.0, 2., 21., 6., 6., 2., 2., 8.],
);

// Solve Ax=b
let tol = 1e-18;
let max_iter = 50;
let b = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0; 4]);
let x0 = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0, 1.0, 1.0, 1.0]);

let res = BiCGSTAB::<'_, f32, _, _>::solve(
a.view(),
x0.view(),
b.view(),
tol,
max_iter,
)
.unwrap();
let b_recovered = &a * &res.x();

println!("Iteration count {:?}", res.iteration_count());
println!("Soft restart count {:?}", res.soft_restart_count());
println!("Hard restart count {:?}", res.hard_restart_count());

// Make sure the solved values match expectation
for (input, output) in
b.to_dense().iter().zip(b_recovered.to_dense().iter())
{
assert!(
(1.0 - input / output).abs() < tol,
"Solved output did not match input"
);
}
}

#[test]
fn test_bicgstab_f64() {
let a = CsMatI::new_csc(
(4, 4),
vec![0, 2, 4, 6, 8],
vec![0, 3, 1, 2, 1, 2, 0, 3],
vec![1.0, 2., 21., 6., 6., 2., 2., 8.],
);

// Solve Ax=b
let tol = 1e-60;
let max_iter = 50;
let b = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0; 4]);
let x0 = CsVecI::new(4, vec![0, 1, 2, 3], vec![1.0, 1.0, 1.0, 1.0]);

let res = BiCGSTAB::<'_, f64, _, _>::solve(
a.view(),
x0.view(),
b.view(),
tol,
max_iter,
)
.unwrap();
let b_recovered = &a * &res.x();

println!("Iteration count {:?}", res.iteration_count());
println!("Soft restart count {:?}", res.soft_restart_count());
println!("Hard restart count {:?}", res.hard_restart_count());

// Make sure the solved values match expectation
for (input, output) in
b.to_dense().iter().zip(b_recovered.to_dense().iter())
{
assert!(
(1.0 - input / output).abs() < tol,
"Solved output did not match input"
);
}
}
}

0 comments on commit 00d4c1e

Please sign in to comment.