diff --git a/src/mcts/stoppers/smooth.cc b/src/mcts/stoppers/smooth.cc index e0cc38e651..2a1a196247 100644 --- a/src/mcts/stoppers/smooth.cc +++ b/src/mcts/stoppers/smooth.cc @@ -29,6 +29,7 @@ #include #include +#include #include "mcts/stoppers/legacy.h" #include "mcts/stoppers/stoppers.h" @@ -83,6 +84,19 @@ class Params { // Max number of avg move times in piggybank. float max_piggybank_moves() const { return max_piggybank_moves_; } + int64_t trend_nps_update_period_ms() const { + return trend_nps_update_period_ms_; + } + + // Expected ration of the best move nps in future, to the current nps. + float bestmove_optimism() const { return bestmove_optimism_; } + + // Expected ration of the non-best move nps in future, to the current nps. + float overtaker_optimism() const { return overtaker_optimism_; } + + // Force a use of piggybank during the first few milliseconds of the move. + float force_piggybank_ms() const { return force_piggybank_ms_; } + // Move overhead. int64_t move_overhead_ms() const { return move_overhead_ms_; } // Returns a function function that estimates remaining moves. @@ -104,6 +118,10 @@ class Params { const float per_move_piggybank_fraction_; const float max_piggybank_use_; const float max_piggybank_moves_; + const float trend_nps_update_period_ms_; + const float bestmove_optimism_; + const float overtaker_optimism_; + const float force_piggybank_ms_; const MovesLeftEstimator moves_left_estimator_; }; @@ -144,6 +162,12 @@ Params::Params(const OptionsDict& params, int64_t move_overhead) params.GetOrDefault("max-piggybank-use", 0.94f)), max_piggybank_moves_( params.GetOrDefault("max-piggybank-moves", 36.5f)), + trend_nps_update_period_ms_( + params.GetOrDefault("trend-nps-update-period-ms", 3000)), + bestmove_optimism_(params.GetOrDefault("bestmove-optimism", 0.2f)), + overtaker_optimism_( + params.GetOrDefault("overtaker-optimism", 4.0f)), + force_piggybank_ms_(params.GetOrDefault("force-piggybank-ms", 1000)), moves_left_estimator_(CreateMovesLeftEstimator(params)) {} // Returns the updated value of @from, towards @to by the number of halves @@ -198,9 +222,36 @@ float LinearDecay(float cur_value, float target_value, class SmoothTimeManager; +class VisitsTrendWatcher { + public: + VisitsTrendWatcher(float nps_update_period, float bestmove_optimism, + float overtaker_optimism) + : nps_update_period_(nps_update_period), + bestmove_optimism_(bestmove_optimism), + overtaker_optimism_(overtaker_optimism) {} + + void Update(uint64_t timestamp, const std::vector& visits); + bool IsBestmoveBeingOvertaken(uint64_t by_which_time) const; + + private: + const float nps_update_period_; + const float bestmove_optimism_; + const float overtaker_optimism_; + + mutable Mutex mutex_; + uint64_t prev_timestamp_ GUARDED_BY(mutex_) = 0; + std::vector prev_visits_ GUARDED_BY(mutex_); + uint64_t cur_timestamp_ GUARDED_BY(mutex_) = 0; + std::vector cur_visits_ GUARDED_BY(mutex_); + uint64_t last_timestamp_ GUARDED_BY(mutex_) = 0; + std::vector last_visits_ GUARDED_BY(mutex_); +}; + class SmoothStopper : public SearchStopper { public: SmoothStopper(int64_t deadline_ms, int64_t allowed_piggybank_use_ms, + float nps_update_period, float bestmove_optimism, + float overtaker_optimism, int64_t forces_piggybank_ms, SmoothTimeManager* manager); private: @@ -209,7 +260,9 @@ class SmoothStopper : public SearchStopper { const int64_t deadline_ms_; const int64_t allowed_piggybank_use_ms_; + const int64_t forced_piggybank_use_ms_; + VisitsTrendWatcher visits_trend_watcher_; SmoothTimeManager* const manager_; std::atomic_flag used_piggybank_; }; @@ -251,19 +304,19 @@ class SmoothTimeManager : public TimeManager { : total_move_time / move_allocated_time_ms_; // Recompute expected move time for logging. const float expected_move_time = move_allocated_time_ms_ * timeuse_; - // If piggybank was used, cannot update timeuse_. + int64_t piggybank_time_used = 0; if (used_piggybank) { piggybank_time_used = std::max(int64_t(), total_move_time - time_budget); piggybank_time_ -= piggybank_time_used; - } else { - timeuse_ = - ExponentialDecay(timeuse_, this_move_time_use, - params_.smartpruning_timeuse_halfupdate_moves(), - this_move_time_fraction); - if (timeuse_ < params_.min_smartpruning_timeuse()) { - timeuse_ = params_.min_smartpruning_timeuse(); - } + } + // If piggybank was used, time use is 100%. + timeuse_ = + ExponentialDecay(timeuse_, used_piggybank ? 1.0f : this_move_time_use, + params_.smartpruning_timeuse_halfupdate_moves(), + this_move_time_fraction); + if (timeuse_ < params_.min_smartpruning_timeuse()) { + timeuse_ = params_.min_smartpruning_timeuse(); } // Remember final number of nodes for tree reuse estimation. last_move_final_nodes_ = total_nodes; @@ -404,8 +457,10 @@ class SmoothTimeManager : public TimeManager { << ", moves=" << remaining_moves << ", time=" << total_remaining_ms << "ms, nps=" << nps_; - return std::make_unique(move_allocated_time_ms_, - allowed_piggybank_time_ms, this); + return std::make_unique( + move_allocated_time_ms_, allowed_piggybank_time_ms, + params_.trend_nps_update_period_ms(), params_.bestmove_optimism(), + params_.overtaker_optimism(), params_.force_piggybank_ms(), this); } void UpdateTreeReuseFactor(int64_t new_move_nodes) REQUIRES(mutex_) { @@ -460,11 +515,65 @@ class SmoothTimeManager : public TimeManager { bool is_first_move_ GUARDED_BY(mutex_) = true; }; +void VisitsTrendWatcher::Update(uint64_t timestamp, + const std::vector& visits) { + Mutex::Lock lock(mutex_); + if (timestamp <= last_timestamp_) return; + if (prev_visits_.empty()) { + prev_visits_ = visits; + cur_visits_ = visits; + prev_timestamp_ = timestamp; + cur_timestamp_ = timestamp; + } + last_timestamp_ = timestamp; + last_visits_ = visits; + if (cur_timestamp_ + nps_update_period_ >= timestamp) { + prev_timestamp_ = cur_timestamp_; + prev_visits_ = std::move(cur_visits_); + cur_visits_ = last_visits_; + cur_timestamp_ = last_timestamp_; + } +} + +bool VisitsTrendWatcher::IsBestmoveBeingOvertaken( + uint64_t by_which_time) const { + Mutex::Lock lock(mutex_); + // If we don't have any stats yet, we cannot stop the search. + if (prev_timestamp_ >= last_timestamp_) return false; + if (by_which_time <= last_timestamp_) return false; + std::vector npms; + npms.reserve(last_visits_.size()); + for (size_t i = 0; i < last_visits_.size(); ++i) { + npms.push_back(static_cast(last_visits_[i] - prev_visits_[i]) / + (last_timestamp_ - prev_timestamp_)); + } + const size_t bestmove_idx = + std::max_element(last_visits_.begin(), last_visits_.end()) - + last_visits_.begin(); + const auto planned_bestmove_visits = + last_visits_[bestmove_idx] + bestmove_optimism_ * npms[bestmove_idx] * + (by_which_time - last_timestamp_); + for (size_t i = 0; i < last_visits_.size(); ++i) { + if (i == bestmove_idx) continue; + const auto planned_visits = + last_visits_[i] + + overtaker_optimism_ * npms[i] * (by_which_time - last_timestamp_); + if (planned_visits > planned_bestmove_visits) return true; + } + return false; +} + SmoothStopper::SmoothStopper(int64_t deadline_ms, int64_t allowed_piggybank_use_ms, + float nps_update_period, float bestmove_optimism, + float overtaker_optimism, + int64_t forced_piggybank_use_ms, SmoothTimeManager* manager) : deadline_ms_(deadline_ms), allowed_piggybank_use_ms_(allowed_piggybank_use_ms), + forced_piggybank_use_ms_(forced_piggybank_use_ms), + visits_trend_watcher_(nps_update_period, bestmove_optimism, + overtaker_optimism), manager_(manager) { used_piggybank_.clear(); } @@ -478,10 +587,16 @@ bool SmoothStopper::ShouldStop(const IterationStats& stats, return true; } + visits_trend_watcher_.Update(stats.time_since_movestart, stats.edge_n); + const auto deadline_with_piggybank = deadline_ms_ + allowed_piggybank_use_ms_; + const bool force_use_piggybank = + stats.time_since_first_batch <= forced_piggybank_use_ms_; const bool use_piggybank = - (stats.time_usage_hint_ == IterationStats::TimeUsageHint::kNeedMoreTime); + (stats.time_usage_hint_ == IterationStats::TimeUsageHint::kNeedMoreTime || + force_use_piggybank || + visits_trend_watcher_.IsBestmoveBeingOvertaken(deadline_with_piggybank)); const int64_t time_limit = - use_piggybank ? (deadline_ms_ + allowed_piggybank_use_ms_) : deadline_ms_; + use_piggybank ? deadline_with_piggybank : deadline_ms_; hints->UpdateEstimatedNps(nps); hints->UpdateEstimatedRemainingTimeMs(time_limit - stats.time_since_movestart); @@ -489,10 +604,20 @@ bool SmoothStopper::ShouldStop(const IterationStats& stats, // It's not entirely correct as due to extended remaining time smart pruning // will trigger later and we spend more time than if use_piggyback was // false, even before reaching the deadline. - used_piggybank_.test_and_set(); + if (!used_piggybank_.test_and_set()) { + LOGFILE << "Entering piggybank, reason: " + << (stats.time_usage_hint_ == + IterationStats::TimeUsageHint::kNeedMoreTime + ? "requested by search." + : force_use_piggybank + ? "forced used in the beginning of the move." + : "bestmove can be overtaken."); + } } if (stats.time_since_movestart >= time_limit) { - LOGFILE << "Stopping search: Ran out of time."; + LOGFILE << "Stopping search: Ran out of time. elapsed=" << std::fixed + << stats.time_since_movestart << " limit=" << time_limit + << " piggy=" << use_piggybank; return true; } return false; diff --git a/src/mcts/stoppers/stoppers.cc b/src/mcts/stoppers/stoppers.cc index 8d6f7ae426..a66e4f8063 100644 --- a/src/mcts/stoppers/stoppers.cc +++ b/src/mcts/stoppers/stoppers.cc @@ -242,8 +242,9 @@ bool SmartPruningStopper::ShouldStop(const IterationStats& stats, } if (remaining_playouts < (largest_n - second_largest_n)) { - LOGFILE << remaining_playouts << " playouts remaining. Best move has " - << largest_n << " visits, second best -- " << second_largest_n + LOGFILE << std::fixed << remaining_playouts + << " playouts remaining. Best move has " << largest_n + << " visits, second best -- " << second_largest_n << ". Difference is " << (largest_n - second_largest_n) << ", so stopping the search after " << stats.batches_since_movestart << " batches.";