forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lora_dpo_single_device.py
578 lines (495 loc) · 22.7 KB
/
lora_dpo_single_device.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import sys
import time
from functools import partial
from typing import Any, Dict, Optional, Tuple
from warnings import warn
import torch
from omegaconf import DictConfig, ListConfig
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft.peft_utils import (
disable_adapter,
get_adapter_params,
get_merged_lora_ckpt,
set_trainable_params,
validate_state_dict_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from tqdm import tqdm
log = utils.get_logger("DEBUG")
class LoRADPORecipeSingleDevice(FTRecipeInterface):
"""
LoRA DPO recipe for dense transformer-based LLMs such as Llama2 for
single device training. This is based on HF's DPOTrainer in the
TRL library: https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L65
This recipe supports:
- Activation checkpointing. This is enabled by default but is configurable.
- Full bf16 training for supported HW architectures. We currently check bf16 support via
the `torch.cuda.is_bf16_supported` API. This is disabled by default but can be enabled via
setting `dtype=bf16` in configuration.
- Checkpointing: of LoRA adapter parameters and their optimizer states. When resuming
from a checkpoint, the adapter parameters are loaded from the checkpoint along
with the base model weights. Note that intra-epoch resumption is not supported.
- Logging to terminal, WandB, or TensorBoard.
Assumptions:
- Checkpoints are ONLY saved at epoch boundaries. In case of failure, work done
in ongoing epoch is lost.
- Datasets are Map-style and data fits in memory (not streamed).
The following configs can be used to run this recipe:
>>> tune ls
RECIPE CONFIG
lora_dpo_single_device llama2/7B_lora_dpo_single_device
Args:
cfg (DictConfig): OmegaConf object parsed from yaml file
Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
"""
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
# Reduced precision logic
self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
# fp16 precision is explicitly disabled as it is not supported in this
# recipe (for example, no gradient scaling).
if self._dtype == torch.float16:
raise ValueError(
"fp16 precision is not supported in this recipe. Please use fp32 or bf16."
)
# For CUDA devices, check if the HW supports bf16 if bf16 is specified.
if (
self._dtype == torch.bfloat16
and self._device != torch.device("cpu")
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError("Full bf16 training is not supported on this hardware.")
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = utils.set_seed(seed=cfg.seed)
self.epochs_run = 0
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. This includes the
base model weights. If resume_from_checkpoint is True, this also includes
the adapter weights and recipe state
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()
if self._resume_from_checkpoint:
if utils.ADAPTER_KEY not in checkpoint_dict:
raise ValueError(
"Adapter weights not found. Please ensure a valid adapter checkpoint is provided."
)
# _update_recipe_state will throw an exception if the recipe state is not correctly loaded
# no need to check here
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
"""
Updates the recipe state from checkpoint.
"""
# If seed, total_epoch or max_steps_per_epoch don't match,
# warn the user and overwrite
if (
self.seed != ckpt_dict[utils.SEED_KEY]
or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]
or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]
):
warn(
message="""Configured value for seed, epochs or max_steps_per_epoch
does not match the value stored in checkpoint."""
)
self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY])
self.epochs_run = ckpt_dict[utils.EPOCHS_KEY]
self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY]
self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY]
def setup(self, cfg: DictConfig) -> None:
"""
Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True),
model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader.
"""
self._metric_logger = config.instantiate(cfg.metric_logger)
# log config with parameter override
self._metric_logger.log_config(cfg)
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
base_model_state_dict=checkpoint_dict[utils.MODEL_KEY],
lora_weights_state_dict=(
checkpoint_dict[utils.ADAPTER_KEY]
if self._resume_from_checkpoint
else None
),
)
self._tokenizer = config.instantiate(cfg.tokenizer)
log.info("Tokenizer is initialized from file.")
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
opt_state_dict=(
checkpoint_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
)
self._loss_fn = config.instantiate(cfg.loss)
log.info("Loss is initialized.")
# Dataloader depends on the tokenizer and loss_fn and should be
# setup after all of these are setup
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
)
# Finally update the recipe state which can only be correctly set after all of the
# other components have been initialized and updated.
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader
# has been setup
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(cfg_model)
self._lora_rank = cfg_model.lora_rank
self._lora_alpha = cfg_model.lora_alpha
self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
validate_state_dict_for_lora(
lora_attn_modules=cfg_model.lora_attn_modules,
apply_lora_to_mlp=cfg_model.apply_lora_to_mlp,
apply_lora_to_output=cfg_model.apply_lora_to_output,
full_model_state_dict_keys=model.state_dict().keys(),
lora_state_dict_keys=(
lora_weights_state_dict.keys()
if lora_weights_state_dict is not None
else None
),
base_model_state_dict_keys=base_model_state_dict.keys(),
)
model.load_state_dict(base_model_state_dict, strict=False)
if lora_weights_state_dict:
model.load_state_dict(lora_weights_state_dict, strict=False)
# Validate model adapter params were loaded in with the expected dtype
# TODO (rohan-varma): Further validation to ensure the appropriate base params
# are NF4 vs bf16 based on the quantization config.
utils.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
log.info(f"Model is initialized with precision {self._dtype}.")
if self._device == torch.device("cuda"):
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)
return model
def _setup_optimizer(
self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None
) -> Optimizer:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
log.info("Optimizer and loss are initialized.")
return optimizer
def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
self._optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
log.info("Learning rate scheduler is initialized.")
return lr_scheduler
def _setup_data(
self,
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports
Map-style Datasets which fit into memory and an option for random shuffling.
Samplers, iterable datasets, and streaming datasets are not supported.
"""
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
sampler = DistributedSampler(
ds,
num_replicas=1,
rank=0,
shuffle=shuffle,
seed=0,
)
dataloader = DataLoader(
dataset=ds,
sampler=sampler,
batch_size=batch_size,
collate_fn=partial(
utils.padded_collate_dpo,
padding_idx=self._tokenizer.pad_id,
ignore_idx=CROSS_ENTROPY_IGNORE_IDX,
),
)
log.info("Dataset and Sampler are initialized.")
return sampler, dataloader
def save_checkpoint(self, epoch: int) -> None:
"""
Checkpoint the state of the recipe. The constructed checkpoint state dict
contains the following information:
- Merged weights with key MODEL_KEY
- Adapter weights with key ADAPTER_KEY
- Relevant recipe state if training is not complete
Checkpointer will save the merged weights, adapter weights and recipe state in
different checkpoint files. To correctly resume from training, the adapter weights
and recipe state must be provided along with the base model weights.
"""
ckpt_dict = {}
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
ckpt_dict.update(
{
utils.OPT_KEY: self._optimizer.state_dict(),
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
# Construct the full state dict with LoRA weights merged into base LLM weights
merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
ckpt_dict.update({utils.MODEL_KEY: merged_state_dict})
# Construct the adapter weights
adapter_key_filter = lambda x: x in self.adapter_params
adapter_state_dict = {
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
}
ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict})
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
)
def concatenated_forward(
self, model: nn.Module, batch: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Run forward pass of the model with chosen and rejected samples concatenated.
Args:
model (nn.Module): The model to be used for the forward pass.
batch (Tuple[torch.Tensor, torch.Tensor]): Tuple of input_ids and labels.
Returns:
Tuple of chosen log probs, rejected log probs, chosen logits, rejected logits.
"""
concatenated_input_ids, concatenated_labels = batch
concatenated_input_ids = concatenated_input_ids.to(self._device)
concatenated_labels = concatenated_labels.to(self._device)
# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2
all_logits = model(concatenated_input_ids)
all_log_probs = self.get_batch_log_probs(all_logits, concatenated_labels)
chosen_log_probs = all_log_probs[:len_chosen]
rejected_log_probs = all_log_probs[len_chosen:]
chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]
return (chosen_log_probs, rejected_log_probs, chosen_logits, rejected_logits)
@staticmethod
def get_batch_log_probs(
logits: torch.FloatTensor,
labels: torch.LongTensor,
label_pad_token_id: int = CROSS_ENTROPY_IGNORE_IDX,
) -> torch.FloatTensor:
"""
Calculate log probabilities based on provided logits and labels.
Args:
logits (torch.FloatTensor): direct logits output of the model of shape (b, s, v)
labels (torch.LongTensor): ground-truth labels to compute log probs with, shape (b, s).
Label tokens with a value of label_pad_token_id are ignored.
label_pad_token_id (int): token id to ignore in labels.
Returns:
Calculated log probs of shape (b, )
Raises:
ValueError: If logits and labels have different shapes.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError(
"Logits (batch and sequence length dim) and labels must have the same shape."
)
labels = labels[:, 1:].clone()
logits = logits[:, :-1, :]
loss_mask = labels != label_pad_token_id
labels[labels == label_pad_token_id] = 0
# take log-likelihood of the labels given our model
per_token_log_probs = torch.gather(
logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)
).squeeze(2)
return (per_token_log_probs * loss_mask).sum(-1)
def train(self) -> None:
"""
The core training loop.
"""
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break
# batch is input_ids, labels
num_tokens += batch[0].numel()
(
policy_chosen_log_probs,
policy_rejected_log_probs,
policy_chosen_logits,
policy_rejected_logits,
) = self.concatenated_forward(self._model, batch)
with torch.no_grad(), disable_adapter(self._model):
(
reference_chosen_log_probs,
reference_rejected_log_probs,
_,
_,
) = self.concatenated_forward(self._model, batch)
loss, chosen_rewards, rejected_rewards = self._loss_fn(
policy_chosen_log_probs,
policy_rejected_log_probs,
reference_chosen_log_probs,
reference_rejected_log_probs,
)
loss = loss.mean()
reward_accuracies = (chosen_rewards > rejected_rewards).float()
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}"
)
# Log per-step metrics
if self.global_step % self._log_every_n_steps == 0:
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
"rewards/chosen": chosen_rewards.mean().cpu(),
"rewards/rejected": rejected_rewards.mean().cpu(),
"rewards/accuracies": reward_accuracies.mean().cpu(),
"rewards/margins": (chosen_rewards - rejected_rewards)
.mean()
.cpu(),
"log_probs/rejected": policy_rejected_log_probs.detach()
.mean()
.cpu(),
"log_probs/chosen": policy_chosen_log_probs.detach()
.mean()
.cpu(),
"logits/rejected": policy_rejected_logits.detach()
.mean()
.cpu(),
"logits/chosen": policy_chosen_logits.detach().mean().cpu(),
}
if self._log_peak_memory_stats:
log_dict.update(utils.get_memory_stats(device=self._device))
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
def cleanup(self) -> None:
self._metric_logger.close()
@config.parse
def recipe_main(cfg: DictConfig) -> None:
"""
Entry point for the recipe.
Configurable parameters are read in the following order:
- Parameters specified in config (see available configs through ``tune ls``)
- Overwritten by arguments from the command-line
"""
config.log_config(recipe_name="LoRADPORecipeSingleDevice", cfg=cfg)
recipe = LoRADPORecipeSingleDevice(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.train()
recipe.cleanup()
if __name__ == "__main__":
sys.exit(recipe_main())