-
Notifications
You must be signed in to change notification settings - Fork 0
/
monitoring.py
86 lines (75 loc) · 3.14 KB
/
monitoring.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gym_snakegame
import gymnasium as gym
import numpy as np
import torch.nn as nn
import torch
import gym_snakegame
from gym_snakegame.wrappers import RewardConverter, SnakeActionMask
from gymnasium.wrappers import DtypeObservation, TransformObservation, TransformReward, ReshapeObservation, RecordVideo
from gymnasium.spaces import Box
from torch.distributions.categorical import Categorical
class CategoricalMasked(Categorical):
def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
self.masks = masks
if len(self.masks) == 0:
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
else:
self.masks = masks.type(torch.BoolTensor).to(device)
logits = torch.where(self.masks, logits, torch.tensor(-1e8).to(device))
super(CategoricalMasked, self).__init__(probs, logits, validate_args)
def entropy(self):
if len(self.masks) == 0:
return super(CategoricalMasked, self).entropy()
p_log_p = self.logits * self.probs
p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.0).to(device))
return -p_log_p.sum(-1)
class Agent(nn.Module):
def __init__(self):
super().__init__()
self.network = nn.Sequential(
nn.Conv2d(n_channel, 32, 3),
nn.ReLU(),
nn.Conv2d(32, 64, 3),
nn.ReLU(),
nn.Conv2d(64, 64, 3),
nn.ReLU(),
nn.Flatten(),
nn.Linear(64 * 6 * 6, 1024),
nn.ReLU(),
)
self.actor = nn.Linear(1024, env.action_space.n)
self.critic = nn.Linear(1024, 1)
def get_action(self, x, invalid_action_mask=None):
hidden = self.network(x)
logits = self.actor(hidden)
probs = CategoricalMasked(logits=logits, masks=invalid_action_mask)
action = probs.sample()
return action
n_channel = 4
board_size = 12
model_path = (
"runs/gym_snakegame/SnakeGame-v0__ppo_v2_s12_action_mask__1__1710016433/cleanrl_ppo_v2_s12_action_mask_244140.pt"
)
n_episode = 3
env = gym.make("gym_snakegame/SnakeGame-v0", board_size=board_size, n_channel=4, render_mode="rgb_array")
env = RecordVideo(env, ".", episode_trigger=lambda x: True, name_prefix="episode")
env = SnakeActionMask(env)
env = DtypeObservation(env, np.float32)
env = TransformObservation(
env, lambda obs: obs / env.unwrapped.ITEM, Box(0, 1, (n_channel, board_size, board_size), dtype=np.float32)
)
env = ReshapeObservation(env, (1, n_channel, board_size, board_size))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent = Agent().to(device)
agent.load_state_dict(torch.load(model_path))
with torch.inference_mode():
observation, info = env.reset()
while n_episode > 0:
invalid_action_mask = torch.Tensor(info["action_mask"]).to(device)
action = agent.get_action(torch.Tensor(observation).to(device), invalid_action_mask=invalid_action_mask)
observation, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
n_episode -= 1
if n_episode > 0:
obs, info = env.reset()
env.close()