Skip to content

Commit

Permalink
feat: add expectimax search
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin18 committed Jan 27, 2024
1 parent 9e0d3a9 commit 81f9670
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 0 deletions.
2 changes: 2 additions & 0 deletions scripts/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NTupleNetworkTDPolicy,
NTupleNetworkTDPolicySmall,
)
from gymnasium_2048.agents.ntuple.search import ExpectimaxSearch


def parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -117,6 +118,7 @@ def enjoy() -> None:
env = gym.make(args.env, render_mode="human")

policy = make_policy(algo=args.algo, trained_agent=args.trained_agent)
policy = ExpectimaxSearch(policy=policy)

for _ in trange(args.n_episodes, desc="Enjoy"):
play_game(env=env, policy=policy)
Expand Down
2 changes: 2 additions & 0 deletions scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
NTupleNetworkTDPolicy,
NTupleNetworkTDPolicySmall,
)
from gymnasium_2048.agents.ntuple.search import ExpectimaxSearch

plt.style.use("ggplot")

Expand Down Expand Up @@ -181,6 +182,7 @@ def evaluate() -> None:
env = make_env(env_id=args.env)
if args.algo is not None and args.trained_agent is not None:
policy = make_policy(algo=args.algo, trained_agent=args.trained_agent)
policy = ExpectimaxSearch(policy=policy)
else:
policy = None

Expand Down
68 changes: 68 additions & 0 deletions src/gymnasium_2048/agents/ntuple/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np

from gymnasium_2048.agents.ntuple.policy import NTupleNetworkBasePolicy
from gymnasium_2048.envs import TwentyFortyEightEnv


class ExpectimaxSearch:
def __init__(
self,
policy: NTupleNetworkBasePolicy,
max_depth: int = 3,
) -> None:
self.policy = policy
self.max_depth = max_depth
self.min_value = 0.0

def _evaluate(self, state: np.ndarray) -> tuple[float, int]:
values = [
self.policy.evaluate(state=state, action=action) for action in range(4)
]
max_action = np.argmax(values)
return max(self.min_value, values[max_action]), max_action

def _maximize(self, state: np.ndarray, depth: int) -> tuple[float, int]:
if depth >= self.max_depth:
return self._evaluate(state=state)

max_value = self.min_value
max_action = 0

for action in range(4):
after_state, _, is_legal = TwentyFortyEightEnv.apply_action(
board=state,
action=action,
)
if not is_legal:
continue

value = self._chance(after_state=after_state, depth=depth + 1)
if value > max_value:
max_value = value
max_action = action

return max_value, max_action

def _chance(self, after_state: np.ndarray, depth: int) -> float:
if depth >= self.max_depth:
return self._evaluate(state=after_state)[0]

values, weights = [], []

for row in range(after_state.shape[0]):
for col in range(after_state.shape[1]):
if after_state[row, col] != 0:
continue

for value, prob in ((1, 0.9), (2, 0.1)):
after_state[row, col] = value
values.append(self._maximize(state=after_state, depth=depth + 1)[0])
weights.append(prob)
after_state[row, col] = 0

return np.average(values, weights=weights)

def predict(self, state: np.ndarray) -> int:
value, action = self._maximize(state=state, depth=0)
print(value, action)
return action

0 comments on commit 81f9670

Please sign in to comment.