Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In smooth TM, use piggybank when it seems like the bestmove can be overtaken. #1762

Merged
merged 12 commits into from
Dec 4, 2022
155 changes: 140 additions & 15 deletions src/mcts/stoppers/smooth.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <functional>
#include <iomanip>
#include <optional>

#include "mcts/stoppers/legacy.h"
#include "mcts/stoppers/stoppers.h"
Expand Down Expand Up @@ -83,6 +84,19 @@ class Params {
// Max number of avg move times in piggybank.
float max_piggybank_moves() const { return max_piggybank_moves_; }

int64_t trend_nps_update_period_ms() const {
return trend_nps_update_period_ms_;
}

// Expected ration of the best move nps in future, to the current nps.
float bestmove_optimism() const { return bestmove_optimism_; }

// Expected ration of the non-best move nps in future, to the current nps.
float overtaker_optimism() const { return overtaker_optimism_; }

// Force a use of piggybank during the first few milliseconds of the move.
float force_piggybank_ms() const { return force_piggybank_ms_; }

// Move overhead.
int64_t move_overhead_ms() const { return move_overhead_ms_; }
// Returns a function function that estimates remaining moves.
Expand All @@ -104,6 +118,10 @@ class Params {
const float per_move_piggybank_fraction_;
const float max_piggybank_use_;
const float max_piggybank_moves_;
const float trend_nps_update_period_ms_;
const float bestmove_optimism_;
const float overtaker_optimism_;
const float force_piggybank_ms_;
const MovesLeftEstimator moves_left_estimator_;
};

Expand Down Expand Up @@ -144,6 +162,12 @@ Params::Params(const OptionsDict& params, int64_t move_overhead)
params.GetOrDefault<float>("max-piggybank-use", 0.94f)),
max_piggybank_moves_(
params.GetOrDefault<float>("max-piggybank-moves", 36.5f)),
trend_nps_update_period_ms_(
params.GetOrDefault<int>("trend-nps-update-period-ms", 3000)),
bestmove_optimism_(params.GetOrDefault<float>("bestmove-optimism", 0.2f)),
overtaker_optimism_(
params.GetOrDefault<float>("overtaker-optimism", 4.0f)),
force_piggybank_ms_(params.GetOrDefault<int>("force-piggybank-ms", 1000)),
moves_left_estimator_(CreateMovesLeftEstimator(params)) {}

