diff --git a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp index 9fce1776b8..14d2a705b4 100644 --- a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp +++ b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.cpp @@ -28,6 +28,46 @@ namespace gtsam { typedef internal::NonlinearOptimizerState State; +/// Fletcher-Reeves formula for computing β, the direction of steepest descent. +static double FletcherReeves(const VectorValues& currentGradient, + const VectorValues& prevGradient) { + // Fletcher-Reeves: beta = g_n'*g_n/g_n-1'*g_n-1 + const double beta = std::max(0.0, currentGradient.dot(currentGradient) / + prevGradient.dot(prevGradient)); + return beta; +} + +/// Polak-Ribiere formula for computing β, the direction of steepest descent. +static double PolakRibiere(const VectorValues& currentGradient, + const VectorValues& prevGradient) { + // Polak-Ribiere: beta = g_n'*(g_n-g_n-1)/g_n-1'*g_n-1 + const double beta = + std::max(0.0, currentGradient.dot(currentGradient - prevGradient) / + prevGradient.dot(prevGradient)); + return beta; +} + +/// The Hestenes-Stiefel formula for computing β, the direction of steepest descent. +static double HestenesStiefel(const VectorValues& currentGradient, + const VectorValues& prevGradient, + const VectorValues& direction) { + // Hestenes-Stiefel: beta = g_n'*(g_n-g_n-1)/(-s_n-1')*(g_n-g_n-1) + VectorValues d = currentGradient - prevGradient; + const double beta = std::max(0.0, currentGradient.dot(d) / -direction.dot(d)); + return beta; +} + +/// The Dai-Yuan formula for computing β, the direction of steepest descent. +static double DaiYuan(const VectorValues& currentGradient, + const VectorValues& prevGradient, + const VectorValues& direction) { + // Dai-Yuan: beta = g_n'*g_n/(-s_n-1')*(g_n-g_n-1) + const double beta = + std::max(0.0, currentGradient.dot(currentGradient) / + -direction.dot(currentGradient - prevGradient)); + return beta; +} + /** * @brief Return the gradient vector of a nonlinear factor graph * @param nfg the graph @@ -43,7 +83,7 @@ static VectorValues gradientInPlace(const NonlinearFactorGraph& nfg, NonlinearConjugateGradientOptimizer::NonlinearConjugateGradientOptimizer( const NonlinearFactorGraph& graph, const Values& initialValues, - const Parameters& params) + const Parameters& params, const DirectionMethod& directionMethod) : Base(graph, std::unique_ptr( new State(initialValues, graph.error(initialValues)))), params_(params) {} @@ -169,10 +209,22 @@ NonlinearConjugateGradientOptimizer::nonlinearConjugateGradient( } else { prevGradient = currentGradient; currentGradient = gradient(currentValues); - // Polak-Ribiere: beta = g'*(g_n-g_n-1)/g_n-1'*g_n-1 - const double beta = - std::max(0.0, currentGradient.dot(currentGradient - prevGradient) / - prevGradient.dot(prevGradient)); + double beta; + + switch (directionMethod_) { + case DirectionMethod::FletcherReeves: + beta = FletcherReeves(currentGradient, prevGradient); + break; + case DirectionMethod::PolakRibiere: + beta = PolakRibiere(currentGradient, prevGradient); + break; + case DirectionMethod::HestenesStiefel: + beta = HestenesStiefel(currentGradient, prevGradient, direction); + break; + case DirectionMethod::DaiYuan: + beta = DaiYuan(currentGradient, prevGradient, direction); + break; + } direction = currentGradient + (beta * direction); } diff --git a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h index cdc0634d6f..f9cd22361f 100644 --- a/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h +++ b/gtsam/nonlinear/NonlinearConjugateGradientOptimizer.h @@ -31,16 +31,25 @@ class GTSAM_EXPORT NonlinearConjugateGradientOptimizer typedef NonlinearOptimizerParams Parameters; typedef std::shared_ptr shared_ptr; + enum class DirectionMethod { + FletcherReeves, + PolakRibiere, + HestenesStiefel, + DaiYuan + }; + protected: Parameters params_; + DirectionMethod directionMethod_; const NonlinearOptimizerParams &_params() const override { return params_; } public: /// Constructor - NonlinearConjugateGradientOptimizer(const NonlinearFactorGraph &graph, - const Values &initialValues, - const Parameters ¶ms = Parameters()); + NonlinearConjugateGradientOptimizer( + const NonlinearFactorGraph &graph, const Values &initialValues, + const Parameters ¶ms = Parameters(), + const DirectionMethod &directionMethod = DirectionMethod::PolakRibiere); /// Destructor ~NonlinearConjugateGradientOptimizer() override {}