diff --git a/.gitignore b/.gitignore
index b6e4761..31da0d3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,5 @@
+*.zip
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/0-player_counting/0-find_top_players.sh b/0-player_counting/0-find_top_players.sh
new file mode 100755
index 0000000..52c60dc
--- /dev/null
+++ b/0-player_counting/0-find_top_players.sh
@@ -0,0 +1,11 @@
+##!/bin/bash
+
+lichesss_raw_dir='/data/chess/bz2/standard/'
+output_dir='../../data/player_counts'
+mkdir -p $output_dir
+
+for t in $lichesss_raw_dir/*-{01..11}.pgn.bz2 $lichesss_raw_dir/*{3..8}-12.pgn.bz2; do
+ fname="$(basename -- $t)"
+ echo "${t} ${output_dir}/${fname}.csv.bz2"
+ screen -S "filter-${fname}" -dm bash -c "source ~/.bashrc; python3 find_top_players.py ${t} ${output_dir}/${fname}.csv.bz2"
+done
diff --git a/0-player_counting/1-collect_top_players.sh b/0-player_counting/1-collect_top_players.sh
new file mode 100755
index 0000000..17dc26e
--- /dev/null
+++ b/0-player_counting/1-collect_top_players.sh
@@ -0,0 +1,49 @@
+##!/bin/bash
+
+lichesss_raw_dir='/data/chess/bz2/standard/'
+counts_dir='../../data/player_counts'
+counts_file='../../data/player_counts_combined.csv.bz2'
+top_list='../../data/player_counts_combined_top_names.csv.bz2'
+
+output_2000_dir='../../data/top_2000_player_games'
+output_2000_metadata_dir='../../data/top_2000_player_data'
+
+players_list='../../data/select_transfer_players'
+
+final_data_dir='../../data/transfer_players_data'
+
+num_train=10
+num_val=900
+num_test=100
+
+python3 combine_player_counts.py $counts_dir/* $counts_file
+
+bzcat $counts_file | head -n 2000 | bzip2 > $top_list
+
+mkdir -p $output_2000_dir
+
+python3 split_by_players.py $top_list $lichesss_raw_dir/*-{01..11}.pgn.bz2 $lichesss_raw_dir/*{3..8}-12.pgn.bz2 $output_2000_dir
+
+rm -v $top_list
+
+mkdir -p $output_2000_metadata_dir
+
+python3 player_game_counts.py $output_2000_dir $output_2000_metadata_dir
+
+python3 select_top_players.py $output_2000_metadata_dir \
+ ${players_list}_train.csv $num_train \
+ ${players_list}_validate.csv $num_val \
+ ${players_list}_test.csv $num_test \
+
+mkdir -p $final_data_dir
+mkdir -p $final_data_dir/metadata
+cp -v ${players_list}*.csv $final_data_dir/metadata
+
+for c in "train" "validate" "test"; do
+ mkdir $final_data_dir/${c}
+ mkdir $final_data_dir/${c}_metadata
+ for t in `tail -n +2 ${players_list}_${c}.csv|awk -F ',' '{print $1}'`; do
+ cp -v ${output_2000_dir}/${t}.pgn.bz2 $final_data_dir/${c}
+ cp ${output_2000_metadata_dir}/${t}.csv.bz2 $final_data_dir/${c}_metadata
+ done
+done
diff --git a/0-player_counting/2-select_extended_set.sh b/0-player_counting/2-select_extended_set.sh
new file mode 100755
index 0000000..159e2ee
--- /dev/null
+++ b/0-player_counting/2-select_extended_set.sh
@@ -0,0 +1,20 @@
+##!/bin/bash
+set -e
+
+vals_dat_dir="../../data/transfer_players_data/validate_metadata/"
+vals_dir="../../data/transfer_players_validate"
+output_dir="../../data/transfer_players_extended"
+list_file='../../data/extended_list.csv'
+
+num_per_bin=5
+bins="1100 1300 1500 1700 1900"
+
+
+python3 select_binned_players.py $vals_dat_dir $list_file $num_per_bin $bins
+
+mkdir -p $output_dir
+
+while read player; do
+ echo $player
+ cp -r ${vals_dir}/${player} ${output_dir}
+done < $list_file
diff --git a/0-player_counting/README.md b/0-player_counting/README.md
new file mode 100755
index 0000000..4188a62
--- /dev/null
+++ b/0-player_counting/README.md
@@ -0,0 +1,3 @@
+# Player Counting
+
+This is the code we used to count the number of games each player has.
diff --git a/0-player_counting/combine_player_counts.py b/0-player_counting/combine_player_counts.py
new file mode 100755
index 0000000..2492139
--- /dev/null
+++ b/0-player_counting/combine_player_counts.py
@@ -0,0 +1,31 @@
+import backend
+
+import argparse
+import bz2
+
+import pandas
+
+@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Collect counts and create list from them', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('inputs', nargs = '+', help='input csvs')
+ parser.add_argument('output', help='output csv')
+ args = parser.parse_args()
+
+ counts = {}
+ for p in args.inputs:
+ backend.printWithDate(f"Processing {p}", end = '\r')
+ df = pandas.read_csv(p)
+ for i, row in df.iterrows():
+ try:
+ counts[row['player']] += row['count']
+ except KeyError:
+ counts[row['player']] = row['count']
+ backend.printWithDate(f"Writing")
+ with bz2.open(args.output, 'wt') as f:
+ f.write('player,count\n')
+ for p, c in sorted(counts.items(), key = lambda x: x[1], reverse=True):
+ f.write(f"{p},{c}\n")
+
+if __name__ == '__main__':
+ main()
diff --git a/0-player_counting/find_top_players.py b/0-player_counting/find_top_players.py
new file mode 100755
index 0000000..e5e195e
--- /dev/null
+++ b/0-player_counting/find_top_players.py
@@ -0,0 +1,42 @@
+import backend
+
+import argparse
+import bz2
+
+@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Count number of times each player occurs in pgn', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('input', help='input pgn')
+ parser.add_argument('output', help='output csv')
+ parser.add_argument('--exclude_bullet', action='store_false', help='Remove bullet games from counts')
+ args = parser.parse_args()
+
+ games = backend.GamesFile(args.input)
+
+ counts = {}
+
+ for i, (d, _) in enumerate(games):
+ if args.exclude_bullet and 'Bullet' in d['Event']:
+ continue
+ else:
+ add_player(d['White'], counts)
+ add_player(d['Black'], counts)
+ if i % 10000 == 0:
+ backend.printWithDate(f"{i} done with {len(counts)} players from {args.input}", end = '\r')
+
+ backend.printWithDate(f"{i} found total of {len(counts)} players from {args.input}")
+ with bz2.open(args.output, 'wt') as f:
+ f.write("player,count\n")
+ for p, c in sorted(counts.items(), key = lambda x: x[1], reverse=True):
+ f.write(f"{p},{c}\n")
+ backend.printWithDate("done")
+
+def add_player(p, d):
+ try:
+ d[p] += 1
+ except KeyError:
+ d[p] = 1
+
+if __name__ == '__main__':
+ main()
diff --git a/0-player_counting/player_game_counts.py b/0-player_counting/player_game_counts.py
new file mode 100755
index 0000000..57c38f9
--- /dev/null
+++ b/0-player_counting/player_game_counts.py
@@ -0,0 +1,67 @@
+import backend
+
+import os
+import os.path
+import csv
+import bz2
+import argparse
+
+@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Get some stats about each of the games')
+ parser.add_argument('targets_dir', help='input pgns dir')
+ parser.add_argument('output_dir', help='output csvs dir')
+ parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 64)
+ args = parser.parse_args()
+ multiProc = backend.Multiproc(args.pool_size)
+ multiProc.reader_init(Files_lister, args.targets_dir)
+ multiProc.processor_init(Games_processor, args.output_dir)
+
+ multiProc.run()
+
+class Files_lister(backend.MultiprocIterable):
+ def __init__(self, targets_dir):
+ self.targets_dir = targets_dir
+ self.targets = [(p.path, p.name.split('.')[0]) for p in os.scandir(targets_dir) if '.pgn.bz2' in p.name]
+ backend.printWithDate(f"Found {len(self.targets)} targets in {targets_dir}")
+ def __next__(self):
+ try:
+ backend.printWithDate(f"Pushed target {len(self.targets)} remaining", end = '\r', flush = True)
+ return self.targets.pop()
+ except IndexError:
+ raise StopIteration
+
+class Games_processor(backend.MultiprocWorker):
+ def __init__(self, output_dir):
+ self.output_dir = output_dir
+
+ def __call__(self, path, name):
+ games = backend.GamesFile(path)
+ with bz2.open(os.path.join(self.output_dir, f"{name}.csv.bz2"), 'wt') as f:
+ writer = csv.DictWriter(f, ["player", "opponent","game_id", "ELO", "opp_ELO", "was_white", "result", "won", "UTCDate", "UTCTime", "TimeControl"])
+
+ writer.writeheader()
+ for d, _ in games:
+ game_dat = {}
+ game_dat['player'] = name
+ game_dat['game_id'] = d['Site'].split('/')[-1]
+ game_dat['result'] = d['Result']
+ game_dat['UTCDate'] = d['UTCDate']
+ game_dat['UTCTime'] = d['UTCTime']
+ game_dat['TimeControl'] = d['TimeControl']
+ if d['Black'] == name:
+ game_dat['was_white'] = False
+ game_dat['opponent'] = d['White']
+ game_dat['ELO'] = d['BlackElo']
+ game_dat['opp_ELO'] = d['WhiteElo']
+ game_dat['won'] = d['Result'] == '0-1'
+ else:
+ game_dat['was_white'] = True
+ game_dat['opponent'] = d['Black']
+ game_dat['ELO'] = d['WhiteElo']
+ game_dat['opp_ELO'] = d['BlackElo']
+ game_dat['won'] = d['Result'] == '1-0'
+ writer.writerow(game_dat)
+
+if __name__ == '__main__':
+ main()
diff --git a/0-player_counting/select_binned_players.py b/0-player_counting/select_binned_players.py
new file mode 100755
index 0000000..901eda4
--- /dev/null
+++ b/0-player_counting/select_binned_players.py
@@ -0,0 +1,52 @@
+import backend
+
+import argparse
+import bz2
+import glob
+import random
+import os.path
+import multiprocessing
+
+import pandas
+
+@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Read all the metadata and select top n players for training/validation/testing', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('csvs_dir', help='dir of csvs')
+ parser.add_argument('output_list', help='list of targets')
+ parser.add_argument('bin_size', type=int, help='players per bin')
+ parser.add_argument('bins', type=int, nargs = '+', help='bins')
+ parser.add_argument('--pool_size', type=int, help='Number of threads to use for reading', default = 48)
+ parser.add_argument('--seed', type=int, help='random seed', default = 1)
+ args = parser.parse_args()
+ random.seed(args.seed)
+
+ bins = [int(b // 100 * 100) for b in args.bins]
+
+ with multiprocessing.Pool(args.pool_size) as pool:
+ players = pool.map(load_player, glob.glob(os.path.join(args.csvs_dir, '*.csv.bz2')))
+ backend.printWithDate(f"Found {len(players)} players, using {len(bins)} bins")
+ binned_players = {b : [] for b in bins}
+ for p in players:
+ pe_round = int(p['elo'] // 100 * 100)
+ if pe_round in bins:
+ binned_players[pe_round].append(p)
+ backend.printWithDate(f"Found: " + ', '.join([f"{b} : {len(p)}" for b, p in binned_players.items()]))
+
+ with open(args.output_list, 'wt') as f:
+ for b, p in binned_players.items():
+ random.shuffle(p)
+ print(b, [d['name'] for d in p[:args.bin_size]])
+ f.write('\n'.join([d['name'] for d in p[:args.bin_size]]) +'\n')
+
+def load_player(path):
+ df = pandas.read_csv(path, low_memory=False)
+ elo = df['ELO'][-10000:].mean()
+ count = len(df)
+ return {
+ 'name' : df['player'].iloc[0],
+ 'elo' : elo,
+ 'count' : count,
+ }
+if __name__ == "__main__":
+ main()
diff --git a/0-player_counting/select_top_players.py b/0-player_counting/select_top_players.py
new file mode 100755
index 0000000..7e6f54e
--- /dev/null
+++ b/0-player_counting/select_top_players.py
@@ -0,0 +1,63 @@
+import backend
+
+import argparse
+import bz2
+import glob
+import random
+import os.path
+import multiprocessing
+
+import pandas
+
+@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Read all the metadata and select top n players for training/validation/testing', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('inputs', help='input csvs dir')
+ parser.add_argument('output_train', help='output csv for training data')
+ parser.add_argument('num_train', type=int, help='num for main training')
+ parser.add_argument('output_val', help='output csv for validation data')
+ parser.add_argument('num_val', type=int, help='num for big validation run')
+ parser.add_argument('output_test', help='output csv for testing data')
+ parser.add_argument('num_test', type=int, help='num for holdout set')
+ parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 48)
+ parser.add_argument('--min_elo', type=int, help='min elo to select', default = 1100)
+ parser.add_argument('--max_elo', type=int, help='max elo to select', default = 2000)
+ parser.add_argument('--seed', type=int, help='random seed', default = 1)
+ args = parser.parse_args()
+ random.seed(args.seed)
+
+ targets = glob.glob(os.path.join(args.inputs, '*csv.bz2'))
+
+ with multiprocessing.Pool(args.pool_size) as pool:
+ players = pool.starmap(check_player, ((t, args.min_elo, args.max_elo) for t in targets))
+
+ players_top = sorted(
+ (p for p in players if p is not None),
+ key = lambda x : x[1],
+ reverse=True,
+ )[:args.num_train + args.num_val + args.num_test]
+
+ random.shuffle(players_top)
+
+ write_output_file(args.output_train, args.num_train, players_top)
+ write_output_file(args.output_val, args.num_val, players_top)
+ write_output_file(args.output_test, args.num_test, players_top)
+
+def write_output_file(path, count, targets):
+ with open(path, 'wt') as f:
+ f.write("player,count,ELO\n")
+ for i in range(count):
+ t = targets.pop()
+ f.write(f"{t[0]},{t[1]},{t[2]}\n")
+
+def check_player(path, min_elo, max_elo):
+ df = pandas.read_csv(path, low_memory=False)
+ elo = df['ELO'][-10000:].mean()
+ count = len(df)
+ if elo > min_elo and elo < max_elo:
+ return path.split('/')[-1].split('.')[0], count, elo
+ else:
+ return None
+
+if __name__ == "__main__":
+ main()
diff --git a/0-player_counting/split_by_players.py b/0-player_counting/split_by_players.py
new file mode 100755
index 0000000..26ed883
--- /dev/null
+++ b/0-player_counting/split_by_players.py
@@ -0,0 +1,79 @@
+import backend
+
+import pandas
+import lockfile
+
+import argparse
+import bz2
+import os
+import os.path
+
+@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Write pgns of games with slected players in them', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('target', help='target players list as csv')
+ parser.add_argument('inputs', nargs = '+', help='input pgns')
+ parser.add_argument('output', help='output dir')
+ parser.add_argument('--exclude_bullet', action='store_false', help='Remove bullet games from counts')
+ parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 48)
+ args = parser.parse_args()
+
+ df_targets = pandas.read_csv(args.target)
+ targets = set(df_targets['player'])
+
+ os.makedirs(args.output, exist_ok=True)
+
+ multiProc = backend.Multiproc(args.pool_size)
+ multiProc.reader_init(Files_lister, args.inputs)
+ multiProc.processor_init(Games_processor, targets, args.output, args.exclude_bullet)
+ multiProc.run()
+ backend.printWithDate("done")
+
+class Files_lister(backend.MultiprocIterable):
+ def __init__(self, inputs):
+ self.inputs = list(inputs)
+ backend.printWithDate(f"Found {len(self.inputs)}")
+ def __next__(self):
+ try:
+ backend.printWithDate(f"Pushed target {len(self.inputs)} remaining", end = '\r', flush = True)
+ return self.inputs.pop()
+ except IndexError:
+ raise StopIteration
+
+class Games_processor(backend.MultiprocWorker):
+ def __init__(self, targets, output_dir, exclude_bullet):
+ self.output_dir = output_dir
+ self.targets = targets
+ self.exclude_bullet = exclude_bullet
+
+ self.c = 0
+
+ def __call__(self, path):
+ games = backend.GamesFile(path)
+ self.c = 0
+ for i, (d, s) in enumerate(games):
+ if self.exclude_bullet and 'Bullet' in d['Event']:
+ continue
+ else:
+ if d['White'] in self.targets:
+ self.write_player(d['White'], s)
+ self.c += 1
+ if d['Black'] in self.targets:
+ self.write_player(d['Black'], s)
+ self.c += 1
+ if i % 10000 == 0:
+ backend.printWithDate(f"{path} {i} done with {self.c} writes", end = '\r')
+
+ def write_player(self, p_name, s):
+
+ p_path = os.path.join(self.output_dir, f"{p_name}.pgn.bz2")
+ lock_path = p_path + '.lock'
+ lock = lockfile.FileLock(lock_path)
+ with lock:
+ with bz2.open(p_path, 'at') as f:
+ f.write(s)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/1-data_generation/0-0-make_training_datasets.sh b/1-data_generation/0-0-make_training_datasets.sh
new file mode 100755
index 0000000..0708b1f
--- /dev/null
+++ b/1-data_generation/0-0-make_training_datasets.sh
@@ -0,0 +1,45 @@
+#!/bin/bash
+
+train_frac=80
+val_frac=10
+test_frac=10
+
+input_files="/maiadata/transfer_players_data/train"
+output_files="/maiadata/transfer_players_train"
+mkdir -p $output_files
+
+for player_file in $input_files/*.bz2; do
+ f=${player_file##*/}
+ p_name=${f%.pgn.bz2}
+ p_dir=$output_files/$p_name
+ split_dir=$output_files/$p_name/split
+ mkdir -p $p_dir
+ mkdir -p $split_dir
+ echo $p_name $p_dir
+ python split_by_player.py $player_file $p_name $split_dir/games
+
+
+ for c in "white" "black"; do
+ python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac
+
+ cd $p_dir
+ mkdir -p pgns
+ for s in "train" "validate" "test"; do
+ mkdir -p $s
+ mkdir $s/$c
+
+ #using tool from:
+ #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/
+ bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000
+
+ cat *.pgn > pgns/${s}_${c}.pgn
+ rm -v *.pgn
+
+ #using tool from:
+ #https://github.com/DanielUranga/trainingdata-tool
+ screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn"
+ done
+ cd -
+ done
+
+done
diff --git a/1-data_generation/0-1-make_training_csvs.sh b/1-data_generation/0-1-make_training_csvs.sh
new file mode 100755
index 0000000..9a38e7d
--- /dev/null
+++ b/1-data_generation/0-1-make_training_csvs.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+data_dir="/maiadata/transfer_players_train"
+
+for player_dir in $data_dir/*; do
+ player_name=`basename ${player_dir}`
+ mkdir $player_dir/csvs
+ for c in "white" "black"; do
+ for s in "train" "validate" "test"; do
+ target=$player_dir/split/${s}_${c}.pgn.bz2
+ output=$player_dir/csvs/${s}_${c}.csv.bz2
+ echo ${player_name} ${s} ${c}
+ screen -S "csv-${player_name}-${c}-${s}" -dm bash -c "python3 ../../data_generators/pgn_to_csv.py ${target} ${output}"
+ done
+ done
+done
diff --git a/1-data_generation/0-2-make_reduced_datasets.sh b/1-data_generation/0-2-make_reduced_datasets.sh
new file mode 100755
index 0000000..39b80f9
--- /dev/null
+++ b/1-data_generation/0-2-make_reduced_datasets.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+
+train_frac=80
+val_frac=10
+test_frac=10
+
+input_files="/maiadata/transfer_players_data/train"
+output_files="/maiadata/transfer_players_train_reduced"
+mkdir -p $output_files
+
+fractions='100 10 1'
+
+for frac in `echo $fractions`; do
+ for player_file in $input_files/*.bz2; do
+ f=${player_file##*/}
+ p_name=${f%.pgn.bz2}
+ p_dir=$output_files/$p_name/$frac
+ split_dir=$output_files/$p_name/$frac/split
+ mkdir -p $p_dir
+ mkdir -p $split_dir
+
+ python pgn_fractional_split.py $player_file $p_dir/raw_reduced.pgn.bz2 $p_dir/extra.pgn.bz2 --ratios $frac `echo "1000- $frac " | bc`
+
+ echo $p_name $frac $p_dir
+ python split_by_player.py $p_dir/raw_reduced.pgn.bz2 $p_name $split_dir/games
+
+
+ for c in "white" "black"; do
+ python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac
+
+ cd $p_dir
+ mkdir -p pgns
+ for s in "train" "validate" "test"; do
+ mkdir -p $s
+ mkdir $s/$c
+
+ #using tool from:
+ #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/
+ bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000
+
+ cat *.pgn > pgns/${s}_${c}.pgn
+ rm -v *.pgn
+
+ #using tool from:
+ #https://github.com/DanielUranga/trainingdata-tool
+ screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn"
+ done
+ cd -
+ done
+ done
+done
diff --git a/1-data_generation/1-0-make_val_datasets.sh b/1-data_generation/1-0-make_val_datasets.sh
new file mode 100755
index 0000000..927b860
--- /dev/null
+++ b/1-data_generation/1-0-make_val_datasets.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+set -e
+
+train_frac=80
+val_frac=10
+test_frac=10
+
+input_files="/maiadata/transfer_players_data/validate"
+output_files="/maiadata/transfer_players_validate"
+mkdir -p $output_files
+
+for player_file in $input_files/*.bz2; do
+ f=${player_file##*/}
+ p_name=${f%.pgn.bz2}
+ p_dir=$output_files/$p_name
+
+ f_size=$(du -sb ${output_files}/${p_name} | cut -f1)
+ if [ $((f_size)) -lt 50000 ]; then
+ echo $f_size $p_dir
+ rm -rv $p_dir/*
+ else
+ continue
+ fi
+
+ split_dir=$output_files/$p_name/split
+ mkdir -p $p_dir
+ mkdir -p $split_dir
+ echo $p_name $p_dir
+ python split_by_player.py $player_file $p_name $split_dir/games
+
+ for c in "white" "black"; do
+ python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac
+
+ cd $p_dir
+ mkdir -p pgns
+ for s in "train" "validate" "test"; do
+ mkdir -p $s
+ mkdir $s/$c
+
+ #using tool from:
+ #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/
+ bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000
+
+ cat *.pgn > pgns/${s}_${c}.pgn
+ rm -v *.pgn
+
+ #using tool from:
+ #https://github.com/DanielUranga/trainingdata-tool
+ screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn"
+ done
+ cd -
+ done
+ while [ `screen -ls | wc -l` -gt 20 ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+done
diff --git a/1-data_generation/1-1-make_val_csvs.sh b/1-data_generation/1-1-make_val_csvs.sh
new file mode 100755
index 0000000..8bcfbee
--- /dev/null
+++ b/1-data_generation/1-1-make_val_csvs.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+data_dir="../../data/transfer_players_validate"
+
+for player_dir in $data_dir/*; do
+ player_name=`basename ${player_dir}`
+ mkdir -p $player_dir/csvs
+ for c in "white" "black"; do
+ for s in "train" "validate" "test"; do
+ target=$player_dir/split/${s}_${c}.pgn.bz2
+ output=$player_dir/csvs/${s}_${c}.csv.bz2
+ echo ${player_name} ${s} ${c}
+ screen -S "csv-${player_name}-${c}-${s}" -dm bash -c "python3 ../../data_generators/pgn_to_csv.py ${target} ${output}"
+ done
+ done
+ while [ `screen -ls | wc -l` -gt 50 ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+done
diff --git a/1-data_generation/2-make_testing_datasets.sh b/1-data_generation/2-make_testing_datasets.sh
new file mode 100755
index 0000000..3f94048
--- /dev/null
+++ b/1-data_generation/2-make_testing_datasets.sh
@@ -0,0 +1,49 @@
+#!/bin/bash
+
+train_frac=80
+val_frac=10
+test_frac=10
+
+input_files="/maiadata/transfer_players_data/test"
+output_files="/maiadata/transfer_players_test"
+mkdir -p $output_files
+
+for player_file in $input_files/*.bz2; do
+ f=${player_file##*/}
+ p_name=${f%.pgn.bz2}
+ p_dir=$output_files/$p_name
+ split_dir=$output_files/$p_name/split
+ mkdir -p $p_dir
+ mkdir -p $split_dir
+ echo $p_name $p_dir
+ python split_by_player.py $player_file $p_name $split_dir/games
+
+
+ for c in "white" "black"; do
+ python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 $split_dir/test_$c.pgn.bz2 --ratios $train_frac $val_frac $test_frac
+
+ cd $p_dir
+ mkdir -p pgns
+ for s in "train" "validate" "test"; do
+ mkdir -p $s
+ mkdir $s/$c
+
+ #using tool from:
+ #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/
+ bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000
+
+ cat *.pgn > pgns/${s}_${c}.pgn
+ rm -v *.pgn
+
+ #using tool from:
+ #https://github.com/DanielUranga/trainingdata-tool
+ screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn"
+ done
+ cd -
+ done
+ while [ `screen -ls | wc -l` -gt 20 ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+
+done
diff --git a/1-data_generation/9-pgn_to_training_data.sh b/1-data_generation/9-pgn_to_training_data.sh
new file mode 100755
index 0000000..f3677a9
--- /dev/null
+++ b/1-data_generation/9-pgn_to_training_data.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+set -e
+
+#args input_path output_dir player
+
+player_file=${1}
+p_dir=${2}
+p_name=${3}
+
+train_frac=90
+val_frac=10
+
+split_dir=$p_dir/split
+
+mkdir -p ${p_dir}
+mkdir -p ${split_dir}
+
+echo "${p_name} to ${p_dir}"
+
+python split_by_player.py $player_file $p_name $split_dir/games
+
+for c in "white" "black"; do
+ python pgn_fractional_split.py $split_dir/games_$c.pgn.bz2 $split_dir/train_$c.pgn.bz2 $split_dir/validate_$c.pgn.bz2 --ratios $train_frac $val_frac
+
+ cd $p_dir
+ mkdir -p pgns
+ for s in "train" "validate"; do
+ mkdir -p $s
+ mkdir -p $s/$c
+
+ #using tool from:
+ #https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/
+ bzcat $split_dir/${s}_${c}.pgn.bz2 | pgn-extract -7 -C -N -#1000
+
+ cat *.pgn > pgns/${s}_${c}.pgn
+ rm -v *.pgn
+
+ #using tool from:
+ #https://github.com/DanielUranga/trainingdata-tool
+ screen -S "${p_name}-${c}-${s}" -dm bash -c "cd ${s}/${c}; trainingdata-tool -v ../../pgns/${s}_${c}.pgn"
+ done
+ cd -
+done
diff --git a/1-data_generation/pgn_fractional_split.py b/1-data_generation/pgn_fractional_split.py
new file mode 100755
index 0000000..99ed8ca
--- /dev/null
+++ b/1-data_generation/pgn_fractional_split.py
@@ -0,0 +1,52 @@
+import backend
+
+import argparse
+import bz2
+import random
+
+@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Split games into some numbe of subsets, by percentage', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('input', help='input pgn')
+
+ parser.add_argument('outputs', nargs='+', help='output pgn files ', type = str)
+
+ parser.add_argument('--ratios', nargs='+', help='ratios of games for the outputs', required = True, type = float)
+
+ parser.add_argument('--no_shuffle', action='store_false', help='Stop output shuffling')
+ parser.add_argument('--seed', type=int, help='random seed', default = 1)
+ args = parser.parse_args()
+
+ if len(args.ratios) != len(args.outputs):
+ raise RuntimeError(f"Invalid outputs specified: {args.outputs} and {args.ratios}")
+
+ random.seed(args.seed)
+ games = backend.GamesFile(args.input)
+
+ game_strs = []
+
+ for i, (d, l) in enumerate(games):
+ game_strs.append(l)
+ if i % 10000 == 0:
+ backend.printWithDate(f"{i} done from {args.input}", end = '\r')
+ backend.printWithDate(f"{i} done total from {args.input}")
+ if not args.no_shuffle:
+ random.shuffle(game_strs)
+
+ split_indices = [int(r * len(game_strs) / sum(args.ratios)) for r in args.ratios]
+
+ #Correction for rounding, not very precise
+ split_indices[0] += len(game_strs) - sum(split_indices)
+
+ for p, c in zip(args.outputs, split_indices):
+ backend.printWithDate(f"Writing {c} games to: {p}")
+ with bz2.open(p, 'wt') as f:
+ f.write(''.join(
+ [game_strs.pop() for i in range(c)]
+ ))
+
+ backend.printWithDate("done")
+
+if __name__ == '__main__':
+ main()
diff --git a/1-data_generation/player_splits.sh b/1-data_generation/player_splits.sh
new file mode 100755
index 0000000..eeeeba6
--- /dev/null
+++ b/1-data_generation/player_splits.sh
@@ -0,0 +1,18 @@
+#!/bin/bash -e
+
+input_files="../../data/top_player_games"
+output_files="../../data/transfer_players_pgns_split"
+mkdir -p $output_files
+
+train_frac=80
+val_frac=10
+test_frac=10
+
+for p in $input_files/*; do
+ name=`basename $p`
+ p_name=${name%.pgn.bz2}
+ split_dir=$output_files/$name
+ mkdir $split_dir
+
+ screen -S "${p_name}" -dm bash -c "python3 pgn_fractional_split.py $p $split_dir/train.pgn.bz2 $split_dir/validate.pgn.bz2 $split_dir/test.pgn.bz2 --ratios $train_frac $val_frac $test_frac"
+done
diff --git a/1-data_generation/split_by_player.py b/1-data_generation/split_by_player.py
new file mode 100755
index 0000000..d184365
--- /dev/null
+++ b/1-data_generation/split_by_player.py
@@ -0,0 +1,48 @@
+import backend
+
+import argparse
+import bz2
+import random
+
+@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Split games into games were the target was White or Black', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('input', help='input pgn')
+ parser.add_argument('player', help='target player name')
+ parser.add_argument('output', help='output pgn prefix')
+ parser.add_argument('--no_shuffle', action='store_false', help='Stop output shuffling')
+ parser.add_argument('--seed', type=int, help='random seed', default = 1)
+ args = parser.parse_args()
+
+ random.seed(args.seed)
+
+ games = backend.GamesFile(args.input)
+
+ outputs_white = []
+ outputs_black = []
+
+ for i, (d, l) in enumerate(games):
+ if d['White'] == args.player:
+ outputs_white.append(l)
+ elif d['Black'] == args.player:
+ outputs_black.append(l)
+ else:
+ raise ValueError(f"{args.player} not found in game {i}:\n{l}")
+ if i % 10000 == 0:
+ backend.printWithDate(f"{i} done with {len(outputs_white)}:{len(outputs_black)} players from {args.input}", end = '\r')
+ backend.printWithDate(f"{i} found totals of {len(outputs_white)}:{len(outputs_black)} players from {args.input}")
+ backend.printWithDate("Writing white")
+ with bz2.open(f"{args.output}_white.pgn.bz2", 'wt') as f:
+ if not args.no_shuffle:
+ random.shuffle(outputs_white)
+ f.write(''.join(outputs_white))
+ backend.printWithDate("Writing black")
+ with bz2.open(f"{args.output}_black.pgn.bz2", 'wt') as f:
+ if not args.no_shuffle:
+ random.shuffle(outputs_black)
+ f.write(''.join(outputs_black))
+ backend.printWithDate("done")
+
+if __name__ == '__main__':
+ main()
diff --git a/2-training/extended_configs/frozen_copy/frozen_copy.yaml b/2-training/extended_configs/frozen_copy/frozen_copy.yaml
new file mode 100755
index 0000000..4729957
--- /dev/null
+++ b/2-training/extended_configs/frozen_copy/frozen_copy.yaml
@@ -0,0 +1,37 @@
+%YAML 1.2
+---
+#gpu: 1
+
+dataset:
+ path: '/data/transfer_players_extended/'
+ #name: ''
+
+training:
+ precision: 'half'
+ batch_size: 256
+ num_batch_splits: 1
+ test_steps: 2000
+ train_avg_report_steps: 50
+ total_steps: 150000
+ checkpoint_steps: 500
+ shuffle_size: 256
+ lr_values:
+ - 0.01
+ - 0.001
+ - 0.0001
+ - 0.00001
+ lr_boundaries:
+ - 35000
+ - 80000
+ - 110000
+ policy_loss_weight: 1.0
+ value_loss_weight: 1.0
+
+model:
+ filters: 64
+ residual_blocks: 6
+ se_ratio: 8
+ path: "maia/1900"
+ keep_weights: true
+ back_prop_blocks: 3
+...
diff --git a/2-training/extended_configs/frozen_random/frozen_random.yaml b/2-training/extended_configs/frozen_random/frozen_random.yaml
new file mode 100755
index 0000000..3cca9e6
--- /dev/null
+++ b/2-training/extended_configs/frozen_random/frozen_random.yaml
@@ -0,0 +1,37 @@
+%YAML 1.2
+---
+#gpu: 1
+
+dataset:
+ path: '/data/transfer_players_extended/'
+ #name: ''
+
+training:
+ precision: 'half'
+ batch_size: 256
+ num_batch_splits: 1
+ test_steps: 2000
+ train_avg_report_steps: 50
+ total_steps: 150000
+ checkpoint_steps: 500
+ shuffle_size: 256
+ lr_values:
+ - 0.01
+ - 0.001
+ - 0.0001
+ - 0.00001
+ lr_boundaries:
+ - 35000
+ - 80000
+ - 110000
+ policy_loss_weight: 1.0
+ value_loss_weight: 1.0
+
+model:
+ filters: 64
+ residual_blocks: 6
+ se_ratio: 8
+ path: "maia/1900"
+ keep_weights: false
+ back_prop_blocks: 3
+...
diff --git a/2-training/extended_configs/unfrozen_copy/unfrozen_copy.yaml b/2-training/extended_configs/unfrozen_copy/unfrozen_copy.yaml
new file mode 100755
index 0000000..947ca46
--- /dev/null
+++ b/2-training/extended_configs/unfrozen_copy/unfrozen_copy.yaml
@@ -0,0 +1,37 @@
+%YAML 1.2
+---
+#gpu: 1
+
+dataset:
+ path: '/data/transfer_players_extended/'
+ #name: ''
+
+training:
+ precision: 'half'
+ batch_size: 256
+ num_batch_splits: 1
+ test_steps: 2000
+ train_avg_report_steps: 50
+ total_steps: 150000
+ checkpoint_steps: 500
+ shuffle_size: 256
+ lr_values:
+ - 0.01
+ - 0.001
+ - 0.0001
+ - 0.00001
+ lr_boundaries:
+ - 35000
+ - 80000
+ - 110000
+ policy_loss_weight: 1.0
+ value_loss_weight: 1.0
+
+model:
+ filters: 64
+ residual_blocks: 6
+ se_ratio: 8
+ path: "maia/1900"
+ keep_weights: true
+ back_prop_blocks: 99
+...
diff --git a/2-training/extended_configs/unfrozen_random/unfrozen_random.yaml b/2-training/extended_configs/unfrozen_random/unfrozen_random.yaml
new file mode 100755
index 0000000..2fae0a4
--- /dev/null
+++ b/2-training/extended_configs/unfrozen_random/unfrozen_random.yaml
@@ -0,0 +1,37 @@
+%YAML 1.2
+---
+#gpu: 1
+
+dataset:
+ path: '/data/transfer_players_extended/'
+ #name: ''
+
+training:
+ precision: 'half'
+ batch_size: 256
+ num_batch_splits: 1
+ test_steps: 2000
+ train_avg_report_steps: 50
+ total_steps: 150000
+ checkpoint_steps: 500
+ shuffle_size: 256
+ lr_values:
+ - 0.01
+ - 0.001
+ - 0.0001
+ - 0.00001
+ lr_boundaries:
+ - 35000
+ - 80000
+ - 110000
+ policy_loss_weight: 1.0
+ value_loss_weight: 1.0
+
+model:
+ filters: 64
+ residual_blocks: 6
+ se_ratio: 8
+ path: "maia/1900"
+ keep_weights: false
+ back_prop_blocks: 99
+...
diff --git a/2-training/final_config.yaml b/2-training/final_config.yaml
new file mode 100755
index 0000000..8a4f0af
--- /dev/null
+++ b/2-training/final_config.yaml
@@ -0,0 +1,37 @@
+%YAML 1.2
+---
+gpu: 0
+
+dataset:
+ path: 'path to player data'
+ #name: ''
+
+training:
+ precision: 'half'
+ batch_size: 256
+ num_batch_splits: 1
+ test_steps: 2000
+ train_avg_report_steps: 50
+ total_steps: 150000
+ checkpoint_steps: 500
+ shuffle_size: 256
+ lr_values:
+ - 0.01
+ - 0.001
+ - 0.0001
+ - 0.00001
+ lr_boundaries:
+ - 35000
+ - 80000
+ - 110000
+ policy_loss_weight: 1.0
+ value_loss_weight: 1.0
+
+model:
+ filters: 64
+ residual_blocks: 6
+ se_ratio: 8
+ path: "maia-1900"
+ keep_weights: true
+ back_prop_blocks: 99
+...
diff --git a/2-training/train_transfer.py b/2-training/train_transfer.py
new file mode 100755
index 0000000..3e6dffe
--- /dev/null
+++ b/2-training/train_transfer.py
@@ -0,0 +1,144 @@
+import argparse
+import os
+import os.path
+import yaml
+import sys
+import glob
+import gzip
+import random
+import multiprocessing
+import shutil
+
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+import tensorflow as tf
+
+
+import backend
+import backend.tf_transfer
+
+SKIP = 32
+
+@backend.logged_main
+def main(config_path, name, collection_name, player_name, gpu, num_workers):
+ output_name = os.path.join('models', collection_name, name + '.txt')
+
+ with open(config_path) as f:
+ cfg = yaml.safe_load(f.read())
+
+ if player_name is not None:
+ cfg['dataset']['name'] = player_name
+ if gpu is not None:
+ cfg['gpu'] = gpu
+
+ backend.printWithDate(yaml.dump(cfg, default_flow_style=False))
+
+ train_chunks_white, train_chunks_black = backend.tf_transfer.get_latest_chunks(os.path.join(
+ cfg['dataset']['path'],
+ cfg['dataset']['name'],
+ 'train',
+ ))
+ val_chunks_white, val_chunks_black = backend.tf_transfer.get_latest_chunks(os.path.join(
+ cfg['dataset']['path'],
+ cfg['dataset']['name'],
+ 'validate',
+ ))
+
+ shuffle_size = cfg['training']['shuffle_size']
+ total_batch_size = cfg['training']['batch_size']
+ backend.tf_transfer.ChunkParser.BATCH_SIZE = total_batch_size
+ tfprocess = backend.tf_transfer.TFProcess(cfg, name, collection_name)
+
+ train_parser = backend.tf_transfer.ChunkParser(
+ backend.tf_transfer.FileDataSrc(train_chunks_white.copy(), train_chunks_black.copy()),
+ shuffle_size=shuffle_size,
+ sample=SKIP,
+ batch_size=backend.tf_transfer.ChunkParser.BATCH_SIZE,
+ workers=num_workers,
+ )
+ train_dataset = tf.data.Dataset.from_generator(
+ train_parser.parse,
+ output_types=(
+ tf.string, tf.string, tf.string, tf.string
+ ),
+ )
+ train_dataset = train_dataset.map(
+ backend.tf_transfer.ChunkParser.parse_function)
+ train_dataset = train_dataset.prefetch(4)
+
+ test_parser = backend.tf_transfer.ChunkParser(
+ backend.tf_transfer.FileDataSrc(val_chunks_white.copy(), val_chunks_black.copy()),
+ shuffle_size=shuffle_size,
+ sample=SKIP,
+ batch_size=backend.tf_transfer.ChunkParser.BATCH_SIZE,
+ workers=num_workers,
+ )
+ test_dataset = tf.data.Dataset.from_generator(
+ test_parser.parse,
+ output_types=(tf.string, tf.string, tf.string, tf.string),
+ )
+ test_dataset = test_dataset.map(
+ backend.tf_transfer.ChunkParser.parse_function)
+ test_dataset = test_dataset.prefetch(4)
+
+ tfprocess.init_v2(train_dataset, test_dataset)
+
+ tfprocess.restore_v2()
+
+ num_evals = cfg['training'].get('num_test_positions', (len(val_chunks_white) + len(val_chunks_black)) * 10)
+ num_evals = max(1, num_evals // backend.tf_transfer.ChunkParser.BATCH_SIZE)
+ print("Using {} evaluation batches".format(num_evals))
+ try:
+ tfprocess.process_loop_v2(total_batch_size, num_evals, batch_splits=1)
+ except KeyboardInterrupt:
+ backend.printWithDate("KeyboardInterrupt: Stopping")
+ train_parser.shutdown()
+ test_parser.shutdown()
+ raise
+ tfprocess.save_leelaz_weights_v2(output_name)
+
+ train_parser.shutdown()
+ test_parser.shutdown()
+ return cfg
+
+def make_model_files(cfg, name, collection_name, save_dir):
+ output_name = os.path.join(save_dir, collection_name, name)
+ models_dir = os.path.join('models', collection_name, name)
+ models = [(int(p.name.split('-')[1]), p.name, p.path) for p in os.scandir(models_dir) if p.name.endswith('.pb.gz')]
+ top_model = max(models, key = lambda x : x[0])
+
+ os.makedirs(output_name, exist_ok=True)
+ model_file_name = top_model[1].replace('ckpt', name)
+ shutil.copy(top_model[2], os.path.join(output_name, model_file_name))
+ with open(os.path.join(output_name, "config.yaml"), 'w') as f:
+ cfg_yaml = yaml.dump(cfg).replace('\n', '\n ').strip()
+ f.write(f"""
+%YAML 1.2
+---
+name: {name}
+display_name: {name.replace('_', ' ')}
+engine: lc0_23
+options:
+ weightsPath: {model_file_name}
+full_config:
+ {cfg_yaml}
+...""")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='Tensorflow pipeline for training Leela Chess.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('config', help='config file for model / training')
+ parser.add_argument('player_name', nargs='?', help='player name to train on', default=None)
+ parser.add_argument('--gpu', help='gpu to use', default = 0, type = int)
+ parser.add_argument('--num_workers', help='number of worker threads to use', default = max(1, multiprocessing.cpu_count() - 2), type = int)
+ parser.add_argument('--copy_dir', help='dir to save final models in', default = 'final_models')
+ args = parser.parse_args()
+
+ collection_name = os.path.basename(os.path.dirname(args.config)).replace('configs_', '')
+ name = os.path.basename(args.config).split('.')[0]
+
+ if args.player_name is not None:
+ name = f"{args.player_name}_{name}"
+
+ multiprocessing.set_start_method('spawn')
+ cfg = main(args.config, name, collection_name, args.player_name, args.gpu, args.num_workers)
+ make_model_files(cfg, name, collection_name, args.copy_dir)
diff --git a/3-analysis/1-0-baselines_results.sh b/3-analysis/1-0-baselines_results.sh
new file mode 100755
index 0000000..aa7b3f3
--- /dev/null
+++ b/3-analysis/1-0-baselines_results.sh
@@ -0,0 +1,32 @@
+#!/bin/bash
+
+targets_dir="../../data/transfer_players_train"
+outputs_dir="../../data/transfer_results/train"
+
+maias="../../models/maia"
+stockfish="../../models/stockfish/stockfish_d15"
+leela="../../models/leela/sergio"
+
+
+mkdir -p outputs_dir
+
+for player_dir in $targets_dir/*; do
+ player=`basename ${player_dir}`
+ player_ret_dir=$outputs_dir/$player
+
+ echo $player_dir
+
+ mkdir -p $player_ret_dir
+ mkdir -p $player_ret_dir/maia
+ mkdir -p $player_ret_dir/leela
+ #mkdir -p $player_ret_dir/stockfish
+ for c in "white" "black"; do
+ player_files=$player_dir/csvs/test_${c}.csv.bz2
+ #screen -S "baseline-tests-${player}-leela-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $leela ${player_files} ${player_ret_dir}/leela/segio_${c}.csv.bz2"
+ for maia_path in $maias/*; do
+ maia_name=`basename ${maia_path}`
+ printf "$maia_name\r"
+ screen -S "baseline-tests-${player}-${maia_name}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $maia_path ${player_files} ${player_ret_dir}/maia/${maia_name}_${c}.csv.bz2"
+ done
+ done
+done
diff --git a/3-analysis/1-1-baselines_results_validation.sh b/3-analysis/1-1-baselines_results_validation.sh
new file mode 100755
index 0000000..fbf4e01
--- /dev/null
+++ b/3-analysis/1-1-baselines_results_validation.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+targets_dir="../../data/transfer_players_validate"
+outputs_dir="../../data/transfer_results/validate"
+
+maias="../../models/maia"
+stockfish="../../models/stockfish/stockfish_d15"
+leela="../../models/leela/sergio"
+
+
+mkdir -p outputs_dir
+
+for player_dir in $targets_dir/*; do
+ player=`basename ${player_dir}`
+ player_ret_dir=$outputs_dir/$player
+
+ echo $player_dir
+
+ mkdir -p $player_ret_dir
+ mkdir -p $player_ret_dir/maia
+ mkdir -p $player_ret_dir/leela
+ #mkdir -p $player_ret_dir/stockfish
+ for c in "white" "black"; do
+ player_files=$player_dir/csvs/test_${c}.csv.bz2
+ #screen -S "baseline-tests-${player}-leela-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $leela ${player_files} ${player_ret_dir}/leela/segio_${c}.csv.bz2"
+ for maia_path in $maias/1{1..9..2}00; do
+ maia_name=`basename ${maia_path}`
+ printf "$maia_name\r"
+ screen -S "baseline-tests-${player}-${maia_name}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $maia_path ${player_files} ${player_ret_dir}/maia/${maia_name}_${c}.csv.bz2"
+ done
+ done
+ while [ `screen -ls | wc -l` -gt 70 ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+done
diff --git a/3-analysis/2-0-baseline_results.sh b/3-analysis/2-0-baseline_results.sh
new file mode 100755
index 0000000..3bb931b
--- /dev/null
+++ b/3-analysis/2-0-baseline_results.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+set -e
+
+max_screens=40
+
+targets_dir="../../data/transfer_players"
+outputs_dir="../../data/transfer_results"
+kdd_path="../../datasets/10000_full_2019-12.csv.bz2"
+
+mkdir -p outputs_dir
+
+maias_dir=../../models/maia
+
+for t in "train" "extended" "validate"; do
+ for player_dir in ${targets_dir}_${t}/*; do
+ for model in $maias_dir/1{1..9..2}00; do
+ maia_type=`basename ${model}`
+ player_ret_dir=$outputs_dir/$t/$player/maia
+ mkdir -p $player_ret_dir
+ player=`basename ${player_dir}`
+ echo $t $maia_type $player
+ for c in "white" "black"; do
+ player_files=${player_dir}/csvs/test_${c}.csv.bz2
+ screen -S "baselines-${player}-${maia_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${maia_type}_${c}.csv.bz2"
+ done
+ while [ `screen -ls | wc -l` -gt $max_screens ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+ done
+ done
+done
+
diff --git a/3-analysis/2-1-model_results.sh b/3-analysis/2-1-model_results.sh
new file mode 100755
index 0000000..1f9d9fe
--- /dev/null
+++ b/3-analysis/2-1-model_results.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+set -e
+
+max_screens=50
+
+targets_dir="../../data/transfer_players"
+outputs_dir="../../data/transfer_results"
+kdd_path="../../datasets/10000_full_2019-12.csv.bz2"
+
+
+models_dir="../../transfer_training/final_models"
+mkdir -p outputs_dir
+
+for model in $models_dir/*/*; do
+ player=`python3 get_models_player.py ${model}`
+ model_type=`dirname ${model}`
+ model_type=`basename ${model_type}`
+ model_name=`basename ${model}`
+ #echo $model $model_type $model_name $player
+
+ for c in "white" "black"; do
+ for t in "train" "extended"; do
+ player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2
+ if [ -f "$player_files" ]; then
+ echo $player_files
+ player_ret_dir=$outputs_dir/$t/$player/transfer/$model_type
+ mkdir -p $player_ret_dir
+ screen -S "transfer-tests-${player}-${model_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${model_name}_${c}.csv.bz2"
+ fi
+ done
+ done
+ while [ `screen -ls | wc -l` -gt $max_screens ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+ #screen -S "transfer-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd.csv.bz2"
+done
+
+
diff --git a/3-analysis/2-2-model_results_val.sh b/3-analysis/2-2-model_results_val.sh
new file mode 100755
index 0000000..a7a5946
--- /dev/null
+++ b/3-analysis/2-2-model_results_val.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+set -e
+
+max_screens=50
+
+targets_dir="../../data/transfer_players"
+outputs_dir="../../data/transfer_results_val"
+summaries_dir="../../data/transfer_summaries"
+kdd_path="../../data/reduced_kdd_test_set.csv.bz2"
+
+
+models_dir="../../transfer_training/final_models_val"
+mkdir -p $outputs_dir
+mkdir -p $summaries_dir
+
+for model in $models_dir/*/*; do
+ player=`python3 get_models_player.py ${model}`
+ model_type=`dirname ${model}`
+ model_type=`basename ${model_type}`
+ model_name=`basename ${model}`
+ #echo $model $model_type $model_name $player
+ for t in "train" "validate"; do
+ for c in "white" "black"; do
+ player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2
+ if [ -f "$player_files" ]; then
+ echo $player_files
+ player_ret_dir=$outputs_dir/$t/$player/${t}/$model_type
+ player_sum_dir=$summaries_dir/$t/$player/${t}/$model_type
+ mkdir -p $player_ret_dir
+ screen -S "val-tests-${player}-${model_type}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model ${player_files} ${player_ret_dir}/${model_name}_${c}.csv.bz2;python3 make_summary.py ${player_ret_dir}/${model_name}_${c}.csv.bz2 ${player_sum_dir}/${model_name}_${c}.json"
+ fi
+ done
+ done
+ screen -S "val-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd_reduced.csv.bz2;python3 make_summary.py ${player_ret_dir}/${model_name}_kdd_reduced.csv.bz2
+ ${player_sum_dir}/${model_name}_kdd_reduced.csv.bz2"
+ while [ `screen -ls | wc -l` -gt $max_screens ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+ #screen -S "transfer-tests-${player}-${model_type}-kdd" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py $model ${kdd_path} ${player_ret_dir}/${model_name}_kdd.csv.bz2"
+done
+
+
diff --git a/3-analysis/3-0-model_cross_table.sh b/3-analysis/3-0-model_cross_table.sh
new file mode 100755
index 0000000..217b1ac
--- /dev/null
+++ b/3-analysis/3-0-model_cross_table.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+set -e
+
+max_screens=40
+
+targets_dir="../../data/transfer_players"
+outputs_dir="../../data/transfer_results_cross"
+
+models_dir="../../transfer_training/final_models"
+
+target_models=`echo ../../transfer_training/final_models/{no_stop,unfrozen_copy}/*`
+
+mkdir -p $outputs_dir
+
+for model in $target_models; do
+ player=`python3 get_models_player.py ${model}`
+ model_type=`dirname ${model}`
+ model_type=`basename ${model_type}`
+ model_name=`basename ${model}`
+ echo $player $model_type $model
+ for c in "white" "black"; do
+ for t in "train" "extended"; do
+ player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2
+ if [ -f "$player_files" ]; then
+ player_ret_dir=$outputs_dir/$player
+ mkdir -p $player_ret_dir
+ echo $player_files
+ for model2 in $target_models; do
+ model2_name=`basename ${model2}`
+ model2_player=`python3 get_models_player.py ${model2}`
+ screen -S "cross-${player}-${model2_player}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model2 ${player_files} ${player_ret_dir}/${model2_player}_${c}.csv.bz2"
+ done
+ while [ `screen -ls | wc -l` -gt $max_screens ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+ fi
+ done
+ done
+done
diff --git a/3-analysis/3-1-model_cross_table_val_generation.sh b/3-analysis/3-1-model_cross_table_val_generation.sh
new file mode 100755
index 0000000..cd55258
--- /dev/null
+++ b/3-analysis/3-1-model_cross_table_val_generation.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+set -e
+
+max_screens=40
+
+targets_dir="../../data/transfer_players_validate"
+outputs_dir="../../data/transfer_players_validate_cross_csvs"
+
+mkdir -p $outputs_dir
+
+for player_dir in $targets_dir/*; do
+ player=`basename ${player_dir}`
+ echo $player
+ mkdir -p ${outputs_dir}/${player}
+ for c in "white" "black"; do
+ screen -S "cross-${player}-${c}" -dm bash -c "sourcer ~/.basrc; python3 csv_trimmer.py ${player_dir}/csvs/test_${c}.csv.bz2 ${outputs_dir}/${player}/test_${c}_reduced.csv.bz2"
+ done
+done
+
diff --git a/3-analysis/3-2-model_cross_table_val.sh b/3-analysis/3-2-model_cross_table_val.sh
new file mode 100755
index 0000000..e0c1562
--- /dev/null
+++ b/3-analysis/3-2-model_cross_table_val.sh
@@ -0,0 +1,40 @@
+#!/bin/bash
+set -e
+
+max_screens=40
+
+targets_dir="../../data/transfer_players"
+outputs_dir="../../data/transfer_results_cross_val"
+
+models_dir="../../transfer_training/final_models"
+
+target_models=`echo ../../transfer_training/final_models_val/unfrozen_copy/* ../../transfer_training/final_models/{no_stop,unfrozen_copy}/* `
+
+mkdir -p $outputs_dir
+
+for model in $target_models; do
+ player=`python3 get_models_player.py ${model}`
+ model_type=`dirname ${model}`
+ model_type=`basename ${model_type}`
+ model_name=`basename ${model}`
+ echo $player $model_type $model
+ for c in "white" "black"; do
+ for t in "train" "extended" "validate"; do
+ player_files=${targets_dir}_${t}/$player/csvs/test_${c}.csv.bz2
+ if [ -f "$player_files" ]; then
+ player_ret_dir=$outputs_dir/$player
+ mkdir -p $player_ret_dir
+ echo $player_files
+ for model2 in $target_models; do
+ model2_name=`basename ${model2}`
+ model2_player=`python3 get_models_player.py ${model2}`
+ screen -S "cross-${player}-${model2_player}-${c}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player} $model2 ${player_files} ${player_ret_dir}/${model2_player}_${c}.csv.bz2"
+ done
+ while [ `screen -ls | wc -l` -gt $max_screens ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+ fi
+ done
+ done
+done
diff --git a/3-analysis/4-0-result_summaries.sh b/3-analysis/4-0-result_summaries.sh
new file mode 100755
index 0000000..f60e316
--- /dev/null
+++ b/3-analysis/4-0-result_summaries.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+set -e
+
+max_screens=80
+
+targets_dir="../../data/transfer_results"
+outputs_dir="../../data/transfer_summaries"
+mkdir -p outputs_dir
+
+for p in $targets_dir/*/*/*/*.bz2 $targets_dir/*/*/*/*/*.bz2; do
+ #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g")
+ out_path=${p/$targets_dir/$outputs_dir}
+ out_path=${out_path/.csv.bz2/.json}
+ base=`dirname ${p/$targets_dir/}`
+ base=${base//\//-}
+ mkdir -p `dirname $out_path`
+ echo $base
+ #"${${}/${outputs_dir}/${targets_dir}}"
+ screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}"
+ while [ `screen -ls | wc -l` -gt $max_screens ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+done
+
+
+
diff --git a/3-analysis/4-1-result_summaries_cross.sh b/3-analysis/4-1-result_summaries_cross.sh
new file mode 100755
index 0000000..8db770d
--- /dev/null
+++ b/3-analysis/4-1-result_summaries_cross.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+set -e
+
+max_screens=80
+
+targets_dir="../../data/transfer_results_cross"
+outputs_dir="../../data/transfer_results_cross_summaries"
+mkdir -p outputs_dir
+
+for p in $targets_dir/*/*.bz2; do
+ #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g")
+ out_path=${p/$targets_dir/$outputs_dir}
+ out_path=${out_path/.csv.bz2/.json}
+ base=`dirname ${p/$targets_dir/}`
+ base=${base//\//-}
+ mkdir -p `dirname $out_path`
+ echo $base
+ #"${${}/${outputs_dir}/${targets_dir}}"
+ screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}"
+ while [ `screen -ls | wc -l` -gt $max_screens ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+done
+
+
+
diff --git a/3-analysis/4-2-result_summaries_val.sh b/3-analysis/4-2-result_summaries_val.sh
new file mode 100755
index 0000000..b2268ad
--- /dev/null
+++ b/3-analysis/4-2-result_summaries_val.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+set -e
+
+max_screens=80
+
+targets_dir="../../data/transfer_results_val/validate"
+outputs_dir="../../data/transfer_summaries_val"
+mkdir -p outputs_dir
+
+for p in $targets_dir/*/*/*/*.bz2; do
+ #result=$(echo "$p" | sed "s/$targets_dir/$outputs_dir/g")
+ out_path=${p/$targets_dir/$outputs_dir}
+ out_path=${out_path/.csv.bz2/.json}
+ base=`dirname ${p/$targets_dir/}`
+ base=${base//\//-}
+ mkdir -p `dirname $out_path`
+ echo $base
+ #"${${}/${outputs_dir}/${targets_dir}}"
+ screen -S "summary${base}" -dm bash -c "source ~/.bashrc; python3 make_summary.py ${p} ${out_path}"
+ while [ `screen -ls | wc -l` -gt $max_screens ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+done
+
+
+
diff --git a/3-analysis/csv_trimmer.py b/3-analysis/csv_trimmer.py
new file mode 100755
index 0000000..9f27754
--- /dev/null
+++ b/3-analysis/csv_trimmer.py
@@ -0,0 +1,49 @@
+import argparse
+import os
+import os.path
+import bz2
+import csv
+import multiprocessing
+import humanize
+import time
+import queue
+import json
+import pandas
+
+import chess
+
+import backend
+
+#@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Run model on all the lines of the csv', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('input', help='input CSV')
+ parser.add_argument('output', help='output CSV')
+ parser.add_argument('--ngames', type=int, help='number of games to read in', default = 10)
+ parser.add_argument('--min_ply', type=int, help='look at games with ply above this', default = 50)
+ parser.add_argument('--max_ply', type=int, help='look at games with ply below this', default = 100)
+ args = parser.parse_args()
+ backend.printWithDate(f"Starting {args.input} to {args.output}")
+
+ with bz2.open(args.input, 'rt') as fin, bz2.open(args.output, 'wt') as fout:
+ reader = csv.DictReader(fin)
+ writer = csv.DictWriter(fout, reader.fieldnames)
+ writer.writeheader()
+ games_count = 0
+ current_game = None
+ for row in reader:
+ if args.min_ply is not None and int(row['num_ply']) <= args.min_ply:
+ continue
+ elif args.max_ply is not None and int(row['num_ply']) >= args.max_ply:
+ continue
+ elif row['game_id'] != current_game:
+ current_game = row['game_id']
+ games_count += 1
+ if args.ngames is not None and games_count >args.ngames:
+ break
+ writer.writerow(row)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/3-analysis/get_accuracy.py b/3-analysis/get_accuracy.py
new file mode 100755
index 0000000..a5f0194
--- /dev/null
+++ b/3-analysis/get_accuracy.py
@@ -0,0 +1,22 @@
+import argparse
+import os
+import os.path
+
+import pandas
+
+def main():
+ parser = argparse.ArgumentParser(description='Quick helper for getting model accuracies', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('inputs', nargs = '+', help='input CSVs')
+ parser.add_argument('--nrows', help='num lines', type = int, default=None)
+ args = parser.parse_args()
+
+ for p in args.inputs:
+ try:
+ df = pandas.read_csv(p, nrows = args.nrows)
+ except EOFError:
+ print(f"{os.path.abspath(p).split('.')[0]} EOF")
+ else:
+ print(f"{os.path.abspath(p).split('.')[0]} {df['model_correct'].mean() * 100:.2f}%")
+
+if __name__ == "__main__":
+ main()
diff --git a/3-analysis/get_models_player.py b/3-analysis/get_models_player.py
new file mode 100755
index 0000000..5da4dfe
--- /dev/null
+++ b/3-analysis/get_models_player.py
@@ -0,0 +1,30 @@
+import argparse
+import os
+import os.path
+import yaml
+
+import pandas
+
+def main():
+ parser = argparse.ArgumentParser(description='Quick helper for getting model players', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('input', help='input model dir')
+ args = parser.parse_args()
+
+ conf_path = os.path.abspath(os.path.join(args.input, "config.yaml"))
+ if os.path.isfile(conf_path):
+ with open(conf_path) as f:
+ cfg = yaml.safe_load(f)
+ try:
+ print(cfg['full_config']['name'])
+ except (KeyError, TypeError):
+ #some have corrupted configs
+ if 'Eron_Capivara' in args.input:
+ print('Eron_Capivara') #hack
+ else:
+
+ print(os.path.basename(os.path.dirname(conf_path)).split('_')[0])
+ else:
+ raise FileNotFoundError(f"Not a config path: {conf_path}")
+
+if __name__ == "__main__":
+ main()
diff --git a/3-analysis/make_summary.py b/3-analysis/make_summary.py
new file mode 100755
index 0000000..291caf9
--- /dev/null
+++ b/3-analysis/make_summary.py
@@ -0,0 +1,150 @@
+import argparse
+import os
+import os.path
+import json
+import re
+import glob
+
+import pandas
+import numpy as np
+
+root_dir = os.path.relpath("../..", start=os.path.dirname(os.path.abspath(__file__)))
+root_dir = os.path.abspath(root_dir)
+cats_ply = {
+ 'early' : (0, 10),
+ 'mid' : (11, 50),
+ 'late' : (51, 999),
+ 'kdd' : (11, 999),
+ }
+
+last_n = [2**n for n in range(12)]
+
+def main():
+ parser = argparse.ArgumentParser(description='Create summary json from results csv', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('input', help='input CSV')
+ parser.add_argument('output', help='output JSON')
+ parser.add_argument('--players_infos', default="/ada/projects/chess/backend-backend/data/players_infos.json")#os.path.join(root_dir, 'data/players_infos.json'))
+ args = parser.parse_args()
+
+ with open(args.players_infos) as f:
+ player_to_dat = json.load(f)
+
+ fixed_data_paths = glob.glob("/ada/projects/chess/backend-backend/data/top_2000_player_data/*.csv.bz2")
+ fixed_data_lookup = {p.split('/')[-1].replace('.csv.bz2','') :p for p in fixed_data_paths}
+
+
+ df = collect_results_csv(args.input ,player_to_dat, fixed_data_lookup)
+
+ r_d = dict(df.iloc[0])
+ sum_dict = {
+ 'count' : len(df),
+ 'player' : r_d['player_name'],
+ 'model' : r_d['model_name'],
+ 'backend' : bool(r_d['backend']),
+ #'model_correct' : float(df['model_correct'].mean()),
+ 'elo' : r_d['elo'],
+ }
+ c = 'white' if 'white' in args.input.split('/')[-1].split('_')[-1] else 'black'
+
+ add_infos(sum_dict, "full", df)
+
+ try:
+ csv_raw_path = glob.glob(f"/ada/projects/chess/backend-backend/data/transfer_players_*/{r_d['player_name']}/csvs/test_{c}.csv.bz2")[0]
+ except IndexError:
+ if args.input.endswith('kdd.csv.bz2'):
+ csv_raw_path = "/ada/projects/chess/backend-backend/data/reduced_kdd_test_set.csv.bz2"
+ else:
+ csv_raw_path = None
+ if csv_raw_path is not None:
+ csv_base = pandas.read_csv(csv_raw_path, index_col=['game_id', 'move_ply'], low_memory=False)
+ csv_base['winrate_no_0'] = np.where(csv_base.reset_index()['move_ply'] < 2,np.nan, csv_base['winrate'])
+ csv_base['next_wr'] = 1 - csv_base['winrate_no_0'].shift(-1)
+ csv_base['move_delta_wr'] = csv_base['next_wr'] - csv_base['winrate']
+ csv_base_dropped = csv_base[~csv_base['winrate_loss'].isna()]
+ csv_base_dropped = csv_base_dropped.join(df.set_index(['game_id', 'move_ply']), how = 'inner', lsuffix = 'r_')
+ csv_base_dropped['move_delta_wr_rounded'] = (csv_base_dropped['move_delta_wr'] * 1).round(2) / 1
+
+ for dr in csv_base_dropped['move_delta_wr_rounded'].unique():
+ if dr < 0 and dr >-.32:
+ df_dr = csv_base_dropped[csv_base_dropped['move_delta_wr_rounded'] == dr]
+ add_infos(sum_dict, f"delta_wr_{dr}", df_dr)
+
+ for k, v in player_to_dat[r_d['player_name']].items():
+ if k != 'name':
+ sum_dict[k] = v
+ if r_d['backend']:
+ sum_dict['backend_elo'] = int(r_d['model_name'].split('_')[-1])
+
+
+ for c, (p_min, p_max) in cats_ply.items():
+ df_c = df[(df['move_ply'] >= p_min) & (df['move_ply'] <= p_max)]
+ add_infos(sum_dict, c, df_c)
+
+ for year in df['UTCDate'].dt.year.unique():
+ df_y = df[df['UTCDate'].dt.year == year]
+ add_infos(sum_dict, int(year), df_y)
+
+ for ply in range(50):
+ df_p = df[df['move_ply'] == ply]
+ if len(df_p) > 0:
+ # ignore the 50% missing ones
+ add_infos(sum_dict, f"ply_{ply}", df_p)
+
+ for won in [True, False]:
+ df_w = df[df['won'] == won]
+ add_infos(sum_dict, "won" if won else "lost", df_w)
+
+ games = list(df.groupby('game_id').first().sort_values('UTCDate').index)
+
+ for n in last_n:
+ df_n = df[df['game_id'].isin(games[-n:])]
+ add_infos(sum_dict, f"last_{n}", df_n)
+
+ p_min, p_max = cats_ply['kdd']
+ df_n_kdd = df_n[(df_n['move_ply'] >= p_min) & (df_n['move_ply'] <= p_max)]
+ add_infos(sum_dict, f"last_{n}_kdd", df_n_kdd)
+
+ with open(args.output, 'wt') as f:
+ json.dump(sum_dict, f)
+
+def collect_results_csv(path, player_to_dat, fixed_data_lookup):
+ try:
+ df = pandas.read_csv(path, low_memory=False)
+ except EOFError:
+ print(f"Error on: {path}")
+ return None
+ if len(df) < 1:
+ return None
+ df['colour'] = re.search("(black|white)\.csv\.bz2", path).group(1)
+ #df['class'] = re.search(f"{base_dir}/([a-z]*)/", path).group(1)
+ backend = 'final_backend_' in df['model_name'].iloc[0]
+ df['backend'] = backend
+ try:
+ df['player'] = df['player_name']
+ except KeyError:
+ pass
+ try:
+ if backend:
+ df['model_type'] = df['model_name'].iloc[0].replace('final_', '')
+ else:
+ df['model_type'] = df['model_name'].iloc[0].replace(f'{df.iloc[0]["player"]}_', '')
+ for k, v in player_to_dat[df.iloc[0]["player"]].items():
+ if k != 'name':
+ df[k] = v
+ games_df = pandas.read_csv(fixed_data_lookup[df['player'].iloc[0]],
+ low_memory=False, parse_dates = ['UTCDate'], index_col = 'game_id')
+ df = df.join(games_df, how= 'left', on = 'game_id', rsuffix='_per_game')
+ except Exception as e:
+ print(f"{e} : {path}")
+ raise
+ return df
+
+def add_infos(target_dict, name, df_sub):
+ target_dict[f'model_correct_{name}'] = float(df_sub['model_correct'].dropna().mean())
+ target_dict[f'model_correct_per_game_{name}'] = float(df_sub.groupby(['game_id']).mean()['model_correct'].dropna().mean())
+ target_dict[f'count_{name}'] = len(df_sub)
+ target_dict[f'std_{name}'] = float(df_sub['model_correct'].dropna().std())
+ target_dict[f'num_games_{name}'] = len(df_sub.groupby('game_id').count())
+
+if __name__ == "__main__":
+ main()
diff --git a/3-analysis/move_predictions.sh b/3-analysis/move_predictions.sh
new file mode 100755
index 0000000..5922fe2
--- /dev/null
+++ b/3-analysis/move_predictions.sh
@@ -0,0 +1,29 @@
+#!/bin/bash
+set -e
+
+data_dir="../../data/transfer_players_train"
+outputs_dir="../../data/transfer_players_train_results/weights_testing"
+maia_path="../../models/maia/1900"
+models_path="../models/weights_testing"
+
+kdd_path="../../datasets/10000_full_2019-12.csv.bz2"
+
+mkdir -p $outputs_dir
+
+for player_dir in $data_dir/*; do
+ player_name=`basename ${player_dir}`
+ echo $player_name
+ mkdir -p $outputs_dir/$player_name
+
+ for c in "white" "black"; do
+ #echo "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${models_path}/${player_name}*/ ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/${player_name}/transfer_test_${c}.csv.bz2"
+ screen -S "test-transfer-${c}-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${models_path}/${player_name}*/ ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/${player_name}/transfer_test_${c}.csv.bz2"
+ screen -S "test-maia-${c}-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py --target_player ${player_name} ${maia_path} ${data_dir}/${player_name}/csvs/test_${c}.csv.bz2 $outputs_dir/$player_name/maia_test_${c}.csv.bz2"
+ done
+ screen -S "kdd-transfer-${player_name}" -dm bash -c "source ~/.bashrc; python3 prediction_generator.py ${models_path}/${player_name}*/ ${kdd_path} $outputs_dir/$player_name/transfer_kdd.csv.bz2"
+
+ while [ `screen -ls | wc -l` -gt 250 ]; do
+ printf "waiting\r"
+ sleep 10
+ done
+done
diff --git a/3-analysis/prediction_generator.py b/3-analysis/prediction_generator.py
new file mode 100755
index 0000000..95029d0
--- /dev/null
+++ b/3-analysis/prediction_generator.py
@@ -0,0 +1,203 @@
+import argparse
+import os
+import os.path
+import bz2
+import csv
+import multiprocessing
+import humanize
+import time
+import queue
+import json
+import pandas
+
+import chess
+
+import backend
+
+#@backend.logged_main
+def main():
+ parser = argparse.ArgumentParser(description='Run model on all the lines of the csv', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+
+ parser.add_argument('model', help='model dir or file')
+ parser.add_argument('input', help='input CSV')
+ parser.add_argument('output', help='CSV')
+ parser.add_argument('model_name', nargs='?',help='model name')
+ parser.add_argument('--target_player', type=str, help='Only look at board by this player', default = None)
+ parser.add_argument('--nrows', type=int, help='number of rows to read in', default = None)
+ parser.add_argument('--pool_size', type=int, help='Number of models to run in parallel', default = 4)
+ parser.add_argument('--overwrite', help='Overwrite successful runs', default = False, action="store_true")
+ args = parser.parse_args()
+ backend.printWithDate(f"Starting model {args.model} analysis of {args.input} to {args.output}")
+
+ if not args.overwrite and os.path.isfile(args.output):
+ try:
+ df_out = pandas.read_csv(args.output, low_memory=False)
+ out_player_name = df_out.iloc[0]['player_name']
+ except (EOFError, KeyError):
+ backend.printWithDate("Found corrupted output file, overwriting")
+ else:
+ df_in = pandas.read_csv(args.input,low_memory=False)
+ if len(df_out) < .9 * len(df_in) * (.5 if args.target_player else 1.):
+ backend.printWithDate(f"Found truncated {len(df_out)} instead of {len(df_in) * (.5 if args.target_player else 1.)} output file, overwriting")
+ elif out_player_name != args.target_player:
+ backend.printWithDate(f"Found incorrect player {out_player_name} instead of {args.target_player}")
+ elif 'player_name' not in df_out.columns:
+ backend.printWithDate("Found output file missing player, overwriting")
+ df_out['player_name'] = args.target_player if args.target_player is not None else ''
+ df_out.to_csv(args.output, index = None)
+ return
+ else:
+ backend.printWithDate(f"Found completed output file, ending job")
+ return
+ model = None
+ model_name = args.model_name
+ if os.path.isfile(args.model):
+ model = args.model
+ else:
+ for name, _, files in os.walk(args.model):
+ if 'config.yaml' in files:
+ model = name
+ break
+ if model is None:
+ model_files = sorted([p.path for p in os.scandir(args.model) if p.name.endswith('pb.gz')], key = lambda x : int(x.split('/')[-1].split('-')[1]))
+ if len(model_files) > 0:
+ model = model_files[-1]
+ if model_name is None:
+ model_name = f"{os.path.basename(os.path.dirname(model))}_{model.split('/')[-1].split('-')[1]}"
+ else:
+ raise RuntimeError(f"No model or config found for: {args.model}")
+ backend.printWithDate(f"Found model: {model}")
+ multiProc = backend.Multiproc(args.pool_size)
+ multiProc.reader_init(CSV_reader, args.input, args.nrows, args.target_player)
+ multiProc.processor_init(line_processor, model, model_name)
+ multiProc.writer_init(CSV_writer, args.output, args.target_player)
+
+ multiProc.run()
+
+class CSV_reader(backend.MultiprocIterable):
+ def __init__(self, path, max_rows, target_player):
+ self.path = path
+ self.max_rows = max_rows
+ self.target_player = target_player
+ self.f = bz2.open(self.path, 'rt')
+ self.reader = csv.DictReader(self.f)
+ self.reader_iter = enumerate(self.reader)
+ self.board = chess.Board()
+ self.current_game = None
+ self.tstart = time.time()
+
+ def __del__(self):
+ try:
+ self.f.close()
+ except:
+ pass
+
+ def __next__(self):
+ i, row = next(self.reader_iter)
+ if self.max_rows is not None and i >= self.max_rows:
+ raise StopIteration("Hit max number of rows")
+ if row['game_id'] != self.current_game:
+ self.current_game = row['game_id']
+ self.board = chess.Board(fen = row['board'])
+ send_board = self.board.copy()
+ try:
+ self.board.push_uci(row['move'])
+ except ValueError:
+ self.current_game = row['game_id']
+ if i % 1000 == 0:
+ backend.printWithDate(f"Row {i} in {self.delta_start()}", end = '\r', flush = True)
+ if self.target_player is None or row['active_player'] == self.target_player:
+ return (send_board, {
+ 'game_id' : row['game_id'],
+ 'move_ply' : row['move_ply'],
+ 'move' : row['move'],
+ }
+ )
+ else:
+ return next(self)
+
+ def delta_start(self):
+ return humanize.naturaldelta(time.time() - self.tstart)
+
+class line_processor(backend.MultiprocWorker):
+ def __init__(self, model_path, model_name):
+ self.model_path = model_path
+ self.model_name = model_name
+
+ if os.path.isdir(self.model_path):
+ self.model = backend.model_from_config(self.model_path)
+ self.name = self.model.config['name']
+ self.display_name = self.model.config['display_name']
+ else:
+ self.model = backend.LC0_Engine(self.model_path)
+ if self.model_name is not None:
+ self.name = self.model_name
+ self.display_name = self.model_name
+ else:
+ self.name = os.path.basename(self.model_path)
+ self.display_name = os.path.basename(self.model_path)
+
+ def __call__(self, board, row):
+ try:
+ v, ps = self.model.board_pv(board)
+ except KeyError:
+ raise backend.SkipCallMultiProc(f"No moves for this board: {board.fen()}")
+ top_move = sorted(ps.items(), key = lambda x : x[1])[-1][0]
+ move_dat = {
+ 'model_move': top_move,
+ 'top_p' : ps[top_move],
+ }
+ try:
+ move_dat['act_p'] = ps[row['move']]
+ except KeyError:
+ move_dat['act_p'] = 0.
+ try:
+ second_move = sorted(ps.items(), key = lambda x : x[1])[-2][0]
+ move_dat['second_move'] = second_move
+ move_dat['second_p'] = ps[second_move]
+ except IndexError:
+ pass
+ return (v,
+ move_dat,
+ row,
+ self.name,
+ self.display_name,
+ )
+
+class CSV_writer(backend.MultiprocWorker):
+ def __init__(self, output_path, player_name):
+ self.output_path = output_path
+ self.player_name = player_name if player_name is not None else ''
+ self.f = bz2.open(self.output_path, 'wt')
+ self.writer = csv.DictWriter(self.f,
+ ['game_id', 'move_ply', 'player_move', 'model_move', 'model_v', 'model_correct', 'model_name', 'model_display_name', 'player_name', 'rl_depth', 'top_p', 'act_p', 'second_move', 'second_p'
+ ])
+ self.writer.writeheader()
+ self.c = 0
+
+ def __del__(self):
+ try:
+ self.f.close()
+ except:
+ pass
+
+ def __call__(self, v, move_dat, row, name, display_name):
+ write_dict = {
+ 'game_id' : row['game_id'],
+ 'move_ply' : row['move_ply'],
+ 'player_move' : row['move'],
+ 'model_correct' : row['move'] == move_dat['model_move'],
+ 'model_name' : name,
+ 'model_display_name' : display_name,
+ 'player_name' : self.player_name,
+ 'rl_depth' : 0,
+ 'model_v' : v
+ }
+ write_dict.update(move_dat)
+ self.writer.writerow(write_dict)
+ self.c += 1
+ if self.c % 10000 == 0:
+ self.f.flush()
+
+if __name__ == "__main__":
+ main()
diff --git a/3-analysis/run-kdd-tests.sh b/3-analysis/run-kdd-tests.sh
new file mode 100755
index 0000000..c807235
--- /dev/null
+++ b/3-analysis/run-kdd-tests.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+
+mkdir -p ../data
+mkdir -p ../data/kdd_sweeps
+
+screen -S "kdd-sweep" -dm bash -c "source ~/.bashrc; python3 ../../analysis/move_prediction_csv.py ../../transfer_models ../../datasets/10000_full_2019-12.csv.bz2 ../data/kdd_sweeps"
diff --git a/4-cp_loss_stylo_baseline/README.md b/4-cp_loss_stylo_baseline/README.md
new file mode 100755
index 0000000..702b709
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/README.md
@@ -0,0 +1,31 @@
+`cp_loss_hist`: This is where i save all the train/validation/test data for adding up all games
+
+`cp_loss_count_per_game`: This is where i save all the train/validation/test data per game. Note the counts haven't been normalized.
+
+`cp_loss_hist_per_move`: This is where i save all the train/validation/test data per move adding up all games.
+
+`cp_loss_hist_per_move_per_game`: This is where i save all the train/validation/test data per move per game.
+
+`cp_loss_hist_per_move_per_game_count`: This is where i save all the train/validation/test data per move per game in counts so they can be added later.
+
+`get_cp_loss.py`: Parsing code to get cp_loss and its histograms for both train and extended players, and save them in format of **.npy**
+
+`get_cp_loss_per_game.py`: Parsing code to get cp_loss and its histograms (counts) for extended players for each game, and save them in format of **.npy**. Note I don't normalize when saving, so I can sweep across it to get parametrization of num_games.
+
+`get_cp_loss_per_move.py`: Parsing code to get cp_loss and its histograms for both train and extended players for all games by moves, and save them in format of **.npy**.
+
+`get_cp_loss_per_move_per_game.py`: Parsing code to get cp_loss and its histograms for both train and extended players for each game by moves, and save them in format of **.npy**.
+
+`get_cp_loss_per_move_per_game_count`: Parsing code to get cp_loss and its histograms (counts) for both train and extended players for each game by moves, and save them in format of **.npy**.
+
+`test_all_games.py`: Baseline to test accuracy using all games instead of individual games, with Euclidean Distance or Naive Bayes. Data is from `cp_loss_hist`.
+
+`sweep_num_games.py`: Baseline using Euclidean Distance or Naive Bayes. Training Data is from `cp_loss_hist` and Test Data is from `cp_loss_count_per_game`. Will sweep across [1, 2, 4, 8, 16] number of games.
+
+`sweep_moves_per_game.py`: Naive Bayes on per move evaluation. This is done on average accuracy for each game. Training data is from `cp_loss_hist_per_move`, Test data is from `cp_loss_hist_per_move_per_game`.
+
+`sweep_moves_all_games.py`: Naive Bayes on per move evaluation. This is done on average accuracy for each game. Data is from `cp_loss_hist_per_move`
+
+`sweep_moves_num_games.py`: Naive Bayes on per move evaluation given number of games. Training data is from `cp_loss_hist_per_move`, Test data is from `cp_loss_hist_per_move_per_game_count`. Set it to 1 will be same as `sweep_moves_per_game.py`
+
+`train_cploss_per_game.py`: Baseline using simple neural network with 2 fully-connected layer. Training on each game, and also evaluate per game accuracy. **This now gives nan value when training on 30 players.**
diff --git a/4-cp_loss_stylo_baseline/get_cp_loss.py b/4-cp_loss_stylo_baseline/get_cp_loss.py
new file mode 100755
index 0000000..1a7d274
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/get_cp_loss.py
@@ -0,0 +1,146 @@
+
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+# import matplotlib
+# matplotlib.use('TkAgg')
+import matplotlib.pyplot as plt
+import multiprocessing
+from functools import partial
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--input_dir', default='/data/transfer_players_validate')
+ parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy')
+ parser.add_argument('--saved_dir', default='cp_loss_hist')
+ parser.add_argument('--will_save', default=True)
+
+ return parser.parse_args()
+
+def normalize(data):
+ norm = np.linalg.norm(data)
+ data_norm = data/norm
+ return data_norm
+
+def prepare_dataset(players, player_name, cp_hist, dataset):
+ # add up black and white games (counts can be directly added)
+ if players[player_name][dataset] is None:
+ players[player_name][dataset] = cp_hist
+ else:
+ players[player_name][dataset] = players[player_name][dataset] + cp_hist
+
+def save_npy(saved_dir, players, player_name, dataset):
+ if not os.path.exists(saved_dir):
+ os.mkdir(saved_dir)
+
+ saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset))
+ print('saving data to {}'.format(saved))
+ np.save(saved, players[player_name][dataset])
+
+def multi_parse(input_dir, saved_dir, players, save, player_name):
+
+ print("=============================================")
+ print("parsing data for {}".format(player_name))
+ players[player_name] = {'train': None, 'validation': None, 'test': None}
+
+ csv_dir = os.path.join(input_dir, player_name, 'csvs')
+ # for each csv, add up black and white games (counts can be directly added)
+ for csv_fname in os.listdir(csv_dir):
+ path = os.path.join(csv_dir, csv_fname)
+ # parse bz2 file
+ source_file = bz2.BZ2File(path, "r")
+ cp_hist, num_games = get_cp_loss_from_csv(player_name, source_file)
+ print(path)
+
+ if csv_fname.startswith('train'):
+ prepare_dataset(players, player_name, cp_hist, 'train')
+
+ elif csv_fname.startswith('validate'):
+ prepare_dataset(players, player_name, cp_hist, 'validation')
+
+ elif csv_fname.startswith('test'):
+ prepare_dataset(players, player_name, cp_hist, 'test')
+
+ # normalize the histogram to range [0, 1]
+ players[player_name]['train'] = normalize(players[player_name]['train'])
+ players[player_name]['validation'] = normalize(players[player_name]['validation'])
+ players[player_name]['test'] = normalize(players[player_name]['test'])
+
+ # save for future use, parsing takes too long...
+ if save:
+ save_npy(saved_dir, players, player_name, 'train')
+ save_npy(saved_dir, players, player_name, 'validation')
+ save_npy(saved_dir, players, player_name, 'test')
+
+def construct_datasets(player_names, input_dir, saved_dir, will_save):
+ players = {}
+
+ pool = multiprocessing.Pool(25)
+ func = partial(multi_parse, input_dir, saved_dir, players, will_save)
+ pool.map(func, player_names)
+ pool.close()
+ pool.join()
+
+def get_cp_loss_from_csv(player_name, path):
+ cp_losses = []
+ games = {}
+ with bz2.open(path, 'rt') as f:
+ for i, line in enumerate(path):
+ if i > 0:
+ line = line.decode("utf-8")
+ row = line.rstrip().split(',')
+ # avoid empty line
+ if row[0] == '':
+ continue
+
+ game_id = row[0]
+ cp_loss = row[17]
+ active_player = row[25]
+ if player_name != active_player:
+ continue
+
+ # ignore cases like -inf, inf, nan
+ if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan':
+ # append cp loss per move
+ cp_losses.append(float(cp_loss))
+
+ # for purpose of counting how many games
+ if game_id not in games:
+ games[game_id] = 1
+
+
+ ######################## plot for viewing ########################
+ # plt.hist(cp_losses, density=False, bins=50)
+ # plt.ylabel('Count')
+ # plt.xlabel('Cp Loss')
+ # plt.show()
+
+ cp_hist = np.histogram(cp_losses, density=False, bins=50, range=(0, 5)) # density=False for counts
+
+ cp_hist = cp_hist[0] # cp_hist in format of (hist count, range)
+
+ print("number of games: {}".format(len(games)))
+
+ return cp_hist, len(games)
+
+
+def get_player_names(player_name_dir):
+
+ player_names = []
+ for player_name in os.listdir(player_name_dir):
+ player = player_name.replace("_unfrozen_copy", "")
+ player_names.append(player)
+
+ # print(player_names)
+ return player_names
+
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_names = get_player_names(args.player_name_dir)
+
+ construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save)
diff --git a/4-cp_loss_stylo_baseline/get_cp_loss_per_game.py b/4-cp_loss_stylo_baseline/get_cp_loss_per_game.py
new file mode 100755
index 0000000..2dd0896
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/get_cp_loss_per_game.py
@@ -0,0 +1,132 @@
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+# import matplotlib
+# matplotlib.use('TkAgg')
+import matplotlib.pyplot as plt
+import multiprocessing
+from functools import partial
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--input_dir', default='/data/transfer_players_validate')
+ parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy')
+ parser.add_argument('--saved_dir', default='cp_loss_count_per_game')
+ parser.add_argument('--will_save', default=True)
+
+ return parser.parse_args()
+
+def normalize(data):
+ norm = np.linalg.norm(data)
+ data_norm = data/norm
+ return data_norm
+
+def prepare_dataset(players, player_name, games, dataset):
+ # add up black and white games
+ if players[player_name][dataset] is None:
+ players[player_name][dataset] = games
+ else:
+ players[player_name][dataset].update(games)
+
+def save_npy(saved_dir, players, player_name, dataset):
+ if not os.path.exists(saved_dir):
+ os.mkdir(saved_dir)
+
+ saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset))
+ print('saving data to {}'.format(saved))
+ print('total number of games: {}'.format(len(players[player_name][dataset])))
+ np.save(saved, players[player_name][dataset])
+
+def multi_parse(input_dir, saved_dir, players, save, player_name):
+
+ print("=============================================")
+ print("parsing data for {}".format(player_name))
+ players[player_name] = {'train': None, 'validation': None, 'test': None}
+
+ csv_dir = os.path.join(input_dir, player_name, 'csvs')
+ for csv_fname in os.listdir(csv_dir):
+ path = os.path.join(csv_dir, csv_fname)
+ print(path)
+ source_file = bz2.BZ2File(path, "r")
+ games, num_games = get_cp_loss_from_csv(player_name, source_file)
+
+ if csv_fname.startswith('train'):
+ prepare_dataset(players, player_name, games, 'train')
+
+ elif csv_fname.startswith('validate'):
+ prepare_dataset(players, player_name, games, 'validation')
+
+ elif csv_fname.startswith('test'):
+ prepare_dataset(players, player_name, games, 'test')
+
+ if save:
+ save_npy(saved_dir, players, player_name, 'train')
+ save_npy(saved_dir, players, player_name, 'validation')
+ save_npy(saved_dir, players, player_name, 'test')
+
+def construct_datasets(player_names, input_dir, saved_dir, will_save):
+ players = {}
+
+ pool = multiprocessing.Pool(25)
+ func = partial(multi_parse, input_dir, saved_dir, players, will_save)
+ pool.map(func, player_names)
+ pool.close()
+ pool.join()
+
+def get_cp_loss_from_csv(player_name, path):
+ cp_losses = []
+ games = {}
+ with bz2.open(path, 'rt') as f:
+ for i, line in enumerate(path):
+ if i > 0:
+ line = line.decode("utf-8")
+ row = line.rstrip().split(',')
+ # avoid empty line
+ if row[0] == '':
+ continue
+
+ game_id = row[0]
+ cp_loss = row[17]
+ active_player = row[25]
+ if player_name != active_player:
+ continue
+
+ if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan':
+ cp_losses.append(float(cp_loss))
+
+ # for purpose of counting how many games
+ if game_id not in games:
+ games[game_id] = [float(cp_loss)]
+ else:
+ games[game_id].append(float(cp_loss))
+
+ # get per game histogram
+ for key, value in games.items():
+ games[key] = np.histogram(games[key], density=False, bins=50, range=(0, 5))
+ # cp_hist in format (hist, range)
+ games[key] = games[key][0]
+ # games[key] = normalize(games[key])
+
+ print("number of games: {}".format(len(games)))
+
+ return games, len(games)
+
+def get_player_names(player_name_dir):
+
+ player_names = []
+ for player_name in os.listdir(player_name_dir):
+ player = player_name.replace("_unfrozen_copy", "")
+ player_names.append(player)
+
+ # print(player_names)
+ return player_names
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_names = get_player_names(args.player_name_dir)
+
+ construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save)
diff --git a/4-cp_loss_stylo_baseline/get_cp_loss_per_move.py b/4-cp_loss_stylo_baseline/get_cp_loss_per_move.py
new file mode 100755
index 0000000..336848f
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/get_cp_loss_per_move.py
@@ -0,0 +1,167 @@
+
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+# import matplotlib
+# matplotlib.use('TkAgg')
+import matplotlib.pyplot as plt
+import multiprocessing
+from functools import partial
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--input_dir', default='/data/transfer_players_validate')
+ parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy')
+ parser.add_argument('--saved_dir', default='cp_loss_hist_per_move')
+ parser.add_argument('--will_save', default=True)
+
+ return parser.parse_args()
+
+def normalize(data):
+ for i in range(101):
+ # update in-place
+ if any(v != 0 for v in data['start_after'][i]):
+ start_after_norm = np.linalg.norm(data['start_after'][i])
+ data['start_after'][i] = data['start_after'][i] / start_after_norm
+ if any(v != 0 for v in data['stop_after'][i]):
+ stop_after_norm = np.linalg.norm(data['stop_after'][i])
+ data['stop_after'][i] = data['stop_after'][i] / stop_after_norm
+
+
+def prepare_dataset(players, player_name, cp_loss_hist_dict, dataset):
+ # add up black and white games (counts can be directly added)
+ if players[player_name][dataset] is None:
+ players[player_name][dataset] = cp_loss_hist_dict
+ else:
+ # add up each move
+ for i in range(101):
+ players[player_name][dataset]['start_after'][i] = players[player_name][dataset]['start_after'][i] + cp_loss_hist_dict['start_after'][i]
+ players[player_name][dataset]['stop_after'][i] = players[player_name][dataset]['stop_after'][i] + cp_loss_hist_dict['stop_after'][i]
+
+def save_npy(saved_dir, players, player_name, dataset):
+ if not os.path.exists(saved_dir):
+ os.mkdir(saved_dir)
+
+ saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset))
+ print('saving data to {}'.format(saved))
+ np.save(saved, players[player_name][dataset])
+
+def multi_parse(input_dir, saved_dir, players, save, player_name):
+
+ print("=============================================")
+ print("parsing data for {}".format(player_name))
+ players[player_name] = {'train': None, 'validation': None, 'test': None}
+
+ csv_dir = os.path.join(input_dir, player_name, 'csvs')
+ # for each csv, add up black and white games (counts can be directly added)
+ for csv_fname in os.listdir(csv_dir):
+ path = os.path.join(csv_dir, csv_fname)
+ # parse bz2 file
+ source_file = bz2.BZ2File(path, "r")
+ cp_loss_hist_dict, num_games = get_cp_loss_from_csv(player_name, source_file)
+ print(path)
+
+ if csv_fname.startswith('train'):
+ prepare_dataset(players, player_name, cp_loss_hist_dict, 'train')
+
+ elif csv_fname.startswith('validate'):
+ prepare_dataset(players, player_name, cp_loss_hist_dict, 'validation')
+
+ elif csv_fname.startswith('test'):
+ prepare_dataset(players, player_name, cp_loss_hist_dict, 'test')
+
+ # normalize the histogram to range [0, 1]
+ normalize(players[player_name]['train'])
+ normalize(players[player_name]['validation'])
+ normalize(players[player_name]['test'])
+
+ # save for future use, parsing takes too long...
+ if save:
+ save_npy(saved_dir, players, player_name, 'train')
+ save_npy(saved_dir, players, player_name, 'validation')
+ save_npy(saved_dir, players, player_name, 'test')
+
+def construct_datasets(player_names, input_dir, saved_dir, will_save):
+ players = {}
+
+ pool = multiprocessing.Pool(25)
+ func = partial(multi_parse, input_dir, saved_dir, players, will_save)
+ pool.map(func, player_names)
+ pool.close()
+ pool.join()
+
+def get_cp_loss_start_after(cp_losses, move_start=0):
+ return cp_losses[move_start:]
+
+# 0 will be empty
+def get_cp_loss_stop_after(cp_losses, move_stop=100):
+ return cp_losses[:move_stop]
+
+# move_stop is in range [0, 100]
+def get_cp_loss_from_csv(player_name, path):
+ cp_loss_hist_dict = {'start_after': {}, 'stop_after': {}}
+ # cp_losses = []
+ games = {}
+ with bz2.open(path, 'rt') as f:
+ for i, line in enumerate(path):
+ if i > 0:
+ line = line.decode("utf-8")
+ row = line.rstrip().split(',')
+ # avoid empty line
+ if row[0] == '':
+ continue
+
+ active_player = row[25]
+ if player_name != active_player:
+ continue
+
+ # move_ply starts from 0, need to add 1, move will be parsed in order
+ move_ply = int(row[13])
+ move = move_ply // 2 + 1
+
+ game_id = row[0]
+ cp_loss = row[17]
+
+ # ignore cases like -inf, inf, nan
+ if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan':
+ if game_id not in games:
+ games[game_id] = [float(cp_loss)]
+ else:
+ games[game_id].append(float(cp_loss))
+
+ # get per game histogram
+ for i in range(101):
+ cp_loss_hist_dict['start_after'][i] = []
+ cp_loss_hist_dict['stop_after'][i] = []
+ for key, value in games.items():
+ cp_loss_hist_dict['start_after'][i].extend(get_cp_loss_start_after(value, i))
+ cp_loss_hist_dict['stop_after'][i].extend(get_cp_loss_stop_after(value, i))
+
+ # transform into counts
+ cp_loss_hist_dict['start_after'][i] = np.histogram(cp_loss_hist_dict['start_after'][i], density=False, bins=50, range=(0, 5))[0]
+ cp_loss_hist_dict['stop_after'][i] = np.histogram(cp_loss_hist_dict['stop_after'][i], density=False, bins=50, range=(0, 5))[0]
+
+
+ print("number of games: {}".format(len(games)))
+
+ return cp_loss_hist_dict, len(games)
+
+def get_player_names(player_name_dir):
+
+ player_names = []
+ for player_name in os.listdir(player_name_dir):
+ player = player_name.replace("_unfrozen_copy", "")
+ player_names.append(player)
+
+ # print(player_names)
+ return player_names
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_names = get_player_names(args.player_name_dir)
+
+ construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save)
diff --git a/4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game.py b/4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game.py
new file mode 100755
index 0000000..84c7a5f
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game.py
@@ -0,0 +1,160 @@
+
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+# import matplotlib
+# matplotlib.use('TkAgg')
+import matplotlib.pyplot as plt
+import multiprocessing
+from functools import partial
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--input_dir', default='/data/transfer_players_validate')
+ parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy')
+ parser.add_argument('--saved_dir', default='cp_loss_hist_per_move_per_game')
+ parser.add_argument('--will_save', default=True)
+
+ return parser.parse_args()
+
+def normalize(data):
+ data_norm = data
+
+ if any(v != 0 for v in data):
+ norm = np.linalg.norm(data)
+ data_norm = data/norm
+
+ return data_norm
+
+def prepare_dataset(players, player_name, games, dataset):
+ # add up black and white games
+ if players[player_name][dataset] is None:
+ players[player_name][dataset] = games
+ else:
+ players[player_name][dataset].update(games)
+
+
+def save_npy(saved_dir, players, player_name, dataset):
+ if not os.path.exists(saved_dir):
+ os.mkdir(saved_dir)
+
+ saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset))
+ print('saving data to {}'.format(saved))
+ np.save(saved, players[player_name][dataset])
+
+def multi_parse(input_dir, saved_dir, players, save, player_name):
+
+ print("=============================================")
+ print("parsing data for {}".format(player_name))
+ players[player_name] = {'train': None, 'validation': None, 'test': None}
+
+ csv_dir = os.path.join(input_dir, player_name, 'csvs')
+ # for each csv, add up black and white games (counts can be directly added)
+ for csv_fname in os.listdir(csv_dir):
+ path = os.path.join(csv_dir, csv_fname)
+ print(path)
+ # parse bz2 file
+ source_file = bz2.BZ2File(path, "r")
+ games, num_games = get_cp_loss_from_csv(player_name, source_file)
+
+ if csv_fname.startswith('train'):
+ prepare_dataset(players, player_name, games, 'train')
+
+ elif csv_fname.startswith('validate'):
+ prepare_dataset(players, player_name, games, 'validation')
+
+ elif csv_fname.startswith('test'):
+ prepare_dataset(players, player_name, games, 'test')
+
+ # save for future use, parsing takes too long...
+ if save:
+ save_npy(saved_dir, players, player_name, 'train')
+ save_npy(saved_dir, players, player_name, 'validation')
+ save_npy(saved_dir, players, player_name, 'test')
+
+def construct_datasets(player_names, input_dir, saved_dir, will_save):
+ players = {}
+
+ pool = multiprocessing.Pool(25)
+ func = partial(multi_parse, input_dir, saved_dir, players, will_save)
+ pool.map(func, player_names)
+ pool.close()
+ pool.join()
+
+
+def get_cp_loss_start_after(cp_losses, move_start=0):
+ return cp_losses[move_start:]
+
+# 0 will be empty
+def get_cp_loss_stop_after(cp_losses, move_stop=100):
+ return cp_losses[:move_stop]
+
+# move_stop is in range [0, 100]
+def get_cp_loss_from_csv(player_name, path):
+ # cp_losses = []
+ games = {}
+ with bz2.open(path, 'rt') as f:
+ for i, line in enumerate(path):
+ if i > 0:
+ line = line.decode("utf-8")
+ row = line.rstrip().split(',')
+ # avoid empty line
+ if row[0] == '':
+ continue
+
+ active_player = row[25]
+ if player_name != active_player:
+ continue
+
+ # move_ply starts from 0, need to add 1, move will be parsed in order
+ move_ply = int(row[13])
+ move = move_ply // 2 + 1
+
+ game_id = row[0]
+ cp_loss = row[17]
+
+ # ignore cases like -inf, inf, nan
+ if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan':
+ if game_id not in games:
+ games[game_id] = [float(cp_loss)]
+ else:
+ games[game_id].append(float(cp_loss))
+
+ final_games = {key: {'start_after': {}, 'stop_after': {}} for key in games.keys()}
+
+ # get per game histogram
+ for i in range(101):
+ for key, value in games.items():
+ final_games[key]['start_after'][i] = get_cp_loss_start_after(value, i)
+ final_games[key]['stop_after'][i] = get_cp_loss_stop_after(value, i)
+
+ final_games[key]['start_after'][i] = np.histogram(final_games[key]['start_after'][i], density=False, bins=50, range=(0, 5))[0]
+ final_games[key]['stop_after'][i] = np.histogram(final_games[key]['stop_after'][i], density=False, bins=50, range=(0, 5))[0]
+
+ final_games[key]['start_after'][i] = normalize(final_games[key]['start_after'][i])
+ final_games[key]['stop_after'][i] = normalize(final_games[key]['stop_after'][i])
+
+ print("number of games: {}".format(len(games)))
+
+ return final_games, len(games)
+
+
+def get_player_names(player_name_dir):
+
+ player_names = []
+ for player_name in os.listdir(player_name_dir):
+ player = player_name.replace("_unfrozen_copy", "")
+ player_names.append(player)
+
+ # print(player_names)
+ return player_names
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_names = get_player_names(args.player_name_dir)
+
+ construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save)
diff --git a/4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game_count.py b/4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game_count.py
new file mode 100755
index 0000000..bbed7f5
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/get_cp_loss_per_move_per_game_count.py
@@ -0,0 +1,168 @@
+
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+# import matplotlib
+# matplotlib.use('TkAgg')
+import matplotlib.pyplot as plt
+import multiprocessing
+from functools import partial
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--input_dir', default='/data/transfer_players_validate')
+ parser.add_argument('--player_name_dir', default='../transfer_training/final_models_val/unfrozen_copy')
+ parser.add_argument('--saved_dir', default='cp_loss_hist_per_move_per_game_count')
+ parser.add_argument('--will_save', default=True)
+
+ return parser.parse_args()
+
+def normalize(data):
+ data_norm = data
+
+ if any(v != 0 for v in data):
+ norm = np.linalg.norm(data)
+ data_norm = data/norm
+
+ return data_norm
+
+def prepare_dataset(players, player_name, games, dataset):
+ # add up black and white games
+ if players[player_name][dataset] is None:
+ players[player_name][dataset] = games
+ else:
+ players[player_name][dataset].update(games)
+
+
+def save_npy(saved_dir, players, player_name, dataset):
+ if not os.path.exists(saved_dir):
+ os.mkdir(saved_dir)
+
+ saved = os.path.join(saved_dir, player_name + '_{}.npy'.format(dataset))
+ print('saving data to {}'.format(saved))
+ np.save(saved, players[player_name][dataset])
+
+def multi_parse(input_dir, saved_dir, players, save, player_name):
+
+ print("=============================================")
+ print("parsing data for {}".format(player_name))
+ players[player_name] = {'train': None, 'validation': None, 'test': None}
+
+ csv_dir = os.path.join(input_dir, player_name, 'csvs')
+ # for each csv, add up black and white games (counts can be directly added)
+ for csv_fname in os.listdir(csv_dir):
+ path = os.path.join(csv_dir, csv_fname)
+ print(path)
+ # parse bz2 file
+ source_file = bz2.BZ2File(path, "r")
+ games, num_games = get_cp_loss_from_csv(player_name, source_file)
+
+ if csv_fname.startswith('train'):
+ prepare_dataset(players, player_name, games, 'train')
+
+ elif csv_fname.startswith('validate'):
+ prepare_dataset(players, player_name, games, 'validation')
+
+ elif csv_fname.startswith('test'):
+ prepare_dataset(players, player_name, games, 'test')
+
+ # save for future use, parsing takes too long...
+ if save:
+ save_npy(saved_dir, players, player_name, 'train')
+ save_npy(saved_dir, players, player_name, 'validation')
+ save_npy(saved_dir, players, player_name, 'test')
+
+def construct_datasets(player_names, input_dir, saved_dir, will_save):
+ players = {}
+
+ pool = multiprocessing.Pool(40)
+ func = partial(multi_parse, input_dir, saved_dir, players, will_save)
+ pool.map(func, player_names)
+ pool.close()
+ pool.join()
+
+
+def get_cp_loss_start_after(cp_losses, move_start=0):
+ return cp_losses[move_start:]
+
+# 0 will be empty
+def get_cp_loss_stop_after(cp_losses, move_stop=100):
+ return cp_losses[:move_stop]
+
+# move_stop is in range [0, 100]
+def get_cp_loss_from_csv(player_name, path):
+ # cp_losses = []
+ games = {}
+ with bz2.open(path, 'rt') as f:
+ for i, line in enumerate(path):
+ if i > 0:
+ line = line.decode("utf-8")
+ row = line.rstrip().split(',')
+ # avoid empty line
+ if row[0] == '':
+ continue
+
+ active_player = row[25]
+ if player_name != active_player:
+ continue
+
+ # move_ply starts from 0, need to add 1, move will be parsed in order
+ move_ply = int(row[13])
+ move = move_ply // 2 + 1
+
+ game_id = row[0]
+ cp_loss = row[17]
+
+ if game_id in games:
+ if cp_loss == str(-1 * np.inf) or cp_loss == str(np.inf) or cp_loss == 'nan':
+ cp_loss = float(-100)
+
+ # ignore cases like -inf, inf, nan
+ if cp_loss != str(-1 * np.inf) and cp_loss != str(np.inf) and cp_loss != 'nan':
+ if game_id not in games:
+ games[game_id] = [float(cp_loss)]
+ else:
+ games[game_id].append(float(cp_loss))
+
+ final_games = {key: {'start_after': {}, 'stop_after': {}} for key, value in games.items() if len(value) > 25 and len(value) < 50}
+ # final_games = {key: {'start_after': {}, 'stop_after': {}} for key in games.keys()}
+
+ # get per game histogram
+ for i in range(101):
+ for key, value in games.items():
+ if len(value) > 25 and len(value) < 50:
+ if key not in final_games:
+ print(key)
+
+ final_games[key]['start_after'][i] = get_cp_loss_start_after(value, i)
+ final_games[key]['stop_after'][i] = get_cp_loss_stop_after(value, i)
+
+ final_games[key]['start_after'][i] = np.histogram(final_games[key]['start_after'][i], density=False, bins=50, range=(0, 5))[0]
+ final_games[key]['stop_after'][i] = np.histogram(final_games[key]['stop_after'][i], density=False, bins=50, range=(0, 5))[0]
+
+ # final_games[key]['start_after'][i] = normalize(final_games[key]['start_after'][i])
+ # final_games[key]['stop_after'][i] = normalize(final_games[key]['stop_after'][i])
+
+ print("number of games: {}".format(len(games)))
+ return final_games, len(games)
+
+
+def get_player_names(player_name_dir):
+
+ player_names = []
+ for player_name in os.listdir(player_name_dir):
+ player = player_name.replace("_unfrozen_copy", "")
+ player_names.append(player)
+
+ # print(player_names)
+ return player_names
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_names = get_player_names(args.player_name_dir)
+
+ construct_datasets(player_names, args.input_dir, args.saved_dir, args.will_save)
diff --git a/4-cp_loss_stylo_baseline/results/games_accuracy.csv b/4-cp_loss_stylo_baseline/results/games_accuracy.csv
new file mode 100755
index 0000000..865021d
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results/games_accuracy.csv
@@ -0,0 +1,6 @@
+num_games,accuracy
+1,0.05579916684937514
+2,0.06811689738519007
+4,0.08245149911816578
+8,0.11002661934338953
+16,0.1223021582733813
diff --git a/4-cp_loss_stylo_baseline/results/start_after.csv b/4-cp_loss_stylo_baseline/results/start_after.csv
new file mode 100755
index 0000000..4156ac6
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results/start_after.csv
@@ -0,0 +1,102 @@
+move,accuracy
+0,0.059309021113243765
+1,0.05508637236084453
+2,0.053358925143953934
+3,0.052207293666026874
+4,0.050287907869481764
+5,0.054318618042226485
+6,0.05182341650671785
+7,0.05105566218809981
+8,0.05259117082533589
+9,0.05356114417354579
+10,0.05586484929928969
+11,0.05682472643501632
+12,0.05779569892473118
+13,0.05915114269252929
+14,0.05859750240153699
+15,0.05573707476455891
+16,0.05714835482008851
+17,0.05772200772200772
+18,0.05515773175924134
+19,0.05539358600583091
+20,0.05190989226248776
+21,0.05554457402648745
+22,0.05923836389280677
+23,0.06203627370156636
+24,0.0579647917561185
+25,0.053575482406356414
+26,0.053108174253548704
+27,0.05705944798301486
+28,0.061177152797912436
+29,0.048
+30,0.05115452930728242
+31,0.045636509207365894
+32,0.040467625899280574
+33,0.04314329738058552
+34,0.03882915173237754
+35,0.03529411764705882
+36,0.0575831305758313
+37,0.04664723032069971
+38,0.05052878965922444
+39,0.04850213980028531
+40,0.06204379562043796
+41,0.06697459584295612
+42,0.06382978723404255
+43,0.05761316872427984
+44,0.08928571428571429
+45,0.13274336283185842
+46,0.1506849315068493
+47,0.2571428571428571
+48,0.3076923076923077
+49,0
+50,0
+51,0
+52,0
+53,0
+54,0
+55,0
+56,0
+57,0
+58,0
+59,0
+60,0
+61,0
+62,0
+63,0
+64,0
+65,0
+66,0
+67,0
+68,0
+69,0
+70,0
+71,0
+72,0
+73,0
+74,0
+75,0
+76,0
+77,0
+78,0
+79,0
+80,0
+81,0
+82,0
+83,0
+84,0
+85,0
+86,0
+87,0
+88,0
+89,0
+90,0
+91,0
+92,0
+93,0
+94,0
+95,0
+96,0
+97,0
+98,0
+99,0
+100,0
diff --git a/4-cp_loss_stylo_baseline/results/start_after_all_game.csv b/4-cp_loss_stylo_baseline/results/start_after_all_game.csv
new file mode 100755
index 0000000..c4c59b7
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results/start_after_all_game.csv
@@ -0,0 +1,102 @@
+move,accuracy
+0,0.6333333333333333
+1,0.6666666666666666
+2,0.5666666666666667
+3,0.5666666666666667
+4,0.5333333333333333
+5,0.5333333333333333
+6,0.5666666666666667
+7,0.5333333333333333
+8,0.5
+9,0.5666666666666667
+10,0.5
+11,0.4666666666666667
+12,0.4666666666666667
+13,0.5
+14,0.5
+15,0.43333333333333335
+16,0.3
+17,0.23333333333333334
+18,0.3
+19,0.3333333333333333
+20,0.3
+21,0.36666666666666664
+22,0.3333333333333333
+23,0.26666666666666666
+24,0.3
+25,0.23333333333333334
+26,0.26666666666666666
+27,0.36666666666666664
+28,0.26666666666666666
+29,0.3
+30,0.4666666666666667
+31,0.43333333333333335
+32,0.36666666666666664
+33,0.36666666666666664
+34,0.36666666666666664
+35,0.3333333333333333
+36,0.4
+37,0.36666666666666664
+38,0.4
+39,0.4
+40,0.4
+41,0.23333333333333334
+42,0.3333333333333333
+43,0.2857142857142857
+44,0.3333333333333333
+45,0.2222222222222222
+46,0.28
+47,0.5
+48,0.3333333333333333
+49,0
+50,0
+51,0
+52,0
+53,0
+54,0
+55,0
+56,0
+57,0
+58,0
+59,0
+60,0
+61,0
+62,0
+63,0
+64,0
+65,0
+66,0
+67,0
+68,0
+69,0
+70,0
+71,0
+72,0
+73,0
+74,0
+75,0
+76,0
+77,0
+78,0
+79,0
+80,0
+81,0
+82,0
+83,0
+84,0
+85,0
+86,0
+87,0
+88,0
+89,0
+90,0
+91,0
+92,0
+93,0
+94,0
+95,0
+96,0
+97,0
+98,0
+99,0
+100,0
diff --git a/4-cp_loss_stylo_baseline/results/stop_after.csv b/4-cp_loss_stylo_baseline/results/stop_after.csv
new file mode 100755
index 0000000..26194d8
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results/stop_after.csv
@@ -0,0 +1,102 @@
+move,accuracy
+0,0.01689059500959693
+1,0.10134357005758157
+2,0.07044145873320537
+3,0.06967370441458733
+4,0.0783109404990403
+5,0.06756238003838771
+6,0.06641074856046066
+7,0.06602687140115163
+8,0.06506717850287908
+9,0.06641074856046066
+10,0.06238003838771593
+11,0.05911708253358925
+12,0.0581573896353167
+13,0.059309021113243765
+14,0.0600767754318618
+15,0.06065259117082534
+16,0.06218809980806142
+17,0.06065259117082534
+18,0.061420345489443376
+19,0.061036468330134354
+20,0.060460652591170824
+21,0.061420345489443376
+22,0.060844529750479846
+23,0.06161228406909789
+24,0.0600767754318618
+25,0.05969289827255278
+26,0.060844529750479846
+27,0.06238003838771593
+28,0.05969289827255278
+29,0.05969289827255278
+30,0.06333973128598848
+31,0.061036468330134354
+32,0.061036468330134354
+33,0.06161228406909789
+34,0.060844529750479846
+35,0.05854126679462572
+36,0.05969289827255278
+37,0.0600767754318618
+38,0.059309021113243765
+39,0.05892514395393474
+40,0.059884836852207295
+41,0.05892514395393474
+42,0.05911708253358925
+43,0.05873320537428023
+44,0.05873320537428023
+45,0.05873320537428023
+46,0.05854126679462572
+47,0.0581573896353167
+48,0.05873320537428023
+49,0.059309021113243765
+50,0.059309021113243765
+51,0.059309021113243765
+52,0.059309021113243765
+53,0.059309021113243765
+54,0.059309021113243765
+55,0.059309021113243765
+56,0.059309021113243765
+57,0.059309021113243765
+58,0.059309021113243765
+59,0.059309021113243765
+60,0.059309021113243765
+61,0.059309021113243765
+62,0.059309021113243765
+63,0.059309021113243765
+64,0.059309021113243765
+65,0.059309021113243765
+66,0.059309021113243765
+67,0.059309021113243765
+68,0.059309021113243765
+69,0.059309021113243765
+70,0.059309021113243765
+71,0.059309021113243765
+72,0.059309021113243765
+73,0.059309021113243765
+74,0.059309021113243765
+75,0.059309021113243765
+76,0.059309021113243765
+77,0.059309021113243765
+78,0.059309021113243765
+79,0.059309021113243765
+80,0.059309021113243765
+81,0.059309021113243765
+82,0.059309021113243765
+83,0.059309021113243765
+84,0.059309021113243765
+85,0.059309021113243765
+86,0.059309021113243765
+87,0.059309021113243765
+88,0.059309021113243765
+89,0.059309021113243765
+90,0.059309021113243765
+91,0.059309021113243765
+92,0.059309021113243765
+93,0.059309021113243765
+94,0.059309021113243765
+95,0.059309021113243765
+96,0.059309021113243765
+97,0.059309021113243765
+98,0.059309021113243765
+99,0.059309021113243765
+100,0.059309021113243765
diff --git a/4-cp_loss_stylo_baseline/results/stop_after_all_game.csv b/4-cp_loss_stylo_baseline/results/stop_after_all_game.csv
new file mode 100755
index 0000000..7d2a525
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results/stop_after_all_game.csv
@@ -0,0 +1,102 @@
+move,accuracy
+0,0.03333333333333333
+1,0.16666666666666666
+2,0.5333333333333333
+3,0.6333333333333333
+4,0.5
+5,0.5
+6,0.5
+7,0.43333333333333335
+8,0.43333333333333335
+9,0.43333333333333335
+10,0.5
+11,0.43333333333333335
+12,0.4666666666666667
+13,0.5
+14,0.5333333333333333
+15,0.5
+16,0.5333333333333333
+17,0.5666666666666667
+18,0.5333333333333333
+19,0.5666666666666667
+20,0.6
+21,0.5333333333333333
+22,0.5666666666666667
+23,0.5333333333333333
+24,0.6
+25,0.6
+26,0.6
+27,0.6666666666666666
+28,0.6333333333333333
+29,0.6333333333333333
+30,0.6333333333333333
+31,0.6666666666666666
+32,0.6666666666666666
+33,0.6666666666666666
+34,0.6
+35,0.6
+36,0.6
+37,0.6
+38,0.5666666666666667
+39,0.6
+40,0.6
+41,0.6
+42,0.6333333333333333
+43,0.6
+44,0.6333333333333333
+45,0.6333333333333333
+46,0.6333333333333333
+47,0.6333333333333333
+48,0.6333333333333333
+49,0.6333333333333333
+50,0.6333333333333333
+51,0.6333333333333333
+52,0.6333333333333333
+53,0.6333333333333333
+54,0.6333333333333333
+55,0.6333333333333333
+56,0.6333333333333333
+57,0.6333333333333333
+58,0.6333333333333333
+59,0.6333333333333333
+60,0.6333333333333333
+61,0.6333333333333333
+62,0.6333333333333333
+63,0.6333333333333333
+64,0.6333333333333333
+65,0.6333333333333333
+66,0.6333333333333333
+67,0.6333333333333333
+68,0.6333333333333333
+69,0.6333333333333333
+70,0.6333333333333333
+71,0.6333333333333333
+72,0.6333333333333333
+73,0.6333333333333333
+74,0.6333333333333333
+75,0.6333333333333333
+76,0.6333333333333333
+77,0.6333333333333333
+78,0.6333333333333333
+79,0.6333333333333333
+80,0.6333333333333333
+81,0.6333333333333333
+82,0.6333333333333333
+83,0.6333333333333333
+84,0.6333333333333333
+85,0.6333333333333333
+86,0.6333333333333333
+87,0.6333333333333333
+88,0.6333333333333333
+89,0.6333333333333333
+90,0.6333333333333333
+91,0.6333333333333333
+92,0.6333333333333333
+93,0.6333333333333333
+94,0.6333333333333333
+95,0.6333333333333333
+96,0.6333333333333333
+97,0.6333333333333333
+98,0.6333333333333333
+99,0.6333333333333333
+100,0.6333333333333333
diff --git a/4-cp_loss_stylo_baseline/results_validation/games_accuracy.csv b/4-cp_loss_stylo_baseline/results_validation/games_accuracy.csv
new file mode 100755
index 0000000..02a6e93
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results_validation/games_accuracy.csv
@@ -0,0 +1,6 @@
+num_games,accuracy
+1,0.005322294500295683
+2,0.007152042305413879
+4,0.00927616894222686
+8,0.013657853265627736
+16,0.020702709097361882
diff --git a/4-cp_loss_stylo_baseline/results_validation/start_after.csv b/4-cp_loss_stylo_baseline/results_validation/start_after.csv
new file mode 100755
index 0000000..9da3c7d
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results_validation/start_after.csv
@@ -0,0 +1,102 @@
+move,accuracy
+0,0.008001119393492254
+1,0.007988399012898465
+2,0.006487394102831557
+3,0.00449029434960694
+4,0.0042867682601063425
+5,0.004032360648230595
+6,0.004032360648230595
+7,0.004235940620508059
+8,0.004312373586393762
+9,0.004134439242825158
+10,0.004134649636150832
+11,0.0044149829507863
+12,0.004517631488527761
+13,0.004315392840776007
+14,0.004419762835780974
+15,0.004282436910527657
+16,0.004453405132262304
+17,0.0038727983844167794
+18,0.003896336930609315
+19,0.0038616499543038087
+20,0.0037848837962902956
+21,0.0037588076590617386
+22,0.004161569962240068
+23,0.0038119613902767757
+24,0.0037130110684436414
+25,0.003143389199255121
+26,0.002826501275963433
+27,0.002573004599686305
+28,0.0027073019801980196
+29,0.002663825253063399
+30,0.0027465372321534274
+31,0.0026176626123744053
+32,0.0026072529035316427
+33,0.0031696249833177634
+34,0.0031620252200083815
+35,0.0035925520262869663
+36,0.003196184871391609
+37,0.0030963439323567943
+38,0.0032433194669674965
+39,0.004067796610169492
+40,0.003640902943930095
+41,0.003702234563004099
+42,0.004691572545612511
+43,0.004493850520340586
+44,0.0050150451354062184
+45,0.010349926071956629
+46,0.008507347254447023
+47,0.008253094910591471
+48,0.038338658146964855
+49,0
+50,0
+51,0
+52,0
+53,0
+54,0
+55,0
+56,0
+57,0
+58,0
+59,0
+60,0
+61,0
+62,0
+63,0
+64,0
+65,0
+66,0
+67,0
+68,0
+69,0
+70,0
+71,0
+72,0
+73,0
+74,0
+75,0
+76,0
+77,0
+78,0
+79,0
+80,0
+81,0
+82,0
+83,0
+84,0
+85,0
+86,0
+87,0
+88,0
+89,0
+90,0
+91,0
+92,0
+93,0
+94,0
+95,0
+96,0
+97,0
+98,0
+99,0
+100,0
diff --git a/4-cp_loss_stylo_baseline/results_validation/start_after_4games.csv b/4-cp_loss_stylo_baseline/results_validation/start_after_4games.csv
new file mode 100755
index 0000000..66a64e2
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results_validation/start_after_4games.csv
@@ -0,0 +1,102 @@
+move,accuracy
+0,0.013087661671114761
+1,0.01365222746869226
+2,0.011599260932046808
+3,0.00949497023198522
+4,0.008160541983165674
+5,0.007647300349004311
+6,0.008160541983165674
+7,0.0083662680285377
+8,0.00810963403993225
+9,0.00810963403993225
+10,0.00805913454134798
+11,0.00765004877547877
+12,0.008062445437272121
+13,0.008527689304428234
+14,0.008069490131578948
+15,0.007920588386565858
+16,0.007926295743476247
+17,0.006804825239715435
+18,0.006308822008480711
+19,0.005659691572771172
+20,0.0067447453727909655
+21,0.0066029264169880095
+22,0.0075660012878300065
+23,0.006154184295840431
+24,0.006217557469625236
+25,0.005656517029726802
+26,0.004635977799542932
+27,0.0037816625044595075
+28,0.00454331818893937
+29,0.0036316472114137485
+30,0.0045231450293523245
+31,0.004412397761515282
+32,0.0063978754225012075
+33,0.006824075337791729
+34,0.006715602061533656
+35,0.007404731804226115
+36,0.009222385244183609
+37,0.0061942517343904855
+38,0.006505026611472502
+39,0.00972972972972973
+40,0.009379187137114784
+41,0.01678240740740741
+42,0.01564945226917058
+43,0.022598870056497175
+44,0.029209621993127148
+45,0.03932584269662921
+46,0.06349206349206349
+47,0.0641025641025641
+48,0.375
+49,0
+50,0
+51,0
+52,0
+53,0
+54,0
+55,0
+56,0
+57,0
+58,0
+59,0
+60,0
+61,0
+62,0
+63,0
+64,0
+65,0
+66,0
+67,0
+68,0
+69,0
+70,0
+71,0
+72,0
+73,0
+74,0
+75,0
+76,0
+77,0
+78,0
+79,0
+80,0
+81,0
+82,0
+83,0
+84,0
+85,0
+86,0
+87,0
+88,0
+89,0
+90,0
+91,0
+92,0
+93,0
+94,0
+95,0
+96,0
+97,0
+98,0
+99,0
+100,0
diff --git a/4-cp_loss_stylo_baseline/results_validation/stop_after.csv b/4-cp_loss_stylo_baseline/results_validation/stop_after.csv
new file mode 100755
index 0000000..a2e33a7
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results_validation/stop_after.csv
@@ -0,0 +1,102 @@
+move,accuracy
+0,0.0015773271936296335
+1,0.003510825043885313
+2,0.01217340422825451
+3,0.011651868623909227
+4,0.010201745236217467
+5,0.010367110183936703
+6,0.010061821049685806
+7,0.009985498766123082
+8,0.009743811534841123
+9,0.00961660772890325
+10,0.00927315745287099
+11,0.009285877833464778
+12,0.009196835169308266
+13,0.009311318594652352
+14,0.009362200117027502
+15,0.009234996311089629
+16,0.008878825654463582
+17,0.008853384893276008
+18,0.008993309079807667
+19,0.00884066451268222
+20,0.009018749840995242
+21,0.00927315745287099
+22,0.009069631363370393
+23,0.008967868318620092
+24,0.008777062609713282
+25,0.008916986796244943
+26,0.008866105273869794
+27,0.008878825654463582
+28,0.008853384893276008
+29,0.00884066451268222
+30,0.008929707176838731
+31,0.008904266415651157
+32,0.008891546035057369
+33,0.008929707176838731
+34,0.008815223751494645
+35,0.008649858803775409
+36,0.008700740326150558
+37,0.008713460706744346
+38,0.00854809575902511
+39,0.00830640852774315
+40,0.008433612333681024
+41,0.008420891953087236
+42,0.008319128908336937
+43,0.008344569669524512
+44,0.00820464548299285
+45,0.007975678632304679
+46,0.008090162057648766
+47,0.008039280535273615
+48,0.008001119393492254
+49,0.008001119393492254
+50,0.008001119393492254
+51,0.008001119393492254
+52,0.008001119393492254
+53,0.008001119393492254
+54,0.008001119393492254
+55,0.008001119393492254
+56,0.008001119393492254
+57,0.008001119393492254
+58,0.008001119393492254
+59,0.008001119393492254
+60,0.008001119393492254
+61,0.008001119393492254
+62,0.008001119393492254
+63,0.008001119393492254
+64,0.008001119393492254
+65,0.008001119393492254
+66,0.008001119393492254
+67,0.008001119393492254
+68,0.008001119393492254
+69,0.008001119393492254
+70,0.008001119393492254
+71,0.008001119393492254
+72,0.008001119393492254
+73,0.008001119393492254
+74,0.008001119393492254
+75,0.008001119393492254
+76,0.008001119393492254
+77,0.008001119393492254
+78,0.008001119393492254
+79,0.008001119393492254
+80,0.008001119393492254
+81,0.008001119393492254
+82,0.008001119393492254
+83,0.008001119393492254
+84,0.008001119393492254
+85,0.008001119393492254
+86,0.008001119393492254
+87,0.008001119393492254
+88,0.008001119393492254
+89,0.008001119393492254
+90,0.008001119393492254
+91,0.008001119393492254
+92,0.008001119393492254
+93,0.008001119393492254
+94,0.008001119393492254
+95,0.008001119393492254
+96,0.008001119393492254
+97,0.008001119393492254
+98,0.008001119393492254
+99,0.008001119393492254
+100,0.008001119393492254
diff --git a/4-cp_loss_stylo_baseline/results_validation/stop_after_4games.csv b/4-cp_loss_stylo_baseline/results_validation/stop_after_4games.csv
new file mode 100755
index 0000000..7ac9e2a
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/results_validation/stop_after_4games.csv
@@ -0,0 +1,102 @@
+move,accuracy
+0,0.0015910490659002258
+1,0.006056251283104086
+2,0.016115787312666805
+3,0.017193594744405665
+4,0.014884007390679532
+5,0.016013138985834532
+6,0.015499897351673168
+7,0.015448573188257032
+8,0.014986655717511805
+9,0.014678710737014987
+10,0.01591049065900226
+11,0.015294600698008622
+12,0.014524738246766578
+13,0.014730034900431123
+14,0.015037979880927942
+15,0.014730034900431123
+16,0.014832683227263395
+17,0.015345924861424758
+18,0.015089304044344077
+19,0.014114144939437486
+20,0.014422089919934305
+21,0.015037979880927942
+22,0.014678710737014987
+23,0.013960172449189078
+24,0.013857524122356805
+25,0.013395606651611578
+26,0.013087661671114761
+27,0.014165469102853623
+28,0.014576062410182713
+29,0.01365222746869226
+30,0.014319441593102033
+31,0.015037979880927942
+32,0.014576062410182713
+33,0.01478135906384726
+34,0.014422089919934305
+35,0.014576062410182713
+36,0.013960172449189078
+37,0.013754875795524533
+38,0.013600903305276125
+39,0.013908848285772941
+40,0.014319441593102033
+41,0.014114144939437486
+42,0.014268117429685897
+43,0.01365222746869226
+44,0.013395606651611578
+45,0.013292958324779306
+46,0.013190309997947033
+47,0.013292958324779306
+48,0.013190309997947033
+49,0.013087661671114761
+50,0.013087661671114761
+51,0.013087661671114761
+52,0.013087661671114761
+53,0.013087661671114761
+54,0.013087661671114761
+55,0.013087661671114761
+56,0.013087661671114761
+57,0.013087661671114761
+58,0.013087661671114761
+59,0.013087661671114761
+60,0.013087661671114761
+61,0.013087661671114761
+62,0.013087661671114761
+63,0.013087661671114761
+64,0.013087661671114761
+65,0.013087661671114761
+66,0.013087661671114761
+67,0.013087661671114761
+68,0.013087661671114761
+69,0.013087661671114761
+70,0.013087661671114761
+71,0.013087661671114761
+72,0.013087661671114761
+73,0.013087661671114761
+74,0.013087661671114761
+75,0.013087661671114761
+76,0.013087661671114761
+77,0.013087661671114761
+78,0.013087661671114761
+79,0.013087661671114761
+80,0.013087661671114761
+81,0.013087661671114761
+82,0.013087661671114761
+83,0.013087661671114761
+84,0.013087661671114761
+85,0.013087661671114761
+86,0.013087661671114761
+87,0.013087661671114761
+88,0.013087661671114761
+89,0.013087661671114761
+90,0.013087661671114761
+91,0.013087661671114761
+92,0.013087661671114761
+93,0.013087661671114761
+94,0.013087661671114761
+95,0.013087661671114761
+96,0.013087661671114761
+97,0.013087661671114761
+98,0.013087661671114761
+99,0.013087661671114761
+100,0.013087661671114761
diff --git a/4-cp_loss_stylo_baseline/sweep_moves_all_games.py b/4-cp_loss_stylo_baseline/sweep_moves_all_games.py
new file mode 100755
index 0000000..21c7f95
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/sweep_moves_all_games.py
@@ -0,0 +1,145 @@
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+from sklearn.naive_bayes import GaussianNB
+from matplotlib import pyplot as plt
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--input_dir', default='cp_loss_hist_per_move')
+ parser.add_argument('--output_start_after_csv', default='start_after_all_game.csv')
+ parser.add_argument('--output_stop_after_csv', default='stop_after_all_game.csv')
+ parser.add_argument('--saved_plot', default='plot_all_game.png')
+
+ return parser.parse_args()
+
+def read_npy(input_dir):
+
+ player_list = {}
+ for input_data in os.listdir(input_dir):
+ # will split into [player_name, 'train/test/val']
+ input_name = input_data.split('_')
+ if len(input_name) > 2:
+ player_name = input_name[:-1]
+ player_name = '_'.join(player_name)
+ else:
+ player_name = input_name[0]
+ # add into player list
+ if player_name not in player_list:
+ player_list[player_name] = 1
+
+ player_list = list(player_list.keys())
+
+ player_data = {}
+ for player_name in player_list:
+ player_data[player_name] = {'train': None, 'validation': None, 'test': None}
+ train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train'))
+ val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation'))
+ test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test'))
+
+ player_data[player_name]['train'] = np.load(train_path, allow_pickle=True)
+ player_data[player_name]['train'] = player_data[player_name]['train'].item()
+ player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True)
+ player_data[player_name]['validation'] = player_data[player_name]['validation'].item()
+ player_data[player_name]['test'] = np.load(test_path, allow_pickle=True)
+ player_data[player_name]['test'] = player_data[player_name]['test'].item()
+
+ return player_data
+
+
+def construct_train_set(player_data, is_start_after, move_stop):
+ player_index = {}
+ train_list = []
+ train_label = []
+
+ i = 0
+ for player in player_data.keys():
+ # if player in os.listdir('/data/csvs'):
+ player_index[player] = i
+ train_label.append(i)
+ if is_start_after:
+ train_list.append(player_data[player]['train']['start_after'][move_stop])
+ else:
+ train_list.append(player_data[player]['train']['stop_after'][move_stop])
+
+ i += 1
+
+ train_label = np.asarray(train_label)
+ # one_hot = np.zeros((train_label.size, train_label.max()+1))
+ # one_hot[np.arange(train_label.size),train_label] = 1
+ # print(one_hot.shape)
+
+ train_data = np.stack(train_list, 0)
+ return train_data, train_label, player_index
+
+
+def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop):
+ correct = 0
+ total = 0
+ model = GaussianNB()
+ model.fit(train_data, train_label)
+ for player in player_data.keys():
+ test = player_data[player]['test']
+ if is_start_after:
+ test = test['start_after'][move_stop]
+ if all(v == 0 for v in test):
+ continue
+ else:
+ test = test['stop_after'][move_stop]
+
+ predicted = model.predict(np.expand_dims(test, axis=0))
+ index = predicted[0]
+ if index == player_index[player]:
+ correct += 1
+ total += 1
+
+ if total == 0:
+ accuracy = 0
+ else:
+ accuracy = correct / total
+
+ print(accuracy)
+ return accuracy
+
+
+def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name):
+ plt.plot(moves, start_after_accuracies, label="Start after x moves")
+ plt.plot(moves, stop_after_accuracies, label="Stop after x moves")
+ plt.legend()
+ plt.xlabel("Moves")
+ plt.savefig(plot_name)
+
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_data = read_npy(args.input_dir)
+ moves = [i for i in range(101)]
+ start_after_accuracies = []
+ stop_after_accuracies = []
+ output_start_csv = open(args.output_start_after_csv, 'w', newline='')
+ writer_start = csv.writer(output_start_csv)
+ writer_start.writerow(['move', 'accuracy'])
+
+ output_stop_csv = open(args.output_stop_after_csv, 'w', newline='')
+ writer_stop = csv.writer(output_stop_csv)
+ writer_stop.writerow(['move', 'accuracy'])
+
+ for is_start_after in (True, False):
+ for i in range(101):
+ print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i))
+ train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i)
+
+ accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i)
+
+ if is_start_after:
+ start_after_accuracies.append(accuracy)
+ writer_start.writerow([i, accuracy])
+ else:
+ stop_after_accuracies.append(accuracy)
+ writer_stop.writerow([i, accuracy])
+
+ make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot)
diff --git a/4-cp_loss_stylo_baseline/sweep_moves_num_games.py b/4-cp_loss_stylo_baseline/sweep_moves_num_games.py
new file mode 100755
index 0000000..e07cb3e
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/sweep_moves_num_games.py
@@ -0,0 +1,195 @@
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+from sklearn.naive_bayes import GaussianNB
+from matplotlib import pyplot as plt
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--train_dir', default='cp_loss_hist_per_move')
+ parser.add_argument('--input_dir', default='cp_loss_hist_per_move_per_game_count')
+ parser.add_argument('--output_start_after_csv', default='start_after_4games.csv')
+ parser.add_argument('--output_stop_after_csv', default='stop_after_4games.csv')
+ parser.add_argument('--num_games', default=4)
+ parser.add_argument('--saved_plot', default='plot_4games.png')
+
+ return parser.parse_args()
+
+def read_npy(train_dir, input_dir):
+
+ player_list = {}
+ for input_data in os.listdir(input_dir):
+ # will split into [player_name, 'train/test/val']
+ input_name = input_data.split('_')
+ if len(input_name) > 2:
+ player_name = input_name[:-1]
+ player_name = '_'.join(player_name)
+ else:
+ player_name = input_name[0]
+ # add into player list
+ if player_name not in player_list:
+ player_list[player_name] = 1
+
+ player_list = list(player_list.keys())
+
+ player_data = {}
+ for player_name in player_list:
+ player_data[player_name] = {'train': None, 'validation': None, 'test': None}
+ train_path = os.path.join(train_dir, player_name + '_{}.npy'.format('train'))
+ val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation'))
+ test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test'))
+
+ player_data[player_name]['train'] = np.load(train_path, allow_pickle=True)
+ player_data[player_name]['train'] = player_data[player_name]['train'].item()
+ player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True)
+ player_data[player_name]['validation'] = player_data[player_name]['validation'].item()
+ player_data[player_name]['test'] = np.load(test_path, allow_pickle=True)
+ player_data[player_name]['test'] = player_data[player_name]['test'].item()
+
+ return player_data
+
+def normalize(data):
+ data_norm = data
+
+ if any(v != 0 for v in data):
+ norm = np.linalg.norm(data)
+ data_norm = data/norm
+
+ return data_norm
+
+def construct_train_set(player_data, is_start_after, move_stop):
+ player_index = {}
+ train_list = []
+ train_label = []
+
+ i = 0
+ for player in player_data.keys():
+ # if player in os.listdir('/data/csvs'):
+ player_index[player] = i
+ train_label.append(i)
+ if is_start_after:
+ train_list.append(player_data[player]['train']['start_after'][move_stop])
+ else:
+ train_list.append(player_data[player]['train']['stop_after'][move_stop])
+
+ i += 1
+
+ train_label = np.asarray(train_label)
+ # one_hot = np.zeros((train_label.size, train_label.max()+1))
+ # one_hot[np.arange(train_label.size),train_label] = 1
+ # print(one_hot.shape)
+
+ train_data = np.stack(train_list, 0)
+ return train_data, train_label, player_index
+
+
+def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop, num_games):
+ accurcies = []
+ correct = 0
+ total = 0
+ model = GaussianNB()
+ model.fit(train_data, train_label)
+ results = None
+ for player in player_data.keys():
+ test_game = None
+ tmp_game = None
+ test_games = []
+ test = player_data[player]['test']
+ count = 1
+
+ # key is game id
+ for key, value in test.items():
+ # get which game to use
+ if is_start_after:
+ tmp_game = test[key]['start_after'][move_stop]
+ # ignore all 0 cases, essentially there's no more move in this game
+ if all(v == 0 for v in tmp_game):
+ continue
+ else:
+ tmp_game = test[key]['stop_after'][move_stop]
+
+ # add up counts in each game
+ if test_game is None:
+ test_game = tmp_game
+ else:
+ test_game = test_game + tmp_game
+
+ if count == num_games:
+ # test_game is addition of counts, need to normalize before testing
+ test_game = normalize(test_game)
+ test_games.append(test_game)
+
+ # reset
+ test_game = None
+ tmp_game = None
+ count = 1
+
+ else:
+ count += 1
+
+ # skip player if all games are beyond move_stop
+ if not test_games:
+ continue
+
+ test_games = np.stack(test_games, axis=0)
+ predicted = model.predict(test_games)
+ result = (predicted == player_index[player]).astype(float)
+
+ # append to the overall result
+ if results is None:
+ results = result
+ else:
+ results = np.append(results, result, 0)
+
+ if results is None:
+ accuracy = 0
+
+ else:
+ accuracy = np.mean(results)
+
+ print(accuracy)
+
+ return accuracy
+
+
+def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name):
+ plt.plot(moves, start_after_accuracies, label="Start after x moves")
+ plt.plot(moves, stop_after_accuracies, label="Stop after x moves")
+ plt.legend()
+ plt.xlabel("Moves")
+ plt.savefig(plot_name)
+
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_data = read_npy(args.train_dir, args.input_dir)
+ moves = [i for i in range(101)]
+ start_after_accuracies = []
+ stop_after_accuracies = []
+ output_start_csv = open(args.output_start_after_csv, 'w', newline='')
+ writer_start = csv.writer(output_start_csv)
+ writer_start.writerow(['move', 'accuracy'])
+
+ output_stop_csv = open(args.output_stop_after_csv, 'w', newline='')
+ writer_stop = csv.writer(output_stop_csv)
+ writer_stop.writerow(['move', 'accuracy'])
+
+ for is_start_after in (True, False):
+ for i in range(101):
+ print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i))
+ train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i)
+
+ accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i, args.num_games)
+
+ if is_start_after:
+ start_after_accuracies.append(accuracy)
+ writer_start.writerow([i, accuracy])
+ else:
+ stop_after_accuracies.append(accuracy)
+ writer_stop.writerow([i, accuracy])
+
+ make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot)
diff --git a/4-cp_loss_stylo_baseline/sweep_moves_per_game.py b/4-cp_loss_stylo_baseline/sweep_moves_per_game.py
new file mode 100755
index 0000000..5b37033
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/sweep_moves_per_game.py
@@ -0,0 +1,166 @@
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+from sklearn.naive_bayes import GaussianNB
+from matplotlib import pyplot as plt
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--train_dir', default='cp_loss_hist_per_move')
+ parser.add_argument('--input_dir', default='cp_loss_hist_per_move_per_game')
+ parser.add_argument('--output_start_after_csv', default='start_after.csv')
+ parser.add_argument('--output_stop_after_csv', default='stop_after.csv')
+ parser.add_argument('--saved_plot', default='plot.png')
+
+ return parser.parse_args()
+
+def read_npy(train_dir, input_dir):
+
+ player_list = {}
+ for input_data in os.listdir(input_dir):
+ # will split into [player_name, 'train/test/val']
+ input_name = input_data.split('_')
+ if len(input_name) > 2:
+ player_name = input_name[:-1]
+ player_name = '_'.join(player_name)
+ else:
+ player_name = input_name[0]
+ # add into player list
+ if player_name not in player_list:
+ player_list[player_name] = 1
+
+ player_list = list(player_list.keys())
+
+ player_data = {}
+ for player_name in player_list:
+ player_data[player_name] = {'train': None, 'validation': None, 'test': None}
+ train_path = os.path.join(train_dir, player_name + '_{}.npy'.format('train'))
+ val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation'))
+ test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test'))
+
+ player_data[player_name]['train'] = np.load(train_path, allow_pickle=True)
+ player_data[player_name]['train'] = player_data[player_name]['train'].item()
+ player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True)
+ player_data[player_name]['validation'] = player_data[player_name]['validation'].item()
+ player_data[player_name]['test'] = np.load(test_path, allow_pickle=True)
+ player_data[player_name]['test'] = player_data[player_name]['test'].item()
+
+ return player_data
+
+
+def construct_train_set(player_data, is_start_after, move_stop):
+ player_index = {}
+ train_list = []
+ train_label = []
+
+ i = 0
+ for player in player_data.keys():
+ # if player in os.listdir('/data/csvs'):
+ player_index[player] = i
+ train_label.append(i)
+ if is_start_after:
+ train_list.append(player_data[player]['train']['start_after'][move_stop])
+ else:
+ train_list.append(player_data[player]['train']['stop_after'][move_stop])
+
+ i += 1
+
+ train_label = np.asarray(train_label)
+ # one_hot = np.zeros((train_label.size, train_label.max()+1))
+ # one_hot[np.arange(train_label.size),train_label] = 1
+ # print(one_hot.shape)
+
+ train_data = np.stack(train_list, 0)
+ return train_data, train_label, player_index
+
+
+def predict(train_data, train_label, player_data, player_index, is_start_after, move_stop):
+ accurcies = []
+ correct = 0
+ total = 0
+ model = GaussianNB()
+ model.fit(train_data, train_label)
+ results = None
+ for player in player_data.keys():
+ test_game = None
+ test_games = []
+ test = player_data[player]['test']
+
+ # key is game id
+ for key, value in test.items():
+ if is_start_after:
+ test_game = test[key]['start_after'][move_stop]
+ # ignore all 0 cases, essentially there's no more move in this game
+ if all(v == 0 for v in test_game):
+ continue
+ else:
+ test_game = test[key]['stop_after'][move_stop]
+
+ test_games.append(test_game)
+
+ # skip player if all games are beyond move_stop
+ if not test_games:
+ continue
+
+ test_games = np.stack(test_games, axis=0)
+ predicted = model.predict(test_games)
+ result = (predicted == player_index[player]).astype(float)
+
+ # append to the overall result
+ if results is None:
+ results = result
+ else:
+ results = np.append(results, result, 0)
+
+ if results is None:
+ accuracy = 0
+
+ else:
+ accuracy = np.mean(results)
+
+ print(accuracy)
+
+ return accuracy
+
+
+def make_plots(moves, start_after_accuracies, stop_after_accuracies, plot_name):
+ plt.plot(moves, start_after_accuracies, label="Start after x moves")
+ plt.plot(moves, stop_after_accuracies, label="Stop after x moves")
+ plt.legend()
+ plt.xlabel("Moves")
+ plt.savefig(plot_name)
+
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_data = read_npy(args.train_dir, args.input_dir)
+ moves = [i for i in range(101)]
+ start_after_accuracies = []
+ stop_after_accuracies = []
+ output_start_csv = open(args.output_start_after_csv, 'w', newline='')
+ writer_start = csv.writer(output_start_csv)
+ writer_start.writerow(['move', 'accuracy'])
+
+ output_stop_csv = open(args.output_stop_after_csv, 'w', newline='')
+ writer_stop = csv.writer(output_stop_csv)
+ writer_stop.writerow(['move', 'accuracy'])
+
+ for is_start_after in (True, False):
+ for i in range(101):
+ print('testing {} move {}'.format('start_after' if is_start_after else 'stop_after', i))
+ train_data, train_label, player_index = construct_train_set(player_data, is_start_after, i)
+
+ accuracy = predict(train_data, train_label, player_data, player_index, is_start_after, i)
+
+ if is_start_after:
+ start_after_accuracies.append(accuracy)
+ writer_start.writerow([i, accuracy])
+ else:
+ stop_after_accuracies.append(accuracy)
+ writer_stop.writerow([i, accuracy])
+
+ make_plots(moves, start_after_accuracies, stop_after_accuracies, args.saved_plot)
diff --git a/4-cp_loss_stylo_baseline/sweep_num_games.py b/4-cp_loss_stylo_baseline/sweep_num_games.py
new file mode 100755
index 0000000..0caba45
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/sweep_num_games.py
@@ -0,0 +1,225 @@
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+from sklearn.naive_bayes import GaussianNB
+import multiprocessing
+from functools import partial
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--train_dir', default='cp_loss_hist')
+ parser.add_argument('--input_dir', default='cp_loss_count_per_game')
+ parser.add_argument('--use_bayes', default=True)
+ parser.add_argument('--num_games_list', default=[1, 2, 4, 8, 16], type=list)
+ parser.add_argument('--output_csv', default='games_accuracy.csv')
+
+ return parser.parse_args()
+
+def read_npy(train_dir, input_dir):
+ player_list = {}
+ for input_data in os.listdir(input_dir):
+ # will split into [player_name, 'train/test/val']
+ input_name = input_data.split('_')
+ if len(input_name) > 2:
+ player_name = input_name[:-1]
+ player_name = '_'.join(player_name)
+ else:
+ player_name = input_name[0]
+ # add into player list
+ if player_name not in player_list:
+ player_list[player_name] = 1
+
+ player_list = list(player_list.keys())
+
+
+ player_data = {}
+ for player_name in player_list:
+ player_data[player_name] = {'train': None, 'validation': None, 'test': None}
+ train_path = os.path.join(train_dir, player_name + '_{}.npy'.format('train'))
+ val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation'))
+ test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test'))
+
+ player_data[player_name]['train'] = np.load(train_path)
+ player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True)
+ player_data[player_name]['validation'] = player_data[player_name]['validation'].item()
+ player_data[player_name]['test'] = np.load(test_path, allow_pickle=True)
+ player_data[player_name]['test'] = player_data[player_name]['test'].item()
+
+ return player_data
+
+def normalize(data):
+ norm = np.linalg.norm(data)
+ data_norm = data/norm
+ return data_norm
+
+# =============================== Naive Bayes ===============================
+def construct_train_set(player_data):
+ player_index = {}
+ train_list = []
+ train_label = []
+
+ i = 0
+ for player in player_data.keys():
+ player_index[player] = i
+ train_label.append(i)
+ train_list.append(player_data[player]['train'])
+ i += 1
+
+ train_label = np.asarray(train_label)
+ # one_hot = np.zeros((train_label.size, train_label.max()+1))
+ # one_hot[np.arange(train_label.size),train_label] = 1
+ # print(one_hot.shape)
+
+ train_data = np.stack(train_list, 0)
+ return train_data, train_label, player_index
+
+def predict(train_data, train_label, player_data, num_games_list, player_index):
+ print(player_index)
+
+ model = GaussianNB()
+ model.fit(train_data, train_label)
+
+ accuracies = []
+
+ for num_games in num_games_list:
+ results = None
+ print("evaluating with {} games".format(num_games))
+
+ for player in player_data.keys():
+ test = player_data[player]['test']
+ count = 1
+ test_game = None
+ test_games = []
+ for key, value in test.items():
+ if test_game is None:
+ test_game = value
+ else:
+ test_game = test_game + value
+ if count == num_games:
+ # test_game is addition of counts, need to normalize before testing
+ test_game = normalize(test_game)
+ test_games.append(test_game)
+
+ # reset
+ test_game = None
+ count = 1
+
+ else:
+ count += 1
+
+ test_games = np.stack(test_games, axis=0)
+ predicted = model.predict(test_games)
+ result = (predicted == player_index[player]).astype(float)
+
+ if results is None:
+ results = result
+ else:
+ results = np.append(results, result, 0)
+
+ if results is None:
+ accuracy = 0
+ else:
+ accuracy = np.mean(results)
+
+ accuracies.append([num_games, accuracy])
+ print("num_games: {}, accuracy: {}".format(num_games, accuracy))
+
+ return accuracies
+
+# =============================== Euclidean Distance ===============================
+def construct_train_list(player_data):
+ # player_index is {player_name: id} mapping
+ player_index = {}
+ train_list = []
+ i = 0
+ for player in player_data.keys():
+ player_index[player] = i
+ train_list.append(player_data[player]['train'])
+ i += 1
+
+ return train_list, player_index
+
+
+def test_euclidean_dist(train_list, player_data, num_games_list, player_index):
+ accuracies = []
+
+ for num_games in num_games_list:
+ print("evaluating with {} games".format(num_games))
+ correct = 0
+ total = 0
+
+ # loop through each player and test their 'test set'
+ for player in player_data.keys():
+ test = player_data[player]['test']
+ count = 1
+ test_game = None
+
+ for key, value in test.items():
+ if test_game is None:
+ test_game = value
+ else:
+ test_game = test_game + value
+
+ if count == num_games:
+ test_game = normalize(test_game)
+
+ dist_list = []
+ # save distance for each (test, train)
+ for train_data in train_list:
+ dist = np.linalg.norm(train_data - test_game)
+ dist_list.append(dist)
+
+ # find minimum distance and its index
+ min_index = dist_list.index(min(dist_list))
+ if min_index == player_index[player]:
+ correct += 1
+ total += 1
+
+ # reset
+ test_game = None
+ count = 1
+
+ else:
+ count += 1
+
+ accuracies.append([num_games, correct / total])
+ print("num_games: {}, accuracy: {}".format(num_games, correct / total))
+
+ return accuracies
+
+# =============================== run bayes or euclidean ===============================
+def run_bayes(player_data, output_csv, num_games_list):
+
+ train_data, train_label, player_index = construct_train_set(player_data)
+ accuracies = predict(train_data, train_label, player_data, num_games_list, player_index)
+
+ output_csv = open(output_csv, 'w', newline='')
+ writer = csv.writer(output_csv)
+ writer.writerow(['num_games', 'accuracy'])
+ for i in range(len(accuracies)):
+ writer.writerow(accuracies[i])
+
+def run_euclidean_dist(player_data, output_csv, num_games_list):
+ train_list, player_index = construct_train_list(player_data)
+ accuracies = test_euclidean_dist(train_list, player_data, num_games_list, player_index)
+
+ output_csv = open(output_csv, 'w', newline='')
+ writer = csv.writer(output_csv)
+ writer.writerow(['num_games', 'accuracy'])
+ for i in range(len(accuracies)):
+ writer.writerow(accuracies[i])
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_data = read_npy(args.train_dir, args.input_dir)
+
+ if args.use_bayes:
+ run_bayes(player_data, args.output_csv, args.num_games_list)
+ else:
+ run_euclidean_dist(player_data, args.output_csv, args.num_games_list)
+
+
diff --git a/4-cp_loss_stylo_baseline/test_all_games.py b/4-cp_loss_stylo_baseline/test_all_games.py
new file mode 100755
index 0000000..b057a2c
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/test_all_games.py
@@ -0,0 +1,136 @@
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+from sklearn.naive_bayes import GaussianNB
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--input_dir', default='cp_loss_hist')
+ parser.add_argument('--use_bayes', default=True)
+
+ return parser.parse_args()
+
+def read_npy(input_dir):
+
+ player_list = {}
+ for input_data in os.listdir(input_dir):
+ # will split into [player_name, 'train/test/val']
+ input_name = input_data.split('_')
+ if len(input_name) > 2:
+ player_name = input_name[:-1]
+ player_name = '_'.join(player_name)
+ else:
+ player_name = input_name[0]
+ # add into player list
+ if player_name not in player_list:
+ player_list[player_name] = 1
+
+ player_list = list(player_list.keys())
+
+ player_data = {}
+ for player_name in player_list:
+ player_data[player_name] = {'train': None, 'validation': None, 'test': None}
+ train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train'))
+ val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation'))
+ test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test'))
+
+ player_data[player_name]['train'] = np.load(train_path)
+ player_data[player_name]['validation'] = np.load(val_path)
+ player_data[player_name]['test'] = np.load(test_path)
+
+ return player_data
+
+# =============================== Naive Bayes ===============================
+def construct_train_set(player_data):
+ player_index = {}
+ train_list = []
+ train_label = []
+
+ i = 0
+ for player in player_data.keys():
+ player_index[player] = i
+ train_label.append(i)
+ train_list.append(player_data[player]['train'])
+ i += 1
+
+ train_label = np.asarray(train_label)
+
+ train_data = np.stack(train_list, 0)
+ return train_data, train_label, player_index
+
+def predict(train_data, train_label, player_data, player_index):
+ print(player_index)
+ correct = 0
+ total = 0
+ model = GaussianNB()
+ model.fit(train_data, train_label)
+
+ for player in player_data.keys():
+ test = player_data[player]['test']
+ predicted = model.predict(np.expand_dims(test, axis=0))
+ index = predicted[0]
+ if index == player_index[player]:
+ correct += 1
+ total += 1
+
+ print('accuracy is {}'.format(correct / total))
+
+# =============================== Euclidean Distance ===============================
+def construct_train_list(player_data):
+ # player_index is {player_name: id} mapping
+ player_index = {}
+ train_list = []
+ i = 0
+ for player in player_data.keys():
+ player_index[player] = i
+ train_list.append(player_data[player]['train'])
+ i += 1
+
+ return train_list, player_index
+
+def test_euclidean_dist(train_list, player_data, player_index):
+ print(player_index)
+ correct = 0
+ total = 0
+ # loop through each player and test their 'test set'
+ for player in player_data.keys():
+ dist_list = []
+ test = player_data[player]['test']
+
+ # save distance for each (test, train)
+ for train_data in train_list:
+ dist = np.linalg.norm(train_data - test)
+ dist_list.append(dist)
+
+ # find minimum distance and its index
+ min_index = dist_list.index(min(dist_list))
+ if min_index == player_index[player]:
+ correct += 1
+ total += 1
+
+ print('accuracy is {}'.format(correct / total))
+
+# =============================== run bayes or euclidean ===============================
+def run_bayes(player_data):
+ print("Using Naive Bayes")
+ train_data, train_label, player_index = construct_train_set(player_data)
+ predict(train_data, train_label, player_data, player_index)
+
+def run_euclidean_dist(player_data):
+ print("Using Euclidean Distance")
+ train_list, player_index = construct_train_list(player_data)
+ test_euclidean_dist(train_list, player_data, player_index)
+
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ player_data = read_npy(args.input_dir)
+
+ if args.use_bayes:
+ run_bayes(player_data)
+ else:
+ run_euclidean_dist(player_data)
\ No newline at end of file
diff --git a/4-cp_loss_stylo_baseline/train_cploss_per_game.py b/4-cp_loss_stylo_baseline/train_cploss_per_game.py
new file mode 100755
index 0000000..ef5d85c
--- /dev/null
+++ b/4-cp_loss_stylo_baseline/train_cploss_per_game.py
@@ -0,0 +1,144 @@
+import bz2
+import csv
+import argparse
+import os
+import numpy as np
+import tensorflow as tf
+from sklearn.naive_bayes import GaussianNB
+
+def parse_argument():
+ parser = argparse.ArgumentParser(description='arg parser')
+
+ parser.add_argument('--input_dir', default='cp_loss_count_per_game')
+ parser.add_argument('--gpu', default=0, type=int)
+
+ return parser.parse_args()
+
+def normalize(data):
+ norm = np.linalg.norm(data)
+ data_norm = data/norm
+ return data_norm
+
+def read_npy(input_dir):
+
+ player_list = {}
+ for input_data in os.listdir(input_dir):
+ # will split into [player_name, 'train/test/val']
+ input_name = input_data.split('_')
+ if len(input_name) > 2:
+ player_name = input_name[:-1]
+ player_name = '_'.join(player_name)
+ else:
+ player_name = input_name[0]
+ # add into player list
+ if player_name not in player_list:
+ player_list[player_name] = 1
+
+ player_list = list(player_list.keys())
+
+ player_data = {}
+ for player_name in player_list:
+ player_data[player_name] = {'train': None, 'validation': None, 'test': None}
+ train_path = os.path.join(input_dir, player_name + '_{}.npy'.format('train'))
+ val_path = os.path.join(input_dir, player_name + '_{}.npy'.format('validation'))
+ test_path = os.path.join(input_dir, player_name + '_{}.npy'.format('test'))
+
+ player_data[player_name]['train'] = np.load(train_path, allow_pickle=True)
+ player_data[player_name]['train'] = player_data[player_name]['train'].item()
+ player_data[player_name]['validation'] = np.load(val_path, allow_pickle=True)
+ player_data[player_name]['validation'] = player_data[player_name]['validation'].item()
+ player_data[player_name]['test'] = np.load(test_path, allow_pickle=True)
+ player_data[player_name]['test'] = player_data[player_name]['test'].item()
+
+ return player_data
+
+def construct_datasets(player_data):
+ player_index = {}
+ train_list = []
+ train_labels = []
+ validation_list = []
+ validation_labels = []
+ test_list = []
+ test_labels = []
+ i = 0
+ for player in player_data.keys():
+ label = i
+ player_index[player] = i
+ for key, value in player_data[player]['train'].items():
+ train_list.append(normalize(value))
+ train_labels.append(label)
+
+ for key, value in player_data[player]['validation'].items():
+ validation_list.append(normalize(value))
+ validation_labels.append(label)
+
+ for key, value in player_data[player]['test'].items():
+ test_list.append(normalize(value))
+ test_labels.append(label)
+
+ i += 1
+ # convert lists into numpy arrays
+ train_list_np = np.stack(train_list, axis=0)
+ validation_list_np = np.stack(validation_list, axis=0)
+ test_list_np = np.stack(test_list, axis=0)
+
+ train_labels_np = np.stack(train_labels, axis=0)
+ validation_labels_np = np.stack(validation_labels, axis=0)
+ test_labels_np = np.stack(test_labels, axis=0)
+
+ return train_list_np, train_labels_np, validation_list_np, validation_labels_np, test_list_np, test_labels_np, player_index
+
+
+def init_net(output_size):
+ l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001))
+ input_var = tf.keras.Input(shape=(50, ))
+ dense_1 = tf.keras.layers.Dense(40, kernel_initializer='glorot_normal', kernel_regularizer=l2reg, bias_regularizer=l2reg, activation='relu')(input_var)
+ dense_2 = tf.keras.layers.Dense(30, kernel_initializer='glorot_normal', kernel_regularizer=l2reg, bias_regularizer=l2reg)(dense_1)
+
+ model= tf.keras.Model(inputs=input_var, outputs=dense_2)
+ return model
+
+def train(train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index):
+ net = init_net(max(test_labels) + 1)
+ net.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001, clipnorm=1),
+ loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
+ metrics=['accuracy'])
+
+ net.fit(train_dataset, train_labels, batch_size=32, epochs=10, validation_data=(val_dataset, val_labels))
+
+ test_loss, test_acc = net.evaluate(test_dataset, test_labels, verbose=2)
+
+ print('\nTest accuracy:', test_acc)
+
+ return net
+
+# predict is to verify if keras test is correct
+def predict(net, test, test_labels):
+ probability_model = tf.keras.Sequential([net,
+ tf.keras.layers.Softmax()])
+ predictions = probability_model.predict(test)
+
+ correct = 0
+ total = 0
+ for i, prediction in enumerate(predictions):
+ if test_labels[i] == np.argmax(prediction):
+ correct += 1
+ total += 1
+
+ print('test accuracy is: {}'.format(correct / total))
+
+if __name__ == '__main__':
+ args = parse_argument()
+
+ gpus = tf.config.experimental.list_physical_devices('GPU')
+ tf.config.experimental.set_visible_devices(gpus[args.gpu], 'GPU')
+ tf.config.experimental.set_memory_growth(gpus[args.gpu], True)
+
+ player_data = read_npy(args.input_dir)
+
+ train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index = construct_datasets(player_data)
+
+ net = train(train_dataset, train_labels, val_dataset, val_labels, test_dataset, test_labels, player_index)
+
+ # predict is to verify if test is correct
+ # predict(net, test_dataset, test_labels)
diff --git a/9-reduced-data/configs/Best_frozen.yaml b/9-reduced-data/configs/Best_frozen.yaml
new file mode 100644
index 0000000..4415519
--- /dev/null
+++ b/9-reduced-data/configs/Best_frozen.yaml
@@ -0,0 +1,56 @@
+%YAML 1.2
+---
+dataset:
+ name: ''
+ train_path: ''
+ validate_path: ''
+gpu: 3
+model:
+ back_prop_blocks: 3
+ filters: 64
+ keep_weights: true
+ path: maia/1700
+ residual_blocks: 6
+ se_ratio: 8
+training:
+ batch_size: 16
+ checkpoint_small_steps:
+ - 100
+ - 200
+ - 400
+ - 800
+ - 1600
+ - 2500
+ checkpoint_steps: 5000
+ lr_boundaries:
+ - 50000
+ - 110000
+ - 160000
+ lr_values:
+ - 1.0e-05
+ - 1.0e-06
+ - 1.0e-07
+ - 1.0e-08
+ num_batch_splits: 1
+ policy_loss_weight: 1.0
+ precision: half
+ shuffle_size: 256
+ small_mode: true
+ test_small_boundaries:
+ - 20000
+ - 40000
+ - 60000
+ - 80000
+ - 100000
+ test_small_steps:
+ - 100
+ - 200
+ - 400
+ - 800
+ - 1600
+ - 2500
+ test_steps: 2000
+ total_steps: 200000
+ train_avg_report_steps: 50
+ value_loss_weight: 1.0
+...
diff --git a/9-reduced-data/configs/NFP.yaml b/9-reduced-data/configs/NFP.yaml
new file mode 100644
index 0000000..04fba56
--- /dev/null
+++ b/9-reduced-data/configs/NFP.yaml
@@ -0,0 +1,38 @@
+%YAML 1.2
+---
+gpu: 0
+
+dataset:
+ train_path: ''
+ validate_path: ''
+ name: ''
+
+training:
+ precision: 'half'
+ batch_size: 256
+ num_batch_splits: 1
+ test_steps: 1000
+ train_avg_report_steps: 50
+ total_steps: 150000
+ checkpoint_steps: 5000
+ shuffle_size: 256
+ lr_values:
+ - 0.01
+ - 0.001
+ - 0.0001
+ - 0.00001
+ lr_boundaries:
+ - 35000
+ - 80000
+ - 110000
+ policy_loss_weight: 1.0
+ value_loss_weight: 1.0
+
+model:
+ filters: 64
+ residual_blocks: 6
+ se_ratio: 8
+ path: "maia/1900"
+ keep_weights: true
+ back_prop_blocks: 99
+...
diff --git a/9-reduced-data/configs/Tuned.yaml b/9-reduced-data/configs/Tuned.yaml
new file mode 100644
index 0000000..b173853
--- /dev/null
+++ b/9-reduced-data/configs/Tuned.yaml
@@ -0,0 +1,53 @@
+%YAML 1.2
+---
+gpu: 0
+model:
+ back_prop_blocks: 99
+ filters: 64
+ keep_weights: true
+ path: maia/1900
+ residual_blocks: 6
+ se_ratio: 8
+training:
+ batch_size: 16
+ checkpoint_small_steps:
+ - 50
+ - 200
+ - 400
+ - 800
+ - 1600
+ - 2500
+ checkpoint_steps: 5000
+ early_stopping_steps: 10000
+ lr_boundaries:
+ - 50000
+ - 110000
+ - 160000
+ lr_values:
+ - 1.0e-05
+ - 1.0e-06
+ - 1.0e-07
+ - 1.0e-08
+ num_batch_splits: 1
+ policy_loss_weight: 1.0
+ precision: half
+ shuffle_size: 256
+ small_mode: true
+ test_small_boundaries:
+ - 20000
+ - 40000
+ - 60000
+ - 80000
+ - 100000
+ test_small_steps:
+ - 50
+ - 200
+ - 400
+ - 800
+ - 1600
+ - 2500
+ test_steps: 2000
+ total_steps: 200000
+ train_avg_report_steps: 50
+ value_loss_weight: 1.0
+...
diff --git a/README.md b/README.md
index 72c4a5d..afc14c7 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,75 @@
# Learning Models of Individual Behavior in Chess
-## Code release is in progress
+## [website](https://maiachess.com)/[paper](https://arxiv.org/abs/2008.10086)/[code](https://github.com/CSSLab/maia-individual)
-Full paper is on [ArXiv](https://arxiv.org/abs/2008.10086)
+
+
+
+
+## Overview
+
+The main code used in this project is stored in `backend` which is setup as a Python package, running `python setup.py install` will install it. Then the various scripts can be used. We also recommend using the virtual env config we include in `environment.yml` as some packages are required to be up to date. In addition for generating training data two more tools are need [`pgn-extract`](https://www.cs.kent.ac.uk/people/staff/djb/pgn-extract/) to clean the PGNs and [`trainingdata-tool`](https://github.com/DanielUranga/trainingdata-tool) to convert them into training data.
+
+To run a model as a chess engine [`lco`](https://github.com/LeelaChessZero/lc0}{github.com/LeelaChessZero/lc0) version 23 has been tested and should work with all models (you can specify a path to a model with the `-w` argument).
+
+All testing was done with Ubuntu 18.04 with CUDA Version 10.
+
+## Models
+
+We do not have any public models of players publicly available at this time. This is because of the stylometry results shown in section 5.2 of the paper we cannot release anonymous models.
+
+We have included the maia model, from [https://github.com/CSSLab/maia-chess](https://github.com/CSSLab/maia-chess) that was used as the base.
+
+## Running the code
+
+The code for this project is divided into different sections, each has a series of shell scripts that are numbered. If ran in order the training data, then final models can be generated. For the full release we plan to have the process to generate a model more streamlined.
+
+### Quick Run
+
+To get the model for a single player from a single PGN a simpler system can be used first
+
+1. Run `1-data_generation/9-pgn_to_training_data.sh input_PGN_file output_directory player_name`
+2. Create a config file by copying `2-training/final_config.yaml` and adding `output_directory` and `player_name`
+3. Run `python 2-training/train_transfer.py path_to_config`
+4. The final model will be written to `final_models`, read the `--help` for more information
+
+### Full Run
+
+For all scripts if applicable they start with a list of variables, these will need to be edited to match the paths on your system.
+
+The list of players we used was selected using the code in `0-player_counting`. The standard games from lichess [database.lichess.org](database.lichess.org) up to April are required to get our exact results but it should work with other sets, even non-Lichess ones with a bit of work.
+
+Then the players games are extracted and the various sets are constructed from them in `1-data_generation`.
+
+Finally `2-training` has the main training script along with a configuration file that specifies the hyper parameters. All four discussed in the main text are included.
+
+### Extras
+
+The analysis code (`3-analysis`) is included for completeness, but as it is for generating the data used in the plots and relies on various hard coded paths we have not tested it. That said `3-analysis/prediction_generator.py` is the main workhorse and has a `--help`, note it is designed for on files output by `backend.gameToCSVlines`, but less complicated csvs could be used.
+
+The baseline models code and results are in `4-cp_loss_stylo_baseline` this is a simple baseline model to compare our results to, and is included for completeness.
+
+## Reduced Data
+
+The model configurations for the reduced data training are included in `9-reduced-data/configs` To train them yourself simply use the configs in the quick run training.
+
+## Citation
+
+```
+@article{McIlroyYoung_Learning_Models_Chess_2022,
+author = {McIlroy-Young, Reid and Sen, Siddhartha and Kleinberg, Jon and Anderson, Ashton},
+doi = {10.1145/3534678.3539367},
+journal = {KDD '22: Proceedings of the 28th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining},
+month = {8},
+title = {{Learning Models of Individual Behavior in Chess}},
+year = {2022}
+}
+```
+
+## License
+
+The software is available under the GPL License and includes code from the [Leela Chess Zero](https://github.com/LeelaChessZero/lczero-training) project.
+
+## Contact
+
+Please [open an issue](https://github.com/CSSLab/maia-individual/issues/new) or email [Reid McIlroy-Young](https://reidmcy.com/) to get in touch
diff --git a/backend/__init__.py b/backend/__init__.py
new file mode 100755
index 0000000..fdcf55b
--- /dev/null
+++ b/backend/__init__.py
@@ -0,0 +1,8 @@
+from .utils import *
+from .uci_engine import *
+from .pgn_parsering import *
+from .multiproc import *
+from .fen_to_vec import fen_to_vec, array_to_fen, array_to_board, game_to_vecs
+from .pgn_to_csv import *
+
+__version__ = '1.0.0'
diff --git a/backend/fen_to_vec.py b/backend/fen_to_vec.py
new file mode 100755
index 0000000..1c558df
--- /dev/null
+++ b/backend/fen_to_vec.py
@@ -0,0 +1,172 @@
+import re
+
+import chess
+import numpy as np
+
+# Generate the regexs
+boardRE = re.compile(r"(([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)) ((w)|(b)) ((-)|(K)?(Q)?(k)?(q)?)( ((-)|(\w+)))?( \d+)?( \d+)?")
+
+replaceRE = re.compile(r'[1-8/]')
+
+pieceMapWhite = {'E' : [False] * 12}
+pieceMapBlack = {'E' : [False] * 12}
+
+piece_reverse_lookup = {}
+
+all_pieces = 'PNBRQK'
+
+for i, p in enumerate(all_pieces):
+ #White pieces first
+ mP = [False] * 12
+ mP[i] = True
+ pieceMapBlack[p] = mP
+ piece_reverse_lookup[i] = p
+
+ #then black
+ mP = [False] * 12
+ mP[i + len(all_pieces)] = True
+ pieceMapBlack[p.lower()] = mP
+ piece_reverse_lookup[i + len(all_pieces)] = p.lower()
+
+
+ #Black pieces first
+ mP = [False] * 12
+ mP[i] = True
+ pieceMapWhite[p.lower()] = mP
+
+ #then white
+ mP = [False] * 12
+ mP[i + len(all_pieces)] = True
+ pieceMapWhite[p] = mP
+
+iSs = [str(i + 1) for i in range(8)]
+eSubss = [('E' * i, str(i)) for i in range(8,0, -1)]
+castling_vals = 'KQkq'
+
+def toByteBuff(l):
+ return b''.join([b'\1' if e else b'\0' for e in l])
+
+pieceMapBin = {k : toByteBuff(v) for k,v in pieceMapBlack.items()}
+
+def toBin(c):
+ return pieceMapBin[c]
+
+castlesMap = {True : b'\1'*64, False : b'\0'*64}
+
+#Some previous lines are left in just in case
+
+# using N,C,H,W format
+
+move_letters = list('abcdefgh')
+
+moves_lookup = {}
+move_ind = 0
+for r_1 in range(8):
+ for c_1 in range(8):
+ for r_2 in range(8):
+ for c_2 in range(8):
+ moves_lookup[f"{move_letters[r_1]}{c_1+1}{move_letters[r_2]}{c_2+1}"] = move_ind
+ move_ind += 1
+
+def move_to_index(move_str):
+ return moves_lookup[move_str[:4]]
+
+def array_to_preproc(a_target):
+ if not isinstance(a_target, np.ndarray):
+ #check if toch Tensor without importing torch
+ a_target = a_target.cpu().numpy()
+ if a_target.dtype != np.bool_:
+ a_target = a_target.astype(np.bool_)
+ piece_layers = a_target[:12]
+ board_a = np.moveaxis(piece_layers, 2, 0).reshape(64, 12)
+ board_str = ''
+ is_white = bool(a_target[12, 0, 0])
+ castling = [bool(l[0,0]) for l in a_target[13:]]
+ board = [['E'] * 8 for i in range(8)]
+ for i in range(12):
+ for x in range(8):
+ for y in range(8):
+ if piece_layers[i,x,y]:
+ board[x][y] = piece_reverse_lookup[i]
+ board = [''.join(r) for r in board]
+ return ''.join(board), is_white, tuple(castling)
+
+def preproc_to_fen(boardStr, is_white, castling):
+ rows = [boardStr[(i*8):(i*8)+8] for i in range(8)]
+
+ if not is_white:
+ castling = castling[2:] + castling[:2]
+ new_rows = []
+ for b in rows:
+ new_rows.append(b.swapcase()[::-1].replace('e', 'E'))
+
+ rows = reversed(new_rows)
+ row_strs = []
+ for r in rows:
+ for es, i in eSubss:
+ if es in r:
+ r = r.replace(es, i)
+ row_strs.append(r)
+ castle_str = ''
+ for i, v in enumerate(castling):
+ if v:
+ castle_str += castling_vals[i]
+ if len(castle_str) < 1:
+ castle_str = '-'
+
+ is_white_str = 'w' if is_white else 'b'
+ board_str = '/'.join(row_strs)
+ return f"{board_str} {is_white_str} {castle_str} - 0 1"
+
+def array_to_fen(a_target):
+ return preproc_to_fen(*array_to_preproc(a_target))
+
+def array_to_board(a_target):
+ return chess.Board(fen = array_to_fen(a_target))
+
+def simple_fen_vec(boardStr, is_white, castling):
+ castles = [np.frombuffer(castlesMap[c], dtype='bool').reshape(1, 8, 8) for c in castling]
+ board_buff_map = map(toBin, boardStr)
+ board_buff = b''.join(board_buff_map)
+ a = np.frombuffer(board_buff, dtype='bool')
+ a = a.reshape(8, 8, -1)
+ a = np.moveaxis(a, 2, 0)
+ if is_white:
+ colour_plane = np.ones((1, 8, 8), dtype='bool')
+ else:
+ colour_plane = np.zeros((1, 8, 8), dtype='bool')
+
+ return np.concatenate([a, colour_plane, *castles], axis = 0)
+
+def preproc_fen(fenstr):
+ r = boardRE.match(fenstr)
+ if r.group(14):
+ castling = (False, False, False, False)
+ else:
+ castling = (bool(r.group(15)), bool(r.group(16)), bool(r.group(17)), bool(r.group(18)))
+ if r.group(11):
+ is_white = True
+ rows_lst = r.group(1).split('/')
+ else:
+ is_white = False
+ castling = castling[2:] + castling[:2]
+ rows_lst = r.group(1).swapcase().split('/')
+ rows_lst = reversed([s[::-1] for s in rows_lst])
+
+ rowsS = ''.join(rows_lst)
+ for i, iS in enumerate(iSs):
+ if iS in rowsS:
+ rowsS = rowsS.replace(iS, 'E' * (i + 1))
+ return rowsS, is_white, castling
+
+def fen_to_vec(fenstr):
+ return simple_fen_vec(*preproc_fen(fenstr))
+
+def game_to_vecs(game):
+ boards = []
+ board = game.board()
+ for i, node in enumerate(game.mainline()):
+ fen = str(board.fen())
+ board.push(node.move)
+ boards.append(fenToVec(fen))
+ return np.stack(boards, axis = 0)
diff --git a/backend/multiproc.py b/backend/multiproc.py
new file mode 100755
index 0000000..0444a60
--- /dev/null
+++ b/backend/multiproc.py
@@ -0,0 +1,163 @@
+import multiprocessing
+import collections.abc
+import time
+import sys
+import traceback
+import functools
+import pickle
+
+class Multiproc(object):
+ def __init__(self, num_procs, max_queue_size = 1000, proc_check_interval = .1):
+ self.num_procs = num_procs
+ self.max_queue_size = max_queue_size
+ self.proc_check_interval = proc_check_interval
+
+ self.reader = MultiprocIterable
+ self.reader_args = []
+ self.reader_kwargs = {}
+
+ self.processor = MultiprocWorker
+ self.processor_args = []
+ self.processor_kwargs = {}
+
+ self.writer = MultiprocWorker
+ self.writer_args = []
+ self.writer_kwargs = {}
+
+ def reader_init(self, reader_cls, *reader_args, **reader_kwargs):
+ self.reader = reader_cls
+ self.reader_args = reader_args
+ self.reader_kwargs = reader_kwargs
+
+ def processor_init(self, processor_cls, *processor_args, **processor_kwargs):
+ self.processor = processor_cls
+ self.processor_args = processor_args
+ self.processor_kwargs = processor_kwargs
+
+ def writer_init(self, writer_cls, *writer_args, **writer_kwargs):
+ self.writer = writer_cls
+ self.writer_args = writer_args
+ self.writer_kwargs = writer_kwargs
+
+
+ def run(self):
+ with multiprocessing.Pool(self.num_procs + 2) as pool, multiprocessing.Manager() as manager:
+ inputQueue = manager.Queue(self.max_queue_size)
+ resultsQueue = manager.Queue(self.max_queue_size)
+ reader_proc = pool.apply_async(reader_loop, (inputQueue, self.num_procs, self.reader, self.reader_args, self.reader_kwargs))
+
+ worker_procs = []
+ for _ in range(self.num_procs):
+ wp = pool.apply_async(processor_loop, (inputQueue, resultsQueue, self.processor, self.processor_args, self.processor_kwargs))
+ worker_procs.append(wp)
+
+ writer_proc = pool.apply_async(writer_loop, (resultsQueue, self.num_procs, self.writer, self.writer_args, self.writer_kwargs))
+
+ self.cleanup(reader_proc, worker_procs, writer_proc)
+
+ def cleanup(self, reader_proc, worker_procs, writer_proc):
+ reader_working = True
+ processor_working = True
+ writer_working = True
+ while reader_working or processor_working or writer_working:
+ if reader_working and reader_proc.ready():
+ reader_proc.get()
+ reader_working = False
+
+ if processor_working:
+ new_procs = []
+ for p in worker_procs:
+ if p.ready():
+ p.get()
+ else:
+ new_procs.append(p)
+ if len(new_procs) < 1:
+ processor_working = False
+ else:
+ worker_procs = new_procs
+
+ if writer_working and writer_proc.ready():
+ writer_proc.get()
+ writer_working = False
+ time.sleep(self.proc_check_interval)
+
+def catch_remote_exceptions(wrapped_function):
+ """ https://stackoverflow.com/questions/6126007/python-getting-a-traceback """
+
+ @functools.wraps(wrapped_function)
+ def new_function(*args, **kwargs):
+ try:
+ return wrapped_function(*args, **kwargs)
+
+ except:
+ raise Exception( "".join(traceback.format_exception(*sys.exc_info())) )
+
+ return new_function
+
+@catch_remote_exceptions
+def reader_loop(inputQueue, num_workers, reader_cls, reader_args, reader_kwargs):
+ with reader_cls(*reader_args, **reader_kwargs) as R:
+ for dat in R:
+ inputQueue.put(dat)
+ for i in range(num_workers):
+ inputQueue.put(_QueueDone(count = i))
+
+@catch_remote_exceptions
+def processor_loop(inputQueue, resultsQueue, processor_cls, processor_args, processor_kwargs):
+ with processor_cls(*processor_args, **processor_kwargs) as Proc:
+ while True:
+ dat = inputQueue.get()
+ if isinstance(dat, _QueueDone):
+ resultsQueue.put(dat)
+ break
+ try:
+ if isinstance(dat, tuple):
+ procced_dat = Proc(*dat)
+ else:
+ procced_dat = Proc(dat)
+ except SkipCallMultiProc:
+ pass
+ except:
+ raise
+ resultsQueue.put(procced_dat)
+
+@catch_remote_exceptions
+def writer_loop(resultsQueue, num_workers, writer_cls, writer_args, writer_kwargs):
+ complete_workers = 0
+ with writer_cls(*writer_args, **writer_kwargs) as W:
+ if W is None:
+ raise AttributeError(f"Worker was created, but closure failed to form")
+ while complete_workers < num_workers:
+ dat = resultsQueue.get()
+ if isinstance(dat, _QueueDone):
+ complete_workers += 1
+ else:
+ if isinstance(dat, tuple):
+ W(*dat)
+ else:
+ W(dat)
+
+class SkipCallMultiProc(Exception):
+ pass
+
+class _QueueDone(object):
+ def __init__(self, count = 0):
+ self.count = count
+
+class MultiprocWorker(collections.abc.Callable):
+
+ def __call__(self, *args):
+ return None
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+class MultiprocIterable(MultiprocWorker, collections.abc.Iterator):
+ def __next__(self):
+ raise StopIteration
+
+ def __call__(self, *args):
+ return next(self)
diff --git a/backend/pgn_parsering.py b/backend/pgn_parsering.py
new file mode 100755
index 0000000..e776e96
--- /dev/null
+++ b/backend/pgn_parsering.py
@@ -0,0 +1,58 @@
+import re
+import bz2
+
+import chess.pgn
+
+moveRegex = re.compile(r'\d+[.][ \.](\S+) (?:{[^}]*} )?(\S+)')
+
+class GamesFile(object):
+ def __init__(self, path):
+ if path.endswith('bz2'):
+ self.f = bz2.open(path, 'rt')
+ else:
+ self.f = open(path, 'r')
+ self.path = path
+ self.i = 0
+
+ def __iter__(self):
+ try:
+ while True:
+ yield next(self)
+ except StopIteration:
+ return
+
+ def __del__(self):
+ try:
+ self.f.close()
+ except AttributeError:
+ pass
+
+ def __next__(self):
+
+ ret = {}
+ lines = ''
+ for l in self.f:
+ self.i += 1
+ lines += l
+ if len(l) < 2:
+ if len(ret) >= 2:
+ break
+ else:
+ raise RuntimeError(l)
+ else:
+ try:
+ k, v, _ = l.split('"')
+ except ValueError:
+ #bad line
+ if l == 'null\n':
+ pass
+ else:
+ raise
+ else:
+ ret[k[1:-1]] = v
+ nl = self.f.readline()
+ lines += nl
+ lines += self.f.readline()
+ if len(lines) < 1:
+ raise StopIteration
+ return ret, lines
diff --git a/backend/pgn_to_csv.py b/backend/pgn_to_csv.py
new file mode 100755
index 0000000..89dcb8c
--- /dev/null
+++ b/backend/pgn_to_csv.py
@@ -0,0 +1,437 @@
+import chess
+import chess.pgn
+import re
+import datetime
+import json
+import os.path
+import io
+import bz2
+import multiprocessing
+import functools
+
+import pandas
+low_time_threshold = 30
+winrate_blunder_threshold = .1
+
+time_regex = re.compile(r'\[%clk (\d+):(\d+):(\d+)\]')
+eval_regex = re.compile(r'\[%eval ([0-9.+-]+)|(#(-)?[0-9]+)\]')
+low_time_re = re.compile(r'(\d+\. )?\S+ \{ \[%clk 0:00:[210]\d\]')
+
+class NoStockfishEvals(Exception):
+ pass
+
+pieces = {
+ 'pawn' : "P",
+ 'knight' : "N",
+ 'bishop' : "B",
+ 'rook' : "R",
+ 'queen' : "Q",
+ #'king' : "K", #Ignoring kings for the counts
+}
+
+board_stats_header = [
+ 'active_bishop_count',
+ 'active_knight_count',
+ 'active_pawn_count',
+ 'active_queen_count',
+ 'active_rook_count',
+ 'is_check',
+ 'num_legal_moves',
+ 'opp_bishop_count',
+ 'opp_knight_count',
+ 'opp_pawn_count',
+ 'opp_queen_count',
+ 'opp_rook_count',
+]
+
+moveRe = re.compile(r"^\S+")
+probRe = re.compile(r"\(P: +([^)%]+)%\)")
+uRe = re.compile(r"\(U: +([^)]+)\)")
+qRe = re.compile(r"\(Q: +([^)]+)\)")
+nRe = re.compile(r" N: +(\d+) \(")
+
+fenComps = 'rrqn2k1/8/pPp4p/2Pp1pp1/3Pp3/4P1P1/R2NB1PP/1Q4K1 w KQkq - 0 1'.split()
+
+cpLookup = None
+cpLookup_simple = None
+
+def cp_to_winrate(cp, lookup_file = os.path.join(os.path.dirname(__file__), '../data/cp_winrate_lookup_simple.json'), allow_nan = False):
+ global cpLookup_simple
+ try:
+ cp = int(float(cp) * 10) / 10
+ except OverflowError:
+ return float("nan")
+ except ValueError:
+ #This can be caused by a bunch of other things too so this option is dangerous
+ if allow_nan:
+ return float("nan")
+ else:
+ raise
+ if cpLookup_simple is None:
+ with open(lookup_file) as f:
+ cpLookup_str = json.load(f)
+ cpLookup_simple = {float(k) : wr for k, wr in cpLookup_str.items()}
+ try:
+ return cpLookup_simple[cp]
+ except KeyError:
+ return float("nan")
+
+def cp_to_winrate_elo(cp, elo = 1500, lookup_file = os.path.join(os.path.dirname(__file__), '../data/cp_winrate_lookup.json'), allow_nan = False):
+ global cpLookup
+ try:
+ cp = int(float(cp) * 10) / 10
+ elo = int(float(elo)//100) * 100
+ except OverflowError:
+ return float("nan")
+ except ValueError:
+ #This can be caused by a bunch of other things too so this option is dangerous
+ if allow_nan:
+ return float("nan")
+ else:
+ raise
+ if cpLookup is None:
+ with open(lookup_file) as f:
+ cpLookup_str = json.load(f)
+ cpLookup = {}
+ for k, v in cpLookup_str.items():
+ cpLookup[int(k)] = {float(k) : wr for k, wr in v.items()}
+ try:
+ return cpLookup[elo][cp]
+ except KeyError:
+ return float("nan")
+
+def board_stats(input_board, board_fen = None):
+ if isinstance(input_board, str):
+ board = chess.Board(fen=input_board)
+ board_fen = input_board
+ else:
+ board = input_board
+ if board_fen is None:
+ board_fen = input_board.fen()
+ board_str = board_fen.split(' ')[0]
+ dat = {
+ 'num_legal_moves' : len(list(board.legal_moves)),
+ 'is_check' : int(board.is_check())
+ }
+ for name, p in pieces.items():
+ if active_is_white(board_fen):
+ dat[f'active_{name}_count'] = board_fen.count(p)
+ dat[f'opp_{name}_count'] = board_fen.count(p.lower())
+ else:
+ dat[f'active_{name}_count'] = board_fen.count(p.lower())
+ dat[f'opp_{name}_count'] = board_fen.count(p)
+ return dat
+
+def active_is_white(fen_str):
+ return fen_str.split(' ')[1] == 'w'
+
+def time_control_to_secs(timeStr, moves_per_game = 35):
+ if timeStr == '-':
+ return 10800 # 180 minutes per side max on lichess
+ else:
+ t_base, t_add = timeStr.split('+')
+ return int(t_base) + int(t_add) * moves_per_game
+
+def fen_extend(s):
+ splitS = s.split()
+ return ' '.join(splitS + fenComps[len(splitS):])
+
+def fen(s):
+ return chess.Board(fen_extend(s))
+
+def gameToFenSeq(game):
+ headers = dict(game)
+ moves = getBoardMoveMap(game)
+ return {'headers' : headers, 'moves' : moves}
+
+def getMoveStats(s):
+ return {
+ 'move' : moveRe.match(s).group(0),
+ 'prob' : float(probRe.search(s).group(1)) / 100,
+ 'U' : float(uRe.search(s).group(1)),
+ 'Q' : float(qRe.search(s).group(1)),
+ 'N' : float(nRe.search(s).group(1)),
+ }
+
+def movesToUCI(moves, board):
+ if isinstance(board, str):
+ board = fen(board)
+ moveMap = {}
+ for m in moves:
+ board.push_san(m)
+ moveMap[m] = board.pop().uci()
+ return moveMap
+
+def getSeqs(inputNode):
+ retSeqs = []
+ for k, v in list(inputNode.items()):
+ if k == 'hits' or k == 'sfeval':
+ pass
+ elif len(v) <= 2:
+ retSeqs.append([k])
+ else:
+ retSeqs += [[k] + s for s in getSeqs(v)]
+ return retSeqs
+
+def moveSeqToBoard(seq):
+ board = chess.Board()
+ for m in seq:
+ board.push_san(m.replace('?', '').replace('!', ''))
+ return board
+
+def makeFEN(seq):
+ board = moveSeqToBoard(seq)
+ return ','.join(seq), board.fen(), len(list(board.legal_moves))
+
+def moveTreeLookup(d, procs = 64):
+ sequences = getSeqs(d)
+ with multiprocessing.Pool(procs) as pool:
+ maps = pool.map(makeFEN, sequences)
+ return maps
+
+colours = {
+ 'blue' : '\033[94m',
+ 'green' : '\033[92m',
+ 'yellow' : '\033[93m',
+ 'red' : '\033[91m',
+ 'pink' : '\033[95m',
+}
+endColour = '\033[0m'
+
+all_per_game_vals = [
+ 'game_id',
+ 'type',
+ 'result',
+ 'white_player',
+ 'black_player',
+ 'white_elo',
+ 'black_elo',
+ 'time_control',
+ 'num_ply',
+ 'termination',
+ 'white_won',
+ 'black_won',
+ 'no_winner',
+]
+
+
+per_game_funcs = {
+ 'game_id' : lambda x : x['Site'].split('/')[-1],
+ 'type' : lambda x : x['Event'].split(' tournament')[0].replace(' game', '').replace('Rated ', ''),
+ 'result' : lambda x : x['Result'],
+ 'white_player' : lambda x : x['White'],
+ 'black_player' : lambda x : x['Black'],
+ 'white_elo' : lambda x : x['WhiteElo'],
+ 'black_elo' : lambda x : x['BlackElo'],
+ 'time_control' : lambda x : x['TimeControl'],
+ 'termination' : lambda x : x['Termination'],
+ 'white_won' : lambda x : x['Result'] == '1-0',
+ 'black_won' : lambda x : x['Result'] == '0-1',
+ 'no_winner' : lambda x : x['Result'] not in ['1-0', '0-1'],
+}
+
+all_per_move_vals = [
+ 'move_ply',
+ 'move',
+ 'cp',
+ 'cp_rel',
+ 'cp_loss',
+ 'is_blunder_cp',
+ 'winrate',
+ 'winrate_elo',
+ 'winrate_loss',
+ 'is_blunder_wr',
+ 'opp_winrate',
+ 'white_active',
+ 'active_player',
+ 'active_elo',
+ 'opponent_elo',
+ 'active_won',
+ 'is_capture',
+ 'clock',
+ 'opp_clock',
+ 'clock_percent',
+ 'opp_clock_percent',
+ 'low_time',
+ 'board',
+]
+
+per_move_funcs = {
+ 'move_ply' : lambda x : x['i'],
+ 'move' : lambda x : x['node'].move,
+ 'cp' : lambda x : x['cp_str_last'],
+ 'cp_rel' : lambda x : x['cp_rel_str_last'],
+ 'cp_loss' : lambda x : f"{x['cp_loss']:.2f}",
+ 'is_blunder_cp' : lambda x : x['cp_loss'] >= 2,
+ 'winrate' : lambda x : f"{x['winrate_current']:.4f}",
+ 'winrate_elo' : lambda x : f"{x['winrate_current_elo']:.4f}",
+ 'winrate_loss' : lambda x :
+ f"{x['winrate_loss']:.4f}",
+ 'is_blunder_wr' : lambda x : x['winrate_loss'] > winrate_blunder_threshold,
+ 'opp_winrate' : lambda x : f"{x['winrate_opp']:.4f}",
+ 'white_active' : lambda x : x['is_white'],
+ 'active_player' : lambda x : x['white_player'] if x['is_white'] else x['black_player'],
+ 'active_elo' : lambda x : x['act_elo'],
+ 'opponent_elo' : lambda x : x['opp_elo'],
+ 'active_won' : lambda x : x['act_won'],
+ 'is_capture' : lambda x : x['board'].is_capture(x['node'].move),
+ 'clock' : lambda x : x['clock_seconds'],
+ 'opp_clock' : lambda x : x['last_clock_seconds'],
+ 'clock_percent' : lambda x : '' if x['no_time'] else f"{1 - x['clock_seconds']/x['time_per_player']:.3f}",
+ 'opp_clock_percent' : lambda x : '' if x['no_time'] else f"{1 - x['last_clock_seconds']/x['time_per_player']:.3f}",
+ 'low_time' : lambda x : '' if x['no_time'] else x['clock_seconds'] < low_time_threshold,
+ 'board' : lambda x : x['fen'],
+}
+
+full_csv_header = all_per_game_vals + all_per_move_vals + board_stats_header
+
+def gameToDF(input_game, per_game_vals = None, per_move_vals = None, with_board_stats = True, allow_non_sf = False):
+ """Hack to make dataframes instead of CSVs while maintaining the smae code as much as possible"""
+
+ csv_lines = gameToCSVlines(input_game, per_game_vals = per_game_vals, per_move_vals = per_move_vals, with_board_stats = with_board_stats, allow_non_sf = allow_non_sf)
+
+ csv_header = list(per_game_vals) + list(per_move_vals)
+ if with_board_stats:
+ csv_header = csv_header + board_stats_header
+
+ # a hack, but makes things consistant
+ return pandas.read_csv(io.StringIO('\n'.join(csv_lines)), names = csv_header)
+
+def gameToCSVlines(input_game, per_game_vals = None, per_move_vals = None, with_board_stats = True, allow_non_sf = True):
+ """Main function in created the datasets
+
+ There's per game and per board stuff that needs to be calculated, with_board_stats is just a bunch of material counts.
+
+ The different functions that are applied are simple and mostly stored in two dicts: per_game_funcs and per_move_funcs. per_move_funcs are more complicated and can depend on a bunch of stuff so they just get locals() as an input which is a hack, but it works. They all used to be in the local namespace this was just much simpler than rewriting all of them.
+ """
+ #defaults to everything
+ if isinstance(input_game, str):
+ game = chess.pgn.read_game(io.StringIO(input_game))
+ else:
+ game = input_game
+
+ if per_game_vals is None:
+ per_game_vals = all_per_game_vals
+ if per_move_vals is None:
+ per_move_vals = all_per_move_vals
+
+ gameVals = []
+ retVals = []
+
+ for n in per_game_vals:
+ try:
+ gameVals.append(per_game_funcs[n](game.headers))
+ except KeyError as e:
+ if n == 'num_ply':
+ gameVals.append(len(list(game.mainline())))
+ else:
+ raise KeyError(f"{e} for header: {game.headers}\ngame: {input_game}")
+
+ gameVals = [str(v) for v in gameVals]
+
+ white_won = game.headers['Result'] == '1-0'
+ no_winner = game.headers['Result'] not in ['1-0', '0-1']
+
+ time_per_player = time_control_to_secs(game.headers['TimeControl'])
+
+ board = game.board()
+ cp_board = .1
+ cp_str_last = '0.1'
+ cp_rel_str_last = '0.1'
+ no_time = False
+ last_clock_seconds = -1
+ white_player = game.headers['White']
+ black_player = game.headers['Black']
+
+ for i, node in enumerate(game.mainline()):
+ comment = node.comment.replace('\n', ' ')
+ moveVals = []
+ fen = str(board.fen())
+ is_white = fen.split(' ')[1] == 'w'
+
+ try:
+ cp_re = eval_regex.search(comment)
+ cp_str = cp_re.group(1)
+ except AttributeError:
+ if i > 2:
+ #Indicates mate
+ if not is_white:
+ cp_str = '#-0'
+ cp_after = float('-inf')
+ else:
+ cp_str = '#0'
+ cp_after = float('inf')
+ else:
+ if not allow_non_sf:
+ break
+ else:
+ cp_str = 'nan'
+ cp_after = float('nan')
+ #raise AttributeError(f"weird comment found: {node.comment}")
+ else:
+ if cp_str is not None:
+ try:
+ cp_after = float(cp_str)
+ except ValueError:
+ if '-' in comment:
+ cp_after = float('-inf')
+ else:
+ cp_after = float('inf')
+ else:
+ if cp_re.group(3) is None:
+ cp_after = float('inf')
+ else:
+ cp_after = float('-inf')
+ if not is_white:
+ cp_after *= -1
+ cp_rel_str = str(-cp_after)
+ if not no_time:
+ try:
+ timesRe = time_regex.search(comment)
+
+ clock_seconds = int(timesRe.group(1)) * 60 * 60 + int(timesRe.group(2)) * 60 + int(timesRe.group(3))
+
+ except AttributeError:
+ no_time = True
+ clock_seconds = ''
+
+ # make equal on first move
+ if last_clock_seconds < 0:
+ last_clock_seconds = clock_seconds
+
+ act_elo = game.headers['WhiteElo'] if is_white else game.headers['BlackElo']
+ opp_elo = game.headers['BlackElo'] if is_white else game.headers['WhiteElo']
+ if no_winner:
+ act_won = False
+ elif is_white:
+ act_won = white_won
+ else:
+ act_won = not white_won
+
+ cp_loss = cp_board - cp_after # CPs are all relative
+
+ winrate_current_elo = cp_to_winrate_elo(cp_board, elo = act_elo, allow_nan = allow_non_sf)
+ winrate_current = cp_to_winrate(cp_board, allow_nan = allow_non_sf)
+
+ winrate_loss = winrate_current -cp_to_winrate(cp_after, allow_nan = allow_non_sf)
+
+ winrate_opp = cp_to_winrate(-cp_board, allow_nan = allow_non_sf)
+
+ for n in per_move_vals:
+ moveVals.append(per_move_funcs[n](locals()))
+
+ if with_board_stats:
+ moveVals += [str(v) for k,v in sorted(board_stats(board, fen).items(), key = lambda x : x[0])]
+
+ board.push(node.move)
+
+ moveVals = [str(v) for v in moveVals]
+
+ retVals.append(','.join(gameVals + moveVals))
+ cp_board = -1 * cp_after
+ cp_str_last = cp_str
+ cp_rel_str_last = cp_rel_str
+ last_clock_seconds = clock_seconds
+ if len(retVals) < 1 and not allow_non_sf:
+ raise NoStockfishEvals("No evals found in game")
+ return retVals
diff --git a/backend/proto/__init__.py b/backend/proto/__init__.py
new file mode 100644
index 0000000..914b673
--- /dev/null
+++ b/backend/proto/__init__.py
@@ -0,0 +1 @@
+from .net_pb2 import Net, NetworkFormat
diff --git a/backend/proto/net.proto b/backend/proto/net.proto
new file mode 100644
index 0000000..0262b2e
--- /dev/null
+++ b/backend/proto/net.proto
@@ -0,0 +1,163 @@
+/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2018 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+
+ Additional permission under GNU GPL version 3 section 7
+
+ If you modify this Program, or any covered work, by linking or
+ combining it with NVIDIA Corporation's libraries from the NVIDIA CUDA
+ Toolkit and the NVIDIA CUDA Deep Neural Network library (or a
+ modified version of those libraries), containing parts covered by the
+ terms of the respective license agreement, the licensors of this
+ Program grant you additional permission to convey the resulting work.
+*/
+syntax = "proto2";
+
+package pblczero;
+
+message EngineVersion {
+ optional uint32 major = 1;
+ optional uint32 minor = 2;
+ optional uint32 patch = 3;
+}
+
+message Weights {
+ message Layer {
+ optional float min_val = 1;
+ optional float max_val = 2;
+ optional bytes params = 3;
+ }
+
+ message ConvBlock {
+ optional Layer weights = 1;
+ optional Layer biases = 2;
+ optional Layer bn_means = 3;
+ optional Layer bn_stddivs = 4;
+ optional Layer bn_gammas = 5;
+ optional Layer bn_betas = 6;
+ }
+
+ message SEunit {
+ // Squeeze-excitation unit (https://arxiv.org/abs/1709.01507)
+ // weights and biases of the two fully connected layers.
+ optional Layer w1 = 1;
+ optional Layer b1 = 2;
+ optional Layer w2 = 3;
+ optional Layer b2 = 4;
+ }
+
+ message Residual {
+ optional ConvBlock conv1 = 1;
+ optional ConvBlock conv2 = 2;
+ optional SEunit se = 3;
+ }
+
+ // Input convnet.
+ optional ConvBlock input = 1;
+
+ // Residual tower.
+ repeated Residual residual = 2;
+
+ // Policy head
+ // Extra convolution for AZ-style policy head
+ optional ConvBlock policy1 = 11;
+ optional ConvBlock policy = 3;
+ optional Layer ip_pol_w = 4;
+ optional Layer ip_pol_b = 5;
+
+ // Value head
+ optional ConvBlock value = 6;
+ optional Layer ip1_val_w = 7;
+ optional Layer ip1_val_b = 8;
+ optional Layer ip2_val_w = 9;
+ optional Layer ip2_val_b = 10;
+}
+
+message TrainingParams {
+ optional uint32 training_steps = 1;
+ optional float learning_rate = 2;
+ optional float mse_loss = 3;
+ optional float policy_loss = 4;
+ optional float accuracy = 5;
+ optional string lc0_params = 6;
+}
+
+message NetworkFormat {
+ // Format to encode the input planes with. Used by position encoder.
+ enum InputFormat {
+ INPUT_UNKNOWN = 0;
+ INPUT_CLASSICAL_112_PLANE = 1;
+ // INPUT_WITH_COORDINATE_PLANES = 2; // Example. Uncomment/rename.
+ }
+ optional InputFormat input = 1;
+
+ // Output format of the NN. Used by search code to interpret results.
+ enum OutputFormat {
+ OUTPUT_UNKNOWN = 0;
+ OUTPUT_CLASSICAL = 1;
+ OUTPUT_WDL = 2;
+ }
+ optional OutputFormat output = 2;
+
+ // Network architecture. Used by backends to build the network.
+ enum NetworkStructure {
+ // Networks without PolicyFormat or ValueFormat specified
+ NETWORK_UNKNOWN = 0;
+ NETWORK_CLASSICAL = 1;
+ NETWORK_SE = 2;
+ // Networks with PolicyFormat and ValueFormat specified
+ NETWORK_CLASSICAL_WITH_HEADFORMAT = 3;
+ NETWORK_SE_WITH_HEADFORMAT = 4;
+ }
+ optional NetworkStructure network = 3;
+
+ // Policy head architecture
+ enum PolicyFormat {
+ POLICY_UNKNOWN = 0;
+ POLICY_CLASSICAL = 1;
+ POLICY_CONVOLUTION = 2;
+ }
+ optional PolicyFormat policy = 4;
+
+ // Value head architecture
+ enum ValueFormat {
+ VALUE_UNKNOWN = 0;
+ VALUE_CLASSICAL = 1;
+ VALUE_WDL = 2;
+ }
+ optional ValueFormat value = 5;
+}
+
+message Format {
+ enum Encoding {
+ UNKNOWN = 0;
+ LINEAR16 = 1;
+ }
+
+ optional Encoding weights_encoding = 1;
+ // If network_format is missing, it's assumed to have
+ // INPUT_CLASSICAL_112_PLANE / OUTPUT_CLASSICAL / NETWORK_CLASSICAL format.
+ optional NetworkFormat network_format = 2;
+}
+
+message Net {
+ optional fixed32 magic = 1;
+ optional string license = 2;
+ optional EngineVersion min_version = 3;
+ optional Format format = 4;
+ optional TrainingParams training_params = 5;
+ optional Weights weights = 10;
+}
diff --git a/backend/proto/net_pb2.py b/backend/proto/net_pb2.py
new file mode 100644
index 0000000..fe28f04
--- /dev/null
+++ b/backend/proto/net_pb2.py
@@ -0,0 +1,895 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: net.proto
+
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor.FileDescriptor(
+ name='net.proto',
+ package='pblczero',
+ syntax='proto2',
+ serialized_options=None,
+ serialized_pb=b'\n\tnet.proto\x12\x08pblczero\"<\n\rEngineVersion\x12\r\n\x05major\x18\x01 \x01(\r\x12\r\n\x05minor\x18\x02 \x01(\r\x12\r\n\x05patch\x18\x03 \x01(\r\"\xe5\x08\n\x07Weights\x12*\n\x05input\x18\x01 \x01(\x0b\x32\x1b.pblczero.Weights.ConvBlock\x12,\n\x08residual\x18\x02 \x03(\x0b\x32\x1a.pblczero.Weights.Residual\x12,\n\x07policy1\x18\x0b \x01(\x0b\x32\x1b.pblczero.Weights.ConvBlock\x12+\n\x06policy\x18\x03 \x01(\x0b\x32\x1b.pblczero.Weights.ConvBlock\x12)\n\x08ip_pol_w\x18\x04 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12)\n\x08ip_pol_b\x18\x05 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12*\n\x05value\x18\x06 \x01(\x0b\x32\x1b.pblczero.Weights.ConvBlock\x12*\n\tip1_val_w\x18\x07 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12*\n\tip1_val_b\x18\x08 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12*\n\tip2_val_w\x18\t \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12*\n\tip2_val_b\x18\n \x01(\x0b\x32\x17.pblczero.Weights.Layer\x1a\x39\n\x05Layer\x12\x0f\n\x07min_val\x18\x01 \x01(\x02\x12\x0f\n\x07max_val\x18\x02 \x01(\x02\x12\x0e\n\x06params\x18\x03 \x01(\x0c\x1a\x8d\x02\n\tConvBlock\x12(\n\x07weights\x18\x01 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12\'\n\x06\x62iases\x18\x02 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12)\n\x08\x62n_means\x18\x03 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12+\n\nbn_stddivs\x18\x04 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12*\n\tbn_gammas\x18\x05 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12)\n\x08\x62n_betas\x18\x06 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x1a\x9c\x01\n\x06SEunit\x12#\n\x02w1\x18\x01 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12#\n\x02\x62\x31\x18\x02 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12#\n\x02w2\x18\x03 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x12#\n\x02\x62\x32\x18\x04 \x01(\x0b\x32\x17.pblczero.Weights.Layer\x1a\x88\x01\n\x08Residual\x12*\n\x05\x63onv1\x18\x01 \x01(\x0b\x32\x1b.pblczero.Weights.ConvBlock\x12*\n\x05\x63onv2\x18\x02 \x01(\x0b\x32\x1b.pblczero.Weights.ConvBlock\x12$\n\x02se\x18\x03 \x01(\x0b\x32\x18.pblczero.Weights.SEunit\"\x8c\x01\n\x0eTrainingParams\x12\x16\n\x0etraining_steps\x18\x01 \x01(\r\x12\x15\n\rlearning_rate\x18\x02 \x01(\x02\x12\x10\n\x08mse_loss\x18\x03 \x01(\x02\x12\x13\n\x0bpolicy_loss\x18\x04 \x01(\x02\x12\x10\n\x08\x61\x63\x63uracy\x18\x05 \x01(\x02\x12\x12\n\nlc0_params\x18\x06 \x01(\t\"\xd9\x05\n\rNetworkFormat\x12\x32\n\x05input\x18\x01 \x01(\x0e\x32#.pblczero.NetworkFormat.InputFormat\x12\x34\n\x06output\x18\x02 \x01(\x0e\x32$.pblczero.NetworkFormat.OutputFormat\x12\x39\n\x07network\x18\x03 \x01(\x0e\x32(.pblczero.NetworkFormat.NetworkStructure\x12\x34\n\x06policy\x18\x04 \x01(\x0e\x32$.pblczero.NetworkFormat.PolicyFormat\x12\x32\n\x05value\x18\x05 \x01(\x0e\x32#.pblczero.NetworkFormat.ValueFormat\"?\n\x0bInputFormat\x12\x11\n\rINPUT_UNKNOWN\x10\x00\x12\x1d\n\x19INPUT_CLASSICAL_112_PLANE\x10\x01\"H\n\x0cOutputFormat\x12\x12\n\x0eOUTPUT_UNKNOWN\x10\x00\x12\x14\n\x10OUTPUT_CLASSICAL\x10\x01\x12\x0e\n\nOUTPUT_WDL\x10\x02\"\x95\x01\n\x10NetworkStructure\x12\x13\n\x0fNETWORK_UNKNOWN\x10\x00\x12\x15\n\x11NETWORK_CLASSICAL\x10\x01\x12\x0e\n\nNETWORK_SE\x10\x02\x12%\n!NETWORK_CLASSICAL_WITH_HEADFORMAT\x10\x03\x12\x1e\n\x1aNETWORK_SE_WITH_HEADFORMAT\x10\x04\"P\n\x0cPolicyFormat\x12\x12\n\x0ePOLICY_UNKNOWN\x10\x00\x12\x14\n\x10POLICY_CLASSICAL\x10\x01\x12\x16\n\x12POLICY_CONVOLUTION\x10\x02\"D\n\x0bValueFormat\x12\x11\n\rVALUE_UNKNOWN\x10\x00\x12\x13\n\x0fVALUE_CLASSICAL\x10\x01\x12\r\n\tVALUE_WDL\x10\x02\"\x95\x01\n\x06\x46ormat\x12\x33\n\x10weights_encoding\x18\x01 \x01(\x0e\x32\x19.pblczero.Format.Encoding\x12/\n\x0enetwork_format\x18\x02 \x01(\x0b\x32\x17.pblczero.NetworkFormat\"%\n\x08\x45ncoding\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0c\n\x08LINEAR16\x10\x01\"\xcc\x01\n\x03Net\x12\r\n\x05magic\x18\x01 \x01(\x07\x12\x0f\n\x07license\x18\x02 \x01(\t\x12,\n\x0bmin_version\x18\x03 \x01(\x0b\x32\x17.pblczero.EngineVersion\x12 \n\x06\x66ormat\x18\x04 \x01(\x0b\x32\x10.pblczero.Format\x12\x31\n\x0ftraining_params\x18\x05 \x01(\x0b\x32\x18.pblczero.TrainingParams\x12\"\n\x07weights\x18\n \x01(\x0b\x32\x11.pblczero.Weights'
+)
+
+
+
+_NETWORKFORMAT_INPUTFORMAT = _descriptor.EnumDescriptor(
+ name='InputFormat',
+ full_name='pblczero.NetworkFormat.InputFormat',
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name='INPUT_UNKNOWN', index=0, number=0,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='INPUT_CLASSICAL_112_PLANE', index=1, number=1,
+ serialized_options=None,
+ type=None),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=1645,
+ serialized_end=1708,
+)
+_sym_db.RegisterEnumDescriptor(_NETWORKFORMAT_INPUTFORMAT)
+
+_NETWORKFORMAT_OUTPUTFORMAT = _descriptor.EnumDescriptor(
+ name='OutputFormat',
+ full_name='pblczero.NetworkFormat.OutputFormat',
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name='OUTPUT_UNKNOWN', index=0, number=0,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='OUTPUT_CLASSICAL', index=1, number=1,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='OUTPUT_WDL', index=2, number=2,
+ serialized_options=None,
+ type=None),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=1710,
+ serialized_end=1782,
+)
+_sym_db.RegisterEnumDescriptor(_NETWORKFORMAT_OUTPUTFORMAT)
+
+_NETWORKFORMAT_NETWORKSTRUCTURE = _descriptor.EnumDescriptor(
+ name='NetworkStructure',
+ full_name='pblczero.NetworkFormat.NetworkStructure',
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name='NETWORK_UNKNOWN', index=0, number=0,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='NETWORK_CLASSICAL', index=1, number=1,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='NETWORK_SE', index=2, number=2,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='NETWORK_CLASSICAL_WITH_HEADFORMAT', index=3, number=3,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='NETWORK_SE_WITH_HEADFORMAT', index=4, number=4,
+ serialized_options=None,
+ type=None),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=1785,
+ serialized_end=1934,
+)
+_sym_db.RegisterEnumDescriptor(_NETWORKFORMAT_NETWORKSTRUCTURE)
+
+_NETWORKFORMAT_POLICYFORMAT = _descriptor.EnumDescriptor(
+ name='PolicyFormat',
+ full_name='pblczero.NetworkFormat.PolicyFormat',
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name='POLICY_UNKNOWN', index=0, number=0,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='POLICY_CLASSICAL', index=1, number=1,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='POLICY_CONVOLUTION', index=2, number=2,
+ serialized_options=None,
+ type=None),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=1936,
+ serialized_end=2016,
+)
+_sym_db.RegisterEnumDescriptor(_NETWORKFORMAT_POLICYFORMAT)
+
+_NETWORKFORMAT_VALUEFORMAT = _descriptor.EnumDescriptor(
+ name='ValueFormat',
+ full_name='pblczero.NetworkFormat.ValueFormat',
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name='VALUE_UNKNOWN', index=0, number=0,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='VALUE_CLASSICAL', index=1, number=1,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='VALUE_WDL', index=2, number=2,
+ serialized_options=None,
+ type=None),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=2018,
+ serialized_end=2086,
+)
+_sym_db.RegisterEnumDescriptor(_NETWORKFORMAT_VALUEFORMAT)
+
+_FORMAT_ENCODING = _descriptor.EnumDescriptor(
+ name='Encoding',
+ full_name='pblczero.Format.Encoding',
+ filename=None,
+ file=DESCRIPTOR,
+ values=[
+ _descriptor.EnumValueDescriptor(
+ name='UNKNOWN', index=0, number=0,
+ serialized_options=None,
+ type=None),
+ _descriptor.EnumValueDescriptor(
+ name='LINEAR16', index=1, number=1,
+ serialized_options=None,
+ type=None),
+ ],
+ containing_type=None,
+ serialized_options=None,
+ serialized_start=2201,
+ serialized_end=2238,
+)
+_sym_db.RegisterEnumDescriptor(_FORMAT_ENCODING)
+
+
+_ENGINEVERSION = _descriptor.Descriptor(
+ name='EngineVersion',
+ full_name='pblczero.EngineVersion',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='major', full_name='pblczero.EngineVersion.major', index=0,
+ number=1, type=13, cpp_type=3, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='minor', full_name='pblczero.EngineVersion.minor', index=1,
+ number=2, type=13, cpp_type=3, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='patch', full_name='pblczero.EngineVersion.patch', index=2,
+ number=3, type=13, cpp_type=3, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=23,
+ serialized_end=83,
+)
+
+
+_WEIGHTS_LAYER = _descriptor.Descriptor(
+ name='Layer',
+ full_name='pblczero.Weights.Layer',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='min_val', full_name='pblczero.Weights.Layer.min_val', index=0,
+ number=1, type=2, cpp_type=6, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='max_val', full_name='pblczero.Weights.Layer.max_val', index=1,
+ number=2, type=2, cpp_type=6, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='params', full_name='pblczero.Weights.Layer.params', index=2,
+ number=3, type=12, cpp_type=9, label=1,
+ has_default_value=False, default_value=b"",
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=584,
+ serialized_end=641,
+)
+
+_WEIGHTS_CONVBLOCK = _descriptor.Descriptor(
+ name='ConvBlock',
+ full_name='pblczero.Weights.ConvBlock',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='weights', full_name='pblczero.Weights.ConvBlock.weights', index=0,
+ number=1, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='biases', full_name='pblczero.Weights.ConvBlock.biases', index=1,
+ number=2, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bn_means', full_name='pblczero.Weights.ConvBlock.bn_means', index=2,
+ number=3, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bn_stddivs', full_name='pblczero.Weights.ConvBlock.bn_stddivs', index=3,
+ number=4, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bn_gammas', full_name='pblczero.Weights.ConvBlock.bn_gammas', index=4,
+ number=5, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='bn_betas', full_name='pblczero.Weights.ConvBlock.bn_betas', index=5,
+ number=6, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=644,
+ serialized_end=913,
+)
+
+_WEIGHTS_SEUNIT = _descriptor.Descriptor(
+ name='SEunit',
+ full_name='pblczero.Weights.SEunit',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='w1', full_name='pblczero.Weights.SEunit.w1', index=0,
+ number=1, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='b1', full_name='pblczero.Weights.SEunit.b1', index=1,
+ number=2, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='w2', full_name='pblczero.Weights.SEunit.w2', index=2,
+ number=3, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='b2', full_name='pblczero.Weights.SEunit.b2', index=3,
+ number=4, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=916,
+ serialized_end=1072,
+)
+
+_WEIGHTS_RESIDUAL = _descriptor.Descriptor(
+ name='Residual',
+ full_name='pblczero.Weights.Residual',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='conv1', full_name='pblczero.Weights.Residual.conv1', index=0,
+ number=1, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='conv2', full_name='pblczero.Weights.Residual.conv2', index=1,
+ number=2, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='se', full_name='pblczero.Weights.Residual.se', index=2,
+ number=3, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=1075,
+ serialized_end=1211,
+)
+
+_WEIGHTS = _descriptor.Descriptor(
+ name='Weights',
+ full_name='pblczero.Weights',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='input', full_name='pblczero.Weights.input', index=0,
+ number=1, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='residual', full_name='pblczero.Weights.residual', index=1,
+ number=2, type=11, cpp_type=10, label=3,
+ has_default_value=False, default_value=[],
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='policy1', full_name='pblczero.Weights.policy1', index=2,
+ number=11, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='policy', full_name='pblczero.Weights.policy', index=3,
+ number=3, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='ip_pol_w', full_name='pblczero.Weights.ip_pol_w', index=4,
+ number=4, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='ip_pol_b', full_name='pblczero.Weights.ip_pol_b', index=5,
+ number=5, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='value', full_name='pblczero.Weights.value', index=6,
+ number=6, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='ip1_val_w', full_name='pblczero.Weights.ip1_val_w', index=7,
+ number=7, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='ip1_val_b', full_name='pblczero.Weights.ip1_val_b', index=8,
+ number=8, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='ip2_val_w', full_name='pblczero.Weights.ip2_val_w', index=9,
+ number=9, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='ip2_val_b', full_name='pblczero.Weights.ip2_val_b', index=10,
+ number=10, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[_WEIGHTS_LAYER, _WEIGHTS_CONVBLOCK, _WEIGHTS_SEUNIT, _WEIGHTS_RESIDUAL, ],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=86,
+ serialized_end=1211,
+)
+
+
+_TRAININGPARAMS = _descriptor.Descriptor(
+ name='TrainingParams',
+ full_name='pblczero.TrainingParams',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='training_steps', full_name='pblczero.TrainingParams.training_steps', index=0,
+ number=1, type=13, cpp_type=3, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='learning_rate', full_name='pblczero.TrainingParams.learning_rate', index=1,
+ number=2, type=2, cpp_type=6, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='mse_loss', full_name='pblczero.TrainingParams.mse_loss', index=2,
+ number=3, type=2, cpp_type=6, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='policy_loss', full_name='pblczero.TrainingParams.policy_loss', index=3,
+ number=4, type=2, cpp_type=6, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='accuracy', full_name='pblczero.TrainingParams.accuracy', index=4,
+ number=5, type=2, cpp_type=6, label=1,
+ has_default_value=False, default_value=float(0),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='lc0_params', full_name='pblczero.TrainingParams.lc0_params', index=5,
+ number=6, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=b"".decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=1214,
+ serialized_end=1354,
+)
+
+
+_NETWORKFORMAT = _descriptor.Descriptor(
+ name='NetworkFormat',
+ full_name='pblczero.NetworkFormat',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='input', full_name='pblczero.NetworkFormat.input', index=0,
+ number=1, type=14, cpp_type=8, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='output', full_name='pblczero.NetworkFormat.output', index=1,
+ number=2, type=14, cpp_type=8, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='network', full_name='pblczero.NetworkFormat.network', index=2,
+ number=3, type=14, cpp_type=8, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='policy', full_name='pblczero.NetworkFormat.policy', index=3,
+ number=4, type=14, cpp_type=8, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='value', full_name='pblczero.NetworkFormat.value', index=4,
+ number=5, type=14, cpp_type=8, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ _NETWORKFORMAT_INPUTFORMAT,
+ _NETWORKFORMAT_OUTPUTFORMAT,
+ _NETWORKFORMAT_NETWORKSTRUCTURE,
+ _NETWORKFORMAT_POLICYFORMAT,
+ _NETWORKFORMAT_VALUEFORMAT,
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=1357,
+ serialized_end=2086,
+)
+
+
+_FORMAT = _descriptor.Descriptor(
+ name='Format',
+ full_name='pblczero.Format',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='weights_encoding', full_name='pblczero.Format.weights_encoding', index=0,
+ number=1, type=14, cpp_type=8, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='network_format', full_name='pblczero.Format.network_format', index=1,
+ number=2, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ _FORMAT_ENCODING,
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=2089,
+ serialized_end=2238,
+)
+
+
+_NET = _descriptor.Descriptor(
+ name='Net',
+ full_name='pblczero.Net',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='magic', full_name='pblczero.Net.magic', index=0,
+ number=1, type=7, cpp_type=3, label=1,
+ has_default_value=False, default_value=0,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='license', full_name='pblczero.Net.license', index=1,
+ number=2, type=9, cpp_type=9, label=1,
+ has_default_value=False, default_value=b"".decode('utf-8'),
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='min_version', full_name='pblczero.Net.min_version', index=2,
+ number=3, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='format', full_name='pblczero.Net.format', index=3,
+ number=4, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='training_params', full_name='pblczero.Net.training_params', index=4,
+ number=5, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ _descriptor.FieldDescriptor(
+ name='weights', full_name='pblczero.Net.weights', index=5,
+ number=10, type=11, cpp_type=10, label=1,
+ has_default_value=False, default_value=None,
+ message_type=None, enum_type=None, containing_type=None,
+ is_extension=False, extension_scope=None,
+ serialized_options=None, file=DESCRIPTOR),
+ ],
+ extensions=[
+ ],
+ nested_types=[],
+ enum_types=[
+ ],
+ serialized_options=None,
+ is_extendable=False,
+ syntax='proto2',
+ extension_ranges=[],
+ oneofs=[
+ ],
+ serialized_start=2241,
+ serialized_end=2445,
+)
+
+_WEIGHTS_LAYER.containing_type = _WEIGHTS
+_WEIGHTS_CONVBLOCK.fields_by_name['weights'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_CONVBLOCK.fields_by_name['biases'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_CONVBLOCK.fields_by_name['bn_means'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_CONVBLOCK.fields_by_name['bn_stddivs'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_CONVBLOCK.fields_by_name['bn_gammas'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_CONVBLOCK.fields_by_name['bn_betas'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_CONVBLOCK.containing_type = _WEIGHTS
+_WEIGHTS_SEUNIT.fields_by_name['w1'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_SEUNIT.fields_by_name['b1'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_SEUNIT.fields_by_name['w2'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_SEUNIT.fields_by_name['b2'].message_type = _WEIGHTS_LAYER
+_WEIGHTS_SEUNIT.containing_type = _WEIGHTS
+_WEIGHTS_RESIDUAL.fields_by_name['conv1'].message_type = _WEIGHTS_CONVBLOCK
+_WEIGHTS_RESIDUAL.fields_by_name['conv2'].message_type = _WEIGHTS_CONVBLOCK
+_WEIGHTS_RESIDUAL.fields_by_name['se'].message_type = _WEIGHTS_SEUNIT
+_WEIGHTS_RESIDUAL.containing_type = _WEIGHTS
+_WEIGHTS.fields_by_name['input'].message_type = _WEIGHTS_CONVBLOCK
+_WEIGHTS.fields_by_name['residual'].message_type = _WEIGHTS_RESIDUAL
+_WEIGHTS.fields_by_name['policy1'].message_type = _WEIGHTS_CONVBLOCK
+_WEIGHTS.fields_by_name['policy'].message_type = _WEIGHTS_CONVBLOCK
+_WEIGHTS.fields_by_name['ip_pol_w'].message_type = _WEIGHTS_LAYER
+_WEIGHTS.fields_by_name['ip_pol_b'].message_type = _WEIGHTS_LAYER
+_WEIGHTS.fields_by_name['value'].message_type = _WEIGHTS_CONVBLOCK
+_WEIGHTS.fields_by_name['ip1_val_w'].message_type = _WEIGHTS_LAYER
+_WEIGHTS.fields_by_name['ip1_val_b'].message_type = _WEIGHTS_LAYER
+_WEIGHTS.fields_by_name['ip2_val_w'].message_type = _WEIGHTS_LAYER
+_WEIGHTS.fields_by_name['ip2_val_b'].message_type = _WEIGHTS_LAYER
+_NETWORKFORMAT.fields_by_name['input'].enum_type = _NETWORKFORMAT_INPUTFORMAT
+_NETWORKFORMAT.fields_by_name['output'].enum_type = _NETWORKFORMAT_OUTPUTFORMAT
+_NETWORKFORMAT.fields_by_name['network'].enum_type = _NETWORKFORMAT_NETWORKSTRUCTURE
+_NETWORKFORMAT.fields_by_name['policy'].enum_type = _NETWORKFORMAT_POLICYFORMAT
+_NETWORKFORMAT.fields_by_name['value'].enum_type = _NETWORKFORMAT_VALUEFORMAT
+_NETWORKFORMAT_INPUTFORMAT.containing_type = _NETWORKFORMAT
+_NETWORKFORMAT_OUTPUTFORMAT.containing_type = _NETWORKFORMAT
+_NETWORKFORMAT_NETWORKSTRUCTURE.containing_type = _NETWORKFORMAT
+_NETWORKFORMAT_POLICYFORMAT.containing_type = _NETWORKFORMAT
+_NETWORKFORMAT_VALUEFORMAT.containing_type = _NETWORKFORMAT
+_FORMAT.fields_by_name['weights_encoding'].enum_type = _FORMAT_ENCODING
+_FORMAT.fields_by_name['network_format'].message_type = _NETWORKFORMAT
+_FORMAT_ENCODING.containing_type = _FORMAT
+_NET.fields_by_name['min_version'].message_type = _ENGINEVERSION
+_NET.fields_by_name['format'].message_type = _FORMAT
+_NET.fields_by_name['training_params'].message_type = _TRAININGPARAMS
+_NET.fields_by_name['weights'].message_type = _WEIGHTS
+DESCRIPTOR.message_types_by_name['EngineVersion'] = _ENGINEVERSION
+DESCRIPTOR.message_types_by_name['Weights'] = _WEIGHTS
+DESCRIPTOR.message_types_by_name['TrainingParams'] = _TRAININGPARAMS
+DESCRIPTOR.message_types_by_name['NetworkFormat'] = _NETWORKFORMAT
+DESCRIPTOR.message_types_by_name['Format'] = _FORMAT
+DESCRIPTOR.message_types_by_name['Net'] = _NET
+_sym_db.RegisterFileDescriptor(DESCRIPTOR)
+
+EngineVersion = _reflection.GeneratedProtocolMessageType('EngineVersion', (_message.Message,), {
+ 'DESCRIPTOR' : _ENGINEVERSION,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.EngineVersion)
+ })
+_sym_db.RegisterMessage(EngineVersion)
+
+Weights = _reflection.GeneratedProtocolMessageType('Weights', (_message.Message,), {
+
+ 'Layer' : _reflection.GeneratedProtocolMessageType('Layer', (_message.Message,), {
+ 'DESCRIPTOR' : _WEIGHTS_LAYER,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.Weights.Layer)
+ })
+ ,
+
+ 'ConvBlock' : _reflection.GeneratedProtocolMessageType('ConvBlock', (_message.Message,), {
+ 'DESCRIPTOR' : _WEIGHTS_CONVBLOCK,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.Weights.ConvBlock)
+ })
+ ,
+
+ 'SEunit' : _reflection.GeneratedProtocolMessageType('SEunit', (_message.Message,), {
+ 'DESCRIPTOR' : _WEIGHTS_SEUNIT,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.Weights.SEunit)
+ })
+ ,
+
+ 'Residual' : _reflection.GeneratedProtocolMessageType('Residual', (_message.Message,), {
+ 'DESCRIPTOR' : _WEIGHTS_RESIDUAL,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.Weights.Residual)
+ })
+ ,
+ 'DESCRIPTOR' : _WEIGHTS,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.Weights)
+ })
+_sym_db.RegisterMessage(Weights)
+_sym_db.RegisterMessage(Weights.Layer)
+_sym_db.RegisterMessage(Weights.ConvBlock)
+_sym_db.RegisterMessage(Weights.SEunit)
+_sym_db.RegisterMessage(Weights.Residual)
+
+TrainingParams = _reflection.GeneratedProtocolMessageType('TrainingParams', (_message.Message,), {
+ 'DESCRIPTOR' : _TRAININGPARAMS,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.TrainingParams)
+ })
+_sym_db.RegisterMessage(TrainingParams)
+
+NetworkFormat = _reflection.GeneratedProtocolMessageType('NetworkFormat', (_message.Message,), {
+ 'DESCRIPTOR' : _NETWORKFORMAT,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.NetworkFormat)
+ })
+_sym_db.RegisterMessage(NetworkFormat)
+
+Format = _reflection.GeneratedProtocolMessageType('Format', (_message.Message,), {
+ 'DESCRIPTOR' : _FORMAT,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.Format)
+ })
+_sym_db.RegisterMessage(Format)
+
+Net = _reflection.GeneratedProtocolMessageType('Net', (_message.Message,), {
+ 'DESCRIPTOR' : _NET,
+ '__module__' : 'net_pb2'
+ # @@protoc_insertion_point(class_scope:pblczero.Net)
+ })
+_sym_db.RegisterMessage(Net)
+
+
+# @@protoc_insertion_point(module_scope)
diff --git a/backend/tf_transfer/__init__.py b/backend/tf_transfer/__init__.py
new file mode 100755
index 0000000..f0cf9f6
--- /dev/null
+++ b/backend/tf_transfer/__init__.py
@@ -0,0 +1,5 @@
+from .tfprocess import TFProcess
+from .chunkparser import ChunkParser
+from .net import *
+from .training_shared import *
+from .utils import *
diff --git a/backend/tf_transfer/chunkparser.py b/backend/tf_transfer/chunkparser.py
new file mode 100755
index 0000000..e09809d
--- /dev/null
+++ b/backend/tf_transfer/chunkparser.py
@@ -0,0 +1,458 @@
+#!/usr/bin/env python3
+#
+# This file is part of Leela Chess.
+# Copyright (C) 2018 Folkert Huizinga
+# Copyright (C) 2017-2018 Gian-Carlo Pascutto
+#
+# Leela Chess is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Leela Chess is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Leela Chess. If not, see .
+
+import itertools
+import multiprocessing as mp
+import numpy as np
+import random
+from .shufflebuffer import ShuffleBuffer
+from ..utils import printWithDate
+import struct
+import tensorflow as tf
+import unittest
+
+V4_VERSION = struct.pack('i', 4)
+V3_VERSION = struct.pack('i', 3)
+V4_STRUCT_STRING = '4s7432s832sBBBBBBBbffff'
+V3_STRUCT_STRING = '4s7432s832sBBBBBBBb'
+
+# Interface for a chunk data source.
+class ChunkDataSrc:
+ def __init__(self, items):
+ self.items = items
+ def next(self):
+ if not self.items:
+ return None
+ return self.items.pop()
+
+
+
+class ChunkParser:
+ # static batch size
+ BATCH_SIZE = 8
+ def __init__(self, chunkdatasrc, shuffle_size=1, sample=1, buffer_size=1, batch_size=256, workers=None):
+ """
+ Read data and yield batches of raw tensors.
+
+ 'chunkdatasrc' is an object yeilding chunkdata
+ 'shuffle_size' is the size of the shuffle buffer.
+ 'sample' is the rate to down-sample.
+ 'workers' is the number of child workers to use.
+
+ The data is represented in a number of formats through this dataflow
+ pipeline. In order, they are:
+
+ chunk: The name of a file containing chunkdata
+
+ chunkdata: type Bytes. Multiple records of v4 format where each record
+ consists of (state, policy, result, q)
+
+ raw: A byte string holding raw tensors contenated together. This is
+ used to pass data from the workers to the parent. Exists because
+ TensorFlow doesn't have a fast way to unpack bit vectors. 7950 bytes
+ long.
+ """
+
+ # Build 2 flat float32 planes with values 0,1
+ self.flat_planes = []
+ for i in range(2):
+ self.flat_planes.append(np.zeros(64, dtype=np.float32) + i)
+
+ # set the down-sampling rate
+ self.sample = sample
+ # set the mini-batch size
+ self.batch_size = batch_size
+ # set number of elements in the shuffle buffer.
+ self.shuffle_size = shuffle_size
+ # Start worker processes, leave 2 for TensorFlow
+ if workers is None:
+ workers = max(1, mp.cpu_count() - 2)
+
+ printWithDate("Using {} worker processes.".format(workers))
+
+ # Start the child workers running
+ readers = []
+ writers = []
+ processes = []
+ for i in range(workers):
+ read, write = mp.Pipe(duplex=False)
+ p = mp.Process(target=self.task, args=(chunkdatasrc, write))
+ processes.append(p)
+ p.start()
+ readers.append(read)
+ writers.append(write)
+ printWithDate(f"{len(processes)} tasks started", end = '\r')
+ self.init_structs()
+ self.readers = readers
+ self.writers = writers
+ self.processes = processes
+ printWithDate(f"{len(processes)} tasks started")
+
+ def shutdown(self):
+ """
+ Terminates all the workers
+ """
+ for i in range(len(self.readers)):
+ self.processes[i].terminate()
+ self.processes[i].join()
+ self.readers[i].close()
+ self.writers[i].close()
+
+
+ def init_structs(self):
+ """
+ struct.Struct doesn't pickle, so it needs to be separately
+ constructed in workers.
+
+ V4 Format (8292 bytes total)
+ int32 version (4 bytes)
+ 1858 float32 probabilities (7432 bytes) (removed 66*4 = 264 bytes unused under-promotions)
+ 104 (13*8) packed bit planes of 8 bytes each (832 bytes) (no rep2 plane)
+ uint8 castling us_ooo (1 byte)
+ uint8 castling us_oo (1 byte)
+ uint8 castling them_ooo (1 byte)
+ uint8 castling them_oo (1 byte)
+ uint8 side_to_move (1 byte) aka us_black
+ uint8 rule50_count (1 byte)
+ uint8 move_count (1 byte)
+ int8 result (1 byte)
+ float32 root_q (4 bytes)
+ float32 best_q (4 bytes)
+ float32 root_d (4 bytes)
+ float32 best_d (4 bytes)
+ """
+ self.v4_struct = struct.Struct(V4_STRUCT_STRING)
+ self.v3_struct = struct.Struct(V3_STRUCT_STRING)
+
+
+ @staticmethod
+ def parse_function(planes, probs, winner, q):
+ """
+ Convert unpacked record batches to tensors for tensorflow training
+ """
+ planes = tf.io.decode_raw(planes, tf.float32)
+ probs = tf.io.decode_raw(probs, tf.float32)
+ winner = tf.io.decode_raw(winner, tf.float32)
+ q = tf.io.decode_raw(q, tf.float32)
+
+ planes = tf.reshape(planes, (ChunkParser.BATCH_SIZE, 112, 8*8))
+ probs = tf.reshape(probs, (ChunkParser.BATCH_SIZE, 1858))
+ winner = tf.reshape(winner, (ChunkParser.BATCH_SIZE, 3))
+ q = tf.reshape(q, (ChunkParser.BATCH_SIZE, 3))
+
+ return (planes, probs, winner, q)
+
+
+ def convert_v4_to_tuple(self, content):
+ """
+ Unpack a v4 binary record to 4-tuple (state, policy pi, result, q)
+
+ v4 struct format is (8292 bytes total)
+ int32 version (4 bytes)
+ 1858 float32 probabilities (7432 bytes)
+ 104 (13*8) packed bit planes of 8 bytes each (832 bytes)
+ uint8 castling us_ooo (1 byte)
+ uint8 castling us_oo (1 byte)
+ uint8 castling them_ooo (1 byte)
+ uint8 castling them_oo (1 byte)
+ uint8 side_to_move (1 byte)
+ uint8 rule50_count (1 byte)
+ uint8 move_count (1 byte)
+ int8 result (1 byte)
+ float32 root_q (4 bytes)
+ float32 best_q (4 bytes)
+ float32 root_d (4 bytes)
+ float32 best_d (4 bytes)
+ """
+ (ver, probs, planes, us_ooo, us_oo, them_ooo, them_oo, stm, rule50_count, move_count, winner, root_q, best_q, root_d, best_d) = self.v4_struct.unpack(content)
+ # Enforce move_count to 0
+ move_count = 0
+
+ # Unpack bit planes and cast to 32 bit float
+ planes = np.unpackbits(np.frombuffer(planes, dtype=np.uint8)).astype(np.float32)
+ rule50_plane = (np.zeros(8*8, dtype=np.float32) + rule50_count) / 99
+
+ # Concatenate all byteplanes. Make the last plane all 1's so the NN can
+ # detect edges of the board more easily
+ planes = planes.tobytes() + \
+ self.flat_planes[us_ooo].tobytes() + \
+ self.flat_planes[us_oo].tobytes() + \
+ self.flat_planes[them_ooo].tobytes() + \
+ self.flat_planes[them_oo].tobytes() + \
+ self.flat_planes[stm].tobytes() + \
+ rule50_plane.tobytes() + \
+ self.flat_planes[move_count].tobytes() + \
+ self.flat_planes[1].tobytes()
+
+ assert len(planes) == ((8*13*1 + 8*1*1) * 8 * 8 * 4)
+ winner = float(winner)
+ assert winner == 1.0 or winner == -1.0 or winner == 0.0
+ winner = struct.pack('fff', winner == 1.0, winner == 0.0, winner == -1.0)
+
+ best_q_w = 0.5 * (1.0 - best_d + best_q)
+ best_q_l = 0.5 * (1.0 - best_d - best_q)
+ assert -1.0 <= best_q <= 1.0 and 0.0 <= best_d <= 1.0
+ best_q = struct.pack('fff', best_q_w, best_d, best_q_l)
+
+ return (planes, probs, winner, best_q)
+
+
+ def sample_record(self, chunkdata, is_white):
+ """
+ Randomly sample through the v4 chunk data and select records
+ """
+ version = chunkdata[0:4]
+ if version == V4_VERSION:
+ record_size = self.v4_struct.size
+ elif version == V3_VERSION:
+ record_size = self.v3_struct.size
+ else:
+ return
+
+ for i in range(0, len(chunkdata), record_size):
+ if self.sample > 1:
+ # Downsample, using only 1/Nth of the items.
+ if random.randint(0, self.sample-1) != 0:
+ continue # Skip this record.
+ record = chunkdata[i:i+record_size]
+ if version == V3_VERSION:
+ # add 16 bytes of fake root_q, best_q, root_d, best_d to match V4 format
+ record += 16 * b'\x00'
+
+ (ver, probs, planes, us_ooo, us_oo, them_ooo, them_oo, stm, rule50_count, move_count, winner, root_q, best_q, root_d, best_d) = self.v4_struct.unpack(record)
+
+ if is_white and not stm:
+ yield record
+ elif not is_white and stm:
+ yield record
+ else:
+ continue
+
+ def task(self, chunkdatasrc, writer):
+ """
+ Run in fork'ed process, read data from chunkdatasrc, parsing, shuffling and
+ sending v4 data through pipe back to main process.
+ """
+ self.init_structs()
+ while True:
+ chunkdata, is_white = chunkdatasrc.next()
+ if chunkdata is None:
+ break
+ for item in self.sample_record(chunkdata, is_white):
+ # NOTE: This requires some more thinking, we can't just apply a
+ # reflection along the horizontal or vertical axes as we would
+ # also have to apply the reflection to the move probabilities
+ # which is non trivial for chess.
+ try:
+ writer.send_bytes(item)
+ except KeyboardInterrupt:
+ return
+
+
+ def v4_gen(self):
+ """
+ Read v4 records from child workers, shuffle, and yield
+ records.
+ """
+ sbuff = ShuffleBuffer(self.v4_struct.size, self.shuffle_size)
+ while len(self.readers):
+ #for r in mp.connection.wait(self.readers):
+ for r in self.readers:
+ try:
+ s = r.recv_bytes()
+ s = sbuff.insert_or_replace(s)
+ if s is None:
+ continue # shuffle buffer not yet full
+ yield s
+ except EOFError:
+ printWithDate("Reader EOF")
+ self.readers.remove(r)
+ # drain the shuffle buffer.
+ while True:
+ s = sbuff.extract()
+ if s is None:
+ return
+ yield s
+
+
+ def tuple_gen(self, gen):
+ """
+ Take a generator producing v4 records and convert them to tuples.
+ applying a random symmetry on the way.
+ """
+ for r in gen:
+ yield self.convert_v4_to_tuple(r)
+
+
+ def batch_gen(self, gen):
+ """
+ Pack multiple records into a single batch
+ """
+ # Get N records. We flatten the returned generator to
+ # a list because we need to reuse it.
+ while True:
+ s = list(itertools.islice(gen, self.batch_size))
+ if not len(s):
+ return
+ yield ( b''.join([x[0] for x in s]),
+ b''.join([x[1] for x in s]),
+ b''.join([x[2] for x in s]),
+ b''.join([x[3] for x in s]))
+
+
+ def parse(self):
+ """
+ Read data from child workers and yield batches of unpacked records
+ """
+ gen = self.v4_gen() # read from workers
+ gen = self.tuple_gen(gen) # convert v4->tuple
+ gen = self.batch_gen(gen) # assemble into batches
+ for b in gen:
+ yield b
+
+
+
+# Tests to check that records parse correctly
+class ChunkParserTest(unittest.TestCase):
+ def setUp(self):
+ self.v4_struct = struct.Struct(V4_STRUCT_STRING)
+
+ def generate_fake_pos(self):
+ """
+ Generate a random game position.
+ Result is ([[64] * 104], [1]*5, [1858], [1], [1])
+ """
+ # 0. 104 binary planes of length 64
+ planes = [np.random.randint(2, size=64).tolist() for plane in range(104)]
+
+ # 1. generate the other integer data
+ integer = np.zeros(7, dtype=np.int32)
+ for i in range(5):
+ integer[i] = np.random.randint(2)
+ integer[5] = np.random.randint(100)
+
+ # 2. 1858 probs
+ probs = np.random.randint(9, size=1858, dtype=np.int32)
+
+ # 3. And a winner: 1, 0, -1
+ winner = np.random.randint(3) - 1
+
+ # 4. evaluation after search
+ best_q = np.random.uniform(-1, 1)
+ best_d = np.random.uniform(0, 1 - np.abs(best_q))
+ return (planes, integer, probs, winner, best_q, best_d)
+
+
+ def v4_record(self, planes, i, probs, winner, best_q, best_d):
+ pl = []
+ for plane in planes:
+ pl.append(np.packbits(plane))
+ pl = np.array(pl).flatten().tobytes()
+ pi = probs.tobytes()
+ root_q, root_d = 0.0, 0.0
+ return self.v4_struct.pack(V4_VERSION, pi, pl, i[0], i[1], i[2], i[3], i[4], i[5], i[6], winner, root_q, best_q, root_d, best_d)
+
+
+ def test_structsize(self):
+ """
+ Test struct size
+ """
+ self.assertEqual(self.v4_struct.size, 8292)
+
+
+ def test_parsing(self):
+ """
+ Test game position decoding pipeline.
+ """
+ truth = self.generate_fake_pos()
+ batch_size = 4
+ records = []
+ for i in range(batch_size):
+ record = b''
+ for j in range(2):
+ record += self.v4_record(*truth)
+ records.append(record)
+
+ parser = ChunkParser(ChunkDataSrc(records), shuffle_size=1, workers=1, batch_size=batch_size)
+ batchgen = parser.parse()
+ data = next(batchgen)
+
+ batch = ( np.reshape(np.frombuffer(data[0], dtype=np.float32), (batch_size, 112, 64)),
+ np.reshape(np.frombuffer(data[1], dtype=np.int32), (batch_size, 1858)),
+ np.reshape(np.frombuffer(data[2], dtype=np.float32), (batch_size, 3)),
+ np.reshape(np.frombuffer(data[3], dtype=np.float32), (batch_size, 3)) )
+
+ fltplanes = truth[1].astype(np.float32)
+ fltplanes[5] /= 99
+ for i in range(batch_size):
+ data = (batch[0][i][:104], np.array([batch[0][i][j][0] for j in range(104,111)]), batch[1][i], batch[2][i], batch[3][i])
+ self.assertTrue((data[0] == truth[0]).all())
+ self.assertTrue((data[1] == fltplanes).all())
+ self.assertTrue((data[2] == truth[2]).all())
+ scalar_win = data[3][0] - data[3][-1]
+ self.assertTrue(np.abs(scalar_win - truth[3]) < 1e-6)
+ scalar_q = data[4][0] - data[4][-1]
+ self.assertTrue(np.abs(scalar_q - truth[4]) < 1e-6)
+
+ parser.shutdown()
+
+
+ def test_tensorflow_parsing(self):
+ """
+ Test game position decoding pipeline including tensorflow.
+ """
+ truth = self.generate_fake_pos()
+ batch_size = 4
+ ChunkParser.BATCH_SIZE = batch_size
+ records = []
+ for i in range(batch_size):
+ record = b''
+ for j in range(2):
+ record += self.v4_record(*truth)
+ records.append(record)
+
+ parser = ChunkParser(ChunkDataSrc(records), shuffle_size=1, workers=1, batch_size=batch_size)
+ batchgen = parser.parse()
+ data = next(batchgen)
+
+ planes = np.frombuffer(data[0], dtype=np.float32, count=112*8*8*batch_size)
+ planes = planes.reshape(batch_size, 112, 8*8)
+ probs = np.frombuffer(data[1], dtype=np.float32, count=1858*batch_size)
+ probs = probs.reshape(batch_size, 1858)
+ winner = np.frombuffer(data[2], dtype=np.float32, count=3*batch_size)
+ winner = winner.reshape(batch_size, 3)
+ best_q = np.frombuffer(data[3], dtype=np.float32, count=3*batch_size)
+ best_q = best_q.reshape(batch_size, 3)
+
+ # Pass it through tensorflow
+ with tf.compat.v1.Session() as sess:
+ graph = ChunkParser.parse_function(data[0], data[1], data[2], data[3])
+ tf_planes, tf_probs, tf_winner, tf_q = sess.run(graph)
+
+ for i in range(batch_size):
+ self.assertTrue((probs[i] == tf_probs[i]).all())
+ self.assertTrue((planes[i] == tf_planes[i]).all())
+ self.assertTrue((winner[i] == tf_winner[i]).all())
+ self.assertTrue((best_q[i] == tf_q[i]).all())
+
+ parser.shutdown()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/backend/tf_transfer/decode_training.py b/backend/tf_transfer/decode_training.py
new file mode 100755
index 0000000..5d1c4c9
--- /dev/null
+++ b/backend/tf_transfer/decode_training.py
@@ -0,0 +1,2130 @@
+#!/usr/bin/env python3
+#
+# This file is part of Leela Chess.
+# Copyright (C) 2018 Folkert Huizinga
+# Copyright (C) 2017-2018 Gian-Carlo Pascutto
+#
+# Leela Chess is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Leela Chess is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Leela Chess. If not, see .
+
+import array
+import binascii
+import chunkparser
+import glob
+import gzip
+import itertools
+import math
+import numpy as np
+import random
+import re
+import os
+import shutil
+import struct
+import sys
+import threading
+import time
+import unittest
+import argparse
+from collections import defaultdict
+
+# VERSION of the training data file format
+# 1 - Text, oldflip
+# 2 - Binary, oldflip
+# 3 - Binary, newflip
+# b'\1\0\0\0' - Invalid, see issue #119
+#
+# Note: VERSION1 does not include a version in the header, it starts with
+# text hex characters. This means any VERSION that is also a valid ASCII
+# hex string could potentially be a training file. Most of the time it
+# will be "00ff", but maybe games could get split and then there are more
+# "00??" possibilities.
+#
+# Also note "0002" is actually b'\0x30\0x30\0x30\0x32' (or maybe reversed?)
+# so it doesn't collide with VERSION2.
+#
+VERSION3 = chunkparser.V3_VERSION
+VERSION4 = chunkparser.V4_VERSION
+
+V3_BYTES = 8276
+V4_BYTES = 8292
+
+# Us -- uppercase
+# Them -- lowercase
+PIECES = "PNBRQKpnbrqk"
+
+MOVES = [
+ "a1b1",
+ "a1c1",
+ "a1d1",
+ "a1e1",
+ "a1f1",
+ "a1g1",
+ "a1h1",
+ "a1a2",
+ "a1b2",
+ "a1c2",
+ "a1a3",
+ "a1b3",
+ "a1c3",
+ "a1a4",
+ "a1d4",
+ "a1a5",
+ "a1e5",
+ "a1a6",
+ "a1f6",
+ "a1a7",
+ "a1g7",
+ "a1a8",
+ "a1h8",
+ "b1a1",
+ "b1c1",
+ "b1d1",
+ "b1e1",
+ "b1f1",
+ "b1g1",
+ "b1h1",
+ "b1a2",
+ "b1b2",
+ "b1c2",
+ "b1d2",
+ "b1a3",
+ "b1b3",
+ "b1c3",
+ "b1d3",
+ "b1b4",
+ "b1e4",
+ "b1b5",
+ "b1f5",
+ "b1b6",
+ "b1g6",
+ "b1b7",
+ "b1h7",
+ "b1b8",
+ "c1a1",
+ "c1b1",
+ "c1d1",
+ "c1e1",
+ "c1f1",
+ "c1g1",
+ "c1h1",
+ "c1a2",
+ "c1b2",
+ "c1c2",
+ "c1d2",
+ "c1e2",
+ "c1a3",
+ "c1b3",
+ "c1c3",
+ "c1d3",
+ "c1e3",
+ "c1c4",
+ "c1f4",
+ "c1c5",
+ "c1g5",
+ "c1c6",
+ "c1h6",
+ "c1c7",
+ "c1c8",
+ "d1a1",
+ "d1b1",
+ "d1c1",
+ "d1e1",
+ "d1f1",
+ "d1g1",
+ "d1h1",
+ "d1b2",
+ "d1c2",
+ "d1d2",
+ "d1e2",
+ "d1f2",
+ "d1b3",
+ "d1c3",
+ "d1d3",
+ "d1e3",
+ "d1f3",
+ "d1a4",
+ "d1d4",
+ "d1g4",
+ "d1d5",
+ "d1h5",
+ "d1d6",
+ "d1d7",
+ "d1d8",
+ "e1a1",
+ "e1b1",
+ "e1c1",
+ "e1d1",
+ "e1f1",
+ "e1g1",
+ "e1h1",
+ "e1c2",
+ "e1d2",
+ "e1e2",
+ "e1f2",
+ "e1g2",
+ "e1c3",
+ "e1d3",
+ "e1e3",
+ "e1f3",
+ "e1g3",
+ "e1b4",
+ "e1e4",
+ "e1h4",
+ "e1a5",
+ "e1e5",
+ "e1e6",
+ "e1e7",
+ "e1e8",
+ "f1a1",
+ "f1b1",
+ "f1c1",
+ "f1d1",
+ "f1e1",
+ "f1g1",
+ "f1h1",
+ "f1d2",
+ "f1e2",
+ "f1f2",
+ "f1g2",
+ "f1h2",
+ "f1d3",
+ "f1e3",
+ "f1f3",
+ "f1g3",
+ "f1h3",
+ "f1c4",
+ "f1f4",
+ "f1b5",
+ "f1f5",
+ "f1a6",
+ "f1f6",
+ "f1f7",
+ "f1f8",
+ "g1a1",
+ "g1b1",
+ "g1c1",
+ "g1d1",
+ "g1e1",
+ "g1f1",
+ "g1h1",
+ "g1e2",
+ "g1f2",
+ "g1g2",
+ "g1h2",
+ "g1e3",
+ "g1f3",
+ "g1g3",
+ "g1h3",
+ "g1d4",
+ "g1g4",
+ "g1c5",
+ "g1g5",
+ "g1b6",
+ "g1g6",
+ "g1a7",
+ "g1g7",
+ "g1g8",
+ "h1a1",
+ "h1b1",
+ "h1c1",
+ "h1d1",
+ "h1e1",
+ "h1f1",
+ "h1g1",
+ "h1f2",
+ "h1g2",
+ "h1h2",
+ "h1f3",
+ "h1g3",
+ "h1h3",
+ "h1e4",
+ "h1h4",
+ "h1d5",
+ "h1h5",
+ "h1c6",
+ "h1h6",
+ "h1b7",
+ "h1h7",
+ "h1a8",
+ "h1h8",
+ "a2a1",
+ "a2b1",
+ "a2c1",
+ "a2b2",
+ "a2c2",
+ "a2d2",
+ "a2e2",
+ "a2f2",
+ "a2g2",
+ "a2h2",
+ "a2a3",
+ "a2b3",
+ "a2c3",
+ "a2a4",
+ "a2b4",
+ "a2c4",
+ "a2a5",
+ "a2d5",
+ "a2a6",
+ "a2e6",
+ "a2a7",
+ "a2f7",
+ "a2a8",
+ "a2g8",
+ "b2a1",
+ "b2b1",
+ "b2c1",
+ "b2d1",
+ "b2a2",
+ "b2c2",
+ "b2d2",
+ "b2e2",
+ "b2f2",
+ "b2g2",
+ "b2h2",
+ "b2a3",
+ "b2b3",
+ "b2c3",
+ "b2d3",
+ "b2a4",
+ "b2b4",
+ "b2c4",
+ "b2d4",
+ "b2b5",
+ "b2e5",
+ "b2b6",
+ "b2f6",
+ "b2b7",
+ "b2g7",
+ "b2b8",
+ "b2h8",
+ "c2a1",
+ "c2b1",
+ "c2c1",
+ "c2d1",
+ "c2e1",
+ "c2a2",
+ "c2b2",
+ "c2d2",
+ "c2e2",
+ "c2f2",
+ "c2g2",
+ "c2h2",
+ "c2a3",
+ "c2b3",
+ "c2c3",
+ "c2d3",
+ "c2e3",
+ "c2a4",
+ "c2b4",
+ "c2c4",
+ "c2d4",
+ "c2e4",
+ "c2c5",
+ "c2f5",
+ "c2c6",
+ "c2g6",
+ "c2c7",
+ "c2h7",
+ "c2c8",
+ "d2b1",
+ "d2c1",
+ "d2d1",
+ "d2e1",
+ "d2f1",
+ "d2a2",
+ "d2b2",
+ "d2c2",
+ "d2e2",
+ "d2f2",
+ "d2g2",
+ "d2h2",
+ "d2b3",
+ "d2c3",
+ "d2d3",
+ "d2e3",
+ "d2f3",
+ "d2b4",
+ "d2c4",
+ "d2d4",
+ "d2e4",
+ "d2f4",
+ "d2a5",
+ "d2d5",
+ "d2g5",
+ "d2d6",
+ "d2h6",
+ "d2d7",
+ "d2d8",
+ "e2c1",
+ "e2d1",
+ "e2e1",
+ "e2f1",
+ "e2g1",
+ "e2a2",
+ "e2b2",
+ "e2c2",
+ "e2d2",
+ "e2f2",
+ "e2g2",
+ "e2h2",
+ "e2c3",
+ "e2d3",
+ "e2e3",
+ "e2f3",
+ "e2g3",
+ "e2c4",
+ "e2d4",
+ "e2e4",
+ "e2f4",
+ "e2g4",
+ "e2b5",
+ "e2e5",
+ "e2h5",
+ "e2a6",
+ "e2e6",
+ "e2e7",
+ "e2e8",
+ "f2d1",
+ "f2e1",
+ "f2f1",
+ "f2g1",
+ "f2h1",
+ "f2a2",
+ "f2b2",
+ "f2c2",
+ "f2d2",
+ "f2e2",
+ "f2g2",
+ "f2h2",
+ "f2d3",
+ "f2e3",
+ "f2f3",
+ "f2g3",
+ "f2h3",
+ "f2d4",
+ "f2e4",
+ "f2f4",
+ "f2g4",
+ "f2h4",
+ "f2c5",
+ "f2f5",
+ "f2b6",
+ "f2f6",
+ "f2a7",
+ "f2f7",
+ "f2f8",
+ "g2e1",
+ "g2f1",
+ "g2g1",
+ "g2h1",
+ "g2a2",
+ "g2b2",
+ "g2c2",
+ "g2d2",
+ "g2e2",
+ "g2f2",
+ "g2h2",
+ "g2e3",
+ "g2f3",
+ "g2g3",
+ "g2h3",
+ "g2e4",
+ "g2f4",
+ "g2g4",
+ "g2h4",
+ "g2d5",
+ "g2g5",
+ "g2c6",
+ "g2g6",
+ "g2b7",
+ "g2g7",
+ "g2a8",
+ "g2g8",
+ "h2f1",
+ "h2g1",
+ "h2h1",
+ "h2a2",
+ "h2b2",
+ "h2c2",
+ "h2d2",
+ "h2e2",
+ "h2f2",
+ "h2g2",
+ "h2f3",
+ "h2g3",
+ "h2h3",
+ "h2f4",
+ "h2g4",
+ "h2h4",
+ "h2e5",
+ "h2h5",
+ "h2d6",
+ "h2h6",
+ "h2c7",
+ "h2h7",
+ "h2b8",
+ "h2h8",
+ "a3a1",
+ "a3b1",
+ "a3c1",
+ "a3a2",
+ "a3b2",
+ "a3c2",
+ "a3b3",
+ "a3c3",
+ "a3d3",
+ "a3e3",
+ "a3f3",
+ "a3g3",
+ "a3h3",
+ "a3a4",
+ "a3b4",
+ "a3c4",
+ "a3a5",
+ "a3b5",
+ "a3c5",
+ "a3a6",
+ "a3d6",
+ "a3a7",
+ "a3e7",
+ "a3a8",
+ "a3f8",
+ "b3a1",
+ "b3b1",
+ "b3c1",
+ "b3d1",
+ "b3a2",
+ "b3b2",
+ "b3c2",
+ "b3d2",
+ "b3a3",
+ "b3c3",
+ "b3d3",
+ "b3e3",
+ "b3f3",
+ "b3g3",
+ "b3h3",
+ "b3a4",
+ "b3b4",
+ "b3c4",
+ "b3d4",
+ "b3a5",
+ "b3b5",
+ "b3c5",
+ "b3d5",
+ "b3b6",
+ "b3e6",
+ "b3b7",
+ "b3f7",
+ "b3b8",
+ "b3g8",
+ "c3a1",
+ "c3b1",
+ "c3c1",
+ "c3d1",
+ "c3e1",
+ "c3a2",
+ "c3b2",
+ "c3c2",
+ "c3d2",
+ "c3e2",
+ "c3a3",
+ "c3b3",
+ "c3d3",
+ "c3e3",
+ "c3f3",
+ "c3g3",
+ "c3h3",
+ "c3a4",
+ "c3b4",
+ "c3c4",
+ "c3d4",
+ "c3e4",
+ "c3a5",
+ "c3b5",
+ "c3c5",
+ "c3d5",
+ "c3e5",
+ "c3c6",
+ "c3f6",
+ "c3c7",
+ "c3g7",
+ "c3c8",
+ "c3h8",
+ "d3b1",
+ "d3c1",
+ "d3d1",
+ "d3e1",
+ "d3f1",
+ "d3b2",
+ "d3c2",
+ "d3d2",
+ "d3e2",
+ "d3f2",
+ "d3a3",
+ "d3b3",
+ "d3c3",
+ "d3e3",
+ "d3f3",
+ "d3g3",
+ "d3h3",
+ "d3b4",
+ "d3c4",
+ "d3d4",
+ "d3e4",
+ "d3f4",
+ "d3b5",
+ "d3c5",
+ "d3d5",
+ "d3e5",
+ "d3f5",
+ "d3a6",
+ "d3d6",
+ "d3g6",
+ "d3d7",
+ "d3h7",
+ "d3d8",
+ "e3c1",
+ "e3d1",
+ "e3e1",
+ "e3f1",
+ "e3g1",
+ "e3c2",
+ "e3d2",
+ "e3e2",
+ "e3f2",
+ "e3g2",
+ "e3a3",
+ "e3b3",
+ "e3c3",
+ "e3d3",
+ "e3f3",
+ "e3g3",
+ "e3h3",
+ "e3c4",
+ "e3d4",
+ "e3e4",
+ "e3f4",
+ "e3g4",
+ "e3c5",
+ "e3d5",
+ "e3e5",
+ "e3f5",
+ "e3g5",
+ "e3b6",
+ "e3e6",
+ "e3h6",
+ "e3a7",
+ "e3e7",
+ "e3e8",
+ "f3d1",
+ "f3e1",
+ "f3f1",
+ "f3g1",
+ "f3h1",
+ "f3d2",
+ "f3e2",
+ "f3f2",
+ "f3g2",
+ "f3h2",
+ "f3a3",
+ "f3b3",
+ "f3c3",
+ "f3d3",
+ "f3e3",
+ "f3g3",
+ "f3h3",
+ "f3d4",
+ "f3e4",
+ "f3f4",
+ "f3g4",
+ "f3h4",
+ "f3d5",
+ "f3e5",
+ "f3f5",
+ "f3g5",
+ "f3h5",
+ "f3c6",
+ "f3f6",
+ "f3b7",
+ "f3f7",
+ "f3a8",
+ "f3f8",
+ "g3e1",
+ "g3f1",
+ "g3g1",
+ "g3h1",
+ "g3e2",
+ "g3f2",
+ "g3g2",
+ "g3h2",
+ "g3a3",
+ "g3b3",
+ "g3c3",
+ "g3d3",
+ "g3e3",
+ "g3f3",
+ "g3h3",
+ "g3e4",
+ "g3f4",
+ "g3g4",
+ "g3h4",
+ "g3e5",
+ "g3f5",
+ "g3g5",
+ "g3h5",
+ "g3d6",
+ "g3g6",
+ "g3c7",
+ "g3g7",
+ "g3b8",
+ "g3g8",
+ "h3f1",
+ "h3g1",
+ "h3h1",
+ "h3f2",
+ "h3g2",
+ "h3h2",
+ "h3a3",
+ "h3b3",
+ "h3c3",
+ "h3d3",
+ "h3e3",
+ "h3f3",
+ "h3g3",
+ "h3f4",
+ "h3g4",
+ "h3h4",
+ "h3f5",
+ "h3g5",
+ "h3h5",
+ "h3e6",
+ "h3h6",
+ "h3d7",
+ "h3h7",
+ "h3c8",
+ "h3h8",
+ "a4a1",
+ "a4d1",
+ "a4a2",
+ "a4b2",
+ "a4c2",
+ "a4a3",
+ "a4b3",
+ "a4c3",
+ "a4b4",
+ "a4c4",
+ "a4d4",
+ "a4e4",
+ "a4f4",
+ "a4g4",
+ "a4h4",
+ "a4a5",
+ "a4b5",
+ "a4c5",
+ "a4a6",
+ "a4b6",
+ "a4c6",
+ "a4a7",
+ "a4d7",
+ "a4a8",
+ "a4e8",
+ "b4b1",
+ "b4e1",
+ "b4a2",
+ "b4b2",
+ "b4c2",
+ "b4d2",
+ "b4a3",
+ "b4b3",
+ "b4c3",
+ "b4d3",
+ "b4a4",
+ "b4c4",
+ "b4d4",
+ "b4e4",
+ "b4f4",
+ "b4g4",
+ "b4h4",
+ "b4a5",
+ "b4b5",
+ "b4c5",
+ "b4d5",
+ "b4a6",
+ "b4b6",
+ "b4c6",
+ "b4d6",
+ "b4b7",
+ "b4e7",
+ "b4b8",
+ "b4f8",
+ "c4c1",
+ "c4f1",
+ "c4a2",
+ "c4b2",
+ "c4c2",
+ "c4d2",
+ "c4e2",
+ "c4a3",
+ "c4b3",
+ "c4c3",
+ "c4d3",
+ "c4e3",
+ "c4a4",
+ "c4b4",
+ "c4d4",
+ "c4e4",
+ "c4f4",
+ "c4g4",
+ "c4h4",
+ "c4a5",
+ "c4b5",
+ "c4c5",
+ "c4d5",
+ "c4e5",
+ "c4a6",
+ "c4b6",
+ "c4c6",
+ "c4d6",
+ "c4e6",
+ "c4c7",
+ "c4f7",
+ "c4c8",
+ "c4g8",
+ "d4a1",
+ "d4d1",
+ "d4g1",
+ "d4b2",
+ "d4c2",
+ "d4d2",
+ "d4e2",
+ "d4f2",
+ "d4b3",
+ "d4c3",
+ "d4d3",
+ "d4e3",
+ "d4f3",
+ "d4a4",
+ "d4b4",
+ "d4c4",
+ "d4e4",
+ "d4f4",
+ "d4g4",
+ "d4h4",
+ "d4b5",
+ "d4c5",
+ "d4d5",
+ "d4e5",
+ "d4f5",
+ "d4b6",
+ "d4c6",
+ "d4d6",
+ "d4e6",
+ "d4f6",
+ "d4a7",
+ "d4d7",
+ "d4g7",
+ "d4d8",
+ "d4h8",
+ "e4b1",
+ "e4e1",
+ "e4h1",
+ "e4c2",
+ "e4d2",
+ "e4e2",
+ "e4f2",
+ "e4g2",
+ "e4c3",
+ "e4d3",
+ "e4e3",
+ "e4f3",
+ "e4g3",
+ "e4a4",
+ "e4b4",
+ "e4c4",
+ "e4d4",
+ "e4f4",
+ "e4g4",
+ "e4h4",
+ "e4c5",
+ "e4d5",
+ "e4e5",
+ "e4f5",
+ "e4g5",
+ "e4c6",
+ "e4d6",
+ "e4e6",
+ "e4f6",
+ "e4g6",
+ "e4b7",
+ "e4e7",
+ "e4h7",
+ "e4a8",
+ "e4e8",
+ "f4c1",
+ "f4f1",
+ "f4d2",
+ "f4e2",
+ "f4f2",
+ "f4g2",
+ "f4h2",
+ "f4d3",
+ "f4e3",
+ "f4f3",
+ "f4g3",
+ "f4h3",
+ "f4a4",
+ "f4b4",
+ "f4c4",
+ "f4d4",
+ "f4e4",
+ "f4g4",
+ "f4h4",
+ "f4d5",
+ "f4e5",
+ "f4f5",
+ "f4g5",
+ "f4h5",
+ "f4d6",
+ "f4e6",
+ "f4f6",
+ "f4g6",
+ "f4h6",
+ "f4c7",
+ "f4f7",
+ "f4b8",
+ "f4f8",
+ "g4d1",
+ "g4g1",
+ "g4e2",
+ "g4f2",
+ "g4g2",
+ "g4h2",
+ "g4e3",
+ "g4f3",
+ "g4g3",
+ "g4h3",
+ "g4a4",
+ "g4b4",
+ "g4c4",
+ "g4d4",
+ "g4e4",
+ "g4f4",
+ "g4h4",
+ "g4e5",
+ "g4f5",
+ "g4g5",
+ "g4h5",
+ "g4e6",
+ "g4f6",
+ "g4g6",
+ "g4h6",
+ "g4d7",
+ "g4g7",
+ "g4c8",
+ "g4g8",
+ "h4e1",
+ "h4h1",
+ "h4f2",
+ "h4g2",
+ "h4h2",
+ "h4f3",
+ "h4g3",
+ "h4h3",
+ "h4a4",
+ "h4b4",
+ "h4c4",
+ "h4d4",
+ "h4e4",
+ "h4f4",
+ "h4g4",
+ "h4f5",
+ "h4g5",
+ "h4h5",
+ "h4f6",
+ "h4g6",
+ "h4h6",
+ "h4e7",
+ "h4h7",
+ "h4d8",
+ "h4h8",
+ "a5a1",
+ "a5e1",
+ "a5a2",
+ "a5d2",
+ "a5a3",
+ "a5b3",
+ "a5c3",
+ "a5a4",
+ "a5b4",
+ "a5c4",
+ "a5b5",
+ "a5c5",
+ "a5d5",
+ "a5e5",
+ "a5f5",
+ "a5g5",
+ "a5h5",
+ "a5a6",
+ "a5b6",
+ "a5c6",
+ "a5a7",
+ "a5b7",
+ "a5c7",
+ "a5a8",
+ "a5d8",
+ "b5b1",
+ "b5f1",
+ "b5b2",
+ "b5e2",
+ "b5a3",
+ "b5b3",
+ "b5c3",
+ "b5d3",
+ "b5a4",
+ "b5b4",
+ "b5c4",
+ "b5d4",
+ "b5a5",
+ "b5c5",
+ "b5d5",
+ "b5e5",
+ "b5f5",
+ "b5g5",
+ "b5h5",
+ "b5a6",
+ "b5b6",
+ "b5c6",
+ "b5d6",
+ "b5a7",
+ "b5b7",
+ "b5c7",
+ "b5d7",
+ "b5b8",
+ "b5e8",
+ "c5c1",
+ "c5g1",
+ "c5c2",
+ "c5f2",
+ "c5a3",
+ "c5b3",
+ "c5c3",
+ "c5d3",
+ "c5e3",
+ "c5a4",
+ "c5b4",
+ "c5c4",
+ "c5d4",
+ "c5e4",
+ "c5a5",
+ "c5b5",
+ "c5d5",
+ "c5e5",
+ "c5f5",
+ "c5g5",
+ "c5h5",
+ "c5a6",
+ "c5b6",
+ "c5c6",
+ "c5d6",
+ "c5e6",
+ "c5a7",
+ "c5b7",
+ "c5c7",
+ "c5d7",
+ "c5e7",
+ "c5c8",
+ "c5f8",
+ "d5d1",
+ "d5h1",
+ "d5a2",
+ "d5d2",
+ "d5g2",
+ "d5b3",
+ "d5c3",
+ "d5d3",
+ "d5e3",
+ "d5f3",
+ "d5b4",
+ "d5c4",
+ "d5d4",
+ "d5e4",
+ "d5f4",
+ "d5a5",
+ "d5b5",
+ "d5c5",
+ "d5e5",
+ "d5f5",
+ "d5g5",
+ "d5h5",
+ "d5b6",
+ "d5c6",
+ "d5d6",
+ "d5e6",
+ "d5f6",
+ "d5b7",
+ "d5c7",
+ "d5d7",
+ "d5e7",
+ "d5f7",
+ "d5a8",
+ "d5d8",
+ "d5g8",
+ "e5a1",
+ "e5e1",
+ "e5b2",
+ "e5e2",
+ "e5h2",
+ "e5c3",
+ "e5d3",
+ "e5e3",
+ "e5f3",
+ "e5g3",
+ "e5c4",
+ "e5d4",
+ "e5e4",
+ "e5f4",
+ "e5g4",
+ "e5a5",
+ "e5b5",
+ "e5c5",
+ "e5d5",
+ "e5f5",
+ "e5g5",
+ "e5h5",
+ "e5c6",
+ "e5d6",
+ "e5e6",
+ "e5f6",
+ "e5g6",
+ "e5c7",
+ "e5d7",
+ "e5e7",
+ "e5f7",
+ "e5g7",
+ "e5b8",
+ "e5e8",
+ "e5h8",
+ "f5b1",
+ "f5f1",
+ "f5c2",
+ "f5f2",
+ "f5d3",
+ "f5e3",
+ "f5f3",
+ "f5g3",
+ "f5h3",
+ "f5d4",
+ "f5e4",
+ "f5f4",
+ "f5g4",
+ "f5h4",
+ "f5a5",
+ "f5b5",
+ "f5c5",
+ "f5d5",
+ "f5e5",
+ "f5g5",
+ "f5h5",
+ "f5d6",
+ "f5e6",
+ "f5f6",
+ "f5g6",
+ "f5h6",
+ "f5d7",
+ "f5e7",
+ "f5f7",
+ "f5g7",
+ "f5h7",
+ "f5c8",
+ "f5f8",
+ "g5c1",
+ "g5g1",
+ "g5d2",
+ "g5g2",
+ "g5e3",
+ "g5f3",
+ "g5g3",
+ "g5h3",
+ "g5e4",
+ "g5f4",
+ "g5g4",
+ "g5h4",
+ "g5a5",
+ "g5b5",
+ "g5c5",
+ "g5d5",
+ "g5e5",
+ "g5f5",
+ "g5h5",
+ "g5e6",
+ "g5f6",
+ "g5g6",
+ "g5h6",
+ "g5e7",
+ "g5f7",
+ "g5g7",
+ "g5h7",
+ "g5d8",
+ "g5g8",
+ "h5d1",
+ "h5h1",
+ "h5e2",
+ "h5h2",
+ "h5f3",
+ "h5g3",
+ "h5h3",
+ "h5f4",
+ "h5g4",
+ "h5h4",
+ "h5a5",
+ "h5b5",
+ "h5c5",
+ "h5d5",
+ "h5e5",
+ "h5f5",
+ "h5g5",
+ "h5f6",
+ "h5g6",
+ "h5h6",
+ "h5f7",
+ "h5g7",
+ "h5h7",
+ "h5e8",
+ "h5h8",
+ "a6a1",
+ "a6f1",
+ "a6a2",
+ "a6e2",
+ "a6a3",
+ "a6d3",
+ "a6a4",
+ "a6b4",
+ "a6c4",
+ "a6a5",
+ "a6b5",
+ "a6c5",
+ "a6b6",
+ "a6c6",
+ "a6d6",
+ "a6e6",
+ "a6f6",
+ "a6g6",
+ "a6h6",
+ "a6a7",
+ "a6b7",
+ "a6c7",
+ "a6a8",
+ "a6b8",
+ "a6c8",
+ "b6b1",
+ "b6g1",
+ "b6b2",
+ "b6f2",
+ "b6b3",
+ "b6e3",
+ "b6a4",
+ "b6b4",
+ "b6c4",
+ "b6d4",
+ "b6a5",
+ "b6b5",
+ "b6c5",
+ "b6d5",
+ "b6a6",
+ "b6c6",
+ "b6d6",
+ "b6e6",
+ "b6f6",
+ "b6g6",
+ "b6h6",
+ "b6a7",
+ "b6b7",
+ "b6c7",
+ "b6d7",
+ "b6a8",
+ "b6b8",
+ "b6c8",
+ "b6d8",
+ "c6c1",
+ "c6h1",
+ "c6c2",
+ "c6g2",
+ "c6c3",
+ "c6f3",
+ "c6a4",
+ "c6b4",
+ "c6c4",
+ "c6d4",
+ "c6e4",
+ "c6a5",
+ "c6b5",
+ "c6c5",
+ "c6d5",
+ "c6e5",
+ "c6a6",
+ "c6b6",
+ "c6d6",
+ "c6e6",
+ "c6f6",
+ "c6g6",
+ "c6h6",
+ "c6a7",
+ "c6b7",
+ "c6c7",
+ "c6d7",
+ "c6e7",
+ "c6a8",
+ "c6b8",
+ "c6c8",
+ "c6d8",
+ "c6e8",
+ "d6d1",
+ "d6d2",
+ "d6h2",
+ "d6a3",
+ "d6d3",
+ "d6g3",
+ "d6b4",
+ "d6c4",
+ "d6d4",
+ "d6e4",
+ "d6f4",
+ "d6b5",
+ "d6c5",
+ "d6d5",
+ "d6e5",
+ "d6f5",
+ "d6a6",
+ "d6b6",
+ "d6c6",
+ "d6e6",
+ "d6f6",
+ "d6g6",
+ "d6h6",
+ "d6b7",
+ "d6c7",
+ "d6d7",
+ "d6e7",
+ "d6f7",
+ "d6b8",
+ "d6c8",
+ "d6d8",
+ "d6e8",
+ "d6f8",
+ "e6e1",
+ "e6a2",
+ "e6e2",
+ "e6b3",
+ "e6e3",
+ "e6h3",
+ "e6c4",
+ "e6d4",
+ "e6e4",
+ "e6f4",
+ "e6g4",
+ "e6c5",
+ "e6d5",
+ "e6e5",
+ "e6f5",
+ "e6g5",
+ "e6a6",
+ "e6b6",
+ "e6c6",
+ "e6d6",
+ "e6f6",
+ "e6g6",
+ "e6h6",
+ "e6c7",
+ "e6d7",
+ "e6e7",
+ "e6f7",
+ "e6g7",
+ "e6c8",
+ "e6d8",
+ "e6e8",
+ "e6f8",
+ "e6g8",
+ "f6a1",
+ "f6f1",
+ "f6b2",
+ "f6f2",
+ "f6c3",
+ "f6f3",
+ "f6d4",
+ "f6e4",
+ "f6f4",
+ "f6g4",
+ "f6h4",
+ "f6d5",
+ "f6e5",
+ "f6f5",
+ "f6g5",
+ "f6h5",
+ "f6a6",
+ "f6b6",
+ "f6c6",
+ "f6d6",
+ "f6e6",
+ "f6g6",
+ "f6h6",
+ "f6d7",
+ "f6e7",
+ "f6f7",
+ "f6g7",
+ "f6h7",
+ "f6d8",
+ "f6e8",
+ "f6f8",
+ "f6g8",
+ "f6h8",
+ "g6b1",
+ "g6g1",
+ "g6c2",
+ "g6g2",
+ "g6d3",
+ "g6g3",
+ "g6e4",
+ "g6f4",
+ "g6g4",
+ "g6h4",
+ "g6e5",
+ "g6f5",
+ "g6g5",
+ "g6h5",
+ "g6a6",
+ "g6b6",
+ "g6c6",
+ "g6d6",
+ "g6e6",
+ "g6f6",
+ "g6h6",
+ "g6e7",
+ "g6f7",
+ "g6g7",
+ "g6h7",
+ "g6e8",
+ "g6f8",
+ "g6g8",
+ "g6h8",
+ "h6c1",
+ "h6h1",
+ "h6d2",
+ "h6h2",
+ "h6e3",
+ "h6h3",
+ "h6f4",
+ "h6g4",
+ "h6h4",
+ "h6f5",
+ "h6g5",
+ "h6h5",
+ "h6a6",
+ "h6b6",
+ "h6c6",
+ "h6d6",
+ "h6e6",
+ "h6f6",
+ "h6g6",
+ "h6f7",
+ "h6g7",
+ "h6h7",
+ "h6f8",
+ "h6g8",
+ "h6h8",
+ "a7a1",
+ "a7g1",
+ "a7a2",
+ "a7f2",
+ "a7a3",
+ "a7e3",
+ "a7a4",
+ "a7d4",
+ "a7a5",
+ "a7b5",
+ "a7c5",
+ "a7a6",
+ "a7b6",
+ "a7c6",
+ "a7b7",
+ "a7c7",
+ "a7d7",
+ "a7e7",
+ "a7f7",
+ "a7g7",
+ "a7h7",
+ "a7a8",
+ "a7b8",
+ "a7c8",
+ "b7b1",
+ "b7h1",
+ "b7b2",
+ "b7g2",
+ "b7b3",
+ "b7f3",
+ "b7b4",
+ "b7e4",
+ "b7a5",
+ "b7b5",
+ "b7c5",
+ "b7d5",
+ "b7a6",
+ "b7b6",
+ "b7c6",
+ "b7d6",
+ "b7a7",
+ "b7c7",
+ "b7d7",
+ "b7e7",
+ "b7f7",
+ "b7g7",
+ "b7h7",
+ "b7a8",
+ "b7b8",
+ "b7c8",
+ "b7d8",
+ "c7c1",
+ "c7c2",
+ "c7h2",
+ "c7c3",
+ "c7g3",
+ "c7c4",
+ "c7f4",
+ "c7a5",
+ "c7b5",
+ "c7c5",
+ "c7d5",
+ "c7e5",
+ "c7a6",
+ "c7b6",
+ "c7c6",
+ "c7d6",
+ "c7e6",
+ "c7a7",
+ "c7b7",
+ "c7d7",
+ "c7e7",
+ "c7f7",
+ "c7g7",
+ "c7h7",
+ "c7a8",
+ "c7b8",
+ "c7c8",
+ "c7d8",
+ "c7e8",
+ "d7d1",
+ "d7d2",
+ "d7d3",
+ "d7h3",
+ "d7a4",
+ "d7d4",
+ "d7g4",
+ "d7b5",
+ "d7c5",
+ "d7d5",
+ "d7e5",
+ "d7f5",
+ "d7b6",
+ "d7c6",
+ "d7d6",
+ "d7e6",
+ "d7f6",
+ "d7a7",
+ "d7b7",
+ "d7c7",
+ "d7e7",
+ "d7f7",
+ "d7g7",
+ "d7h7",
+ "d7b8",
+ "d7c8",
+ "d7d8",
+ "d7e8",
+ "d7f8",
+ "e7e1",
+ "e7e2",
+ "e7a3",
+ "e7e3",
+ "e7b4",
+ "e7e4",
+ "e7h4",
+ "e7c5",
+ "e7d5",
+ "e7e5",
+ "e7f5",
+ "e7g5",
+ "e7c6",
+ "e7d6",
+ "e7e6",
+ "e7f6",
+ "e7g6",
+ "e7a7",
+ "e7b7",
+ "e7c7",
+ "e7d7",
+ "e7f7",
+ "e7g7",
+ "e7h7",
+ "e7c8",
+ "e7d8",
+ "e7e8",
+ "e7f8",
+ "e7g8",
+ "f7f1",
+ "f7a2",
+ "f7f2",
+ "f7b3",
+ "f7f3",
+ "f7c4",
+ "f7f4",
+ "f7d5",
+ "f7e5",
+ "f7f5",
+ "f7g5",
+ "f7h5",
+ "f7d6",
+ "f7e6",
+ "f7f6",
+ "f7g6",
+ "f7h6",
+ "f7a7",
+ "f7b7",
+ "f7c7",
+ "f7d7",
+ "f7e7",
+ "f7g7",
+ "f7h7",
+ "f7d8",
+ "f7e8",
+ "f7f8",
+ "f7g8",
+ "f7h8",
+ "g7a1",
+ "g7g1",
+ "g7b2",
+ "g7g2",
+ "g7c3",
+ "g7g3",
+ "g7d4",
+ "g7g4",
+ "g7e5",
+ "g7f5",
+ "g7g5",
+ "g7h5",
+ "g7e6",
+ "g7f6",
+ "g7g6",
+ "g7h6",
+ "g7a7",
+ "g7b7",
+ "g7c7",
+ "g7d7",
+ "g7e7",
+ "g7f7",
+ "g7h7",
+ "g7e8",
+ "g7f8",
+ "g7g8",
+ "g7h8",
+ "h7b1",
+ "h7h1",
+ "h7c2",
+ "h7h2",
+ "h7d3",
+ "h7h3",
+ "h7e4",
+ "h7h4",
+ "h7f5",
+ "h7g5",
+ "h7h5",
+ "h7f6",
+ "h7g6",
+ "h7h6",
+ "h7a7",
+ "h7b7",
+ "h7c7",
+ "h7d7",
+ "h7e7",
+ "h7f7",
+ "h7g7",
+ "h7f8",
+ "h7g8",
+ "h7h8",
+ "a8a1",
+ "a8h1",
+ "a8a2",
+ "a8g2",
+ "a8a3",
+ "a8f3",
+ "a8a4",
+ "a8e4",
+ "a8a5",
+ "a8d5",
+ "a8a6",
+ "a8b6",
+ "a8c6",
+ "a8a7",
+ "a8b7",
+ "a8c7",
+ "a8b8",
+ "a8c8",
+ "a8d8",
+ "a8e8",
+ "a8f8",
+ "a8g8",
+ "a8h8",
+ "b8b1",
+ "b8b2",
+ "b8h2",
+ "b8b3",
+ "b8g3",
+ "b8b4",
+ "b8f4",
+ "b8b5",
+ "b8e5",
+ "b8a6",
+ "b8b6",
+ "b8c6",
+ "b8d6",
+ "b8a7",
+ "b8b7",
+ "b8c7",
+ "b8d7",
+ "b8a8",
+ "b8c8",
+ "b8d8",
+ "b8e8",
+ "b8f8",
+ "b8g8",
+ "b8h8",
+ "c8c1",
+ "c8c2",
+ "c8c3",
+ "c8h3",
+ "c8c4",
+ "c8g4",
+ "c8c5",
+ "c8f5",
+ "c8a6",
+ "c8b6",
+ "c8c6",
+ "c8d6",
+ "c8e6",
+ "c8a7",
+ "c8b7",
+ "c8c7",
+ "c8d7",
+ "c8e7",
+ "c8a8",
+ "c8b8",
+ "c8d8",
+ "c8e8",
+ "c8f8",
+ "c8g8",
+ "c8h8",
+ "d8d1",
+ "d8d2",
+ "d8d3",
+ "d8d4",
+ "d8h4",
+ "d8a5",
+ "d8d5",
+ "d8g5",
+ "d8b6",
+ "d8c6",
+ "d8d6",
+ "d8e6",
+ "d8f6",
+ "d8b7",
+ "d8c7",
+ "d8d7",
+ "d8e7",
+ "d8f7",
+ "d8a8",
+ "d8b8",
+ "d8c8",
+ "d8e8",
+ "d8f8",
+ "d8g8",
+ "d8h8",
+ "e8e1",
+ "e8e2",
+ "e8e3",
+ "e8a4",
+ "e8e4",
+ "e8b5",
+ "e8e5",
+ "e8h5",
+ "e8c6",
+ "e8d6",
+ "e8e6",
+ "e8f6",
+ "e8g6",
+ "e8c7",
+ "e8d7",
+ "e8e7",
+ "e8f7",
+ "e8g7",
+ "e8a8",
+ "e8b8",
+ "e8c8",
+ "e8d8",
+ "e8f8",
+ "e8g8",
+ "e8h8",
+ "f8f1",
+ "f8f2",
+ "f8a3",
+ "f8f3",
+ "f8b4",
+ "f8f4",
+ "f8c5",
+ "f8f5",
+ "f8d6",
+ "f8e6",
+ "f8f6",
+ "f8g6",
+ "f8h6",
+ "f8d7",
+ "f8e7",
+ "f8f7",
+ "f8g7",
+ "f8h7",
+ "f8a8",
+ "f8b8",
+ "f8c8",
+ "f8d8",
+ "f8e8",
+ "f8g8",
+ "f8h8",
+ "g8g1",
+ "g8a2",
+ "g8g2",
+ "g8b3",
+ "g8g3",
+ "g8c4",
+ "g8g4",
+ "g8d5",
+ "g8g5",
+ "g8e6",
+ "g8f6",
+ "g8g6",
+ "g8h6",
+ "g8e7",
+ "g8f7",
+ "g8g7",
+ "g8h7",
+ "g8a8",
+ "g8b8",
+ "g8c8",
+ "g8d8",
+ "g8e8",
+ "g8f8",
+ "g8h8",
+ "h8a1",
+ "h8h1",
+ "h8b2",
+ "h8h2",
+ "h8c3",
+ "h8h3",
+ "h8d4",
+ "h8h4",
+ "h8e5",
+ "h8h5",
+ "h8f6",
+ "h8g6",
+ "h8h6",
+ "h8f7",
+ "h8g7",
+ "h8h7",
+ "h8a8",
+ "h8b8",
+ "h8c8",
+ "h8d8",
+ "h8e8",
+ "h8f8",
+ "h8g8",
+ "a7a8q",
+ "a7a8r",
+ "a7a8b",
+ "a7b8q",
+ "a7b8r",
+ "a7b8b",
+ "b7a8q",
+ "b7a8r",
+ "b7a8b",
+ "b7b8q",
+ "b7b8r",
+ "b7b8b",
+ "b7c8q",
+ "b7c8r",
+ "b7c8b",
+ "c7b8q",
+ "c7b8r",
+ "c7b8b",
+ "c7c8q",
+ "c7c8r",
+ "c7c8b",
+ "c7d8q",
+ "c7d8r",
+ "c7d8b",
+ "d7c8q",
+ "d7c8r",
+ "d7c8b",
+ "d7d8q",
+ "d7d8r",
+ "d7d8b",
+ "d7e8q",
+ "d7e8r",
+ "d7e8b",
+ "e7d8q",
+ "e7d8r",
+ "e7d8b",
+ "e7e8q",
+ "e7e8r",
+ "e7e8b",
+ "e7f8q",
+ "e7f8r",
+ "e7f8b",
+ "f7e8q",
+ "f7e8r",
+ "f7e8b",
+ "f7f8q",
+ "f7f8r",
+ "f7f8b",
+ "f7g8q",
+ "f7g8r",
+ "f7g8b",
+ "g7f8q",
+ "g7f8r",
+ "g7f8b",
+ "g7g8q",
+ "g7g8r",
+ "g7g8b",
+ "g7h8q",
+ "g7h8r",
+ "g7h8b",
+ "h7g8q",
+ "h7g8r",
+ "h7g8b",
+ "h7h8q",
+ "h7h8r",
+ "h7h8b"
+]
+
+class Board:
+ def __init__(self):
+ self.clear_board()
+
+ def clear_board(self):
+ self.board = []
+ for rank in range(8):
+ self.board.append(list("."*8))
+ self.reps = 0
+
+ def describe(self):
+ s = []
+ for rank in range(8):
+ s.append("".join(self.board[rank]))
+ s.append("reps {} ".format(self.reps))
+ return s
+
+class TrainingStep:
+ def __init__(self, version):
+ self.version = version
+ # Construct a fake parser just to get access to it's variables
+ self.parser = chunkparser.ChunkParser(chunkparser.ChunkDataSrc([]), workers=1)
+ self.NUM_HIST = 8
+ self.NUM_PIECE_TYPES = 6
+ self.V3_NUM_PLANES = self.NUM_PIECE_TYPES*2+1 # = 13 (6*2 us/them pieces, rep1 (no rep2))
+ self.NUM_PLANES = self.V3_NUM_PLANES
+ self.NUM_REALS = 7 # 4 castling, 1 color, 1 50rule, 1 movecount
+ self.NUM_OUTPUTS = 2 # policy, value
+ self.NUM_PLANES_BYTES = self.NUM_PLANES*4
+ self.NUM_PLANES_BYTES = self.NUM_PLANES*4
+ self.NUM_PLANES_BYTES = self.NUM_PLANES*4
+
+ self.V3_NUM_POLICY_MOVES = 1858 # (7432 bytes)
+ self.NUM_POLICY_MOVES = self.V3_NUM_POLICY_MOVES
+
+ self.init_structs()
+ self.init_move_map()
+ self.history = []
+ self.probs = []
+ for history in range(self.NUM_HIST):
+ self.history.append(Board())
+ self.us_ooo = 0
+ self.us_oo = 0
+ self.them_ooo = 0
+ self.them_oo = 0
+ self.us_black = 0
+ self.rule50_count = 0
+ self.winner = None
+ self.q = None
+
+ def init_structs(self):
+ self.v4_struct = self.parser.v4_struct
+ self.this_struct = self.v4_struct
+
+ def init_move_map(self):
+ self.new_white_move_map = defaultdict(lambda:-1)
+ self.new_black_move_map = defaultdict(lambda:-1)
+ self.old_rev_move_map = {}
+ self.new_rev_white_move_map = {}
+ self.new_rev_black_move_map = {}
+
+ for idx, m in enumerate(MOVES):
+ self.new_white_move_map[m] = idx
+ self.new_rev_white_move_map[idx] = m
+ m_black = m.translate(str.maketrans("12345678", "87654321"))
+ self.new_black_move_map[m_black] = idx
+ self.new_rev_black_move_map[idx] = m_black
+
+ def clear_hist(self):
+ for hist in range(self.NUM_HIST):
+ self.history.clear_board()
+
+ def update_board(self, hist, piece, bit_board):
+ """
+ Update the ASCII board representation
+ """
+ for r in range(8):
+ for f in range(8):
+ # Note: Using 8-1-f because both the text and binary have the
+ # column bits reversed fhom what this code expects
+ if bit_board & (1<<(r*8+(8-1-f))):
+ assert(self.history[hist].board[r][f] == ".")
+ self.history[hist].board[r][f] = piece
+
+ def describe(self):
+ s = ""
+ if self.us_black:
+ s += "us = Black"
+ else:
+ s += "us = White"
+ if self.winner == 1:
+ s += " won\n"
+ elif self.winner == -1:
+ s += " lost\n"
+ elif self.winner == 0:
+ s += " draw\n"
+ else:
+ raise Exception("Invalid winner: {}".format(self.winner))
+ s += "Root Q = {} (diff to result: {}) \n".format(self.root_q, abs(self.winner - self.root_q))
+ s += "Best Q = {} (diff to result: {}) \n".format(self.best_q, abs(self.winner - self.best_q))
+ if self.us_black:
+ s += "(Note the black pieces are CAPS, black moves up, but A1 is in lower left)\n"
+ s += "rule50_count {} b_ooo b_oo, w_ooo, w_oo {} {} {} {}\n".format(
+ self.rule50_count, self.us_ooo, self.us_oo, self.them_ooo, self.them_oo)
+ s += " abcdefgh\n"
+ rank_strings = [[]]
+ for rank in reversed(range(8)):
+ rank_strings[0].append("{}".format(rank+1))
+ rank_strings[0].append(" ")
+ for hist in range(self.NUM_HIST):
+ rank_strings.append(self.history[hist].describe())
+ for hist in range(self.NUM_HIST+1):
+ for rank in range(8+1):
+ #if hist == 8 and rank == 0: continue
+ s += rank_strings[rank][hist] + " "
+ s += "\n"
+ sum = 0.0
+ top_moves = {}
+ for idx, prob in enumerate(self.probs):
+ # Include all moves with at least 1 visit.
+ condition = prob > 0 if self.version == 3 else prob >= 0
+ if condition:
+ top_moves[idx] = prob
+ sum += prob
+ for idx, prob in sorted(top_moves.items(), key=lambda x:-x[1]):
+ s += "{} {:4.1f}%\n".format(self.new_rev_white_move_map[idx], prob*100)
+ #print("debug prob sum", sum, "cnt", len(self.probs))
+ return s
+
+ def update_reals(self, text_item):
+ self.us_ooo = int(text_item[self.NUM_HIST*self.NUM_PLANES+0])
+ self.us_oo = int(text_item[self.NUM_HIST*self.NUM_PLANES+1])
+ self.them_ooo = int(text_item[self.NUM_HIST*self.NUM_PLANES+2])
+ self.them_oo = int(text_item[self.NUM_HIST*self.NUM_PLANES+3])
+ self.us_black = int(text_item[self.NUM_HIST*self.NUM_PLANES+4])
+ self.rule50_count = min(int(text_item[self.NUM_HIST*self.NUM_PLANES+5]), 255)
+ # should be around 99-102ish
+ assert self.rule50_count < 105
+
+ def flip_single_v1_plane(self, plane):
+ # Split hexstring into bytes (2 ascii chars), reverse, rejoin
+ # This causes a vertical flip
+ return "".join([plane[x:x+2] for x in reversed(range(0, len(plane), 2))])
+
+ def display_v4(self, ply, content):
+ (ver, probs, planes, us_ooo, us_oo, them_ooo, them_oo, us_black, rule50_count, move_count, winner, root_q, best_q, root_d, best_d) = self.this_struct.unpack(content)
+ assert self.version == int.from_bytes(ver, byteorder="little")
+ # Enforce move_count to 0
+ move_count = 0
+ # Unpack planes.
+ for hist in range(self.NUM_HIST):
+ for idx, piece in enumerate(PIECES):
+ start = hist*self.NUM_PLANES*8+idx*8
+ end = start + 8
+ self.update_board(hist, piece, int.from_bytes(planes[start:end], byteorder="big"))
+ if planes[hist*self.NUM_PLANES*8+12*8:hist*self.NUM_PLANES*8+12*8+8] != struct.pack('II', 0, 0):
+ self.history[hist].reps = 1
+ assert planes[hist*self.NUM_PLANES*8+12*8:hist*self.NUM_PLANES*8+12*8+8] == struct.pack('II', 0xffffffff, 0xffffffff)
+ self.us_ooo = us_ooo
+ self.us_oo = us_oo
+ self.them_ooo = them_ooo
+ self.them_oo = them_oo
+ self.us_black = us_black
+ self.rule50_count = rule50_count
+ self.winner = winner
+ self.root_q = root_q
+ self.best_q = best_q
+ for idx in range(0, len(probs), 4):
+ self.probs.append(struct.unpack("f", probs[idx:idx+4])[0])
+ print("ply {} move {} (Not actually part of training data)".format(
+ ply+1, (ply+2)//2))
+ print(self.describe())
+
+def main(args):
+ for filename in args.files:
+ #print("Parsing {}".format(filename))
+ with gzip.open(filename, 'rb') as f:
+ chunkdata = f.read()
+ version = chunkdata[0:4]
+ if version in {VERSION4, VERSION3}:
+ if version == VERSION3:
+ record_size = V3_BYTES
+ else:
+ record_size = V4_BYTES
+ for i in range(0, len(chunkdata), record_size):
+ ts = TrainingStep(4 if version == VERSION4 else 3)
+ record = chunkdata[i:i+record_size]
+ if chunkdata[0:4] == VERSION3:
+ record += 16 * b'\x00'
+ ts.display_v4(i//record_size, record)
+ else:
+ print("Invalid version")
+
+if __name__ == '__main__':
+ usage_str = """
+Parse training files and display them."""
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ description=usage_str)
+ parser.add_argument("files", type=str, nargs="+",
+ help="training*.gz")
+ args = parser.parse_args()
+
+ main(args)
diff --git a/backend/tf_transfer/lc0_az_policy_map.py b/backend/tf_transfer/lc0_az_policy_map.py
new file mode 100755
index 0000000..61e9281
--- /dev/null
+++ b/backend/tf_transfer/lc0_az_policy_map.py
@@ -0,0 +1,156 @@
+#!/usr/bin/env python3
+import sys
+import numpy as np
+from .policy_index import policy_index
+
+columns = 'abcdefgh'
+rows = '12345678'
+promotions = 'rbq' # N is encoded as normal move
+
+col_index = {columns[i] : i for i in range(len(columns))}
+row_index = {rows[i] : i for i in range(len(rows))}
+
+def index_to_position(x):
+ return columns[x[0]] + rows[x[1]]
+
+def position_to_index(p):
+ return col_index[p[0]], row_index[p[1]]
+
+def valid_index(i):
+ if i[0] > 7 or i[0] < 0:
+ return False
+ if i[1] > 7 or i[1] < 0:
+ return False
+ return True
+
+def queen_move(start, direction, steps):
+ i = position_to_index(start)
+ dir_vectors = {'N': (0, 1), 'NE': (1, 1), 'E': (1, 0), 'SE': (1, -1),
+ 'S':(0, -1), 'SW':(-1, -1), 'W': (-1, 0), 'NW': (-1, 1)}
+ v = dir_vectors[direction]
+ i = i[0] + v[0] * steps, i[1] + v[1] * steps
+ if not valid_index(i):
+ return None
+ return index_to_position(i)
+
+def knight_move(start, direction, steps):
+ i = position_to_index(start)
+ dir_vectors = {'N': (1, 2), 'NE': (2, 1), 'E': (2, -1), 'SE': (1, -2),
+ 'S':(-1, -2), 'SW':(-2, -1), 'W': (-2, 1), 'NW': (-1, 2)}
+ v = dir_vectors[direction]
+ i = i[0] + v[0] * steps, i[1] + v[1] * steps
+ if not valid_index(i):
+ return None
+ return index_to_position(i)
+
+def make_map(kind='matrix'):
+ # 56 planes of queen moves
+ moves = []
+ for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']:
+ for steps in range(1, 8):
+ for r0 in rows:
+ for c0 in columns:
+ start = c0 + r0
+ end = queen_move(start, direction, steps)
+ if end == None:
+ moves.append('illegal')
+ else:
+ moves.append(start+end)
+
+ # 8 planes of knight moves
+ for direction in ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']:
+ for r0 in rows:
+ for c0 in columns:
+ start = c0 + r0
+ end = knight_move(start, direction, 1)
+ if end == None:
+ moves.append('illegal')
+ else:
+ moves.append(start+end)
+
+ # 9 promotions
+ for direction in ['NW', 'N', 'NE']:
+ for promotion in promotions:
+ for r0 in rows:
+ for c0 in columns:
+ # Promotion only in the second last rank
+ if r0 != '7':
+ moves.append('illegal')
+ continue
+ start = c0 + r0
+ end = queen_move(start, direction, 1)
+ if end == None:
+ moves.append('illegal')
+ else:
+ moves.append(start+end+promotion)
+
+ for m in policy_index:
+ if m not in moves:
+ raise ValueError('Missing move: {}'.format(m))
+
+ az_to_lc0 = np.zeros((80*8*8, len(policy_index)), dtype=np.float32)
+ indices = []
+ legal_moves = 0
+ for e, m in enumerate(moves):
+ if m == 'illegal':
+ indices.append(-1)
+ continue
+ legal_moves += 1
+ # Check for missing moves
+ if m not in policy_index:
+ raise ValueError('Missing move: {}'.format(m))
+ i = policy_index.index(m)
+ indices.append(i)
+ az_to_lc0[e][i] = 1
+
+ assert legal_moves == len(policy_index)
+ assert np.sum(az_to_lc0) == legal_moves
+ for e in range(80*8*8):
+ for i in range(len(policy_index)):
+ pass
+ if kind == 'matrix':
+ return az_to_lc0
+ elif kind == 'index':
+ return indices
+
+if __name__ == "__main__":
+ # Generate policy map include file for lc0
+ if len(sys.argv) != 2:
+ raise ValueError("Output filename is needed as a command line argument")
+
+ az_to_lc0 = np.ravel(make_map('index'))
+ header = \
+"""/*
+ This file is part of Leela Chess Zero.
+ Copyright (C) 2019 The LCZero Authors
+
+ Leela Chess is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ Leela Chess is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with Leela Chess. If not, see .
+ */
+
+#pragma once
+
+namespace lczero {
+"""
+ line_length = 12
+ with open(sys.argv[1], 'w') as f:
+ f.write(header+'\n')
+ f.write('const short kConvPolicyMap[] = {\\\n')
+ for e, i in enumerate(az_to_lc0):
+ if e % line_length == 0 and e > 0:
+ f.write('\n')
+ f.write(str(i).rjust(5))
+ if e != len(az_to_lc0)-1:
+ f.write(',')
+ f.write('};\n\n')
+ f.write('} // namespace lczero')
diff --git a/backend/tf_transfer/net.py b/backend/tf_transfer/net.py
new file mode 100755
index 0000000..d3ed5b5
--- /dev/null
+++ b/backend/tf_transfer/net.py
@@ -0,0 +1,357 @@
+#!/usr/bin/env python3
+
+import argparse
+import gzip
+import os
+import numpy as np
+import backend.proto.net_pb2 as pb
+
+LC0_MAJOR = 0
+LC0_MINOR = 21
+LC0_PATCH = 0
+WEIGHTS_MAGIC = 0x1c0
+
+
+class Net:
+ def __init__(self,
+ net=pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT,
+ input=pb.NetworkFormat.INPUT_CLASSICAL_112_PLANE,
+ value=pb.NetworkFormat.VALUE_CLASSICAL,
+ policy=pb.NetworkFormat.POLICY_CLASSICAL):
+
+ if net == pb.NetworkFormat.NETWORK_SE:
+ net = pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT
+ if net == pb.NetworkFormat.NETWORK_CLASSICAL:
+ net = pb.NetworkFormat.NETWORK_CLASSICAL_WITH_HEADFORMAT
+
+ self.pb = pb.Net()
+ self.pb.magic = WEIGHTS_MAGIC
+ self.pb.min_version.major = LC0_MAJOR
+ self.pb.min_version.minor = LC0_MINOR
+ self.pb.min_version.patch = LC0_PATCH
+ self.pb.format.weights_encoding = pb.Format.LINEAR16
+
+ self.weights = []
+
+ self.set_networkformat(net)
+ self.pb.format.network_format.input = input
+ self.set_policyformat(policy)
+ self.set_valueformat(value)
+
+ def set_networkformat(self, net):
+ self.pb.format.network_format.network = net
+
+ def set_policyformat(self, policy):
+ self.pb.format.network_format.policy = policy
+
+ def set_valueformat(self, value):
+ self.pb.format.network_format.value = value
+
+ # OutputFormat is for search to know which kind of value the net returns.
+ if value == pb.NetworkFormat.VALUE_WDL:
+ self.pb.format.network_format.output = pb.NetworkFormat.OUTPUT_WDL
+ else:
+ self.pb.format.network_format.output = pb.NetworkFormat.OUTPUT_CLASSICAL
+
+ def get_weight_amounts(self):
+ value_weights = 8
+ policy_weights = 6
+ head_weights = value_weights + policy_weights
+ if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:
+ # Batch norm gammas in head convolutions.
+ head_weights += 2
+ if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:
+ return {"input": 5, "residual": 14, "head": head_weights}
+ else:
+ return {"input": 4, "residual": 8, "head": head_weights}
+
+ def fill_layer(self, layer, weights):
+ """Normalize and populate 16bit layer in protobuf"""
+ params = np.array(weights.pop(), dtype=np.float32)
+ layer.min_val = 0 if len(params) == 1 else float(np.min(params))
+ layer.max_val = 1 if len(params) == 1 and np.max(params) == 0 else float(np.max(params))
+ if layer.max_val == layer.min_val:
+ # Avoid division by zero if max == min.
+ params = (params - layer.min_val)
+ else:
+ params = (params - layer.min_val) / (layer.max_val - layer.min_val)
+ params *= 0xffff
+ params = np.round(params)
+ layer.params = params.astype(np.uint16).tobytes()
+
+ def fill_conv_block(self, convblock, weights, gammas):
+ """Normalize and populate 16bit convblock in protobuf"""
+ if gammas:
+ self.fill_layer(convblock.bn_stddivs, weights)
+ self.fill_layer(convblock.bn_means, weights)
+ self.fill_layer(convblock.bn_betas, weights)
+ self.fill_layer(convblock.bn_gammas, weights)
+ self.fill_layer(convblock.weights, weights)
+ else:
+ self.fill_layer(convblock.bn_stddivs, weights)
+ self.fill_layer(convblock.bn_means, weights)
+ self.fill_layer(convblock.biases, weights)
+ self.fill_layer(convblock.weights, weights)
+
+ def fill_plain_conv(self, convblock, weights):
+ """Normalize and populate 16bit convblock in protobuf"""
+ self.fill_layer(convblock.biases, weights)
+ self.fill_layer(convblock.weights, weights)
+
+ def fill_se_unit(self, se_unit, weights):
+ self.fill_layer(se_unit.b2, weights)
+ self.fill_layer(se_unit.w2, weights)
+ self.fill_layer(se_unit.b1, weights)
+ self.fill_layer(se_unit.w1, weights)
+
+ def denorm_layer(self, layer, weights):
+ """Denormalize a layer from protobuf"""
+ params = np.frombuffer(layer.params, np.uint16).astype(np.float32)
+ params /= 0xffff
+ weights.insert(0, params * (layer.max_val - layer.min_val) + layer.min_val)
+
+ def denorm_conv_block(self, convblock, weights):
+ """Denormalize a convblock from protobuf"""
+ se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT
+
+ if se:
+ self.denorm_layer(convblock.bn_stddivs, weights)
+ self.denorm_layer(convblock.bn_means, weights)
+ self.denorm_layer(convblock.bn_betas, weights)
+ self.denorm_layer(convblock.bn_gammas, weights)
+ self.denorm_layer(convblock.weights, weights)
+ else:
+ self.denorm_layer(convblock.bn_stddivs, weights)
+ self.denorm_layer(convblock.bn_means, weights)
+ self.denorm_layer(convblock.biases, weights)
+ self.denorm_layer(convblock.weights, weights)
+
+ def denorm_plain_conv(self, convblock, weights):
+ """Denormalize a plain convolution from protobuf"""
+ self.denorm_layer(convblock.biases, weights)
+ self.denorm_layer(convblock.weights, weights)
+
+ def denorm_se_unit(self, convblock, weights):
+ """Denormalize SE-unit from protobuf"""
+ se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT
+
+ assert se
+
+ self.denorm_layer(convblock.b2, weights)
+ self.denorm_layer(convblock.w2, weights)
+ self.denorm_layer(convblock.b1, weights)
+ self.denorm_layer(convblock.w1, weights)
+
+ def save_txt(self, filename):
+ """Save weights as txt file"""
+ weights = self.get_weights()
+
+ if len(filename.split('.')) == 1:
+ filename += ".txt.gz"
+
+ # Legacy .txt files are version 2, SE is version 3.
+
+ version = 2
+ if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT:
+ version = 3
+
+ if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:
+ version = 4
+
+ with gzip.open(filename, 'wb') as f:
+ f.write("{}\n".format(version).encode('utf-8'))
+ for row in weights:
+ f.write((" ".join(map(str, row.tolist())) + "\n").encode('utf-8'))
+
+ size = os.path.getsize(filename) / 1024**2
+ print("saved as '{}' {}M".format(filename, round(size, 2)))
+
+ def save_proto(self, filename):
+ """Save weights gzipped protobuf file"""
+ if len(filename.split('.')) == 1:
+ filename += ".pb.gz"
+
+ with gzip.open(filename, 'wb') as f:
+ data = self.pb.SerializeToString()
+ f.write(data)
+
+ size = os.path.getsize(filename) / 1024**2
+ print("saved as '{}' {}M".format(filename, round(size, 2)))
+
+ def get_weights(self):
+ """Returns the weights as floats per layer"""
+ se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT
+ if self.weights == []:
+ self.denorm_layer(self.pb.weights.ip2_val_b, self.weights)
+ self.denorm_layer(self.pb.weights.ip2_val_w, self.weights)
+ self.denorm_layer(self.pb.weights.ip1_val_b, self.weights)
+ self.denorm_layer(self.pb.weights.ip1_val_w, self.weights)
+ self.denorm_conv_block(self.pb.weights.value, self.weights)
+
+ if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:
+ self.denorm_plain_conv(self.pb.weights.policy, self.weights)
+ self.denorm_conv_block(self.pb.weights.policy1, self.weights)
+ else:
+ self.denorm_layer(self.pb.weights.ip_pol_b, self.weights)
+ self.denorm_layer(self.pb.weights.ip_pol_w, self.weights)
+ self.denorm_conv_block(self.pb.weights.policy, self.weights)
+
+ for res in reversed(self.pb.weights.residual):
+ if se:
+ self.denorm_se_unit(res.se, self.weights)
+ self.denorm_conv_block(res.conv2, self.weights)
+ self.denorm_conv_block(res.conv1, self.weights)
+
+ self.denorm_conv_block(self.pb.weights.input, self.weights)
+
+ return self.weights
+
+ def filters(self):
+ w = self.get_weights()
+ return len(w[1])
+
+ def blocks(self):
+ w = self.get_weights()
+
+ ws = self.get_weight_amounts()
+ blocks = len(w) - (ws['input'] + ws['head'])
+
+ if blocks % ws['residual'] != 0:
+ raise ValueError("Inconsistent number of weights in the file")
+
+ return blocks // ws['residual']
+
+ def print_stats(self):
+ print("Blocks: {}".format(self.blocks()))
+ print("Filters: {}".format(self.filters()))
+ print_pb_stats(self.pb)
+ print()
+
+ def parse_proto(self, filename):
+ with gzip.open(filename, 'rb') as f:
+ self.pb = self.pb.FromString(f.read())
+ # Populate policyFormat and valueFormat fields in old protobufs
+ # without these fields.
+ if self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE:
+ self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT)
+ self.set_valueformat(pb.NetworkFormat.VALUE_CLASSICAL);
+ self.set_policyformat(pb.NetworkFormat.POLICY_CLASSICAL);
+ elif self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_CLASSICAL:
+ self.set_networkformat(pb.NetworkFormat.NETWORK_CLASSICAL_WITH_HEADFORMAT)
+ self.set_valueformat(pb.NetworkFormat.VALUE_CLASSICAL);
+ self.set_policyformat(pb.NetworkFormat.POLICY_CLASSICAL);
+
+ def parse_txt(self, filename):
+ weights = []
+
+ with open(filename, 'r') as f:
+ try:
+ version = int(f.readline()[0])
+ except:
+ raise ValueError('Unable to read version.')
+ for e, line in enumerate(f):
+ weights.append(list(map(float, line.split(' '))))
+
+ if version == 3:
+ self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT)
+
+ if version == 4:
+ self.set_networkformat(pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT)
+ self.set_policyformat(pb.NetworkFormat.POLICY_CONVOLUTION)
+
+ self.fill_net(weights)
+
+ def fill_net(self, weights):
+ self.weights = []
+ # Batchnorm gammas in ConvBlock?
+ se = self.pb.format.network_format.network == pb.NetworkFormat.NETWORK_SE_WITH_HEADFORMAT
+ gammas = se
+
+ ws = self.get_weight_amounts()
+
+ blocks = len(weights) - (ws['input'] + ws['head'])
+
+ if blocks % ws['residual'] != 0:
+ raise ValueError("Inconsistent number of weights in the file")
+ blocks //= ws['residual']
+
+ self.pb.format.weights_encoding = pb.Format.LINEAR16
+ self.fill_layer(self.pb.weights.ip2_val_b, weights)
+ self.fill_layer(self.pb.weights.ip2_val_w, weights)
+ self.fill_layer(self.pb.weights.ip1_val_b, weights)
+ self.fill_layer(self.pb.weights.ip1_val_w, weights)
+ self.fill_conv_block(self.pb.weights.value, weights, gammas)
+
+ if self.pb.format.network_format.policy == pb.NetworkFormat.POLICY_CONVOLUTION:
+ self.fill_plain_conv(self.pb.weights.policy, weights)
+ self.fill_conv_block(self.pb.weights.policy1, weights, gammas)
+ else:
+ self.fill_layer(self.pb.weights.ip_pol_b, weights)
+ self.fill_layer(self.pb.weights.ip_pol_w, weights)
+ self.fill_conv_block(self.pb.weights.policy, weights, gammas)
+
+ del self.pb.weights.residual[:]
+ tower = []
+ for i in range(blocks):
+ tower.append(self.pb.weights.residual.add())
+
+ for res in reversed(tower):
+ if se:
+ self.fill_se_unit(res.se, weights)
+ self.fill_conv_block(res.conv2, weights, gammas)
+ self.fill_conv_block(res.conv1, weights, gammas)
+
+ self.fill_conv_block(self.pb.weights.input, weights, gammas)
+
+def print_pb_stats(obj, parent=None):
+ for descriptor in obj.DESCRIPTOR.fields:
+ value = getattr(obj, descriptor.name)
+ if descriptor.name == "weights": return
+ if descriptor.type == descriptor.TYPE_MESSAGE:
+ if descriptor.label == descriptor.LABEL_REPEATED:
+ map(print_pb_stats, value)
+ else:
+ print_pb_stats(value, obj)
+ elif descriptor.type == descriptor.TYPE_ENUM:
+ enum_name = descriptor.enum_type.values[value].name
+ print("%s: %s" % (descriptor.full_name, enum_name))
+ else:
+ print("%s: %s" % (descriptor.full_name, value))
+
+def main(argv):
+ net = Net()
+
+ if argv.input.endswith(".txt"):
+ print('Found .txt network')
+ net.parse_txt(argv.input)
+ net.print_stats()
+ if argv.output == None:
+ argv.output = argv.input.replace('.txt', '.pb.gz')
+ assert argv.output.endswith('.pb.gz')
+ print('Writing output to: {}'.format(argv.output))
+ net.save_proto(argv.output)
+ elif argv.input.endswith(".pb.gz"):
+ print('Found .pb.gz network')
+ net.parse_proto(argv.input)
+ net.print_stats()
+ if argv.output == None:
+ argv.output = argv.input.replace('.pb.gz', '.txt.gz')
+ print('Writing output to: {}'.format(argv.output))
+ assert argv.output.endswith('.txt.gz')
+ if argv.output.endswith(".pb.gz"):
+ net.save_proto(argv.output)
+ else:
+ net.save_txt(argv.output)
+ else:
+ print('Unable to detect the network format. '
+ 'Filename should end in ".txt" or ".pb.gz"')
+
+
+if __name__ == "__main__":
+ argparser = argparse.ArgumentParser(
+ description='Convert network textfile to proto.')
+ argparser.add_argument('-i', '--input', type=str,
+ help='input network weight text file')
+ argparser.add_argument('-o', '--output', type=str,
+ help='output filepath without extension')
+ main(argparser.parse_args())
diff --git a/backend/tf_transfer/net_to_model.py b/backend/tf_transfer/net_to_model.py
new file mode 100755
index 0000000..b1e0c70
--- /dev/null
+++ b/backend/tf_transfer/net_to_model.py
@@ -0,0 +1,39 @@
+#!/usr/bin/env python3
+import argparse
+import tensorflow as tf
+import os
+import yaml
+from .tfprocess import TFProcess
+from .net import Net
+
+argparser = argparse.ArgumentParser(description='Convert net to model.')
+argparser.add_argument('net', type=str,
+ help='Net file to be converted to a model checkpoint.')
+argparser.add_argument('--start', type=int, default=0,
+ help='Offset to set global_step to.')
+argparser.add_argument('--cfg', type=argparse.FileType('r'),
+ help='yaml configuration with training parameters')
+args = argparser.parse_args()
+cfg = yaml.safe_load(args.cfg.read())
+print(yaml.dump(cfg, default_flow_style=False))
+START_FROM = args.start
+net = Net()
+net.parse_proto(args.net)
+
+filters, blocks = net.filters(), net.blocks()
+if cfg['model']['filters'] != filters:
+ raise ValueError("Number of filters in YAML doesn't match the network")
+if cfg['model']['residual_blocks'] != blocks:
+ raise ValueError("Number of blocks in YAML doesn't match the network")
+weights = net.get_weights()
+
+tfp = TFProcess(cfg)
+tfp.init_net_v2()
+tfp.replace_weights_v2(weights)
+tfp.global_step.assign(START_FROM)
+
+root_dir = os.path.join(cfg['training']['path'], cfg['name'])
+if not os.path.exists(root_dir):
+ os.makedirs(root_dir)
+tfp.manager.save()
+print("Wrote model to {}".format(tfp.manager.latest_checkpoint))
diff --git a/backend/tf_transfer/policy_index.py b/backend/tf_transfer/policy_index.py
new file mode 100755
index 0000000..d641bf5
--- /dev/null
+++ b/backend/tf_transfer/policy_index.py
@@ -0,0 +1,233 @@
+policy_index = ["a1b1", "a1c1", "a1d1", "a1e1", "a1f1", "a1g1", "a1h1", "a1a2",
+"a1b2", "a1c2", "a1a3", "a1b3", "a1c3", "a1a4", "a1d4", "a1a5",
+"a1e5", "a1a6", "a1f6", "a1a7", "a1g7", "a1a8", "a1h8", "b1a1",
+"b1c1", "b1d1", "b1e1", "b1f1", "b1g1", "b1h1", "b1a2", "b1b2",
+"b1c2", "b1d2", "b1a3", "b1b3", "b1c3", "b1d3", "b1b4", "b1e4",
+"b1b5", "b1f5", "b1b6", "b1g6", "b1b7", "b1h7", "b1b8", "c1a1",
+"c1b1", "c1d1", "c1e1", "c1f1", "c1g1", "c1h1", "c1a2", "c1b2",
+"c1c2", "c1d2", "c1e2", "c1a3", "c1b3", "c1c3", "c1d3", "c1e3",
+"c1c4", "c1f4", "c1c5", "c1g5", "c1c6", "c1h6", "c1c7", "c1c8",
+"d1a1", "d1b1", "d1c1", "d1e1", "d1f1", "d1g1", "d1h1", "d1b2",
+"d1c2", "d1d2", "d1e2", "d1f2", "d1b3", "d1c3", "d1d3", "d1e3",
+"d1f3", "d1a4", "d1d4", "d1g4", "d1d5", "d1h5", "d1d6", "d1d7",
+"d1d8", "e1a1", "e1b1", "e1c1", "e1d1", "e1f1", "e1g1", "e1h1",
+"e1c2", "e1d2", "e1e2", "e1f2", "e1g2", "e1c3", "e1d3", "e1e3",
+"e1f3", "e1g3", "e1b4", "e1e4", "e1h4", "e1a5", "e1e5", "e1e6",
+"e1e7", "e1e8", "f1a1", "f1b1", "f1c1", "f1d1", "f1e1", "f1g1",
+"f1h1", "f1d2", "f1e2", "f1f2", "f1g2", "f1h2", "f1d3", "f1e3",
+"f1f3", "f1g3", "f1h3", "f1c4", "f1f4", "f1b5", "f1f5", "f1a6",
+"f1f6", "f1f7", "f1f8", "g1a1", "g1b1", "g1c1", "g1d1", "g1e1",
+"g1f1", "g1h1", "g1e2", "g1f2", "g1g2", "g1h2", "g1e3", "g1f3",
+"g1g3", "g1h3", "g1d4", "g1g4", "g1c5", "g1g5", "g1b6", "g1g6",
+"g1a7", "g1g7", "g1g8", "h1a1", "h1b1", "h1c1", "h1d1", "h1e1",
+"h1f1", "h1g1", "h1f2", "h1g2", "h1h2", "h1f3", "h1g3", "h1h3",
+"h1e4", "h1h4", "h1d5", "h1h5", "h1c6", "h1h6", "h1b7", "h1h7",
+"h1a8", "h1h8", "a2a1", "a2b1", "a2c1", "a2b2", "a2c2", "a2d2",
+"a2e2", "a2f2", "a2g2", "a2h2", "a2a3", "a2b3", "a2c3", "a2a4",
+"a2b4", "a2c4", "a2a5", "a2d5", "a2a6", "a2e6", "a2a7", "a2f7",
+"a2a8", "a2g8", "b2a1", "b2b1", "b2c1", "b2d1", "b2a2", "b2c2",
+"b2d2", "b2e2", "b2f2", "b2g2", "b2h2", "b2a3", "b2b3", "b2c3",
+"b2d3", "b2a4", "b2b4", "b2c4", "b2d4", "b2b5", "b2e5", "b2b6",
+"b2f6", "b2b7", "b2g7", "b2b8", "b2h8", "c2a1", "c2b1", "c2c1",
+"c2d1", "c2e1", "c2a2", "c2b2", "c2d2", "c2e2", "c2f2", "c2g2",
+"c2h2", "c2a3", "c2b3", "c2c3", "c2d3", "c2e3", "c2a4", "c2b4",
+"c2c4", "c2d4", "c2e4", "c2c5", "c2f5", "c2c6", "c2g6", "c2c7",
+"c2h7", "c2c8", "d2b1", "d2c1", "d2d1", "d2e1", "d2f1", "d2a2",
+"d2b2", "d2c2", "d2e2", "d2f2", "d2g2", "d2h2", "d2b3", "d2c3",
+"d2d3", "d2e3", "d2f3", "d2b4", "d2c4", "d2d4", "d2e4", "d2f4",
+"d2a5", "d2d5", "d2g5", "d2d6", "d2h6", "d2d7", "d2d8", "e2c1",
+"e2d1", "e2e1", "e2f1", "e2g1", "e2a2", "e2b2", "e2c2", "e2d2",
+"e2f2", "e2g2", "e2h2", "e2c3", "e2d3", "e2e3", "e2f3", "e2g3",
+"e2c4", "e2d4", "e2e4", "e2f4", "e2g4", "e2b5", "e2e5", "e2h5",
+"e2a6", "e2e6", "e2e7", "e2e8", "f2d1", "f2e1", "f2f1", "f2g1",
+"f2h1", "f2a2", "f2b2", "f2c2", "f2d2", "f2e2", "f2g2", "f2h2",
+"f2d3", "f2e3", "f2f3", "f2g3", "f2h3", "f2d4", "f2e4", "f2f4",
+"f2g4", "f2h4", "f2c5", "f2f5", "f2b6", "f2f6", "f2a7", "f2f7",
+"f2f8", "g2e1", "g2f1", "g2g1", "g2h1", "g2a2", "g2b2", "g2c2",
+"g2d2", "g2e2", "g2f2", "g2h2", "g2e3", "g2f3", "g2g3", "g2h3",
+"g2e4", "g2f4", "g2g4", "g2h4", "g2d5", "g2g5", "g2c6", "g2g6",
+"g2b7", "g2g7", "g2a8", "g2g8", "h2f1", "h2g1", "h2h1", "h2a2",
+"h2b2", "h2c2", "h2d2", "h2e2", "h2f2", "h2g2", "h2f3", "h2g3",
+"h2h3", "h2f4", "h2g4", "h2h4", "h2e5", "h2h5", "h2d6", "h2h6",
+"h2c7", "h2h7", "h2b8", "h2h8", "a3a1", "a3b1", "a3c1", "a3a2",
+"a3b2", "a3c2", "a3b3", "a3c3", "a3d3", "a3e3", "a3f3", "a3g3",
+"a3h3", "a3a4", "a3b4", "a3c4", "a3a5", "a3b5", "a3c5", "a3a6",
+"a3d6", "a3a7", "a3e7", "a3a8", "a3f8", "b3a1", "b3b1", "b3c1",
+"b3d1", "b3a2", "b3b2", "b3c2", "b3d2", "b3a3", "b3c3", "b3d3",
+"b3e3", "b3f3", "b3g3", "b3h3", "b3a4", "b3b4", "b3c4", "b3d4",
+"b3a5", "b3b5", "b3c5", "b3d5", "b3b6", "b3e6", "b3b7", "b3f7",
+"b3b8", "b3g8", "c3a1", "c3b1", "c3c1", "c3d1", "c3e1", "c3a2",
+"c3b2", "c3c2", "c3d2", "c3e2", "c3a3", "c3b3", "c3d3", "c3e3",
+"c3f3", "c3g3", "c3h3", "c3a4", "c3b4", "c3c4", "c3d4", "c3e4",
+"c3a5", "c3b5", "c3c5", "c3d5", "c3e5", "c3c6", "c3f6", "c3c7",
+"c3g7", "c3c8", "c3h8", "d3b1", "d3c1", "d3d1", "d3e1", "d3f1",
+"d3b2", "d3c2", "d3d2", "d3e2", "d3f2", "d3a3", "d3b3", "d3c3",
+"d3e3", "d3f3", "d3g3", "d3h3", "d3b4", "d3c4", "d3d4", "d3e4",
+"d3f4", "d3b5", "d3c5", "d3d5", "d3e5", "d3f5", "d3a6", "d3d6",
+"d3g6", "d3d7", "d3h7", "d3d8", "e3c1", "e3d1", "e3e1", "e3f1",
+"e3g1", "e3c2", "e3d2", "e3e2", "e3f2", "e3g2", "e3a3", "e3b3",
+"e3c3", "e3d3", "e3f3", "e3g3", "e3h3", "e3c4", "e3d4", "e3e4",
+"e3f4", "e3g4", "e3c5", "e3d5", "e3e5", "e3f5", "e3g5", "e3b6",
+"e3e6", "e3h6", "e3a7", "e3e7", "e3e8", "f3d1", "f3e1", "f3f1",
+"f3g1", "f3h1", "f3d2", "f3e2", "f3f2", "f3g2", "f3h2", "f3a3",
+"f3b3", "f3c3", "f3d3", "f3e3", "f3g3", "f3h3", "f3d4", "f3e4",
+"f3f4", "f3g4", "f3h4", "f3d5", "f3e5", "f3f5", "f3g5", "f3h5",
+"f3c6", "f3f6", "f3b7", "f3f7", "f3a8", "f3f8", "g3e1", "g3f1",
+"g3g1", "g3h1", "g3e2", "g3f2", "g3g2", "g3h2", "g3a3", "g3b3",
+"g3c3", "g3d3", "g3e3", "g3f3", "g3h3", "g3e4", "g3f4", "g3g4",
+"g3h4", "g3e5", "g3f5", "g3g5", "g3h5", "g3d6", "g3g6", "g3c7",
+"g3g7", "g3b8", "g3g8", "h3f1", "h3g1", "h3h1", "h3f2", "h3g2",
+"h3h2", "h3a3", "h3b3", "h3c3", "h3d3", "h3e3", "h3f3", "h3g3",
+"h3f4", "h3g4", "h3h4", "h3f5", "h3g5", "h3h5", "h3e6", "h3h6",
+"h3d7", "h3h7", "h3c8", "h3h8", "a4a1", "a4d1", "a4a2", "a4b2",
+"a4c2", "a4a3", "a4b3", "a4c3", "a4b4", "a4c4", "a4d4", "a4e4",
+"a4f4", "a4g4", "a4h4", "a4a5", "a4b5", "a4c5", "a4a6", "a4b6",
+"a4c6", "a4a7", "a4d7", "a4a8", "a4e8", "b4b1", "b4e1", "b4a2",
+"b4b2", "b4c2", "b4d2", "b4a3", "b4b3", "b4c3", "b4d3", "b4a4",
+"b4c4", "b4d4", "b4e4", "b4f4", "b4g4", "b4h4", "b4a5", "b4b5",
+"b4c5", "b4d5", "b4a6", "b4b6", "b4c6", "b4d6", "b4b7", "b4e7",
+"b4b8", "b4f8", "c4c1", "c4f1", "c4a2", "c4b2", "c4c2", "c4d2",
+"c4e2", "c4a3", "c4b3", "c4c3", "c4d3", "c4e3", "c4a4", "c4b4",
+"c4d4", "c4e4", "c4f4", "c4g4", "c4h4", "c4a5", "c4b5", "c4c5",
+"c4d5", "c4e5", "c4a6", "c4b6", "c4c6", "c4d6", "c4e6", "c4c7",
+"c4f7", "c4c8", "c4g8", "d4a1", "d4d1", "d4g1", "d4b2", "d4c2",
+"d4d2", "d4e2", "d4f2", "d4b3", "d4c3", "d4d3", "d4e3", "d4f3",
+"d4a4", "d4b4", "d4c4", "d4e4", "d4f4", "d4g4", "d4h4", "d4b5",
+"d4c5", "d4d5", "d4e5", "d4f5", "d4b6", "d4c6", "d4d6", "d4e6",
+"d4f6", "d4a7", "d4d7", "d4g7", "d4d8", "d4h8", "e4b1", "e4e1",
+"e4h1", "e4c2", "e4d2", "e4e2", "e4f2", "e4g2", "e4c3", "e4d3",
+"e4e3", "e4f3", "e4g3", "e4a4", "e4b4", "e4c4", "e4d4", "e4f4",
+"e4g4", "e4h4", "e4c5", "e4d5", "e4e5", "e4f5", "e4g5", "e4c6",
+"e4d6", "e4e6", "e4f6", "e4g6", "e4b7", "e4e7", "e4h7", "e4a8",
+"e4e8", "f4c1", "f4f1", "f4d2", "f4e2", "f4f2", "f4g2", "f4h2",
+"f4d3", "f4e3", "f4f3", "f4g3", "f4h3", "f4a4", "f4b4", "f4c4",
+"f4d4", "f4e4", "f4g4", "f4h4", "f4d5", "f4e5", "f4f5", "f4g5",
+"f4h5", "f4d6", "f4e6", "f4f6", "f4g6", "f4h6", "f4c7", "f4f7",
+"f4b8", "f4f8", "g4d1", "g4g1", "g4e2", "g4f2", "g4g2", "g4h2",
+"g4e3", "g4f3", "g4g3", "g4h3", "g4a4", "g4b4", "g4c4", "g4d4",
+"g4e4", "g4f4", "g4h4", "g4e5", "g4f5", "g4g5", "g4h5", "g4e6",
+"g4f6", "g4g6", "g4h6", "g4d7", "g4g7", "g4c8", "g4g8", "h4e1",
+"h4h1", "h4f2", "h4g2", "h4h2", "h4f3", "h4g3", "h4h3", "h4a4",
+"h4b4", "h4c4", "h4d4", "h4e4", "h4f4", "h4g4", "h4f5", "h4g5",
+"h4h5", "h4f6", "h4g6", "h4h6", "h4e7", "h4h7", "h4d8", "h4h8",
+"a5a1", "a5e1", "a5a2", "a5d2", "a5a3", "a5b3", "a5c3", "a5a4",
+"a5b4", "a5c4", "a5b5", "a5c5", "a5d5", "a5e5", "a5f5", "a5g5",
+"a5h5", "a5a6", "a5b6", "a5c6", "a5a7", "a5b7", "a5c7", "a5a8",
+"a5d8", "b5b1", "b5f1", "b5b2", "b5e2", "b5a3", "b5b3", "b5c3",
+"b5d3", "b5a4", "b5b4", "b5c4", "b5d4", "b5a5", "b5c5", "b5d5",
+"b5e5", "b5f5", "b5g5", "b5h5", "b5a6", "b5b6", "b5c6", "b5d6",
+"b5a7", "b5b7", "b5c7", "b5d7", "b5b8", "b5e8", "c5c1", "c5g1",
+"c5c2", "c5f2", "c5a3", "c5b3", "c5c3", "c5d3", "c5e3", "c5a4",
+"c5b4", "c5c4", "c5d4", "c5e4", "c5a5", "c5b5", "c5d5", "c5e5",
+"c5f5", "c5g5", "c5h5", "c5a6", "c5b6", "c5c6", "c5d6", "c5e6",
+"c5a7", "c5b7", "c5c7", "c5d7", "c5e7", "c5c8", "c5f8", "d5d1",
+"d5h1", "d5a2", "d5d2", "d5g2", "d5b3", "d5c3", "d5d3", "d5e3",
+"d5f3", "d5b4", "d5c4", "d5d4", "d5e4", "d5f4", "d5a5", "d5b5",
+"d5c5", "d5e5", "d5f5", "d5g5", "d5h5", "d5b6", "d5c6", "d5d6",
+"d5e6", "d5f6", "d5b7", "d5c7", "d5d7", "d5e7", "d5f7", "d5a8",
+"d5d8", "d5g8", "e5a1", "e5e1", "e5b2", "e5e2", "e5h2", "e5c3",
+"e5d3", "e5e3", "e5f3", "e5g3", "e5c4", "e5d4", "e5e4", "e5f4",
+"e5g4", "e5a5", "e5b5", "e5c5", "e5d5", "e5f5", "e5g5", "e5h5",
+"e5c6", "e5d6", "e5e6", "e5f6", "e5g6", "e5c7", "e5d7", "e5e7",
+"e5f7", "e5g7", "e5b8", "e5e8", "e5h8", "f5b1", "f5f1", "f5c2",
+"f5f2", "f5d3", "f5e3", "f5f3", "f5g3", "f5h3", "f5d4", "f5e4",
+"f5f4", "f5g4", "f5h4", "f5a5", "f5b5", "f5c5", "f5d5", "f5e5",
+"f5g5", "f5h5", "f5d6", "f5e6", "f5f6", "f5g6", "f5h6", "f5d7",
+"f5e7", "f5f7", "f5g7", "f5h7", "f5c8", "f5f8", "g5c1", "g5g1",
+"g5d2", "g5g2", "g5e3", "g5f3", "g5g3", "g5h3", "g5e4", "g5f4",
+"g5g4", "g5h4", "g5a5", "g5b5", "g5c5", "g5d5", "g5e5", "g5f5",
+"g5h5", "g5e6", "g5f6", "g5g6", "g5h6", "g5e7", "g5f7", "g5g7",
+"g5h7", "g5d8", "g5g8", "h5d1", "h5h1", "h5e2", "h5h2", "h5f3",
+"h5g3", "h5h3", "h5f4", "h5g4", "h5h4", "h5a5", "h5b5", "h5c5",
+"h5d5", "h5e5", "h5f5", "h5g5", "h5f6", "h5g6", "h5h6", "h5f7",
+"h5g7", "h5h7", "h5e8", "h5h8", "a6a1", "a6f1", "a6a2", "a6e2",
+"a6a3", "a6d3", "a6a4", "a6b4", "a6c4", "a6a5", "a6b5", "a6c5",
+"a6b6", "a6c6", "a6d6", "a6e6", "a6f6", "a6g6", "a6h6", "a6a7",
+"a6b7", "a6c7", "a6a8", "a6b8", "a6c8", "b6b1", "b6g1", "b6b2",
+"b6f2", "b6b3", "b6e3", "b6a4", "b6b4", "b6c4", "b6d4", "b6a5",
+"b6b5", "b6c5", "b6d5", "b6a6", "b6c6", "b6d6", "b6e6", "b6f6",
+"b6g6", "b6h6", "b6a7", "b6b7", "b6c7", "b6d7", "b6a8", "b6b8",
+"b6c8", "b6d8", "c6c1", "c6h1", "c6c2", "c6g2", "c6c3", "c6f3",
+"c6a4", "c6b4", "c6c4", "c6d4", "c6e4", "c6a5", "c6b5", "c6c5",
+"c6d5", "c6e5", "c6a6", "c6b6", "c6d6", "c6e6", "c6f6", "c6g6",
+"c6h6", "c6a7", "c6b7", "c6c7", "c6d7", "c6e7", "c6a8", "c6b8",
+"c6c8", "c6d8", "c6e8", "d6d1", "d6d2", "d6h2", "d6a3", "d6d3",
+"d6g3", "d6b4", "d6c4", "d6d4", "d6e4", "d6f4", "d6b5", "d6c5",
+"d6d5", "d6e5", "d6f5", "d6a6", "d6b6", "d6c6", "d6e6", "d6f6",
+"d6g6", "d6h6", "d6b7", "d6c7", "d6d7", "d6e7", "d6f7", "d6b8",
+"d6c8", "d6d8", "d6e8", "d6f8", "e6e1", "e6a2", "e6e2", "e6b3",
+"e6e3", "e6h3", "e6c4", "e6d4", "e6e4", "e6f4", "e6g4", "e6c5",
+"e6d5", "e6e5", "e6f5", "e6g5", "e6a6", "e6b6", "e6c6", "e6d6",
+"e6f6", "e6g6", "e6h6", "e6c7", "e6d7", "e6e7", "e6f7", "e6g7",
+"e6c8", "e6d8", "e6e8", "e6f8", "e6g8", "f6a1", "f6f1", "f6b2",
+"f6f2", "f6c3", "f6f3", "f6d4", "f6e4", "f6f4", "f6g4", "f6h4",
+"f6d5", "f6e5", "f6f5", "f6g5", "f6h5", "f6a6", "f6b6", "f6c6",
+"f6d6", "f6e6", "f6g6", "f6h6", "f6d7", "f6e7", "f6f7", "f6g7",
+"f6h7", "f6d8", "f6e8", "f6f8", "f6g8", "f6h8", "g6b1", "g6g1",
+"g6c2", "g6g2", "g6d3", "g6g3", "g6e4", "g6f4", "g6g4", "g6h4",
+"g6e5", "g6f5", "g6g5", "g6h5", "g6a6", "g6b6", "g6c6", "g6d6",
+"g6e6", "g6f6", "g6h6", "g6e7", "g6f7", "g6g7", "g6h7", "g6e8",
+"g6f8", "g6g8", "g6h8", "h6c1", "h6h1", "h6d2", "h6h2", "h6e3",
+"h6h3", "h6f4", "h6g4", "h6h4", "h6f5", "h6g5", "h6h5", "h6a6",
+"h6b6", "h6c6", "h6d6", "h6e6", "h6f6", "h6g6", "h6f7", "h6g7",
+"h6h7", "h6f8", "h6g8", "h6h8", "a7a1", "a7g1", "a7a2", "a7f2",
+"a7a3", "a7e3", "a7a4", "a7d4", "a7a5", "a7b5", "a7c5", "a7a6",
+"a7b6", "a7c6", "a7b7", "a7c7", "a7d7", "a7e7", "a7f7", "a7g7",
+"a7h7", "a7a8", "a7b8", "a7c8", "b7b1", "b7h1", "b7b2", "b7g2",
+"b7b3", "b7f3", "b7b4", "b7e4", "b7a5", "b7b5", "b7c5", "b7d5",
+"b7a6", "b7b6", "b7c6", "b7d6", "b7a7", "b7c7", "b7d7", "b7e7",
+"b7f7", "b7g7", "b7h7", "b7a8", "b7b8", "b7c8", "b7d8", "c7c1",
+"c7c2", "c7h2", "c7c3", "c7g3", "c7c4", "c7f4", "c7a5", "c7b5",
+"c7c5", "c7d5", "c7e5", "c7a6", "c7b6", "c7c6", "c7d6", "c7e6",
+"c7a7", "c7b7", "c7d7", "c7e7", "c7f7", "c7g7", "c7h7", "c7a8",
+"c7b8", "c7c8", "c7d8", "c7e8", "d7d1", "d7d2", "d7d3", "d7h3",
+"d7a4", "d7d4", "d7g4", "d7b5", "d7c5", "d7d5", "d7e5", "d7f5",
+"d7b6", "d7c6", "d7d6", "d7e6", "d7f6", "d7a7", "d7b7", "d7c7",
+"d7e7", "d7f7", "d7g7", "d7h7", "d7b8", "d7c8", "d7d8", "d7e8",
+"d7f8", "e7e1", "e7e2", "e7a3", "e7e3", "e7b4", "e7e4", "e7h4",
+"e7c5", "e7d5", "e7e5", "e7f5", "e7g5", "e7c6", "e7d6", "e7e6",
+"e7f6", "e7g6", "e7a7", "e7b7", "e7c7", "e7d7", "e7f7", "e7g7",
+"e7h7", "e7c8", "e7d8", "e7e8", "e7f8", "e7g8", "f7f1", "f7a2",
+"f7f2", "f7b3", "f7f3", "f7c4", "f7f4", "f7d5", "f7e5", "f7f5",
+"f7g5", "f7h5", "f7d6", "f7e6", "f7f6", "f7g6", "f7h6", "f7a7",
+"f7b7", "f7c7", "f7d7", "f7e7", "f7g7", "f7h7", "f7d8", "f7e8",
+"f7f8", "f7g8", "f7h8", "g7a1", "g7g1", "g7b2", "g7g2", "g7c3",
+"g7g3", "g7d4", "g7g4", "g7e5", "g7f5", "g7g5", "g7h5", "g7e6",
+"g7f6", "g7g6", "g7h6", "g7a7", "g7b7", "g7c7", "g7d7", "g7e7",
+"g7f7", "g7h7", "g7e8", "g7f8", "g7g8", "g7h8", "h7b1", "h7h1",
+"h7c2", "h7h2", "h7d3", "h7h3", "h7e4", "h7h4", "h7f5", "h7g5",
+"h7h5", "h7f6", "h7g6", "h7h6", "h7a7", "h7b7", "h7c7", "h7d7",
+"h7e7", "h7f7", "h7g7", "h7f8", "h7g8", "h7h8", "a8a1", "a8h1",
+"a8a2", "a8g2", "a8a3", "a8f3", "a8a4", "a8e4", "a8a5", "a8d5",
+"a8a6", "a8b6", "a8c6", "a8a7", "a8b7", "a8c7", "a8b8", "a8c8",
+"a8d8", "a8e8", "a8f8", "a8g8", "a8h8", "b8b1", "b8b2", "b8h2",
+"b8b3", "b8g3", "b8b4", "b8f4", "b8b5", "b8e5", "b8a6", "b8b6",
+"b8c6", "b8d6", "b8a7", "b8b7", "b8c7", "b8d7", "b8a8", "b8c8",
+"b8d8", "b8e8", "b8f8", "b8g8", "b8h8", "c8c1", "c8c2", "c8c3",
+"c8h3", "c8c4", "c8g4", "c8c5", "c8f5", "c8a6", "c8b6", "c8c6",
+"c8d6", "c8e6", "c8a7", "c8b7", "c8c7", "c8d7", "c8e7", "c8a8",
+"c8b8", "c8d8", "c8e8", "c8f8", "c8g8", "c8h8", "d8d1", "d8d2",
+"d8d3", "d8d4", "d8h4", "d8a5", "d8d5", "d8g5", "d8b6", "d8c6",
+"d8d6", "d8e6", "d8f6", "d8b7", "d8c7", "d8d7", "d8e7", "d8f7",
+"d8a8", "d8b8", "d8c8", "d8e8", "d8f8", "d8g8", "d8h8", "e8e1",
+"e8e2", "e8e3", "e8a4", "e8e4", "e8b5", "e8e5", "e8h5", "e8c6",
+"e8d6", "e8e6", "e8f6", "e8g6", "e8c7", "e8d7", "e8e7", "e8f7",
+"e8g7", "e8a8", "e8b8", "e8c8", "e8d8", "e8f8", "e8g8", "e8h8",
+"f8f1", "f8f2", "f8a3", "f8f3", "f8b4", "f8f4", "f8c5", "f8f5",
+"f8d6", "f8e6", "f8f6", "f8g6", "f8h6", "f8d7", "f8e7", "f8f7",
+"f8g7", "f8h7", "f8a8", "f8b8", "f8c8", "f8d8", "f8e8", "f8g8",
+"f8h8", "g8g1", "g8a2", "g8g2", "g8b3", "g8g3", "g8c4", "g8g4",
+"g8d5", "g8g5", "g8e6", "g8f6", "g8g6", "g8h6", "g8e7", "g8f7",
+"g8g7", "g8h7", "g8a8", "g8b8", "g8c8", "g8d8", "g8e8", "g8f8",
+"g8h8", "h8a1", "h8h1", "h8b2", "h8h2", "h8c3", "h8h3", "h8d4",
+"h8h4", "h8e5", "h8h5", "h8f6", "h8g6", "h8h6", "h8f7", "h8g7",
+"h8h7", "h8a8", "h8b8", "h8c8", "h8d8", "h8e8", "h8f8", "h8g8",
+"a7a8q", "a7a8r", "a7a8b", "a7b8q", "a7b8r", "a7b8b", "b7a8q", "b7a8r",
+"b7a8b", "b7b8q", "b7b8r", "b7b8b", "b7c8q", "b7c8r", "b7c8b", "c7b8q",
+"c7b8r", "c7b8b", "c7c8q", "c7c8r", "c7c8b", "c7d8q", "c7d8r", "c7d8b",
+"d7c8q", "d7c8r", "d7c8b", "d7d8q", "d7d8r", "d7d8b", "d7e8q", "d7e8r",
+"d7e8b", "e7d8q", "e7d8r", "e7d8b", "e7e8q", "e7e8r", "e7e8b", "e7f8q",
+"e7f8r", "e7f8b", "f7e8q", "f7e8r", "f7e8b", "f7f8q", "f7f8r", "f7f8b",
+"f7g8q", "f7g8r", "f7g8b", "g7f8q", "g7f8r", "g7f8b", "g7g8q", "g7g8r",
+"g7g8b", "g7h8q", "g7h8r", "g7h8b", "h7g8q", "h7g8r", "h7g8b", "h7h8q",
+"h7h8r", "h7h8b"]
diff --git a/backend/tf_transfer/shufflebuffer.py b/backend/tf_transfer/shufflebuffer.py
new file mode 100755
index 0000000..06c369d
--- /dev/null
+++ b/backend/tf_transfer/shufflebuffer.py
@@ -0,0 +1,121 @@
+#!/usr/bin/env python3
+#
+# This file is part of Leela Chess.
+# Copyright (C) 2018 Michael O
+#
+# Leela Chess is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Leela Chess is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Leela Chess. If not, see .
+
+import random
+import unittest
+
+class ShuffleBuffer:
+ def __init__(self, elem_size, elem_count):
+ """
+ A shuffle buffer for fixed sized elements.
+
+ Manages 'elem_count' items in a fixed buffer, each item being exactly
+ 'elem_size' bytes.
+ """
+ assert elem_size > 0, elem_size
+ assert elem_count > 0, elem_count
+ # Size of each element.
+ self.elem_size = elem_size
+ # Number of elements in the buffer.
+ self.elem_count = elem_count
+ # Fixed size buffer used to hold all the element.
+ self.buffer = bytearray(elem_size * elem_count)
+ # Number of elements actually contained in the buffer.
+ self.used = 0
+
+ def extract(self):
+ """
+ Return an item from the shuffle buffer.
+
+ If the buffer is empty, returns None
+ """
+ if self.used < 1:
+ return None
+ # The items in the shuffle buffer are held in shuffled order
+ # so returning the last item is sufficient.
+ self.used -= 1
+ i = self.used
+ return self.buffer[i * self.elem_size : (i+1) * self.elem_size]
+
+ def insert_or_replace(self, item):
+ """
+ Inserts 'item' into the shuffle buffer, returning
+ a random item.
+
+ If the buffer is not yet full, returns None
+ """
+ assert len(item) == self.elem_size, len(item)
+ # putting the new item in a random location, and appending
+ # the displaced item to the end of the buffer achieves a full
+ # random shuffle (Fisher-Yates)
+ if self.used > 0:
+ # swap 'item' with random item in buffer.
+ i = random.randint(0, self.used-1)
+ old_item = self.buffer[i * self.elem_size : (i+1) * self.elem_size]
+ self.buffer[i * self.elem_size : (i+1) * self.elem_size] = item
+ item = old_item
+ # If the buffer isn't yet full, append 'item' to the end of the buffer.
+ if self.used < self.elem_count:
+ # Not yet full, so place the returned item at the end of the buffer.
+ i = self.used
+ self.buffer[i * self.elem_size : (i+1) * self.elem_size] = item
+ self.used += 1
+ return None
+ return item
+
+
+class ShuffleBufferTest(unittest.TestCase):
+ def test_extract(self):
+ sb = ShuffleBuffer(3, 1)
+ r = sb.extract()
+ assert r == None, r # empty buffer => None
+ r = sb.insert_or_replace(b'111')
+ assert r == None, r # buffer not yet full => None
+ r = sb.extract()
+ assert r == b'111', r # one item in buffer => item
+ r = sb.extract()
+ assert r == None, r # buffer empty => None
+ def test_wrong_size(self):
+ sb = ShuffleBuffer(3, 1)
+ try:
+ sb.insert_or_replace(b'1') # wrong length, so should throw.
+ assert False # Should not be reached.
+ except:
+ pass
+ def test_insert_or_replace(self):
+ n=10 # number of test items.
+ items=[bytes([x,x,x]) for x in range(n)]
+ sb = ShuffleBuffer(elem_size=3, elem_count=2)
+ out=[]
+ for i in items:
+ r = sb.insert_or_replace(i)
+ if not r is None:
+ out.append(r)
+ # Buffer size is 2, 10 items, should be 8 seen so far.
+ assert len(out) == n - 2, len(out)
+ # Get the last two items.
+ out.append(sb.extract())
+ out.append(sb.extract())
+ assert sorted(items) == sorted(out), (items, out)
+ # Check that buffer is empty
+ r = sb.extract()
+ assert r is None, r
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/backend/tf_transfer/tfprocess.py b/backend/tf_transfer/tfprocess.py
new file mode 100755
index 0000000..6352789
--- /dev/null
+++ b/backend/tf_transfer/tfprocess.py
@@ -0,0 +1,829 @@
+#!/usr/bin/env python3
+#
+# This file is part of Leela Zero.
+# Copyright (C) 2017-2018 Gian-Carlo Pascutto
+#
+# Leela Zero is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Leela Zero is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Leela Zero. If not, see .
+
+import numpy as np
+import os
+import os.path
+import random
+import tensorflow as tf
+import time
+import bisect
+from .lc0_az_policy_map import make_map
+import backend.proto.net_pb2 as pb
+
+from .net import Net
+
+from ..utils import printWithDate
+
+import natsort
+
+
+def model_path_gen(short_path):
+ models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../models'))
+ return os.path.join(models_path, short_path, 'ckpt/ckpt-40')
+
+class ApplySqueezeExcitation(tf.keras.layers.Layer):
+ def __init__(self, **kwargs):
+ super(ApplySqueezeExcitation, self).__init__(**kwargs)
+
+ def build(self, input_dimens):
+ self.reshape_size = input_dimens[1][1]
+
+ def call(self, inputs):
+ x = inputs[0]
+ excited = inputs[1]
+ gammas, betas = tf.split(tf.reshape(excited, [-1, self.reshape_size, 1, 1]), 2, axis=1)
+ return tf.nn.sigmoid(gammas) * x + betas
+
+
+class ApplyPolicyMap(tf.keras.layers.Layer):
+ def __init__(self, **kwargs):
+ super(ApplyPolicyMap, self).__init__(**kwargs)
+ self.fc1 = tf.constant(make_map())
+
+ def call(self, inputs):
+ h_conv_pol_flat = tf.reshape(inputs, [-1, 80*8*8])
+ return tf.matmul(h_conv_pol_flat, tf.cast(self.fc1, h_conv_pol_flat.dtype))
+
+class TFProcess:
+ def __init__(self, cfg, name, collection_name):
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+ self.cfg = cfg
+ self.name = name
+ self.collection_name = collection_name
+ self.net = Net()
+ self.root_dir = os.path.join('models', self.collection_name, self.name)
+
+ # Network structure
+ self.RESIDUAL_FILTERS = self.cfg['model']['filters']
+ self.RESIDUAL_BLOCKS = self.cfg['model']['residual_blocks']
+ self.SE_ratio = self.cfg['model']['se_ratio']
+ self.policy_channels = self.cfg['model'].get('policy_channels', 32)
+ precision = self.cfg['training'].get('precision', 'single')
+ loss_scale = self.cfg['training'].get('loss_scale', 128)
+
+ if precision == 'single':
+ self.model_dtype = tf.float32
+ elif precision == 'half':
+ self.model_dtype = tf.float16
+ else:
+ raise ValueError("Unknown precision: {}".format(precision))
+
+ # Scale the loss to prevent gradient underflow
+ self.loss_scale = 1 if self.model_dtype == tf.float32 else loss_scale
+
+ self.VALUE_HEAD = None
+
+ self.POLICY_HEAD = pb.NetworkFormat.POLICY_CONVOLUTION
+
+ self.net.set_policyformat(self.POLICY_HEAD)
+
+ self.VALUE_HEAD = pb.NetworkFormat.VALUE_WDL
+ self.wdl = True
+
+
+ self.net.set_valueformat(self.VALUE_HEAD)
+
+ self.swa_enabled = self.cfg['training'].get('swa', False)
+
+ # Limit momentum of SWA exponential average to 1 - 1/(swa_max_n + 1)
+ self.swa_max_n = self.cfg['training'].get('swa_max_n', 0)
+
+ self.renorm_enabled = self.cfg['training'].get('renorm', False)
+ self.renorm_max_r = self.cfg['training'].get('renorm_max_r', 1)
+ self.renorm_max_d = self.cfg['training'].get('renorm_max_d', 0)
+ self.renorm_momentum = self.cfg['training'].get('renorm_momentum', 0.99)
+
+ gpus = tf.config.experimental.list_physical_devices('GPU')
+ tf.config.experimental.set_visible_devices(gpus[self.cfg['gpu']], 'GPU')
+ tf.config.experimental.set_memory_growth(gpus[self.cfg['gpu']], True)
+ if self.model_dtype == tf.float16:
+ tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
+
+ self.global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int64)
+
+ def init_v2(self, train_dataset, test_dataset):
+ self.train_dataset = train_dataset
+ self.train_iter = iter(train_dataset)
+ self.test_dataset = test_dataset
+ self.test_iter = iter(test_dataset)
+ self.init_net_v2()
+
+ def init_net_v2(self):
+ self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001))
+ input_var = tf.keras.Input(shape=(112, 8*8))
+ x_planes = tf.keras.layers.Reshape([112, 8, 8])(input_var)
+
+ base_ckpt_path = model_path_gen(self.cfg['model']['path'])
+
+ self.model_maia = tf.keras.Model(inputs=input_var, outputs=self.construct_net_complete(x_planes))
+ self.checkpoint_restore = tf.train.Checkpoint(model=self.model_maia)
+ self.restore_ckpt(base_ckpt_path)
+
+ # The tf names use natural numbers with no prefixes
+ # so to index layers correctly we neec to sort them naturally
+
+ natsort_key = natsort.natsort_keygen()
+
+ self.model = tf.keras.Model(
+ inputs=input_var,
+ outputs=self.construct_with_stops(
+ x_planes,
+ self.cfg['model'].get('back_prop_blocks', 3),
+ ))
+ maia_layers = sorted(
+ self.model_maia.layers,
+ key = lambda x : natsort_key(x.name),
+ )
+ model_layers = sorted(
+ [l for l in self.model.layers if 'lambda' not in l.name],
+ key = lambda x : natsort_key(x.name),
+ )
+
+ layer_map = {model_layer.name : maia_layer for model_layer, maia_layer in zip(model_layers, maia_layers)}
+
+ for i, model_layer in enumerate(self.model.layers):
+ if not self.cfg['model'].get('keep_weights', False) and self.cfg['model'].get('back_prop_blocks', 3) > self.RESIDUAL_BLOCKS + 4:
+ printWithDate(f"ending at depth {i}: {model_layer.name}")
+ break
+ if 'lambda' not in model_layer.name:
+ l_maia = layer_map[model_layer.name]
+ model_layer.set_weights([w.numpy() for w in l_maia.weights])
+ elif not self.cfg['model'].get('keep_weights', False):
+ printWithDate(f"ending at depth {i}: {model_layer.name}")
+ break
+
+ printWithDate("Setting up lc0 stuff")
+ # swa_count initialized reguardless to make checkpoint code simpler.
+ self.swa_count = tf.Variable(0., name='swa_count', trainable=False)
+ self.swa_weights = None
+ if self.swa_enabled:
+ # Count of networks accumulated into SWA
+ self.swa_weights = [tf.Variable(w, trainable=False) for w in self.model.weights]
+
+ self.active_lr = 0.01
+ self.optimizer = tf.keras.optimizers.SGD(learning_rate=lambda: self.active_lr, momentum=0.9, nesterov=True)
+ self.orig_optimizer = self.optimizer
+ if self.loss_scale != 1:
+ self.optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(self.optimizer, self.loss_scale)
+ def correct_policy(target, output):
+ output = tf.cast(output, tf.float32)
+ # Calculate loss on policy head
+ if self.cfg['training'].get('mask_legal_moves'):
+ # extract mask for legal moves from target policy
+ move_is_legal = tf.greater_equal(target, 0)
+ # replace logits of illegal moves with large negative value (so that it doesn't affect policy of legal moves) without gradient
+ illegal_filler = tf.zeros_like(output) - 1.0e10
+ output = tf.where(move_is_legal, output, illegal_filler)
+ # y_ still has -1 on illegal moves, flush them to 0
+ target = tf.nn.relu(target)
+ return target, output
+ def policy_loss(target, output):
+ target, output = correct_policy(target, output)
+ policy_cross_entropy = \
+ tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(target),
+ logits=output)
+ return tf.reduce_mean(input_tensor=policy_cross_entropy)
+ self.policy_loss_fn = policy_loss
+ def policy_accuracy(target, output):
+ target, output = correct_policy(target, output)
+ return tf.reduce_mean(tf.cast(tf.equal(tf.argmax(input=target, axis=1), tf.argmax(input=output, axis=1)), tf.float32))
+ self.policy_accuracy_fn = policy_accuracy
+
+
+ q_ratio = self.cfg['training'].get('q_ratio', 0)
+ assert 0 <= q_ratio <= 1
+
+ # Linear conversion to scalar to compute MSE with, for comparison to old values
+ wdl = tf.expand_dims(tf.constant([1.0, 0.0, -1.0]), 1)
+
+ self.qMix = lambda z, q: q * q_ratio + z *(1 - q_ratio)
+ # Loss on value head
+ if self.wdl:
+ def value_loss(target, output):
+ output = tf.cast(output, tf.float32)
+ value_cross_entropy = \
+ tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(target),
+ logits=output)
+ return tf.reduce_mean(input_tensor=value_cross_entropy)
+ self.value_loss_fn = value_loss
+ def mse_loss(target, output):
+ output = tf.cast(output, tf.float32)
+ scalar_z_conv = tf.matmul(tf.nn.softmax(output), wdl)
+ scalar_target = tf.matmul(target, wdl)
+ return tf.reduce_mean(input_tensor=tf.math.squared_difference(scalar_target, scalar_z_conv))
+ self.mse_loss_fn = mse_loss
+ else:
+ def value_loss(target, output):
+ return tf.constant(0)
+ self.value_loss_fn = value_loss
+ def mse_loss(target, output):
+ output = tf.cast(output, tf.float32)
+ scalar_target = tf.matmul(target, wdl)
+ return tf.reduce_mean(input_tensor=tf.math.squared_difference(scalar_target, output))
+ self.mse_loss_fn = mse_loss
+
+ pol_loss_w = self.cfg['training']['policy_loss_weight']
+ val_loss_w = self.cfg['training']['value_loss_weight']
+ self.lossMix = lambda policy, value: pol_loss_w * policy + val_loss_w * value
+
+ def accuracy(target, output):
+ output = tf.cast(output, tf.float32)
+ return tf.reduce_mean(tf.cast(tf.equal(tf.argmax(input=target, axis=1), tf.argmax(input=output, axis=1)), tf.float32))
+ self.accuracy_fn = accuracy
+
+ self.avg_policy_loss = []
+ self.avg_value_loss = []
+ self.avg_mse_loss = []
+ self.avg_reg_term = []
+ self.time_start = None
+ self.last_steps = None
+ # Set adaptive learning rate during training
+ self.cfg['training']['lr_boundaries'].sort()
+ self.warmup_steps = self.cfg['training'].get('warmup_steps', 0)
+ self.lr = self.cfg['training']['lr_values'][0]
+ self.test_writer = tf.summary.create_file_writer(os.path.join(
+ 'runs',
+ self.collection_name,
+ self.name + '-test',
+ ))
+ self.train_writer = tf.summary.create_file_writer(os.path.join(
+ 'runs',
+ self.collection_name,
+ self.name + '-train',
+ ))
+ if self.swa_enabled:
+ self.swa_writer = tf.summary.create_file_writer(os.path.join(
+ 'runs',
+ self.collection_name,
+ self.name + '-swa-test',
+ ))
+ self.checkpoint = tf.train.Checkpoint(optimizer=self.orig_optimizer, model=self.model, global_step=self.global_step, swa_count=self.swa_count)
+ self.checkpoint.listed = self.swa_weights
+ self.manager = tf.train.CheckpointManager(
+ self.checkpoint, directory=self.root_dir, max_to_keep=50, keep_checkpoint_every_n_hours=24)
+
+ def replace_weights_v2(self, new_weights_orig):
+ new_weights = [w for w in new_weights_orig]
+ # self.model.weights ordering doesn't match up nicely, so first shuffle the new weights to match up.
+ # input order is (for convolutional policy):
+ # policy conv
+ # policy bn * 4
+ # policy raw conv and bias
+ # value conv
+ # value bn * 4
+ # value dense with bias
+ # value dense with bias
+ #
+ # output order is (for convolutional policy):
+ # value conv
+ # policy conv
+ # value bn * 4
+ # policy bn * 4
+ # policy raw conv and bias
+ # value dense with bias
+ # value dense with bias
+ new_weights[-5] = new_weights_orig[-10]
+ new_weights[-6] = new_weights_orig[-11]
+ new_weights[-7] = new_weights_orig[-12]
+ new_weights[-8] = new_weights_orig[-13]
+ new_weights[-9] = new_weights_orig[-14]
+ new_weights[-10] = new_weights_orig[-15]
+ new_weights[-11] = new_weights_orig[-5]
+ new_weights[-12] = new_weights_orig[-6]
+ new_weights[-13] = new_weights_orig[-7]
+ new_weights[-14] = new_weights_orig[-8]
+ new_weights[-15] = new_weights_orig[-16]
+ new_weights[-16] = new_weights_orig[-9]
+
+ all_evals = []
+ offset = 0
+ last_was_gamma = False
+ for e, weights in enumerate(self.model.weights):
+ source_idx = e+offset
+ if weights.shape.ndims == 4:
+ # Rescale rule50 related weights as clients do not normalize the input.
+ if e == 0:
+ num_inputs = 112
+ # 50 move rule is the 110th input, or 109 starting from 0.
+ rule50_input = 109
+ for i in range(len(new_weights[source_idx])):
+ if (i % (num_inputs*9))//9 == rule50_input:
+ new_weights[source_idx][i] = new_weights[source_idx][i]*99
+
+ # Convolution weights need a transpose
+ #
+ # TF (kYXInputOutput)
+ # [filter_height, filter_width, in_channels, out_channels]
+ #
+ # Leela/cuDNN/Caffe (kOutputInputYX)
+ # [output, input, filter_size, filter_size]
+ s = weights.shape.as_list()
+ shape = [s[i] for i in [3, 2, 0, 1]]
+ new_weight = tf.constant(new_weights[source_idx], shape=shape)
+ weights.assign(
+ tf.transpose(a=new_weight, perm=[2, 3, 1, 0]))
+ elif weights.shape.ndims == 2:
+ # Fully connected layers are [in, out] in TF
+ #
+ # [out, in] in Leela
+ #
+ s = weights.shape.as_list()
+ shape = [s[i] for i in [1, 0]]
+ new_weight = tf.constant(new_weights[source_idx], shape=shape)
+ weights.assign(
+ tf.transpose(a=new_weight, perm=[1, 0]))
+ else:
+ # Can't populate renorm weights, but the current new_weight will need using elsewhere.
+ if 'renorm' in weights.name:
+ offset-=1
+ continue
+ # betas without gamms need to skip the gamma in the input.
+ if 'beta:' in weights.name and not last_was_gamma:
+ source_idx+=1
+ offset+=1
+ # Biases, batchnorm etc
+ new_weight = tf.constant(new_weights[source_idx], shape=weights.shape)
+ if 'stddev:' in weights.name:
+ weights.assign(tf.math.sqrt(new_weight + 1e-5))
+ else:
+ weights.assign(new_weight)
+ # need to use the variance to also populate the stddev for renorm, so adjust offset.
+ if 'variance:' in weights.name and self.renorm_enabled:
+ offset-=1
+ last_was_gamma = 'gamma:' in weights.name
+ # Replace the SWA weights as well, ensuring swa accumulation is reset.
+ if self.swa_enabled:
+ self.swa_count.assign(tf.constant(0.))
+ self.update_swa_v2()
+ # This should result in identical file to the starting one
+ # self.save_leelaz_weights_v2('restored.pb.gz')
+
+ def restore_v2(self):
+ if self.manager.latest_checkpoint is not None:
+ print("Restoring from {0}".format(self.manager.latest_checkpoint))
+ self.checkpoint.restore(self.manager.latest_checkpoint)
+
+ def restore_ckpt(self, ckpt_path):
+ print("loading lower weights from {}".format(ckpt_path))
+ self.checkpoint_restore.restore(ckpt_path)
+
+ def process_loop_v2(self, batch_size, test_batches, batch_splits=1):
+ # Get the initial steps value in case this is a resume from a step count
+ # which is not a multiple of total_steps.
+ steps = self.global_step.read_value()
+ total_steps = self.cfg['training']['total_steps']
+ for _ in range(steps % total_steps, total_steps):
+ self.process_v2(batch_size, test_batches, batch_splits=batch_splits)
+
+ @tf.function()
+ def read_weights(self):
+ return [w.read_value() for w in self.model.weights]
+
+ @tf.function()
+ def process_inner_loop(self, x, y, z, q):
+ with tf.GradientTape() as tape:
+ policy, value = self.model(x, training=True)
+ policy_loss = self.policy_loss_fn(y, policy)
+ reg_term = sum(self.model.losses)
+ if self.wdl:
+ value_loss = self.value_loss_fn(self.qMix(z, q), value)
+ total_loss = self.lossMix(policy_loss, value_loss) + reg_term
+ else:
+ mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
+ total_loss = self.lossMix(policy_loss, mse_loss) + reg_term
+ if self.loss_scale != 1:
+ total_loss = self.optimizer.get_scaled_loss(total_loss)
+ if self.wdl:
+ mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
+ else:
+ value_loss = self.value_loss_fn(self.qMix(z, q), value)
+ return policy_loss, value_loss, mse_loss, reg_term, tape.gradient(total_loss, self.model.trainable_weights)
+
+ def process_v2(self, batch_size, test_batches, batch_splits=1):
+ if not self.time_start:
+ self.time_start = time.time()
+
+ # Get the initial steps value before we do a training step.
+ steps = self.global_step.read_value()
+ if not self.last_steps:
+ self.last_steps = steps
+
+ if self.swa_enabled:
+ # split half of test_batches between testing regular weights and SWA weights
+ test_batches //= 2
+ # Run test before first step to see delta since end of last run.
+ if steps % self.cfg['training']['total_steps'] == 0:
+ # Steps is given as one higher than current in order to avoid it
+ # being equal to the value the end of a run is stored against.
+ self.calculate_test_summaries_v2(test_batches, steps + 1)
+ if self.swa_enabled:
+ self.calculate_swa_summaries_v2(test_batches, steps + 1)
+
+ # Make sure that ghost batch norm can be applied
+ if batch_size % 64 != 0:
+ # Adjust required batch size for batch splitting.
+ required_factor = 64 * \
+ self.cfg['training'].get('num_batch_splits', 1)
+ raise ValueError(
+ 'batch_size must be a multiple of {}'.format(required_factor))
+
+ # Determine learning rate
+ lr_values = self.cfg['training']['lr_values']
+ lr_boundaries = self.cfg['training']['lr_boundaries']
+ steps_total = steps % self.cfg['training']['total_steps']
+ self.lr = lr_values[bisect.bisect_right(lr_boundaries, steps_total)]
+ if self.warmup_steps > 0 and steps < self.warmup_steps:
+ self.lr = self.lr * tf.cast(steps + 1, tf.float32) / self.warmup_steps
+
+ # need to add 1 to steps because steps will be incremented after gradient update
+ if (steps + 1) % self.cfg['training']['train_avg_report_steps'] == 0 or (steps + 1) % self.cfg['training']['total_steps'] == 0:
+ before_weights = self.read_weights()
+
+ x, y, z, q = next(self.train_iter)
+ policy_loss, value_loss, mse_loss, reg_term, grads = self.process_inner_loop(x, y, z, q)
+ # Keep running averages
+ # Google's paper scales MSE by 1/4 to a [0, 1] range, so do the same to
+ # get comparable values.
+ mse_loss /= 4.0
+ self.avg_policy_loss.append(policy_loss)
+ if self.wdl:
+ self.avg_value_loss.append(value_loss)
+ self.avg_mse_loss.append(mse_loss)
+ self.avg_reg_term.append(reg_term)
+
+ # Gradients of batch splits are summed, not averaged like usual, so need to scale lr accordingly to correct for this.
+ self.active_lr = self.lr / batch_splits
+ if self.loss_scale != 1:
+ grads = self.optimizer.get_unscaled_gradients(grads)
+ max_grad_norm = self.cfg['training'].get('max_grad_norm', 10000.0) * batch_splits
+ grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
+ self.optimizer.apply_gradients([
+ (grad, var)
+ for (grad, var) in zip(grads, self.model.trainable_variables)
+ if grad is not None
+ ])
+
+ # Update steps.
+ self.global_step.assign_add(1)
+ steps = self.global_step.read_value()
+
+ if steps % self.cfg['training']['train_avg_report_steps'] == 0 or steps % self.cfg['training']['total_steps'] == 0:
+ pol_loss_w = self.cfg['training']['policy_loss_weight']
+ val_loss_w = self.cfg['training']['value_loss_weight']
+ time_end = time.time()
+ speed = 0
+ if self.time_start:
+ elapsed = time_end - self.time_start
+ steps_elapsed = steps - self.last_steps
+ speed = batch_size * (tf.cast(steps_elapsed, tf.float32) / elapsed)
+ avg_policy_loss = np.mean(self.avg_policy_loss or [0])
+ avg_value_loss = np.mean(self.avg_value_loss or [0])
+ avg_mse_loss = np.mean(self.avg_mse_loss or [0])
+ avg_reg_term = np.mean(self.avg_reg_term or [0])
+ printWithDate("step {}, lr={:g} policy={:g} value={:g} mse={:g} reg={:g} total={:g} ({:g} pos/s)".format(
+ steps, self.lr, avg_policy_loss, avg_value_loss, avg_mse_loss, avg_reg_term,
+ pol_loss_w * avg_policy_loss + val_loss_w * avg_value_loss + avg_reg_term,
+ speed))
+
+ after_weights = self.read_weights()
+ with self.train_writer.as_default():
+ tf.summary.scalar("Policy Loss", avg_policy_loss, step=steps)
+ tf.summary.scalar("Value Loss", avg_value_loss, step=steps)
+ tf.summary.scalar("Reg term", avg_reg_term, step=steps)
+ tf.summary.scalar("LR", self.lr, step=steps)
+ tf.summary.scalar("Gradient norm", grad_norm / batch_splits, step=steps)
+ tf.summary.scalar("MSE Loss", avg_mse_loss, step=steps)
+ self.compute_update_ratio_v2(
+ before_weights, after_weights, steps)
+ self.train_writer.flush()
+ self.time_start = time_end
+ self.last_steps = steps
+ self.avg_policy_loss, self.avg_value_loss, self.avg_mse_loss, self.avg_reg_term = [], [], [], []
+
+ if self.swa_enabled and steps % self.cfg['training']['swa_steps'] == 0:
+ self.update_swa_v2()
+
+ # Calculate test values every 'test_steps', but also ensure there is
+ # one at the final step so the delta to the first step can be calculted.
+ if steps % self.cfg['training']['test_steps'] == 0 or steps % self.cfg['training']['total_steps'] == 0:
+ self.calculate_test_summaries_v2(test_batches, steps)
+ if self.swa_enabled:
+ self.calculate_swa_summaries_v2(test_batches, steps)
+
+ # Save session and weights at end, and also optionally every 'checkpoint_steps'.
+ if steps % self.cfg['training']['total_steps'] == 0 or (
+ 'checkpoint_steps' in self.cfg['training'] and steps % self.cfg['training']['checkpoint_steps'] == 0):
+ self.manager.save()
+ print("Model saved in file: {}".format(self.manager.latest_checkpoint))
+ evaled_steps = steps.numpy()
+ leela_path = self.manager.latest_checkpoint + "-" + str(evaled_steps)
+ swa_path = self.manager.latest_checkpoint + "-swa-" + str(evaled_steps)
+ self.net.pb.training_params.training_steps = evaled_steps
+ self.save_leelaz_weights_v2(leela_path)
+ print("Weights saved in file: {}".format(leela_path))
+ if self.swa_enabled:
+ self.save_swa_weights_v2(swa_path)
+ print("SWA Weights saved in file: {}".format(swa_path))
+
+ def calculate_swa_summaries_v2(self, test_batches, steps):
+ backup = self.read_weights()
+ for (swa, w) in zip(self.swa_weights, self.model.weights):
+ w.assign(swa.read_value())
+ true_test_writer, self.test_writer = self.test_writer, self.swa_writer
+ print('swa', end=' ')
+ self.calculate_test_summaries_v2(test_batches, steps)
+ self.test_writer = true_test_writer
+ for (old, w) in zip(backup, self.model.weights):
+ w.assign(old)
+
+ @tf.function()
+ def calculate_test_summaries_inner_loop(self, x, y, z, q):
+ policy, value = self.model(x, training=False)
+ policy_loss = self.policy_loss_fn(y, policy)
+ policy_accuracy = self.policy_accuracy_fn(y, policy)
+ if self.wdl:
+ value_loss = self.value_loss_fn(self.qMix(z, q), value)
+ mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
+ value_accuracy = self.accuracy_fn(self.qMix(z,q), value)
+ else:
+ value_loss = self.value_loss_fn(self.qMix(z, q), value)
+ mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
+ value_accuracy = tf.constant(0.)
+ return policy_loss, value_loss, mse_loss, policy_accuracy, value_accuracy
+
+ def calculate_test_summaries_v2(self, test_batches, steps):
+ sum_policy_accuracy = 0
+ sum_value_accuracy = 0
+ sum_mse = 0
+ sum_policy = 0
+ sum_value = 0
+ for _ in range(0, test_batches):
+ x, y, z, q = next(self.test_iter)
+ policy_loss, value_loss, mse_loss, policy_accuracy, value_accuracy = self.calculate_test_summaries_inner_loop(x, y, z, q)
+ sum_policy_accuracy += policy_accuracy
+ sum_mse += mse_loss
+ sum_policy += policy_loss
+ if self.wdl:
+ sum_value_accuracy += value_accuracy
+ sum_value += value_loss
+ sum_policy_accuracy /= test_batches
+ sum_policy_accuracy *= 100
+ sum_policy /= test_batches
+ sum_value /= test_batches
+ if self.wdl:
+ sum_value_accuracy /= test_batches
+ sum_value_accuracy *= 100
+ # Additionally rescale to [0, 1] so divide by 4
+ sum_mse /= (4.0 * test_batches)
+ self.net.pb.training_params.learning_rate = self.lr
+ self.net.pb.training_params.mse_loss = sum_mse
+ self.net.pb.training_params.policy_loss = sum_policy
+ # TODO store value and value accuracy in pb
+ self.net.pb.training_params.accuracy = sum_policy_accuracy
+ with self.test_writer.as_default():
+ tf.summary.scalar("Policy Loss", sum_policy, step=steps)
+ tf.summary.scalar("Value Loss", sum_value, step=steps)
+ tf.summary.scalar("MSE Loss", sum_mse, step=steps)
+ tf.summary.scalar("Policy Accuracy", sum_policy_accuracy, step=steps)
+ if self.wdl:
+ tf.summary.scalar("Value Accuracy", sum_value_accuracy, step=steps)
+ for w in self.model.weights:
+ tf.summary.histogram(w.name, w, buckets=1000, step=steps)
+ self.test_writer.flush()
+
+ printWithDate("step {}, policy={:g} value={:g} policy accuracy={:g}% value accuracy={:g}% mse={:g}".\
+ format(steps, sum_policy, sum_value, sum_policy_accuracy, sum_value_accuracy, sum_mse))
+
+ @tf.function()
+ def compute_update_ratio_v2(self, before_weights, after_weights, steps):
+ """Compute the ratio of gradient norm to weight norm.
+
+ Adapted from https://github.com/tensorflow/minigo/blob/c923cd5b11f7d417c9541ad61414bf175a84dc31/dual_net.py#L567
+ """
+ deltas = [after - before for after,
+ before in zip(after_weights, before_weights)]
+ delta_norms = [tf.math.reduce_euclidean_norm(d) for d in deltas]
+ weight_norms = [tf.math.reduce_euclidean_norm(w) for w in before_weights]
+ ratios = [(tensor.name, tf.cond(w != 0., lambda: d / w, lambda: -1.)) for d, w, tensor in zip(delta_norms, weight_norms, self.model.weights) if not 'moving' in tensor.name]
+ for name, ratio in ratios:
+ tf.summary.scalar('update_ratios/' + name, ratio, step=steps)
+ # Filtering is hard, so just push infinities/NaNs to an unreasonably large value.
+ ratios = [tf.cond(r > 0, lambda: tf.math.log(r) / 2.30258509299, lambda: 200.) for (_, r) in ratios]
+ tf.summary.histogram('update_ratios_log10', tf.stack(ratios), buckets=1000, step=steps)
+
+ def update_swa_v2(self):
+ num = self.swa_count.read_value()
+ for (w, swa) in zip(self.model.weights, self.swa_weights):
+ swa.assign(swa.read_value() * (num / (num + 1.)) + w.read_value() * (1. / (num + 1.)))
+ self.swa_count.assign(min(num + 1., self.swa_max_n))
+
+ def save_swa_weights_v2(self, filename):
+ backup = self.read_weights()
+ for (swa, w) in zip(self.swa_weights, self.model.weights):
+ w.assign(swa.read_value())
+ self.save_leelaz_weights_v2(filename)
+ for (old, w) in zip(backup, self.model.weights):
+ w.assign(old)
+
+ def save_leelaz_weights_v2(self, filename):
+ all_tensors = []
+ all_weights = []
+ last_was_gamma = False
+ for weights in self.model.weights:
+ work_weights = None
+ if weights.shape.ndims == 4:
+ # Convolution weights need a transpose
+ #
+ # TF (kYXInputOutput)
+ # [filter_height, filter_width, in_channels, out_channels]
+ #
+ # Leela/cuDNN/Caffe (kOutputInputYX)
+ # [output, input, filter_size, filter_size]
+ work_weights = tf.transpose(a=weights, perm=[3, 2, 0, 1])
+ elif weights.shape.ndims == 2:
+ # Fully connected layers are [in, out] in TF
+ #
+ # [out, in] in Leela
+ #
+ work_weights = tf.transpose(a=weights, perm=[1, 0])
+ else:
+ # batch renorm has extra weights, but we don't know what to do with them.
+ if 'renorm' in weights.name:
+ continue
+ # renorm has variance, but it is not the primary source of truth
+ if 'variance:' in weights.name and self.renorm_enabled:
+ continue
+ # Renorm has moving stddev not variance, undo the transform to make it compatible.
+ if 'stddev:' in weights.name:
+ all_tensors.append(tf.math.square(weights) - 1e-5)
+ continue
+ # Biases, batchnorm etc
+ # pb expects every batch norm to have gammas, but not all of our
+ # batch norms have gammas, so manually add pretend gammas.
+ if 'beta:' in weights.name and not last_was_gamma:
+ all_tensors.append(tf.ones_like(weights))
+ work_weights = weights.read_value()
+ all_tensors.append(work_weights)
+ last_was_gamma = 'gamma:' in weights.name
+
+ # HACK: model weights ordering is some kind of breadth first traversal,
+ # but pb expects a specific ordering which BFT is not a match for once
+ # we get to the heads. Apply manual permutation.
+ # This is fragile and at minimum should have some checks to ensure it isn't breaking things.
+ #TODO: also support classic policy head as it has a different set of layers and hence changes the permutation.
+ permuted_tensors = [w for w in all_tensors]
+ permuted_tensors[-5] = all_tensors[-11]
+ permuted_tensors[-6] = all_tensors[-12]
+ permuted_tensors[-7] = all_tensors[-13]
+ permuted_tensors[-8] = all_tensors[-14]
+ permuted_tensors[-9] = all_tensors[-16]
+ permuted_tensors[-10] = all_tensors[-5]
+ permuted_tensors[-11] = all_tensors[-6]
+ permuted_tensors[-12] = all_tensors[-7]
+ permuted_tensors[-13] = all_tensors[-8]
+ permuted_tensors[-14] = all_tensors[-9]
+ permuted_tensors[-15] = all_tensors[-10]
+ permuted_tensors[-16] = all_tensors[-15]
+ all_tensors = permuted_tensors
+
+ for e, nparray in enumerate(all_tensors):
+ # Rescale rule50 related weights as clients do not normalize the input.
+ if e == 0:
+ num_inputs = 112
+ # 50 move rule is the 110th input, or 109 starting from 0.
+ rule50_input = 109
+ wt_flt = []
+ for i, weight in enumerate(np.ravel(nparray)):
+ if (i % (num_inputs*9))//9 == rule50_input:
+ wt_flt.append(weight/99)
+ else:
+ wt_flt.append(weight)
+ else:
+ wt_flt = [wt for wt in np.ravel(nparray)]
+ all_weights.append(wt_flt)
+
+ self.net.fill_net(all_weights)
+ self.net.save_proto(filename)
+
+ def batch_norm_v2(self, input, scale=False):
+ if self.renorm_enabled:
+ clipping = {
+ "rmin": 1.0/self.renorm_max_r,
+ "rmax": self.renorm_max_r,
+ "dmax": self.renorm_max_d
+ }
+ return tf.keras.layers.BatchNormalization(
+ epsilon=1e-5, axis=1, fused=False, center=True,
+ scale=scale, renorm=True, renorm_clipping=clipping,
+ renorm_momentum=self.renorm_momentum)(input)
+ else:
+ return tf.keras.layers.BatchNormalization(
+ epsilon=1e-5, axis=1, fused=False, center=True,
+ scale=scale, virtual_batch_size=64)(input)
+
+ def squeeze_excitation_v2(self, inputs, channels):
+ assert channels % self.SE_ratio == 0
+
+ pooled = tf.keras.layers.GlobalAveragePooling2D(data_format='channels_first')(inputs)
+ squeezed = tf.keras.layers.Activation('relu')(tf.keras.layers.Dense(channels // self.SE_ratio, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg)(pooled))
+ excited = tf.keras.layers.Dense(2 * channels, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg)(squeezed)
+ return ApplySqueezeExcitation()([inputs, excited])
+
+ def conv_block_v2(self, inputs, filter_size, output_channels, bn_scale=False):
+ conv = tf.keras.layers.Conv2D(output_channels, filter_size, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first')(inputs)
+ return tf.keras.layers.Activation('relu')(self.batch_norm_v2(conv, scale=bn_scale))
+
+ def residual_block_v2(self, inputs, channels):
+ conv1 = tf.keras.layers.Conv2D(channels, 3, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first')(inputs)
+ out1 = tf.keras.layers.Activation('relu')(self.batch_norm_v2(conv1, scale=False))
+ conv2 = tf.keras.layers.Conv2D(channels, 3, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first')(out1)
+ out2 = self.squeeze_excitation_v2(self.batch_norm_v2(conv2, scale=True), channels)
+ return tf.keras.layers.Activation('relu')(tf.keras.layers.add([inputs, out2]))
+
+ def construct_net_complete(self, inputs):
+ flow = self.conv_block_v2(inputs, filter_size=3, output_channels=self.RESIDUAL_FILTERS, bn_scale=True)
+
+ for _ in range(0, self.RESIDUAL_BLOCKS):
+ flow = self.residual_block_v2(flow, self.RESIDUAL_FILTERS)
+
+ # Policy head
+ conv_pol = self.conv_block_v2(flow, filter_size=3, output_channels=self.RESIDUAL_FILTERS)
+ conv_pol2 = tf.keras.layers.Conv2D(80, 3, use_bias=True, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, data_format='channels_first')(conv_pol)
+ h_fc1 = ApplyPolicyMap()(conv_pol2)
+
+ # Value head
+ conv_val = self.conv_block_v2(flow, filter_size=1, output_channels=32)
+ h_conv_val_flat = tf.keras.layers.Flatten()(conv_val)
+ h_fc2 = tf.keras.layers.Dense(128, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation='relu')(h_conv_val_flat)
+ h_fc3 = tf.keras.layers.Dense(3, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg)(h_fc2)
+ return h_fc1, h_fc3
+
+ def construct_with_stops(self, inputs, num_sections, from_top = True):
+
+ tot_sections = self.RESIDUAL_BLOCKS + 4
+ if from_top:
+ num_sections = tot_sections - num_sections
+ #section_count = 0 or less is no stops
+ section_count = 1
+
+ flow_p = self.conv_block_v2(inputs, filter_size=3, output_channels=self.RESIDUAL_FILTERS, bn_scale=True)
+ if section_count == num_sections:
+ flow_p = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(flow_p)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ for _ in range(0, self.RESIDUAL_BLOCKS):
+
+ section_count += 1
+ flow_p = self.residual_block_v2(flow_p, self.RESIDUAL_FILTERS)
+ if section_count == num_sections:
+ flow_p = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(flow_p)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ section_count += 1
+ conv_pol = self.conv_block_v2(flow_p, filter_size=3, output_channels=self.RESIDUAL_FILTERS)
+ if section_count == num_sections:
+ conv_pol = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(conv_pol)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ section_count += 1
+ conv_pol2 = tf.keras.layers.Conv2D(80, 3, use_bias=True, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, data_format='channels_first')(conv_pol)
+ if section_count == num_sections:
+ conv_pol2 = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(conv_pol2)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ section_count += 1
+ h_fc1 = ApplyPolicyMap()(conv_pol2)
+ if section_count == num_sections:
+ h_fc1 = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(h_fc1)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ if section_count != tot_sections:
+ raise RuntimeError(f"Number of sections was calculated to be {tot_sections}, but is actually {section_count}")
+
+ # Value head
+ flow_v = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(flow_p)
+ conv_val = self.conv_block_v2(flow_v, filter_size=1, output_channels=32)
+ h_conv_val_flat = tf.keras.layers.Flatten()(conv_val)
+ h_fc2 = tf.keras.layers.Dense(128, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation='relu')(h_conv_val_flat)
+ h_fc3 = tf.keras.layers.Dense(3, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg)(h_fc2)
+
+
+ return h_fc1, h_fc3
+
diff --git a/backend/tf_transfer/tfprocess_reg_lr_noise.py b/backend/tf_transfer/tfprocess_reg_lr_noise.py
new file mode 100755
index 0000000..71cfcdb
--- /dev/null
+++ b/backend/tf_transfer/tfprocess_reg_lr_noise.py
@@ -0,0 +1,923 @@
+#!/usr/bin/env python3
+#
+# This file is part of Leela Zero.
+# Copyright (C) 2017-2018 Gian-Carlo Pascutto
+#
+# Leela Zero is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# Leela Zero is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with Leela Zero. If not, see .
+
+import numpy as np
+import os
+import os.path
+import random
+import tensorflow as tf
+import time
+import bisect
+from .lc0_az_policy_map import make_map
+import maia.proto.net_pb2 as pb
+
+from .net import Net
+
+from ..utils import printWithDate
+
+import natsort
+
+
+def model_path_gen(short_path):
+ models_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../models'))
+ return os.path.join(models_path, short_path, 'ckpt/ckpt-40')
+
+class ApplySqueezeExcitation(tf.keras.layers.Layer):
+ def __init__(self, **kwargs):
+ super(ApplySqueezeExcitation, self).__init__(**kwargs)
+
+ def build(self, input_dimens):
+ self.reshape_size = input_dimens[1][1]
+
+ def call(self, inputs):
+ x = inputs[0]
+ excited = inputs[1]
+ gammas, betas = tf.split(tf.reshape(excited, [-1, self.reshape_size, 1, 1]), 2, axis=1)
+ return tf.nn.sigmoid(gammas) * x + betas
+
+
+class ApplyPolicyMap(tf.keras.layers.Layer):
+ def __init__(self, **kwargs):
+ super(ApplyPolicyMap, self).__init__(**kwargs)
+ self.fc1 = tf.constant(make_map())
+
+ def call(self, inputs):
+ h_conv_pol_flat = tf.reshape(inputs, [-1, 80*8*8])
+ return tf.matmul(h_conv_pol_flat, tf.cast(self.fc1, h_conv_pol_flat.dtype))
+
+class TFProcess:
+ def __init__(self, cfg, name, collection_name):
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+ self.cfg = cfg
+ self.name = name
+ self.collection_name = collection_name
+ self.net = Net()
+ self.root_dir = os.path.join('models', self.collection_name, self.name)
+
+ # Network structure
+ self.RESIDUAL_FILTERS = self.cfg['model']['filters']
+ self.RESIDUAL_BLOCKS = self.cfg['model']['residual_blocks']
+ self.SE_ratio = self.cfg['model']['se_ratio']
+ self.policy_channels = self.cfg['model'].get('policy_channels', 32)
+ precision = self.cfg['training'].get('precision', 'single')
+ loss_scale = self.cfg['training'].get('loss_scale', 128)
+
+ if precision == 'single':
+ self.model_dtype = tf.float32
+ elif precision == 'half':
+ self.model_dtype = tf.float16
+ else:
+ raise ValueError("Unknown precision: {}".format(precision))
+
+ # Scale the loss to prevent gradient underflow
+ self.loss_scale = 1 if self.model_dtype == tf.float32 else loss_scale
+
+ self.VALUE_HEAD = None
+
+ self.POLICY_HEAD = pb.NetworkFormat.POLICY_CONVOLUTION
+
+ self.net.set_policyformat(self.POLICY_HEAD)
+
+ self.VALUE_HEAD = pb.NetworkFormat.VALUE_WDL
+ self.wdl = True
+
+
+ self.net.set_valueformat(self.VALUE_HEAD)
+
+ self.swa_enabled = self.cfg['training'].get('swa', False)
+
+ # Limit momentum of SWA exponential average to 1 - 1/(swa_max_n + 1)
+ self.swa_max_n = self.cfg['training'].get('swa_max_n', 0)
+
+ self.renorm_enabled = self.cfg['training'].get('renorm', False)
+ self.renorm_max_r = self.cfg['training'].get('renorm_max_r', 1)
+ self.renorm_max_d = self.cfg['training'].get('renorm_max_d', 0)
+ self.renorm_momentum = self.cfg['training'].get('renorm_momentum', 0.99)
+
+ gpus = tf.config.experimental.list_physical_devices('GPU')
+ tf.config.experimental.set_visible_devices(gpus[self.cfg['gpu']], 'GPU')
+ tf.config.experimental.set_memory_growth(gpus[self.cfg['gpu']], True)
+ if self.model_dtype == tf.float16:
+ tf.keras.mixed_precision.experimental.set_policy('mixed_float16')
+
+ self.global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int64)
+
+ def init_v2(self, train_dataset, test_dataset):
+ self.train_dataset = train_dataset
+ self.train_iter = iter(train_dataset)
+ self.test_dataset = test_dataset
+ self.test_iter = iter(test_dataset)
+ self.init_net_v2()
+
+ def init_net_v2(self):
+ self.l2reg = tf.keras.regularizers.l2(l=0.5 * (0.0001))
+ input_var = tf.keras.Input(shape=(112, 8*8))
+ x_planes = tf.keras.layers.Reshape([112, 8, 8])(input_var)
+
+ base_ckpt_path = model_path_gen(self.cfg['model']['path'])
+
+ self.model_maia = tf.keras.Model(inputs=input_var, outputs=self.construct_net_complete(x_planes))
+ self.checkpoint_restore = tf.train.Checkpoint(model=self.model_maia)
+ self.restore_ckpt(base_ckpt_path)
+
+ # The tf names use natural numbers with no prefixes
+ # so to index layers correctly we neec to sort them naturally
+
+ natsort_key = natsort.natsort_keygen()
+
+ self.model = tf.keras.Model(
+ inputs=input_var,
+ outputs=self.construct_with_stops(
+ x_planes,
+ self.cfg['model'].get('back_prop_blocks', 3),
+ ))
+ maia_layers = sorted(
+ self.model_maia.layers,
+ key = lambda x : natsort_key(x.name),
+ )
+ model_layers = sorted(
+ [l for l in self.model.layers if 'lambda' not in l.name],
+ key = lambda x : natsort_key(x.name),
+ )
+
+ layer_map = {model_layer.name : maia_layer for model_layer, maia_layer in zip(model_layers, maia_layers)}
+
+ for i, model_layer in enumerate(self.model.layers):
+ if not self.cfg['model'].get('keep_weights', False) and self.cfg['model'].get('back_prop_blocks', 3) > self.RESIDUAL_BLOCKS + 4:
+ printWithDate(f"ending at depth {i}: {model_layer.name}")
+ break
+ # modify pretrained weights with gaussian noise
+ if 'lambda' not in model_layer.name:
+ # l_maia = layer_map[model_layer.name]
+ # model_layer.set_weights([w.numpy() for w in l_maia.weights])
+
+ multiplier = 0.01
+ l_maia = layer_map[model_layer.name]
+ new_weights = []
+ for w in l_maia.weights:
+ layer_weight = w.numpy()
+ noise = np.random.normal(loc=0, scale=multiplier * np.std(layer_weight), size=layer_weight.shape)
+ layer_weight = layer_weight + noise
+ new_weights.append(layer_weight)
+
+ model_layer.set_weights(new_weights)
+
+ elif not self.cfg['model'].get('keep_weights', False):
+ printWithDate(f"ending at depth {i}: {model_layer.name}")
+ break
+
+ printWithDate("Setting up lc0 stuff")
+ # swa_count initialized reguardless to make checkpoint code simpler.
+ self.swa_count = tf.Variable(0., name='swa_count', trainable=False)
+ self.swa_weights = None
+ if self.swa_enabled:
+ # Count of networks accumulated into SWA
+ self.swa_weights = [tf.Variable(w, trainable=False) for w in self.model.weights]
+
+ self.active_lr = 0.01
+ self.optimizer = tf.keras.optimizers.SGD(learning_rate=lambda: self.active_lr, momentum=0.9, nesterov=True)
+ self.orig_optimizer = self.optimizer
+ if self.loss_scale != 1:
+ self.optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(self.optimizer, self.loss_scale)
+ def correct_policy(target, output):
+ output = tf.cast(output, tf.float32)
+ # Calculate loss on policy head
+ if self.cfg['training'].get('mask_legal_moves'):
+ # extract mask for legal moves from target policy
+ move_is_legal = tf.greater_equal(target, 0)
+ # replace logits of illegal moves with large negative value (so that it doesn't affect policy of legal moves) without gradient
+ illegal_filler = tf.zeros_like(output) - 1.0e10
+ output = tf.where(move_is_legal, output, illegal_filler)
+ # y_ still has -1 on illegal moves, flush them to 0
+ target = tf.nn.relu(target)
+ return target, output
+ def policy_loss(target, output):
+ target, output = correct_policy(target, output)
+ policy_cross_entropy = \
+ tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(target),
+ logits=output)
+ return tf.reduce_mean(input_tensor=policy_cross_entropy)
+ self.policy_loss_fn = policy_loss
+ def policy_accuracy(target, output):
+ target, output = correct_policy(target, output)
+ return tf.reduce_mean(tf.cast(tf.equal(tf.argmax(input=target, axis=1), tf.argmax(input=output, axis=1)), tf.float32))
+ self.policy_accuracy_fn = policy_accuracy
+
+
+ q_ratio = self.cfg['training'].get('q_ratio', 0)
+ assert 0 <= q_ratio <= 1
+
+ # Linear conversion to scalar to compute MSE with, for comparison to old values
+ wdl = tf.expand_dims(tf.constant([1.0, 0.0, -1.0]), 1)
+
+ self.qMix = lambda z, q: q * q_ratio + z *(1 - q_ratio)
+ # Loss on value head
+ if self.wdl:
+ def value_loss(target, output):
+ output = tf.cast(output, tf.float32)
+ value_cross_entropy = \
+ tf.nn.softmax_cross_entropy_with_logits(labels=tf.stop_gradient(target),
+ logits=output)
+ return tf.reduce_mean(input_tensor=value_cross_entropy)
+ self.value_loss_fn = value_loss
+ def mse_loss(target, output):
+ output = tf.cast(output, tf.float32)
+ scalar_z_conv = tf.matmul(tf.nn.softmax(output), wdl)
+ scalar_target = tf.matmul(target, wdl)
+ return tf.reduce_mean(input_tensor=tf.math.squared_difference(scalar_target, scalar_z_conv))
+ self.mse_loss_fn = mse_loss
+ else:
+ def value_loss(target, output):
+ return tf.constant(0)
+ self.value_loss_fn = value_loss
+ def mse_loss(target, output):
+ output = tf.cast(output, tf.float32)
+ scalar_target = tf.matmul(target, wdl)
+ return tf.reduce_mean(input_tensor=tf.math.squared_difference(scalar_target, output))
+ self.mse_loss_fn = mse_loss
+
+ pol_loss_w = self.cfg['training']['policy_loss_weight']
+ val_loss_w = self.cfg['training']['value_loss_weight']
+ self.lossMix = lambda policy, value: pol_loss_w * policy + val_loss_w * value
+
+ def accuracy(target, output):
+ output = tf.cast(output, tf.float32)
+ return tf.reduce_mean(tf.cast(tf.equal(tf.argmax(input=target, axis=1), tf.argmax(input=output, axis=1)), tf.float32))
+ self.accuracy_fn = accuracy
+
+ self.avg_policy_loss = []
+ self.avg_value_loss = []
+ self.avg_mse_loss = []
+ self.avg_reg_term = []
+ self.time_start = None
+ self.last_steps = None
+ # Set adaptive learning rate during training
+ self.cfg['training']['lr_boundaries'].sort()
+ self.warmup_steps = self.cfg['training'].get('warmup_steps', 0)
+ self.lr = self.cfg['training']['lr_values'][0]
+ self.test_writer = tf.summary.create_file_writer(os.path.join(
+ 'runs',
+ self.collection_name,
+ self.name + '-test',
+ ))
+ self.train_writer = tf.summary.create_file_writer(os.path.join(
+ 'runs',
+ self.collection_name,
+ self.name + '-train',
+ ))
+ if self.swa_enabled:
+ self.swa_writer = tf.summary.create_file_writer(os.path.join(
+ 'runs',
+ self.collection_name,
+ self.name + '-swa-test',
+ ))
+ self.checkpoint = tf.train.Checkpoint(optimizer=self.orig_optimizer, model=self.model, global_step=self.global_step, swa_count=self.swa_count)
+ self.checkpoint.listed = self.swa_weights
+ self.manager = tf.train.CheckpointManager(
+ self.checkpoint, directory=self.root_dir, max_to_keep=50, keep_checkpoint_every_n_hours=24)
+
+ def replace_weights_v2(self, new_weights_orig):
+ new_weights = [w for w in new_weights_orig]
+ # self.model.weights ordering doesn't match up nicely, so first shuffle the new weights to match up.
+ # input order is (for convolutional policy):
+ # policy conv
+ # policy bn * 4
+ # policy raw conv and bias
+ # value conv
+ # value bn * 4
+ # value dense with bias
+ # value dense with bias
+ #
+ # output order is (for convolutional policy):
+ # value conv
+ # policy conv
+ # value bn * 4
+ # policy bn * 4
+ # policy raw conv and bias
+ # value dense with bias
+ # value dense with bias
+ new_weights[-5] = new_weights_orig[-10]
+ new_weights[-6] = new_weights_orig[-11]
+ new_weights[-7] = new_weights_orig[-12]
+ new_weights[-8] = new_weights_orig[-13]
+ new_weights[-9] = new_weights_orig[-14]
+ new_weights[-10] = new_weights_orig[-15]
+ new_weights[-11] = new_weights_orig[-5]
+ new_weights[-12] = new_weights_orig[-6]
+ new_weights[-13] = new_weights_orig[-7]
+ new_weights[-14] = new_weights_orig[-8]
+ new_weights[-15] = new_weights_orig[-16]
+ new_weights[-16] = new_weights_orig[-9]
+
+ all_evals = []
+ offset = 0
+ last_was_gamma = False
+ for e, weights in enumerate(self.model.weights):
+ source_idx = e+offset
+ if weights.shape.ndims == 4:
+ # Rescale rule50 related weights as clients do not normalize the input.
+ if e == 0:
+ num_inputs = 112
+ # 50 move rule is the 110th input, or 109 starting from 0.
+ rule50_input = 109
+ for i in range(len(new_weights[source_idx])):
+ if (i % (num_inputs*9))//9 == rule50_input:
+ new_weights[source_idx][i] = new_weights[source_idx][i]*99
+
+ # Convolution weights need a transpose
+ #
+ # TF (kYXInputOutput)
+ # [filter_height, filter_width, in_channels, out_channels]
+ #
+ # Leela/cuDNN/Caffe (kOutputInputYX)
+ # [output, input, filter_size, filter_size]
+ s = weights.shape.as_list()
+ shape = [s[i] for i in [3, 2, 0, 1]]
+ new_weight = tf.constant(new_weights[source_idx], shape=shape)
+ weights.assign(
+ tf.transpose(a=new_weight, perm=[2, 3, 1, 0]))
+ elif weights.shape.ndims == 2:
+ # Fully connected layers are [in, out] in TF
+ #
+ # [out, in] in Leela
+ #
+ s = weights.shape.as_list()
+ shape = [s[i] for i in [1, 0]]
+ new_weight = tf.constant(new_weights[source_idx], shape=shape)
+ weights.assign(
+ tf.transpose(a=new_weight, perm=[1, 0]))
+ else:
+ # Can't populate renorm weights, but the current new_weight will need using elsewhere.
+ if 'renorm' in weights.name:
+ offset-=1
+ continue
+ # betas without gamms need to skip the gamma in the input.
+ if 'beta:' in weights.name and not last_was_gamma:
+ source_idx+=1
+ offset+=1
+ # Biases, batchnorm etc
+ new_weight = tf.constant(new_weights[source_idx], shape=weights.shape)
+ if 'stddev:' in weights.name:
+ weights.assign(tf.math.sqrt(new_weight + 1e-5))
+ else:
+ weights.assign(new_weight)
+ # need to use the variance to also populate the stddev for renorm, so adjust offset.
+ if 'variance:' in weights.name and self.renorm_enabled:
+ offset-=1
+ last_was_gamma = 'gamma:' in weights.name
+ # Replace the SWA weights as well, ensuring swa accumulation is reset.
+ if self.swa_enabled:
+ self.swa_count.assign(tf.constant(0.))
+ self.update_swa_v2()
+ # This should result in identical file to the starting one
+ # self.save_leelaz_weights_v2('restored.pb.gz')
+
+ def restore_v2(self):
+ if self.manager.latest_checkpoint is not None:
+ print("Restoring from {0}".format(self.manager.latest_checkpoint))
+ self.checkpoint.restore(self.manager.latest_checkpoint)
+
+ def restore_ckpt(self, ckpt_path):
+ print("loading lower weights from {}".format(ckpt_path))
+ self.checkpoint_restore.restore(ckpt_path)
+
+ def process_loop_v2(self, batch_size, test_batches, batch_splits=1):
+ # Get the initial steps value in case this is a resume from a step count
+ # which is not a multiple of total_steps.
+ steps = self.global_step.read_value()
+ total_steps = self.cfg['training']['total_steps']
+ for _ in range(steps % total_steps, total_steps):
+ self.process_v2(batch_size, test_batches, batch_splits=batch_splits)
+
+ @tf.function()
+ def read_weights(self):
+ return [w.read_value() for w in self.model.weights]
+
+ def get_reg_loss(self):
+ '''
+ Collect regularization losses based on where to freeze our model.
+ Each entry in reg_loss_dict represents unique layers' regularization loss.
+ Based on where to freeze, will add up reg loss correspondingly
+ '''
+ # from 0 to 10, 0 is after apply policy map, 10 is no freezing at all
+ stop_point = self.cfg['model'].get('back_prop_blocks', 3)
+ # empty for freezing at layer 0 and 1
+ if stop_point == 0 or stop_point == 1:
+ print("stopping at 0 or 1, no reg term applied")
+ return []
+
+ total_sections = self.RESIDUAL_BLOCKS + 4
+ # stop_point will be <= 10
+ if stop_point > total_sections:
+ stop_point = total_sections
+
+ stops = [i for i in range(0, total_sections + 1)]
+
+ # make regularization loss dict
+ reg_loss_dict = dict.fromkeys(stops, [])
+ # define number of layers before residual tower to determine where each res block is in dict
+ num_layers_before_res = 1
+
+ for model_layer in self.model.layers:
+ if 'conv_block1' in model_layer.name:
+ reg_loss_dict[10] = reg_loss_dict[10] + model_layer.losses
+
+ elif 'policy_head/conv_pol2' in model_layer.name:
+ reg_loss_dict[2] = reg_loss_dict[2] + model_layer.losses
+
+ elif 'policy_head/conv_pol1' in model_layer.name:
+ reg_loss_dict[3] = reg_loss_dict[3] + model_layer.losses
+ else:
+ for i in range(0, self.RESIDUAL_BLOCKS):
+ res_block_name = 'res_{}'.format(i)
+ if res_block_name in model_layer.name:
+ res_index_in_section = total_sections - i - num_layers_before_res
+ reg_loss_dict[res_index_in_section] = reg_loss_dict[res_index_in_section] + model_layer.losses
+
+ # for key, value in reg_loss_dict.items():
+ # print(key, len(value))
+
+ reg_loss = []
+ for i in range(2, stop_point + 1):
+ reg_loss += reg_loss_dict[i]
+ # print(reg_loss)
+ print("stopping at {}".format(stop_point))
+ print("collected {} regularization loss".format(len(reg_loss)))
+ return reg_loss
+
+ @tf.function()
+ def process_inner_loop(self, x, y, z, q):
+ with tf.GradientTape() as tape:
+ policy, value = self.model(x, training=True)
+ policy_loss = self.policy_loss_fn(y, policy)
+ # reg_term = sum(self.model.losses)
+ reg_term = sum(self.get_reg_loss())
+
+ if self.wdl:
+ value_loss = self.value_loss_fn(self.qMix(z, q), value)
+ total_loss = self.lossMix(policy_loss, value_loss) + reg_term
+ else:
+ mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
+ total_loss = self.lossMix(policy_loss, mse_loss) + reg_term
+ if self.loss_scale != 1:
+ total_loss = self.optimizer.get_scaled_loss(total_loss)
+ if self.wdl:
+ mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
+ else:
+ value_loss = self.value_loss_fn(self.qMix(z, q), value)
+ return policy_loss, value_loss, mse_loss, reg_term, tape.gradient(total_loss, self.model.trainable_weights)
+
+ def process_v2(self, batch_size, test_batches, batch_splits=1):
+ if not self.time_start:
+ self.time_start = time.time()
+
+ # Get the initial steps value before we do a training step.
+ steps = self.global_step.read_value()
+ if not self.last_steps:
+ self.last_steps = steps
+
+ if self.swa_enabled:
+ # split half of test_batches between testing regular weights and SWA weights
+ test_batches //= 2
+ # Run test before first step to see delta since end of last run.
+ if steps % self.cfg['training']['total_steps'] == 0:
+ # Steps is given as one higher than current in order to avoid it
+ # being equal to the value the end of a run is stored against.
+ self.calculate_test_summaries_v2(test_batches, steps + 1)
+ if self.swa_enabled:
+ self.calculate_swa_summaries_v2(test_batches, steps + 1)
+
+ # Make sure that ghost batch norm can be applied
+ if batch_size % 64 != 0:
+ # Adjust required batch size for batch splitting.
+ required_factor = 64 * \
+ self.cfg['training'].get('num_batch_splits', 1)
+ raise ValueError(
+ 'batch_size must be a multiple of {}'.format(required_factor))
+
+ # Determine learning rate
+ lr_values = self.cfg['training']['lr_values']
+ lr_boundaries = self.cfg['training']['lr_boundaries']
+ steps_total = steps % self.cfg['training']['total_steps']
+ self.lr = lr_values[bisect.bisect_right(lr_boundaries, steps_total)]
+ if self.warmup_steps > 0 and steps < self.warmup_steps:
+ self.lr = self.lr * tf.cast(steps + 1, tf.float32) / self.warmup_steps
+
+ # need to add 1 to steps because steps will be incremented after gradient update
+ if (steps + 1) % self.cfg['training']['train_avg_report_steps'] == 0 or (steps + 1) % self.cfg['training']['total_steps'] == 0:
+ before_weights = self.read_weights()
+
+ x, y, z, q = next(self.train_iter)
+ policy_loss, value_loss, mse_loss, reg_term, grads = self.process_inner_loop(x, y, z, q)
+
+ # apply different learning rates to different layers
+ new_grads_zip = zip(grads, self.model.trainable_variables)
+ new_grads = []
+ for (grad, var) in new_grads_zip:
+ if 'conv_block1' in var.name:
+ new_grads.append(grad * 0.05)
+ elif 'res_0' in var.name:
+ new_grads.append(grad * 0.1)
+ elif 'res_1' in var.name:
+ new_grads.append(grad * 0.25)
+ elif 'res_2' in var.name:
+ new_grads.append(grad * 0.4)
+ elif 'res_3' in var.name:
+ new_grads.append(grad * 0.55)
+ elif 'res_4' in var.name:
+ new_grads.append(grad * 0.7)
+ elif 'res_5' in var.name:
+ new_grads.append(grad * 0.85)
+ else:
+ new_grads.append(grad)
+
+ grads = new_grads
+
+ # Keep running averages
+ # Google's paper scales MSE by 1/4 to a [0, 1] range, so do the same to
+ # get comparable values.
+ mse_loss /= 4.0
+ self.avg_policy_loss.append(policy_loss)
+ if self.wdl:
+ self.avg_value_loss.append(value_loss)
+ self.avg_mse_loss.append(mse_loss)
+ self.avg_reg_term.append(reg_term)
+
+ # Gradients of batch splits are summed, not averaged like usual, so need to scale lr accordingly to correct for this.
+ self.active_lr = self.lr / batch_splits
+ if self.loss_scale != 1:
+ grads = self.optimizer.get_unscaled_gradients(grads)
+ max_grad_norm = self.cfg['training'].get('max_grad_norm', 10000.0) * batch_splits
+ grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm)
+ self.optimizer.apply_gradients([
+ (grad, var)
+ for (grad, var) in zip(grads, self.model.trainable_variables)
+ if grad is not None
+ ])
+
+ # Update steps.
+ self.global_step.assign_add(1)
+ steps = self.global_step.read_value()
+
+ if steps % self.cfg['training']['train_avg_report_steps'] == 0 or steps % self.cfg['training']['total_steps'] == 0:
+ pol_loss_w = self.cfg['training']['policy_loss_weight']
+ val_loss_w = self.cfg['training']['value_loss_weight']
+ time_end = time.time()
+ speed = 0
+ if self.time_start:
+ elapsed = time_end - self.time_start
+ steps_elapsed = steps - self.last_steps
+ speed = batch_size * (tf.cast(steps_elapsed, tf.float32) / elapsed)
+ avg_policy_loss = np.mean(self.avg_policy_loss or [0])
+ avg_value_loss = np.mean(self.avg_value_loss or [0])
+ avg_mse_loss = np.mean(self.avg_mse_loss or [0])
+ avg_reg_term = np.mean(self.avg_reg_term or [0])
+ printWithDate("step {}, lr={:g} policy={:g} value={:g} mse={:g} reg={:g} total={:g} ({:g} pos/s)".format(
+ steps, self.lr, avg_policy_loss, avg_value_loss, avg_mse_loss, avg_reg_term,
+ pol_loss_w * avg_policy_loss + val_loss_w * avg_value_loss + avg_reg_term,
+ speed))
+
+ after_weights = self.read_weights()
+ with self.train_writer.as_default():
+ tf.summary.scalar("Policy Loss", avg_policy_loss, step=steps)
+ tf.summary.scalar("Value Loss", avg_value_loss, step=steps)
+ tf.summary.scalar("Reg term", avg_reg_term, step=steps)
+ tf.summary.scalar("LR", self.lr, step=steps)
+ tf.summary.scalar("Gradient norm", grad_norm / batch_splits, step=steps)
+ tf.summary.scalar("MSE Loss", avg_mse_loss, step=steps)
+ self.compute_update_ratio_v2(
+ before_weights, after_weights, steps)
+ self.train_writer.flush()
+ self.time_start = time_end
+ self.last_steps = steps
+ self.avg_policy_loss, self.avg_value_loss, self.avg_mse_loss, self.avg_reg_term = [], [], [], []
+
+ if self.swa_enabled and steps % self.cfg['training']['swa_steps'] == 0:
+ self.update_swa_v2()
+
+ # Calculate test values every 'test_steps', but also ensure there is
+ # one at the final step so the delta to the first step can be calculted.
+ if steps % self.cfg['training']['test_steps'] == 0 or steps % self.cfg['training']['total_steps'] == 0:
+ self.calculate_test_summaries_v2(test_batches, steps)
+ if self.swa_enabled:
+ self.calculate_swa_summaries_v2(test_batches, steps)
+
+ # Save session and weights at end, and also optionally every 'checkpoint_steps'.
+ if steps % self.cfg['training']['total_steps'] == 0 or (
+ 'checkpoint_steps' in self.cfg['training'] and steps % self.cfg['training']['checkpoint_steps'] == 0):
+ self.manager.save()
+ print("Model saved in file: {}".format(self.manager.latest_checkpoint))
+ evaled_steps = steps.numpy()
+ leela_path = self.manager.latest_checkpoint + "-" + str(evaled_steps)
+ swa_path = self.manager.latest_checkpoint + "-swa-" + str(evaled_steps)
+ self.net.pb.training_params.training_steps = evaled_steps
+ self.save_leelaz_weights_v2(leela_path)
+ print("Weights saved in file: {}".format(leela_path))
+ if self.swa_enabled:
+ self.save_swa_weights_v2(swa_path)
+ print("SWA Weights saved in file: {}".format(swa_path))
+
+ def calculate_swa_summaries_v2(self, test_batches, steps):
+ backup = self.read_weights()
+ for (swa, w) in zip(self.swa_weights, self.model.weights):
+ w.assign(swa.read_value())
+ true_test_writer, self.test_writer = self.test_writer, self.swa_writer
+ print('swa', end=' ')
+ self.calculate_test_summaries_v2(test_batches, steps)
+ self.test_writer = true_test_writer
+ for (old, w) in zip(backup, self.model.weights):
+ w.assign(old)
+
+ @tf.function()
+ def calculate_test_summaries_inner_loop(self, x, y, z, q):
+ policy, value = self.model(x, training=False)
+ policy_loss = self.policy_loss_fn(y, policy)
+ policy_accuracy = self.policy_accuracy_fn(y, policy)
+ if self.wdl:
+ value_loss = self.value_loss_fn(self.qMix(z, q), value)
+ mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
+ value_accuracy = self.accuracy_fn(self.qMix(z,q), value)
+ else:
+ value_loss = self.value_loss_fn(self.qMix(z, q), value)
+ mse_loss = self.mse_loss_fn(self.qMix(z, q), value)
+ value_accuracy = tf.constant(0.)
+ return policy_loss, value_loss, mse_loss, policy_accuracy, value_accuracy
+
+ def calculate_test_summaries_v2(self, test_batches, steps):
+ sum_policy_accuracy = 0
+ sum_value_accuracy = 0
+ sum_mse = 0
+ sum_policy = 0
+ sum_value = 0
+ for _ in range(0, test_batches):
+ x, y, z, q = next(self.test_iter)
+ policy_loss, value_loss, mse_loss, policy_accuracy, value_accuracy = self.calculate_test_summaries_inner_loop(x, y, z, q)
+ sum_policy_accuracy += policy_accuracy
+ sum_mse += mse_loss
+ sum_policy += policy_loss
+ if self.wdl:
+ sum_value_accuracy += value_accuracy
+ sum_value += value_loss
+ sum_policy_accuracy /= test_batches
+ sum_policy_accuracy *= 100
+ sum_policy /= test_batches
+ sum_value /= test_batches
+ if self.wdl:
+ sum_value_accuracy /= test_batches
+ sum_value_accuracy *= 100
+ # Additionally rescale to [0, 1] so divide by 4
+ sum_mse /= (4.0 * test_batches)
+ self.net.pb.training_params.learning_rate = self.lr
+ self.net.pb.training_params.mse_loss = sum_mse
+ self.net.pb.training_params.policy_loss = sum_policy
+ # TODO store value and value accuracy in pb
+ self.net.pb.training_params.accuracy = sum_policy_accuracy
+ with self.test_writer.as_default():
+ tf.summary.scalar("Policy Loss", sum_policy, step=steps)
+ tf.summary.scalar("Value Loss", sum_value, step=steps)
+ tf.summary.scalar("MSE Loss", sum_mse, step=steps)
+ tf.summary.scalar("Policy Accuracy", sum_policy_accuracy, step=steps)
+ if self.wdl:
+ tf.summary.scalar("Value Accuracy", sum_value_accuracy, step=steps)
+ for w in self.model.weights:
+ tf.summary.histogram(w.name, w, buckets=1000, step=steps)
+ self.test_writer.flush()
+
+ printWithDate("step {}, policy={:g} value={:g} policy accuracy={:g}% value accuracy={:g}% mse={:g}".\
+ format(steps, sum_policy, sum_value, sum_policy_accuracy, sum_value_accuracy, sum_mse))
+
+ @tf.function()
+ def compute_update_ratio_v2(self, before_weights, after_weights, steps):
+ """Compute the ratio of gradient norm to weight norm.
+
+ Adapted from https://github.com/tensorflow/minigo/blob/c923cd5b11f7d417c9541ad61414bf175a84dc31/dual_net.py#L567
+ """
+ deltas = [after - before for after,
+ before in zip(after_weights, before_weights)]
+ delta_norms = [tf.math.reduce_euclidean_norm(d) for d in deltas]
+ weight_norms = [tf.math.reduce_euclidean_norm(w) for w in before_weights]
+ ratios = [(tensor.name, tf.cond(w != 0., lambda: d / w, lambda: -1.)) for d, w, tensor in zip(delta_norms, weight_norms, self.model.weights) if not 'moving' in tensor.name]
+ for name, ratio in ratios:
+ tf.summary.scalar('update_ratios/' + name, ratio, step=steps)
+ # Filtering is hard, so just push infinities/NaNs to an unreasonably large value.
+ ratios = [tf.cond(r > 0, lambda: tf.math.log(r) / 2.30258509299, lambda: 200.) for (_, r) in ratios]
+ tf.summary.histogram('update_ratios_log10', tf.stack(ratios), buckets=1000, step=steps)
+
+ def update_swa_v2(self):
+ num = self.swa_count.read_value()
+ for (w, swa) in zip(self.model.weights, self.swa_weights):
+ swa.assign(swa.read_value() * (num / (num + 1.)) + w.read_value() * (1. / (num + 1.)))
+ self.swa_count.assign(min(num + 1., self.swa_max_n))
+
+ def save_swa_weights_v2(self, filename):
+ backup = self.read_weights()
+ for (swa, w) in zip(self.swa_weights, self.model.weights):
+ w.assign(swa.read_value())
+ self.save_leelaz_weights_v2(filename)
+ for (old, w) in zip(backup, self.model.weights):
+ w.assign(old)
+
+ def save_leelaz_weights_v2(self, filename):
+ all_tensors = []
+ all_weights = []
+ last_was_gamma = False
+ for weights in self.model.weights:
+ work_weights = None
+ if weights.shape.ndims == 4:
+ # Convolution weights need a transpose
+ #
+ # TF (kYXInputOutput)
+ # [filter_height, filter_width, in_channels, out_channels]
+ #
+ # Leela/cuDNN/Caffe (kOutputInputYX)
+ # [output, input, filter_size, filter_size]
+ work_weights = tf.transpose(a=weights, perm=[3, 2, 0, 1])
+ elif weights.shape.ndims == 2:
+ # Fully connected layers are [in, out] in TF
+ #
+ # [out, in] in Leela
+ #
+ work_weights = tf.transpose(a=weights, perm=[1, 0])
+ else:
+ # batch renorm has extra weights, but we don't know what to do with them.
+ if 'renorm' in weights.name:
+ continue
+ # renorm has variance, but it is not the primary source of truth
+ if 'variance:' in weights.name and self.renorm_enabled:
+ continue
+ # Renorm has moving stddev not variance, undo the transform to make it compatible.
+ if 'stddev:' in weights.name:
+ all_tensors.append(tf.math.square(weights) - 1e-5)
+ continue
+ # Biases, batchnorm etc
+ # pb expects every batch norm to have gammas, but not all of our
+ # batch norms have gammas, so manually add pretend gammas.
+ if 'beta:' in weights.name and not last_was_gamma:
+ all_tensors.append(tf.ones_like(weights))
+ work_weights = weights.read_value()
+ all_tensors.append(work_weights)
+ last_was_gamma = 'gamma:' in weights.name
+
+ # HACK: model weights ordering is some kind of breadth first traversal,
+ # but pb expects a specific ordering which BFT is not a match for once
+ # we get to the heads. Apply manual permutation.
+ # This is fragile and at minimum should have some checks to ensure it isn't breaking things.
+ #TODO: also support classic policy head as it has a different set of layers and hence changes the permutation.
+ permuted_tensors = [w for w in all_tensors]
+ permuted_tensors[-5] = all_tensors[-11]
+ permuted_tensors[-6] = all_tensors[-12]
+ permuted_tensors[-7] = all_tensors[-13]
+ permuted_tensors[-8] = all_tensors[-14]
+ permuted_tensors[-9] = all_tensors[-16]
+ permuted_tensors[-10] = all_tensors[-5]
+ permuted_tensors[-11] = all_tensors[-6]
+ permuted_tensors[-12] = all_tensors[-7]
+ permuted_tensors[-13] = all_tensors[-8]
+ permuted_tensors[-14] = all_tensors[-9]
+ permuted_tensors[-15] = all_tensors[-10]
+ permuted_tensors[-16] = all_tensors[-15]
+ all_tensors = permuted_tensors
+
+ for e, nparray in enumerate(all_tensors):
+ # Rescale rule50 related weights as clients do not normalize the input.
+ if e == 0:
+ num_inputs = 112
+ # 50 move rule is the 110th input, or 109 starting from 0.
+ rule50_input = 109
+ wt_flt = []
+ for i, weight in enumerate(np.ravel(nparray)):
+ if (i % (num_inputs*9))//9 == rule50_input:
+ wt_flt.append(weight/99)
+ else:
+ wt_flt.append(weight)
+ else:
+ wt_flt = [wt for wt in np.ravel(nparray)]
+ all_weights.append(wt_flt)
+
+ self.net.fill_net(all_weights)
+ self.net.save_proto(filename)
+
+ def set_name(self, name, suffix):
+ return None if not name else "{}/{}".format(name, suffix)
+
+ def batch_norm_v2(self, input, scale=False, name=None):
+ if self.renorm_enabled:
+ clipping = {
+ "rmin": 1.0/self.renorm_max_r,
+ "rmax": self.renorm_max_r,
+ "dmax": self.renorm_max_d
+ }
+ return tf.keras.layers.BatchNormalization(
+ epsilon=1e-5, axis=1, fused=False, center=True,
+ scale=scale, renorm=True, renorm_clipping=clipping,
+ renorm_momentum=self.renorm_momentum, name=self.set_name(name,'batchnorm'))(input)
+ else:
+ return tf.keras.layers.BatchNormalization(
+ epsilon=1e-5, axis=1, fused=False, center=True,
+ scale=scale, virtual_batch_size=64, name=self.set_name(name,'batchnorm'))(input)
+
+ def squeeze_excitation_v2(self, inputs, channels, name):
+ assert channels % self.SE_ratio == 0
+
+ pooled = tf.keras.layers.GlobalAveragePooling2D(data_format='channels_first', name=self.set_name(name,'global_avgpool'))(inputs)
+ squeezed = tf.keras.layers.Activation('relu', name=self.set_name(name,'activation'))(tf.keras.layers.Dense(channels // self.SE_ratio, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, name=self.set_name(name,'dense_1'))(pooled))
+ excited = tf.keras.layers.Dense(2 * channels, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, name=self.set_name(name,'dense_2'))(squeezed)
+ return ApplySqueezeExcitation(name=self.set_name(name,'squeeze_excitation'))([inputs, excited])
+
+ def conv_block_v2(self, inputs, filter_size, output_channels, bn_scale=False, name=None):
+ conv = tf.keras.layers.Conv2D(output_channels, filter_size, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first', name=self.set_name(name,'conv2d'))(inputs)
+ return tf.keras.layers.Activation('relu', name=self.set_name(name,'activation'))(self.batch_norm_v2(conv, scale=bn_scale, name=None if not name else name))
+
+ def residual_block_v2(self, inputs, channels, name=None):
+ conv1 = tf.keras.layers.Conv2D(channels, 3, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first', name=self.set_name(name,'conv2d_1'))(inputs)
+ out1 = tf.keras.layers.Activation('relu', name=self.set_name(name,'activation_1'))(self.batch_norm_v2(conv1, scale=False, name = None if not name else (name + '/bn_1')))
+ conv2 = tf.keras.layers.Conv2D(channels, 3, use_bias=False, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, data_format='channels_first', name=self.set_name(name,'conv2d_2'))(out1)
+ out2 = self.squeeze_excitation_v2(self.batch_norm_v2(conv2, scale=True, name = None if not name else (name + '/bn_2')), channels, None if not name else (name + '/se_block'))
+ return tf.keras.layers.Activation('relu', name=self.set_name(name,'activation_2'))(tf.keras.layers.add([inputs, out2], name=self.set_name(name,'add')))
+
+ def construct_net_complete(self, inputs):
+ flow = self.conv_block_v2(inputs, filter_size=3, output_channels=self.RESIDUAL_FILTERS, bn_scale=True, name='conv_block1')
+
+ for _ in range(0, self.RESIDUAL_BLOCKS):
+ flow = self.residual_block_v2(flow, self.RESIDUAL_FILTERS, name='res_tower/res_{}'.format(_))
+
+ # Policy head
+ conv_pol = self.conv_block_v2(flow, filter_size=3, output_channels=self.RESIDUAL_FILTERS, name='policy_head/conv_pol1')
+ conv_pol2 = tf.keras.layers.Conv2D(80, 3, use_bias=True, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, data_format='channels_first', name='policy_head/conv_pol2')(conv_pol)
+ h_fc1 = ApplyPolicyMap(name='policy_head/h_fc1')(conv_pol2)
+
+ # Value head
+ conv_val = self.conv_block_v2(flow, filter_size=1, output_channels=32, name='value_head/conv_val')
+ h_conv_val_flat = tf.keras.layers.Flatten(name='value_head/flatten')(conv_val)
+ h_fc2 = tf.keras.layers.Dense(128, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation='relu', name='value_head/dense_1')(h_conv_val_flat)
+ h_fc3 = tf.keras.layers.Dense(3, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, name='value_head/dense_2')(h_fc2)
+ return h_fc1, h_fc3
+
+ def construct_with_stops(self, inputs, num_sections, from_top = True):
+
+ tot_sections = self.RESIDUAL_BLOCKS + 4
+ if from_top:
+ num_sections = tot_sections - num_sections
+ #section_count = 0 or less is no stops
+ section_count = 1
+
+ flow_p = self.conv_block_v2(inputs, filter_size=3, output_channels=self.RESIDUAL_FILTERS, bn_scale=True, name='conv_block1')
+ if section_count == num_sections:
+ flow_p = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(flow_p)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ for _ in range(0, self.RESIDUAL_BLOCKS):
+
+ section_count += 1
+ flow_p = self.residual_block_v2(flow_p, self.RESIDUAL_FILTERS, name='res_towers/res_{}'.format(_))
+ if section_count == num_sections:
+ flow_p = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(flow_p)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ section_count += 1
+ conv_pol = self.conv_block_v2(flow_p, filter_size=3, output_channels=self.RESIDUAL_FILTERS, name='policy_head/conv_pol1')
+ if section_count == num_sections:
+ conv_pol = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(conv_pol)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ section_count += 1
+ conv_pol2 = tf.keras.layers.Conv2D(80, 3, use_bias=True, padding='same', kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, data_format='channels_first', name='policy_head/conv_pol2')(conv_pol)
+ if section_count == num_sections:
+ conv_pol2 = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(conv_pol2)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ section_count += 1
+ h_fc1 = ApplyPolicyMap(name='policy_head/h_fc1')(conv_pol2)
+ if section_count == num_sections:
+ h_fc1 = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(h_fc1)
+ printWithDate(f"Adding stop at depth: {section_count}")
+
+ if section_count != tot_sections:
+ raise RuntimeError(f"Number of sections was calculated to be {tot_sections}, but is actually {section_count}")
+
+ # Value head
+ flow_v = tf.keras.layers.Lambda(lambda x: tf.keras.backend.stop_gradient(x))(flow_p)
+ conv_val = self.conv_block_v2(flow_v, filter_size=1, output_channels=32, name='value_head/conv_val')
+ h_conv_val_flat = tf.keras.layers.Flatten(name='value_head/flatten')(conv_val)
+ h_fc2 = tf.keras.layers.Dense(128, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, activation='relu', name='value_head/dense_1')(h_conv_val_flat)
+ h_fc3 = tf.keras.layers.Dense(3, kernel_initializer='glorot_normal', kernel_regularizer=self.l2reg, bias_regularizer=self.l2reg, name='value_head/dense_2')(h_fc2)
+
+
+ return h_fc1, h_fc3
+
diff --git a/backend/tf_transfer/training_shared.py b/backend/tf_transfer/training_shared.py
new file mode 100755
index 0000000..c3ef52a
--- /dev/null
+++ b/backend/tf_transfer/training_shared.py
@@ -0,0 +1,96 @@
+from ..utils import printWithDate
+
+import tensorflow as tf
+
+import glob
+import os
+import os.path
+import random
+import gzip
+import sys
+
+def get_latest_chunks(path):
+ chunks = []
+ printWithDate(f"found {glob.glob(path)} chunk dirs")
+ whites = []
+ blacks = []
+
+ for d in glob.glob(path):
+ for root, dirs, files in os.walk(d):
+ for fpath in files:
+ if fpath.endswith('.gz'):
+ #TODO: Make less sketchy
+ if 'black' in root:
+ blacks.append(os.path.join(root, fpath))
+ elif 'white' in root:
+ whites.append(os.path.join(root, fpath))
+ else:
+ raise RuntimeError(
+ f"invalid chunk path found:{os.path.join(root, fpath)}")
+
+ printWithDate(
+ f"found {len(whites)} white {len(blacks)} black chunks", end='\r')
+ printWithDate(f"found {len(whites) + len(blacks)} chunks total")
+ if len(whites) < 1 or len(blacks) < 1:
+ print("Not enough chunks {}".format(len(blacks)))
+ sys.exit(1)
+
+ print("sorting {} B chunks...".format(len(blacks)), end='')
+ blacks.sort(key=os.path.getmtime, reverse=True)
+ print("sorting {} W chunks...".format(len(whites)), end='')
+ whites.sort(key=os.path.getmtime, reverse=True)
+ print("[done]")
+ print("{} - {}".format(os.path.basename(whites[-1]), os.path.basename(whites[0])))
+ print("{} - {}".format(os.path.basename(blacks[-1]), os.path.basename(blacks[0])))
+ random.shuffle(blacks)
+ random.shuffle(whites)
+ return whites, blacks
+
+
+class FileDataSrc:
+ """
+ data source yielding chunkdata from chunk files.
+ """
+ def __init__(self, white_chunks, black_chunks):
+ self.white_chunks = []
+ self.white_done = white_chunks
+
+ self.black_chunks = []
+ self.black_done = black_chunks
+
+ self.next_is_white = True
+
+ def next(self):
+ self.next_is_white = not self.next_is_white
+ return self.next_by_colour(not self.next_is_white)
+
+ def next_by_colour(self, is_white):
+ if is_white:
+ if not self.white_chunks:
+ self.white_chunks, self.white_done = self.white_done, self.white_chunks
+ random.shuffle(self.white_chunks)
+ if not self.white_chunks:
+ return None
+ while len(self.white_chunks):
+ filename = self.white_chunks.pop()
+ try:
+ with gzip.open(filename, 'rb') as chunk_file:
+ self.white_done.append(filename)
+ return chunk_file.read(), True
+ except:
+ print("failed to parse {}".format(filename))
+ else:
+ if not self.black_chunks:
+ self.black_chunks, self.black_done = self.black_done, self.black_chunks
+ random.shuffle(self.black_chunks)
+ if not self.black_chunks:
+ return None, False
+ while len(self.black_chunks):
+ filename = self.black_chunks.pop()
+ try:
+ with gzip.open(filename, 'rb') as chunk_file:
+ self.black_done.append(filename)
+ return chunk_file.read(), False
+ except:
+ print("failed to parse {}".format(filename))
+
diff --git a/backend/tf_transfer/update_steps.py b/backend/tf_transfer/update_steps.py
new file mode 100755
index 0000000..8956ec8
--- /dev/null
+++ b/backend/tf_transfer/update_steps.py
@@ -0,0 +1,37 @@
+#!/usr/bin/env python3
+import argparse
+import os
+import yaml
+import sys
+import tensorflow as tf
+from .tfprocess import TFProcess
+
+START_FROM = 0
+
+def main(cmd):
+ cfg = yaml.safe_load(cmd.cfg.read())
+ print(yaml.dump(cfg, default_flow_style=False))
+
+ root_dir = os.path.join(cfg['training']['path'], cfg['name'])
+ if not os.path.exists(root_dir):
+ os.makedirs(root_dir)
+
+ tfprocess = TFProcess(cfg)
+ tfprocess.init_net_v2()
+
+ tfprocess.restore_v2()
+
+ START_FROM = cmd.start
+
+ tfprocess.global_step.assign(START_FROM)
+ tfprocess.manager.save()
+
+if __name__ == "__main__":
+ argparser = argparse.ArgumentParser(description=\
+ 'Convert current checkpoint to new step count.')
+ argparser.add_argument('--cfg', type=argparse.FileType('r'),
+ help='yaml configuration with training parameters')
+ argparser.add_argument('--start', type=int, default=0,
+ help='Offset to set global_step to.')
+
+ main(argparser.parse_args())
diff --git a/backend/tf_transfer/utils.py b/backend/tf_transfer/utils.py
new file mode 100755
index 0000000..492cf52
--- /dev/null
+++ b/backend/tf_transfer/utils.py
@@ -0,0 +1,18 @@
+import io
+import tempfile
+
+import tensorflow as tf
+
+def show_model(model, filename = None, detailed = False, show_shapes = False):
+ if filename is None:
+ tempf = tempfile.NamedTemporaryFile(suffix='.png')
+ filename = tempf.name
+ return tf.keras.utils.plot_model(
+ model,
+ to_file=filename,
+ show_shapes=show_shapes,
+ show_layer_names=True,
+ rankdir='TB',
+ expand_nested=detailed,
+ dpi=96,
+ )
diff --git a/backend/uci_engine.py b/backend/uci_engine.py
new file mode 100755
index 0000000..7950638
--- /dev/null
+++ b/backend/uci_engine.py
@@ -0,0 +1,209 @@
+import subprocess
+import os.path
+import re
+import datetime
+import concurrent
+
+import yaml
+
+import chess
+import chess.engine
+import chess.pgn
+
+from .utils import tz
+#from .proto import Board, Node, Game
+
+p_re = re.compile(r"(\S+) [^P]+P\: +([0-9.]+)")
+
+#You will probably need to set these manually
+
+lc0_path = 'lc0'
+
+sf_path = 'stockfish'
+
+def model_from_config(config_dir_path, nodes = None):
+ with open(os.path.join(config_dir_path, 'config.yaml')) as f:
+ config = yaml.safe_load(f.read())
+
+ if config['engine'] == 'stockfish':
+ model = Stockfish_Engine(**config['options'])
+
+ elif config['engine'] in ['lc0', 'lc0_23']:
+ config['options']['weightsPath'] = os.path.join(config_dir_path, config['options']['weightsPath'])
+ if nodes is not None:
+ config['options']['nodes'] = nodes
+ model = LC0_Engine(**config['options'])
+ else:
+ raise NotImplementedError(f"{config['engine']} is not a known engine type")
+
+ model.config = config
+
+ return model
+
+class Shallow_Board_Query(Exception):
+ pass
+
+def is_shallow_board(board):
+ #https://github.com/niklasf/python-chess/blob/master/chess/engine.py#L1183
+ if len(board.move_stack) < 1 and board.fen().split()[0] != chess.STARTING_BOARD_FEN:
+ return True
+ return False
+
+
+class UCI_Engine(object):
+ def __init__(self, engine, movetime = None, nodes = None, depth = None):
+ self.engine = engine
+ self.limits = chess.engine.Limit(
+ time = movetime,
+ depth = depth,
+ nodes = nodes,
+ )
+ self.config = None
+ self.query_counter = 0
+
+ def __del__(self):
+ try:
+ try:
+ self.engine.quit()
+ except (chess.engine.EngineTerminatedError, concurrent.futures._base.TimeoutError):
+ pass
+ except AttributeError:
+ pass
+
+ def getMove(self, board, allow_shallow = False):
+ return self.board_info(board, multipv = 1, allow_shallow = allow_shallow)[0][0]
+
+ def board_info(self, board, multipv = 1, allow_shallow = False):
+ """Basic board info"""
+ if is_shallow_board(board) and not allow_shallow:
+ raise Shallow_Board_Query(f"{board.fen()} has no history")
+ r = self.engine.analyse(
+ board,
+ self.limits,
+ multipv = multipv,
+ info = chess.engine.INFO_ALL,
+ game = self.query_counter,
+ )
+ self.query_counter += 1
+ return [(p['pv'][0], p) for p in r]
+
+ def board_info_full(self, board, multipv = 1, allow_shallow = False):
+ """All the info string stuff"""
+ if is_shallow_board(board) and not allow_shallow:
+ raise Shallow_Board_Query(f"{board.fen()} has no history")
+ r = self.engine.analysis(
+ board,
+ self.limits,
+ multipv = multipv,
+ info = chess.engine.INFO_ALL,
+ game = self.query_counter,
+ )
+ self.query_counter += 1
+ r.wait()
+ info_strs = []
+ while not r.empty():
+ try:
+ info_strs.append(r.get()['string'])
+ except KeyError:
+ pass
+ return [(p['pv'][0], p) for p in r.multipv], '\n'.join(info_strs)
+
+class LC0_Engine(UCI_Engine):
+ def __init__(self, weightsPath, movetime = None, nodes = 1, depth = None, binary_path = lc0_path, threads = 2):
+ E = chess.engine.SimpleEngine.popen_uci([binary_path, '-w', weightsPath, '--verbose-move-stats', f'--threads={threads}'], stderr=subprocess.DEVNULL)
+
+ super().__init__(E, movetime = movetime, nodes = nodes, depth = depth)
+
+ def board_parsed_p_values(self, board, allow_shallow = False):
+ dat, info_str = self.board_info_full(board, multipv = 1, allow_shallow = allow_shallow)
+ p_vals = {}
+ for l in info_str.split('\n'):
+ r = p_re.match(l)
+ p_vals[r.group(1)] = float(r.group(2))
+ return dat, p_vals
+
+ def board_pv(self, board, allow_shallow = False):
+ d, p = self.board_parsed_p_values(board, allow_shallow = allow_shallow)
+ return d[0][1]['score'].relative.cp, p
+
+ def make_tree_node(self, board, depth = 2, width = 10, allow_shallow = False):
+ N = Node()
+ try:
+ v, p_dict = self.board_pv(board, allow_shallow=allow_shallow)
+ except KeyError:
+ if board.is_game_over():
+ N.depth = depth
+ N.value = 0
+ return N
+ else:
+ raise
+ N.depth = depth
+ N.value = v
+ children = sorted(p_dict.items(), key = lambda x : x[1], reverse=True)[:width]
+ N.child_values.extend([p for m, p in children])
+ N.child_moves.extend([m for m, p in children])
+ chunks = []
+ if depth > 0:
+ for m, pv in children:
+ b = board.copy()
+ b.push_uci(m)
+ chunks.append(self.make_tree_node(b, depth = depth - 1, width = width, allow_shallow=allow_shallow))
+ N.children.extend(chunks)
+ return N
+
+ def make_game_file(self, game, depth = 2, width = 10, intial_skip = 4, allow_shallow = False):
+ G = Game()
+ G.game_id = game.headers['Site'].split('/')[-1]
+ G.black_elo = int(game.headers['BlackElo'])
+ G.white_elo = int(game.headers['WhiteElo'])
+ boards = []
+ for i, (mm, mb) in list(enumerate(zip(list(game.mainline())[1:], list(game.mainline())[:-1])))[intial_skip:-1]:
+ board = mb.board()
+ b_node = self.make_tree_node(board, depth = depth, width = width, allow_shallow = allow_shallow)
+ proto_board = Board()
+ proto_board.tree.MergeFrom(b_node)
+ proto_board.fen = mb.board().fen()
+ proto_board.ply = i
+ proto_board.move = str(mm.move)
+ try:
+ proto_board.move_index = list(b_node.child_moves).index(str(mm.move))
+ except ValueError:
+ proto_board.move_index = -1
+ boards.append(proto_board)
+ G.boards.extend(boards)
+ return G
+
+class Stockfish_Engine(UCI_Engine):
+ def __init__(self, movetime = None, nodes = None, depth = 15, binary_path = sf_path, threads = 2, hash = 256):
+ E = chess.engine.SimpleEngine.popen_uci([binary_path])
+
+ super().__init__(E, movetime = movetime, nodes = nodes, depth = depth)
+ self.engine.configure({"Threads": threads, "Hash": hash})
+
+def play_game(E1, E2, round = None, startingFen = None, notes = None):
+
+ timeStarted = datetime.datetime.now(tz)
+ i = 0
+ if startingFen is not None:
+ board = chess.Board(fen=startingFen)
+ else:
+ board = chess.Board()
+
+ players = [E1, E2]
+
+ while not board.is_game_over():
+ E = players[i % 2]
+ board.push(E.getMove(board))
+ i += 1
+ pgnGame = chess.pgn.Game.from_board(board)
+
+ pgnGame.headers['Event'] = f"{E1.config['name']} vs {E2.config['name']}"
+ pgnGame.headers['White'] = E1.config['name']
+ pgnGame.headers['Black'] = E2.config['name']
+ pgnGame.headers['Date'] = timeStarted.strftime("%Y-%m-%d %H:%M:%S")
+ if round is not None:
+ pgnGame.headers['Round'] = round
+ if notes is not None:
+ for k, v in notes.items():
+ pgnGame.headers[k] = v
+ return pgnGame
diff --git a/backend/utils.py b/backend/utils.py
new file mode 100755
index 0000000..44370bf
--- /dev/null
+++ b/backend/utils.py
@@ -0,0 +1,135 @@
+import functools
+import sys
+import time
+import datetime
+import os
+import os.path
+import traceback
+
+import pytz
+
+
+min_run_time = 60 * 10 # 10 minutes
+infos_dir_name = 'runinfos'
+tz = pytz.timezone('Canada/Eastern')
+
+colours = {
+ 'blue' : '\033[94m',
+ 'green' : '\033[92m',
+ 'yellow' : '\033[93m',
+ 'red' : '\033[91m',
+ 'pink' : '\033[95m',
+}
+endColour = '\033[0m'
+
+def printWithDate(s, colour = None, **kwargs):
+ if colour is None:
+ print(f"{datetime.datetime.now(tz).strftime('%Y-%m-%d %H:%M:%S')} {s}", **kwargs)
+ else:
+ print(f"{datetime.datetime.now(tz).strftime('%Y-%m-%d %H:%M:%S')}{colours[colour]} {s}{endColour}", **kwargs)
+
+class Tee(object):
+ #Based on https://stackoverflow.com/a/616686
+ def __init__(self, fname, is_err = False):
+ self.file = open(fname, 'a')
+ self.is_err = is_err
+ if is_err:
+ self.stdstream = sys.stderr
+ sys.stderr = self
+ else:
+ self.stdstream = sys.stdout
+ sys.stdout = self
+ def __del__(self):
+ if self.is_err:
+ sys.stderr = self.stdstream
+ else:
+ sys.stdout = self.stdstream
+ self.file.close()
+ def write(self, data):
+ self.file.write(data)
+ self.stdstream.write(data)
+ def flush(self):
+ self.file.flush()
+
+class LockedName(object):
+ def __init__(self, script_name, start_time):
+ self.script_name = script_name
+ self.start_time = start_time
+ os.makedirs(infos_dir_name, exist_ok = True)
+ os.makedirs(os.path.join(infos_dir_name, self.script_name), exist_ok = True)
+
+ self.file_prefix = self.get_name_prefix()
+ self.full_prefix = self.file_prefix + f"-{start_time.strftime('%Y-%m-%d-%H%M')}_"
+ self.lock = None
+ self.lock_name = None
+
+ def __enter__(self):
+ try:
+ self.lock_name = self.file_prefix + '.lock'
+ self.lock = open(self.lock_name, 'x')
+ except FileExistsError:
+ self.file_prefix = self.get_name_prefix()
+ self.full_prefix = self.file_prefix + f"-{start_time.strftime('%Y-%m-%d-%H%M')}_"
+ return self.__enter__()
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ try:
+ self.lock.close()
+ os.remove(self.lock_name)
+ except:
+ pass
+
+ def get_name_prefix(self):
+ fdir = os.path.join(infos_dir_name, self.script_name)
+ prefixes = [n.name.split('-')[0] for n in os.scandir(fdir) if n.is_file()]
+ file_num = 1
+ nums = []
+ for p in set(prefixes):
+ try:
+ nums.append(int(p))
+ except ValueError:
+ pass
+ if len(nums) > 0:
+ file_num = max(nums) + 1
+
+ return os.path.join(fdir, f"{file_num:04.0f}")
+
+def logged_main(mainFunc):
+ @functools.wraps(mainFunc)
+ def wrapped_main(*args, **kwds):
+ start_time = datetime.datetime.now(tz)
+ script_name = os.path.basename(sys.argv[0])[:-3]
+
+ with LockedName(script_name, start_time) as name_lock:
+ tee_out = Tee(name_lock.full_prefix + 'stdout.log', is_err = False)
+ tee_err = Tee(name_lock.full_prefix + 'stderr.log', is_err = True)
+ logs_prefix = name_lock.full_prefix
+ printWithDate(' '.join(sys.argv), colour = 'blue')
+ printWithDate(f"Starting {script_name}", colour = 'blue')
+ try:
+ tstart = time.time()
+ val = mainFunc(*args, **kwds)
+ except (Exception, KeyboardInterrupt) as e:
+ printWithDate(f"Error encountered", colour = 'blue')
+ if (time.time() - tstart) > min_run_time:
+ makeLog(logs_prefix, start_time, tstart, True, 'Error', e, traceback.format_exc())
+ raise
+ else:
+ printWithDate(f"Run completed", colour = 'blue')
+ if (time.time() - tstart) > min_run_time:
+ makeLog(logs_prefix, start_time, tstart, False, 'Successful')
+ tee_out.flush()
+ tee_err.flush()
+ return val
+ return wrapped_main
+
+def makeLog(logs_prefix, start_time, tstart, is_error, *notes):
+ fname = f'error.log' if is_error else f'run.log'
+ with open(logs_prefix + fname, 'w') as f:
+ f.write(f"start: {start_time.strftime('%Y-%m-%d-%H:%M:%S')}\n")
+ f.write(f"stop: {datetime.datetime.now(tz).strftime('%Y-%m-%d-%H:%M:%S')}\n")
+ f.write(f"duration: {int(tstart > min_run_time)}s\n")
+ f.write(f"dir: {os.path.abspath(os.getcwd())}\n")
+ f.write(f"{' '.join(sys.argv)}\n")
+ f.write('\n'.join([str(n) for n in notes]))
diff --git a/environment.yml b/environment.yml
new file mode 100755
index 0000000..65a0517
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,186 @@
+name: transfer_chess
+channels:
+ - pytorch
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _tflow_select=2.1.0=gpu
+ - absl-py=0.9.0=py37_0
+ - asn1crypto=1.3.0=py37_0
+ - astor=0.8.0=py37_0
+ - attrs=19.3.0=py_0
+ - backcall=0.1.0=py37_0
+ - blas=1.0=mkl
+ - bleach=3.1.4=py_0
+ - blinker=1.4=py37_0
+ - c-ares=1.15.0=h7b6447c_1001
+ - ca-certificates=2020.1.1=0
+ - cachetools=3.1.1=py_0
+ - cairo=1.14.12=h8948797_3
+ - certifi=2020.4.5.1=py37_0
+ - cffi=1.13.2=py37h2e261b9_0
+ - chardet=3.0.4=py37_1003
+ - click=7.0=py37_0
+ - conda=4.8.3=py37_0
+ - conda-package-handling=1.6.0=py37h7b6447c_0
+ - cryptography=2.8=py37h1ba5d50_0
+ - cudatoolkit=10.1.243=h6bb024c_0
+ - cudnn=7.6.5=cuda10.1_0
+ - cupti=10.1.168=0
+ - dbus=1.13.12=h746ee38_0
+ - decorator=4.4.1=py_0
+ - defusedxml=0.6.0=py_0
+ - entrypoints=0.3=py37_0
+ - expat=2.2.6=he6710b0_0
+ - fontconfig=2.13.0=h9420a91_0
+ - freetype=2.9.1=h8a8886c_1
+ - fribidi=1.0.5=h7b6447c_0
+ - gast=0.2.2=py37_0
+ - glib=2.63.1=h5a9c865_0
+ - gmp=6.1.2=h6c8ec71_1
+ - google-auth=1.11.2=py_0
+ - google-auth-oauthlib=0.4.1=py_2
+ - google-pasta=0.1.8=py_0
+ - graphite2=1.3.13=h23475e2_0
+ - graphviz=2.40.1=h21bd128_2
+ - grpcio=1.27.2=py37hf8bcb03_0
+ - gst-plugins-base=1.14.0=hbbd80ab_1
+ - gstreamer=1.14.0=hb453b48_1
+ - h5py=2.10.0=py37h7918eee_0
+ - harfbuzz=1.8.8=hffaf4a1_0
+ - hdf5=1.10.4=hb1b8bf9_0
+ - icu=58.2=h9c2bf20_1
+ - idna=2.8=py37_0
+ - importlib_metadata=1.4.0=py37_0
+ - intel-openmp=2020.0=166
+ - ipykernel=5.1.4=py37h39e3cac_0
+ - ipython=7.11.1=py37h39e3cac_0
+ - ipython_genutils=0.2.0=py37_0
+ - ipywidgets=7.5.1=py_0
+ - jedi=0.16.0=py37_0
+ - jinja2=2.11.1=py_0
+ - joblib=0.14.1=py_0
+ - jpeg=9b=h024ee3a_2
+ - jsonschema=3.2.0=py37_0
+ - jupyter=1.0.0=py37_7
+ - jupyter_client=5.3.4=py37_0
+ - jupyter_console=6.1.0=py_0
+ - jupyter_core=4.6.1=py37_0
+ - keras-applications=1.0.8=py_0
+ - keras-preprocessing=1.1.0=py_1
+ - libedit=3.1.20181209=hc058e9b_0
+ - libffi=3.2.1=hd88cf55_4
+ - libgcc-ng=9.1.0=hdf63c60_0
+ - libgfortran-ng=7.3.0=hdf63c60_0
+ - libpng=1.6.37=hbc83047_0
+ - libprotobuf=3.11.4=hd408876_0
+ - libsodium=1.0.16=h1bed415_0
+ - libstdcxx-ng=9.1.0=hdf63c60_0
+ - libtiff=4.1.0=h2733197_0
+ - libuuid=1.0.3=h1bed415_2
+ - libxcb=1.13=h1bed415_1
+ - libxml2=2.9.9=hea5a465_1
+ - markdown=3.1.1=py37_0
+ - markupsafe=1.1.1=py37h7b6447c_0
+ - meson=0.52.0=py_0
+ - mistune=0.8.4=py37h7b6447c_0
+ - mkl=2020.0=166
+ - mkl-service=2.3.0=py37he904b0f_0
+ - mkl_fft=1.0.15=py37ha843d7b_0
+ - mkl_random=1.1.0=py37hd6b4f25_0
+ - more-itertools=8.0.2=py_0
+ - nb_conda_kernels=2.2.2=py37_0
+ - nbconvert=5.6.1=py37_0
+ - nbformat=5.0.4=py_0
+ - ncurses=6.1=he6710b0_1
+ - ninja=1.9.0=py37hfd86e86_0
+ - notebook=6.0.3=py37_0
+ - numpy=1.18.1=py37h4f9e942_0
+ - numpy-base=1.18.1=py37hde5b4d6_1
+ - oauthlib=3.1.0=py_0
+ - olefile=0.46=py37_0
+ - openssl=1.1.1f=h7b6447c_0
+ - opt_einsum=3.1.0=py_0
+ - pandoc=2.2.3.2=0
+ - pandocfilters=1.4.2=py37_1
+ - pango=1.42.4=h049681c_0
+ - parso=0.6.0=py_0
+ - pcre=8.43=he6710b0_0
+ - pexpect=4.8.0=py37_0
+ - pickleshare=0.7.5=py37_0
+ - pillow=7.0.0=py37hb39fc2d_0
+ - pip=20.0.2=py37_1
+ - pixman=0.38.0=h7b6447c_0
+ - prometheus_client=0.7.1=py_0
+ - prompt_toolkit=3.0.3=py_0
+ - protobuf=3.11.4=py37he6710b0_0
+ - ptyprocess=0.6.0=py37_0
+ - pyasn1=0.4.8=py_0
+ - pyasn1-modules=0.2.7=py_0
+ - pycosat=0.6.3=py37h7b6447c_0
+ - pycparser=2.19=py37_0
+ - pydot=1.4.1=py37_0
+ - pygments=2.5.2=py_0
+ - pyjwt=1.7.1=py37_0
+ - pyopenssl=19.1.0=py37_0
+ - pyparsing=2.4.6=py_0
+ - pyqt=5.9.2=py37h05f1152_2
+ - pyrsistent=0.15.7=py37h7b6447c_0
+ - pysocks=1.7.1=py37_0
+ - python=3.7.4=h265db76_1
+ - python-dateutil=2.8.1=py_0
+ - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0
+ - pyyaml=5.3=py37h7b6447c_0
+ - pyzmq=18.1.1=py37he6710b0_0
+ - qt=5.9.7=h5867ecd_1
+ - qtconsole=4.6.0=py_1
+ - readline=7.0=h7b6447c_5
+ - requests=2.22.0=py37_1
+ - requests-oauthlib=1.3.0=py_0
+ - rsa=4.0=py_0
+ - ruamel_yaml=0.15.87=py37h7b6447c_0
+ - scikit-learn=0.22.1=py37hd81dba3_0
+ - scipy=1.4.1=py37h0b6359f_0
+ - send2trash=1.5.0=py37_0
+ - setuptools=45.1.0=py37_0
+ - sip=4.19.8=py37hf484d3e_0
+ - six=1.14.0=py37_0
+ - sqlite=3.30.1=h7b6447c_0
+ - tensorboard=2.1.0=py3_0
+ - tensorflow=2.1.0=gpu_py37h7a4bb67_0
+ - tensorflow-base=2.1.0=gpu_py37h6c5654b_0
+ - tensorflow-estimator=2.1.0=pyhd54b08b_0
+ - tensorflow-gpu=2.1.0=h0d30ee6_0
+ - termcolor=1.1.0=py37_1
+ - terminado=0.8.3=py37_0
+ - testpath=0.4.4=py_0
+ - tk=8.6.8=hbc83047_0
+ - torchvision=0.5.0=py37_cu101
+ - tornado=6.0.3=py37h7b6447c_0
+ - tqdm=4.42.0=py_0
+ - traitlets=4.3.3=py37_0
+ - urllib3=1.25.8=py37_0
+ - wcwidth=0.1.9=py_0
+ - webencodings=0.5.1=py37_1
+ - werkzeug=1.0.0=py_0
+ - wheel=0.34.1=py37_0
+ - widgetsnbextension=3.5.1=py37_0
+ - wrapt=1.11.2=py37h7b6447c_0
+ - xz=5.2.4=h14c3975_4
+ - yaml=0.1.7=had09818_2
+ - zeromq=4.3.1=he6710b0_3
+ - zipp=2.2.0=py_0
+ - zlib=1.2.11=h7b6447c_3
+ - zstd=1.3.7=h0b5b093_0
+ - pip:
+ - cycler==0.10.0
+ - humanize==2.4.0
+ - kiwisolver==1.2.0
+ - matplotlib==3.2.1
+ - natsort==7.0.1
+ - pandas==1.0.3
+ - python-chess==0.30.1
+ - pytz==2019.3
+ - seaborn==0.10.0
+ - tensorboardx==2.0
+
diff --git a/images/kdd_indiv_final.jpg b/images/kdd_indiv_final.jpg
new file mode 100644
index 0000000..9c5a0f4
Binary files /dev/null and b/images/kdd_indiv_final.jpg differ
diff --git a/models/maia-1900/ckpt/checkpoint b/models/maia-1900/ckpt/checkpoint
new file mode 100755
index 0000000..8533215
--- /dev/null
+++ b/models/maia-1900/ckpt/checkpoint
@@ -0,0 +1,82 @@
+model_checkpoint_path: "ckpt-40"
+all_model_checkpoint_paths: "ckpt-1"
+all_model_checkpoint_paths: "ckpt-2"
+all_model_checkpoint_paths: "ckpt-3"
+all_model_checkpoint_paths: "ckpt-4"
+all_model_checkpoint_paths: "ckpt-5"
+all_model_checkpoint_paths: "ckpt-6"
+all_model_checkpoint_paths: "ckpt-7"
+all_model_checkpoint_paths: "ckpt-8"
+all_model_checkpoint_paths: "ckpt-9"
+all_model_checkpoint_paths: "ckpt-10"
+all_model_checkpoint_paths: "ckpt-11"
+all_model_checkpoint_paths: "ckpt-12"
+all_model_checkpoint_paths: "ckpt-13"
+all_model_checkpoint_paths: "ckpt-14"
+all_model_checkpoint_paths: "ckpt-15"
+all_model_checkpoint_paths: "ckpt-16"
+all_model_checkpoint_paths: "ckpt-17"
+all_model_checkpoint_paths: "ckpt-18"
+all_model_checkpoint_paths: "ckpt-19"
+all_model_checkpoint_paths: "ckpt-20"
+all_model_checkpoint_paths: "ckpt-21"
+all_model_checkpoint_paths: "ckpt-22"
+all_model_checkpoint_paths: "ckpt-23"
+all_model_checkpoint_paths: "ckpt-24"
+all_model_checkpoint_paths: "ckpt-25"
+all_model_checkpoint_paths: "ckpt-26"
+all_model_checkpoint_paths: "ckpt-27"
+all_model_checkpoint_paths: "ckpt-28"
+all_model_checkpoint_paths: "ckpt-29"
+all_model_checkpoint_paths: "ckpt-30"
+all_model_checkpoint_paths: "ckpt-31"
+all_model_checkpoint_paths: "ckpt-32"
+all_model_checkpoint_paths: "ckpt-33"
+all_model_checkpoint_paths: "ckpt-34"
+all_model_checkpoint_paths: "ckpt-35"
+all_model_checkpoint_paths: "ckpt-36"
+all_model_checkpoint_paths: "ckpt-37"
+all_model_checkpoint_paths: "ckpt-38"
+all_model_checkpoint_paths: "ckpt-39"
+all_model_checkpoint_paths: "ckpt-40"
+all_model_checkpoint_timestamps: 1580106783.8790061
+all_model_checkpoint_timestamps: 1580113034.5215666
+all_model_checkpoint_timestamps: 1580119167.9981554
+all_model_checkpoint_timestamps: 1580125270.5550704
+all_model_checkpoint_timestamps: 1580131382.6197543
+all_model_checkpoint_timestamps: 1580138060.0350215
+all_model_checkpoint_timestamps: 1580144931.4751053
+all_model_checkpoint_timestamps: 1580151357.3907902
+all_model_checkpoint_timestamps: 1580157406.0482683
+all_model_checkpoint_timestamps: 1580163445.5980349
+all_model_checkpoint_timestamps: 1580169474.1105049
+all_model_checkpoint_timestamps: 1580175510.0387604
+all_model_checkpoint_timestamps: 1580181567.815861
+all_model_checkpoint_timestamps: 1580187622.8185244
+all_model_checkpoint_timestamps: 1580193674.1944962
+all_model_checkpoint_timestamps: 1580199721.2665217
+all_model_checkpoint_timestamps: 1580205792.755944
+all_model_checkpoint_timestamps: 1580211859.5465987
+all_model_checkpoint_timestamps: 1580217928.1305025
+all_model_checkpoint_timestamps: 1580223989.668282
+all_model_checkpoint_timestamps: 1580231494.4801118
+all_model_checkpoint_timestamps: 1580240895.8979034
+all_model_checkpoint_timestamps: 1580250465.895426
+all_model_checkpoint_timestamps: 1580259628.7052832
+all_model_checkpoint_timestamps: 1580268883.0895178
+all_model_checkpoint_timestamps: 1580278314.7480402
+all_model_checkpoint_timestamps: 1580288003.8131309
+all_model_checkpoint_timestamps: 1580297809.2752874
+all_model_checkpoint_timestamps: 1580307735.15046
+all_model_checkpoint_timestamps: 1580318164.597156
+all_model_checkpoint_timestamps: 1580328825.4124599
+all_model_checkpoint_timestamps: 1580339783.5046844
+all_model_checkpoint_timestamps: 1580347138.0900939
+all_model_checkpoint_timestamps: 1580354427.078483
+all_model_checkpoint_timestamps: 1580360702.8677912
+all_model_checkpoint_timestamps: 1580366508.5701687
+all_model_checkpoint_timestamps: 1580372158.3093505
+all_model_checkpoint_timestamps: 1580377816.579277
+all_model_checkpoint_timestamps: 1580383466.9756734
+all_model_checkpoint_timestamps: 1580389118.3248632
+last_preserved_timestamp: 1580099931.4647074
diff --git a/models/maia-1900/ckpt/ckpt-40-400000.pb.gz b/models/maia-1900/ckpt/ckpt-40-400000.pb.gz
new file mode 100755
index 0000000..52d22a8
Binary files /dev/null and b/models/maia-1900/ckpt/ckpt-40-400000.pb.gz differ
diff --git a/models/maia-1900/ckpt/ckpt-40.data-00000-of-00002 b/models/maia-1900/ckpt/ckpt-40.data-00000-of-00002
new file mode 100755
index 0000000..6cdb51a
Binary files /dev/null and b/models/maia-1900/ckpt/ckpt-40.data-00000-of-00002 differ
diff --git a/models/maia-1900/ckpt/ckpt-40.data-00001-of-00002 b/models/maia-1900/ckpt/ckpt-40.data-00001-of-00002
new file mode 100755
index 0000000..2268437
Binary files /dev/null and b/models/maia-1900/ckpt/ckpt-40.data-00001-of-00002 differ
diff --git a/models/maia-1900/ckpt/ckpt-40.index b/models/maia-1900/ckpt/ckpt-40.index
new file mode 100755
index 0000000..f102c21
Binary files /dev/null and b/models/maia-1900/ckpt/ckpt-40.index differ
diff --git a/models/maia-1900/config.yaml b/models/maia-1900/config.yaml
new file mode 100755
index 0000000..68f2f41
--- /dev/null
+++ b/models/maia-1900/config.yaml
@@ -0,0 +1,11 @@
+%YAML 1.2
+---
+name: final_maia_1900
+display_name: Final Maia 1900
+engine: lc0_23
+options:
+ nodes: 1
+ weightsPath: final_1900-40.pb.gz
+ movetime: 10
+ threads: 8
+...
diff --git a/models/maia-1900/final_1900-40.pb.gz b/models/maia-1900/final_1900-40.pb.gz
new file mode 100755
index 0000000..52d22a8
Binary files /dev/null and b/models/maia-1900/final_1900-40.pb.gz differ
diff --git a/setup.py b/setup.py
new file mode 100755
index 0000000..e93ecff
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,25 @@
+from setuptools import setup, find_packages
+import re
+
+with open('backend/__init__.py') as f:
+ versionString = re.search(r"__version__ = '(.+)'", f.read()).group(1)
+
+if __name__ == '__main__':
+ setup(name='backend',
+ version = versionString,
+ author="Anon",
+ author_email="anon@anon",
+ packages = find_packages(),
+ install_requires = [
+ 'numpy',
+ 'matplotlib',
+ 'pandas',
+ 'seaborn',
+ 'python-chess>=0.30.0',
+ 'pytz',
+ 'natsort',
+ 'humanize',
+ 'pyyaml',
+ 'tensorboardX',
+ ],
+ )