Skip to content

Commit

Permalink
[solvers] Add max_explored_nodes for branch_and_bound.
Browse files Browse the repository at this point in the history
Early terminate the branch and bound process if the explored nodes in the tree exceeds this number. (#19366)
  • Loading branch information
hongkai-dai authored May 10, 2023
1 parent 26ebe98 commit ec2ee07
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 34 deletions.
18 changes: 15 additions & 3 deletions bindings/pydrake/solvers/solvers_py_branch_and_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,21 @@ void DefineSolversBranchAndBound(py::module m) {
{
using Class = MixedIntegerBranchAndBound;
constexpr auto& cls_doc = doc.MixedIntegerBranchAndBound;
py::class_<Class>(m, "MixedIntegerBranchAndBound", cls_doc.doc)
.def(py::init<const MathematicalProgram&, const SolverId&>(),
py::arg("prog"), py::arg("solver_id"), cls_doc.ctor.doc)
py::class_<Class> bnb_cls(m, "MixedIntegerBranchAndBound", cls_doc.doc);

py::class_<MixedIntegerBranchAndBound::Options>(
bnb_cls, "Options", cls_doc.Options.doc)
.def(py::init<>(), cls_doc.Options.ctor.doc)
.def_readwrite("max_explored_nodes",
&MixedIntegerBranchAndBound::Options::max_explored_nodes,
cls_doc.Options.max_explored_nodes.doc);

bnb_cls
.def(py::init<const MathematicalProgram&, const SolverId&,
MixedIntegerBranchAndBound::Options>(),
py::arg("prog"), py::arg("solver_id"),
py::arg("options") = MixedIntegerBranchAndBound::Options{},
cls_doc.ctor.doc)
.def("Solve", &Class::Solve, cls_doc.Solve.doc)
.def("GetOptimalCost", &Class::GetOptimalCost,
cls_doc.GetOptimalCost.doc)
Expand Down
21 changes: 14 additions & 7 deletions bindings/pydrake/solvers/test/branch_and_bound_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,18 @@ def test(self):
prog.AddLinearConstraint(b[1] + 1.2 * x[1] - b[0] <= 5)
prog.AddQuadraticCost(x[0] * x[0])

dut = MixedIntegerBranchAndBound(prog, OsqpSolver().solver_id())
solution_result = dut.Solve()
dut1 = MixedIntegerBranchAndBound(prog, OsqpSolver().solver_id())
solution_result = dut1.Solve()
self.assertEqual(solution_result, SolutionResult.kSolutionFound)
self.assertAlmostEqual(dut.GetOptimalCost(), 1.)
self.assertAlmostEqual(dut.GetSubOptimalCost(0), 1.)
self.assertAlmostEqual(dut.GetSolution(x[0], 0), 1.)
self.assertAlmostEqual(dut.GetSolution(x[0], 1), 1.)
np.testing.assert_allclose(dut.GetSolution(x, 0), [1., 0.], atol=1e-12)
self.assertAlmostEqual(dut1.GetOptimalCost(), 1.)
self.assertAlmostEqual(dut1.GetSubOptimalCost(0), 1.)
self.assertAlmostEqual(dut1.GetSolution(x[0], 0), 1.)
self.assertAlmostEqual(dut1.GetSolution(x[0], 1), 1.)
np.testing.assert_allclose(
dut1.GetSolution(x, 0), [1., 0.], atol=1e-12)

options = MixedIntegerBranchAndBound.Options()
options.max_explored_nodes = 1
dut2 = MixedIntegerBranchAndBound(
prog=prog, solver_id=OsqpSolver().solver_id(), options=options)
solution_result = dut2.Solve()
50 changes: 38 additions & 12 deletions solvers/branch_and_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,22 @@ bool MixedIntegerBranchAndBoundNode::optimal_solution_is_integral() const {
DRAKE_UNREACHABLE();
}

bool MixedIntegerBranchAndBoundNode::is_explored() const {
return prog_result_.get() != nullptr;
}

int MixedIntegerBranchAndBoundNode::NumExploredNodesInSubtree() const {
// First count this node as the root of the subtree.
int ret = is_explored();
if (left_child_.get() != nullptr) {
ret += left_child_->NumExploredNodesInSubtree();
}
if (right_child_.get() != nullptr) {
ret += right_child_->NumExploredNodesInSubtree();
}
return ret;
}

bool IsVariableInList(const std::list<symbolic::Variable>& variable_list,
const symbolic::Variable& variable) {
for (const auto& var : variable_list) {
Expand Down Expand Up @@ -314,8 +330,10 @@ void MixedIntegerBranchAndBoundNode::Branch(
}

MixedIntegerBranchAndBound::MixedIntegerBranchAndBound(
const MathematicalProgram& prog, const SolverId& solver_id)
const MathematicalProgram& prog, const SolverId& solver_id,
MixedIntegerBranchAndBound::Options options)
: root_{nullptr},
options_{std::move(options)},
map_old_vars_to_new_vars_{},
best_upper_bound_{std::numeric_limits<double>::infinity()},
best_lower_bound_{-std::numeric_limits<double>::infinity()},
Expand Down Expand Up @@ -360,18 +378,26 @@ SolutionResult MixedIntegerBranchAndBound::Solve() {
}
MixedIntegerBranchAndBoundNode* branching_node = PickBranchingNode();
while (branching_node) {
// Found a branching node, branch on this node. If no branching node is
// found, then every leaf node is fathomed, the branch-and-bound process
// should terminate.
// TODO(hongkai.dai) We might need to have a function that picks the
// branching node together with the branching variable simultaneously.
const symbolic::Variable* branching_variable =
PickBranchingVariable(*branching_node);
BranchAndUpdate(branching_node, *branching_variable);
if (HasConverged()) {
return SolutionResult::kSolutionFound;
// Each branch will create two new nodes. So if the current number of nodes
// + 2 is larger than options_.max_explored_nodes, we don't branch
// any more.
if (options_.max_explored_nodes >= 1 &&
root_->NumExploredNodesInSubtree() + 2 > options_.max_explored_nodes) {
return SolutionResult::kIterationLimit;
} else {
// Found a branching node, branch on this node. If no branching node is
// found, then every leaf node is fathomed, the branch-and-bound process
// should terminate.
// TODO(hongkai.dai) We might need to have a function that picks the
// branching node together with the branching variable simultaneously.
const symbolic::Variable* branching_variable =
PickBranchingVariable(*branching_node);
BranchAndUpdate(branching_node, *branching_variable);
if (HasConverged()) {
return SolutionResult::kSolutionFound;
}
branching_node = PickBranchingNode();
}
branching_node = PickBranchingNode();
}
// No node to branch.
if (best_lower_bound_ == -std::numeric_limits<double>::infinity()) {
Expand Down
32 changes: 25 additions & 7 deletions solvers/branch_and_bound.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,23 @@ class MixedIntegerBranchAndBoundNode {
/** Getter for solver id. */
const SolverId& solver_id() const { return solver_id_; }

private:
/**
* If the solution to a binary variable is either less than integral_tol or
* larger than 1 - integral_tol, then we regard the solution to be binary.
* This method set this tolerance.
/** If the mathematical program in this node has been solved and the result is
* stored inside this node, then we say this node has been explored. */
[[nodiscard]] bool is_explored() const;

/** Returns the total number of explored nodes in the subtree
* (including this node if it has been explored).
*/
[[nodiscard]] int NumExploredNodesInSubtree() const;

private:
// If the solution to a binary variable is either less than integral_tol or
// larger than 1 - integral_tol, then we regard the solution to be binary.
// This method set this tolerance.
void set_integral_tolerance(double integral_tol) {
integral_tol_ = integral_tol;
}

private:
// Constructs an empty node. Clone the input mathematical program to this
// node. The child and the parent nodes are all nullptr.
// @param prog The optimization program whose binary variable constraints are
Expand Down Expand Up @@ -273,6 +279,15 @@ class MixedIntegerBranchAndBound {
kMinLowerBound, ///< Pick the node with the smallest optimal cost.
};

struct Options {
Options() {}
// The maximal number of explored nodes in the tree. The branch and bound
// process will terminate if the tree has explored this number of nodes.
// max_explored_nodes <= 0 means that we don't put an upper bound on
// the number of explored nodes.
int max_explored_nodes{-1};
};

/**
* The function signature for the user defined method to pick a branching node
* or a branching variable.
Expand All @@ -292,7 +307,8 @@ class MixedIntegerBranchAndBound {
* @param solver_id The ID of the solver for the optimization.
*/
explicit MixedIntegerBranchAndBound(const MathematicalProgram& prog,
const SolverId& solver_id);
const SolverId& solver_id,
Options options = Options{});

/**
* Solve the mixed-integer problem (MIP) through a branch and bound process.
Expand Down Expand Up @@ -637,6 +653,8 @@ class MixedIntegerBranchAndBound {
// The root node of the tree.
std::unique_ptr<MixedIntegerBranchAndBoundNode> root_;

MixedIntegerBranchAndBound::Options options_;

// We re-created the decision variables in the optimization program in the
// branch-and-bound. All nodes uses the same new set of decision variables,
// which is different from the variables in the original mixed-integer program
Expand Down
Loading

0 comments on commit ec2ee07

Please sign in to comment.