Skip to content

Commit

Permalink
Merge pull request LeelaChessZero#3 from Ergodice/r50-adjust
Browse files Browse the repository at this point in the history
R50 adjust
  • Loading branch information
Ergodice authored Jul 15, 2023
2 parents b819495 + edaad59 commit c545246
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 56 deletions.
6 changes: 6 additions & 0 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,12 @@ std::pair<LowNode*, bool> NodeTree::TTGetOrCreate(uint64_t hash) {
return {tt_iter->second.get(), is_tt_miss};
}

std::pair<LowNode*, bool> NodeTree::TTGetOrCreate(const LowNode& p, uint64_t hash) {
auto [tt_iter, is_tt_miss] =
tt_.insert({hash, std::make_unique<LowNode>(p, hash)});
return {tt_iter->second.get(), is_tt_miss};
}

void NodeTree::TTMaintenance() { TTGCSome(0); }

void NodeTree::TTClear() {
Expand Down
23 changes: 23 additions & 0 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,25 @@ class LowNode {
edges_ = std::make_unique<Edge[]>(num_edges_);
std::memcpy(edges_.get(), p.edges_.get(), num_edges_ * sizeof(Edge));
}

LowNode(const LowNode& p, const uint64_t hash)
: wl_(p.wl_),
v_(p.wl_),
hash_(hash),
d_(p.d_),
m_(p.m_),
vs_(p.vs_),
num_edges_(p.num_edges_),
terminal_type_(Terminal::NonTerminal),
lower_bound_(GameResult::BLACK_WON),
upper_bound_(GameResult::WHITE_WON),
is_transposition(false),
is_tt_(false) {
assert(p.edges_);
edges_ = std::make_unique<Edge[]>(num_edges_);
std::memcpy(edges_.get(), p.edges_.get(), num_edges_ * sizeof(Edge));
}

// Init @edges_ with moves from @moves and 0 policy.
// Also create the first child at @index.
// For non-TT nodes.
Expand Down Expand Up @@ -966,6 +985,10 @@ class NodeTree {
// new low node and insert it into the Transposition Table if it is not there
// already. Return the low node for the hash.
std::pair<LowNode*, bool> TTGetOrCreate(uint64_t hash);


std::pair<LowNode*, bool> TTGetOrCreate(const LowNode& p, uint64_t hash);

// Evict unused low nodes from the Transposition Table.
void TTMaintenance();
// Clear the Transposition Table.
Expand Down
129 changes: 76 additions & 53 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2015,20 +2015,52 @@ void SearchWorker::ExtendNode(NodeToProcess& picked_node) {
picked_node.tt_low_node = tt_low_node;
picked_node.is_tt_hit = true;
} else {
// Make list of 100 LowNode references, one for each ply left






int my_ply = picked_node.GetRule50Ply();
int ply_lo, ply_hi;

if (my_ply <= 64) {
ply_lo = (my_ply / 8) * 8;
ply_hi = ply_lo + 7;
} else {
ply_lo = my_ply;
ply_hi = my_ply;
}

int max_visits = 0;
LowNode* comrade_low_node = nullptr;
for (int ply = ply_lo; ply <= ply_hi; ply++) {
uint64_t hash = search_->dag_->GetHistoryHash(history, ply);
auto low_node = search_->dag_->TTFind(hash);
if (low_node != nullptr) {
int visits = low_node->GetN();
if (visits > max_visits) {
max_visits = visits;
comrade_low_node = low_node;
}
}
}
if (comrade_low_node != nullptr) {
picked_node.comrade_low_node = comrade_low_node;
picked_node.is_comrade_hit = true;
}

float early_q = 99.0f, late_q = 0.0f;
int early_visits, late_visits;

float err_total = 0;
float weight_total = 0;

LowNode* r50_low_nodes[100];

// calculate total error
for (int r50_ply = 0; r50_ply < 100; r50_ply++) {
uint64_t hash = search_->dag_->GetHistoryHash(history, r50_ply);
auto r50_low_node = search_->dag_->TTFind(hash);
r50_low_nodes[r50_ply] = r50_low_node;
if (r50_low_node != nullptr) {
int n = r50_low_node->GetN();
float q = r50_low_node->GetWL();
Expand All @@ -2038,21 +2070,9 @@ void SearchWorker::ExtendNode(NodeToProcess& picked_node) {
err_total += n * (q - v);
weight_total += n;

// it is not possible for your ply to equal mine
if (r50_ply < my_ply) {
if (n > early_visits) {
early_q = q;
}
} else {
if (n > late_visits) {
late_q = q;
}
}
}
}
picked_node.comrade_error = err_total / (weight_total + 0.001f);
picked_node.early_q = early_q;
picked_node.late_q = late_q;

picked_node.lock = NNCacheLock(search_->cache_, picked_node.hash);
picked_node.is_cache_hit = picked_node.lock;
Expand Down Expand Up @@ -2097,46 +2117,49 @@ void SearchWorker::FetchSingleNodeResult(NodeToProcess* node_to_process,
if (!node_to_process->nn_queried) return;

if (!node_to_process->is_tt_hit) {
auto [tt_low_node, is_tt_miss] =
search_->dag_->TTGetOrCreate(node_to_process->hash);

assert(tt_low_node != nullptr);
node_to_process->tt_low_node = tt_low_node;
if (is_tt_miss) {
auto nn_eval = computation.GetNNEval(idx_in_computation).get();
if (params_.GetContemptPerspective() != ContemptPerspective::NONE) {
bool root_stm =
(params_.GetContemptPerspective() == ContemptPerspective::STM
? !(search_->played_history_.Last().IsBlackToMove())
: (params_.GetContemptPerspective() ==
ContemptPerspective::WHITE));
auto sign = (root_stm ^ node_to_process->history.IsBlackToMove())
? 1.0f
: -1.0f;
if (params_.GetWDLRescaleRatio() != 1.0f ||
params_.GetWDLRescaleDiff() != 0.0f) {
float v = nn_eval->q;
float d = nn_eval->d;
WDLRescale(v, d, nullptr, params_.GetWDLRescaleRatio(),
params_.GetWDLRescaleDiff(), sign, false);
nn_eval->q = v;
nn_eval->d = d;


if (node_to_process->is_comrade_hit) {
LowNode comrade_low_node = *(node_to_process->comrade_low_node);
auto [tt_low_node, is_tt_miss] =
search_->dag_->TTGetOrCreate(comrade_low_node, node_to_process->hash);
assert(tt_low_node != nullptr);
node_to_process->tt_low_node = tt_low_node;
}
else {
auto [tt_low_node, is_tt_miss] =
search_->dag_->TTGetOrCreate(node_to_process->hash);
assert(tt_low_node != nullptr);

node_to_process->tt_low_node = tt_low_node;
if (is_tt_miss) {
auto nn_eval = computation.GetNNEval(idx_in_computation).get();
if (params_.GetContemptPerspective() != ContemptPerspective::NONE) {
bool root_stm =
(params_.GetContemptPerspective() == ContemptPerspective::STM
? !(search_->played_history_.Last().IsBlackToMove())
: (params_.GetContemptPerspective() ==
ContemptPerspective::WHITE));
auto sign = (root_stm ^ node_to_process->history.IsBlackToMove())
? 1.0f
: -1.0f;
if (params_.GetWDLRescaleRatio() != 1.0f ||
params_.GetWDLRescaleDiff() != 0.0f) {
float v = nn_eval->q;
float d = nn_eval->d;
WDLRescale(v, d, nullptr, params_.GetWDLRescaleRatio(),
params_.GetWDLRescaleDiff(), sign, false);
nn_eval->q = v;
nn_eval->d = d;
}
}
// after wdl scaling we adjust based on r50
// 0.5 constant is chosen arbitrarily
// TODO: some stuff to consider with differing signs
float q_adjusted = nn_eval->q + node_to_process->comrade_error * 0.5;
nn_eval->q = q_adjusted;
node_to_process->tt_low_node->SetNNEval(nn_eval);
}
// after wdl scaling we adjust based on r50
// 0.5 constant is chosen arbitrarily
// TODO: some stuff to consider with differing signs
float q_adjusted = nn_eval->q + node_to_process->comrade_error * 0.5;
float early_q = node_to_process->early_q,
late_q = node_to_process->late_q;
if (abs(q_adjusted) < abs(late_q)) {
q_adjusted = late_q;
}
if (abs(q_adjusted) > abs(early_q)) {
q_adjusted = early_q;
}
nn_eval->q = q_adjusted;
node_to_process->tt_low_node->SetNNEval(nn_eval);
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/mcts/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class SearchWorker {
return is_tt_hit || is_cache_hit || node->IsTerminal() ||
node->GetLowNode();
}
bool ShouldAddToInput() const { return nn_queried && !is_tt_hit; }
bool ShouldAddToInput() const { return nn_queried && !is_tt_hit && !is_comrade_hit; }
int GetRule50Ply() const { return history.Last().GetRule50Ply(); }

// The path to the node to extend.
Expand All @@ -311,18 +311,20 @@ class SearchWorker {
uint32_t maxvisit = 0;
bool nn_queried = false;
bool is_tt_hit = false;
bool is_comrade_hit = false;

bool is_cache_hit = false;
bool is_collision = false;

// values for improving r50 estimates, filled in as we go
float early_q;
float late_q;
float comrade_error;

// Details that are filled in as we go.
uint64_t hash;

LowNode* tt_low_node;
LowNode* comrade_low_node;

NNCacheLock lock;
PositionHistory history;
bool ooo_completed = false;
Expand Down

0 comments on commit c545246

Please sign in to comment.