Skip to content

Commit

Permalink
Merge pull request #12 from mooskagh/master
Browse files Browse the repository at this point in the history
Fix castling handling, update default params.
  • Loading branch information
mooskagh authored Jun 2, 2018
2 parents f97b871 + 506956d commit 2321011
Show file tree
Hide file tree
Showing 16 changed files with 174 additions and 106 deletions.
10 changes: 9 additions & 1 deletion src/chess/bitboard.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ std::vector<unsigned short> BuildMoveIndices() {
}

const std::vector<unsigned short> kMoveToIdx = BuildMoveIndices();
const int kKingCastleIndex =
kMoveToIdx[BoardSquare("e1").as_int() * 64 + BoardSquare("h1").as_int()];
const int kQueenCastleIndex =
kMoveToIdx[BoardSquare("e1").as_int() * 64 + BoardSquare("a1").as_int()];
} // namespace

Move::Move(const std::string& str, bool black) {
Expand Down Expand Up @@ -307,6 +311,10 @@ uint16_t Move::as_packed_int() const {
}
}

uint16_t Move::as_nn_index() const { return kMoveToIdx[as_packed_int()]; }
uint16_t Move::as_nn_index() const {
if (!castling_) return kMoveToIdx[as_packed_int()];
if (from_.col() < to_.col()) return kKingCastleIndex;
return kQueenCastleIndex;
}

} // namespace lczero
6 changes: 1 addition & 5 deletions src/chess/bitboard.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,7 @@ class Move {
}

