Skip to content

Commit

Permalink
Support standard deviation factor
Browse files Browse the repository at this point in the history
  • Loading branch information
Ergodice authored and Craftyawesome committed Jul 18, 2023
1 parent 20a4b66 commit 7de757e
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 28 deletions.
2 changes: 1 addition & 1 deletion build.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ set EIGEN=false
set TEST=false

rem 2. Edit the paths for the build dependencies.
set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.0
set CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.0
set CUDNN_PATH=%CUDA_PATH%
set OPENBLAS_PATH=C:\OpenBLAS
set MKL_PATH=C:\Program Files (x86)\IntelSWTools\compilers_and_libraries\windows\mkl
Expand Down
21 changes: 17 additions & 4 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ void Node::Trim(GCQueue* gc_queue) {

d_ = 0.0f;
m_ = 0.0f;
vs_ = 0.0f;
n_ = 0;
n_in_flight_ = 0;

Expand Down Expand Up @@ -225,6 +226,7 @@ void LowNode::MakeTerminal(GameResult result, float plies_left, Terminal type) {
wl_ = -1.0f;
d_ = 0.0f;
}
vs_ = wl_ * wl_;

assert(WLDMInvariantsHold());
}
Expand All @@ -240,6 +242,7 @@ void LowNode::MakeNotTerminal(const Node* node) {
wl_ = 0.0;
d_ = 0.0;
m_ = 0.0;
vs_ = 0.0;

// Include children too.
if (node->GetNumEdges() > 0) {
Expand All @@ -252,13 +255,15 @@ void LowNode::MakeNotTerminal(const Node* node) {
wl_ += child.GetWL(0.0f) * n;
d_ += child.GetD(0.0f) * n;
m_ += child.GetM(0.0f) * n;
vs_ += child.GetVS(0.0f) * n;
}
}

// Recompute with current eval (instead of network's) and children's eval.
wl_ /= n_;
d_ /= n_;
m_ /= n_;
vs_ /= n_;
}

assert(WLDMInvariantsHold());
Expand Down Expand Up @@ -290,6 +295,7 @@ void Node::MakeTerminal(GameResult result, float plies_left, Terminal type) {
// comparable to another non-loss choice. Force this by clearing the policy.
SetP(0.0f);
}
vs_ = wl_ * wl_;

assert(WLDMInvariantsHold());
}
Expand All @@ -313,13 +319,15 @@ void Node::MakeNotTerminal(bool also_low_node) {
wl_ = -low_node_->GetWL();
d_ = low_node_->GetD();
m_ = low_node_->GetM() + 1;
vs_ = low_node_->GetVS();
} else { // Real terminal.
lower_bound_ = GameResult::BLACK_WON;
upper_bound_ = GameResult::WHITE_WON;
n_ = 0.0f;
wl_ = 0.0f;
d_ = 0.0f;
m_ = 0.0f;
vs_ = 0.0f;
}

assert(WLDMInvariantsHold());
Expand Down Expand Up @@ -349,37 +357,41 @@ void Node::CancelScoreUpdate(uint32_t multivisit) {
n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel);
}

