Skip to content

Commit

Permalink
Merge pull request #1 from LeelaChessZero/master
Browse files Browse the repository at this point in the history
Catch up to master
  • Loading branch information
AlexisOlson authored Apr 18, 2020
2 parents a56709c + 2f46f4d commit 58d97b0
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 84 deletions.
8 changes: 4 additions & 4 deletions src/chess/board.cc
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ MoveList ChessBoard::GenerateLegalMoves() const {
return result;
}

void ChessBoard::SetFromFen(const std::string& fen, int* no_capture_ply,
void ChessBoard::SetFromFen(const std::string& fen, int* rule50_ply,
int* moves) {
Clear();
int row = 7;
Expand All @@ -980,10 +980,10 @@ void ChessBoard::SetFromFen(const std::string& fen, int* no_capture_ply,
string who_to_move;
string castlings;
string en_passant;
int no_capture_halfmoves;
int rule50_halfmoves;
int total_moves;
fen_str >> board >> who_to_move >> castlings >> en_passant >>
no_capture_halfmoves >> total_moves;
rule50_halfmoves >> total_moves;

if (!fen_str) throw Exception("Bad fen string: " + fen);

Expand Down Expand Up @@ -1096,7 +1096,7 @@ void ChessBoard::SetFromFen(const std::string& fen, int* no_capture_ply,
if (who_to_move == "b" || who_to_move == "B") {
Mirror();
}
if (no_capture_ply) *no_capture_ply = no_capture_halfmoves;
if (rule50_ply) *rule50_ply = rule50_halfmoves;
if (moves) *moves = total_moves;
}

Expand Down
4 changes: 2 additions & 2 deletions src/chess/board.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ class ChessBoard {
static const BitBoard kPawnMask;

// Sets position from FEN string.
// If @no_capture_ply and @moves are not nullptr, they are filled with number
// If @rule50_ply and @moves are not nullptr, they are filled with number
// of moves without capture and number of full moves since the beginning of
// the game.
void SetFromFen(const std::string& fen, int* no_capture_ply = nullptr,
void SetFromFen(const std::string& fen, int* rule50_ply = nullptr,
int* moves = nullptr);
// Nullifies the whole structure.
void Clear();
Expand Down
30 changes: 18 additions & 12 deletions src/chess/position.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@
namespace lczero {

Position::Position(const Position& parent, Move m)
: no_capture_ply_(parent.no_capture_ply_ + 1),
: rule50_ply_(parent.rule50_ply_ + 1),
ply_count_(parent.ply_count_ + 1) {
them_board_ = parent.us_board_;
const bool capture = them_board_.ApplyMove(m);
const bool is_zeroing = them_board_.ApplyMove(m);
us_board_ = them_board_;
us_board_.Mirror();
if (capture) no_capture_ply_ = 0;
if (is_zeroing) rule50_ply_ = 0;
}

Position::Position(const ChessBoard& board, int no_capture_ply, int game_ply)
: no_capture_ply_(no_capture_ply), repetitions_(0), ply_count_(game_ply) {
Position::Position(const ChessBoard& board, int rule50_ply, int game_ply)
: rule50_ply_(rule50_ply), repetitions_(0), ply_count_(game_ply) {
us_board_ = board;
them_board_ = board;
them_board_.Mirror();
Expand All @@ -54,6 +54,12 @@ uint64_t Position::Hash() const {

std::string Position::DebugString() const { return us_board_.DebugString(); }

GameResult operator-(const GameResult& res) {
return res == GameResult::BLACK_WON
? GameResult::WHITE_WON
: res == GameResult::WHITE_WON ? GameResult::BLACK_WON : res;
}

GameResult PositionHistory::ComputeGameResult() const {
const auto& board = Last().GetBoard();
auto legal_moves = board.GenerateLegalMoves();
Expand All @@ -67,17 +73,17 @@ GameResult PositionHistory::ComputeGameResult() const {
}

if (!board.HasMatingMaterial()) return GameResult::DRAW;
if (Last().GetNoCaptureNoPawnPly() >= 100) return GameResult::DRAW;
if (Last().GetRule50Ply() >= 100) return GameResult::DRAW;
if (Last().GetGamePly() >= 450) return GameResult::DRAW;
if (Last().GetRepetitions() >= 2) return GameResult::DRAW;

return GameResult::UNDECIDED;
}

void PositionHistory::Reset(const ChessBoard& board, int no_capture_ply,
void PositionHistory::Reset(const ChessBoard& board, int rule50_ply,
int game_ply) {
positions_.clear();
positions_.emplace_back(board, no_capture_ply, game_ply);
positions_.emplace_back(board, rule50_ply, game_ply);
}

void PositionHistory::Append(Move m) {
Expand All @@ -91,14 +97,14 @@ void PositionHistory::Append(Move m) {
int PositionHistory::ComputeLastMoveRepetitions() const {
const auto& last = positions_.back();
// TODO(crem) implement hash/cache based solution.
if (last.GetNoCaptureNoPawnPly() < 4) return 0;
if (last.GetRule50Ply() < 4) return 0;

for (int idx = positions_.size() - 3; idx >= 0; idx -= 2) {
const auto& pos = positions_[idx];
if (pos.GetBoard() == last.GetBoard()) {
return 1 + pos.GetRepetitions();
}
if (pos.GetNoCaptureNoPawnPly() < 2) return 0;
if (pos.GetRule50Ply() < 2) return 0;
}
return 0;
}
Expand All @@ -107,7 +113,7 @@ bool PositionHistory::DidRepeatSinceLastZeroingMove() const {
for (auto iter = positions_.rbegin(), end = positions_.rend(); iter != end;
++iter) {
if (iter->GetRepetitions() > 0) return true;
if (iter->GetNoCaptureNoPawnPly() == 0) return false;
if (iter->GetRule50Ply() == 0) return false;
}
return false;
}
Expand All @@ -119,7 +125,7 @@ uint64_t PositionHistory::HashLast(int positions) const {
if (!positions--) break;
hash = HashCat(hash, iter->Hash());
}
return HashCat(hash, Last().GetNoCaptureNoPawnPly());
return HashCat(hash, Last().GetRule50Ply());
}

} // namespace lczero
12 changes: 7 additions & 5 deletions src/chess/position.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Position {
// From parent position and move.
Position(const Position& parent, Move m);
// From particular position.
Position(const ChessBoard& board, int no_capture_ply, int game_ply);
Position(const ChessBoard& board, int rule50_ply, int game_ply);

uint64_t Hash() const;
bool IsBlackToMove() const { return us_board_.flipped(); }
Expand All @@ -54,7 +54,7 @@ class Position {
void SetRepetitions(int repetitions) { repetitions_ = repetitions; }

// Number of ply with no captures and pawn moves.
int GetNoCaptureNoPawnPly() const { return no_capture_ply_; }
int GetRule50Ply() const { return rule50_ply_; }

// Gets board from the point of view of player to move.
const ChessBoard& GetBoard() const { return us_board_; }
Expand All @@ -70,14 +70,16 @@ class Position {
ChessBoard them_board_;

// How many half-moves without capture or pawn move was there.
int no_capture_ply_ = 0;
int rule50_ply_ = 0;
// How many repetitions this position had before. For new positions it's 0.
int repetitions_;
// number of half-moves since beginning of the game.
int ply_count_ = 0;
};

enum class GameResult { UNDECIDED, WHITE_WON, DRAW, BLACK_WON };
// These are ordered so max() prefers the best result.
enum class GameResult : uint8_t { UNDECIDED, BLACK_WON, DRAW, WHITE_WON };
GameResult operator-(const GameResult& res);

class PositionHistory {
public:
Expand All @@ -102,7 +104,7 @@ class PositionHistory {
int GetLength() const { return positions_.size(); }

// Resets the position to a given state.
void Reset(const ChessBoard& board, int no_capture_ply, int game_ply);
void Reset(const ChessBoard& board, int rule50_ply, int game_ply);

// Appends a position to history.
void Append(Move m);
Expand Down
29 changes: 21 additions & 8 deletions src/mcts/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,14 @@ std::string Node::DebugString() const {
<< " Parent:" << parent_ << " Index:" << index_
<< " Child:" << child_.get() << " Sibling:" << sibling_.get()
<< " WL:" << wl_ << " N:" << n_ << " N_:" << n_in_flight_
<< " Edges:" << edges_.size();
<< " Edges:" << edges_.size()
<< " Bounds:" << static_cast<int>(lower_bound_) - 2 << ","
<< static_cast<int>(upper_bound_) - 2;
return oss.str();
}

void Node::MakeTerminal(GameResult result, float plies_left, Terminal type) {
SetBounds(result, result);
terminal_type_ = type;
m_ = plies_left;
if (result == GameResult::DRAW) {
Expand Down Expand Up @@ -257,6 +260,11 @@ void Node::MakeNotTerminal() {
}
}

void Node::SetBounds(GameResult lower, GameResult upper) {
lower_bound_ = lower;
upper_bound_ = upper;
}

bool Node::TryStartScoreUpdate() {
if (n_ == 0 && n_in_flight_ > 0) return false;
++n_in_flight_;
Expand Down Expand Up @@ -380,18 +388,21 @@ V5TrainingData Node::GetV5TrainingData(
// Other params.
if (input_format ==
pblczero::NetworkFormat::INPUT_112_WITH_CANONICALIZATION) {
result.side_to_move = position.GetBoard().en_passant().as_int() >> 56;
result.side_to_move_or_enpassant =
position.GetBoard().en_passant().as_int() >> 56;
if ((transform & FlipTransform) != 0) {
result.side_to_move = ReverseBitsInBytes(result.side_to_move);
result.side_to_move_or_enpassant =
ReverseBitsInBytes(result.side_to_move_or_enpassant);
}
// Send transform in deprecated move count so rescorer can reverse it to
// calculate the actual move list from the input data.
result.deprecated_move_count = transform;
result.invariance_info =
transform | (position.IsBlackToMove() ? (1u << 7) : 0u);
} else {
result.side_to_move = position.IsBlackToMove() ? 1 : 0;
result.deprecated_move_count = 0;
result.side_to_move_or_enpassant = position.IsBlackToMove() ? 1 : 0;
result.invariance_info = 0;
}
result.rule50_count = position.GetNoCaptureNoPawnPly();
result.rule50_count = position.GetRule50Ply();

// Game result.
if (game_result == GameResult::WHITE_WON) {
Expand Down Expand Up @@ -468,7 +479,9 @@ bool NodeTree::ResetToPosition(const std::string& starting_fen,
int no_capture_ply;
int full_moves;
starting_board.SetFromFen(starting_fen, &no_capture_ply, &full_moves);
if (gamebegin_node_ && history_.Starting().GetBoard() != starting_board) {
if (gamebegin_node_ &&
(history_.Starting().GetBoard() != starting_board ||
history_.Starting().GetRule50Ply() != no_capture_ply)) {
// Completely different position.
DeallocateTree();
}
Expand Down
25 changes: 20 additions & 5 deletions src/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,15 @@ class Node {
using Iterator = Edge_Iterator<false>;
using ConstIterator = Edge_Iterator<true>;

enum class Terminal : uint8_t { NonTerminal, Terminal, Tablebase };
enum class Terminal : uint8_t { NonTerminal, EndOfGame, Tablebase };

// Takes pointer to a parent node and own index in a parent.
Node(Node* parent, uint16_t index) : parent_(parent), index_(index) {}
Node(Node* parent, uint16_t index)
: parent_(parent),
index_(index),
terminal_type_(Terminal::NonTerminal),
lower_bound_(GameResult::BLACK_WON),
upper_bound_(GameResult::WHITE_WON) {}

// Allocates a new edge and a new node. The node has to be no edges before
// that.
Expand Down Expand Up @@ -166,13 +171,16 @@ class Node {
// Returns whether the node is known to be draw/lose/win.
bool IsTerminal() const { return terminal_type_ != Terminal::NonTerminal; }
bool IsTbTerminal() const { return terminal_type_ == Terminal::Tablebase; }
typedef std::pair<GameResult, GameResult> Bounds;
Bounds GetBounds() const { return {lower_bound_, upper_bound_}; }
uint16_t GetNumEdges() const { return edges_.size(); }

// Makes the node terminal and sets it's score.
void MakeTerminal(GameResult result, float plies_left = 0.0f,
Terminal type = Terminal::Terminal);
Terminal type = Terminal::EndOfGame);
// Makes the node not terminal and updates its visits.
void MakeNotTerminal();
void SetBounds(GameResult lower, GameResult upper);

// If this node is not in the process of being expanded by another thread
// (which can happen only if n==0 and n-in-flight==1), mark the node as
Expand Down Expand Up @@ -301,9 +309,12 @@ class Node {
// Index of this node is parent's edge list.
uint16_t index_;

// 1 byte fields.
// Bit fields using parts of uint8_t fields initialized in the constructor.
// Whether or not this node end game (with a winning of either sides or draw).
Terminal terminal_type_ = Terminal::NonTerminal;
Terminal terminal_type_ : 2;
// Best and worst result for this node.
GameResult lower_bound_ : 2;
GameResult upper_bound_ : 2;

// TODO(mooskagh) Unfriend NodeTree.
friend class NodeTree;
Expand Down Expand Up @@ -372,6 +383,10 @@ class EdgeAndNode {
// Whether the node is known to be terminal.
bool IsTerminal() const { return node_ ? node_->IsTerminal() : false; }
bool IsTbTerminal() const { return node_ ? node_->IsTbTerminal() : false; }
Node::Bounds GetBounds() const {
return node_ ? node_->GetBounds()
: Node::Bounds{GameResult::BLACK_WON, GameResult::WHITE_WON};
}

// Edge related getters.
float GetP() const { return edge_->GetP(); }
Expand Down
Loading

0 comments on commit 58d97b0

Please sign in to comment.