-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
executable file
·95 lines (78 loc) · 3.72 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
import torch
import torch.nn as nn
from ..utils.mpi import sync_networks
class BaseAgent(nn.Module):
def __init__(self):
super().__init__()
def train(self, training=True):
self.training = training
self.actor.train(training)
self.critic.train(training)
def get_samples(self, replay_buffer):
# sample replay buffer
transitions = replay_buffer.sample()
# preprocess
o, o_next, g = transitions['obs'], transitions['obs_next'], transitions['g']
transitions['obs'], transitions['g'] = self._preproc_og(o, g)
transitions['obs_next'], transitions['g_next'] = self._preproc_og(o_next, g)
obs_norm = self.o_norm.normalize(transitions['obs'])
g_norm = self.g_norm.normalize(transitions['g'])
inputs_norm = np.concatenate([obs_norm, g_norm], axis=1)
obs_next_norm = self.o_norm.normalize(transitions['obs_next'])
g_next_norm = self.g_norm.normalize(transitions['g_next'])
inputs_next_norm = np.concatenate([obs_next_norm, g_next_norm], axis=1)
obs = self.to_torch(inputs_norm)
next_obs = self.to_torch(inputs_next_norm)
action = self.to_torch(transitions['actions'])
reward = self.to_torch(transitions['r'])
done = self.to_torch(transitions['dones'])
return obs, action, reward, done, next_obs
def update_target(self):
# Update the frozen target models
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
target_param.data.copy_(self.soft_target_tau * param.data + (1 - self.soft_target_tau) * target_param.data)
def update_normalizer(self, episode_batch):
mb_obs, mb_ag, mb_g, mb_actions, dones = episode_batch.obs, episode_batch.ag, episode_batch.g, \
episode_batch.actions, episode_batch.dones
mb_obs_next = mb_obs[:, 1:, :]
mb_ag_next = mb_ag[:, 1:, :]
# get the number of normalization transitions
num_transitions = mb_actions.shape[1]
# create the new buffer to store them
buffer_temp = {'obs': mb_obs,
'ag': mb_ag,
'g': mb_g,
'actions': mb_actions,
'obs_next': mb_obs_next,
'ag_next': mb_ag_next,
}
transitions = self.her_sampler.sample_her_transitions(buffer_temp, num_transitions)
obs, g = transitions['obs'], transitions['g']
# pre process the obs and g
transitions['obs'], transitions['g'] = self._preproc_og(obs, g)
# update
self.o_norm.update(transitions['obs'])
self.g_norm.update(transitions['g'])
# recompute the stats
self.o_norm.recompute_stats()
self.g_norm.recompute_stats()
def _preproc_og(self, o, g):
o = np.clip(o, -self.clip_obs, self.clip_obs)
g = np.clip(g, -self.clip_obs, self.clip_obs)
return o, g
def _preproc_inputs(self, o, g, dim=0, device=None):
o_norm = self.o_norm.normalize(o, device=device)
g_norm = self.g_norm.normalize(g, device=device)
if isinstance(o_norm, np.ndarray):
inputs = np.concatenate([o_norm, g_norm], dim)
inputs = torch.tensor(inputs, dtype=torch.float32).unsqueeze(0).to(self.device)
elif isinstance(o_norm, torch.Tensor):
inputs = torch.cat([o_norm, g_norm], dim=1)
return inputs
def to_torch(self, array, copy=True):
if copy:
return torch.tensor(array, dtype=torch.float32).to(self.device)
return torch.as_tensor(array).to(self.device)
def sync_networks(self):
sync_networks(self)