Skip to content

Commit

Permalink
Merge pull request #21 from dwmunster/feature/bfgs-speedup
Browse files Browse the repository at this point in the history
BFGS: Performance Improvements
  • Loading branch information
TobiasJacob authored May 30, 2024
2 parents bb8133a + 28511d6 commit 397d85b
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 64 deletions.
10 changes: 5 additions & 5 deletions src/primitives/arc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::cell::RefCell;
use std::rc::Rc;

use nalgebra::{DVector, DVectorView, SMatrix, SMatrixView, SVector, Vector2};
use nalgebra::{DVectorView, SMatrix, SMatrixView, SVector, Vector2};
use serde::{Deserialize, Serialize};

#[cfg(feature = "tsify")]
Expand Down Expand Up @@ -177,12 +177,12 @@ impl PrimitiveLike for Arc {
self.gradient = SVector::<f64, 3>::zeros();
}

fn get_data(&self) -> DVector<f64> {
DVector::from_row_slice(self.data.as_slice())
fn get_data(&self) -> DVectorView<f64> {
self.data.as_view()
}

fn get_gradient(&self) -> DVector<f64> {
DVector::from_row_slice(self.gradient.as_slice())
fn get_gradient(&self) -> DVectorView<f64> {
self.gradient.as_view()
}

fn set_data(&mut self, data: DVectorView<f64>) {
Expand Down
10 changes: 5 additions & 5 deletions src/primitives/circle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::cell::RefCell;
use std::rc::Rc;

use nalgebra::{DVector, DVectorView, SMatrix, SMatrixView, SVector};
use nalgebra::{DVectorView, SMatrix, SMatrixView, SVector};
use serde::{Deserialize, Serialize};

#[cfg(feature = "tsify")]
Expand Down Expand Up @@ -70,17 +70,17 @@ impl PrimitiveLike for Circle {
self.gradient = SVector::<f64, 1>::zeros();
}

fn get_data(&self) -> DVector<f64> {
DVector::from_row_slice(self.data.as_slice())
fn get_data(&self) -> DVectorView<f64> {
self.data.as_view()
}

fn set_data(&mut self, data: DVectorView<f64>) {
assert!(data.iter().all(|x| x.is_finite()));
self.data.copy_from(&data);
}

fn get_gradient(&self) -> DVector<f64> {
DVector::from_row_slice(self.gradient.as_slice())
fn get_gradient(&self) -> DVectorView<f64> {
self.gradient.as_view()
}

fn to_primitive(&self) -> super::Primitive {
Expand Down
17 changes: 11 additions & 6 deletions src/primitives/line.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::cell::RefCell;
use std::rc::Rc;

use nalgebra::{DVector, DVectorView, SMatrix, SMatrixView};
use nalgebra::{DVectorView, SMatrix, SMatrixView, SVector};
use serde::{Deserialize, Serialize};

#[cfg(feature = "tsify")]
Expand All @@ -16,11 +16,16 @@ use super::{PrimitiveCell, PrimitiveLike};
pub struct Line {
start: Rc<RefCell<Point2>>,
end: Rc<RefCell<Point2>>,
empty: SVector<f64, 0>,
}

impl Line {
pub fn new(start: Rc<RefCell<Point2>>, end: Rc<RefCell<Point2>>) -> Self {
Self { start, end }
Self {
start,
end,
empty: SVector::<f64, 0>::zeros(),
}
}

pub fn start(&self) -> Rc<RefCell<Point2>> {
Expand Down Expand Up @@ -71,18 +76,18 @@ impl PrimitiveLike for Line {
// Referenced points will zero their gradients automatically as they are part of the sketch
}

fn get_data(&self) -> DVector<f64> {
fn get_data(&self) -> DVectorView<f64> {
// empty vector
DVector::from_row_slice(&[])
self.empty.as_view()
}

fn set_data(&mut self, _data: DVectorView<f64>) {
// Do nothing
}

fn get_gradient(&self) -> DVector<f64> {
fn get_gradient(&self) -> DVectorView<f64> {
// empty vector
DVector::from_row_slice(&[])
self.empty.as_view()
}

fn to_primitive(&self) -> super::Primitive {
Expand Down
6 changes: 3 additions & 3 deletions src/primitives/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Debug;
use std::ptr;
use std::rc::Rc;

use nalgebra::{DVector, DVectorView};
use nalgebra::DVectorView;
use serde::{Deserialize, Serialize};

#[cfg(feature = "tsify")]
Expand All @@ -18,9 +18,9 @@ pub mod point2;
pub trait PrimitiveLike: Debug {
fn references(&self) -> Vec<PrimitiveCell>;
fn zero_gradient(&mut self);
fn get_data(&self) -> DVector<f64>;
fn get_data(&self) -> DVectorView<f64>;
fn set_data(&mut self, data: DVectorView<f64>);
fn get_gradient(&self) -> DVector<f64>;
fn get_gradient(&self) -> DVectorView<f64>;
fn to_primitive(&self) -> Primitive;
}

Expand Down
10 changes: 5 additions & 5 deletions src/primitives/point2.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use nalgebra::{DVector, DVectorView, SMatrix, SMatrixView, Vector2};
use nalgebra::{DVectorView, SMatrix, SMatrixView, Vector2};
use serde::{Deserialize, Serialize};

#[cfg(feature = "tsify")]
Expand Down Expand Up @@ -64,17 +64,17 @@ impl PrimitiveLike for Point2 {
self.gradient = Vector2::zeros();
}

fn get_data(&self) -> DVector<f64> {
DVector::from_row_slice(self.data.as_slice())
fn get_data(&self) -> DVectorView<f64> {
self.data.as_view()
}

fn set_data(&mut self, data: DVectorView<f64>) {
assert!(data.iter().all(|x| x.is_finite()));
self.data = Vector2::from_row_slice(data.as_slice());
}

fn get_gradient(&self) -> DVector<f64> {
DVector::from_row_slice(self.gradient.as_slice())
fn get_gradient(&self) -> DVectorView<f64> {
self.gradient.as_view()
}

fn to_primitive(&self) -> super::Primitive {
Expand Down
15 changes: 9 additions & 6 deletions src/sketch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ impl Sketch {
let mut data = DVector::zeros(self.get_n_dofs());
let mut i = 0;
for primitive in self.primitives.iter() {
let primitive_data = primitive.1.borrow().get_data();
let p = primitive.1.borrow();
let primitive_data = p.get_data();
data.rows_mut(i, primitive_data.len())
.copy_from(&primitive_data);
i += primitive_data.len();
Expand Down Expand Up @@ -135,7 +136,8 @@ impl Sketch {
let mut gradient = DVector::zeros(self.get_n_dofs());
let mut i = 0;
for primitive in self.primitives.iter() {
let primitive_gradient = primitive.1.borrow().get_gradient();
let p = primitive.1.borrow();
let primitive_gradient = p.get_gradient();
assert!(
primitive_gradient.iter().all(|x| x.is_finite()),
"Gradient contains NaN or Inf"
Expand Down Expand Up @@ -168,7 +170,8 @@ impl Sketch {
// Copy the gradient of the constraint to the jacobian
let mut j = 0;
for primitive in self.primitives.iter() {
let primitive_gradient = primitive.1.borrow().get_gradient();
let p = primitive.1.borrow();
let primitive_gradient = p.get_gradient();
jacobian
.row_mut(i)
.columns_mut(j, primitive_gradient.len())
Expand Down Expand Up @@ -202,13 +205,13 @@ impl Sketch {
// Compare to numerical gradients
let constraint_loss = constraint.borrow().loss_value();
for primitive in self.primitives.iter_mut() {
let original_value = primitive.1.borrow().get_data();
let analytical_gradient = primitive.1.borrow().get_gradient();
let original_value = primitive.1.borrow().get_data().clone_owned();
let analytical_gradient = primitive.1.borrow().get_gradient().clone_owned();
let mut numerical_gradient = DVector::zeros(original_value.len());
let n = primitive.1.borrow().get_data().len();
assert!(n == analytical_gradient.len());
for i in 0..n {
let mut new_value = original_value.clone();
let mut new_value = original_value.clone_owned();
new_value[i] += epsilon;
primitive
.1
Expand Down
48 changes: 14 additions & 34 deletions src/solvers/bfgs_solver.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use std::ops::DerefMut;
use std::{cell::RefCell, error::Error, rc::Rc};

use nalgebra::{DMatrix, UniformNorm};

use crate::sketch::Sketch;
use crate::solvers::line_search::line_search_wolfe;

use super::Solver;

const WOLFE_C1: f64 = 1e-4;
const WOLFE_C2: f64 = 0.9;

pub struct BFGSSolver {
max_iterations: usize,
min_loss: f64,
Expand Down Expand Up @@ -43,7 +42,6 @@ impl Solver for BFGSSolver {
);

let mut data = sketch.borrow().get_data();
let mut alpha = 1.0;
while iterations < self.max_iterations {
let loss = sketch.borrow_mut().get_loss();
if loss < self.min_loss {
Expand All @@ -64,32 +62,7 @@ impl Solver for BFGSSolver {
return Err("search direction contains non-finite values".into());
}

let m = gradient.dot(&p);
if m > 0.0 {
return Err("search direction is not a descent direction".into());
}

let curvature_condition = WOLFE_C2 * m;
while alpha > 1e-16 {
let new_data = &data + alpha * &p;
sketch.borrow_mut().set_data(new_data);
let new_loss = sketch.borrow_mut().get_loss();
// Sufficient decrease condition
if new_loss <= loss + WOLFE_C1 * alpha * m {
// Curvature condition
let new_gradient = sketch.borrow_mut().get_gradient();
let curvature = p.dot(&new_gradient);
if curvature >= curvature_condition {
break;
}
alpha *= 1.5;
} else {
alpha *= 0.5;
}
}
if alpha < 1e-16 {
return Err("could not find a suitable step size".into());
}
let alpha = line_search_wolfe(sketch.borrow_mut().deref_mut(), &p, &gradient)?;

let s = alpha * &p;

Expand All @@ -105,10 +78,17 @@ impl Solver for BFGSSolver {
// println!("Warning: s_dot_y is too small");
s_dot_y += 1e-6;
}
let factor = s_dot_y + (y.transpose() * &h * &y)[(0, 0)];
let new_h = &h + factor * (&s * s.transpose()) / (s_dot_y * s_dot_y)
- (&h * &y * s.transpose() + &s * &y.transpose() * &h) / s_dot_y;
h = new_h;

let hy = &h * &y;
let factor = (s_dot_y + y.dot(&hy)) / (s_dot_y * s_dot_y);
// h = 1.0*h + factor * s * s'
h.ger(factor, &s, &s, 1.0);

let hys_factor = -1.0 / s_dot_y;
// h = 1.0*h - hy * s' / s_dot_y
h.ger(hys_factor, &hy, &s, 1.0);
// h = 1.0*h - s' * hy' / s_dot_y
h.ger(hys_factor, &s, &hy, 1.0);

iterations += 1;
}
Expand Down
39 changes: 39 additions & 0 deletions src/solvers/line_search.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
use crate::sketch::Sketch;
use nalgebra::DVector;
use std::error::Error;

const WOLFE_C1: f64 = 1e-4;
const WOLFE_C2: f64 = 0.9;

pub(crate) fn line_search_wolfe(
sketch: &mut Sketch,
direction: &DVector<f64>,
gradient: &DVector<f64>,
) -> Result<f64, Box<dyn Error>> {
let mut alpha = 1.0;
let m = gradient.dot(direction);
if m >= 0.0 {
return Err("line search failed: search direction is not a descent direction".into());
}
let curvature_condition = WOLFE_C2 * m;
let loss = sketch.get_loss();
let x0 = sketch.get_data();
while alpha > 1e-16 {
let data = &x0 + alpha * direction;
sketch.set_data(data);
let new_loss = sketch.get_loss();
// Sufficent decrease condition
if new_loss <= loss + WOLFE_C1 * alpha * m {
// Curvature condition
let new_gradient = sketch.get_gradient();
let curvature = new_gradient.dot(direction);
if curvature >= curvature_condition {
return Ok(alpha);
}
alpha *= 1.5;
} else {
alpha *= 0.5;
}
}
Err("line search failed: alpha is too small".into())
}
2 changes: 2 additions & 0 deletions src/solvers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::{cell::RefCell, error::Error, rc::Rc};

use crate::sketch::Sketch;

mod line_search;

pub mod bfgs_solver;
pub mod gauss_newton_solver;
pub mod gradient_based_solver;
Expand Down

0 comments on commit 397d85b

Please sign in to comment.