-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Why does bf16 of LightningLite use GPU Memory much more than pytorch_lightning.Trainer #10371
Comments
Dear @gitabtion, Thanks for looking into LightningLite. We will look into this promptly. Best, |
It's possible that (part) the memory increase comes from the fact that we convert the output of the model back to float32. It happens here: The reason we do this is because the inputs to the model are also float32 and we don't know what the user does with the outputs (e.g., compute loss terms). In Lightning, this would all be under the |
@awaelchli Thanks for your reply, I have tried to print the outputs of the Trainer, it is float32 too. And bf16 of GPU memory used is close to fp32, only about 3GB. I have updated the GPU Memory used of some precison, hope that is helpful for you. |
@tchaton |
@gitabtion so far I haven't been able to detect a difference between Lightning and Lite in training with bf16 precision on our basic examples. Lite:
with (2462 MB) vs.
(2464 MB) on 2x A100 GPUs. Your issue is not strictly related to multi-node training, right? |
@awaelchli Yes, difference could be detect in one node, you can use following script to reproduce this bug. import os
import torch
import transformers as tfs
from pytorch_lightning.lite import LightningLite
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from pytorch_lightning import seed_everything
class Lite(LightningLite):
def run(self, model, optimizer):
self.setup(model, optimizer)
for i in range(10):
train_loader = self.get_data()
train_loader = self.setup_dataloaders(train_loader)
model.train()
for j, batch in enumerate(tqdm(train_loader)):
input_ids, labels = batch
outputs = model(input_ids=input_ids, labels=labels, output_hidden_states=True)
loss = outputs.loss
optimizer.zero_grad()
self.backward(loss)
optimizer.step()
def get_data(self):
input_ids = torch.randint(21128, (4096, 512))
labels = torch.randint(21128, (4096, 512))
dataset = TensorDataset(input_ids, labels)
train_loader = DataLoader(dataset, batch_size=16)
return train_loader
def main():
model = tfs.BertForMaskedLM.from_pretrained('bert-base-chinese')
optimizer = torch.optim.Adam(model.parameters())
lite = Lite(
strategy='deepspeed',
gpus=4,
accelerator='gpu',
precision=16,
num_nodes=int(os.environ.get('NUM_NODES', '1'))
)
lite.run(model, optimizer)
if __name__ == '__main__':
seed_everything()
main() DeepSpeed FP16: about 11GB maybe the line
I think maybe the ddp train type plugin set model to float forcely, but the model's some parameters is torch.long. |
@gitabtion Yes, definitely it should be
Apologies for not catching that immediately. But the error you get is a bug. We are currently converting all inputs unconditionally to the given precision type, but we should only do that for floating point tensors and not for the types Long, Int etc. I will create a fix for this. |
🐛 Bug
I test two ways to train my model, LightningLite and pytorch_lightning.Trainer on 2 A100 machines with the same configuration, The Memory used is very different.
To Reproduce
Expected behavior
Environment
conda
,pip
, source):torch.__config__.show()
:Additional context
Here is env dockerfile.
The text was updated successfully, but these errors were encountered: