From d5c9328c60e54dadcca5b4c7ef0ef7fbfe88a390 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Sat, 9 May 2020 19:33:04 +0200 Subject: [PATCH] Move move filter population to a constructor. (#1281) --- src/chess/bitboard.h | 2 +- src/mcts/search.cc | 94 ++++++++++++++++++++++---------------------- src/mcts/search.h | 12 ++---- 3 files changed, 51 insertions(+), 57 deletions(-) diff --git a/src/chess/bitboard.h b/src/chess/bitboard.h index 63e77559b1..d66ea49a89 100644 --- a/src/chess/bitboard.h +++ b/src/chess/bitboard.h @@ -259,7 +259,7 @@ class Move { uint16_t as_nn_index(int transform) const; explicit operator bool() const { return data_ != 0; } - bool operator==(const Move& other) { return data_ == other.data_; } + bool operator==(const Move& other) const { return data_ == other.data_; } void Mirror() { data_ ^= 0b111000111000; } diff --git a/src/mcts/search.cc b/src/mcts/search.cc index 08553dfbf7..1741c04569 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -47,6 +47,28 @@ namespace lczero { namespace { // Maximum delay between outputting "uci info" when nothing interesting happens. const int kUciInfoMinimumFrequencyMs = 5000; + +MoveList MakeRootMoveFilter(const MoveList& searchmoves, + SyzygyTablebase* syzygy_tb, + const PositionHistory& history, bool fast_play, + std::atomic* tb_hits) { + // Search moves overrides tablebase. + if (!searchmoves.empty()) return searchmoves; + const auto& board = history.Last().GetBoard(); + MoveList root_moves; + if (!syzygy_tb || !board.castlings().no_legal_castle() || + (board.ours() | board.theirs()).count() > syzygy_tb->max_cardinality()) { + return root_moves; + } + if (syzygy_tb->root_probe( + history.Last(), fast_play || history.DidRepeatSinceLastZeroingMove(), + &root_moves) || + syzygy_tb->root_probe_wdl(history.Last(), &root_moves)) { + tb_hits->fetch_add(1, std::memory_order_acq_rel); + } + return root_moves; +} + } // namespace Search::Search(const NodeTree& tree, Network* network, @@ -63,11 +85,14 @@ Search::Search(const NodeTree& tree, Network* network, syzygy_tb_(syzygy_tb), played_history_(tree.GetPositionHistory()), network_(network), + params_(options), searchmoves_(searchmoves), start_time_(start_time), initial_visits_(root_node_->GetN()), - uci_responder_(std::move(uci_responder)), - params_(options) { + root_move_filter_( + MakeRootMoveFilter(searchmoves_, syzygy_tb_, played_history_, + params_.GetSyzygyFastPlay(), &tb_hits_)), + uci_responder_(std::move(uci_responder)) { if (params_.GetMaxConcurrentSearchers() != 0) { pending_searchers_.store(params_.GetMaxConcurrentSearchers(), std::memory_order_release); @@ -476,25 +501,6 @@ std::int64_t Search::GetTotalPlayouts() const { return total_playouts_; } -bool Search::PopulateRootMoveLimit(MoveList* root_moves) const { - // Search moves overrides tablebase. - if (!searchmoves_.empty()) { - *root_moves = searchmoves_; - return false; - } - auto board = played_history_.Last().GetBoard(); - if (!syzygy_tb_ || !board.castlings().no_legal_castle() || - (board.ours() | board.theirs()).count() > syzygy_tb_->max_cardinality()) { - return false; - } - return syzygy_tb_->root_probe( - played_history_.Last(), - params_.GetSyzygyFastPlay() || - played_history_.DidRepeatSinceLastZeroingMove(), - root_moves) || - syzygy_tb_->root_probe_wdl(played_history_.Last(), root_moves); -} - void Search::ResetBestMove() { SharedMutex::Lock nodes_lock(nodes_mutex_); Mutex::Lock lock(counters_mutex_); @@ -522,8 +528,9 @@ void Search::EnsureBestMoveKnown() REQUIRES(nodes_mutex_) if (moves >= decay_delay_moves + decay_moves) { temperature = 0.0; } else if (moves >= decay_delay_moves) { - temperature *= static_cast - (decay_delay_moves + decay_moves - moves) / decay_moves; + temperature *= + static_cast(decay_delay_moves + decay_moves - moves) / + decay_moves; } // don't allow temperature to decay below endgame temperature if (temperature < params_.GetTemperatureEndgame()) { @@ -543,10 +550,6 @@ void Search::EnsureBestMoveKnown() REQUIRES(nodes_mutex_) std::vector Search::GetBestChildrenNoTemperature(Node* parent, int count, int depth) const { - MoveList root_limit; - if (parent == root_node_) { - PopulateRootMoveLimit(&root_limit); - } const bool is_odd_depth = (depth % 2) == 1; const float draw_score = GetDrawScore(is_odd_depth); // Best child is selected using the following criteria: @@ -557,9 +560,9 @@ std::vector Search::GetBestChildrenNoTemperature(Node* parent, // * If that number is larger than 0, the one with larger eval wins. std::vector edges; for (auto edge : parent->Edges()) { - if (parent == root_node_ && !root_limit.empty() && - std::find(root_limit.begin(), root_limit.end(), edge.GetMove()) == - root_limit.end()) { + if (parent == root_node_ && !root_move_filter_.empty() && + std::find(root_move_filter_.begin(), root_move_filter_.end(), + edge.GetMove()) == root_move_filter_.end()) { continue; } edges.push_back(edge); @@ -648,8 +651,6 @@ EdgeAndNode Search::GetBestChildNoTemperature(Node* parent, int depth) const { EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const { // Root is at even depth. const float draw_score = GetDrawScore(/* is_odd_depth= */ false); - MoveList root_limit; - PopulateRootMoveLimit(&root_limit); std::vector cumulative_sums; float sum = 0.0; @@ -660,8 +661,9 @@ EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const { GetFpu(params_, root_node_, /* is_root= */ true, draw_score); for (auto edge : root_node_->Edges()) { - if (!root_limit.empty() && std::find(root_limit.begin(), root_limit.end(), - edge.GetMove()) == root_limit.end()) { + if (!root_move_filter_.empty() && + std::find(root_move_filter_.begin(), root_move_filter_.end(), + edge.GetMove()) == root_move_filter_.end()) { continue; } if (edge.GetN() + offset > max_n) { @@ -677,8 +679,9 @@ EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const { const float min_eval = max_eval - params_.GetTemperatureWinpctCutoff() / 50.0f; for (auto edge : root_node_->Edges()) { - if (!root_limit.empty() && std::find(root_limit.begin(), root_limit.end(), - edge.GetMove()) == root_limit.end()) { + if (!root_move_filter_.empty() && + std::find(root_move_filter_.begin(), root_move_filter_.end(), + edge.GetMove()) == root_move_filter_.end()) { continue; } if (edge.GetQ(fpu, draw_score, /* logit_q= */ false) < min_eval) continue; @@ -695,8 +698,9 @@ EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const { cumulative_sums.begin(); for (auto edge : root_node_->Edges()) { - if (!root_limit.empty() && std::find(root_limit.begin(), root_limit.end(), - edge.GetMove()) == root_limit.end()) { + if (!root_move_filter_.empty() && + std::find(root_move_filter_.begin(), root_move_filter_.end(), + edge.GetMove()) == root_move_filter_.end()) { continue; } if (edge.GetQ(fpu, draw_score, /* logit_q= */ false) < min_eval) continue; @@ -902,13 +906,6 @@ void SearchWorker::InitializeIteration( computation_ = std::make_unique(std::move(computation), search_->cache_); minibatch_.clear(); - - if (!root_move_filter_populated_) { - root_move_filter_populated_ = true; - if (search_->PopulateRootMoveLimit(&root_move_filter_)) { - search_->tb_hits_.fetch_add(1, std::memory_order_acq_rel); - } - } } // 2. Gather minibatch. @@ -1021,6 +1018,7 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend( bool is_root_node = true; const float even_draw_score = search_->GetDrawScore(false); const float odd_draw_score = search_->GetDrawScore(true); + const auto& root_move_filter = search_->root_move_filter_; uint16_t depth = 0; bool node_already_updated = true; @@ -1096,9 +1094,9 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend( continue; } // If root move filter exists, make sure move is in the list. - if (!root_move_filter_.empty() && - std::find(root_move_filter_.begin(), root_move_filter_.end(), - child.GetMove()) == root_move_filter_.end()) { + if (!root_move_filter.empty() && + std::find(root_move_filter.begin(), root_move_filter.end(), + child.GetMove()) == root_move_filter.end()) { continue; } } diff --git a/src/mcts/search.h b/src/mcts/search.h index ff5b703c8e..ccb40282d0 100644 --- a/src/mcts/search.h +++ b/src/mcts/search.h @@ -121,10 +121,6 @@ class Search { // uci `stop` command; void WatchdogThread(); - // Populates the given list with allowed root moves. - // Returns true if the population came from tablebase. - bool PopulateRootMoveLimit(MoveList* root_moves) const; - // Fills IterationStats with global (rather than per-thread) portion of search // statistics. Currently all stats there (in IterationStats) are global // though. @@ -173,9 +169,13 @@ class Search { const PositionHistory& played_history_; Network* const network_; + const SearchParams params_; const MoveList searchmoves_; const std::chrono::steady_clock::time_point start_time_; const int64_t initial_visits_; + // tb_hits_ must be initialized before root_move_filter_. + std::atomic tb_hits_{0}; + const MoveList root_move_filter_; mutable SharedMutex nodes_mutex_; EdgeAndNode current_best_edge_ GUARDED_BY(nodes_mutex_); @@ -188,7 +188,6 @@ class Search { // Cumulative depth of all paths taken in PickNodetoExtend. uint64_t cum_depth_ GUARDED_BY(nodes_mutex_) = 0; std::optional nps_start_time_; - std::atomic tb_hits_{0}; std::atomic pending_searchers_{0}; @@ -196,7 +195,6 @@ class Search { GUARDED_BY(nodes_mutex_); std::unique_ptr uci_responder_; - const SearchParams params_; friend class SearchWorker; }; @@ -320,8 +318,6 @@ class SearchWorker { std::unique_ptr computation_; // History is reset and extended by PickNodeToExtend(). PositionHistory history_; - MoveList root_move_filter_; - bool root_move_filter_populated_ = false; int number_out_of_order_ = 0; const SearchParams& params_; std::unique_ptr precached_node_;