-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama uses significantly more memory in 4.38 & 4.39 than 4.37 with identical code #30010
Comments
That is a really great report, thanks! 🤗
TLDR; there must be something a bit wrong for FSDP |
Yes, that's likely the issue. |
FSDP is not the issue, it's just in that case we're using a lot less memory for weights and doing gradient checkpointing, so the spikes show up better. |
@fxmarty can you have a look? 🤗 |
As @johnowhitaker mentioned, FSDP is not the issue.
@ArthurZucker, I couldn't fit a large enough sequence length with 4-bit Llama 7B, so I'll illustrate with a single Attention layer. It will use an attention mask and thus the SDPA efficient kernel if class Attention(torch.nn.Module):
def __init__(
self,
hidden_size: int = 2048,
num_heads: int = 16,
seq_len: int = 2048,
sdpa_flash: bool = False,
):
super().__init__()
self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
self.Wo = nn.Linear(hidden_size, hidden_size, bias=False)
self.nh = num_heads
self.sdpa_flash = sdpa_flash
if not sdpa_flash:
self.register_buffer(
"causal_mask",
torch.triu(torch.ones([seq_len, seq_len], dtype=torch.bool), diagonal=1)
.logical_not()
.view(1, 1, seq_len, seq_len),
)
def forward(self, x):
B, S, C = x.shape
x = self.Wqkv(x).reshape(B, S, 3, self.nh, C // self.nh)
query, key, value = x.transpose(3, 1).unbind(dim=2)
attn = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=self.causal_mask[:, :, :S, :S] if not self.sdpa_flash else None,
is_causal=self.sdpa_flash,
)
return self.Wo(attn.transpose(1, 2).reshape(B, S, C)) At a batch size of 128,
In addition to Here is the script and commands to recreate the single Attention layer memory measurements: python example.py --batch_size 16
python example.py --batch_size 16 --sdpa_flash
python example.py --batch_size 128 --sdpa_flash import argparse
from random import randint
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import Dataset
torch.set_float32_matmul_precision("high")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--max_sequence_length", type=int, default=2048)
parser.add_argument("--dataset_size", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--vocab_size", type=int, default=64)
parser.add_argument("--hidden_size", type=int, default=2048)
parser.add_argument("--num_heads", type=int, default=16)
parser.add_argument("--sdpa_flash", action="store_true")
parser.add_argument("--torch_compile", action="store_true")
parser.add_argument("--profile_memory", action="store_true")
return parser.parse_args()
def get_dataset(dataset_size, sequence_length, vocab_size=64):
dataset = Dataset.from_dict(
{"input_ids": [[randint(0, vocab_size) for _ in range(sequence_length)] for i in range(0, dataset_size)]}
)
return dataset
def data_collator(batch):
batch = torch.stack([b["input_ids"] for b in batch])
return {"input_ids": batch, "labels": batch}
def append_stats(batch, stats):
if batch == 0:
stats.append(torch.cuda.memory_reserved(0))
class Attention(torch.nn.Module):
def __init__(
self,
hidden_size: int = 2048,
num_heads: int = 16,
seq_len: int = 2048,
sdpa_flash: bool = False,
):
super().__init__()
self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
self.Wo = nn.Linear(hidden_size, hidden_size, bias=False)
self.nh = num_heads
self.sdpa_flash = sdpa_flash
if not sdpa_flash:
self.register_buffer(
"causal_mask",
torch.triu(torch.ones([seq_len, seq_len], dtype=torch.bool), diagonal=1)
.logical_not()
.view(1, 1, seq_len, seq_len),
)
def forward(self, x):
B, S, C = x.shape
x = self.Wqkv(x).reshape(B, S, 3, self.nh, C // self.nh)
query, key, value = x.transpose(3, 1).unbind(dim=2)
attn = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=self.causal_mask[:, :, :S, :S] if not self.sdpa_flash else None,
is_causal=self.sdpa_flash,
)
return self.Wo(attn.transpose(1, 2).reshape(B, S, C))
class OneLayerTransformer(nn.Module):
def __init__(
self,
vocab_size: int,
hidden_size: int,
num_heads: int,
seq_len: int,
sdpa_flash: bool = False,
loss_fn: nn.Module = nn.CrossEntropyLoss(),
):
super().__init__()
self.loss_fn = loss_fn
self.We = nn.Embedding(vocab_size, hidden_size)
self.attn = Attention(hidden_size, num_heads, seq_len, sdpa_flash)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, input_ids, labels):
x = self.We(input_ids)
x = x + self.attn(self.norm(x))
return self.loss_fn(x.view(-1, x.shape[-1]), labels.view(-1))
def main():
args = parse_args()
print(args)
stats = []
model = OneLayerTransformer(
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
num_heads=args.num_heads,
seq_len=args.max_sequence_length,
sdpa_flash=args.sdpa_flash,
).to(dtype=torch.bfloat16, device="cuda")
optimizer = optim.SGD(model.parameters(), lr=0.001)
dataset = get_dataset(
args.dataset_size * args.batch_size,
args.max_sequence_length,
args.vocab_size - 1,
).with_format("torch")
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.num_workers,
collate_fn=data_collator,
)
if args.torch_compile:
print("Compiling model")
model = torch.compile(model)
torch.cuda.reset_peak_memory_stats(0)
for i, batch in enumerate(tqdm(dataloader)):
input_ids = batch["input_ids"].to("cuda")
labels = batch["labels"].to("cuda")
if args.profile_memory and i == 0:
torch.cuda.memory._record_memory_history()
append_stats(i, stats)
loss = model(input_ids=input_ids, labels=labels)
append_stats(i, stats)
loss.backward()
append_stats(i, stats)
optimizer.step()
optimizer.zero_grad()
if args.profile_memory and i == 0:
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
for label, stat in zip(["forward", "backward", "optimizer"], stats):
print(f"Before {label}: {stat/2**30:.2f} GiB")
print(f"Max Reserved: {torch.cuda.max_memory_reserved(0)/2**30:.2f} GiB")
if __name__ == "__main__":
main() |
Hi @warner-benjamin, this will be fixed in #30070 in case you are not using torch.compile. In case you use torch.compile, we will first need pytorch/pytorch#120400 (only disallow FA2 for compile with fullgraph=True). |
System Info
Transformers 4.37.2, 4.38.2, & 4.39.3.
Python 3.11.
PyTorch 2.2.2 w/ Cuda 12.1
Who can help?
@ArthurZucker and @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
QLoRA Llama-7B with the SDPA Attention when ran with 4.37.2 on a 24GB card uses less memory than 4.38.2 and 4.39.3. You can reproduce with this script:
# 4.37.2 python train.py --batch_size 1 --max_sequence_length 1280 python train.py --batch_size 1 --max_sequence_length 1536
Expected behavior
While spot checking our FSDP+QLoRA script on Transformers 4.39.3, I noticed that the maximum batch size we could finetune on two 24 GB cards with a sequence length of 2048 was reduced from 12 to 5 compared to 4.37.2. This is due to Llama 2 in 4.38 and 4.39 using significantly more memory than 4.37.
This Llama memory issue persists post #29753, which resolved some but not all of the Llama rewrite memory issues mentioned in #29484 and other issues.
I also reproduced the issue without FSDP on one 24GB card using the script above.
4.39.3 uses almost a GB more memory at a sequence length of 1280 and errors out at 1536.
Using
torch.cuda.memory._record_memory_history
shows a peak in memory usage in 4.37.2 at the start of the backward pass as expected. (Use the--peak_memory
flag in the above script).Switching to 3.49.3 shows an unexpected result: consistently large memory spikes during the backward pass of the SDPA kernel:
When sharding Llama across two cards with FSDP and gradient checkpointing, the memory spikes become quite visible as ~10GB outliers:
There are no spikes in 4.37 with FSDP.
The culprit appears to be the switch from using the SDPA Flash Attention kernel in 4.37.2 to the SDPA Efficient kernel in 3.38 & 4.39.
By removing the
is_causal
and using a customcausal_mask
instead,scaled_dot_product_attention
is now using the more memory hungry Efficient kernel instead of the memory efficient Flash Attention 2 kernel.You can spot check this by using
LlamaFlashAttention2
in the reproduction script by using--flash_attn
flag.With
LlamaFlashAttention2
, 4.39 uses a moderate amount more memory than 4.37. Although this may increase to a significant amount at longer context sizes.When training,
LlamaSdpaAttention
should useis_causal=True
if there isn't an attention mask passed to Llama instead of creatingcausal_mask
.The switch to a custom
causal_mask
also causestorch.compile
errors which are not present in 4.37 due to data type mismatches in 4.39, particularly when training in pure bfloat16 instead of mixed precision.The text was updated successfully, but these errors were encountered: