Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestion for implementing user callbacks #108

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ if(NOT AMPLSOLVER)
message(WARNING "Optional library amplsolver (ASL) was not found.")
else()
message(STATUS "Library amplsolver was found.")
add_executable(uno_ampl bindings/AMPL/AMPLModel.cpp bindings/AMPL/uno_ampl.cpp)
add_executable(uno_ampl bindings/AMPL/AMPLModel.cpp bindings/AMPL/AMPLUserCallbacks.cpp bindings/AMPL/uno_ampl.cpp)

target_link_libraries(uno_ampl PUBLIC uno ${AMPLSOLVER} ${CMAKE_DL_LIBS})
add_definitions("-D HAS_AMPLSOLVER")
Expand Down
20 changes: 20 additions & 0 deletions bindings/AMPL/AMPLUserCallbacks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2024 Charlie Vanaret
// Licensed under the MIT license. See LICENSE file in the project directory for details.

#include "AMPLUserCallbacks.hpp"
#include "linear_algebra/Vector.hpp"
#include "optimization/Multipliers.hpp"

namespace uno {
AMPLUserCallbacks::AMPLUserCallbacks(): UserCallbacks() { }

void AMPLUserCallbacks::notify_acceptable_iterate(const Vector<double>& /*primals*/, const Multipliers& /*multipliers*/,
double /*objective_multiplier*/) {
}

void AMPLUserCallbacks::notify_new_primals(const Vector<double>& /*primals*/) {
}

void AMPLUserCallbacks::notify_new_multipliers(const Multipliers& /*multipliers*/) {
}
} // namespace
20 changes: 20 additions & 0 deletions bindings/AMPL/AMPLUserCallbacks.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) 2024 Charlie Vanaret
// Licensed under the MIT license. See LICENSE file in the project directory for details.

#ifndef UNO_AMPLUSERCALLBACKS_H
#define UNO_AMPLUSERCALLBACKS_H

#include "tools/UserCallbacks.hpp"

namespace uno {
class AMPLUserCallbacks: public UserCallbacks {
public:
AMPLUserCallbacks();

void notify_acceptable_iterate(const Vector<double>& primals, const Multipliers& multipliers, double objective_multiplier) override;
void notify_new_primals(const Vector<double>& primals) override;
void notify_new_multipliers(const Multipliers& multipliers) override;
};
} // namespace

