Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change mlh threshold effect to create a smooth transition #1844

Merged
merged 8 commits into from
May 21, 2023
40 changes: 23 additions & 17 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,26 +106,29 @@ class MEvaluator {
}
}

float GetM(const EdgeAndNode& child, float q) const {
// Calculates the utility for favoring shorter wins and longer losses.
float GetMUtility(Node* child, float q) const {
if (!enabled_ || !parent_within_threshold_) return 0.0f;
const float child_m = child.GetM(parent_m_);
const float child_m = child->GetM();
float m = std::clamp(m_slope_ * (child_m - parent_m_), -m_cap_, m_cap_);
m *= FastSign(-q);
if (q_threshold_ > 0.0f && q_threshold_ < 1.0f) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment what GetM function does?

I though it returns moves left, but then it returns 0 if |q| < q_thredshol, so it's not clear what it's supposed to do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It returns the "moves left head effect" term which gets added to the PUCT score, so that PUCT = Q + U + M. This PR doesn't change the meaning of GetM(), but your comment makes clear that it was already confusing before this PR, so a comment (even if it has nothing to do with this PR) would certainly justified.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like this before the function declaration.

// Returns adjustment to Q based on Q and M value of a node.

Could you add this in this PR?

Also, GetM() sounds like a wrong name for this function. Maybe rename to GetQAdjustment() ?
(it's more or less clear that it's based on M because it inside the MEvaluator class).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Settled at GetMUtility() because it best describes both where it comes from (M) and what it is (a utility).

// This allows a smooth M effect with higher q thresholds, which is
// necessary for using MLH together with contempt.
q = std::max(0.0f, (std::abs(q) - q_threshold_)) / (1.0f - q_threshold_);
}
m *= a_constant_ + a_linear_ * std::abs(q) + a_square_ * q * q;
return m;
}

float GetM(Node* child, float q) const {
float GetMUtility(const EdgeAndNode& child, float q) const {
if (!enabled_ || !parent_within_threshold_) return 0.0f;
const float child_m = child->GetM();
float m = std::clamp(m_slope_ * (child_m - parent_m_), -m_cap_, m_cap_);
m *= FastSign(-q);
m *= a_constant_ + a_linear_ * std::abs(q) + a_square_ * q * q;
return m;
if (child.GetN() == 0) return GetDefaultMUtility();
return GetMUtility(child.node(), q);
}

// The M utility to use for unvisited nodes.
float GetDefaultM() const { return 0.0f; }
float GetDefaultMUtility() const { return 0.0f; }

private:
static bool WithinThreshold(const Node* parent, float q_threshold) {
Expand Down Expand Up @@ -436,10 +439,13 @@ std::vector<std::string> Search::GetVerboseStats(Node* node) const {
up = -up;
std::swap(lo, up);
}
*oss << (lo == up ? "(T) "
: lo == GameResult::DRAW && up == GameResult::WHITE_WON ? "(W) "
: lo == GameResult::BLACK_WON && up == GameResult::DRAW ? "(L) "
: "");
*oss << (lo == up
? "(T) "
: lo == GameResult::DRAW && up == GameResult::WHITE_WON
? "(W) "
: lo == GameResult::BLACK_WON && up == GameResult::DRAW
? "(L) "
: "");
}
};

Expand All @@ -449,7 +455,7 @@ std::vector<std::string> Search::GetVerboseStats(Node* node) const {
: MEvaluator();
for (const auto& edge : edges) {
float Q = edge.GetQ(fpu, draw_score);
float M = m_evaluator.GetM(edge, Q);
float M = m_evaluator.GetMUtility(edge, Q);
std::ostringstream oss;
oss << std::left;
// TODO: should this be displaying transformed index?
Expand Down Expand Up @@ -862,7 +868,7 @@ void Search::PopulateCommonIterationStats(IterationStats* stats) {
for (const auto& edge : root_node_->Edges()) {
const auto n = edge.GetN();
const auto q = edge.GetQ(fpu, draw_score);
const auto m = m_evaluator.GetM(edge, q);
const auto m = m_evaluator.GetMUtility(edge, q);
const auto q_plus_m = q + m;
stats->edge_n.push_back(n);
if (n > 0 && edge.IsTerminal() && edge.GetWL(0.0f) > 0.0f) {
Expand Down Expand Up @@ -1599,13 +1605,13 @@ void SearchWorker::PickNodesToExtendTask(
int index = child->Index();
visited_pol += current_pol[index];
float q = child->GetQ(draw_score);
current_util[index] = q + m_evaluator.GetM(child, q);
current_util[index] = q + m_evaluator.GetMUtility(child, q);
}
const float fpu =
GetFpu(params_, node, is_root_node, draw_score, visited_pol);
for (int i = 0; i < max_needed; i++) {
if (current_util[i] == std::numeric_limits<float>::lowest()) {
current_util[i] = fpu + m_evaluator.GetDefaultM();
current_util[i] = fpu + m_evaluator.GetDefaultMUtility();
}
}

Expand Down