-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMCTS.py
159 lines (134 loc) · 4.58 KB
/
MCTS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#MCTS
import copy
import math
import random
from utils import game_result,get_childs
from game import GameSettings
def simulate(game, player_i):
# print("Simulating for player",GameSettings[player_i],"...")
result=None
while result is None:
choices = get_childs(game)
if len(choices)==0:
result=game_result(game)
else:
i,j = random.choice(choices)
game[i][j] = GameSettings.players[player_i]
result = game_result(game)
player_i = 1-player_i
# print_board(game)
# print("Result: ",end='')
# if result==-1: print("Draw")
# else: print("Player ",GameSettings[result]," wins.")
# print("<"*50)
return result
class Node:
c=2**.5
def __init__(self,game,player_i,pos=None,parent=None) -> None:
self.wins=0
self.losses=0
self.draws=0
self.n_sims=0
self.game = game
self.parent = parent
self.pos = pos
self.childs=[]
self.player_i = player_i
def __repr__(self):
return f"Node<{self.pos[0]*3+self.pos[1]+1}/{9-len(self.get_childs())}>({GameSettings[self.player_i]}) [{self.wins}/{self.n_sims}]"
def get_childs(self):
return get_childs(self.game)
def game_result(self):
return game_result(self.game)
def won(self):
self.wins+=1
def drawn(self):
self.draws+=1
def lost(self):
self.losses+=1
#returns 1 for win, -1 for loss and 0 for draw
def simulate(self):
result = simulate(copy.deepcopy(self.game), self.player_i)
if result == self.player_i:
return 1
elif result == 1-self.player_i:
return -1
return 0
#used this function to get self.wins indirectly just so that I could change it if required | didn't get required though
def get_sim_score(self):
return self.wins
def getUCB(self,t):
if self.n_sims==0: return float('inf')
exploit_score = self.wins/self.n_sims
return exploit_score + self.c*(math.log(t)/self.n_sims)**.5
def expand(self):
childs = self.get_childs()
pre_childs = [child.pos for child in self.childs] # positions of existing childs
childs = [child for child in childs if child not in pre_childs]
next_player_i = 1-self.player_i
for i,j in self.get_childs():
game = copy.deepcopy(self.game)
game[i][j]=GameSettings.players[self.player_i]
nd = Node(game,next_player_i,(i,j),self)
self.childs.append(nd)
return True
def MCTS_sim(root:Node,n_iters:int):
t=0
root.expand()
while t<n_iters:
#selection
node = root
while True:
childs = node.childs
if len(childs)==0:
break
max_childs=[]
max_val=childs[0].getUCB(node.n_sims)
for nd in childs:
val = nd.getUCB(node.n_sims)
if val>max_val:
max_val=val
max_childs=[nd]
elif val==max_val:
max_childs.append(nd)
# print("selection:",end='')
# print(max_val,max_childs)
node = random.choice(max_childs)
#expansion
#this extra check was "required" so that all no unsimulated nodes would be
#expanded, which increases the breadth of search region
if node.n_sims>0:
node.expand()
else:
child = node
#simulation
childs2 = node.childs
if len(childs2)>0:
child = random.choice(childs2)
result = child.simulate()
elif node.n_sims==0:
result = child.simulate()
else:
result = 0 #draw
child = node
#BackPropagation
parent = child
while parent is not None:
match result:
case 1:
parent.won()
case -1:
parent.lost()
case 0:
parent.drawn()
parent.n_sims+=1
parent = parent.parent
result = -result # win for child is loss for parent and so on...
t+=1
#get UCB position
priority_childs = sorted(root.childs,key=lambda nd: nd.get_sim_score(),reverse=False)
if len(priority_childs)>0:
max_wins = priority_childs[0].get_sim_score()
priority_childs = [child for child in priority_childs if child.get_sim_score()==max_wins]
# print(root.childs)
return priority_childs