#endif //UNO_AMPLUSERCALLBACKS_H
6 changes: 5 additions & 1 deletion bindings/AMPL/uno_ampl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ingredients/constraint_relaxation_strategies/ConstraintRelaxationStrategy.hpp"
#include "ingredients/constraint_relaxation_strategies/ConstraintRelaxationStrategyFactory.hpp"
#include "AMPLModel.hpp"
#include "AMPLUserCallbacks.hpp"
#include "Uno.hpp"
#include "model/ModelFactory.hpp"
#include "options/DefaultOptions.hpp"
Expand Down Expand Up @@ -50,8 +51,11 @@ namespace uno {
auto globalization_mechanism = GlobalizationMechanismFactory::create(*constraint_relaxation_strategy, options);
Uno uno = Uno(*globalization_mechanism, options);

// create the user callbacks
AMPLUserCallbacks user_callbacks{};

// solve the instance
uno.solve(*model, initial_iterate, options);
uno.solve(*model, initial_iterate, options, user_callbacks);
// std::cout << "memory_allocation_amount = " << memory_allocation_amount << '\n';
}
catch (std::exception& exception) {
Expand Down
14 changes: 13 additions & 1 deletion uno/Uno.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "options/Options.hpp"
#include "tools/Statistics.hpp"
#include "tools/Timer.hpp"
#include "tools/UserCallbacks.hpp"

namespace uno {
Uno::Uno(GlobalizationMechanism& globalization_mechanism, const Options& options) :
Expand All @@ -30,7 +31,15 @@ namespace uno {

Level Logger::level = INFO;

// solve without user callbacks
void Uno::solve(const Model& model, Iterate& current_iterate, const Options& options) {
// pass user callbacks that do nothing
NoUserCallbacks user_callbacks{};
this->solve(model, current_iterate, options, user_callbacks);
}

// solve with user callbacks
void Uno::solve(const Model& model, Iterate& current_iterate, const Options& options, UserCallbacks& user_callbacks) {
Timer timer{};
Statistics statistics = Uno::create_statistics(model, options);
WarmstartInformation warmstart_information{};
Expand All @@ -54,8 +63,11 @@ namespace uno {

// compute an acceptable iterate by solving a subproblem at the current point
warmstart_information.iterate_changed();
this->globalization_mechanism.compute_next_iterate(statistics, model, current_iterate, trial_iterate, warmstart_information);
this->globalization_mechanism.compute_next_iterate(statistics, model, current_iterate, trial_iterate, warmstart_information, user_callbacks);
termination = this->termination_criteria(trial_iterate.status, major_iterations, timer.get_duration());
user_callbacks.notify_new_primals(trial_iterate.primals);
user_callbacks.notify_new_multipliers(trial_iterate.multipliers);

// the trial iterate becomes the current iterate for the next iteration
std::swap(current_iterate, trial_iterate);
}
Expand Down
3 changes: 3 additions & 0 deletions uno/Uno.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ namespace uno {
class Options;
class Statistics;
class Timer;
class UserCallbacks;

class Uno {
public:
Uno(GlobalizationMechanism& globalization_mechanism, const Options& options);

// solve with or without user callbacks
void solve(const Model& model, Iterate& initial_iterate, const Options& options);
void solve(const Model& model, Iterate& initial_iterate, const Options& options, UserCallbacks& user_callbacks);

static std::string current_version();
static void print_available_strategies();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace uno {
class Subproblem;
template <typename IndexType, typename ElementType>
class SymmetricMatrix;
class UserCallbacks;
template <typename ElementType>
class Vector;
struct WarmstartInformation;
Expand All @@ -50,7 +51,7 @@ namespace uno {

// trial iterate acceptance
[[nodiscard]] virtual bool is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& warmstart_information) = 0;
double step_length, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) = 0;
[[nodiscard]] TerminationStatus check_termination(Iterate& iterate);

// primal-dual residuals
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
#include "model/Model.hpp"
#include "optimization/Iterate.hpp"
#include "optimization/WarmstartInformation.hpp"
#include "symbolic/VectorView.hpp"
#include "options/Options.hpp"
#include "symbolic/VectorView.hpp"
#include "tools/UserCallbacks.hpp"

namespace uno {
FeasibilityRestoration::FeasibilityRestoration(const Model& model, const Options& options) :
Expand Down Expand Up @@ -148,7 +149,7 @@ namespace uno {
}

bool FeasibilityRestoration::is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& warmstart_information) {
double step_length, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
// TODO pick right multipliers
this->subproblem->postprocess_iterate(this->current_problem(), trial_iterate);
this->compute_progress_measures(current_iterate, trial_iterate);
Expand Down Expand Up @@ -176,6 +177,11 @@ namespace uno {
predicted_reduction, this->current_problem().get_objective_multiplier());
}
ConstraintRelaxationStrategy::set_progress_statistics(statistics, trial_iterate);
if (accept_iterate) {
user_callbacks.notify_acceptable_iterate(trial_iterate.primals,
this->current_phase == Phase::OPTIMALITY ? trial_iterate.multipliers : trial_iterate.feasibility_multipliers,
this->current_problem().get_objective_multiplier());
}
return accept_iterate;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace uno {

// trial iterate acceptance
[[nodiscard]] bool is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& warmstart_information) override;
double step_length, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) override;

// primal-dual residuals
void compute_primal_dual_residuals(Iterate& iterate) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "symbolic/VectorView.hpp"
#include "options/Options.hpp"
#include "tools/Statistics.hpp"
#include "tools/UserCallbacks.hpp"

/*
* Infeasibility detection and SQP methods for nonlinear optimization
Expand Down Expand Up @@ -233,7 +234,7 @@ namespace uno {
}

bool l1Relaxation::is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& /*warmstart_information*/) {
double step_length, WarmstartInformation& /*warmstart_information*/, UserCallbacks& user_callbacks) {
this->subproblem->postprocess_iterate(this->l1_relaxed_problem, trial_iterate);
this->compute_progress_measures(current_iterate, trial_iterate);
trial_iterate.objective_multiplier = this->l1_relaxed_problem.get_objective_multiplier();
Expand All @@ -254,6 +255,7 @@ namespace uno {
if (accept_iterate) {
this->check_exact_relaxation(trial_iterate);
// this->set_dual_residuals_statistics(statistics, trial_iterate);
user_callbacks.notify_acceptable_iterate(trial_iterate.primals, trial_iterate.multipliers, this->penalty_parameter);
}
this->set_progress_statistics(statistics, trial_iterate);
return accept_iterate;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace uno {

// trial iterate acceptance
[[nodiscard]] bool is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
double step_length, WarmstartInformation& warmstart_information) override;
double step_length, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) override;

// primal-dual residuals
void compute_primal_dual_residuals(Iterate& iterate) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ namespace uno {
}

void BacktrackingLineSearch::compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) {
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
DEBUG2 << "Current iterate\n" << current_iterate << '\n';

this->constraint_relaxation_strategy.compute_feasible_direction(statistics, current_iterate, this->direction, warmstart_information);
BacktrackingLineSearch::check_unboundedness(this->direction);
this->backtrack_along_direction(statistics, model, current_iterate, trial_iterate, warmstart_information);
this->backtrack_along_direction(statistics, model, current_iterate, trial_iterate, warmstart_information, user_callbacks);
}

// go a fraction along the direction by finding an acceptable step length
void BacktrackingLineSearch::backtrack_along_direction(Statistics& statistics, const Model& model, Iterate& current_iterate,
Iterate& trial_iterate, WarmstartInformation& warmstart_information) {
Iterate& trial_iterate, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
double step_length = 1.;
bool termination = false;
size_t number_iterations = 0;
Expand All @@ -59,7 +59,7 @@ namespace uno {
this->scale_duals_with_step_length ? step_length : 1.);

is_acceptable = this->constraint_relaxation_strategy.is_iterate_acceptable(statistics, current_iterate, trial_iterate, this->direction,
step_length, warmstart_information);
step_length, warmstart_information, user_callbacks);
this->set_statistics(statistics, trial_iterate, this->direction, step_length, number_iterations);
}
catch (const EvaluationError& e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ namespace uno {

void initialize(Statistics& statistics, Iterate& initial_iterate, const Options& options) override;
void compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) override;
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) override;

private:
const double backtracking_ratio;
const double minimum_step_length;
const bool scale_duals_with_step_length;

void backtrack_along_direction(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information);
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks);
[[nodiscard]] bool terminate_with_small_step_length(Statistics& statistics, Iterate& trial_iterate);
[[nodiscard]] double decrease_step_length(double step_length) const;
static void check_unboundedness(const Direction& direction);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace uno {
class Model;
class Options;
class Statistics;
class UserCallbacks;
struct WarmstartInformation;

class GlobalizationMechanism {
Expand All @@ -22,7 +23,7 @@ namespace uno {

virtual void initialize(Statistics& statistics, Iterate& initial_iterate, const Options& options) = 0;
virtual void compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) = 0;
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) = 0;

[[nodiscard]] size_t get_hessian_evaluation_count() const;
[[nodiscard]] size_t get_number_subproblems_solved() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ namespace uno {
}

void TrustRegionStrategy::compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) {
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
DEBUG2 << "Current iterate\n" << current_iterate << '\n';

size_t number_iterations = 0;
Expand Down Expand Up @@ -77,7 +77,8 @@ namespace uno {
GlobalizationMechanism::assemble_trial_iterate(model, current_iterate, trial_iterate, this->direction, 1., 1.);
this->reset_active_trust_region_multipliers(model, this->direction, trial_iterate);

is_acceptable = this->is_iterate_acceptable(statistics, current_iterate, trial_iterate, this->direction, warmstart_information);
is_acceptable = this->is_iterate_acceptable(statistics, current_iterate, trial_iterate, this->direction, warmstart_information,
user_callbacks);
if (is_acceptable) {
this->constraint_relaxation_strategy.set_dual_residuals_statistics(statistics, trial_iterate);
this->reset_radius();
Expand Down Expand Up @@ -122,9 +123,9 @@ namespace uno {

// the trial iterate is accepted by the constraint relaxation strategy or if the step is small and we cannot switch to solving the feasibility problem
bool TrustRegionStrategy::is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate,
const Direction& direction, WarmstartInformation& warmstart_information) {
const Direction& direction, WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) {
bool accept_iterate = this->constraint_relaxation_strategy.is_iterate_acceptable(statistics, current_iterate, trial_iterate, direction, 1.,
warmstart_information);
warmstart_information, user_callbacks);
this->set_statistics(statistics, trial_iterate, direction);
if (accept_iterate) {
trial_iterate.status = this->constraint_relaxation_strategy.check_termination(trial_iterate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace uno {

void initialize(Statistics& statistics, Iterate& initial_iterate, const Options& options) override;
void compute_next_iterate(Statistics& statistics, const Model& model, Iterate& current_iterate, Iterate& trial_iterate,
WarmstartInformation& warmstart_information) override;
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks) override;

private:
double radius; /*!< Current trust region radius */
Expand All @@ -26,7 +26,7 @@ namespace uno {
const double tolerance;

bool is_iterate_acceptable(Statistics& statistics, Iterate& current_iterate, Iterate& trial_iterate, const Direction& direction,
WarmstartInformation& warmstart_information);
WarmstartInformation& warmstart_information, UserCallbacks& user_callbacks);
void possibly_increase_radius(double step_norm);
void decrease_radius(double step_norm);
void decrease_radius();
Expand Down
33 changes: 33 additions & 0 deletions uno/tools/UserCallbacks.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) 2024 Charlie Vanaret
// Licensed under the MIT license. See LICENSE file in the project directory for details.

#ifndef UNO_USERCALLBACKS_H
#define UNO_USERCALLBACKS_H

namespace uno {
// forward declarations
struct Multipliers;
template <class ElementType>
class Vector;

class UserCallbacks {
public:
UserCallbacks() = default;
virtual ~UserCallbacks() = default;

virtual void notify_acceptable_iterate(const Vector<double>& primals, const Multipliers& multipliers, double objective_multiplier) = 0;
virtual void notify_new_primals(const Vector<double>& primals) = 0;
virtual void notify_new_multipliers(const Multipliers& multipliers) = 0;
};

class NoUserCallbacks: public UserCallbacks {
public:
NoUserCallbacks(): UserCallbacks() { }

void notify_acceptable_iterate(const Vector<double>& /*primals*/, const Multipliers& /*multipliers*/, double /*objective_multiplier*/) override { }
void notify_new_primals(const Vector<double>& /*primals*/) override { }
void notify_new_multipliers(const Multipliers& /*multipliers*/) override { }
};
} // namespace

#endif //UNO_USERCALLBACKS_H
Loading