forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
full_finetune_single_device.py
510 lines (437 loc) · 21.9 KB
/
full_finetune_single_device.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import sys
import time
from functools import partial
from typing import Any, Dict, Optional, Tuple
from warnings import warn
import torch
from omegaconf import DictConfig, ListConfig
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, utils
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from tqdm import tqdm
log = utils.get_logger("DEBUG")
class FullFinetuneRecipeSingleDevice(FTRecipeInterface):
"""
Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe is optimized
for single GPU training. Training on CPU is not supported.
Features:
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
loss in model quality (will depend on the model, training data and other settings). For
GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
precision are currently not supported.
- Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
controlled using the ``gradient_accumulation_steps`` flag.
Total Batch Size = batch_size * gradient accumulation steps.
For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32.
Gradient accumulation is especially useful when you are memory constrained. In this case,
accumulating gradients might give you better training speed than enabling activation
checkpointing.
- Optimizer in Backward. Fusing the optimizer step into the backward pass helps reduce the memory
footprint associated with gradients. This can be especially helpful when you are memory
constrained. Note that users can only use ONE of gradient accumulation or optimizer in backward.
These features currently do not work together. For more details on optimizer in backward, please
see this tutorial: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
- Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes
library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with
8-bit AdamW and Paged AdamW. These optimizers are especially helpful when you are memory constrained
since they help reduce the memory footprint associated with the optimizer states.
- Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
training. Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are
only saved at the end of a given epoch and used in case of resuming training.
Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
currently not supported.
For more details on the checkpointer, please take a look at
our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
- Logging. Terminal, Disk, WandB and TensorBoard are all supported.
For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
has example commands for how to kick-off training.
Args:
cfg (DictConfig): OmegaConf object parsed from yaml file
Raises:
RuntimeError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``gradient_accumulation_steps > 1`` and ``optimizer_in_bwd`` is `True`.
"""
def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
self._dtype = utils.get_dtype(cfg.dtype, device=self._device)
# Disable for fp16, as we haven't validated "full" fp16 with this recipe, nor
# enabled necessary features such as gradient scaling.
if self._dtype == torch.float16:
raise RuntimeError(
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)
# logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.optimizer_in_bwd
# TODO: find a better place / way to perform validation of args that don't yet
# compose with each other.
if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
raise RuntimeError(
"Gradient accumulation is not supported with optimizer in bwd."
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
)
# These are public properties which are updated by the checkpoint loader
# when ``resume_from_checkpoint`` is `True` or validated in tests
self.seed = utils.set_seed(seed=cfg.seed)
self.epochs_run = 0
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Extract the checkpoint state from file and validate. If resume_from_checkpoint
is True, this also includes the recipe state.
"""
self._checkpointer = config.instantiate(
cfg_checkpointer,
resume_from_checkpoint=self._resume_from_checkpoint,
)
checkpoint_dict = self._checkpointer.load_checkpoint()
if self._resume_from_checkpoint:
self._update_recipe_state(checkpoint_dict)
return checkpoint_dict
def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
"""
Updates the recipe state from checkpoint.
"""
# If seed, total_epoch or max_steps_per_epoch don't match,
# warn the user and overwrite
try:
if (
self.seed != ckpt_dict[utils.SEED_KEY]
or self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]
or self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]
):
warn(
message="""Configured value for seed, epochs or max_steps_per_epoch
does not match the value stored in checkpoint."""
)
self.seed = utils.set_seed(seed=ckpt_dict[utils.SEED_KEY])
self.epochs_run = ckpt_dict[utils.EPOCHS_KEY]
self.total_epochs = ckpt_dict[utils.TOTAL_EPOCHS_KEY]
self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY]
except KeyError as e:
raise KeyError from e(
"Checkpoint does not contain the required keys needed for updating recipe state."
"Are you sure you passed in the right recipe checkpoint?"
)
def setup(self, cfg: DictConfig) -> None:
"""
Sets up the recipe state correctly. This includes setting recipe attributes based
on the ``resume_from_checkpoint`` flag.
"""
self._metric_logger = config.instantiate(cfg.metric_logger)
# log config with parameter override
self._metric_logger.log_config(cfg)
ckpt_dict = self.load_checkpoint(cfg.checkpointer)
# ``_setup_model`` handles initialization and loading the state dict. This method
# should be called before ``_setup_optimizer`` since transforming the optimizer
# state dict requires the model
self._model_compile = cfg.compile
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
compile_model=self._model_compile,
model_state_dict=ckpt_dict[utils.MODEL_KEY],
)
self._tokenizer = config.instantiate(cfg.tokenizer)
log.info("Tokenizer is initialized from file.")
# _setup_optimizer should take in ckpt_dict only if training is resumed from
# checkpoint. Transforming the opt state dict is handled by this method
self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=cfg.optimizer_in_bwd,
opt_state_dict=(
ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None
),
)
self._loss_fn = config.instantiate(cfg.loss)
log.info("Loss is initialized.")
# sampler and dataloader depend on the tokenizer and loss_fn and should be
# setup after both of these are initialized
self._sampler, self._dataloader = self._setup_data(
cfg_dataset=cfg.dataset,
shuffle=cfg.shuffle,
batch_size=cfg.batch_size,
)
# Finally update the recipe state which can only be correctly set after all of the
# other components have been initialized and updated.
#
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader, the max_steps_per_epoch param set by the user and the
# gradient_accumulation_steps param. This value is used for logging and tracking
# training state. The computation should happen after the dataloader has been setup
self._steps_per_epoch = (
len(self._dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
compile_model: bool,
model_state_dict: Dict[str, Any],
) -> nn.Module:
"""
Set up the model including enabling activation checkpointing.
"""
with utils.set_default_dtype(self._dtype), self._device:
model = config.instantiate(cfg_model)
if enable_activation_checkpointing:
utils.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerDecoderLayer}
)
model.load_state_dict(model_state_dict)
# Validate model was loaded in with the expected dtype.
utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
log.info(f"Model is initialized with precision {self._dtype}.")
# Compile model, if enabled.
if compile_model:
log.info("Compiling model with torch.compile...")
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
model.compile(backend=backend)
if self._device.type == "cuda":
memory_stats = utils.get_memory_stats(device=self._device)
utils.log_memory_stats(memory_stats)
return model
def _setup_optimizer(
self,
cfg_optimizer: DictConfig,
optimizer_in_bwd: bool = False,
opt_state_dict: Optional[Dict[str, Any]] = None,
) -> Optional[Optimizer]:
"""
Set up the optimizer. This method also handles loading the optimizer state_dict, if specified.
"""
if optimizer_in_bwd:
# Maintain a dict of optims for every parameter.
optim_dict = {
p: config.instantiate(cfg_optimizer, [p])
for p in self._model.parameters()
}
# Register optimizer step hooks on the model to run optimizer in backward.
utils.register_optim_in_bwd_hooks(model=self._model, optim_dict=optim_dict)
# Create a wrapper for checkpoint save/load of optimizer states when running in backward.
self._optim_ckpt_wrapper = utils.create_optim_in_bwd_wrapper(
model=self._model, optim_dict=optim_dict
)
# Load optimizer states. If optimizer states are being restored in an optimizer in backward
# run, these need to have been saved with the same setting. Cannot restore from runs that did not
# use optimizer in backward.
if opt_state_dict is not None:
try:
self._optim_ckpt_wrapper.load_state_dict(opt_state_dict)
except BaseException as e:
raise RuntimeError(
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
) from e
log.info("In-backward optimizers are set up.")
return None
else:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
log.info("Optimizer is initialized.")
return optimizer
def _setup_data(
self,
cfg_dataset: DictConfig,
shuffle: bool,
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
"""
All data related setup happens here. Currently this recipe only supports the
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer)
for single_cfg_dataset in cfg_dataset
]
ds = ConcatDataset(datasets=datasets)
packed = False
else:
ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer)
packed = cfg_dataset.get("packed", False)
sampler = DistributedSampler(
ds,
num_replicas=1,
rank=0,
shuffle=shuffle,
seed=0,
)
dataloader = DataLoader(
dataset=ds,
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
utils.padded_collate,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
if not packed
else None,
)
log.info("Dataset and Sampler are initialized.")
return sampler, dataloader
def save_checkpoint(self, epoch: int) -> None:
"""
Save state dict to file. The recipe save_checkpoint method is responsible for
correctly creating the checkpoint dict and passing to the checkpointer.
"""
ckpt_dict = {utils.MODEL_KEY: self._model.state_dict()}
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
ckpt_dict.update(
{
utils.SEED_KEY: self.seed,
utils.EPOCHS_KEY: self.epochs_run,
utils.TOTAL_EPOCHS_KEY: self.total_epochs,
utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
}
)
if not self._optimizer_in_bwd:
ckpt_dict[utils.OPT_KEY] = self._optimizer.state_dict()
else:
ckpt_dict[utils.OPT_KEY] = self._optim_ckpt_wrapper.state_dict()
self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
)
def train(self) -> None:
"""
The core training loop. Supports training on subsets of the dataset using the
``max_steps_per_epoch``.
"""
if self._model_compile:
log.info(
"NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration."
)
# zero out the gradients before starting training
if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)
pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break
# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
# exist. Currently, only sample packing in PackedDataset returns these
mask = batch.get("mask", None) # shape [b, s, s]
input_pos = batch.get("input_pos", None) # shape [b, s]
tokens = tokens.to(self._device)
num_tokens += tokens.numel()
labels = labels.to(self._device)
mask = mask.to(self._device) if mask is not None else None
input_pos = (
input_pos.to(self._device) if input_pos is not None else None
)
logits = self._model(tokens, mask=mask, input_pos=input_pos)
# Shift so that tokens < n predict n
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
# Compute loss
loss = self._loss_fn(logits, labels)
loss = loss / self._gradient_accumulation_steps
running_loss += loss
loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
if not self._optimizer_in_bwd:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self.global_step += 1
loss_to_log = running_loss.item()
pbar.update(1)
pbar.set_description(
f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}"
)
# Log per-step metrics
if self.global_step % self._log_every_n_steps == 0:
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
# NOTE: for optim in backward, this assumes all optimizers have the same LR. This is currently
# true since we don't expose the ability to configure this yet.
"lr": (
self._optim_ckpt_wrapper.get_optim_key("lr")
if self._optimizer_in_bwd
else self._optimizer.param_groups[0]["lr"]
),
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if self._device.type == "cuda" and self._log_peak_memory_stats:
log_dict.update(utils.get_memory_stats(device=self._device))
self._metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()
self.epochs_run += 1
self.save_checkpoint(epoch=curr_epoch)
def cleanup(self) -> None:
self._metric_logger.close()
@config.parse
def recipe_main(cfg: DictConfig) -> None:
"""
Entry point for the recipe.
Configurable parameters are read in the following order:
- Parameters specified in config (see available configs through ``tune ls``)
- Overwritten by arguments from the command-line
"""
config.log_config(recipe_name="FullFinetuneRecipeSingleDevice", cfg=cfg)
recipe = FullFinetuneRecipeSingleDevice(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.train()
recipe.cleanup()
if __name__ == "__main__":
sys.exit(recipe_main())