// Returns the updated value of @from, towards @to by the number of halves
Expand Down Expand Up @@ -198,9 +222,36 @@ float LinearDecay(float cur_value, float target_value,

class SmoothTimeManager;

class VisitsTrendWatcher {
public:
VisitsTrendWatcher(float nps_update_period, float bestmove_optimism,
float overtaker_optimism)
: nps_update_period_(nps_update_period),
bestmove_optimism_(bestmove_optimism),
overtaker_optimism_(overtaker_optimism) {}

void Update(uint64_t timestamp, const std::vector<uint32_t>& visits);
bool IsBestmoveBeingOvertaken(uint64_t by_which_time) const;

private:
const float nps_update_period_;
const float bestmove_optimism_;
const float overtaker_optimism_;

mutable Mutex mutex_;
uint64_t prev_timestamp_ GUARDED_BY(mutex_) = 0;
std::vector<uint32_t> prev_visits_ GUARDED_BY(mutex_);
uint64_t cur_timestamp_ GUARDED_BY(mutex_) = 0;
std::vector<uint32_t> cur_visits_ GUARDED_BY(mutex_);
uint64_t last_timestamp_ GUARDED_BY(mutex_) = 0;
std::vector<uint32_t> last_visits_ GUARDED_BY(mutex_);
};

class SmoothStopper : public SearchStopper {
public:
SmoothStopper(int64_t deadline_ms, int64_t allowed_piggybank_use_ms,
float nps_update_period, float bestmove_optimism,
float overtaker_optimism, int64_t forces_piggybank_ms,
SmoothTimeManager* manager);

private:
Expand All @@ -209,7 +260,9 @@ class SmoothStopper : public SearchStopper {

const int64_t deadline_ms_;
const int64_t allowed_piggybank_use_ms_;
const int64_t forced_piggybank_use_ms_;

VisitsTrendWatcher visits_trend_watcher_;
SmoothTimeManager* const manager_;
std::atomic_flag used_piggybank_;
};
Expand Down Expand Up @@ -251,19 +304,19 @@ class SmoothTimeManager : public TimeManager {
: total_move_time / move_allocated_time_ms_;
// Recompute expected move time for logging.
const float expected_move_time = move_allocated_time_ms_ * timeuse_;
// If piggybank was used, cannot update timeuse_.

int64_t piggybank_time_used = 0;
if (used_piggybank) {
piggybank_time_used = std::max(int64_t(), total_move_time - time_budget);
piggybank_time_ -= piggybank_time_used;
} else {
timeuse_ =
ExponentialDecay(timeuse_, this_move_time_use,
params_.smartpruning_timeuse_halfupdate_moves(),
this_move_time_fraction);
if (timeuse_ < params_.min_smartpruning_timeuse()) {
timeuse_ = params_.min_smartpruning_timeuse();
}
}
// If piggybank was used, time use is 100%.
timeuse_ =
ExponentialDecay(timeuse_, used_piggybank ? 1.0f : this_move_time_use,
params_.smartpruning_timeuse_halfupdate_moves(),
this_move_time_fraction);
if (timeuse_ < params_.min_smartpruning_timeuse()) {
timeuse_ = params_.min_smartpruning_timeuse();
}
// Remember final number of nodes for tree reuse estimation.
last_move_final_nodes_ = total_nodes;
Expand Down Expand Up @@ -404,8 +457,10 @@ class SmoothTimeManager : public TimeManager {
<< ", moves=" << remaining_moves << ", time=" << total_remaining_ms
<< "ms, nps=" << nps_;

return std::make_unique<SmoothStopper>(move_allocated_time_ms_,
allowed_piggybank_time_ms, this);
return std::make_unique<SmoothStopper>(
move_allocated_time_ms_, allowed_piggybank_time_ms,
params_.trend_nps_update_period_ms(), params_.bestmove_optimism(),
params_.overtaker_optimism(), params_.force_piggybank_ms(), this);
}

void UpdateTreeReuseFactor(int64_t new_move_nodes) REQUIRES(mutex_) {
Expand Down Expand Up @@ -460,11 +515,65 @@ class SmoothTimeManager : public TimeManager {
bool is_first_move_ GUARDED_BY(mutex_) = true;
};

void VisitsTrendWatcher::Update(uint64_t timestamp,
const std::vector<uint32_t>& visits) {
Mutex::Lock lock(mutex_);
if (timestamp <= last_timestamp_) return;
if (prev_visits_.empty()) {
prev_visits_ = visits;
cur_visits_ = visits;
prev_timestamp_ = timestamp;
cur_timestamp_ = timestamp;
}
last_timestamp_ = timestamp;
last_visits_ = visits;
if (cur_timestamp_ + nps_update_period_ >= timestamp) {
prev_timestamp_ = cur_timestamp_;
prev_visits_ = std::move(cur_visits_);
cur_visits_ = last_visits_;
cur_timestamp_ = last_timestamp_;
}
}

bool VisitsTrendWatcher::IsBestmoveBeingOvertaken(
uint64_t by_which_time) const {
Mutex::Lock lock(mutex_);
// If we don't have any stats yet, we cannot stop the search.
if (prev_timestamp_ >= last_timestamp_) return false;
if (by_which_time <= last_timestamp_) return false;
std::vector<float> npms;
npms.reserve(last_visits_.size());
for (size_t i = 0; i < last_visits_.size(); ++i) {
npms.push_back(static_cast<float>(last_visits_[i] - prev_visits_[i]) /
(last_timestamp_ - prev_timestamp_));
}
const size_t bestmove_idx =
std::max_element(last_visits_.begin(), last_visits_.end()) -
last_visits_.begin();
const auto planned_bestmove_visits =
last_visits_[bestmove_idx] + bestmove_optimism_ * npms[bestmove_idx] *
(by_which_time - last_timestamp_);
for (size_t i = 0; i < last_visits_.size(); ++i) {
if (i == bestmove_idx) continue;
const auto planned_visits =
last_visits_[i] +
overtaker_optimism_ * npms[i] * (by_which_time - last_timestamp_);
if (planned_visits > planned_bestmove_visits) return true;
}
return false;
}

SmoothStopper::SmoothStopper(int64_t deadline_ms,
int64_t allowed_piggybank_use_ms,
float nps_update_period, float bestmove_optimism,
float overtaker_optimism,
int64_t forced_piggybank_use_ms,
SmoothTimeManager* manager)
: deadline_ms_(deadline_ms),
allowed_piggybank_use_ms_(allowed_piggybank_use_ms),
forced_piggybank_use_ms_(forced_piggybank_use_ms),
visits_trend_watcher_(nps_update_period, bestmove_optimism,
overtaker_optimism),
manager_(manager) {
used_piggybank_.clear();
}
Expand All @@ -478,21 +587,37 @@ bool SmoothStopper::ShouldStop(const IterationStats& stats,
return true;
}

visits_trend_watcher_.Update(stats.time_since_movestart, stats.edge_n);
const auto deadline_with_piggybank = deadline_ms_ + allowed_piggybank_use_ms_;
const bool force_use_piggybank =
stats.time_since_first_batch <= forced_piggybank_use_ms_;
const bool use_piggybank =
(stats.time_usage_hint_ == IterationStats::TimeUsageHint::kNeedMoreTime);
(stats.time_usage_hint_ == IterationStats::TimeUsageHint::kNeedMoreTime ||
force_use_piggybank ||
visits_trend_watcher_.IsBestmoveBeingOvertaken(deadline_with_piggybank));
const int64_t time_limit =
use_piggybank ? (deadline_ms_ + allowed_piggybank_use_ms_) : deadline_ms_;
use_piggybank ? deadline_with_piggybank : deadline_ms_;
hints->UpdateEstimatedNps(nps);
hints->UpdateEstimatedRemainingTimeMs(time_limit -
stats.time_since_movestart);
if (use_piggybank && stats.time_since_movestart >= deadline_ms_) {
// It's not entirely correct as due to extended remaining time smart pruning
// will trigger later and we spend more time than if use_piggyback was
// false, even before reaching the deadline.
used_piggybank_.test_and_set();
if (!used_piggybank_.test_and_set()) {
LOGFILE << "Entering piggybank, reason: "
<< (stats.time_usage_hint_ ==
IterationStats::TimeUsageHint::kNeedMoreTime
? "requested by search."
: force_use_piggybank
? "forced used in the beginning of the move."
: "bestmove can be overtaken.");
}
}
if (stats.time_since_movestart >= time_limit) {
LOGFILE << "Stopping search: Ran out of time.";
LOGFILE << "Stopping search: Ran out of time. elapsed=" << std::fixed
<< stats.time_since_movestart << " limit=" << time_limit
<< " piggy=" << use_piggybank;
return true;
}
return false;
Expand Down
5 changes: 3 additions & 2 deletions src/mcts/stoppers/stoppers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,9 @@ bool SmartPruningStopper::ShouldStop(const IterationStats& stats,
}

if (remaining_playouts < (largest_n - second_largest_n)) {
LOGFILE << remaining_playouts << " playouts remaining. Best move has "
<< largest_n << " visits, second best -- " << second_largest_n
LOGFILE << std::fixed << remaining_playouts
<< " playouts remaining. Best move has " << largest_n
<< " visits, second best -- " << second_largest_n
<< ". Difference is " << (largest_n - second_largest_n)
<< ", so stopping the search after "
<< stats.batches_since_movestart << " batches.";
Expand Down