-
Notifications
You must be signed in to change notification settings - Fork 2
/
nash_dqn_exploiter.py
211 lines (180 loc) · 10.7 KB
/
nash_dqn_exploiter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import numpy as np
import random, copy
import torch
import torch.optim as optim
import torch.nn.functional as F
from dqn import DQN
from nash_dqn import NashDQNBase
from equilibrium_solver import NashEquilibriumECOSSolver
class NashDQNExploiter(DQN):
"""
Nash-DQN algorithm
"""
def __init__(self, env, args):
super().__init__(env, args)
self.num_envs = args.num_envs
if args.num_process > 1:
self.normal_nashQ.share_memory()
self.exploiter_nashQ.share_memory()
self.normal_nashQ_target.share_memory()
self.exploiter_nashQ_target.share_memory()
self.num_agents = env.num_agents[0] if isinstance(env.num_agents, list) else env.num_agents
self._init_optimizer(args)
self.exploiter_update_itr = args.algorithm_spec['exploiter_update_itr']
self.env = env
self.args = args
def _init_model(self, env, args):
"""Overwrite DQN's models
:param env: environment
:type env: object
:param args: arguments
:type args: dict
"""
self.normal_nashQ = NashDQNBase(env, args.net_architecture, args.num_envs, two_side_obs = args.marl_spec['global_state']).to(self.device)
self.exploiter_nashQ = NashDQNBase(env, args.net_architecture, args.num_envs, two_side_obs = args.marl_spec['global_state']).to(self.device)
self.normal_nashQ_target = copy.deepcopy(self.normal_nashQ).to(self.device)
self.exploiter_nashQ_target = copy.deepcopy(self.exploiter_nashQ).to(self.device)
self.model = [self.normal_nashQ, self.exploiter_nashQ]
self.target = [self.normal_nashQ_target, self.exploiter_nashQ_target]
def _init_optimizer(self, args):
self.normal_optimizer = optim.Adam(self.normal_nashQ.parameters(), lr=float(args.learning_rate))
self.exploiter_optimizer = optim.Adam(self.exploiter_nashQ.parameters(), lr=float(args.learning_rate))
def choose_action(self, state, Greedy=False, epsilon=None):
if Greedy:
epsilon = 0.
elif epsilon is None:
epsilon = self.epsilon_scheduler.get_epsilon()
if not isinstance(state, torch.Tensor):
state = torch.Tensor(state).to(self.device)
if self.args.ram:
if self.num_envs == 1: # state: (agents, state_dim)
state = state.unsqueeze(0).view(1, -1) # change state from (agents, state_dim) to (1, agents*state_dim)
else: # state: (agents, envs, state_dim)
state = torch.transpose(state, 0, 1) # to state: (envs, agents, state_dim)
state = state.view(state.shape[0], -1) # to state: (envs, agents*state_dim)
else: # image-based input
if self.num_envs == 1: # state: (agents, C, H, W)
state = state.unsqueeze(0).view(1, -1, state.shape[-2], state.shape[-1]) # (1, agents*C, H, W)
else: # state: (agents, envs, C, H, W)
state = torch.transpose(state, 0, 1) # state: (envs, agents, C, H, W)
state = state.view(state.shape[0], -1, state.shape[-2], state.shape[-1]) # state: (envs, agents*C, H, W)
if random.random() > epsilon: # NoisyNet does not use e-greedy
with torch.no_grad():
q_values = self.normal_nashQ(state).detach().cpu().numpy() # needs state: (batch, agents*state_dim)
exploiter_q_values = self.exploiter_nashQ(state).detach().cpu().numpy() # needs state: (batch, agents*state_dim)
try: # nash computation may report error and terminate the process
actions, dists, ne_vs = self.compute_nash(q_values, exploiter_q_values)
except:
print("Invalid nash computation.")
actions = np.random.randint(self.action_dim, size=(state.shape[0], self.num_agents))
else:
actions = np.random.randint(self.action_dim, size=(state.shape[0], self.num_agents))
if self.num_envs == 1:
actions = actions[0] # list of actions to its item
else:
actions = np.array(actions).T # to shape: (agents, envs, action_dim)
return actions
def compute_nash(self, q_values, exploiter_q_values):
q_tables = q_values.reshape(-1, self.action_dim, self.action_dim)
exploiter_q_tables = exploiter_q_values.reshape(-1, self.action_dim, self.action_dim)
all_actions = []
all_dists = []
all_ne_values = []
# all_dists, all_ne_values = NashEquilibriumParallelMWUSolver(q_tables)
for q_table in q_tables:
dist, value = NashEquilibriumECOSSolver(q_table)
all_dists.append(dist)
all_ne_values.append(value)
for ne, eqs in zip(all_dists, exploiter_q_tables):
actions = []
# Nash DQN with exploiter
first_player_expected_value = ne[0] @ eqs # Nash distribution of the first player by exploiter Q table
second_player_best_response = [np.argmin(first_player_expected_value)]
try:
sample_hist = np.random.multinomial(1, ne[0]) # return one-hot vectors as sample from multinomial
first_player_sampled_action = np.where(sample_hist>0)[0]
except:
print('Not a valid distribution from Nash equilibrium solution.')
actions = np.array([first_player_sampled_action, second_player_best_response]).reshape(-1)
all_actions.append(actions)
return np.array(all_actions), all_dists, all_ne_values
def compute_nash_dist(self, q_values):
all_dists = []
all_ne_values = []
q_tables = q_values.reshape(-1, self.action_dim, self.action_dim)
# all_dists, all_ne_values = NashEquilibriumParallelMWUSolver(q_tables)
for q_table in q_tables:
dist, value = NashEquilibriumECOSSolver(q_table)
all_dists.append(dist)
all_ne_values.append(value)
return all_dists, all_ne_values
def update(self):
state, action, reward, next_state, done = self.buffer.sample(self.batch_size)
state = torch.FloatTensor(np.float32(state)).to(self.device)
next_state = torch.FloatTensor(np.float32(next_state)).to(self.device)
action = torch.LongTensor(action).to(self.device)
reward = torch.FloatTensor(reward).to(self.device)
done = torch.FloatTensor(np.float32(done)).to(self.device)
# Nash DQN update for normal one
q_values = self.normal_nashQ(state)
action_dim = int(np.sqrt(q_values.shape[-1])) # for two-symmetric-agent case only
action = torch.LongTensor([a[0]*action_dim+a[1] for a in action]).to(self.device)
# target_next_q_values_ = self.normal_nashQ(next_state) # or use this one
target_next_q_values_ = self.normal_nashQ_target(next_state)
target_next_q_values = target_next_q_values_.detach().cpu().numpy()
q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
try: # nash computation may encounter error and terminate the process
nash_dists, next_q_value = self.compute_nash_dist(target_next_q_values)
except:
print("Invalid nash computation.")
next_q_value = np.zeros_like(reward)
nash_dists = np.ones((*action.shape, self.action_dim))/float(self.action_dim)
next_q_value = torch.FloatTensor(next_q_value).to(self.device)
nash_dists_ = torch.FloatTensor(nash_dists).to(self.device)
expected_q_value = reward + (self.gamma ** self.multi_step) * next_q_value * (1 - done)
loss = F.mse_loss(q_value, expected_q_value.detach())
loss = loss.mean()
self.normal_optimizer.zero_grad()
loss.backward()
self.normal_optimizer.step()
# Nash DQN update for exploiter, probably several steps of update for this TODO
for _ in range(self.exploiter_update_itr):
exploiter_q_values = self.exploiter_nashQ(state)
exploiter_q_value = exploiter_q_values.gather(1, action.unsqueeze(1)).squeeze(1)
# target_exploiter_next_q_values_ = self.exploiter_nashQ(next_state) # or use this one
target_exploiter_next_q_values_ = self.exploiter_nashQ_target(next_state)
target_exploiter_next_q_values_ = target_exploiter_next_q_values_.reshape(-1, action_dim, action_dim)
first_player_expected_value = torch.einsum('bj,bjk->bk', nash_dists_[:, 0], target_exploiter_next_q_values_)
exploiter_next_q_value, _ = torch.min(first_player_expected_value, dim=-1) # second player takes the best response to the first player
expected_exploiter_q_value = reward + (self.gamma ** self.multi_step) * exploiter_next_q_value * (1 - done)
exploiter_loss = F.mse_loss(exploiter_q_value, expected_exploiter_q_value.detach())
exploiter_loss = exploiter_loss.mean()
self.exploiter_optimizer.zero_grad()
exploiter_loss.backward()
self.exploiter_optimizer.step()
if self.update_cnt % self.target_update_interval == 0:
self.update_target(self.normal_nashQ, self.normal_nashQ_target )
self.update_target(self.exploiter_nashQ, self.exploiter_nashQ_target )
self.update_cnt += 1
return loss.item()
def save_model(self, path):
try: # for PyTorch >= 1.7 to be compatible with loading models from any lower version
torch.save(self.normal_nashQ.state_dict(), path+'_normal_model', _use_new_zipfile_serialization=False)
torch.save(self.normal_nashQ_target.state_dict(), path+'_normal_target', _use_new_zipfile_serialization=False)
torch.save(self.exploiter_nashQ.state_dict(), path+'_exploiter_model', _use_new_zipfile_serialization=False)
torch.save(self.exploiter_nashQ_target.state_dict(), path+'_exploiter_target', _use_new_zipfile_serialization=False)
except:
torch.save(self.normal_nashQ.state_dict(), path+'_normal_model')
torch.save(self.normal_nashQ_target.state_dict(), path+'_normal_target')
torch.save(self.exploiter_nashQ.state_dict(), path+'_exploiter_model')
torch.save(self.exploiter_nashQ_target.state_dict(), path+'_exploiter_target')
def load_model(self, path, eval=True):
self.normal_nashQ.load_state_dict(torch.load(path+'_normal_model'))
self.normal_nashQ_target.load_state_dict(torch.load(path+'_normal_target'))
self.exploiter_nashQ.load_state_dict(torch.load(path+'_exploiter_model'))
self.exploiter_nashQ_target.load_state_dict(torch.load(path+'_exploiter_target'))
if eval:
self.normal_nashQ.eval()
self.normal_nashQ_target.eval()
self.exploiter_nashQ.eval()
self.exploiter_nashQ_target.eval()