From 20b5e1552aa40850ecc5198a3088ffc1c2421a11 Mon Sep 17 00:00:00 2001 From: Daniel Monroe <39802758+Ergodice@users.noreply.github.com> Date: Fri, 7 Jul 2023 15:29:14 -0400 Subject: [PATCH] Merge pull request #2 from Ergodice/r50-adjust R50 adjust --- src/chess/position.cc | 11 +++-- src/chess/position.h | 2 +- src/mcts/node.cc | 7 +-- src/mcts/node.h | 24 ++++++---- src/mcts/search.cc | 107 ++++++++++++++++++++++++++++++------------ src/mcts/search.h | 20 +++++--- 6 files changed, 118 insertions(+), 53 deletions(-) diff --git a/src/chess/position.cc b/src/chess/position.cc index 9efbde0f6a..76a041877e 100644 --- a/src/chess/position.cc +++ b/src/chess/position.cc @@ -77,9 +77,7 @@ Position::Position(const ChessBoard& board, int rule50_ply, int game_ply) them_board_.Mirror(); } -uint64_t Position::Hash() const { - return us_board_.Hash(); -} +uint64_t Position::Hash() const { return us_board_.Hash(); } std::string Position::DebugString() const { return us_board_.DebugString(); } @@ -150,14 +148,17 @@ bool PositionHistory::DidRepeatSinceLastZeroingMove() const { return false; } -uint64_t PositionHistory::HashLast(int positions) const { +uint64_t PositionHistory::HashLast(int positions, int r50_ply) const { uint64_t hash = positions; for (auto iter = positions_.rbegin(), end = positions_.rend(); iter != end; ++iter) { if (!positions--) break; hash = HashCat(hash, iter->Hash()); } - return HashCat(hash, Last().GetRule50Ply()); + if (r50_ply < 0) { + r50_ply = Last().GetRule50Ply(); + } + return HashCat(hash, r50_ply); } std::string GetFen(const Position& pos) { diff --git a/src/chess/position.h b/src/chess/position.h index 879e08fb2c..6481b8554d 100644 --- a/src/chess/position.h +++ b/src/chess/position.h @@ -154,7 +154,7 @@ class PositionHistory { bool IsBlackToMove() const { return Last().IsBlackToMove(); } // Builds a hash from last X positions. - uint64_t HashLast(int positions) const; + uint64_t HashLast(int positions, int r50_ply = -1) const; // Checks for any repetitions since the last time 50 move rule was reset. bool DidRepeatSinceLastZeroingMove() const; diff --git a/src/mcts/node.cc b/src/mcts/node.cc index 5ae2bfaed8..1352d26690 100644 --- a/src/mcts/node.cc +++ b/src/mcts/node.cc @@ -366,7 +366,6 @@ void LowNode::FinalizeScoreUpdate(float v, float d, float m, float vs, m_ += multivisit * (m - m_) / (n_ + multivisit); vs_ += multivisit * (vs - vs_) / (n_ + multivisit); - assert(WLDMInvariantsHold()); // Increment N. @@ -386,7 +385,8 @@ void LowNode::AdjustForTerminal(float v, float d, float m, float vs, assert(WLDMInvariantsHold()); } -void Node::FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit) { +void Node::FinalizeScoreUpdate(float v, float d, float m, float vs, + uint32_t multivisit) { // Recompute Q. wl_ += multivisit * (v - wl_) / (n_ + multivisit); d_ += multivisit * (d - d_) / (n_ + multivisit); @@ -402,7 +402,8 @@ void Node::FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t mul n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel); } -void Node::AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit) { +void Node::AdjustForTerminal(float v, float d, float m, float vs, + uint32_t multivisit) { assert(static_cast(multivisit) <= n_); // Recompute Q. diff --git a/src/mcts/node.h b/src/mcts/node.h index 4770d21efb..4d550eb8fa 100644 --- a/src/mcts/node.h +++ b/src/mcts/node.h @@ -292,7 +292,6 @@ class Node { float GetM() const { return m_; } float GetVS() const { return vs_; } - // Returns whether the node is known to be draw/lose/win. bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; } bool IsTbTerminal() const { return terminal_type_ == Terminal::Tablebase; } @@ -320,9 +319,11 @@ class Node { // * Q (weighted average of all V in a subtree) // * N (+=multivisit) // * N-in-flight (-=multivisit) - void FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit); + void FinalizeScoreUpdate(float v, float d, float m, float vs, + uint32_t multivisit); // Like FinalizeScoreUpdate, but it updates n existing visits by delta amount. - void AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit); + void AdjustForTerminal(float v, float d, float m, float vs, + uint32_t multivisit); // When search decides to treat one visit as several (in case of collisions // or visiting terminal nodes several times), it amplifies the visit by // incrementing n_in_flight. @@ -453,6 +454,7 @@ class LowNode { // For non-TT nodes. LowNode(const LowNode& p) : wl_(p.wl_), + v_(p.wl_), d_(p.d_), hash_(p.hash_), m_(p.m_), @@ -492,6 +494,7 @@ class LowNode { eval->num_edges * sizeof(Edge)); wl_ = eval->q; + v_ = eval->q; d_ = eval->d; m_ = eval->m; vs_ = wl_ * wl_; @@ -513,11 +516,11 @@ class LowNode { // Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1 // for terminal nodes. float GetWL() const { return wl_; } + float GetV() const { return v_; } float GetD() const { return d_; } float GetM() const { return m_; } float GetVS() const { return vs_; } - // Returns whether the node is known to be draw/loss/win. bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; } Bounds GetBounds() const { return {lower_bound_, upper_bound_}; } @@ -542,9 +545,11 @@ class LowNode { // * Q (weighted average of all V in a subtree) // * N (+=multivisit) // * N-in-flight (-=multivisit) - void FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit); + void FinalizeScoreUpdate(float v, float d, float m, float vs, + uint32_t multivisit); // Like FinalizeScoreUpdate, but it updates n existing visits by delta amount. - void AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit); + void AdjustForTerminal(float v, float d, float m, float vs, + uint32_t multivisit); // Deletes all children. void ReleaseChildren(GCQueue* gc_queue); @@ -618,6 +623,8 @@ class LowNode { // 4 byte fields. // Estimated remaining plies. float m_ = 0.0f; + // original eval + float v_ = 0.0f; // How many completed visits this node had. uint32_t n_ = 0; @@ -975,8 +982,9 @@ class NodeTree { size_t AllocatedNodeCount() const { return tt_.size() + non_tt_.size(); }; // Get position hash used for TT nodes and NN cache. - uint64_t GetHistoryHash(const PositionHistory& history) const { - return history.HashLast(hash_history_length_); + uint64_t GetHistoryHash(const PositionHistory& history, + int r50_ply = -1) const { + return history.HashLast(hash_history_length_, r50_ply); } private: diff --git a/src/mcts/search.cc b/src/mcts/search.cc index b4455acd4c..6e5896c361 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -412,7 +412,6 @@ float Search::GetDrawScore(bool is_odd_depth) const { } namespace { - inline float ComputeStdev(const SearchParams& params, float q, uint32_t n, float vs) { @@ -449,11 +448,9 @@ inline float ComputeStdevFactor(const SearchParams& params, float q, uint32_t n, } inline float ComputeStdevFactor(const SearchParams& params, Node* node) { - return ComputeStdevFactor(params, node->GetWL(), node->GetN(), - node->GetVS()); + return ComputeStdevFactor(params, node->GetWL(), node->GetN(), node->GetVS()); } - inline float GetFpu(const SearchParams& params, Node* node, bool is_root_node, float draw_score) { const auto value = params.GetFpuValue(is_root_node); @@ -480,12 +477,13 @@ inline float ComputeCpuct(const SearchParams& params, uint32_t N, return init + (k ? k * FastLog((N + base) / base) : 0.0f); } -inline float ComputeCpuct(const SearchParams& params, uint32_t n, float q, float vs, bool is_root_node) { +inline float ComputeCpuct(const SearchParams& params, uint32_t n, float q, + float vs, bool is_root_node) { const float init = params.GetCpuct(is_root_node); const float k = params.GetCpuctFactor(is_root_node); const float base = params.GetCpuctBase(is_root_node); const float stdev_factor = ComputeStdevFactor(params, q, n, vs); - return stdev_factor * (init + (k ? k * FastLog((n + base) / base) : 0.0f)) ; + return stdev_factor * (init + (k ? k * FastLog((n + base) / base) : 0.0f)); } } // namespace @@ -496,8 +494,8 @@ std::vector Search::GetVerboseStats(Node* node) const { const bool is_black_to_move = (played_history_.IsBlackToMove() == is_root); const float draw_score = GetDrawScore(is_odd_depth); const float fpu = GetFpu(params_, node, is_root, draw_score); - const float cpuct = - ComputeCpuct(params_, node->GetTotalVisits(), node->GetWL(), node->GetVS(), is_root); + const float cpuct = ComputeCpuct(params_, node->GetTotalVisits(), + node->GetWL(), node->GetVS(), is_root); const float U_coeff = cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u)); std::vector edges; @@ -529,14 +527,13 @@ std::vector Search::GetVerboseStats(Node* node) const { print(oss, "(WL: ", sign * n->GetWL(), ") ", 8, 5); print(oss, "(D: ", n->GetD(), ") ", 5, 3); print(oss, "(M: ", n->GetM(), ") ", 4, 1); - print(oss, "(STD: ", ComputeStdev(params_, n->GetWL(), n->GetN(), n->GetVS()), - ") ", - 6, 5); + print(oss, + "(STD: ", ComputeStdev(params_, n->GetWL(), n->GetN(), n->GetVS()), + ") ", 6, 5); print(oss, "(STDF: ", ComputeStdevFactor(params_, n->GetWL(), n->GetN(), n->GetVS()), ") ", 6, 5); - print(oss, "(VS: ", n->GetVS(), - ") ", 6, 5); + print(oss, "(VS: ", n->GetVS(), ") ", 6, 5); } else { *oss << "(WL: -.-----) (D: -.---) (M: -.-) "; } @@ -568,13 +565,10 @@ std::vector Search::GetVerboseStats(Node* node) const { up = -up; std::swap(lo, up); } - *oss << (lo == up - ? "(T) " - : lo == GameResult::DRAW && up == GameResult::WHITE_WON - ? "(W) " - : lo == GameResult::BLACK_WON && up == GameResult::DRAW - ? "(L) " - : ""); + *oss << (lo == up ? "(T) " + : lo == GameResult::DRAW && up == GameResult::WHITE_WON ? "(W) " + : lo == GameResult::BLACK_WON && up == GameResult::DRAW ? "(L) " + : ""); } }; @@ -1733,7 +1727,8 @@ void SearchWorker::PickNodesToExtendTask( } const float cpuct = - ComputeCpuct(params_, node->GetTotalVisits(), node->GetWL(), node->GetVS(), is_root_node); + ComputeCpuct(params_, node->GetTotalVisits(), node->GetWL(), + node->GetVS(), is_root_node); const float puct_mult = cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u)); int cache_filled_idx = -1; @@ -2020,6 +2015,45 @@ void SearchWorker::ExtendNode(NodeToProcess& picked_node) { picked_node.tt_low_node = tt_low_node; picked_node.is_tt_hit = true; } else { + // Make list of 100 LowNode references, one for each ply left + + int my_ply = picked_node.GetRule50Ply(); + float early_q = 99.0f, late_q = 0.0f; + int early_visits, late_visits; + + float err_total = 0; + float weight_total = 0; + + LowNode* r50_low_nodes[100]; + for (int r50_ply = 0; r50_ply < 100; r50_ply++) { + uint64_t hash = search_->dag_->GetHistoryHash(history, r50_ply); + auto r50_low_node = search_->dag_->TTFind(hash); + r50_low_nodes[r50_ply] = r50_low_node; + if (r50_low_node != nullptr) { + int n = r50_low_node->GetN(); + float q = r50_low_node->GetWL(); + float v = r50_low_node->GetV(); + // katago applies an exponent ~.3 < 1 so that the error calculation is + // not dominated by a few nodes + err_total += n * (q - v); + weight_total += n; + + // it is not possible for your ply to equal mine + if (r50_ply < my_ply) { + if (n > early_visits) { + early_q = q; + } + } else { + if (n > late_visits) { + late_q = q; + } + } + } + } + picked_node.comrade_error = err_total / (weight_total + 0.001f); + picked_node.early_q = early_q; + picked_node.late_q = late_q; + picked_node.lock = NNCacheLock(search_->cache_, picked_node.hash); picked_node.is_cache_hit = picked_node.lock; } @@ -2089,6 +2123,19 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process, nn_eval->d = d; } } + // after wdl scaling we adjust based on r50 + // 0.5 constant is chosen arbitrarily + // TODO: some stuff to consider with differing signs + float q_adjusted = nn_eval->q + node_to_process->comrade_error * 0.5; + float early_q = node_to_process->early_q, + late_q = node_to_process->late_q; + if (abs(q_adjusted) < abs(late_q)) { + q_adjusted = late_q; + } + if (abs(q_adjusted) > abs(early_q)) { + q_adjusted = early_q; + } + nn_eval->q = q_adjusted; node_to_process->tt_low_node->SetNNEval(nn_eval); } } @@ -2128,8 +2175,8 @@ void SearchWorker::DoBackupUpdate() { bool SearchWorker::MaybeAdjustForTerminalOrTransposition( Node* n, const LowNode* nl, float& v, float& d, float& m, float& vs, - uint32_t& n_to_fix, float& v_delta, float& d_delta, float& m_delta, float& vs_delta, - bool& update_parent_bounds) const { + uint32_t& n_to_fix, float& v_delta, float& d_delta, float& m_delta, + float& vs_delta, bool& update_parent_bounds) const { if (n->IsTerminal()) { v = n->GetWL(); d = n->GetD(); @@ -2207,7 +2254,6 @@ void SearchWorker::DoBackupUpdateSingleNode( float d_delta = 0.0f; float m_delta = 0.0f; float vs_delta = 0.0f; - // Update the low node at the start of the backup path first, but only visit // it the first time that backup sees it. @@ -2223,9 +2269,9 @@ void SearchWorker::DoBackupUpdateSingleNode( v = 0.0f; d = 1.0f; m = 1; - } else if (!MaybeAdjustForTerminalOrTransposition(n, nl, v, d, m, vs, n_to_fix, - v_delta, d_delta, m_delta, vs_delta, - update_parent_bounds)) { + } else if (!MaybeAdjustForTerminalOrTransposition( + n, nl, v, d, m, vs, n_to_fix, v_delta, d_delta, m_delta, + vs_delta, update_parent_bounds)) { // If there is nothing better, use original NN values adjusted for node. v = -nl->GetWL(); d = nl->GetD(); @@ -2278,9 +2324,10 @@ void SearchWorker::DoBackupUpdateSingleNode( bool old_update_parent_bounds = update_parent_bounds; // Try setting parent bounds except the root or those already terminal. - update_parent_bounds = - update_parent_bounds && p != search_->root_node_ && !pl->IsTerminal() && - MaybeSetBounds(p, m, &n_to_fix, &v_delta, &d_delta, &m_delta, &vs_delta); + update_parent_bounds = update_parent_bounds && p != search_->root_node_ && + !pl->IsTerminal() && + MaybeSetBounds(p, m, &n_to_fix, &v_delta, &d_delta, + &m_delta, &vs_delta); // Q will be flipped for opponent. v = -v; diff --git a/src/mcts/search.h b/src/mcts/search.h index 66259db420..0e0e875ecf 100644 --- a/src/mcts/search.h +++ b/src/mcts/search.h @@ -222,9 +222,7 @@ class SearchWorker { search_->network_->InitThread(id); for (int i = 0; i < params.GetTaskWorkersPerSearchWorker(); i++) { task_workspaces_.emplace_back(); - task_threads_.emplace_back([this, i]() { - this->RunTasks(i); - }); + task_threads_.emplace_back([this, i]() { this->RunTasks(i); }); } } @@ -300,6 +298,7 @@ class SearchWorker { node->GetLowNode(); } bool ShouldAddToInput() const { return nn_queried && !is_tt_hit; } + int GetRule50Ply() const { return history.Last().GetRule50Ply(); } // The path to the node to extend. BackupPath path; @@ -315,8 +314,14 @@ class SearchWorker { bool is_cache_hit = false; bool is_collision = false; + // values for improving r50 estimates, filled in as we go + float early_q; + float late_q; + float comrade_error; + // Details that are filled in as we go. uint64_t hash; + LowNode* tt_low_node; NNCacheLock lock; PositionHistory history; @@ -334,6 +339,8 @@ class SearchWorker { return NodeToProcess(path, history); } + void SetR50Bounds(NodeTree* dag) {} + // Method to allow NodeToProcess to conform as a 'Computation'. Only safe // to call if is_cache_hit is true in the multigather path. std::shared_ptr GetNNEval(int) const { return lock->eval; } @@ -432,9 +439,10 @@ class SearchWorker { // terminal status of node @n using information from its child low node. // Return true if adjustment happened. bool MaybeAdjustForTerminalOrTransposition(Node* n, const LowNode* nl, - float& v, float& d, float& m, float& vs, - uint32_t& n_to_fix, float& v_delta, - float& d_delta, float& m_delta, float& vs_delta, + float& v, float& d, float& m, + float& vs, uint32_t& n_to_fix, + float& v_delta, float& d_delta, + float& m_delta, float& vs_delta, bool& update_parent_bounds) const; void DoBackupUpdateSingleNode(const NodeToProcess& node_to_process); // Returns whether a node's bounds were set based on its children.