Skip to content

Commit

Permalink
Merge branch 'master' into WDLconversion
Browse files Browse the repository at this point in the history
  • Loading branch information
Naphthalin authored May 21, 2023
2 parents f876733 + 905d88e commit 1e00ff0
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ class MEvaluator {
}
}

float GetM(const EdgeAndNode& child, float q) const {
// Calculates the utility for favoring shorter wins and longer losses.
float GetMUtility(Node* child, float q) const {
if (!enabled_ || !parent_within_threshold_) return 0.0f;
const float child_m = child.GetM(parent_m_);
const float child_m = child->GetM();
float m = std::clamp(m_slope_ * (child_m - parent_m_), -m_cap_, m_cap_);
m *= FastSign(-q);
if (q_threshold_ > 0.0f && q_threshold_ < 1.0f) {
Expand All @@ -120,17 +121,14 @@ class MEvaluator {
return m;
}

float GetM(Node* child, float q) const {
float GetMUtility(const EdgeAndNode& child, float q) const {
if (!enabled_ || !parent_within_threshold_) return 0.0f;
const float child_m = child->GetM();
float m = std::clamp(m_slope_ * (child_m - parent_m_), -m_cap_, m_cap_);
m *= FastSign(-q);
m *= a_constant_ + a_linear_ * std::abs(q) + a_square_ * q * q;
return m;
if (child.GetN() == 0) return GetDefaultMUtility();
return GetMUtility(child.node(), q);
}

// The M utility to use for unvisited nodes.
float GetDefaultM() const { return 0.0f; }
float GetDefaultMUtility() const { return 0.0f; }

private:
static bool WithinThreshold(const Node* parent, float q_threshold) {
Expand Down Expand Up @@ -518,7 +516,7 @@ std::vector<std::string> Search::GetVerboseStats(Node* node) const {
: MEvaluator();
for (const auto& edge : edges) {
float Q = edge.GetQ(fpu, draw_score);
float M = m_evaluator.GetM(edge, Q);
float M = m_evaluator.GetMUtility(edge, Q);
std::ostringstream oss;
oss << std::left;
// TODO: should this be displaying transformed index?
Expand Down Expand Up @@ -931,7 +929,7 @@ void Search::PopulateCommonIterationStats(IterationStats* stats) {
for (const auto& edge : root_node_->Edges()) {
const auto n = edge.GetN();
const auto q = edge.GetQ(fpu, draw_score);
const auto m = m_evaluator.GetM(edge, q);
const auto m = m_evaluator.GetMUtility(edge, q);
const auto q_plus_m = q + m;
stats->edge_n.push_back(n);
if (n > 0 && edge.IsTerminal() && edge.GetWL(0.0f) > 0.0f) {
Expand Down Expand Up @@ -1668,13 +1666,13 @@ void SearchWorker::PickNodesToExtendTask(
int index = child->Index();
visited_pol += current_pol[index];
float q = child->GetQ(draw_score);
current_util[index] = q + m_evaluator.GetM(child, q);
current_util[index] = q + m_evaluator.GetMUtility(child, q);
}
const float fpu =
GetFpu(params_, node, is_root_node, draw_score, visited_pol);
for (int i = 0; i < max_needed; i++) {
if (current_util[i] == std::numeric_limits<float>::lowest()) {
current_util[i] = fpu + m_evaluator.GetDefaultM();
current_util[i] = fpu + m_evaluator.GetDefaultMUtility();
}
}

Expand Down

0 comments on commit 1e00ff0

Please sign in to comment.