-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
ppo_trainer.py
271 lines (225 loc) · 10.5 KB
/
ppo_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import torch.nn.functional as F
import sys
import os
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
from utils.utils import print_rank_0
def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32).cuda()
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
def get_model_norm(model):
with torch.no_grad():
total = 0.0
for param in model.parameters():
should_gather = hasattr(
param,
'ds_id') and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with deepspeed.zero.GatheredParameters(param,
enabled=should_gather):
total += float(param.float().norm())
return total
def gather_log_probs(logits, labels):
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
class DeepSpeedPPOTrainer():
def __init__(self, rlhf_engine, args):
self.rlhf_engine = rlhf_engine
self.actor_model = self.rlhf_engine.actor
self.critic_model = self.rlhf_engine.critic
self.ref_model = self.rlhf_engine.ref
self.reward_model = self.rlhf_engine.reward
self.tokenizer = self.rlhf_engine.tokenizer
self.args = args
self.max_answer_seq_len = args.max_answer_seq_len
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
# Those value can be changed
self.kl_ctl = 0.02
self.clip_reward_value = 5
self.cliprange = 0.2
self.cliprange_value = 0.2
self.gamma = 1.0
self.lam = 0.95
def _generate_sequence(self, prompts):
max_min_length = self.max_answer_seq_len + prompts.shape[1]
with torch.no_grad():
seq = self.actor_model.module.generate(prompts,
max_length=max_min_length,
min_length=max_min_length)
# Filter out seq with no asnwers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning
# NOTE: this will causes each GPU has different number of examples
batch_size = seq.shape[0]
prompt_length = prompts.shape[1]
ans = seq[:, prompt_length:]
self.prompt_length = prompt_length
valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1)
out_seq = []
for i in range(batch_size):
if valid_ans_len[
i] <= 1: # if the answer is shorter than 1 token, drop it
continue
else:
out_seq.append(seq[i:i + 1])
out_seq = torch.cat(out_seq, dim=0) # concate output in the batch dim
return out_seq
def generate_experience(self, prompts):
self.eval()
seq = self._generate_sequence(prompts)
self.train()
pad_token_id = self.tokenizer.pad_token_id
attention_mask = seq.not_equal(pad_token_id).long()
with torch.no_grad():
output = self.actor_model(seq, attention_mask=attention_mask)
output_ref = self.ref_model(seq, attention_mask=attention_mask)
reward_score = self.reward_model.forward_value(
seq, attention_mask,
prompt_length=self.prompt_length)['chosen_end_scores'].detach(
)
values = self.critic_model.forward_value(
seq, attention_mask, return_value_only=True).detach()[:, :-1]
logits = output.logits
logits_ref = output_ref.logits
return {
'prompts': prompts,
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:,
1:]),
'value': values,
'rewards': reward_score,
'input_ids': seq,
"attention_mask": attention_mask
}
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
action_mask):
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)
rewards = kl_divergence_estimate
start = prompts.shape[1] - 1
ends = start + action_mask[:, start:].sum(1)
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
self.clip_reward_value)
batch_size = log_probs.shape[0]
for j in range(batch_size):
rewards[j, start:ends[j]][-1] += reward_clip[j]
return rewards
def train_rlhf(self, inputs):
# train the rlhf mode here
### process the old outputs
prompts = inputs['prompts']
log_probs = inputs['logprobs']
ref_log_probs = inputs['ref_logprobs']
reward_score = inputs['rewards']
values = inputs['value']
attention_mask = inputs['attention_mask']
seq = inputs['input_ids']
start = prompts.size()[-1] - 1
action_mask = attention_mask[:, 1:]
old_values = values
with torch.no_grad():
old_rewards = self.compute_rewards(prompts, log_probs,
ref_log_probs, reward_score,
action_mask)
advantages, returns = self.get_advantages_and_returns(
old_values, old_rewards, start)
### process the new outputs
batch = {'input_ids': seq, "attention_mask": attention_mask}
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :],
inputs['input_ids'][:, 1:])
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
log_probs[:, start:], advantages,
action_mask[:, start:])
self.actor_model.backward(actor_loss)
self.actor_model.step()
value = self.critic_model.forward_value(**batch,
return_value_only=True,
use_cache=False)[:, :-1]
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,
start:],
returns, action_mask[:, start:])
self.critic_model.backward(critic_loss)
self.critic_model.step()
return actor_loss, critic_loss
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
## policy gradient loss
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
pg_loss1 = -advantages * ratio
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
def critic_loss_fn(self, values, old_values, returns, mask):
## value loss
values_clipped = torch.clamp(
values,
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
vf_loss1 = (values - returns)**2
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
return vf_loss
def get_advantages_and_returns(self, values, rewards, start):
# Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
for t in reversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values[:, start:]
return advantages.detach(), returns
def _validate_training_mode(self):
assert self.actor_model.module.training
assert self.critic_model.module.training
def _validate_evaluation_mode(self):
assert not self.actor_model.module.training
assert not self.critic_model.module.training
assert not self.ref_model.module.training
assert not self.reward_model.module.training
def train(self):
self.actor_model.train()
self.critic_model.train()
def eval(self):
self.actor_model.eval()
self.critic_model.eval()
self.reward_model.eval()
self.ref_model.eval()
def dump_model_norms(self, tag):
actor_model_norm = get_model_norm(self.actor_model)
ref_model_norm = get_model_norm(self.ref_model)
critic_model_norm = get_model_norm(self.critic_model)
reward_model_norm = get_model_norm(self.reward_model)
print_all_ranks(f'{tag} global_actor_model_norm', actor_model_norm,
self.args.local_rank)
print_all_ranks(f'{tag} global_ref_model_norm', ref_model_norm,
self.args.local_rank)
print_all_ranks(f'{tag} global_critic_model_norm', critic_model_norm,
self.args.local_rank)
print_all_ranks(f'{tag} global_reward_model_norm', reward_model_norm,
self.args.local_rank)
class DeepSpeedPPOTrainerUnsupervised(DeepSpeedPPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def train_unsupervised(self, inputs, unsup_coef):
# Train the unsupervised model here
self._validate_training_mode()
outputs = self.actor_model(**inputs, use_cache=False)
loss = outputs.loss
self.actor_model.backward(unsup_coef * loss)
self.actor_model.step()
return loss