Skip to content

Commit

Permalink
Don't do any smart pruning if smart pruning factor is 0 or lower. (#1307
Browse files Browse the repository at this point in the history
)

* Don't do any smart pruning if smart pruning factor is 0 or lower.

Also disable this smart pruning in benchmark and selfplay that don't support self pruning setting.

* Review feedback applied.

* Fix name of parameter.
  • Loading branch information
Tilps authored May 24, 2020
1 parent e0bd02f commit 9883568
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/benchmark/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void Benchmark::Run() {
stopper->AddStopper(std::make_unique<TimeLimitStopper>(movetime));
}
if (visits > -1) {
stopper->AddStopper(std::make_unique<VisitsStopper>(visits));
stopper->AddStopper(std::make_unique<VisitsStopper>(visits, false));
}

NNCache cache;
Expand Down
11 changes: 7 additions & 4 deletions src/mcts/stoppers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,19 @@ void PopulateCommonUciStoppers(ChainedSearchStopper* stopper,
const auto cache_size_mb = options.Get<int>(kNNCacheSizeId);
const int ram_limit = options.Get<int>(kRamLimitMbId);
if (ram_limit) {
stopper->AddStopper(
std::make_unique<MemoryWatchingStopper>(cache_size_mb, ram_limit));
stopper->AddStopper(std::make_unique<MemoryWatchingStopper>(
cache_size_mb, ram_limit,
options.Get<float>(kSmartPruningFactorId) > 0.0f));
}

// "go nodes" stopper.
if (params.nodes) {
if (options.Get<bool>(kNodesAsPlayoutsId)) {
stopper->AddStopper(std::make_unique<PlayoutsStopper>(*params.nodes));
stopper->AddStopper(std::make_unique<PlayoutsStopper>(
*params.nodes, options.Get<float>(kSmartPruningFactorId) > 0.0f));
} else {
stopper->AddStopper(std::make_unique<VisitsStopper>(*params.nodes));
stopper->AddStopper(std::make_unique<VisitsStopper>(
*params.nodes, options.Get<float>(kSmartPruningFactorId) > 0.0f));
}
}

Expand Down
18 changes: 12 additions & 6 deletions src/mcts/stoppers/stoppers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ void ChainedSearchStopper::OnSearchDone(const IterationStats& stats) {

bool VisitsStopper::ShouldStop(const IterationStats& stats,
StoppersHints* hints) {
hints->UpdateEstimatedRemainingRemainingPlayouts(nodes_limit_ -
stats.total_nodes);
if (populate_remaining_playouts_) {
hints->UpdateEstimatedRemainingRemainingPlayouts(nodes_limit_ -
stats.total_nodes);
}
if (stats.total_nodes >= nodes_limit_) {
LOGFILE << "Stopped search: Reached visits limit: " << stats.total_nodes
<< ">=" << nodes_limit_;
Expand All @@ -74,8 +76,10 @@ bool VisitsStopper::ShouldStop(const IterationStats& stats,

bool PlayoutsStopper::ShouldStop(const IterationStats& stats,
StoppersHints* hints) {
hints->UpdateEstimatedRemainingRemainingPlayouts(nodes_limit_ -
stats.nodes_since_movestart);
if (populate_remaining_playouts_) {
hints->UpdateEstimatedRemainingRemainingPlayouts(
nodes_limit_ - stats.nodes_since_movestart);
}
if (stats.nodes_since_movestart >= nodes_limit_) {
LOGFILE << "Stopped search: Reached playouts limit: "
<< stats.nodes_since_movestart << ">=" << nodes_limit_;
Expand All @@ -97,10 +101,12 @@ const size_t kAvgCacheItemSize =
MemoryWatchingStopper::kAvgMovesPerPosition;
} // namespace

MemoryWatchingStopper::MemoryWatchingStopper(int cache_size, int ram_limit_mb)
MemoryWatchingStopper::MemoryWatchingStopper(int cache_size, int ram_limit_mb,
bool populate_remaining_playouts)
: VisitsStopper(
(ram_limit_mb * 1000000LL - cache_size * kAvgCacheItemSize) /
kAvgNodeSize) {
kAvgNodeSize,
populate_remaining_playouts) {
LOGFILE << "RAM limit " << ram_limit_mb << "MB. Cache takes "
<< cache_size * kAvgCacheItemSize / 1000000
<< "MB. Remaining memory is enough for " << GetVisitsLimit()
Expand Down
13 changes: 10 additions & 3 deletions src/mcts/stoppers/stoppers.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,29 @@ class ChainedSearchStopper : public SearchStopper {
// Watches visits (total tree nodes) and predicts remaining visits.
class VisitsStopper : public SearchStopper {
public:
VisitsStopper(int64_t limit) : nodes_limit_(limit) {}
VisitsStopper(int64_t limit, bool populate_remaining_playouts)
: nodes_limit_(limit),
populate_remaining_playouts_(populate_remaining_playouts) {}
int64_t GetVisitsLimit() const { return nodes_limit_; }
bool ShouldStop(const IterationStats&, StoppersHints*) override;

private:
const int64_t nodes_limit_;
const bool populate_remaining_playouts_;
};

// Watches playouts (new tree nodes) and predicts remaining visits.
class PlayoutsStopper : public SearchStopper {
public:
PlayoutsStopper(int64_t limit) : nodes_limit_(limit) {}
PlayoutsStopper(int64_t limit, bool populate_remaining_playouts)
: nodes_limit_(limit),
populate_remaining_playouts_(populate_remaining_playouts) {}
int64_t GetVisitsLimit() const { return nodes_limit_; }
bool ShouldStop(const IterationStats&, StoppersHints*) override;

private:
const int64_t nodes_limit_;
const bool populate_remaining_playouts_;
};

// Computes tree size which may fit into the memory and limits by that tree
Expand All @@ -78,7 +84,8 @@ class MemoryWatchingStopper : public VisitsStopper {
public:
// Must be in sync with description at kRamLimitMbId.
static constexpr size_t kAvgMovesPerPosition = 30;
MemoryWatchingStopper(int cache_size, int ram_limit_mb);
MemoryWatchingStopper(int cache_size, int ram_limit_mb,
bool populate_remaining_playouts);
};

// Stops after time budget is gone.
Expand Down
6 changes: 4 additions & 2 deletions src/selfplay/game.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,11 @@ std::unique_ptr<ChainedSearchStopper> SelfPlayLimits::MakeSearchStopper()
const {
auto result = std::make_unique<ChainedSearchStopper>();

if (visits >= 0) result->AddStopper(std::make_unique<VisitsStopper>(visits));
if (visits >= 0) {
result->AddStopper(std::make_unique<VisitsStopper>(visits, false));
}
if (playouts >= 0) {
result->AddStopper(std::make_unique<PlayoutsStopper>(playouts));
result->AddStopper(std::make_unique<PlayoutsStopper>(playouts, false));
}
if (movetime >= 0) {
result->AddStopper(std::make_unique<TimeLimitStopper>(movetime));
Expand Down

0 comments on commit 9883568

Please sign in to comment.