Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't release memory occupied by model after trainer.train() with del model and gc.collect(). #26571

Closed
2 of 4 tasks
hanrui4248 opened this issue Oct 3, 2023 · 10 comments
Closed
2 of 4 tasks
Assignees

Comments

@hanrui4248
Copy link

hanrui4248 commented Oct 3, 2023

System Info

  • torch==2.0.1
  • transformers==4.31.0
  • peft==0.4.0
  • accelerate==0.20.3
  • bitsandbytes==0.41.1
  • gpu : Quadro RTX 8000

Who can help?

@muellerzr @pacman1 @ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm using lora and flan-t5 small model for summarization task, and I want to release memory occupied by model. However, it didn't work even though I tried using del model and gc.collect(), following is my code:

import torch
import gc
from functools import partial
from peft import prepare_model_for_int8_training
from peft import LoraConfig, get_peft_model
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from metrics import compute_metrics
from ordalie_dataset import get_processed_ordalie_dataset

def memory_stats():
    print("memory allocated: ", torch.cuda.memory_allocated()/1024**2)
    print("memory reserved: ", torch.cuda.memory_reserved()/1024**2)

model_name = "google/flan-t5-small"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, load_in_8bit=True)

print("model's memory:")
memory_stats()

tokenizer = AutoTokenizer.from_pretrained(model_name)
lora_config = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q", "v"], lora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM"
)

model = prepare_model_for_int8_training(model)

model = get_peft_model(model, lora_config)

args = Seq2SeqTrainingArguments(
    "temp",  
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    gradient_accumulation_steps=12,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    save_strategy = "no",
    predict_with_generate=True, 
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

dataset = get_processed_ordalie_dataset(
            tokenizer,
            512,
            42,
        )

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=partial(compute_metrics, tokenizer=tokenizer),
)


trainer.train()

print("memory before release:")
memory_stats()

del model
del data_collator
del trainer
gc.collect()
torch.cuda.empty_cache()
print("memory after release:")
memory_stats()

Output:

model's memory:
memory allocated:  130.3193359375
memory reserved:  136.0

memory before release:
memory allocated:  244.77490234375
memory reserved:  18390.0

memory after release:
memory allocated:  223.77490234375
memory reserved:  288.0

output of nvidia-smi after training looks like:

Tue Oct  3 12:16:25 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Quadro RTX 8000                On  | 00000000:14:00.0 Off |                 Off* |
| 34%   25C    P8              16W / 260W |    497MiB / 49152MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     15268      C   ...i.huang/.conda/envs/llms/bin/python      494MiB |
+---------------------------------------------------------------------------------------+

Expected behavior

The output after release should be 0 for both allocated and reserved memory.
I also tried move the release memory operations before the trainer.train() , as shown below:

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=partial(compute_metrics, tokenizer=tokenizer),
)

del model
del data_collator
del trainer

print("memory after release:")
memory_stats()

trainer.train()

The memory is successfully released after I made this change:

memory after release:
memory allocated:  0.0
memory reserved:  0.0

This leads me to suspect that there might be some internal references to the model within trainer.train(). So I delved into the source code of trainer.train() and, by copying the entire original method and removing certain lines, I identified potential places that could cause memory leakage:

  1. transformers/trainer.py line 1988
if self.control.should_epoch_stop or self.control.should_training_stop:
                    break
  1. transformers/trainer.py line 1891
with self.accelerator.accumulate(model):
                    tr_loss_step = self.training_step(model, inputs)

After removing only place 1, and placing the memory release operations after trainer.train(), the output is:

memory after release:
memory allocated:  16.25
memory reserved:  40

After removing both place 1 and place 2, and placing the memory release operations after trainer.train(), the output is::

memory after release:
memory allocated:  0.0
memory reserved:  0.0

I'm trying to understand what's causing this behavior, but it seems so magical. Is this a bug? How can I fully release the memory? I need to instantiate and train the model multiple times, but I don't have enough memory to instantiate the XXL model twice (without releasing memory in between).

@LysandreJik
Copy link
Member

WDYT @muellerzr @pacman100 ?

@muellerzr muellerzr self-assigned this Oct 4, 2023
@muellerzr
Copy link
Contributor

muellerzr commented Oct 4, 2023

I'm not sure we really can reduce it all to zero, due to PyTorch itself. Take the below example, which removes all major transformers and accelerate code and does everything in pure python, bare freeing of memory (it's the same thing as what you do manually there). No matter what, we are still left with 8.125 in allocated, 20.0 in reserved. Also: this only shows up after getting an output from the model.

