Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix castling handling, update default params. #12

Merged
merged 9 commits into from
Jun 2, 2018
10 changes: 9 additions & 1 deletion src/chess/bitboard.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ std::vector<unsigned short> BuildMoveIndices() {
}

const std::vector<unsigned short> kMoveToIdx = BuildMoveIndices();
const int kKingCastleIndex =
kMoveToIdx[BoardSquare("e1").as_int() * 64 + BoardSquare("h1").as_int()];
const int kQueenCastleIndex =
kMoveToIdx[BoardSquare("e1").as_int() * 64 + BoardSquare("a1").as_int()];
} // namespace

Move::Move(const std::string& str, bool black) {
Expand Down Expand Up @@ -307,6 +311,10 @@ uint16_t Move::as_packed_int() const {
}
}

uint16_t Move::as_nn_index() const { return kMoveToIdx[as_packed_int()]; }
uint16_t Move::as_nn_index() const {
if (!castling_) return kMoveToIdx[as_packed_int()];
if (from_.col() < to_.col()) return kKingCastleIndex;
return kQueenCastleIndex;
}

} // namespace lczero
6 changes: 1 addition & 5 deletions src/chess/bitboard.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,7 @@ class Move {
}

std::string as_string() const {
BoardSquare to = to_;
if (castling_) {
to = BoardSquare(to.row(), (to.col() == 7) ? 6 : 2);
}
std::string res = from_.as_string() + to.as_string();
std::string res = from_.as_string() + to_.as_string();
switch (promotion_) {
case Promotion::None:
return res;
Expand Down
16 changes: 5 additions & 11 deletions src/chess/board.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ MoveList ChessBoard::GeneratePseudolegalMoves() const {
}
}
if (can_castle) {
result.emplace_back(source, BoardSquare(0, 7));
result.emplace_back(source, BoardSquare(0, 6));
result.back().SetCastling();
}
}
Expand All @@ -230,7 +230,7 @@ MoveList ChessBoard::GeneratePseudolegalMoves() const {
}
}
if (can_castle) {
result.emplace_back(source, BoardSquare(0, 0));
result.emplace_back(source, BoardSquare(0, 2));
result.back().SetCastling();
}
}
Expand Down Expand Up @@ -344,9 +344,9 @@ bool ChessBoard::ApplyMove(Move move) {
const auto to_row = to.row();
const auto to_col = to.col();

// Remove our piece from old location, but not put to destination
// (for the case of castling).
// Move in our pieces.
our_pieces_.reset(from);
our_pieces_.set(to);

// Remove captured piece
bool reset_50_moves = their_pieces_.get(to);
Expand Down Expand Up @@ -378,26 +378,20 @@ bool ChessBoard::ApplyMove(Move move) {
if (from == our_king_) {
castlings_.reset_we_can_00();
castlings_.reset_we_can_000();
our_king_ = to;
// Castling
if (to_col - from_col > 1) {
// 0-0
our_pieces_.reset(7);
rooks_.reset(7);
our_pieces_.set(5);
rooks_.set(5);
our_king_ = BoardSquare(0, 6); /* g8 */
our_pieces_.set(our_king_);
} else if (from_col - to_col > 1) {
// 0-0-0
our_pieces_.reset(0);
rooks_.reset(0);
our_pieces_.set(3);
rooks_.set(3);
our_king_ = BoardSquare(0, 2); /* c8 */
our_pieces_.set(our_king_);
} else {
our_king_ = to;
our_pieces_.set(to);
}
return reset_50_moves;
}
Expand Down
32 changes: 26 additions & 6 deletions src/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,17 @@ void EngineController::PopulateOptions(OptionsParser* options) {
options->Add<ChoiceOption>(kNnBackendStr, backends, "backend") =
backends.empty() ? "<none>" : backends[0];
options->Add<StringOption>(kNnBackendOptionsStr, "backend-opts");
options->Add<FloatOption>(kSlowMoverStr, 0.0, 100.0, "slowmover") = 2.2;
options->Add<FloatOption>(kSlowMoverStr, 0.0f, 100.0f, "slowmover") = 2.2f;
options->Add<IntOption>(kMoveOverheadStr, 0, 10000, "move-overhead") = 100;

Search::PopulateUciParams(options);

auto defaults = options->GetMutableDefaultsOptions();

defaults->Set<int>(Search::kMiniBatchSizeStr, 256); // Minibatch = 256
defaults->Set<float>(Search::kFpuReductionStr, 0.2f); // FPU reduction = 0.2
defaults->Set<float>(Search::kCpuctStr, 3.1f); // CPUCT = 3.1
defaults->Set<int>(Search::kAllowedNodeCollisionsStr, 32); // Node collisions
}

SearchLimits EngineController::PopulateSearchLimits(int /*ply*/, bool is_black,
Expand All @@ -87,8 +94,9 @@ SearchLimits EngineController::PopulateSearchLimits(int /*ply*/, bool is_black,
float slowmover = options_.Get<float>(kSlowMoverStr);
int64_t move_overhead = options_.Get<int>(kMoveOverheadStr);
// Total time till control including increments.
auto total_moves_time = std::max(
int64_t{0}, time + increment * (movestogo - 1) - move_overhead * movestogo);
auto total_moves_time =
std::max(int64_t{0},
time + increment * (movestogo - 1) - move_overhead * movestogo);

const int kSmartPruningToleranceMs = 200;

Expand All @@ -99,10 +107,12 @@ SearchLimits EngineController::PopulateSearchLimits(int /*ply*/, bool is_black,
// reduce it.
if (slowmover < 1.0 || this_move_time > kSmartPruningToleranceMs) {
// Budget X*slowmover for current move, X*1.0 for the rest.
this_move_time = total_moves_time / (movestogo - 1 + slowmover) * slowmover;
this_move_time = static_cast<int64_t>(
total_moves_time / (movestogo - 1 + slowmover) * slowmover);
}
// Make sure we don't exceed current time limit with what we calculated.
limits.time_ms = std::max(int64_t{0}, std::min(this_move_time, time - move_overhead));
limits.time_ms =
std::max(int64_t{0}, std::min(this_move_time, time - move_overhead));
return limits;
}

Expand Down Expand Up @@ -134,6 +144,11 @@ void EngineController::UpdateNetwork() {

void EngineController::SetCacheSize(int size) { cache_.SetCapacity(size); }

void EngineController::EnsureReady() {
GarbageCollectNodePool();
std::unique_lock<RpSharedMutex> lock(busy_mutex_);
}

void EngineController::NewGame() {
SharedLock lock(busy_mutex_);
cache_.Clear();
Expand All @@ -151,7 +166,7 @@ void EngineController::SetPosition(const std::string& fen,

std::vector<Move> moves;
for (const auto& move : moves_str) moves.emplace_back(move);
tree_->ResetToPosition(fen, moves);
tree_->ResetToPosition(fen, moves, false);
UpdateNetwork();
}

Expand All @@ -163,6 +178,11 @@ void EngineController::Go(const GoParams& params) {
auto limits = PopulateSearchLimits(tree_->GetPlyCount(),
tree_->IsBlackToMove(), params);

BestMoveInfo::Callback best_move_callback = [this](const BestMoveInfo& info) {
best_move_callback_(info);
GarbageCollectNodePool();
};

search_ =
std::make_unique<Search>(*tree_, network_.get(), best_move_callback_,
info_callback_, limits, options_, &cache_);
Expand Down
2 changes: 1 addition & 1 deletion src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class EngineController {
void PopulateOptions(OptionsParser* options);

// Blocks.
void EnsureReady() { std::unique_lock<RpSharedMutex> lock(busy_mutex_); }
void EnsureReady();

// Must not block.
void NewGame();
Expand Down
91 changes: 58 additions & 33 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,21 @@ class Node::Pool {
public:
// Allocates a new node and initializes it with all zeros.
Node* AllocateNode();
// Return node to the pool.
void ReleaseNode(Node*);

// Release* function don't release trees immediately but rather schedule
// release until when GarbageCollect() is called.
// Releases all children of the node, except specified. Also updates pointers
// accordingly.
void ReleaseAllChildrenExceptOne(Node* root, Node* subtree);
// Releases all children, but doesn't release the node isself.
void ReleaseChildren(Node*);
// Releases all children and the node itself;
void ReleaseSubtree(Node*);
// Really releases subtrees makerd for release earlier.
void GarbageCollect();

private:
void AllocateNewBatch();
void ReleaseNodeInternal(Node*);
void ReleaseChildrenInternal(Node*);
void ReleaseSubtreeInternal(Node*);

union FreeNode {
FreeNode* next;
Expand All @@ -65,6 +64,8 @@ class Node::Pool {
FreeNode() {}
};

static FreeNode* UnrollNodeTree(FreeNode* node);

mutable Mutex mutex_;
// Linked list of free nodes.
FreeNode* free_list_ GUARDED_BY(mutex_) = nullptr;
Expand All @@ -74,8 +75,41 @@ class Node::Pool {
FreeNode* reserve_list_ GUARDED_BY(allocations_mutex_) = nullptr;
std::vector<std::unique_ptr<FreeNode[]>> allocations_
GUARDED_BY(allocations_mutex_);

mutable Mutex gc_mutex_;
std::vector<Node*> subtrees_to_gc_ GUARDED_BY(gc_mutex_);
};

Node::Pool::FreeNode* Node::Pool::UnrollNodeTree(FreeNode* node) {
if (!node->node.child_) return node;
FreeNode* prev = node;
for (Node* iter = node->node.child_; iter; iter = iter->sibling_) {
FreeNode* next = reinterpret_cast<FreeNode*>(iter);
prev->next = next;
prev = UnrollNodeTree(next);
}
return prev;
}

void Node::Pool::GarbageCollect() {
while (true) {
Node* node_to_gc = nullptr;
{
Mutex::Lock lock(gc_mutex_);
if (subtrees_to_gc_.empty()) return;
node_to_gc = subtrees_to_gc_.back();
subtrees_to_gc_.pop_back();
}
FreeNode* head = reinterpret_cast<FreeNode*>(node_to_gc);
FreeNode* tail = UnrollNodeTree(head);
{
Mutex::Lock lock(mutex_);
tail->next = free_list_;
free_list_ = head;
}
}
}

Node* Node::Pool::AllocateNode() {
while (true) {
Node* result = nullptr;
Expand Down Expand Up @@ -111,17 +145,6 @@ Node* Node::Pool::AllocateNode() {
}
}

void Node::Pool::ReleaseNode(Node* node) {
Mutex::Lock lock(mutex_);
ReleaseNodeInternal(node);
}

void Node::Pool::ReleaseNodeInternal(Node* node) REQUIRES(mutex_) {
auto* free_node = reinterpret_cast<FreeNode*>(node);
free_node->next = free_list_;
free_list_ = free_node;
}

void Node::Pool::AllocateNewBatch() REQUIRES(allocations_mutex_) {
allocations_.emplace_back(std::make_unique<FreeNode[]>(kAllocationSize));

Expand All @@ -134,11 +157,6 @@ void Node::Pool::AllocateNewBatch() REQUIRES(allocations_mutex_) {
}

void Node::Pool::ReleaseChildren(Node* node) {
Mutex::Lock lock(mutex_);
ReleaseChildrenInternal(node);
}

void Node::Pool::ReleaseChildrenInternal(Node* node) REQUIRES(mutex_) {
Node* next = node->child_;
// Iterating manually rather than with iterator, as node is released in the
// middle and can be taken by other threads, so we have to be careful.
Expand All @@ -147,7 +165,7 @@ void Node::Pool::ReleaseChildrenInternal(Node* node) REQUIRES(mutex_) {
// Getting next after releasing node, as otherwise it can be reallocated
// and overwritten.
next = next->sibling_;
ReleaseSubtreeInternal(iter);
ReleaseSubtree(iter);
}
node->child_ = nullptr;
}
Expand All @@ -173,17 +191,18 @@ void Node::Pool::ReleaseAllChildrenExceptOne(Node* root, Node* subtree) {
}

void Node::Pool::ReleaseSubtree(Node* node) {
Mutex::Lock lock(mutex_);
ReleaseSubtreeInternal(node);
}

void Node::Pool::ReleaseSubtreeInternal(Node* node) REQUIRES(mutex_) {
ReleaseChildrenInternal(node);
ReleaseNodeInternal(node);
Mutex::Lock lock(gc_mutex_);
subtrees_to_gc_.push_back(node);
}

Node::Pool gNodePool;

void GarbageCollectNodePool() { gNodePool.GarbageCollect(); }

/////////////////////////////////////////////////////////////////////////
// Node
/////////////////////////////////////////////////////////////////////////

Node* Node::CreateChild(Move m) {
Node* new_node = gNodePool.AllocateNode();
new_node->parent_ = this;
Expand Down Expand Up @@ -287,7 +306,8 @@ V3TrainingData Node::GetV3TrainingData(GameResult game_result,
result.version = 3;

// Populate probabilities.
float total_n = n_ - 1; // First visit was expansion of it inself.
float total_n =
static_cast<float>(n_ - 1); // First visit was expansion of it inself.
std::memset(result.probabilities, 0, sizeof(result.probabilities));
for (Node* iter : Children()) {
result.probabilities[iter->move_.as_nn_index()] = iter->n_ / total_n;
Expand Down Expand Up @@ -324,7 +344,7 @@ V3TrainingData Node::GetV3TrainingData(GameResult game_result,
return result;
}

void NodeTree::MakeMove(Move move) {
void NodeTree::MakeMove(Move move, bool auto_garbage_collect) {
if (HeadPosition().IsBlackToMove()) move.Mirror();

Node* new_head = nullptr;
Expand All @@ -337,10 +357,13 @@ void NodeTree::MakeMove(Move move) {
gNodePool.ReleaseAllChildrenExceptOne(current_head_, new_head);
current_head_ = new_head ? new_head : current_head_->CreateChild(move);
history_.Append(move);

if (auto_garbage_collect) GarbageCollectNodePool();
}

void NodeTree::ResetToPosition(const std::string& starting_fen,
const std::vector<Move>& moves) {
const std::vector<Move>& moves,
bool auto_garbage_collect) {
ChessBoard starting_board;
int no_capture_ply;
int full_moves;
Expand All @@ -363,7 +386,7 @@ void NodeTree::ResetToPosition(const std::string& starting_fen,
current_head_ = gamebegin_node_;
bool seen_old_head = (gamebegin_node_ == old_head);
for (const auto& move : moves) {
MakeMove(move);
MakeMove(move, false);
if (old_head == current_head_) seen_old_head = true;
}

Expand All @@ -374,6 +397,8 @@ void NodeTree::ResetToPosition(const std::string& starting_fen,
gNodePool.ReleaseChildren(current_head_);
current_head_->ResetStats();
}

if (auto_garbage_collect) GarbageCollectNodePool();
}

void NodeTree::DeallocateTree() {
Expand Down
Loading