-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
425 lines (384 loc) · 16 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
import copy
import functools
import random
from collections import deque
from pathlib import Path
from typing import Callable, List, Union
import gymnasium as gym
import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils import tensorboard as tb
from tqdm import tqdm, trange
plt.switch_backend("agg")
class ReplayBuffer:
def __init__(self, capicity: int) -> None:
self.capicity = capicity
self.buffer = deque(maxlen=self.capicity)
@property
def size(self):
return len(self.buffer)
def push(self, s, a, r, next_s, t):
if self.size == self.capicity:
self.buffer.popleft()
self.buffer.append([s, a, r, next_s, t])
def is_full(self):
return self.size == self.capicity
def sample(self, N: int, device: str):
"""采样数据并打包"""
assert N <= self.size, "batch is too big"
samples = random.sample(self.buffer, N)
states, actions, rewards, next_states, terminated = zip(*samples)
return (
torch.from_numpy(np.vstack(states)).float().to(device),
torch.from_numpy(np.vstack(actions)).float().to(device),
torch.from_numpy(np.vstack(rewards)).float().to(device),
torch.from_numpy(np.vstack(next_states)).float().to(device),
torch.from_numpy(np.vstack(terminated)).float().to(device),
)
class Actor(nn.Module):
"""actor网络"""
def __init__(
self,
state_dim,
hidden_dim,
action_dim,
action_scope: Union[List, np.ndarray],
):
super().__init__()
assert len(action_scope) == action_dim
self.action_scope = action_scope
# 定义网络结构
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(hidden_dim, action_dim)
# tanh将输出限制在(-1,+1)之间
self.tanh = nn.Tanh()
# 映射动作范围
self.map_layer = nn.Linear(action_dim, action_dim)
self.map_layer.weight.data.copy_(torch.diag(self.action_mapping[:, 0]))
self.map_layer.bias.data.copy_(self.action_mapping[:, 1])
self.map_layer.requires_grad_(False)
def forward(self, state_tensor):
x = self.fc1(state_tensor)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
x = self.tanh(x)
x = self.map_layer(x)
return x
@property
@functools.lru_cache
def action_mapping(self) -> List[Callable]:
maps = []
for x1, x2 in self.action_scope:
A = [[-1, 1], [1, 1]]
B = [x1, x2]
k, b = np.linalg.solve(A, B)
maps.append([k, b])
return torch.tensor(maps).float()
class Critic(nn.Module):
"""Q网络: (s,a)-->q"""
def __init__(self, state_dim, hidden_dim, action_dim):
super().__init__()
self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.relu2 = nn.ReLU()
self.fc3 = nn.Linear(hidden_dim, 1)
def forward(self, state_tensor, action_tensor):
"""网络输入是状态和动作, 因此需要cat在一起"""
x = torch.cat([state_tensor, action_tensor], dim=-1)
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.relu2(x)
x = self.fc3(x)
return x
class DDPG:
def __init__(
self,
env: str,
id: int,
heter: np.ndarray,
seed: Union[int, None] = None,
lr: float = 1e-3,
tau: float = 0.005,
gamma: float = 0.98,
hidden_dim: list = [400, 400],
buffer_capicity: int = 10000,
buffer_init_ratio: float = 0.30,
batch_size: int = 64,
train_batchs: Union[int, None] = None,
device: str = "cpu",
save_dir: Union[str, None] = None,
**kwargs,
):
self.env = gym.make(env, heter=heter, seed=seed)
self.env_for_test = gym.make(env, heter=heter, is_test=True, seed=seed)
self.env_name = self.env.spec.id
self.id = id
state_dim = self.env.observation_space.shape[0]
action_dim = self.env.action_space.shape[0]
action_scope = list(zip(self.env.action_space.low, self.env.action_space.high))
self.actor = Actor(state_dim, hidden_dim, action_dim, action_scope).to(device)
self.critic = Critic(state_dim, hidden_dim, action_dim).to(device)
self.actor_target = Actor(state_dim, hidden_dim, action_dim, action_scope).to(device)
self.critic_target = Critic(state_dim, hidden_dim, action_dim).to(device)
self.actor_target.load_state_dict(self.actor.state_dict())
self.critic_target.load_state_dict(self.critic.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr / 10)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr)
self.tau = tau
self.gamma = gamma
self.device = device
self.replay_buffer = ReplayBuffer(buffer_capicity)
self.buffer_init_ratio = buffer_init_ratio
self.batch_size = batch_size
# 训练时使用
self.episode = 0
self.episode_reward = 0
self.episode_reward_list = []
self.episode_len = 0
self.global_step = 0
self.total_train_batchs = train_batchs
assert save_dir is not None
self.save_dir = Path(save_dir)
self.save_dir.mkdir(parents=True)
self.logger = tb.SummaryWriter(self.save_dir / "log")
self.first_train = True
self.state, _ = self.env.reset()
@torch.no_grad()
def get_action(self, s: Union[np.array, torch.Tensor], eps: float = 0.0):
"""在训练时得到含噪声的连续动作"""
if isinstance(s, np.ndarray) is True:
s = torch.from_numpy(s).float().to(self.device)
if np.random.uniform() < eps:
a = self.env.action_space.sample()
else:
s_ = self.env.unwrapped.normalize_state(s)
a = self.actor(s_)
a = a.cpu().numpy()
return a
def collect_exp_before_train(self):
"""开启训练之前预先往buffer里面存入一定数量的经验"""
assert 0 < self.buffer_init_ratio < 1
num = self.buffer_init_ratio * self.replay_buffer.capicity
bar = tqdm(range(int(num)), leave=False, ncols=80)
bar.set_description_str(f"env:{self.id}->")
s, _ = self.env.reset()
while self.replay_buffer.size < num:
a = self.get_action(s, 1.0)
ns, r, t1, t2, _ = self.env.step(a)
self.replay_buffer.push(
self.env.unwrapped.normalize_state(s),
a,
r,
self.env.unwrapped.normalize_state(ns),
t1,
)
s = ns if not t2 else self.env.reset()[0]
bar.update()
def soft_sync_target(self):
"""软更新参数到target"""
net_groups = [
(self.actor, self.actor_target),
(self.critic, self.critic_target),
]
for net, net_ in net_groups:
for p, p_ in zip(net.parameters(), net_.parameters()):
p_.data.copy_(p.data * self.tau + p_.data * (1 - self.tau))
@property
def epsilon(self):
"""得到递减的epsilon"""
prog = self.global_step / self.total_train_batchs
assert 0 <= prog <= 1.0
return max(np.exp(-4 * prog), 0.10)
def train_one_batch(self):
# 从buffer中取出一批数据
states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size, self.device)
# 计算critic_loss并更新
with torch.no_grad():
td_targets = rewards + self.gamma * (1 - dones) * self.critic_target(
next_states, self.actor_target(next_states)
)
td_errors = td_targets - self.critic(states, actions)
critic_loss = torch.pow(td_errors, 2).mean()
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# 计算actor_loss并更新
actor_loss = -self.critic(states, self.actor(states)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# 软更新target
self.soft_sync_target()
return actor_loss.detach(), critic_loss.detach()
def train(self, train_batchs: int, disable_prog_bar: bool = True):
if self.first_train is True:
self.collect_exp_before_train()
self.first_train = False
# 开始训练
for _ in trange(train_batchs, disable=disable_prog_bar, ncols=80, leave=False):
a = self.get_action(self.state, self.epsilon)
ns, r, t1, t2, _ = self.env.step(a)
self.episode_reward += r
self.replay_buffer.push(
self.env.unwrapped.normalize_state(self.state),
a,
r,
self.env.unwrapped.normalize_state(ns),
t1,
)
if t1 or t2:
self.log_info_per_episode()
self.state, _ = self.env.reset()
else:
self.state = ns
actor_loss, critic_loss = self.train_one_batch()
self.log_info_per_batch(actor_loss, critic_loss)
def log_info_per_episode(self):
self.logger.add_scalar("Train/episode_reward", self.episode_reward, self.episode)
self.logger.add_scalar("Train/buffer_size", self.replay_buffer.size, self.episode)
self.logger.add_scalar("Episode/episode_len", self.episode_len, self.episode)
self.episode_reward_list.append(self.episode_reward)
self.episode += 1
self.episode_len = 0
self.episode_reward = 0
def log_info_per_batch(self, actor_loss, critic_loss):
self.logger.add_scalar("Loss/actor_loss", actor_loss, self.global_step)
self.logger.add_scalar("Loss/critic_loss", critic_loss, self.global_step)
self.logger.add_scalar("Train/epsilon", self.epsilon, self.global_step)
self.global_step += 1
self.episode_len += 1
def save(self, save_path: str):
params = {"actor": self.actor.state_dict(), "critic": self.critic.state_dict()}
torch.save(params, save_path)
class Server:
"""server角色"""
def __init__(self, points: List[DDPG], device: str = "cpu") -> None:
"""为保护用户隐私, 除了神经网络参数之外, 不能从节点读取任何数据"""
self.points = points
self.device = device
self.actor = copy.deepcopy(self.points[0].actor).to(self.device)
self.actor_target = copy.deepcopy(self.points[0].actor_target).to(self.device)
self.critic = copy.deepcopy(self.points[0].critic).to(self.device)
self.critic_target = copy.deepcopy(self.points[0].critic_target).to(self.device)
def merge_params(self, merge_target: bool = False) -> None:
"""合并/分发参数"""
for name, param in self.actor.state_dict().items():
avg_param = torch.stack([p.actor.state_dict()[name] for p in self.points]).mean(dim=0)
param.data.copy_(avg_param.data)
for name, param in self.critic.state_dict().items():
avg_param = torch.stack([p.critic.state_dict()[name] for p in self.points]).mean(dim=0)
param.data.copy_(avg_param.data)
if merge_target is True:
for name, param in self.actor_target.state_dict().items():
avg_param = torch.stack([p.actor_target.state_dict()[name] for p in self.points]).mean(dim=0)
param.data.copy_(avg_param.data)
for name, param in self.critic_target.state_dict().items():
avg_param = torch.stack([p.critic_target.state_dict()[name] for p in self.points]).mean(dim=0)
param.data.copy_(avg_param.data)
for p in self.points:
p.actor.load_state_dict(self.actor.state_dict())
p.critic.load_state_dict(self.critic.state_dict())
if merge_target is True:
p.actor_target.load_state_dict(self.actor_target.state_dict())
p.critic_target.load_state_dict(self.critic_target.state_dict())
class FedDDPG:
def __init__(
self,
point_configs: List[dict],
merge_num: int,
merge_interval: int,
merge_target: bool,
episode_num_eval: int,
save_dir: str = None,
save_interval: int = 10,
device: str = "cpu",
) -> None:
assert save_dir is not None, "save_dir can't be empty"
self.device = device
self.point_configs = point_configs
self.merge_num = merge_num
self.merge_interval = merge_interval
self.merge_target = merge_target
self.episode_num_eval = episode_num_eval
self.save_dir = save_dir
self.save_interval = save_interval
self.points = [DDPG(**c) for c in point_configs]
self.server = Server(self.points, device=self.device)
self.logger = tb.SummaryWriter(self.save_dir / "global" / "log")
def train(self):
"""总共合并训练self.merge_num次"""
with trange(self.merge_num, ncols=80) as prog_bar:
for m in range(self.merge_num):
for p in tqdm(self.points, leave=False, disable=True):
p.train(self.merge_interval)
self.server.merge_params(self.merge_target)
prog_bar.update()
if m % self.save_interval == 0 and m != 0:
self.save(self.save_dir / "server" / f"merge_{m}.pt")
self.summarize_point_reward()
for p in self.points:
self.save(p.save_dir / "latest.pt")
p.logger.close()
self.logger.close()
def train_baseline(self):
"""训练baseline用于对照"""
for p in tqdm(self.points, ncols=80):
p.train(self.merge_num * self.merge_interval, disable_prog_bar=False)
self.summarize_point_reward()
for p in self.points:
p.save(p.save_dir / "latest.pt")
p.logger.close()
self.logger.close()
def evaluate_point_reward(self, point: DDPG):
"""传入一个节点, 评估奖励(不改变环境状态)"""
env = point.env_for_test
point_r = 0
for _ in range(self.episode_num_eval):
s, _ = env.reset()
while True:
a = point.get_action(s)
ns, r, t1, t2, _ = env.step(a)
point_r += r
s = ns
if t1 or t2:
break
return point_r / self.episode_num_eval
def evaluate_avg_reward(self):
"""评估每个节点的奖励并取平均"""
reward_list = []
for p in self.points:
point_r = self.evaluate_point_reward(p)
reward_list.append(point_r)
return sum(reward_list) / len(reward_list)
def summarize_point_reward(self):
"""统计每个point在训练过程中已经完成的episode的奖励, 并按最短的长度取平均"""
min_length = min([len(p.episode_reward_list) for p in self.points])
table = []
for p in self.points:
table.append(p.episode_reward_list[:min_length])
np.save(p.save_dir / "episode_reward_list.npy", np.array(p.episode_reward_list))
avg_episode_reward = np.array(table).mean(0)
(
plt.plot(range(min_length), avg_episode_reward),
plt.grid(),
plt.title("average episode reward"),
)
plt.savefig(self.save_dir / "global" / "average_episode_reward.svg")
plt.close()
def save(self, save_path):
"""保存权重"""
Path(save_path).parent.mkdir(exist_ok=True)
params = {
"actor": self.server.actor.state_dict(),
"critic": self.server.critic.state_dict(),
}
torch.save(params, save_path)