-
Notifications
You must be signed in to change notification settings - Fork 52
/
drq.py
307 lines (232 loc) · 10.9 KB
/
drq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import utils
import hydra
class Encoder(nn.Module):
"""Convolutional encoder for image-based observations."""
def __init__(self, obs_shape, feature_dim):
super().__init__()
assert len(obs_shape) == 3
self.num_layers = 4
self.num_filters = 32
self.output_dim = 35
self.output_logits = False
self.feature_dim = feature_dim
self.convs = nn.ModuleList([
nn.Conv2d(obs_shape[0], self.num_filters, 3, stride=2),
nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1),
nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1),
nn.Conv2d(self.num_filters, self.num_filters, 3, stride=1)
])
self.head = nn.Sequential(
nn.Linear(self.num_filters * 35 * 35, self.feature_dim),
nn.LayerNorm(self.feature_dim))
self.outputs = dict()
def forward_conv(self, obs):
obs = obs / 255.
self.outputs['obs'] = obs
conv = torch.relu(self.convs[0](obs))
self.outputs['conv1'] = conv
for i in range(1, self.num_layers):
conv = torch.relu(self.convs[i](conv))
self.outputs['conv%s' % (i + 1)] = conv
h = conv.view(conv.size(0), -1)
return h
def forward(self, obs, detach=False):
h = self.forward_conv(obs)
if detach:
h = h.detach()
out = self.head(h)
if not self.output_logits:
out = torch.tanh(out)
self.outputs['out'] = out
return out
def copy_conv_weights_from(self, source):
"""Tie convolutional layers"""
for i in range(self.num_layers):
utils.tie_weights(src=source.convs[i], trg=self.convs[i])
def log(self, logger, step):
for k, v in self.outputs.items():
logger.log_histogram(f'train_encoder/{k}_hist', v, step)
if len(v.shape) > 2:
logger.log_image(f'train_encoder/{k}_img', v[0], step)
for i in range(self.num_layers):
logger.log_param(f'train_encoder/conv{i + 1}', self.convs[i], step)
class Actor(nn.Module):
"""torch.distributions implementation of an diagonal Gaussian policy."""
def __init__(self, encoder_cfg, action_shape, hidden_dim, hidden_depth,
log_std_bounds):
super().__init__()
self.encoder = hydra.utils.instantiate(encoder_cfg)
self.log_std_bounds = log_std_bounds
self.trunk = utils.mlp(self.encoder.feature_dim, hidden_dim,
2 * action_shape[0], hidden_depth)
self.outputs = dict()
self.apply(utils.weight_init)
def forward(self, obs, detach_encoder=False):
obs = self.encoder(obs, detach=detach_encoder)
mu, log_std = self.trunk(obs).chunk(2, dim=-1)
# constrain log_std inside [log_std_min, log_std_max]
log_std = torch.tanh(log_std)
log_std_min, log_std_max = self.log_std_bounds
log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std +
1)
std = log_std.exp()
self.outputs['mu'] = mu
self.outputs['std'] = std
dist = utils.SquashedNormal(mu, std)
return dist
def log(self, logger, step):
for k, v in self.outputs.items():
logger.log_histogram(f'train_actor/{k}_hist', v, step)
for i, m in enumerate(self.trunk):
if type(m) == nn.Linear:
logger.log_param(f'train_actor/fc{i}', m, step)
class Critic(nn.Module):
"""Critic network, employes double Q-learning."""
def __init__(self, encoder_cfg, action_shape, hidden_dim, hidden_depth):
super().__init__()
self.encoder = hydra.utils.instantiate(encoder_cfg)
self.Q1 = utils.mlp(self.encoder.feature_dim + action_shape[0],
hidden_dim, 1, hidden_depth)
self.Q2 = utils.mlp(self.encoder.feature_dim + action_shape[0],
hidden_dim, 1, hidden_depth)
self.outputs = dict()
self.apply(utils.weight_init)
def forward(self, obs, action, detach_encoder=False):
assert obs.size(0) == action.size(0)
obs = self.encoder(obs, detach=detach_encoder)
obs_action = torch.cat([obs, action], dim=-1)
q1 = self.Q1(obs_action)
q2 = self.Q2(obs_action)
self.outputs['q1'] = q1
self.outputs['q2'] = q2
return q1, q2
def log(self, logger, step):
self.encoder.log(logger, step)
for k, v in self.outputs.items():
logger.log_histogram(f'train_critic/{k}_hist', v, step)
assert len(self.Q1) == len(self.Q2)
for i, (m1, m2) in enumerate(zip(self.Q1, self.Q2)):
assert type(m1) == type(m2)
if type(m1) is nn.Linear:
logger.log_param(f'train_critic/q1_fc{i}', m1, step)
logger.log_param(f'train_critic/q2_fc{i}', m2, step)
class DRQAgent(object):
"""Data regularized Q: actor-critic method for learning from pixels."""
def __init__(self, obs_shape, action_shape, action_range, device,
encoder_cfg, critic_cfg, actor_cfg, discount,
init_temperature, lr, actor_update_frequency, critic_tau,
critic_target_update_frequency, batch_size):
self.action_range = action_range
self.device = device
self.discount = discount
self.critic_tau = critic_tau
self.actor_update_frequency = actor_update_frequency
self.critic_target_update_frequency = critic_target_update_frequency
self.batch_size = batch_size
self.actor = hydra.utils.instantiate(actor_cfg).to(self.device)
self.critic = hydra.utils.instantiate(critic_cfg).to(self.device)
self.critic_target = hydra.utils.instantiate(critic_cfg).to(
self.device)
self.critic_target.load_state_dict(self.critic.state_dict())
# tie conv layers between actor and critic
self.actor.encoder.copy_conv_weights_from(self.critic.encoder)
self.log_alpha = torch.tensor(np.log(init_temperature)).to(device)
self.log_alpha.requires_grad = True
# set target entropy to -|A|
self.target_entropy = -action_shape[0]
# optimizers
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
lr=lr)
self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=lr)
self.train()
self.critic_target.train()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
@property
def alpha(self):
return self.log_alpha.exp()
def act(self, obs, sample=False):
obs = torch.FloatTensor(obs).to(self.device)
obs = obs.unsqueeze(0)
dist = self.actor(obs)
action = dist.sample() if sample else dist.mean
action = action.clamp(*self.action_range)
assert action.ndim == 2 and action.shape[0] == 1
return utils.to_np(action[0])
def update_critic(self, obs, obs_aug, action, reward, next_obs,
next_obs_aug, not_done, logger, step):
with torch.no_grad():
dist = self.actor(next_obs)
next_action = dist.rsample()
log_prob = dist.log_prob(next_action).sum(-1, keepdim=True)
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
target_V = torch.min(target_Q1,
target_Q2) - self.alpha.detach() * log_prob
target_Q = reward + (not_done * self.discount * target_V)
dist_aug = self.actor(next_obs_aug)
next_action_aug = dist_aug.rsample()
log_prob_aug = dist_aug.log_prob(next_action_aug).sum(-1,
keepdim=True)
target_Q1, target_Q2 = self.critic_target(next_obs_aug,
next_action_aug)
target_V = torch.min(
target_Q1, target_Q2) - self.alpha.detach() * log_prob_aug
target_Q_aug = reward + (not_done * self.discount * target_V)
target_Q = (target_Q + target_Q_aug) / 2
# get current Q estimates
current_Q1, current_Q2 = self.critic(obs, action)
critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(
current_Q2, target_Q)
Q1_aug, Q2_aug = self.critic(obs_aug, action)
critic_loss += F.mse_loss(Q1_aug, target_Q) + F.mse_loss(
Q2_aug, target_Q)
logger.log('train_critic/loss', critic_loss, step)
# Optimize the critic
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
self.critic.log(logger, step)
def update_actor_and_alpha(self, obs, logger, step):
# detach conv filters, so we don't update them with the actor loss
dist = self.actor(obs, detach_encoder=True)
action = dist.rsample()
log_prob = dist.log_prob(action).sum(-1, keepdim=True)
# detach conv filters, so we don't update them with the actor loss
actor_Q1, actor_Q2 = self.critic(obs, action, detach_encoder=True)
actor_Q = torch.min(actor_Q1, actor_Q2)
actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean()
logger.log('train_actor/loss', actor_loss, step)
logger.log('train_actor/target_entropy', self.target_entropy, step)
logger.log('train_actor/entropy', -log_prob.mean(), step)
# optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self.actor.log(logger, step)
self.log_alpha_optimizer.zero_grad()
alpha_loss = (self.alpha *
(-log_prob - self.target_entropy).detach()).mean()
logger.log('train_alpha/loss', alpha_loss, step)
logger.log('train_alpha/value', self.alpha, step)
alpha_loss.backward()
self.log_alpha_optimizer.step()
def update(self, replay_buffer, logger, step):
obs, action, reward, next_obs, not_done, obs_aug, next_obs_aug = replay_buffer.sample(
self.batch_size)
logger.log('train/batch_reward', reward.mean(), step)
self.update_critic(obs, obs_aug, action, reward, next_obs,
next_obs_aug, not_done, logger, step)
if step % self.actor_update_frequency == 0:
self.update_actor_and_alpha(obs, logger, step)
if step % self.critic_target_update_frequency == 0:
utils.soft_update_params(self.critic, self.critic_target,
self.critic_tau)