Script in question:

import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from accelerate.utils import release_memory, send_to_device

config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}

MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32

def memory_stats():
    return torch.cuda.memory_summary()


def get_dataloader(batch_size: int = 16):
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    train_dataset = load_dataset("glue", "mrpc", split="train[:64]")

    def tokenize_function(examples):
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None)
        return outputs

    tokenized_datasets = train_dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
    )

    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        max_length = None
        pad_to_multiple_of = None

        return tokenizer.pad(
            examples,
            padding="longest",
            max_length=max_length,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors="pt",
        )

    train_dataloader = DataLoader(
        tokenized_datasets, shuffle=True, collate_fn=collate_fn, batch_size=batch_size, drop_last=True
    )

    return train_dataloader

batch_size = int(config["batch_size"])
train_dataloader = get_dataloader(batch_size)
model = AutoModel.from_pretrained("bert-base-cased")
model = model.to("cuda")
model.eval()
with torch.inference_mode():
    batch = next(iter(train_dataloader))
    batch = batch.to("cuda")
    out = model.forward(batch["input_ids"], batch["attention_mask"])
    out = send_to_device(out, "cpu")

model.cpu()

model, batch = release_memory(model, batch)
print(
    f"Memory allocated: {torch.cuda.memory_allocated()/1024**2}\nMemory reserved: {torch.cuda.memory_reserved()/1024**2}"
)

To make sure this is actually pytorch and not something to do with transformers, I then checked with a basic pytorch model:

import torch
from accelerate.utils import release_memory

def memory_stats():
    return torch.cuda.memory_summary()

class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(100, 200)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(200, 10)
        self.softmax = torch.nn.Softmax(dim=0)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x
    
model = TinyModel().cuda()
batch = torch.rand(64,100).cuda()
_ = model(batch)
model, batch = release_memory(model, batch)
print(
    f"Memory allocated: {torch.cuda.memory_allocated()/1024**2}\nMemory reserved: {torch.cuda.memory_reserved()/1024**2}"
)

If you run this you will find that yet again, we have a similar leftover memory allocation.

So I'm not 100% convinced that this is a problem we can solve.

If you can release those memory allocations then we can work with that solution, but after extensive research it is impossible I have found to free up all of it entirely after the model has been ran on an input. Likely this is some intermediate activations that somehow are still able to be allocated and never be freed.

Note: including inference_mode/no_grad and model.eval() did not change those end allocation results

@muellerzr
Copy link
Contributor

Official response from the torch team:

The memory used by the CUDA Context itself will still be there. So you won't be able to get the GPU back to 0 I'm afraid.

So, that 16/40 will always remain and there isn't anything we can do else aside from that

@hanrui4248
Copy link
Author

Official response from the torch team:

The memory used by the CUDA Context itself will still be there. So you won't be able to get the GPU back to 0 I'm afraid.

So, that 16/40 will always remain and there isn't anything we can do else aside from that

Thank you for your response!
I ran your code and obtained the same result. It's completely acceptable to have 16/40 of the memory remaining. However, as I mentioned, I can't release the memory occupied by the model itself after executing trainer.train(). To do so, I would need to customize the source code by deleting a specific line in transformers/trainer.py at line 1988 as as shown below:

if self.control.should_epoch_stop or self.control.should_training_stop:
                    break

With the inclusion of this line, I'm unable to release the memory occupied by the model:

memory after release:
memory allocated:  223.77490234375
memory reserved:  288.0

After deleting this line I got the expected behavior:

memory after release:
memory allocated:  16.25
memory reserved:  40

@muellerzr Could you please provide any suggestions or make changes to trainer.py if this is indeed a bug? It will be much appreciated.

@muellerzr
Copy link
Contributor

You need to fully remove the model off CUDA, yes

@hanrui4248
Copy link
Author

You need to fully remove the model off CUDA, yes

Do you mean using model.cpu() transition the model from CUDA to CPU? While this did free up more memory, but it wasn't sufficient. When I scaled the model to the XXL version and then applied model.cpu() along with release_memory, it still has about 8000 memory allocated remaining.

However, when I delete line 1988 in trainer.py and then call release_memory, it always has only 16.25 memory allocated remaining no matter the size of model.

How can I fully remove the model off CUDA? Could you please recheck this? Thank you! @muellerzr