void LowNode::FinalizeScoreUpdate(float v, float d, float m,
void LowNode::FinalizeScoreUpdate(float v, float d, float m, float vs,
uint32_t multivisit) {
assert(edges_);
// Recompute Q.
wl_ += multivisit * (v - wl_) / (n_ + multivisit);
d_ += multivisit * (d - d_) / (n_ + multivisit);
m_ += multivisit * (m - m_) / (n_ + multivisit);
vs_ += multivisit * (vs - vs_) / (n_ + multivisit);


assert(WLDMInvariantsHold());

// Increment N.
n_ += multivisit;
}

void LowNode::AdjustForTerminal(float v, float d, float m,
void LowNode::AdjustForTerminal(float v, float d, float m, float vs,
uint32_t multivisit) {
assert(static_cast<uint32_t>(multivisit) <= n_);

// Recompute Q.
wl_ += multivisit * v / n_;
d_ += multivisit * d / n_;
m_ += multivisit * m / n_;
vs_ += multivisit * vs / n_;

assert(WLDMInvariantsHold());
}

void Node::FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit) {
void Node::FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit) {
// Recompute Q.
wl_ += multivisit * (v - wl_) / (n_ + multivisit);
d_ += multivisit * (d - d_) / (n_ + multivisit);
m_ += multivisit * (m - m_) / (n_ + multivisit);
vs_ += multivisit * (vs - vs_) / (n_ + multivisit);

assert(WLDMInvariantsHold());

Expand All @@ -390,13 +402,14 @@ void Node::FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit) {
n_in_flight_.fetch_sub(multivisit, std::memory_order_acq_rel);
}

void Node::AdjustForTerminal(float v, float d, float m, uint32_t multivisit) {
void Node::AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit) {
assert(static_cast<uint32_t>(multivisit) <= n_);

// Recompute Q.
wl_ += multivisit * v / n_;
d_ += multivisit * d / n_;
m_ += multivisit * m / n_;
vs_ += multivisit * vs / n_;

assert(WLDMInvariantsHold());
}
Expand Down
19 changes: 15 additions & 4 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ class Node {
float GetWL() const { return wl_; }
float GetD() const { return d_; }
float GetM() const { return m_; }
float GetVS() const { return vs_; }


// Returns whether the node is known to be draw/lose/win.
bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; }
Expand Down Expand Up @@ -318,9 +320,9 @@ class Node {
// * Q (weighted average of all V in a subtree)
// * N (+=multivisit)
// * N-in-flight (-=multivisit)
void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit);
void FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit);
// Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
void AdjustForTerminal(float v, float d, float m, uint32_t multivisit);
void AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit);
// When search decides to treat one visit as several (in case of collisions
// or visiting terminal nodes several times), it amplifies the visit by
// incrementing n_in_flight.
Expand Down Expand Up @@ -397,6 +399,7 @@ class Node {
// Averaged draw probability. Works similarly to WL, except that D is not
// flipped depending on the side to move.
double d_ = 0.0f;
double vs_ = 0.0;

// 8 byte fields on 64-bit platforms, 4 byte on 32-bit.
// Pointer to the low node.
Expand Down Expand Up @@ -453,6 +456,7 @@ class LowNode {
d_(p.d_),
hash_(p.hash_),
m_(p.m_),
vs_(p.vs_),
num_edges_(p.num_edges_),
terminal_type_(Terminal::NonTerminal),
lower_bound_(GameResult::BLACK_WON),
Expand Down Expand Up @@ -490,6 +494,7 @@ class LowNode {
wl_ = eval->q;
d_ = eval->d;
m_ = eval->m;
vs_ = wl_ * wl_;

assert(WLDMInvariantsHold());

Expand All @@ -510,6 +515,8 @@ class LowNode {
float GetWL() const { return wl_; }
float GetD() const { return d_; }
float GetM() const { return m_; }
float GetVS() const { return vs_; }


// Returns whether the node is known to be draw/loss/win.
bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; }
Expand All @@ -535,9 +542,9 @@ class LowNode {
// * Q (weighted average of all V in a subtree)
// * N (+=multivisit)
// * N-in-flight (-=multivisit)
void FinalizeScoreUpdate(float v, float d, float m, uint32_t multivisit);
void FinalizeScoreUpdate(float v, float d, float m, float vs, uint32_t multivisit);
// Like FinalizeScoreUpdate, but it updates n existing visits by delta amount.
void AdjustForTerminal(float v, float d, float m, uint32_t multivisit);
void AdjustForTerminal(float v, float d, float m, float vs, uint32_t multivisit);

// Deletes all children.
void ReleaseChildren(GCQueue* gc_queue);
Expand Down Expand Up @@ -598,6 +605,7 @@ class LowNode {
// Averaged draw probability. Works similarly to WL, except that D is not
// flipped depending on the side to move.
double d_ = 0.0f;
double vs_ = 0.0f;
// Position hash and a TT key.
uint64_t hash_ = 0;

Expand Down Expand Up @@ -666,6 +674,9 @@ class EdgeAndNode {
float GetM(float default_m) const {
return (node_ && node_->GetN() > 0) ? node_->GetM() : default_m;
}
float GetVS(float default_vs) const {
return (node_ && node_->GetN() > 0) ? node_->GetVS() : default_vs;
}
// N-related getters, from Node (if exists).
uint32_t GetN() const { return node_ ? node_->GetN() : 0; }
int GetNStarted() const { return node_ ? node_->GetNStarted() : 0; }
Expand Down
32 changes: 31 additions & 1 deletion src/mcts/params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,22 @@ const OptionId SearchParams::kUCIRatingAdvId{
"", "UCI_RatingAdv",
"UCI extension used by some GUIs to pass the estimated Elo advantage over "
"the current opponent, used as the default contempt value."};
const OptionId SearchParams::kCpuctUtilityStdevPriorId{
"cpuct-utility-stdev-prior", "CpuctUtilityStdevPrior",
"Prior for stdev cpuct formula."};
const OptionId SearchParams::kCpuctUtilityStdevScaleId{
"cpuct-utility-stdev-scale", "CpuctUtilityStdevScale",
"Scale value for the utility stdev in the cpuct formula."};
const OptionId SearchParams::kCpuctUtilityStdevPriorWeightId{
"cpuct-utility-stdev-prior-weight", "CpuctUtilityStdevPriorWeight",
"How much to weigh the prior value in the calculation of stdev."};
const OptionId SearchParams::kCpuctAdvantageScaleId{
"cpuct-advantage-scale", "CpuctAdvantageScale",
"How heavily to weight the advantage when modulating cpuct. IGNORED BY ENGINE."};
const OptionId SearchParams::kCpuctAdvantagePriorWeightId{
"cpuct-advantage-prior-weight", "CpuctAdvantagePriorWeight",
"Effectively how long to wait before applying advantage modulation to "
"cpuct. IGNORED BY ENGINE."};

void SearchParams::Populate(OptionsParser* options) {
// Here the uci optimized defaults" are set.
Expand Down Expand Up @@ -551,6 +567,13 @@ void SearchParams::Populate(OptionsParser* options) {
options->Add<IntOption>(kThreadIdlingThresholdId, 0, 128) = 1;
options->Add<StringOption>(kUCIOpponentId);
options->Add<FloatOption>(kUCIRatingAdvId, -10000.0f, 10000.0f) = 0.0f;
options->Add<FloatOption>(kCpuctUtilityStdevPriorId, 0.0f, 2.0f) = 0.1f;
options->Add<FloatOption>(kCpuctUtilityStdevScaleId, 0.0f, 1.0f) = 0.0f;
options->Add<FloatOption>(kCpuctUtilityStdevPriorWeightId, 0.0f, 10000.0f) =
10.0f;
options->Add<FloatOption>(kCpuctAdvantageScaleId, -1.0f, 1.0f) = 0.0f;
options->Add<FloatOption>(kCpuctAdvantagePriorWeightId, 0.0f, 10000.0f) =
10.0f;

options->HideOption(kNoiseEpsilonId);
options->HideOption(kNoiseAlphaId);
Expand Down Expand Up @@ -665,7 +688,14 @@ SearchParams::SearchParams(const OptionsDict& options)
kMaxCollisionVisitsScalingEnd(
options.Get<int>(kMaxCollisionVisitsScalingEndId)),
kMaxCollisionVisitsScalingPower(
options.Get<float>(kMaxCollisionVisitsScalingPowerId)) {
options.Get<float>(kMaxCollisionVisitsScalingPowerId)),
kCpuctUtilityStdevPrior(options.Get<float>(kCpuctUtilityStdevPriorId)),
kCpuctUtilityStdevScale(options.Get<float>(kCpuctUtilityStdevScaleId)),
kCpuctUtilityStdevPriorWeight(
options.Get<float>(kCpuctUtilityStdevPriorWeightId)),
kCpuctAdvantageScale(options.Get<float>(kCpuctAdvantageScaleId)),
kCpuctAdvantagePriorWeight(
options.Get<float>(kCpuctAdvantagePriorWeightId)) {
if (std::max(std::abs(kDrawScoreSidetomove), std::abs(kDrawScoreOpponent)) +
std::max(std::abs(kDrawScoreWhite), std::abs(kDrawScoreBlack)) >
1.0f) {
Expand Down
20 changes: 20 additions & 0 deletions src/mcts/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ class SearchParams {
float GetMaxCollisionVisitsScalingPower() const {
return kMaxCollisionVisitsScalingPower;
}
float GetCpuctUtilityStdevPrior() const { return kCpuctUtilityStdevPrior; }
float GetCpuctUtilityStdevScale() const { return kCpuctUtilityStdevScale; }
float GetCpuctUtilityStdevPriorWeight() const {
return kCpuctUtilityStdevPriorWeight;
}

float GetCpuctAdvantageScale() const { return kCpuctAdvantageScale; }
float GetCpuctAdvantagePriorWeight() const {
return kCpuctAdvantagePriorWeight;
}

// Search parameter IDs.
static const OptionId kMiniBatchSizeId;
Expand Down Expand Up @@ -224,6 +234,11 @@ class SearchParams {
static const OptionId kMaxCollisionVisitsScalingPowerId;
static const OptionId kUCIOpponentId;
static const OptionId kUCIRatingAdvId;
static const OptionId kCpuctUtilityStdevPriorId;
static const OptionId kCpuctUtilityStdevScaleId;
static const OptionId kCpuctUtilityStdevPriorWeightId;
static const OptionId kCpuctAdvantageScaleId;
static const OptionId kCpuctAdvantagePriorWeightId;

private:
const OptionsDict& options_;
Expand Down Expand Up @@ -283,6 +298,11 @@ class SearchParams {
const int kMaxCollisionVisitsScalingStart;
const int kMaxCollisionVisitsScalingEnd;
const float kMaxCollisionVisitsScalingPower;
const float kCpuctUtilityStdevPrior;
const float kCpuctUtilityStdevScale;
const float kCpuctUtilityStdevPriorWeight;
const float kCpuctAdvantageScale;
const float kCpuctAdvantagePriorWeight;
};

} // namespace lczero
Loading

0 comments on commit 7de757e

Please sign in to comment.