Skip to content

Commit

Permalink
[CP-SAT] improve work sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
lperron committed Nov 20, 2024
1 parent 91b3c10 commit b21268c
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 86 deletions.
1 change: 1 addition & 0 deletions ortools/sat/cp_model_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ absl::flat_hash_map<std::string, SatParameters> GetNamedParameters(
SatParameters new_params = base_params;
new_params.set_use_shared_tree_search(true);
new_params.set_search_branching(SatParameters::AUTOMATIC_SEARCH);
new_params.set_linearization_level(0);

// These settings don't make sense with shared tree search, turn them off as
// they can break things.
Expand Down
7 changes: 5 additions & 2 deletions ortools/sat/cp_model_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1517,12 +1517,15 @@ void SolveCpModelParallel(SharedClasses* shared, Model* global_model) {
NeighborhoodGeneratorHelper* helper = unique_helper.get();
subsolvers.push_back(std::move(unique_helper));

// How many shared tree workers to run?
const int num_shared_tree_workers = shared->shared_tree_manager->NumWorkers();

// Add shared tree workers if asked.
if (params.shared_tree_num_workers() > 0 &&
if (num_shared_tree_workers >= 2 &&
shared->model_proto.assumptions().empty()) {
for (const SatParameters& local_params : RepeatParameters(
name_filter.Filter({name_to_params.at("shared_tree")}),
params.shared_tree_num_workers())) {
num_shared_tree_workers)) {
full_worker_subsolvers.push_back(std::make_unique<FullProblemSolver>(
local_params.name(), local_params,
/*split_in_chunks=*/params.interleave_search(), shared));
Expand Down
4 changes: 1 addition & 3 deletions ortools/sat/parameters_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ std::string ValidateParameters(const SatParameters& params) {
TEST_IS_FINITE(restart_dl_average_ratio);
TEST_IS_FINITE(restart_lbd_average_ratio);
TEST_IS_FINITE(shared_tree_open_leaves_per_worker);
TEST_IS_FINITE(shared_tree_worker_objective_split_probability);
TEST_IS_FINITE(shaving_search_deterministic_time);
TEST_IS_FINITE(strategy_change_increase_ratio);
TEST_IS_FINITE(symmetry_detection_deterministic_time_limit);
Expand All @@ -103,7 +102,7 @@ std::string ValidateParameters(const SatParameters& params) {
const int kMaxReasonableParallelism = 10'000;
TEST_IN_RANGE(num_workers, 0, kMaxReasonableParallelism);
TEST_IN_RANGE(num_search_workers, 0, kMaxReasonableParallelism);
TEST_IN_RANGE(shared_tree_num_workers, 0, kMaxReasonableParallelism);
TEST_IN_RANGE(shared_tree_num_workers, -1, kMaxReasonableParallelism);
TEST_IN_RANGE(interleave_batch_size, 0, kMaxReasonableParallelism);
TEST_IN_RANGE(shared_tree_open_leaves_per_worker, 1,
kMaxReasonableParallelism);
Expand All @@ -113,7 +112,6 @@ std::string ValidateParameters(const SatParameters& params) {
TEST_IN_RANGE(mip_max_activity_exponent, 1, 62);
TEST_IN_RANGE(mip_max_bound, 0, 1e17);
TEST_IN_RANGE(solution_pool_size, 1, std::numeric_limits<int32_t>::max());
TEST_IN_RANGE(shared_tree_worker_objective_split_probability, 0.0, 1.0);

// Feasibility jump.
TEST_NOT_NAN(feasibility_jump_decay);
Expand Down
4 changes: 2 additions & 2 deletions ortools/sat/parameters_validation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ TEST(ValidateParameters, LinearizationLevel) {

TEST(ValidateParameters, NumSharedTreeSearchWorkers) {
SatParameters params;
params.set_shared_tree_num_workers(-1);
EXPECT_THAT(ValidateParameters(params), HasSubstr("should be in [0,10000]"));
params.set_shared_tree_num_workers(-2);
EXPECT_THAT(ValidateParameters(params), HasSubstr("should be in [-1,10000]"));
}

TEST(ValidateParameters, SharedTreeSearchMaxNodesPerWorker) {
Expand Down
26 changes: 12 additions & 14 deletions ortools/sat/sat_parameters.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ option java_multiple_files = true;
// Contains the definitions for all the sat algorithm parameters and their
// default values.
//
// NEXT TAG: 304
// NEXT TAG: 305
message SatParameters {
// In some context, like in a portfolio of search, it makes sense to name a
// given parameters set for logging purpose.
Expand Down Expand Up @@ -1085,41 +1085,39 @@ message SatParameters {
// TODO(user): Add reference to paper when published.
optional double violation_ls_compound_move_probability = 259 [default = 0.5];

// Enables experimental workstealing-like shared tree search.
// If non-zero, start this many complete worker threads to explore a shared
// Enables shared tree search.
// If positive, start this many complete worker threads to explore a shared
// search tree. These workers communicate objective bounds and simple decision
// nogoods relating to the shared prefix of the tree, and will avoid exploring
// the same subtrees as one another.
// Specifying a negative number uses a heuristic to select an appropriate
// number of shared tree workeres based on the total number of workers.
optional int32 shared_tree_num_workers = 235 [default = 0];

// Set on shared subtree workers. Users should not set this directly.
optional bool use_shared_tree_search = 236 [default = false];

// After their assigned prefix, shared tree workers will branch on the
// objective with this probability. Higher numbers cause the shared tree
// search to focus on improving the lower bound over finding primal solutions.
optional double shared_tree_worker_objective_split_probability = 237
[default = 0.5];

// Minimum number of restarts before a worker will replace a subtree
// Minimum restarts before a worker will replace a subtree
// that looks "bad" based on the average LBD of learned clauses.
optional int32 shared_tree_worker_min_restarts_per_subtree = 282
[default = 1];

// If true, workers share more of the information from their local trail.
// Specifically, literals implied by the shared tree decisions and
// the longest conflict-free assignment from the last restart (to enable
// cross-worker phase-saving).
// Specifically, literals implied by the shared tree decisions.
optional bool shared_tree_worker_enable_trail_sharing = 295 [default = true];

// If true, shared tree workers share their target phase when returning an
// assigned subtree for the next worker to use.
optional bool shared_tree_worker_enable_phase_sharing = 304 [default = true];

// How many open leaf nodes should the shared tree maintain per worker.
optional double shared_tree_open_leaves_per_worker = 281 [default = 2.0];

// In order to limit total shared memory and communication overhead, limit the
// total number of nodes that may be generated in the shared tree. If the
// shared tree runs out of unassigned leaves, workers act as portfolio
// workers. Note: this limit includes interior nodes, not just leaves.
optional int32 shared_tree_max_nodes_per_worker = 238 [default = 100000];
optional int32 shared_tree_max_nodes_per_worker = 238 [default = 10000];

enum SharedTreeSplitStrategy {
// Uses the default strategy, currently equivalent to
Expand Down
123 changes: 64 additions & 59 deletions ortools/sat/work_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,22 @@ int MaxAllowedDiscrepancyPlusDepth(int num_leaves) {
}
return i;
}

int DefaultNumSharedTreeWorkers(Model* model) {
const SatParameters& params = *model->Get<SatParameters>();
// Shared tree workers are not deterministic, so don't enable them by default
// in interleaved search which is normally used to get deterministic results.
if (params.interleave_search()) return 0;
if (params.num_workers() < 16) return 0;
const bool has_objective =
model->Get<CpModelProto>()->has_objective() ||
model->Get<CpModelProto>()->has_floating_point_objective();
if (has_objective) {
return (params.num_workers() - 8) / 2;
}
return (params.num_workers() - 8) * 3 / 4;
}

} // namespace

Literal ProtoLiteral::Decode(CpModelMapping* mapping,
Expand Down Expand Up @@ -113,20 +129,9 @@ std::optional<ProtoLiteral> ProtoLiteral::EncodeInteger(
std::optional<ProtoLiteral> ProtoLiteral::Encode(Literal literal,
CpModelMapping* mapping,
IntegerEncoder* encoder) {
if (literal.Index() == kNoLiteralIndex) {
return std::nullopt;
}
int model_var =
mapping->GetProtoVariableFromBooleanVariable(literal.Variable());
if (model_var != -1) {
CHECK(mapping->IsBoolean(model_var));
ProtoLiteral result{
literal.IsPositive() ? model_var : NegatedRef(model_var),
literal.IsPositive() ? 1 : 0};
DCHECK_EQ(result.Decode(mapping, encoder), literal);
DCHECK_EQ(result.Negated().Decode(mapping, encoder), literal.Negated());
return result;
}
const std::optional<ProtoLiteral> result = EncodeLiteral(literal, mapping);
if (result.has_value()) return result;

for (auto int_lit : encoder->GetIntegerLiterals(literal)) {
auto result = EncodeInteger(int_lit, mapping);
if (result.has_value()) {
Expand All @@ -138,6 +143,22 @@ std::optional<ProtoLiteral> ProtoLiteral::Encode(Literal literal,
return std::nullopt;
}

std::optional<ProtoLiteral> ProtoLiteral::EncodeLiteral(
Literal literal, CpModelMapping* mapping) {
if (literal.Index() == kNoLiteralIndex) {
return std::nullopt;
}
int model_var =
mapping->GetProtoVariableFromBooleanVariable(literal.Variable());
if (model_var == -1) {
return std::nullopt;
}
DCHECK(mapping->IsBoolean(model_var));
ProtoLiteral result{literal.IsPositive() ? model_var : NegatedRef(model_var),
literal.IsPositive() ? 1 : 0};
return result;
}

void ProtoTrail::PushLevel(const ProtoLiteral& decision,
IntegerValue objective_lb, int node_id) {
CHECK_GT(node_id, 0);
Expand Down Expand Up @@ -209,15 +230,17 @@ absl::Span<const ProtoLiteral> ProtoTrail::Implications(int level) const {

SharedTreeManager::SharedTreeManager(Model* model)
: params_(*model->GetOrCreate<SatParameters>()),
num_workers_(std::max(1, params_.shared_tree_num_workers())),
num_workers_(params_.shared_tree_num_workers() >= 0
? params_.shared_tree_num_workers()
: DefaultNumSharedTreeWorkers(model)),
shared_response_manager_(model->GetOrCreate<SharedResponseManager>()),
num_splits_wanted_(
num_workers_ * params_.shared_tree_open_leaves_per_worker() - 1),
max_nodes_(params_.shared_tree_max_nodes_per_worker() >=
std::numeric_limits<int>::max() / num_workers_
? std::numeric_limits<int>::max()
: num_workers_ *
params_.shared_tree_max_nodes_per_worker()) {
max_nodes_(
params_.shared_tree_max_nodes_per_worker() >=
std::numeric_limits<int>::max() / std::max(num_workers_, 1)
? std::numeric_limits<int>::max()
: num_workers_ * params_.shared_tree_max_nodes_per_worker()) {
// Create the root node with a fake literal.
nodes_.push_back(
{.literal = ProtoLiteral(),
Expand Down Expand Up @@ -253,7 +276,12 @@ bool SharedTreeManager::SyncTree(ProtoTrail& path) {
if (level > 0 && !node->closed) {
NodeTrailInfo* trail_info = GetTrailInfo(node);
for (const ProtoLiteral& implication : path.Implications(level)) {
trail_info->implications.insert(implication);
auto it = trail_info->implications
.emplace(implication.proto_var(), implication.lb())
.first;
if (it->second < implication.lb()) {
it->second = implication.lb();
}
}
}
prev_level = level;
Expand Down Expand Up @@ -360,6 +388,7 @@ void SharedTreeManager::ReplaceTree(ProtoTrail& path) {
unassigned_leaves_.pop_back();
if (!leaf->closed && leaf->children[0] == nullptr) {
AssignLeaf(path, leaf);
path.SetTargetPhase(GetTrailInfo(leaf)->phase);
return;
}
}
Expand Down Expand Up @@ -539,8 +568,8 @@ void SharedTreeManager::AssignLeaf(ProtoTrail& path, Node* leaf) {
path.SetLevelImplied(path.MaxLevel());
}
if (params_.shared_tree_worker_enable_trail_sharing()) {
for (const ProtoLiteral& implication : GetTrailInfo(leaf)->implications) {
path.AddImplication(path.MaxLevel(), implication);
for (const auto& [var, lb] : GetTrailInfo(leaf)->implications) {
path.AddImplication(path.MaxLevel(), ProtoLiteral(var, lb));
}
}
}
Expand Down Expand Up @@ -723,36 +752,7 @@ bool SharedTreeWorker::NextDecision(LiteralIndex* decision_index) {
*decision_index = decision.Index();
return true;
}
if (objective_ == nullptr ||
objective_->objective_var == kNoIntegerVariable) {
return helper_->GetDecision(decision_policy, decision_index);
}
// If the current node is close to the global lower bound, maybe try to
// improve it.
const IntegerValue root_obj_lb =
integer_trail_->LevelZeroLowerBound(objective_->objective_var);
const IntegerValue root_obj_ub =
integer_trail_->LevelZeroUpperBound(objective_->objective_var);
const IntegerValue obj_split =
root_obj_lb + absl::LogUniform<int64_t>(
*random_, 0, (root_obj_ub - root_obj_lb).value());
const double objective_split_probability =
parameters_->shared_tree_worker_objective_split_probability();
return helper_->GetDecision(
[&]() -> BooleanOrIntegerLiteral {
IntegerValue obj_lb =
integer_trail_->LowerBound(objective_->objective_var);
IntegerValue obj_ub =
integer_trail_->UpperBound(objective_->objective_var);
if (obj_lb > obj_split || obj_ub <= obj_split ||
next_level > assigned_tree_.MaxLevel() + 1 ||
absl::Bernoulli(*random_, 1 - objective_split_probability)) {
return decision_policy();
}
return BooleanOrIntegerLiteral(
IntegerLiteral::LowerOrEqual(objective_->objective_var, obj_split));
},
decision_index);
return helper_->GetDecision(decision_policy, decision_index);
}

void SharedTreeWorker::MaybeProposeSplit() {
Expand Down Expand Up @@ -795,20 +795,25 @@ bool SharedTreeWorker::SyncWithSharedTree() {
<< " restarts prev depth: " << assigned_tree_.MaxLevel()
<< " target: " << assigned_tree_lbds_.WindowAverage()
<< " lbd: " << restart_policy_->LbdAverageSinceReset();
if (parameters_->shared_tree_worker_enable_trail_sharing()) {
std::vector<ProtoLiteral> phase_out;
if (parameters_->shared_tree_worker_enable_phase_sharing() &&
assigned_tree_.MaxLevel() > 0 &&
!decision_policy_->GetBestPartialAssignment().empty()) {
assigned_tree_.ClearTargetPhase();
for (Literal lit : decision_policy_->GetBestPartialAssignment()) {
auto encoded = ProtoLiteral::Encode(lit, mapping_, encoder_);
// Only set the phase for booleans to avoid creating literals on other
// workers.
auto encoded = ProtoLiteral::EncodeLiteral(lit, mapping_);
if (!encoded.has_value()) continue;
phase_out.push_back(*encoded);
assigned_tree_.SetPhase(*encoded);
}
assigned_tree_.SetPhase(phase_out);
}
manager_->ReplaceTree(assigned_tree_);
tree_assignment_restart_ = num_restarts_;
assigned_tree_lbds_.Add(restart_policy_->LbdAverageSinceReset());
restart_policy_->Reset();
if (parameters_->shared_tree_worker_enable_trail_sharing()) {
if (parameters_->shared_tree_worker_enable_phase_sharing()) {
VLOG(2) << "Importing phase of length: "
<< assigned_tree_.TargetPhase().size();
decision_policy_->ClearBestPartialAssignment();
for (const ProtoLiteral& lit : assigned_tree_.TargetPhase()) {
decision_policy_->SetTargetPolarity(DecodeDecision(lit));
Expand Down
27 changes: 21 additions & 6 deletions ortools/sat/work_assignment.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define OR_TOOLS_SAT_WORK_ASSIGNMENT_H_

#include <stdint.h>
#include <sys/stat.h>

#include <array>
#include <cmath>
Expand Down Expand Up @@ -58,6 +59,8 @@ class ProtoLiteral {
ProtoLiteral Negated() const {
return ProtoLiteral(NegatedRef(proto_var_), -lb_ + 1);
}
int proto_var() const { return proto_var_; }
IntegerValue lb() const { return lb_; }
bool operator==(const ProtoLiteral& other) const {
return proto_var_ == other.proto_var_ && lb_ == other.lb_;
}
Expand All @@ -69,9 +72,16 @@ class ProtoLiteral {

// Note you should only decode integer literals at the root level.
Literal Decode(CpModelMapping*, IntegerEncoder*) const;

// Enodes a literal as a ProtoLiteral. This can encode literals that occur in
// the proto model, and also integer bounds literals.
static std::optional<ProtoLiteral> Encode(Literal, CpModelMapping*,
IntegerEncoder*);

// As above, but will only encode literals that are boolean variables or their
// negations (i.e. not integer bounds literals).
static std::optional<ProtoLiteral> EncodeLiteral(Literal, CpModelMapping*);

private:
IntegerLiteral DecodeInteger(CpModelMapping*) const;
static std::optional<ProtoLiteral> EncodeInteger(IntegerLiteral,
Expand Down Expand Up @@ -136,11 +146,15 @@ class ProtoTrail {
absl::Span<const ProtoLiteral> Literals() const { return literals_; }

const std::vector<ProtoLiteral>& TargetPhase() const { return target_phase_; }
void SetPhase(absl::Span<const ProtoLiteral> phase) {
target_phase_.clear();
void ClearTargetPhase() { target_phase_.clear(); }
void SetPhase(const ProtoLiteral& lit) {
if (implication_level_.contains(lit)) return;
target_phase_.push_back(lit);
}
void SetTargetPhase(absl::Span<const ProtoLiteral> phase) {
ClearTargetPhase();
for (const ProtoLiteral& lit : phase) {
if (implication_level_.contains(lit)) return;
target_phase_.push_back(lit);
SetPhase(lit);
}
}

Expand Down Expand Up @@ -203,12 +217,13 @@ class SharedTreeManager {
}

private:
// Because it is quite difficult to get a flat_hash_set to release memory,
// Because it is quite difficult to get a flat_hash_map to release memory,
// we store info we need only for open nodes implications via a unique_ptr.
// Note to simplify code, the root will always have a NodeTrailInfo after it
// is closed.
struct NodeTrailInfo {
absl::flat_hash_set<ProtoLiteral> implications;
// A map from literal to the best lower bound proven at this node.
absl::flat_hash_map<int, IntegerValue> implications;
// This is only non-empty for nodes where all but one descendent is closed
// (i.e. mostly leaves).
std::vector<ProtoLiteral> phase;
Expand Down
Loading

0 comments on commit b21268c

Please sign in to comment.