-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
55 lines (46 loc) · 2.01 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from config import config
class RLbase(nn.Module):
def __init__(self) -> None:
super(RLbase, self).__init__()
self.common_network = config["common_network"]
self.actor_network = config["actor_network"]
self.critic_network = config["critic_network"]
def get_device(self):
return "cuda" if torch.cuda.is_available() else "cpu"
def forward(self, state):
x = self.common_network(state)
return self.actor_network(x), self.critic_network(x)
class Actor_Critic():
def __init__(self, base) -> None:
self.network = base
self.optimizer = getattr(optim, config["optimizer"])(self.network.parameters(), **config["optim_hparas"])
def sample(self, state):
action_prop, cummulated_reward = self.network(torch.FloatTensor(state).to(self.network.get_device()))
action_dist = Categorical(action_prop)
action = action_dist.sample()
return action.item(), torch.exp(action_dist.log_prob(action)), cummulated_reward
def learn(self, ac_losses, cr_losses):
loss = 0
for i in range(config["episode_per_batch"]):
loss += (ac_losses[i] + cr_losses[i]).sum()
loss /= config["episode_per_batch"]
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def save(self, PATH, progress, total_rewards, total_losses):
torch.save({
"actor_critic_net": self.network.state_dict(),
"optimizer": self.optimizer.state_dict(),
"total_rewards": total_rewards,
"total_losses": total_losses,
"progress": progress
}, PATH)
def load(self, PATH):
checkpoint = torch.load(PATH)
self.network.load_state_dict(checkpoint["actor_critic_net"])
self.optimizer.load_state_dict(checkpoint["optimizer"])
return checkpoint["progress"], checkpoint["total_rewards"], checkpoint["total_losses"]