Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer does not release all CUDA memory #567

Open
lopozz opened this issue Oct 22, 2024 · 8 comments
Open

Trainer does not release all CUDA memory #567

lopozz opened this issue Oct 22, 2024 · 8 comments

Comments

@lopozz
Copy link

lopozz commented Oct 22, 2024

Im am currently trying to run a kfold trining loop. At the end of each iteration I free memory using gc.collect() and torch.cuda.empty_cache() but seems not to do the job completely. I leave the code here:

dataset = load_from_disk(os.path.join("data", cfg.kfold_kwargs.kfold_dataset_name))

folds = StratifiedKFold(n_splits=cfg.kfold_kwargs.n_splits)
splits = list(folds.split(np.zeros(dataset.num_rows), dataset[cfg.label_column]))

args = setfit.TrainingArguments(**cfg.train_kwargs)

all_metrics = []

for train_idxs, test_idxs in splits:
    fold_dataset = DatasetDict(
        {
            "train": dataset.select(train_idxs),
            "test": dataset.select(test_idxs),
        }
    )

    trainer = setfit.Trainer(
        model_init=model_init_fn(cfg.model_kwargs),
        args=args,
        train_dataset=fold_dataset["train"],
        eval_dataset=fold_dataset["test"],
        metric=custom_metrics_fn(fold_dataset, cfg.label_column),
        column_mapping={"text": "text", cfg.label_column: "label"},
    )

    trainer.train()

    # metrics = trainer.evaluate(fold_dataset["test"])
    # all_metrics.append(metrics)
    # print(metrics)

    def memory_stats():
        print(f"Memory allocated: {torch.cuda.memory_allocated()/1024**2}\nMemory reserved: {torch.cuda.memory_reserved()/1024**2}")

    memory_stats()

    del trainer.model.model_head, trainer.model.model_body
    del fold_dataset, trainer
    # torch.cuda.synchronize()
    gc.collect()
    torch.cuda.empty_cache()
    
    memory_stats()
    print('\n')

and my setup:

accelerate==1.0.1
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
alembic==1.13.3
antlr4-python3-runtime==4.9.3
async-timeout==4.0.3
attrs==24.2.0
certifi==2024.8.30
charset-normalizer==3.4.0
colorlog==6.8.2
datasets==3.0.1
dill==0.3.8
evaluate==0.4.3
filelock==3.16.1
frozenlist==1.4.1
fsspec==2024.6.1
greenlet==3.1.1
huggingface-hub==0.26.1
hydra-core==1.3.2
idna==3.10
Jinja2==3.1.4
joblib==1.4.2
Mako==1.3.6
MarkupSafe==3.0.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.4.2
numpy==2.1.2
nvidia-cublas-cu12==12.4.5.8
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.2.1.3
nvidia-curand-cu12==10.3.5.147
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusparse-cu12==12.3.1.170
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.4.127
omegaconf==2.3.0
optuna==4.0.0
packaging==24.1
pandas==2.2.3
pillow==11.0.0
propcache==0.2.0
psutil==6.1.0
pyarrow==17.0.0
python-dateutil==2.9.0.post0
pytz==2024.2
PyYAML==6.0.2
regex==2024.9.11
requests==2.32.3
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
sentence-transformers==3.2.1
setfit==1.1.0
six==1.16.0
SQLAlchemy==2.0.36
sympy==1.13.1
threadpoolctl==3.5.0
tokenizers==0.20.1
torch==2.5.0
tqdm==4.66.5
transformers==4.45.2
triton==3.1.0
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
xxhash==3.5.0
yarl==1.16.0

I also leave the memory printed at each iteration:

Memory allocated: 279.685546875
Memory reserved: 596.0
Memory allocated: 279.685546875
Memory reserved: 342.0


Memory allocated: 411.4501953125
Memory reserved: 738.0
Memory allocated: 411.4501953125
Memory reserved: 484.0


Memory allocated: 542.93359375
Memory reserved: 876.0
Memory allocated: 542.93359375
Memory reserved: 626.0


Memory allocated: 674.4638671875
Memory reserved: 1052.0
Memory allocated: 674.4638671875
Memory reserved: 780.0

Does anyone could suggest the reason?

@muhammadravi251001
Copy link

Currently get the same issue..

@muhammadravi251001
Copy link

Did you get the solution lately? @lopozz

@chschroeder
Copy link

Hi, this is likely a problem of sentence-transformers. I started to collect information here and linked from a few other issues, this might give you some pointers: UKPLab/sentence-transformers#1793 .

As far as I know, this is still unsolved.

@muhammadravi251001
Copy link

muhammadravi251001 commented Nov 24, 2024

