Skip to content

Commit

Permalink
In smooth TM, use piggybank when it seems like the bestmove can be ov…
Browse files Browse the repository at this point in the history
  • Loading branch information
mooskagh authored and borg323 committed Feb 20, 2023
1 parent b218726 commit 402da93
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 17 deletions.
155 changes: 140 additions & 15 deletions src/mcts/stoppers/smooth.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <functional>
#include <iomanip>
#include <optional>

#include "mcts/stoppers/legacy.h"
#include "mcts/stoppers/stoppers.h"
Expand Down Expand Up @@ -83,6 +84,19 @@ class Params {
// Max number of avg move times in piggybank.
float max_piggybank_moves() const { return max_piggybank_moves_; }

int64_t trend_nps_update_period_ms() const {
return trend_nps_update_period_ms_;
}

// Expected ration of the best move nps in future, to the current nps.
float bestmove_optimism() const { return bestmove_optimism_; }

// Expected ration of the non-best move nps in future, to the current nps.
float overtaker_optimism() const { return overtaker_optimism_; }

// Force a use of piggybank during the first few milliseconds of the move.
float force_piggybank_ms() const { return force_piggybank_ms_; }

// Move overhead.
int64_t move_overhead_ms() const { return move_overhead_ms_; }
// Returns a function function that estimates remaining moves.
Expand All @@ -104,6 +118,10 @@ class Params {
const float per_move_piggybank_fraction_;
const float max_piggybank_use_;
const float max_piggybank_moves_;
const float trend_nps_update_period_ms_;
const float bestmove_optimism_;
const float overtaker_optimism_;
const float force_piggybank_ms_;
const MovesLeftEstimator moves_left_estimator_;
};

Expand Down Expand Up @@ -144,6 +162,12 @@ Params::Params(const OptionsDict& params, int64_t move_overhead)
params.GetOrDefault<float>("max-piggybank-use", 0.94f)),
max_piggybank_moves_(
params.GetOrDefault<float>("max-piggybank-moves", 36.5f)),
trend_nps_update_period_ms_(
params.GetOrDefault<int>("trend-nps-update-period-ms", 3000)),
bestmove_optimism_(params.GetOrDefault<float>("bestmove-optimism", 0.2f)),
overtaker_optimism_(
params.GetOrDefault<float>("overtaker-optimism", 4.0f)),
force_piggybank_ms_(params.GetOrDefault<int>("force-piggybank-ms", 1000)),
moves_left_estimator_(CreateMovesLeftEstimator(params)) {}

