Skip to content

Commit

Permalink
Merge pull request LeelaChessZero#2 from Ergodice/r50-adjust
Browse files Browse the repository at this point in the history
R50 adjust
  • Loading branch information
Ergodice authored and Craftyawesome committed Jul 18, 2023
1 parent 7de757e commit 20b5e15
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 53 deletions.
11 changes: 6 additions & 5 deletions src/chess/position.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ Position::Position(const ChessBoard& board, int rule50_ply, int game_ply)
them_board_.Mirror();
}

uint64_t Position::Hash() const {
return us_board_.Hash();
}
uint64_t Position::Hash() const { return us_board_.Hash(); }

std::string Position::DebugString() const { return us_board_.DebugString(); }

Expand Down Expand Up @@ -150,14 +148,17 @@ bool PositionHistory::DidRepeatSinceLastZeroingMove() const {
return false;
}

uint64_t PositionHistory::HashLast(int positions) const {
uint64_t PositionHistory::HashLast(int positions, int r50_ply) const {
uint64_t hash = positions;
for (auto iter = positions_.rbegin(), end = positions_.rend(); iter != end;
++iter) {
if (!positions--) break;
hash = HashCat(hash, iter->Hash());
}
return HashCat(hash, Last().GetRule50Ply());
if (r50_ply < 0) {
r50_ply = Last().GetRule50Ply();
}
return HashCat(hash, r50_ply);
}

std::string GetFen(const Position& pos) {
Expand Down
2 changes: 1 addition & 1 deletion src/chess/position.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class PositionHistory {
bool IsBlackToMove() const { return Last().IsBlackToMove(); }

// Builds a hash from last X positions.
uint64_t HashLast(int positions) const;
uint64_t HashLast(int positions, int r50_ply = -1) const;

// Checks for any repetitions since the last time 50 move rule was reset.
bool DidRepeatSinceLastZeroingMove() const;
Expand Down
7 changes: 4 additions & 3 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,6 @@ void LowNode::FinalizeScoreUpdate(float v, float d, float m, float vs,
m_ += multivisit * (m - m_) / (n_ + multivisit);
vs_ += multivisit * (vs - vs_) / (n_ + multivisit);


assert(WLDMInvariantsHold());

// Increment N.
Expand All @@ -386,7 +385,8 @@ void LowNode::AdjustForTerminal(float v, float d, float m, float vs,
assert(WLDMInvariantsHold());
}

void Node::FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit) {
void Node::FinalizeScoreUpdate(float v, float d, float m, float vs,
uint32_t multivisit) {
// Recompute Q.
wl_ += multivisit * (v - wl_) / (n_ + multivisit);
d_ += multivisit * (d - d_) / (n_ + multivisit);
Expand All @@ -402,7 +402,8 @@ void Node::FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t mul
n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel);
}

void Node::AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit) {
void Node::AdjustForTerminal(float v, float d, float m, float vs,
uint32_t multivisit) {
assert(static_cast<uint32_t>(multivisit) <= n_);

// Recompute Q.
Expand Down
24 changes: 16 additions & 8 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ class Node {
float GetM() const { return m_; }
float GetVS() const { return vs_; }


// Returns whether the node is known to be draw/lose/win.
bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; }
bool IsTbTerminal() const { return terminal_type_ == Terminal::Tablebase; }
Expand Down Expand Up @@ -320,9 +319,11 @@ class Node {
// * Q (weighted average of all V in a subtree)
// * N (+=multivisit)
// * N-in-flight (-=multivisit)
void FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit);
void FinalizeScoreUpdate(float v, float d, float m, float vs,
uint32_t multivisit);
// Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
void AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit);
void AdjustForTerminal(float v, float d, float m, float vs,
uint32_t multivisit);
// When search decides to treat one visit as several (in case of collisions
// or visiting terminal nodes several times), it amplifies the visit by
// incrementing n_in_flight.
Expand Down Expand Up @@ -453,6 +454,7 @@ class LowNode {
// For non-TT nodes.
LowNode(const LowNode& p)
: wl_(p.wl_),
v_(p.wl_),
d_(p.d_),
hash_(p.hash_),
m_(p.m_),
Expand Down Expand Up @@ -492,6 +494,7 @@ class LowNode {
eval->num_edges * sizeof(Edge));

wl_ = eval->q;
v_ = eval->q;
d_ = eval->d;
m_ = eval->m;
vs_ = wl_ * wl_;
Expand All @@ -513,11 +516,11 @@ class LowNode {
// Returns node eval, i.e. average subtree V for non-terminal node and -1/0/1
// for terminal nodes.
float GetWL() const { return wl_; }
float GetV() const { return v_; }
float GetD() const { return d_; }
float GetM() const { return m_; }
float GetVS() const { return vs_; }


// Returns whether the node is known to be draw/loss/win.
bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; }
Bounds GetBounds() const { return {lower_bound_, upper_bound_}; }
Expand All @@ -542,9 +545,11 @@ class LowNode {
// * Q (weighted average of all V in a subtree)
// * N (+=multivisit)
// * N-in-flight (-=multivisit)
void FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit);
void FinalizeScoreUpdate(float v, float d, float m, float vs,
uint32_t multivisit);
// Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
void AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit);
void AdjustForTerminal(float v, float d, float m, float vs,
uint32_t multivisit);

// Deletes all children.
void ReleaseChildren(GCQueue* gc_queue);
Expand Down Expand Up @@ -618,6 +623,8 @@ class LowNode {
// 4 byte fields.
// Estimated remaining plies.
float m_ = 0.0f;
// original eval
float v_ = 0.0f;
// How many completed visits this node had.
uint32_t n_ = 0;

Expand Down Expand Up @@ -975,8 +982,9 @@ class NodeTree {
size_t AllocatedNodeCount() const { return tt_.size() + non_tt_.size(); };

// Get position hash used for TT nodes and NN cache.
uint64_t GetHistoryHash(const PositionHistory& history) const {
return history.HashLast(hash_history_length_);
uint64_t GetHistoryHash(const PositionHistory& history,
int r50_ply = -1) const {
return history.HashLast(hash_history_length_, r50_ply);
}

private:
Expand Down
107 changes: 77 additions & 30 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ float Search::GetDrawScore(bool is_odd_depth) const {
}

namespace {


inline float ComputeStdev(const SearchParams& params, float q, uint32_t n,
float vs) {
Expand Down Expand Up @@ -449,11 +448,9 @@ inline float ComputeStdevFactor(const SearchParams& params, float q, uint32_t n,
}

inline float ComputeStdevFactor(const SearchParams& params, Node* node) {
return ComputeStdevFactor(params, node->GetWL(), node->GetN(),
node->GetVS());
return ComputeStdevFactor(params, node->GetWL(), node->GetN(), node->GetVS());
}


inline float GetFpu(const SearchParams& params, Node* node, bool is_root_node,
float draw_score) {
const auto value = params.GetFpuValue(is_root_node);
Expand All @@ -480,12 +477,13 @@ inline float ComputeCpuct(const SearchParams& params, uint32_t N,
return init + (k ? k * FastLog((N + base) / base) : 0.0f);
}

inline float ComputeCpuct(const SearchParams& params, uint32_t n, float q, float vs, bool is_root_node) {
inline float ComputeCpuct(const SearchParams& params, uint32_t n, float q,
float vs, bool is_root_node) {
const float init = params.GetCpuct(is_root_node);
const float k = params.GetCpuctFactor(is_root_node);
const float base = params.GetCpuctBase(is_root_node);
const float stdev_factor = ComputeStdevFactor(params, q, n, vs);
return stdev_factor * (init + (k ? k * FastLog((n + base) / base) : 0.0f)) ;
return stdev_factor * (init + (k ? k * FastLog((n + base) / base) : 0.0f));
}

} // namespace
Expand All @@ -496,8 +494,8 @@ std::vector<std::string> Search::GetVerboseStats(Node* node) const {
const bool is_black_to_move = (played_history_.IsBlackToMove() == is_root);
const float draw_score = GetDrawScore(is_odd_depth);
const float fpu = GetFpu(params_, node, is_root, draw_score);
const float cpuct =
ComputeCpuct(params_, node->GetTotalVisits(), node->GetWL(), node->GetVS(), is_root);
const float cpuct = ComputeCpuct(params_, node->GetTotalVisits(),
node->GetWL(), node->GetVS(), is_root);
const float U_coeff =
cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
std::vector<EdgeAndNode> edges;
Expand Down Expand Up @@ -529,14 +527,13 @@ std::vector<std::string> Search::GetVerboseStats(Node* node) const {
print(oss, "(WL: ", sign * n->GetWL(), ") ", 8, 5);
print(oss, "(D: ", n->GetD(), ") ", 5, 3);
print(oss, "(M: ", n->GetM(), ") ", 4, 1);
print(oss, "(STD: ", ComputeStdev(params_, n->GetWL(), n->GetN(), n->GetVS()),
") ",
6, 5);
print(oss,
"(STD: ", ComputeStdev(params_, n->GetWL(), n->GetN(), n->GetVS()),
") ", 6, 5);
print(oss, "(STDF: ",
ComputeStdevFactor(params_, n->GetWL(), n->GetN(), n->GetVS()),
") ", 6, 5);
print(oss, "(VS: ", n->GetVS(),
") ", 6, 5);
print(oss, "(VS: ", n->GetVS(), ") ", 6, 5);
} else {
*oss << "(WL: -.-----) (D: -.---) (M: -.-) ";
}
Expand Down Expand Up @@ -568,13 +565,10 @@ std::vector<std::string> Search::GetVerboseStats(Node* node) const {
up = -up;
std::swap(lo, up);
}
*oss << (lo == up
? "(T) "
: lo == GameResult::DRAW && up == GameResult::WHITE_WON
? "(W) "
: lo == GameResult::BLACK_WON && up == GameResult::DRAW
? "(L) "
: "");
*oss << (lo == up ? "(T) "
: lo == GameResult::DRAW && up == GameResult::WHITE_WON ? "(W) "
: lo == GameResult::BLACK_WON && up == GameResult::DRAW ? "(L) "
: "");
}
};

Expand Down Expand Up @@ -1733,7 +1727,8 @@ void SearchWorker::PickNodesToExtendTask(
}

const float cpuct =
ComputeCpuct(params_, node->GetTotalVisits(), node->GetWL(), node->GetVS(), is_root_node);
ComputeCpuct(params_, node->GetTotalVisits(), node->GetWL(),
node->GetVS(), is_root_node);
const float puct_mult =
cpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u));
int cache_filled_idx = -1;
Expand Down Expand Up @@ -2020,6 +2015,45 @@ void SearchWorker::ExtendNode(NodeToProcess& picked_node) {
picked_node.tt_low_node = tt_low_node;
picked_node.is_tt_hit = true;
} else {
// Make list of 100 LowNode references, one for each ply left

int my_ply = picked_node.GetRule50Ply();
float early_q = 99.0f, late_q = 0.0f;
int early_visits, late_visits;

float err_total = 0;
float weight_total = 0;

LowNode* r50_low_nodes[100];
for (int r50_ply = 0; r50_ply < 100; r50_ply++) {
uint64_t hash = search_->dag_->GetHistoryHash(history, r50_ply);
auto r50_low_node = search_->dag_->TTFind(hash);
r50_low_nodes[r50_ply] = r50_low_node;
if (r50_low_node != nullptr) {
int n = r50_low_node->GetN();
float q = r50_low_node->GetWL();
float v = r50_low_node->GetV();
// katago applies an exponent ~.3 < 1 so that the error calculation is
// not dominated by a few nodes
err_total += n * (q - v);
weight_total += n;

// it is not possible for your ply to equal mine
if (r50_ply < my_ply) {
if (n > early_visits) {
early_q = q;
}
} else {
if (n > late_visits) {
late_q = q;
}
}
}
}
picked_node.comrade_error = err_total / (weight_total + 0.001f);
picked_node.early_q = early_q;
picked_node.late_q = late_q;

picked_node.lock = NNCacheLock(search_->cache_, picked_node.hash);
picked_node.is_cache_hit = picked_node.lock;
}
Expand Down Expand Up @@ -2089,6 +2123,19 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process,
nn_eval->d = d;
}
}
// after wdl scaling we adjust based on r50
// 0.5 constant is chosen arbitrarily
// TODO: some stuff to consider with differing signs
float q_adjusted = nn_eval->q + node_to_process->comrade_error * 0.5;
float early_q = node_to_process->early_q,
late_q = node_to_process->late_q;
if (abs(q_adjusted) < abs(late_q)) {
q_adjusted = late_q;
}
if (abs(q_adjusted) > abs(early_q)) {
q_adjusted = early_q;
}
nn_eval->q = q_adjusted;
node_to_process->tt_low_node->SetNNEval(nn_eval);
}
}
Expand Down Expand Up @@ -2128,8 +2175,8 @@ void SearchWorker::DoBackupUpdate() {

bool SearchWorker::MaybeAdjustForTerminalOrTransposition(
Node* n, const LowNode* nl, float& v, float& d, float& m, float& vs,
uint32_t& n_to_fix, float& v_delta, float& d_delta, float& m_delta, float& vs_delta,
bool& update_parent_bounds) const {
uint32_t& n_to_fix, float& v_delta, float& d_delta, float& m_delta,
float& vs_delta, bool& update_parent_bounds) const {
if (n->IsTerminal()) {
v = n->GetWL();
d = n->GetD();
Expand Down Expand Up @@ -2207,7 +2254,6 @@ void SearchWorker::DoBackupUpdateSingleNode(
float d_delta = 0.0f;
float m_delta = 0.0f;
float vs_delta = 0.0f;


// Update the low node at the start of the backup path first, but only visit
// it the first time that backup sees it.
Expand All @@ -2223,9 +2269,9 @@ void SearchWorker::DoBackupUpdateSingleNode(
v = 0.0f;
d = 1.0f;
m = 1;
} else if (!MaybeAdjustForTerminalOrTransposition(n, nl, v, d, m, vs, n_to_fix,
v_delta, d_delta, m_delta, vs_delta,
update_parent_bounds)) {
} else if (!MaybeAdjustForTerminalOrTransposition(
n, nl, v, d, m, vs, n_to_fix, v_delta, d_delta, m_delta,
vs_delta, update_parent_bounds)) {
// If there is nothing better, use original NN values adjusted for node.
v = -nl->GetWL();
d = nl->GetD();
Expand Down Expand Up @@ -2278,9 +2324,10 @@ void SearchWorker::DoBackupUpdateSingleNode(

bool old_update_parent_bounds = update_parent_bounds;
// Try setting parent bounds except the root or those already terminal.
update_parent_bounds =
update_parent_bounds && p != search_->root_node_ && !pl->IsTerminal() &&
MaybeSetBounds(p, m, &n_to_fix, &v_delta, &d_delta, &m_delta, &vs_delta);
update_parent_bounds = update_parent_bounds && p != search_->root_node_ &&
!pl->IsTerminal() &&
MaybeSetBounds(p, m, &n_to_fix, &v_delta, &d_delta,
&m_delta, &vs_delta);

// Q will be flipped for opponent.
v = -v;
Expand Down
Loading

0 comments on commit 20b5e15

Please sign in to comment.