@muellerzr
Copy link
Contributor

@hanrui4248 I was successful after the following:

...
trainer.train()
del model, trainer
gc.collect()
torch.cuda.empty_cache()

This got me to the tiny amount of memory allocated after. My full script:

import gc
import torch

from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorWithPadding

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

def memory_stats():
    return f"Memory allocated: {torch.cuda.memory_allocated()/1024**2}\nMemory reserved: {torch.cuda.memory_reserved()/1024**2}"

imdb = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")


def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_imdb = imdb.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}


model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)

print('Model memory:',memory_stats())

training_args = TrainingArguments(
    output_dir="my_awesome_model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,
    max_steps=10,
    weight_decay=0.01,
    save_strategy="no",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_imdb["train"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

print('\nMemory stats before release:',memory_stats())

del trainer
del model
gc.collect()
torch.cuda.empty_cache()

print('\nMemory stats after release:',memory_stats())

Print statements:

Memory stats before release: Memory allocated: 786.18212890625
Memory reserved: 6290.0
Memory stats after release: Memory allocated: 17.13671875
Memory reserved: 44.0

@hanrui4248
Copy link
Author

@hanrui4248 I was successful after the following:

...
trainer.train()
del model, trainer
gc.collect()
torch.cuda.empty_cache()

Thank you!

But it didn't work with my script. Could this be because I used peft and lora in it? I think I've tried every possible way to release the memory, but I still can't free up the entire model's memory. After lot of tries, here are my conclusions:
1.only use del and gc.collect() doesn't work
2.Combining del and gc.collect() with model.cpu() can release more memory, but a significant amount of memory still remains .
3.By deleting line 1988 in trainer.py and then do the same memory release operation in 1. can completely remove the model from CUDA.

Here is my script. Could you confirm my conclusion by executing it?@muellerzr Thank you!

import torch
import gc
from functools import partial
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset
from datasets.formatting.formatting import LazyBatch
from transformers import (
    AutoModelForSeq2SeqLM, 
    AutoTokenizer, 
    DataCollatorForSeq2Seq, 
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer, 
    PreTrainedTokenizer, 
    PreTrainedTokenizerFast
)
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model
from metrics import compute_metrics

def memory_stats():
    print("memory allocated: ", torch.cuda.memory_allocated()/1024**2)
    print("memory reserved: ", torch.cuda.memory_reserved()/1024**2)

def get_processed_ordalie_dataset(
    tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
    max_length: int,
    seed: int,
) -> DatasetDict | Dataset | IterableDatasetDict | IterableDataset:
    # load dataset
    dataset = load_dataset("OrdalieTech/baby-ordalie")
    # since this dataset doesn't have validation split, create it manually.
    test_val_split = dataset["train"].train_test_split(test_size=len(dataset["test"]), seed=seed)
    dataset["train"] = test_val_split["train"]
    dataset["validation"] = test_val_split["test"]

    # Process data
    def process_data_to_model_inputs(examples: LazyBatch) -> LazyBatch:
        model_inputs = tokenizer(
            examples["input"],
            max_length=max_length,
            truncation=True,
        )
        labels = tokenizer(examples["output"])
        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    tokenized_datasets = dataset.map(process_data_to_model_inputs, batched=True)
    tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

    # Remove unnecessary columns
    tokenized_datasets = tokenized_datasets.remove_columns(dataset["train"].column_names)

    return tokenized_datasets


model_name = "google/flan-t5-small"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, load_in_8bit=True)

print("model's memory:")
memory_stats()

tokenizer = AutoTokenizer.from_pretrained(model_name)
lora_config = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q", "v"], lora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM"
)

model = prepare_model_for_int8_training(model)

model = get_peft_model(model, lora_config)

args = Seq2SeqTrainingArguments(
    "temp",  
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    gradient_accumulation_steps=12,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=1,
    save_strategy = "no",
    predict_with_generate=True, 
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

dataset = get_processed_ordalie_dataset(
            tokenizer,
            512,
            42,
        )

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=partial(compute_metrics, tokenizer=tokenizer),
)


trainer.train()

print("memory before release:")
memory_stats()

del trainer
del model
gc.collect()
torch.cuda.empty_cache()
print("memory after release:")

memory_stats()

@muellerzr
Copy link
Contributor

In that case the issue stems from peft, so I'd recommend migrating/opening this issue to there as I'm not sure what it could be :)

Copy link

github-actions bot commented Nov 3, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants