Skip to content

Commit

Permalink
squashed a bunch of changes from temp branch
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunaugh committed Feb 24, 2024
1 parent 1647d11 commit f9222d9
Show file tree
Hide file tree
Showing 15 changed files with 130 additions and 128 deletions.
29 changes: 29 additions & 0 deletions ChessNotationConverter/Helpers/TrainingDataHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using ChessNotationConverter.Models;

namespace ChessNotationConverter.Helpers
{
internal static class TrainingDataHelper
{
private static string outFile = "../../../../train_data.txt";

internal static void WritePositionsToTrainingFile(List<Evaluation> evaluations)
{
using var sw = new StreamWriter(outFile, true);

foreach (var evaluation in evaluations)
{
if (evaluation.Position.BlackToMove)
{
evaluation.Position = evaluation.Position.Invert();
if (evaluation.Score != 0)
{
evaluation.Score = -evaluation.Score; // no need to invert 0
}
}

var outputStr = evaluation.Position.Serialize(true) + evaluation.Score;
sw.WriteLine(outputStr);
}
}
}
}
4 changes: 2 additions & 2 deletions ChessNotationConverter/Models/Evaluation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
{
internal class Evaluation
{
internal Position Position { get; private set; }
internal float Score { get; private set; }
internal Position Position { get; set; }
internal float Score { get; set; }
internal int Hash { get; }
internal Evaluation(Position position, float score)
{
Expand Down
37 changes: 31 additions & 6 deletions ChessNotationConverter/Models/Position.cs
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
using System.Text;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Text;

namespace ChessNotationConverter.Models
{
internal class Position
{
private Player player;
internal bool WhiteToMove => player == Player.White;
internal bool BlackToMove => player == Player.Black;
private Player playerToMove;
internal bool WhiteToMove => playerToMove == Player.White;
internal bool BlackToMove => playerToMove == Player.Black;
internal int[,] Matrix { get; set; }

public Position(Player player)
{
this.player = player;
this.playerToMove = player;

// Starting position
Matrix = new int[8, 8]
Expand All @@ -27,9 +29,14 @@ public Position(Player player)
{ -5,-2,-3,-8,-9,-3,-2,-5 },
};
}
public Position(int[,] board, Player player)
{
playerToMove = player;
Matrix = board;
}
public Position(Player player, Position previous, string moveStr)
{
this.player = player;
playerToMove = player;
Matrix = previous.MakeMove(moveStr);
}

Expand Down Expand Up @@ -93,5 +100,23 @@ internal string Serialize(bool trailingComma = true)

return trailingComma ? sb.ToString() : sb.ToString().TrimEnd(',');
}

internal Position Invert()
{
var newPosition = DeepCopy(Matrix);

for (int i = 0; i < 8; i++)
{
for (int j = 0; j < 8; j++)
{
if (newPosition[i,j] != 0)
{
newPosition[i,j] = -newPosition[i,j];
}
}
}

return new Position(newPosition, playerToMove);
}
}
}
44 changes: 33 additions & 11 deletions ChessNotationConverter/PositionEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,43 @@ public static class PositionEvaluator
{
internal static Evaluation EvaluatePosition(Game game, int positionIndex)
{
float result;
var result = (float)0;

// this assumes white won
// using logarithmic function to avoid late-middle game confusion
result = 1 - (float)Math.Log(game.MoveCount + 1 - positionIndex) / (float)Math.Log(game.MoveCount + 1);

if(game.Outcome == -1)
if(game.Outcome == 0)
{
// negate for black
result = -result;
// for a draw, every move was pretty good
result = 0.15f;
}
else if (game.Outcome == 0)
else
{
// always 0 for draw - assume no difference in eval for white/black
result = 0f;
// early game
if (positionIndex < 21)
{
// If the move was played, we'll say it was good
result = 0.1f;
}
// middle/end game
else
{
// this assumes white won
// using logarithmic function to avoid late-middle game confusion
//var numerator = (float)Math.Log(game.MoveCount + 1 - positionIndex);
//var denominator = (float)Math.Log(game.MoveCount + 1);
//result = 1 - numerator / denominator;

int adjustedMoveCount = Math.Max(0, game.MoveCount - 20);
int adjustedPositionIndex = Math.Max(0, positionIndex - 20);

var numerator = (float)Math.Log(adjustedMoveCount + 1 - adjustedPositionIndex);
var denominator = (float)Math.Log(adjustedMoveCount + 1);
result = 1 - numerator / denominator;

if (game.Outcome == -1)
{
// negate for black
result = -result;
}
}
}

return new Evaluation(game.Positions[positionIndex], result);
Expand Down
30 changes: 6 additions & 24 deletions ChessNotationConverter/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
// https://www.chess.com/analysis?tab=analysis

using ChessNotationConverter;
using ChessNotationConverter.Helpers;
using ChessNotationConverter.Models;

const string filePath = "C:\\Users\\sbrunaugh\\Downloads\\all_with_filtered_anotations_since1998.txt\\all_with_filtered_anotations_since1998.txt";
const string whiteOutFile = "../../../../train_data_white.txt";
const string blackOutFile = "../../../../train_data_black.txt";
const string filePath = "../../../../all_with_filtered_anotations_since1998.txt";

string line;
int lineNumber = 0;
var games = new List<Game>();
Expand Down Expand Up @@ -65,32 +65,14 @@

games.Clear();

using (var whiteSw = new StreamWriter(whiteOutFile, true))
{
using (var blackSw = new StreamWriter(blackOutFile, true))
{
foreach (var evaluation in evaluations)
{
var outputStr = evaluation.Position.Serialize(true) + evaluation.Score;

if(evaluation.Position.WhiteToMove)
{
whiteSw.WriteLine(outputStr);
}
else
{
blackSw.WriteLine(outputStr);
}
}
}
}
TrainingDataHelper.WritePositionsToTrainingFile(evaluations);

Console.WriteLine($"Counts: {gameCount} games with {uniquePositions} unique positions written. Discarded {duplicatePositions} duplicate positions.");
evaluations.Clear();
}

// convert/evaluate only 100000
if (gameCount >= 100000)
// convert/evaluate only 50000
if (gameCount >= 50000)
break;
}
}
Expand Down
Binary file removed EngineInteraction.xlsx
Binary file not shown.
16 changes: 4 additions & 12 deletions neuralnetwork/ChessModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,19 @@
import numpy as np

class ChessModel:
w_file_path = './sbrunaugh_chess_model_v11_white.keras'
b_file_path = './sbrunaugh_chess_model_v11_black.keras'
file_path = './sbrunaugh_chess_model_v13.keras'

def __init__(self) -> None:
self.model_white = load_model(self.w_file_path)
self.model_black = load_model(self.b_file_path)
self.model = load_model(self.file_path)

def forward_pass(self, position, is_white_to_move) -> float:
if len(position) != 64:
raise ValueError("Position array must have exactly 64 items.")

model_input = np.array(position, dtype=int).reshape(1, -1)

eval = None

# Run the forward pass
if (is_white_to_move):
eval = self.model_white.predict(model_input)
print('white-to-move-model evaluated position as ', eval[0][0])
else:
eval = self.model_black.predict(model_input)
print('black-to-move-model evaluated position as ', eval[0][0])
eval = self.model.predict(model_input)
print('model evaluated position as ', eval[0][0])

return float(eval[0][0])
6 changes: 2 additions & 4 deletions neuralnetwork/count_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,5 @@ def count_lines(filename):
return len(lines)

# Usage
w_filename = "../train_data_white.txt"
b_filename = "../train_data_black.txt"
print(f"The file '{w_filename}' has {count_lines(w_filename)} lines.")
print(f"The file '{b_filename}' has {count_lines(b_filename)} lines.")
filename = "../train_data.txt"
print(f"The file '{filename}' has {count_lines(filename)} lines.")
Loading

0 comments on commit f9222d9

Please sign in to comment.