std::string as_string() const {
BoardSquare to = to_;
if (castling_) {
to = BoardSquare(to.row(), (to.col() == 7) ? 6 : 2);
}
std::string res = from_.as_string() + to.as_string();
std::string res = from_.as_string() + to_.as_string();
switch (promotion_) {
case Promotion::None:
return res;
Expand Down
16 changes: 5 additions & 11 deletions src/chess/board.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ MoveList ChessBoard::GeneratePseudolegalMoves() const {
}
}
if (can_castle) {
result.emplace_back(source, BoardSquare(0, 7));
result.emplace_back(source, BoardSquare(0, 6));
result.back().SetCastling();
}
}
Expand All @@ -230,7 +230,7 @@ MoveList ChessBoard::GeneratePseudolegalMoves() const {
}
}
if (can_castle) {
result.emplace_back(source, BoardSquare(0, 0));
result.emplace_back(source, BoardSquare(0, 2));
result.back().SetCastling();
}
}
Expand Down Expand Up @@ -344,9 +344,9 @@ bool ChessBoard::ApplyMove(Move move) {
const auto to_row = to.row();
const auto to_col = to.col();

// Remove our piece from old location, but not put to destination
// (for the case of castling).
// Move in our pieces.
our_pieces_.reset(from);
our_pieces_.set(to);

// Remove captured piece
bool reset_50_moves = their_pieces_.get(to);
Expand Down Expand Up @@ -378,26 +378,20 @@ bool ChessBoard::ApplyMove(Move move) {
if (from == our_king_) {
castlings_.reset_we_can_00();
castlings_.reset_we_can_000();
our_king_ = to;
// Castling
if (to_col - from_col > 1) {
// 0-0
our_pieces_.reset(7);
rooks_.reset(7);
our_pieces_.set(5);
rooks_.set(5);
our_king_ = BoardSquare(0, 6); /* g8 */
our_pieces_.set(our_king_);
} else if (from_col - to_col > 1) {
// 0-0-0
our_pieces_.reset(0);
rooks_.reset(0);
our_pieces_.set(3);
rooks_.set(3);
our_king_ = BoardSquare(0, 2); /* c8 */
our_pieces_.set(our_king_);
} else {
our_king_ = to;
our_pieces_.set(to);
}
return reset_50_moves;
}
Expand Down
32 changes: 26 additions & 6 deletions src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,17 @@ void EngineController::PopulateOptions(OptionsParser* options) {
options->Add<ChoiceOption>(kNnBackendStr, backends, "backend") =
backends.empty() ? "<none>" : backends[0];
options->Add<StringOption>(kNnBackendOptionsStr, "backend-opts");
options->Add<FloatOption>(kSlowMoverStr, 0.0, 100.0, "slowmover") = 2.2;
options->Add<FloatOption>(kSlowMoverStr, 0.0f, 100.0f, "slowmover") = 2.2f;
options->Add<IntOption>(kMoveOverheadStr, 0, 10000, "move-overhead") = 100;

Search::PopulateUciParams(options);

auto defaults = options->GetMutableDefaultsOptions();

defaults->Set<int>(Search::kMiniBatchSizeStr, 256); // Minibatch = 256
defaults->Set<float>(Search::kFpuReductionStr, 0.2f); // FPU reduction = 0.2
defaults->Set<float>(Search::kCpuctStr, 3.1f); // CPUCT = 3.1
defaults->Set<int>(Search::kAllowedNodeCollisionsStr, 32); // Node collisions
}

SearchLimits EngineController::PopulateSearchLimits(int /*ply*/, bool is_black,
Expand All @@ -87,8 +94,9 @@ SearchLimits EngineController::PopulateSearchLimits(int /*ply*/, bool is_black,
float slowmover = options_.Get<float>(kSlowMoverStr);
int64_t move_overhead = options_.Get<int>(kMoveOverheadStr);
// Total time till control including increments.
auto total_moves_time = std::max(
int64_t{0}, time + increment * (movestogo - 1) - move_overhead * movestogo);
auto total_moves_time =
std::max(int64_t{0},
time + increment * (movestogo - 1) - move_overhead * movestogo);

const int kSmartPruningToleranceMs = 200;

Expand All @@ -99,10 +107,12 @@ SearchLimits EngineController::PopulateSearchLimits(int /*ply*/, bool is_black,
// reduce it.
if (slowmover < 1.0 || this_move_time > kSmartPruningToleranceMs) {
// Budget X*slowmover for current move, X*1.0 for the rest.
this_move_time = total_moves_time / (movestogo - 1 + slowmover) * slowmover;
this_move_time = static_cast<int64_t>(
total_moves_time / (movestogo - 1 + slowmover) * slowmover);
}
// Make sure we don't exceed current time limit with what we calculated.
limits.time_ms = std::max(int64_t{0}, std::min(this_move_time, time - move_overhead));
limits.time_ms =
std::max(int64_t{0}, std::min(this_move_time, time - move_overhead));
return limits;
}

Expand Down Expand Up @@ -134,6 +144,11 @@ void EngineController::UpdateNetwork() {

void EngineController::SetCacheSize(int size) { cache_.SetCapacity(size); }

void EngineController::EnsureReady() {
GarbageCollectNodePool();
std::unique_lock<RpSharedMutex> lock(busy_mutex_);
}

void EngineController::NewGame() {
SharedLock lock(busy_mutex_);
cache_.Clear();
Expand All @@ -151,7 +166,7 @@ void EngineController::SetPosition(const std::string& fen,

std::vector<Move> moves;
for (const auto& move : moves_str) moves.emplace_back(move);
tree_->ResetToPosition(fen, moves);
tree_->ResetToPosition(fen, moves, false);
UpdateNetwork();
}

Expand All @@ -163,6 +178,11 @@ void EngineController::Go(const GoParams& params) {
auto limits = PopulateSearchLimits(tree_->GetPlyCount(),
tree_->IsBlackToMove(), params);

BestMoveInfo::Callback best_move_callback = [this](const BestMoveInfo& info) {
best_move_callback_(info);
GarbageCollectNodePool();
};

search_ =
std::make_unique<Search>(*tree_, network_.get(), best_move_callback_,
info_callback_, limits, options_, &cache_);
Expand Down
2 changes: 1 addition & 1 deletion src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class EngineController {
void PopulateOptions(OptionsParser* options);

// Blocks.
void EnsureReady() { std::unique_lock<RpSharedMutex> lock(busy_mutex_); }
void EnsureReady();

// Must not block.
void NewGame();
Expand Down
91 changes: 58 additions & 33 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,21 @@ class Node::Pool {
public:
// Allocates a new node and initializes it with all zeros.
Node* AllocateNode();
// Return node to the pool.
void ReleaseNode(Node*);

// Release* function don't release trees immediately but rather schedule
// release until when GarbageCollect() is called.
// Releases all children of the node, except specified. Also updates pointers
// accordingly.
void ReleaseAllChildrenExceptOne(Node* root, Node* subtree);
// Releases all children, but doesn't release the node isself.
void ReleaseChildren(Node*);
// Releases all children and the node itself;
void ReleaseSubtree(Node*);
// Really releases subtrees makerd for release earlier.
void GarbageCollect();

private:
void AllocateNewBatch();
void ReleaseNodeInternal(Node*);
void ReleaseChildrenInternal(Node*);
void ReleaseSubtreeInternal(Node*);

union FreeNode {
FreeNode* next;
Expand All @@ -65,6 +64,8 @@ class Node::Pool {
FreeNode() {}
};

static FreeNode* UnrollNodeTree(FreeNode* node);

mutable Mutex mutex_;
// Linked list of free nodes.
FreeNode* free_list_ GUARDED_BY(mutex_) = nullptr;
Expand All @@ -74,8 +75,41 @@ class Node::Pool {
FreeNode* reserve_list_ GUARDED_BY(allocations_mutex_) = nullptr;
std::vector<std::unique_ptr<FreeNode[]>> allocations_
GUARDED_BY(allocations_mutex_);

mutable Mutex gc_mutex_;
std::vector<Node*> subtrees_to_gc_ GUARDED_BY(gc_mutex_);
};

Node::Pool::FreeNode* Node::Pool::UnrollNodeTree(FreeNode* node) {
if (!node->node.child_) return node;
FreeNode* prev = node;
for (Node* iter = node->node.child_; iter; iter = iter->sibling_) {
FreeNode* next = reinterpret_cast<FreeNode*>(iter);
prev->next = next;
prev = UnrollNodeTree(next);
}
return prev;
}

void Node::Pool::GarbageCollect() {
while (true) {
Node* node_to_gc = nullptr;
{
Mutex::Lock lock(gc_mutex_);
if (subtrees_to_gc_.empty()) return;
node_to_gc = subtrees_to_gc_.back();
subtrees_to_gc_.pop_back();
}
FreeNode* head = reinterpret_cast<FreeNode*>(node_to_gc);
FreeNode* tail = UnrollNodeTree(head);
{
Mutex::Lock lock(mutex_);
tail->next = free_list_;
free_list_ = head;
}
}
}

Node* Node::Pool::AllocateNode() {
while (true) {
Node* result = nullptr;
Expand Down Expand Up @@ -111,17 +145,6 @@ Node* Node::Pool::AllocateNode() {
}
}

void Node::Pool::ReleaseNode(Node* node) {
Mutex::Lock lock(mutex_);
ReleaseNodeInternal(node);
}

void Node::Pool::ReleaseNodeInternal(Node* node) REQUIRES(mutex_) {
auto* free_node = reinterpret_cast<FreeNode*>(node);
free_node->next = free_list_;
free_list_ = free_node;
}

void Node::Pool::AllocateNewBatch() REQUIRES(allocations_mutex_) {
allocations_.emplace_back(std::make_unique<FreeNode[]>(kAllocationSize));

Expand All @@ -134,11 +157,6 @@ void Node::Pool::AllocateNewBatch() REQUIRES(allocations_mutex_) {
}

void Node::Pool::ReleaseChildren(Node* node) {
Mutex::Lock lock(mutex_);
ReleaseChildrenInternal(node);
}

void Node::Pool::ReleaseChildrenInternal(Node* node) REQUIRES(mutex_) {
Node* next = node->child_;
// Iterating manually rather than with iterator, as node is released in the
// middle and can be taken by other threads, so we have to be careful.
Expand All @@ -147,7 +165,7 @@ void Node::Pool::ReleaseChildrenInternal(Node* node) REQUIRES(mutex_) {
// Getting next after releasing node, as otherwise it can be reallocated
// and overwritten.
next = next->sibling_;
ReleaseSubtreeInternal(iter);
ReleaseSubtree(iter);
}
node->child_ = nullptr;
}
Expand All @@ -173,17 +191,18 @@ void Node::Pool::ReleaseAllChildrenExceptOne(Node* root, Node* subtree) {
}

void Node::Pool::ReleaseSubtree(Node* node) {
Mutex::Lock lock(mutex_);
ReleaseSubtreeInternal(node);
}

void Node::Pool::ReleaseSubtreeInternal(Node* node) REQUIRES(mutex_) {
ReleaseChildrenInternal(node);
ReleaseNodeInternal(node);
Mutex::Lock lock(gc_mutex_);
subtrees_to_gc_.push_back(node);
}

Node::Pool gNodePool;

void GarbageCollectNodePool() { gNodePool.GarbageCollect(); }

/////////////////////////////////////////////////////////////////////////
// Node
/////////////////////////////////////////////////////////////////////////

Node* Node::CreateChild(Move m) {
Node* new_node = gNodePool.AllocateNode();
new_node->parent_ = this;
Expand Down Expand Up @@ -287,7 +306,8 @@ V3TrainingData Node::GetV3TrainingData(GameResult game_result,
result.version = 3;

// Populate probabilities.
float total_n = n_ - 1; // First visit was expansion of it inself.
float total_n =
static_cast<float>(n_ - 1); // First visit was expansion of it inself.
std::memset(result.probabilities, 0, sizeof(result.probabilities));
for (Node* iter : Children()) {
result.probabilities[iter->move_.as_nn_index()] = iter->n_ / total_n;
Expand Down Expand Up @@ -324,7 +344,7 @@ V3TrainingData Node::GetV3TrainingData(GameResult game_result,
return result;
}

void NodeTree::MakeMove(Move move) {
void NodeTree::MakeMove(Move move, bool auto_garbage_collect) {
if (HeadPosition().IsBlackToMove()) move.Mirror();

Node* new_head = nullptr;
Expand All @@ -337,10 +357,13 @@ void NodeTree::MakeMove(Move move) {
gNodePool.ReleaseAllChildrenExceptOne(current_head_, new_head);
current_head_ = new_head ? new_head : current_head_->CreateChild(move);
history_.Append(move);

if (auto_garbage_collect) GarbageCollectNodePool();
}

void NodeTree::ResetToPosition(const std::string& starting_fen,
const std::vector<Move>& moves) {
const std::vector<Move>& moves,
bool auto_garbage_collect) {
ChessBoard starting_board;
int no_capture_ply;
int full_moves;
Expand All @@ -363,7 +386,7 @@ void NodeTree::ResetToPosition(const std::string& starting_fen,
current_head_ = gamebegin_node_;
bool seen_old_head = (gamebegin_node_ == old_head);
for (const auto& move : moves) {
MakeMove(move);
MakeMove(move, false);
if (old_head == current_head_) seen_old_head = true;
}

Expand All @@ -374,6 +397,8 @@ void NodeTree::ResetToPosition(const std::string& starting_fen,
gNodePool.ReleaseChildren(current_head_);
current_head_->ResetStats();
}

if (auto_garbage_collect) GarbageCollectNodePool();
}

void NodeTree::DeallocateTree() {
Expand Down
Loading

0 comments on commit 2321011

Please sign in to comment.