// Returns the updated value of @from, towards @to by the number of halves
Expand Down Expand Up @@ -198,9 +222,36 @@ float LinearDecay(float cur_value, float target_value,

class SmoothTimeManager;

class VisitsTrendWatcher {
public:
VisitsTrendWatcher(float nps_update_period, float bestmove_optimism,
float overtaker_optimism)
: nps_update_period_(nps_update_period),
bestmove_optimism_(bestmove_optimism),
overtaker_optimism_(overtaker_optimism) {}

void Update(uint64_t timestamp, const std::vector<uint32_t>& visits);
bool IsBestmoveBeingOvertaken(uint64_t by_which_time) const;

private:
const float nps_update_period_;
const float bestmove_optimism_;
const float overtaker_optimism_;

mutable Mutex mutex_;
uint64_t prev_timestamp_ GUARDED_BY(mutex_) = 0;
std::vector<uint32_t> prev_visits_ GUARDED_BY(mutex_);
uint64_t cur_timestamp_ GUARDED_BY(mutex_) = 0;
std::vector<uint32_t> cur_visits_ GUARDED_BY(mutex_);
uint64_t last_timestamp_ GUARDED_BY(mutex_) = 0;
std::vector<uint32_t> last_visits_ GUARDED_BY(mutex_);
};

class SmoothStopper : public SearchStopper {
public:
SmoothStopper(int64_t deadline_ms, int64_t allowed_piggybank_use_ms,
float nps_update_period, float bestmove_optimism,
float overtaker_optimism, int64_t forces_piggybank_ms,
SmoothTimeManager* manager);

private:
Expand All @@ -209,7 +260,9 @@ class SmoothStopper : public SearchStopper {

const int64_t deadline_ms_;
const int64_t allowed_piggybank_use_ms_;
const int64_t forced_piggybank_use_ms_;

VisitsTrendWatcher visits_trend_watcher_;
SmoothTimeManager* const manager_;
std::atomic_flag used_piggybank_;
};
Expand Down Expand Up @@ -251,19 +304,19 @@ class SmoothTimeManager : public TimeManager {
: total_move_time / move_allocated_time_ms_;
// Recompute expected move time for logging.
const float expected_move_time = move_allocated_time_ms_ * timeuse_;
// If piggybank was used, cannot update timeuse_.

int64_t piggybank_time_used = 0;
if (used_piggybank) {
piggybank_time_used = std::max(int64_t(), total_move_time - time_budget);
piggybank_time_ -= piggybank_time_used;
} else {
timeuse_ =
ExponentialDecay(timeuse_, this_move_time_use,
params_.smartpruning_timeuse_halfupdate_moves(),
this_move_time_fraction);
if (timeuse_ < params_.min_smartpruning_timeuse()) {
timeuse_ = params_.min_smartpruning_timeuse();
}
}
// If piggybank was used, time use is 100%.
timeuse_ =
ExponentialDecay(timeuse_, used_piggybank ? 1.0f : this_move_time_use,
params_.smartpruning_timeuse_halfupdate_moves(),
this_move_time_fraction);
if (timeuse_ < params_.min_smartpruning_timeuse()) {
timeuse_ = params_.min_smartpruning_timeuse();
}
// Remember final number of nodes for tree reuse estimation.
last_move_final_nodes_ = total_nodes;
Expand Down Expand Up @@ -404,8 +457,10 @@ class SmoothTimeManager : public TimeManager {
<< ", moves=" << remaining_moves << ", time=" << total_remaining_ms
<< "ms, nps=" << nps_;

return std::make_unique<SmoothStopper>(move_allocated_time_ms_,
allowed_piggybank_time_ms, this);
return std::make_unique<SmoothStopper>(
move_allocated_time_ms_, allowed_piggybank_time_ms,
params_.trend_nps_update_period_ms(), params_.bestmove_optimism(),
params_.overtaker_optimism(), params_.force_piggybank_ms(), this);
}

void UpdateTreeReuseFactor(int64_t new_move_nodes) REQUIRES(mutex_) {
Expand Down Expand Up @@ -460,11 +515,65 @@ class SmoothTimeManager : public TimeManager {
bool is_first_move_ GUARDED_BY(mutex_) = true;
};

void VisitsTrendWatcher::Update(uint64_t timestamp,
const std::vector<uint32_t>& visits) {
Mutex::Lock lock(mutex_);
if (timestamp <= last_timestamp_) return;
if (prev_visits_.empty()) {
prev_visits_ = visits;
cur_visits_ = visits;
prev_timestamp_ = timestamp;
cur_timestamp_ = timestamp;
}
last_timestamp_ = timestamp;
last_visits_ = visits;
if (cur_timestamp_ + nps_update_period_ >= timestamp) {
prev_timestamp_ = cur_timestamp_;
prev_visits_ = std::move(cur_visits_);
cur_visits_ = last_visits_;
cur_timestamp_ = last_timestamp_;
}
}

bool VisitsTrendWatcher::IsBestmoveBeingOvertaken(
uint64_t by_which_time) const {
Mutex::Lock lock(mutex_);
// If we don't have any stats yet, we cannot stop the search.
if (prev_timestamp_ >= last_timestamp_) return false;
if (by_which_time <= last_timestamp_) return false;
std::vector<float> npms;
npms.reserve(last_visits_.size());
for (size_t i = 0; i < last_visits_.size(); ++i) {
npms.push_back(static_cast<float>(last_visits_[i] - prev_visits_[i]) /
(last_timestamp_ - prev_timestamp_));
}
const size_t bestmove_idx =
std::max_element(last_visits_.begin(), last_visits_.end()) -
last_visits_.begin();
const auto planned_bestmove_visits =
last_visits_[bestmove_idx] + bestmove_optimism_ * npms[bestmove_idx] *
(by_which_time - last_timestamp_);
for (size_t i = 0; i < last_visits_.size(); ++i) {
if (i == bestmove_idx) continue;
const auto planned_visits =
last_visits_[i] +
overtaker_optimism_ * npms[i] * (by_which_time - last_timestamp_);
if (planned_visits > planned_bestmove_visits) return true;
}
return false;
}

SmoothStopper::SmoothStopper(int64_t deadline_ms,
int64_t allowed_piggybank_use_ms,
float nps_update_period, float bestmove_optimism,
float overtaker_optimism,
int64_t forced_piggybank_use_ms,
SmoothTimeManager* manager)
: deadline_ms_(deadline_ms),
allowed_piggybank_use_ms_(allowed_piggybank_use_ms),
forced_piggybank_use_ms_(forced_piggybank_use_ms),
visits_trend_watcher_(nps_update_period, bestmove_optimism,
overtaker_optimism),
manager_(manager) {
used_piggybank_.clear();
}
Expand All @@ -478,21 +587,37 @@ bool SmoothStopper::ShouldStop(const IterationStats& stats,
return true;
}

visits_trend_watcher_.Update(stats.time_since_movestart, stats.edge_n);
const auto deadline_with_piggybank = deadline_ms_ + allowed_piggybank_use_ms_;
const bool force_use_piggybank =
stats.time_since_first_batch <= forced_piggybank_use_ms_;
const bool use_piggybank =
(stats.time_usage_hint_ == IterationStats::TimeUsageHint::kNeedMoreTime);
(stats.time_usage_hint_ == IterationStats::TimeUsageHint::kNeedMoreTime ||
force_use_piggybank ||
visits_trend_watcher_.IsBestmoveBeingOvertaken(deadline_with_piggybank));
const int64_t time_limit =
use_piggybank ? (deadline_ms_ + allowed_piggybank_use_ms_) : deadline_ms_;
use_piggybank ? deadline_with_piggybank : deadline_ms_;
hints->UpdateEstimatedNps(nps);
hints->UpdateEstimatedRemainingTimeMs(time_limit -
stats.time_since_movestart);
if (use_piggybank && stats.time_since_movestart >= deadline_ms_) {
// It's not entirely correct as due to extended remaining time smart pruning
// will trigger later and we spend more time than if use_piggyback was
// false, even before reaching the deadline.
used_piggybank_.test_and_set();
if (!used_piggybank_.test_and_set()) {
LOGFILE << "Entering piggybank, reason: "
<< (stats.time_usage_hint_ ==
IterationStats::TimeUsageHint::kNeedMoreTime
? "requested by search."
: force_use_piggybank
? "forced used in the beginning of the move."
: "bestmove can be overtaken.");
}
}
if (stats.time_since_movestart >= time_limit) {
LOGFILE << "Stopping search: Ran out of time.";
LOGFILE << "Stopping search: Ran out of time. elapsed=" << std::fixed
<< stats.time_since_movestart << " limit=" << time_limit
<< " piggy=" << use_piggybank;
return true;
}
return false;
Expand Down
5 changes: 3 additions & 2 deletions src/mcts/stoppers/stoppers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,9 @@ bool SmartPruningStopper::ShouldStop(const IterationStats& stats,
}

if (remaining_playouts < (largest_n - second_largest_n)) {
LOGFILE << remaining_playouts << " playouts remaining. Best move has "
<< largest_n << " visits, second best -- " << second_largest_n
LOGFILE << std::fixed << remaining_playouts
<< " playouts remaining. Best move has " << largest_n
<< " visits, second best -- " << second_largest_n
<< ". Difference is " << (largest_n - second_largest_n)
<< ", so stopping the search after "
<< stats.batches_since_movestart << " batches.";
Expand Down

0 comments on commit 402da93

Please sign in to comment.