-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can't release memory occupied by model after trainer.train() with del model and gc.collect(). #26571
Comments
WDYT @muellerzr @pacman100 ? |
I'm not sure we really can reduce it all to zero, due to PyTorch itself. Take the below example, which removes all major transformers and accelerate code and does everything in pure python, bare freeing of memory (it's the same thing as what you do manually there). No matter what, we are still left with 8.125 in allocated, 20.0 in reserved. Also: this only shows up after getting an output from the model. Script in question: import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from accelerate.utils import release_memory, send_to_device
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32
def memory_stats():
return torch.cuda.memory_summary()
def get_dataloader(batch_size: int = 16):
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
train_dataset = load_dataset("glue", "mrpc", split="train[:64]")
def tokenize_function(examples):
outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
return outputs
tokenized_datasets = train_dataset.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
def collate_fn(examples):
max_length = None
pad_to_multiple_of = None
return tokenizer.pad(
examples,
padding="longest",
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
)
train_dataloader = DataLoader(
tokenized_datasets, shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
)
return train_dataloader
batch_size = int(config["batch_size"])
train_dataloader = get_dataloader(batch_size)
model = AutoModel.from_pretrained("bert-base-cased")
model = model.to("cuda")
model.eval()
with torch.inference_mode():
batch = next(iter(train_dataloader))
batch = batch.to("cuda")
out = model.forward(batch["input_ids"], batch["attention_mask"])
out = send_to_device(out, "cpu")
model.cpu()
model, batch = release_memory(model, batch)
print(
f"Memory allocated: {torch.cuda.memory_allocated()/1024**2}\nMemory reserved: {torch.cuda.memory_reserved()/1024**2}"
) To make sure this is actually pytorch and not something to do with transformers, I then checked with a basic pytorch model: import torch
from accelerate.utils import release_memory
def memory_stats():
return torch.cuda.memory_summary()
class TinyModel(torch.nn.Module):
def __init__(self):
super(TinyModel, self).__init__()
self.linear1 = torch.nn.Linear(100, 200)
self.activation = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(200, 10)
self.softmax = torch.nn.Softmax(dim=0)
def forward(self, x):
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
x = self.softmax(x)
return x
model = TinyModel().cuda()
batch = torch.rand(64,100).cuda()
_ = model(batch)
model, batch = release_memory(model, batch)
print(
f"Memory allocated: {torch.cuda.memory_allocated()/1024**2}\nMemory reserved: {torch.cuda.memory_reserved()/1024**2}"
) If you run this you will find that yet again, we have a similar leftover memory allocation. So I'm not 100% convinced that this is a problem we can solve. If you can release those memory allocations then we can work with that solution, but after extensive research it is impossible I have found to free up all of it entirely after the model has been ran on an input. Likely this is some intermediate activations that somehow are still able to be allocated and never be freed. Note: including |
Official response from the torch team:
So, that 16/40 will always remain and there isn't anything we can do else aside from that |
Thank you for your response!
With the inclusion of this line, I'm unable to release the memory occupied by the model:
After deleting this line I got the expected behavior:
@muellerzr Could you please provide any suggestions or make changes to |
You need to fully remove the model off CUDA, yes |
Do you mean using However, when I delete line 1988 in How can I fully remove the model off CUDA? Could you please recheck this? Thank you! @muellerzr |
@hanrui4248 I was successful after the following: ...
trainer.train()
del model, trainer
gc.collect()
torch.cuda.empty_cache() This got me to the tiny amount of memory allocated after. My full script: import gc
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
def memory_stats():
return f"Memory allocated: {torch.cuda.memory_allocated()/1024**2}\nMemory reserved: {torch.cuda.memory_reserved()/1024**2}"
imdb = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
tokenized_imdb = imdb.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)
print('Model memory:',memory_stats())
training_args = TrainingArguments(
output_dir="my_awesome_model",
learning_rate=2e-5,
per_device_train_batch_size=16,
gradient_accumulation_steps=1,
max_steps=10,
weight_decay=0.01,
save_strategy="no",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_imdb["train"],
tokenizer=tokenizer,
data_collator=data_collator,
)
trainer.train()
print('\nMemory stats before release:',memory_stats())
del trainer
del model
gc.collect()
torch.cuda.empty_cache()
print('\nMemory stats after release:',memory_stats()) Print statements:
|
Thank you! But it didn't work with my script. Could this be because I used peft and lora in it? I think I've tried every possible way to release the memory, but I still can't free up the entire model's memory. After lot of tries, here are my conclusions: Here is my script. Could you confirm my conclusion by executing it?@muellerzr Thank you!
|
In that case the issue stems from peft, so I'd recommend migrating/opening this issue to there as I'm not sure what it could be :) |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
System Info
Who can help?
@muellerzr @pacman1 @ArthurZucker @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I'm using lora and flan-t5 small model for summarization task, and I want to release memory occupied by model. However, it didn't work even though I tried using
del model
andgc.collect()
, following is my code:Output:
output of nvidia-smi after training looks like:
Expected behavior
The output after release should be 0 for both allocated and reserved memory.
I also tried move the release memory operations before the
trainer.train()
, as shown below:The memory is successfully released after I made this change:
This leads me to suspect that there might be some internal references to the model within trainer.train(). So I delved into the source code of trainer.train() and, by copying the entire original method and removing certain lines, I identified potential places that could cause memory leakage:
transformers/trainer.py
line 1988transformers/trainer.py
line 1891After removing only place 1, and placing the memory release operations after
trainer.train()
, the output is:After removing both place 1 and place 2, and placing the memory release operations after
trainer.train()
, the output is::I'm trying to understand what's causing this behavior, but it seems so magical. Is this a bug? How can I fully release the memory? I need to instantiate and train the model multiple times, but I don't have enough memory to instantiate the XXL model twice (without releasing memory in between).
The text was updated successfully, but these errors were encountered: