Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why does bf16 of LightningLite use GPU Memory much more than pytorch_lightning.Trainer #10371

Closed
gitabtion opened this issue Nov 5, 2021 · 7 comments · Fixed by #10429
Closed
Assignees
Labels
bug Something isn't working fabric lightning.fabric.Fabric help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@gitabtion
Copy link

gitabtion commented Nov 5, 2021

🐛 Bug

I test two ways to train my model, LightningLite and pytorch_lightning.Trainer on 2 A100 machines with the same configuration, The Memory used is very different.

  • LightningLite(bf16): about 24GB
  • LightningLite(fp16): about 27GB
  • LightningLite(fp32): about 27GB
  • LightningLite(deepspeed, fp16): about 14GB
  • pytorch_lightning.Trainer(bf16): about 17GB

To Reproduce

class Lite(LightningLite):
    def run(self, model, optimizer, lr_scheduler, logger, num_files=20, global_step=0):
        self.setup(model, optimizer)
        while True:
            fps = self.get_fps(model.cfg)
            for i in range(len(fps) // num_files):
                train_loader = self.get_data(model.cfg, i, fps, num_files)
                train_loader = self.setup_dataloaders(train_loader)
                model.train()
                metric = {'loss': 0.0, 'mlm_acc': 0.0, 'sop_acc': 0.0, 'log_step': 0.0}
                start_time = time.time()
                for j, batch in enumerate(tqdm(train_loader)):
                    outputs = model(**batch, output_hidden_states=True)
                    labels = batch['labels']

                    loss = outputs.loss
                    optimizer.zero_grad()
                    self.backward(loss)
                    optimizer.step()
                    lr_scheduler.step()
                    global_step += 1

                    if global_step >= model.cfg.SOLVER.MAX_ITERS:
                        break
                if global_step >= model.cfg.SOLVER.MAX_ITERS:
                    break
            if global_step >= model.cfg.SOLVER.MAX_ITERS:
                break

    def get_fps(self, cfg):
        """
        get files path for train.
        """
        if ',' in cfg.DATASETS.TRAIN:
            dirs = cfg.DATASETS.TRAIN.split(',')
            dirs = [d.strip() for d in dirs]
        else:
            dirs = [cfg.DATASETS.TRAIN]
        fps = []
        for d in dirs:
            _fns = os.listdir(get_abs_path(d))
            _fns = [fn for fn in _fns if fn[:4] in ['part', 'ppar']]
            _fps = [get_abs_path(d, fn) for fn in _fns]
            fps += _fps
        random.shuffle(fps)
        return fps

    def get_data(self, cfg, idx, fps, num=20):
        parts = get_parts(fps, idx, num)
        train_loader, _, _ = make_loaders(cfg, fps=parts, num_proc=8)
        return train_loader


def get_parts(fps, i, n):
    if i > len(fps) // n:
        i = 0
    elif i == len(fps) // n and len(fps) % n == 0:
        i = 0
    return fps[i * n:(i + 1) * n]


def main():
    cfg = args_parse('deberta/deberta_v2.yml')
    model = BertsForPretraining(cfg)
    optimizer = make_optimizer(cfg, model)
    lr_scheduler = build_lr_scheduler(cfg, optimizer, offset=0)['scheduler']
    logger = getLogger(cfg.MODEL.NAME)
    logger.setLevel(logging.DEBUG)
    gpus = cfg.MODEL.GPUS if cfg.MODEL.GPUS != 0 else cfg.MODEL.GPU_IDS
    lite = Lite(
        strategy='ddp',
        gpus=gpus,
        accelerator='gpu',
        precision='bf16',
        num_nodes=2
    )
    try:
        lite.run(model, optimizer, lr_scheduler, logger, 20, 2650000)
    except:
        logger.error(traceback.format_exc())


if __name__ == '__main__':
    seed_everything()
    main()

Expected behavior

Environment

  • PyTorch Lightning Version (e.g., 1.3.0): 1.5.0
  • PyTorch Version (e.g., 1.8): 1.10+cu113
  • Python version: 3.8.12
  • OS (e.g., Linux): ubuntu20.04
  • CUDA/cuDNN version: 11.4
  • Device: 2 machines, 8*A100 each.
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source):
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Additional context

Here is env dockerfile.

FROM nvcr.io/nvidia/pytorch:21.09-py3
RUN apt-get update && apt-get -y install openssh-server htop tmux \
    && mkdir -p /var/run/sshd
RUN pip install -U pytorch-lightning transformers yacs ujson pkuseg pypinyin deepspeed datasets tqdm wandb && pip uninstall -y torch && pip install torch==1.10.0+cu113 torchvision==0.11.1+cu113 torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html && pip uninstall -y torchtext
@gitabtion gitabtion added bug Something isn't working help wanted Open to be worked on labels Nov 5, 2021
@awaelchli awaelchli changed the title Why dose bf16 of LightningLite use GPU Memory much more than pytorch_lightning.Trainer Why does bf16 of LightningLite use GPU Memory much more than pytorch_lightning.Trainer Nov 5, 2021
@awaelchli awaelchli self-assigned this Nov 5, 2021
@tchaton tchaton added the fabric lightning.fabric.Fabric label Nov 5, 2021
@tchaton
Copy link
Contributor

tchaton commented Nov 5, 2021

Dear @gitabtion,

