Skip to content

Commit

Permalink
Merge pull request #31 from dwmunster/bugfix/non-terminating-line-search
Browse files Browse the repository at this point in the history
Solvers: Avoid non-terminating condition in line search
  • Loading branch information
TobiasJacob authored Jul 7, 2024
2 parents f2f4590 + 6c53f8d commit f77c3e0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
26 changes: 23 additions & 3 deletions src/solvers/bfgs_solver.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use nalgebra::{DMatrix, UniformNorm};
use std::error::Error;

use nalgebra::{DMatrix, UniformNorm};

use crate::sketch::Sketch;
use crate::solvers::line_search::line_search_wolfe;
use crate::solvers::line_search::{line_search_wolfe, LineSearchError};

use super::Solver;

Expand Down Expand Up @@ -44,6 +45,8 @@ impl Solver for BFGSSolver {

let mut h = DMatrix::identity(n, n);

let mut recently_reset = false;

while iterations < self.max_iterations {
let loss = sketch.get_loss();
if loss < self.min_loss {
Expand All @@ -64,7 +67,24 @@ impl Solver for BFGSSolver {
return Err("search direction contains non-finite values".into());
}

let alpha = line_search_wolfe(sketch, &p, &gradient)?;
let alpha = match line_search_wolfe(sketch, &p, &gradient) {
Ok(alpha) => alpha,
Err(LineSearchError::SearchFailed) => {
// If the line search could not find a suitable step size, the Hessian
// approximation may be inaccurate. Resetting the Hessian to the identity matrix
// will restart with a steepest descent step and hopefully build a better
// approximation.
if recently_reset {
return Err("bfgs: line search failed twice in a row".into());
}
h = DMatrix::identity(n, n);
recently_reset = true;
continue;
}
Err(e) => return Err(e.into()),
};

recently_reset = false;

let s = alpha * &p;

Expand Down
21 changes: 15 additions & 6 deletions src/solvers/line_search.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,37 @@
use crate::sketch::Sketch;
use nalgebra::DVector;
use std::error::Error;
use thiserror::Error;

#[derive(Debug, Error)]
pub enum LineSearchError {
#[error("line search failed: search direction is not a descent direction")]
NotDescentDirection,
#[error("line search failed: could not find a suitable step size")]
SearchFailed,
}

const WOLFE_C1: f64 = 1e-4;
const WOLFE_C2: f64 = 0.9;
const MAX_ITER: usize = 15;

pub(crate) fn line_search_wolfe(
sketch: &mut Sketch,
direction: &DVector<f64>,
gradient: &DVector<f64>,
) -> Result<f64, Box<dyn Error>> {
) -> Result<f64, LineSearchError> {
let mut alpha = 1.0;
let m = gradient.dot(direction);
if m >= 0.0 {
return Err("line search failed: search direction is not a descent direction".into());
return Err(LineSearchError::NotDescentDirection);
}
let curvature_condition = WOLFE_C2 * m;
let loss = sketch.get_loss();
let x0 = sketch.get_data();
while alpha > 1e-16 {
for _i in 0..MAX_ITER {
let data = &x0 + alpha * direction;
sketch.set_data(data);
let new_loss = sketch.get_loss();
// Sufficent decrease condition
// Sufficient decrease condition
if new_loss <= loss + WOLFE_C1 * alpha * m {
// Curvature condition
let new_gradient = sketch.get_gradient();
Expand All @@ -35,5 +44,5 @@ pub(crate) fn line_search_wolfe(
alpha *= 0.5;
}
}
Err("line search failed: alpha is too small".into())
Err(LineSearchError::SearchFailed)
}

0 comments on commit f77c3e0

Please sign in to comment.