-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathothello_utils.py
93 lines (76 loc) · 2.99 KB
/
othello_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch as t
from datasets import load_dataset
from othello_engine_utils import OthelloBoardState, stoi, itos
def board_state_to_RRC(board_state, flip: int = 1):
board_state = t.tensor(board_state, dtype=t.int8)
board_state *= flip
one_hot = t.zeros((8, 8, 3), dtype=t.int8)
one_hot[..., 0] = (board_state == -1).int()
one_hot[..., 1] = (board_state == 0).int()
one_hot[..., 2] = (board_state == 1).int()
return one_hot
# TODO Remove duplicated logic from these functions
def games_batch_to_state_stack_BLRRC(batch_str_moves):
"""Sequences of moves (dataset format) to state stack (one-hot) of shape (seq_len, 8, 8, 3)"""
game_stack = []
for game in batch_str_moves:
if isinstance(game, t.Tensor):
game = game.flatten()
board = OthelloBoardState()
states = []
for move in game:
board.umpire(move)
one_hot = board_state_to_RRC(board.state)
states.append(one_hot)
states = t.stack(states, axis=0)
game_stack.append(states)
return t.stack(game_stack, axis=0)
def games_batch_to_valid_moves_BLRRC(batch_str_moves):
"""Sequences of moves (dataset format) to state stack of valid moves"""
game_stack = []
for game in batch_str_moves:
if isinstance(game, t.Tensor):
game = game.flatten()
board = OthelloBoardState()
states = []
for i, move in enumerate(game):
moves_board = t.zeros(8, 8, 1, dtype=t.int8)
board.umpire(move)
valid_moves_list = board.get_valid_moves()
for move in valid_moves_list:
moves_board[move // 8, move % 8] = 1
states.append(moves_board)
states = t.stack(states, axis=0)
game_stack.append(states)
return t.stack(game_stack, axis=0)
def games_batch_to_state_stack_mine_yours_BLRRC(batch_str_moves):
"""Sequences of moves (dataset format) to state stack (one-hot) of shape (seq_len, 8, 8, 3)"""
game_stack = []
for game in batch_str_moves:
if isinstance(game, t.Tensor):
game = game.flatten()
board = OthelloBoardState()
states = []
for i, move in enumerate(game):
flip = 1
if i % 2 == 1:
flip = -1
board.umpire(move)
one_hot = board_state_to_RRC(board.state, flip)
states.append(one_hot)
states = t.stack(states, axis=0)
game_stack.append(states)
return t.stack(game_stack, axis=0)
othello_functions = [
games_batch_to_state_stack_BLRRC.__name__,
games_batch_to_state_stack_mine_yours_BLRRC.__name__,
games_batch_to_valid_moves_BLRRC.__name__,
]
def get_othello_even_list_indices(tokens_list: list[int]) -> list[int]:
""""""
max_len = len(tokens_list)
return [i for i in range(max_len) if i % 2 == 0]
def get_othello_all_list_indices(tokens_list: list[int]) -> list[int]:
""""""
max_len = len(tokens_list)
return [i for i in range(max_len)]