Thanks for looking into LightningLite. We will look into this promptly.
Would you mind sharing your thoughts on Lite? Any feedbacks?

Best,
T.C

@awaelchli
Copy link
Contributor

awaelchli commented Nov 5, 2021

It's possible that (part) the memory increase comes from the fact that we convert the output of the model back to float32. It happens here:

https://github.com/PyTorchLightning/pytorch-lightning/blob/348fc4b49f0e74acb9785b5179abb8fd01beb45a/pytorch_lightning/lite/wrappers.py#L99-L103

The reason we do this is because the inputs to the model are also float32 and we don't know what the user does with the outputs (e.g., compute loss terms). In Lightning, this would all be under the training_step method so such a conversion is not necessary there.

@gitabtion
Copy link
Author

gitabtion commented Nov 8, 2021

@awaelchli Thanks for your reply, I have tried to print the outputs of the Trainer, it is float32 too. And bf16 of GPU memory used is close to fp32, only about 3GB. I have updated the GPU Memory used of some precison, hope that is helpful for you.

@gitabtion
Copy link
Author

@tchaton
LightningLite is a great feature for me, I have many tricks in training, LightningLite is more flexible sometime, and is easier to find the problem than pytorch_lightning.Trainer.

@awaelchli
Copy link
Contributor

@gitabtion so far I haven't been able to detect a difference between Lightning and Lite in training with bf16 precision on our basic examples.

Lite:

python pl_examples/basic_examples/mnist_examples/image_classifier_2_lite.py

with Lite(strategy="ddp", gpus=2, accelerator="gpu", precision="bf16")

(2462 MB)

vs.

python pl_examples/basic_examples/mnist_examples/image_classifier_4_lightning_module.py  --trainer.gpus 2 --trainer.accelerator gpu --trainer.strategy ddp --trainer.precision bf16

(2464 MB)

on 2x A100 GPUs.

Your issue is not strictly related to multi-node training, right?

@gitabtion
Copy link
Author

gitabtion commented Nov 9, 2021

@awaelchli Yes, difference could be detect in one node, you can use following script to reproduce this bug.

import os
import torch
import transformers as tfs
from pytorch_lightning.lite import LightningLite
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from pytorch_lightning import seed_everything


class Lite(LightningLite):
    def run(self, model, optimizer):
        self.setup(model, optimizer)

        for i in range(10):
            train_loader = self.get_data()
            train_loader = self.setup_dataloaders(train_loader)
            model.train()
            for j, batch in enumerate(tqdm(train_loader)):
                input_ids, labels = batch
                outputs = model(input_ids=input_ids, labels=labels, output_hidden_states=True)

                loss = outputs.loss
                optimizer.zero_grad()
                self.backward(loss)
                optimizer.step()

    def get_data(self):
        input_ids = torch.randint(21128, (4096, 512))
        labels = torch.randint(21128, (4096, 512))
        dataset = TensorDataset(input_ids, labels)
        train_loader = DataLoader(dataset, batch_size=16)
        return train_loader


def main():
    model = tfs.BertForMaskedLM.from_pretrained('bert-base-chinese')
    optimizer = torch.optim.Adam(model.parameters())
    lite = Lite(
        strategy='deepspeed',
        gpus=4,
        accelerator='gpu',
        precision=16,
        num_nodes=int(os.environ.get('NUM_NODES', '1'))
    )
    lite.run(model, optimizer)


if __name__ == '__main__':
    seed_everything()
    main()

DeepSpeed FP16: about 11GB
FP16, FP32, BF16: about 19GB

maybe the line self.setup(model, optimizer) should be model, optimizer = self.setup(model, optimizer), but get error by last one.

Traceback (most recent call last):
  File "lite_toy_test.py", line 52, in <module>
    main()
  File "lite_toy_test.py", line 47, in main
    lite.run(model, optimizer)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/lite/lite.py", line 406, in _run_impl
    return run_method(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/lite/lite.py", line 410, in _run_with_sharded_context
    return run_method(*args, **kwargs)
  File "lite_toy_test.py", line 22, in run
    outputs = model(input_ids=input_ids, labels=labels, output_hidden_states=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/lite/wrappers.py", line 100, in forward
    output = self.module(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 1328, in forward
    outputs = self.bert(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 983, in forward
    embedding_output = self.embeddings(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/transformers/models/bert/modeling_bert.py", line 215, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/sparse.py", line 158, in forward
    return F.embedding(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 2044, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.HalfTensor instead (while checking arguments for embedding)

I think maybe the ddp train type plugin set model to float forcely, but the model's some parameters is torch.long.

@awaelchli
Copy link
Contributor

@gitabtion Yes, definitely it should be

model, optimizer = self.setup(model, optimizer)

Apologies for not catching that immediately.
The model must be wrapped to be working properly with precision. This should explain your memory problem.

But the error you get is a bug. We are currently converting all inputs unconditionally to the given precision type, but we should only do that for floating point tensors and not for the types Long, Int etc.
Here is where it happens:

https://github.com/PyTorchLightning/pytorch-lightning/blob/0ed5e3dc8abcec40aacd64cc9175590bb1409759/pytorch_lightning/lite/wrappers.py#L97-L104

I will create a fix for this.

@awaelchli awaelchli modified the milestones: 1.6.x, 1.5.x Nov 9, 2021
@awaelchli awaelchli added the priority: 0 High priority task label Nov 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working fabric lightning.fabric.Fabric help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants