Skip to content

Commit

Permalink
Move move filter population to a constructor. (LeelaChessZero#1281)
Browse files Browse the repository at this point in the history
  • Loading branch information
mooskagh authored and AlexisOlson committed May 10, 2020
1 parent 9a1ca2a commit d5c9328
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 57 deletions.
2 changes: 1 addition & 1 deletion src/chess/bitboard.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class Move {
uint16_t as_nn_index(int transform) const;

explicit operator bool() const { return data_ != 0; }
bool operator==(const Move& other) { return data_ == other.data_; }
bool operator==(const Move& other) const { return data_ == other.data_; }

void Mirror() { data_ ^= 0b111000111000; }

Expand Down
94 changes: 46 additions & 48 deletions src/mcts/search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,28 @@ namespace lczero {
namespace {
// Maximum delay between outputting "uci info" when nothing interesting happens.
const int kUciInfoMinimumFrequencyMs = 5000;

MoveList MakeRootMoveFilter(const MoveList& searchmoves,
SyzygyTablebase* syzygy_tb,
const PositionHistory& history, bool fast_play,
std::atomic<int>* tb_hits) {
// Search moves overrides tablebase.
if (!searchmoves.empty()) return searchmoves;
const auto& board = history.Last().GetBoard();
MoveList root_moves;
if (!syzygy_tb || !board.castlings().no_legal_castle() ||
(board.ours() | board.theirs()).count() > syzygy_tb->max_cardinality()) {
return root_moves;
}
if (syzygy_tb->root_probe(
history.Last(), fast_play || history.DidRepeatSinceLastZeroingMove(),
&root_moves) ||
syzygy_tb->root_probe_wdl(history.Last(), &root_moves)) {
tb_hits->fetch_add(1, std::memory_order_acq_rel);
}
return root_moves;
}

} // namespace

Search::Search(const NodeTree& tree, Network* network,
Expand All @@ -63,11 +85,14 @@ Search::Search(const NodeTree& tree, Network* network,
syzygy_tb_(syzygy_tb),
played_history_(tree.GetPositionHistory()),
network_(network),
params_(options),
searchmoves_(searchmoves),
start_time_(start_time),
initial_visits_(root_node_->GetN()),
uci_responder_(std::move(uci_responder)),
params_(options) {
root_move_filter_(
MakeRootMoveFilter(searchmoves_, syzygy_tb_, played_history_,
params_.GetSyzygyFastPlay(), &tb_hits_)),
uci_responder_(std::move(uci_responder)) {
if (params_.GetMaxConcurrentSearchers() != 0) {
pending_searchers_.store(params_.GetMaxConcurrentSearchers(),
std::memory_order_release);
Expand Down Expand Up @@ -476,25 +501,6 @@ std::int64_t Search::GetTotalPlayouts() const {
return total_playouts_;
}

bool Search::PopulateRootMoveLimit(MoveList* root_moves) const {
// Search moves overrides tablebase.
if (!searchmoves_.empty()) {
*root_moves = searchmoves_;
return false;
}
auto board = played_history_.Last().GetBoard();
if (!syzygy_tb_ || !board.castlings().no_legal_castle() ||
(board.ours() | board.theirs()).count() > syzygy_tb_->max_cardinality()) {
return false;
}
return syzygy_tb_->root_probe(
played_history_.Last(),
params_.GetSyzygyFastPlay() ||
played_history_.DidRepeatSinceLastZeroingMove(),
root_moves) ||
syzygy_tb_->root_probe_wdl(played_history_.Last(), root_moves);
}

void Search::ResetBestMove() {
SharedMutex::Lock nodes_lock(nodes_mutex_);
Mutex::Lock lock(counters_mutex_);
Expand Down Expand Up @@ -522,8 +528,9 @@ void Search::EnsureBestMoveKnown() REQUIRES(nodes_mutex_)
if (moves >= decay_delay_moves + decay_moves) {
temperature = 0.0;
} else if (moves >= decay_delay_moves) {
temperature *= static_cast<float>
(decay_delay_moves + decay_moves - moves) / decay_moves;
temperature *=
static_cast<float>(decay_delay_moves + decay_moves - moves) /
decay_moves;
}
// don't allow temperature to decay below endgame temperature
if (temperature < params_.GetTemperatureEndgame()) {
Expand All @@ -543,10 +550,6 @@ void Search::EnsureBestMoveKnown() REQUIRES(nodes_mutex_)
std::vector<EdgeAndNode> Search::GetBestChildrenNoTemperature(Node* parent,
int count,
int depth) const {
MoveList root_limit;
if (parent == root_node_) {
PopulateRootMoveLimit(&root_limit);
}
const bool is_odd_depth = (depth % 2) == 1;
const float draw_score = GetDrawScore(is_odd_depth);
// Best child is selected using the following criteria:
Expand All @@ -557,9 +560,9 @@ std::vector<EdgeAndNode> Search::GetBestChildrenNoTemperature(Node* parent,
// * If that number is larger than 0, the one with larger eval wins.
std::vector<EdgeAndNode> edges;
for (auto edge : parent->Edges()) {
if (parent == root_node_ && !root_limit.empty() &&
std::find(root_limit.begin(), root_limit.end(), edge.GetMove()) ==
root_limit.end()) {
if (parent == root_node_ && !root_move_filter_.empty() &&
std::find(root_move_filter_.begin(), root_move_filter_.end(),
edge.GetMove()) == root_move_filter_.end()) {
continue;
}
edges.push_back(edge);
Expand Down Expand Up @@ -648,8 +651,6 @@ EdgeAndNode Search::GetBestChildNoTemperature(Node* parent, int depth) const {
EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const {
// Root is at even depth.
const float draw_score = GetDrawScore(/* is_odd_depth= */ false);
MoveList root_limit;
PopulateRootMoveLimit(&root_limit);

std::vector<float> cumulative_sums;
float sum = 0.0;
Expand All @@ -660,8 +661,9 @@ EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const {
GetFpu(params_, root_node_, /* is_root= */ true, draw_score);

for (auto edge : root_node_->Edges()) {
if (!root_limit.empty() && std::find(root_limit.begin(), root_limit.end(),
edge.GetMove()) == root_limit.end()) {
if (!root_move_filter_.empty() &&
std::find(root_move_filter_.begin(), root_move_filter_.end(),
edge.GetMove()) == root_move_filter_.end()) {
continue;
}
if (edge.GetN() + offset > max_n) {
Expand All @@ -677,8 +679,9 @@ EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const {
const float min_eval =
max_eval - params_.GetTemperatureWinpctCutoff() / 50.0f;
for (auto edge : root_node_->Edges()) {
if (!root_limit.empty() && std::find(root_limit.begin(), root_limit.end(),
edge.GetMove()) == root_limit.end()) {
if (!root_move_filter_.empty() &&
std::find(root_move_filter_.begin(), root_move_filter_.end(),
edge.GetMove()) == root_move_filter_.end()) {
continue;
}
if (edge.GetQ(fpu, draw_score, /* logit_q= */ false) < min_eval) continue;
Expand All @@ -695,8 +698,9 @@ EdgeAndNode Search::GetBestRootChildWithTemperature(float temperature) const {
cumulative_sums.begin();

for (auto edge : root_node_->Edges()) {
if (!root_limit.empty() && std::find(root_limit.begin(), root_limit.end(),
edge.GetMove()) == root_limit.end()) {
if (!root_move_filter_.empty() &&
std::find(root_move_filter_.begin(), root_move_filter_.end(),
edge.GetMove()) == root_move_filter_.end()) {
continue;
}
if (edge.GetQ(fpu, draw_score, /* logit_q= */ false) < min_eval) continue;
Expand Down Expand Up @@ -902,13 +906,6 @@ void SearchWorker::InitializeIteration(
computation_ = std::make_unique<CachingComputation>(std::move(computation),
search_->cache_);
minibatch_.clear();

if (!root_move_filter_populated_) {
root_move_filter_populated_ = true;
if (search_->PopulateRootMoveLimit(&root_move_filter_)) {
search_->tb_hits_.fetch_add(1, std::memory_order_acq_rel);
}
}
}

// 2. Gather minibatch.
Expand Down Expand Up @@ -1021,6 +1018,7 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend(
bool is_root_node = true;
const float even_draw_score = search_->GetDrawScore(false);
const float odd_draw_score = search_->GetDrawScore(true);
const auto& root_move_filter = search_->root_move_filter_;
uint16_t depth = 0;
bool node_already_updated = true;

Expand Down Expand Up @@ -1096,9 +1094,9 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend(
continue;
}
// If root move filter exists, make sure move is in the list.
if (!root_move_filter_.empty() &&
std::find(root_move_filter_.begin(), root_move_filter_.end(),
child.GetMove()) == root_move_filter_.end()) {
if (!root_move_filter.empty() &&
std::find(root_move_filter.begin(), root_move_filter.end(),
child.GetMove()) == root_move_filter.end()) {
continue;
}
}
Expand Down
12 changes: 4 additions & 8 deletions src/mcts/search.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,6 @@ class Search {
// uci `stop` command;
void WatchdogThread();

// Populates the given list with allowed root moves.
// Returns true if the population came from tablebase.
bool PopulateRootMoveLimit(MoveList* root_moves) const;

// Fills IterationStats with global (rather than per-thread) portion of search
// statistics. Currently all stats there (in IterationStats) are global
// though.
Expand Down Expand Up @@ -173,9 +169,13 @@ class Search {
const PositionHistory& played_history_;

Network* const network_;
const SearchParams params_;
const MoveList searchmoves_;
const std::chrono::steady_clock::time_point start_time_;
const int64_t initial_visits_;
// tb_hits_ must be initialized before root_move_filter_.
std::atomic<int> tb_hits_{0};
const MoveList root_move_filter_;

mutable SharedMutex nodes_mutex_;
EdgeAndNode current_best_edge_ GUARDED_BY(nodes_mutex_);
Expand All @@ -188,15 +188,13 @@ class Search {
// Cumulative depth of all paths taken in PickNodetoExtend.
uint64_t cum_depth_ GUARDED_BY(nodes_mutex_) = 0;
std::optional<std::chrono::steady_clock::time_point> nps_start_time_;
std::atomic<int> tb_hits_{0};

std::atomic<int> pending_searchers_{0};

std::vector<std::pair<Node*, int>> shared_collisions_
GUARDED_BY(nodes_mutex_);

std::unique_ptr<UciResponder> uci_responder_;
const SearchParams params_;

friend class SearchWorker;
};
Expand Down Expand Up @@ -320,8 +318,6 @@ class SearchWorker {
std::unique_ptr<CachingComputation> computation_;
// History is reset and extended by PickNodeToExtend().
PositionHistory history_;
MoveList root_move_filter_;
bool root_move_filter_populated_ = false;
int number_out_of_order_ = 0;
const SearchParams& params_;
std::unique_ptr<Node> precached_node_;
Expand Down

0 comments on commit d5c9328

Please sign in to comment.