Thank you very much for your response, @chschroeder! Since this issue is still unsolved, I tried exploring alternative approaches to address it.

The main reason I encountered this error is that I need to initialize the SetFitModel in every iteration of the loop (this is essential for my research, as I aim to train the SetFitModel in isolation using only one language at a time for each iteration).

To work around this issue, I modified my approach by:

  1. Initializing the SetFitModel outside the loop.
  2. Adding a reset_parameters(model) function within the loop.

Here’s the snippet for the reset_parameters function:

def reset_parameters(model):
  def reset_model_body(model_body):
    if model_body is not None:
      def init_weights(module):
        if hasattr(module, 'reset_parameters'):
          module.reset_parameters()
          print("Model body parameters reset using `reset_parameters` function.")
        elif isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
          torch.nn.init.xavier_uniform_(module.weight)
          if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
            print("Model body parameters reset using Xavier initialization and zero bias.")

        model_body.apply(init_weights)
        print("Model body parameters have been successfully reset.")

  if model.model_body is not None:
    reset_model_body(model.model_body)
  
  if hasattr(model.model_head, 'apply'):
    model.model_head.apply(model.model_head._init_weight)
    print("Model head parameters reset using `_init_weight` function.")
  
  if hasattr(model.model_head.linear, 'reset_parameters'):
    model.model_head.linear.reset_parameters()
    print("Model head linear parameters reset using `reset_parameters` function.")

# Use this inside the loop
reset_parameters(model)

With that approach, I use "private" method from SetFit in this line: model.model_head._init_weight while implementation of the _init_weight is on this line in the SetFit repository.

Using this approach, I managed to avoid the CUDA memory issue for now. I hope this helps anyone facing a similar challenge!

@chschroeder
Copy link

Interesting, thanks for the feedback. I was thinking: why should resetting the weights free memory? This means you are suspecting gradients to be the cause of the memory increase, right?

@muhammadravi251001
Copy link

Yup, I believe gradients could be one of the possible causes, but I don’t rule out the possibility of SetFitModel artifacts also contributing to the issue. Actually, the main purpose of my code above is to avoid creating multiple SetFitModel artifacts (one for each iteration), keeping it to just a single object while still ensuring a fresh start (zero knowledge) for each language-specific training session.

@munterkalmsteiner
Copy link

I have the same issue of getting an OOM error when using KFold nested cross-validation. I have an outer KFold loop on which I test the identified optimal parameters, identified in an inner KFold loop. In each optuna trial, I run the inner loop.

The workaround by @muhammadravi251001 reduces the memory consumption in the inner loop, i.e. the memory consumption for each inner fold stays the same. But with each trial, the memory consumption increases and ends in an OOM error.

Without any cross-validation loops, I can run 100s of trials with hyperparameter_search without OOM error (note: without any particular workaround!).

I would be great to find a solution for this, given that cross-validation is somewhat necessary to be able to express how confident one can be about a trained model with few examples.

@munterkalmsteiner
Copy link

Here is a minimal example, straight from the setfit tutorial, to reproduce this memory leak.

from setfit import SetFitModel, Trainer, TrainingArguments, sample_dataset
from datasets import load_dataset
import torch
import gc

def memory_stats():
    return f"\nMemory allocated: {torch.cuda.memory_allocated()/1024**2}\nMemory reserved: {torch.cuda.memory_reserved()/1024**2}"

# Initializing a new SetFit model
model = SetFitModel.from_pretrained("BAAI/bge-small-en-v1.5", labels=["negative", "positive"])

# Preparing the dataset
dataset = load_dataset("SetFit/sst2")
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=8)
test_dataset = dataset["test"]

# Preparing the training arguments
args = TrainingArguments(
    batch_size=32,
    num_epochs=10,
)

# Preparing the trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
)
trainer.train()

# Evaluating
metrics = trainer.evaluate(test_dataset)
print(metrics)
# => {'accuracy': 0.8511806699615596}

print('\nMemory stats before release:',memory_stats())

del trainer
del model
gc.collect()
torch.cuda.empty_cache()

print('\nMemory stats after release:',memory_stats())

Output of the print statements

Memory stats before release: 
Memory allocated: 398.34228515625
Memory reserved: 1478.0

Memory stats after release: 
Memory allocated: 398.34228515625
Memory reserved: 510.0

There is a "fix", reported in this issue, involving commenting out this if statement in trainer.py of transformers.

Output of the print statements after the hacky fix:

Memory stats before release: 
Memory allocated: 398.34228515625
Memory reserved: 1478.0

Memory stats after release: 
Memory allocated: 16.25
Memory reserved: 40.0

But as @muellerzr has shown, the issue does not origin from transformer but from the libraries building on top of it. @tomaarsen any idea how this can be resolved?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants