Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Integrating bitsandbytes 8-bit optimizer / adding Embedding Norm #14819

Closed
stas00 opened this issue Dec 17, 2021 · 13 comments · Fixed by #15622
Closed

RFC: Integrating bitsandbytes 8-bit optimizer / adding Embedding Norm #14819

stas00 opened this issue Dec 17, 2021 · 13 comments · Fixed by #15622
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Dec 17, 2021

🚀 Feature request

  1. BNB AdamW Optimizer: https://github.com/facebookresearch/bitsandbytes created by @TimDettmers uses 8-bit quantization technique, which allows to reduce memory usage for the AdamW optimizer from 8 bytes to 2 bytes, which is a huge memory saving and I think our users will benefit a lot from it.

  2. Additionally, we discovered that one of BNB's components, Embedding Norm, on its own made a huge improvement to the training stability of large models @bigscience.

Therefore this is a 2-features in one request.

Performance

We did experiments at BigScience for 104B model and while we didn't have a chance to run it through a full training to the end, BNB was performing on par with the normal AdamW quality-wise.

I'm currently also running a full 1.3B model training with embed norm to compare scaling laws with the same training w/o embed norm. Should be finished in a few days.

Tech

This technology comes in 2 components.

  1. 8-bit quantization optimizer
  2. required Embedding Norm

The optimizer itself is a drop-in replacement for Adam:

import bitsandbytes as bnb
optim = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=8) 

but there is an important requirement of using Embed norm, which is needed to ensure training stability, which we currently don't have.

In fact for BigScience we discovered that adding Embed norm on its own and w/o BNB made a huge difference to training stability and we are most likely going to enable it in the 200B gpt model training, as the current 104B gpt model results are the best when embed norm is enabled. So once we release the 200B model most likely we want the Embed norm in transformers for the custom architecture of that model.

Embedding norm currently appears to be a new default for google and openai models according to Tim.

BNB comes with StableEmbedding which replaces nn.Embedding

So the only integration that is needed on the HF side (other than adding --optim=adamw_bnb to HF Trainer) is to add an embed norm and config option to have it enabled or not. It also wants xavier_uniform init, but that's a minor detail.

Finetuning

For existing pre-trained transformers models one could use them as is and use 8-bit optimizers for all weights, but 32-bit optimizers for the embedding layer. This will improve stability for fine-tuning. Tim shared that for GLUE fine-tuning, it is fine to have 8-bit optimizers for the embedding layer, but in general 32-bit should be more stable.

Pretraining

For pretraining it would make sense to implement the full stable embedding layer. i.e. add a configurable embed norm at the end of Embedding.forward. Here we would want to implement it ourselves rather than re-use StableEmbedding from BNB, so that we can easily load any model from the hub without depending on BNB, after it was trained with BNB.

We obviously can't make this a default for all our models, but perhaps we can consider starting enabling this for some models where we know it makes a huge difference - or at least to recommend to.

@TimDettmers, please let me know if I missed something or you'd like to add anything to my summary here. Thank you!

Comments are welcome.

@patrickvonplaten, @patil-suraj, @LysandreJik, @sgugger

@patrickvonplaten
Copy link
Contributor

Regarding the optimizers for Trainer, I think we can have a "small" breaking change in general and completely remove our implementation of AdamW and instead make use of torch's native AdamW implementation.

I think it's a good idea to add a --optim arg to Trainer

@patrickvonplaten
Copy link
Contributor

Regarding the StableEmbedding, this has to be handled in each model file respectively IMO and should be done in a second PR (if necessary)

@sgugger
Copy link
Collaborator

sgugger commented Dec 20, 2021

I'm very torn about adding an option for the StableEmbedding in the config of some (all?) models so I feel I need more information. Specifically, let's say we had that option to GPT-2 models:

  • can a current checkpoint for GPT-2 (like gpt2) be used with that option enabled in the config and produce good results, or would it need to be retrained?
  • can a checkpoint trained with StableEmbedding be used with a regular Embedding instead if someone disables the config option?

I'm trying to see if it's something like enable gradient checkpointing for instance, which you can use for training without changing anything if the user that ends with your model doesn't want it, or if it impacts the checkpoints in any way. Depending on the answer to that, we will see if we need new model files or not.

@stas00
Copy link
Contributor Author

stas00 commented Dec 21, 2021

Excellent questions, @sgugger

gradient checkpointing doesn't change the math. layer norm does (that's how it makes the training more stable).

The layer norm has 2 weights which are trained.

  • can a current checkpoint for GPT-2 (like gpt2) be used with that option enabled in the config and produce good results, or would it need to be retrained?

Because layernorm will change the hidden representation that is seen by the next stage almost certainly some finetuning will be needed. Not sure how much.

  • can a checkpoint trained with StableEmbedding be used with a regular Embedding instead if someone disables the config option?

same answer as above, removing this transform will impact the hidden representation.

I have only used it in training from scratch so far. But perhaps @TimDettmers has some suggestions. I know he is on the road, so let's perhaps wait for him to follow up.

@TimDettmers
Copy link
Contributor

Thanks Stas! The current version of bnb also features the normal 32-bit embedding layer without layer norm for the very reason of compatibility. What I would do is use this 32-bit optimizer version as the default when using bnb optimizers and have an option for the stable embedding layer if one wants to pretrain a model.

From my experience the difference between 8-bit/32-bit optimizers for embedding layer and layer norm are as follows:

  • 8-bit: unstable training and poor performance for pretraining; successful finetuning on GLUE; finetuning on more complicated objectives (seq2seq) might be unstable
  • 32-bit: stable training for all objectives for models below 1.5B parameters;
  • 32-bit + layer norm: stable training for all models and all objectives and improved performance.
  • can a current checkpoint for GPT-2 (like gpt2) be used with that option enabled in the config and produce good results, or would it need to be retrained?

I experimented with this. For pretrained checkpoints adding a layer norm after loading the model makes training difficult and leads to poor performance. I tinkered a bit with low learning rates to adapt the layer norm first before regular finetuning, but that did not work well and is a mess. So for pretrained models, the best would be to use the 32-bit optimized embedding layer (bnb.nn.Embedding) and no layer norm if the pretrained model was not trained with a layer norm.

  • can a checkpoint trained with StableEmbedding be used with a regular Embedding instead if someone disables the config option?

This is basically the same as above. If a StableEmbedding has been used for pretraining it needs to be used for fine-tuning. Removing/adding a layer after pretraining makes finetuning difficult. The performance is usually a bit better with StableEmbedding layer (using fairseq for language modeling, masked language modeling, machine translation, multi-lingual machine translation). Pretraining is usually also easier with the layer norm. That is why it is standard for Google/OpenAI models.

I'm trying to see if it's something like enable gradient checkpointing for instance, which you can use for training without changing anything if the user that ends with your model doesn't want it, or if it impacts the checkpoints in any way. Depending on the answer to that, we will see if we need new model files or not.

Like Stas said, the optimizer should not have any effect on gradient checkpointing with or without the layer norm. It just requires consistency between pretrained/finetuned checkpoints.

Let me know if there are any more questions!

@stas00
Copy link
Contributor Author

stas00 commented Dec 31, 2021

Thank you very much for this detailed answer, Tim!

So to use Adam8bit with any normally pre-trained model we can do:

  1. load optimizer
import bitsandbytes as bnb
optim = bnb.optim.Adam8bit
  1. fixup the model architecture - extend the nn.Embedding class with bnb.nn.Embedding.__init__ (which will do embedding optim in 32-bit, while the rest of the model will be optimized in 8-bit) - must do that before loading the model! since we can't miss the init:
    https://github.com/facebookresearch/bitsandbytes/blob/4e60e7dc62c50b6ba9b6becf6e779a1d48906be2/bitsandbytes/nn/modules.py#L51

Perhaps something like:

import torch
from transformers import GPTNeoForCausalLM
from bitsandbytes.optim import GlobalOptimManager
torch.nn.modules.sparse.Embedding.orig__init__ = torch.nn.modules.sparse.Embedding.__init__
def bnb_embed_init(self, *args, **kwargs):
    torch.nn.modules.sparse.Embedding.orig__init__(self, *args, **kwargs)
    GlobalOptimManager.get_instance().register_module_override(self, 'weight', {'optim_bits': 32})
torch.nn.modules.sparse.Embedding.__init__ = bnb_embed_init

which of course can be made into a wrapper and won't be an eye sore. There are also neater way to do it with functools.wraps

import functools
import torch
from bitsandbytes.optim import GlobalOptimManager

def run_after(f):
    @functools.wraps(f)
    def wrapper(module, *args, **kwargs):
        f(module, *args, **kwargs)
        GlobalOptimManager.get_instance().register_module_override(module, 'weight', {'optim_bits': 32})
    return wrapper
cls = torch.nn.modules.sparse.Embedding
cls._old_init = cls.__init__
cls.__init__ = run_after(cls.__init__)
  1. load as normal:
model = GPTNeoForCausalLM.from_pretrained(...) - load as normal.

or may be it's easier to first load the model and then traverse it and tell Adam8bit to run embed layers in fp32:

import torch
import bitsandbytes as bnb
from transformers import GPTNeoForCausalLM
from bitsandbytes.optim import GlobalOptimManager

def set_optim_to_run_embedding_in_fp32(model):
    for module in model.modules():
        if isinstance(module, torch.nn.Embedding):
            GlobalOptimManager.get_instance().register_module_override(module, 'weight', {'optim_bits': 32})

mname = "EleutherAI/gpt-neo-125M"
model = GPTNeoForCausalLM.from_pretrained(mname)
set_optim_to_run_embedding_in_fp32(model)

This does look simpler. @TimDettmers, if this is useful, perhaps bnb_embedding_in_fp32 this can be part of BNB API, but it probably then should take an optional embed_class=torch.nn.Embedding should the user have a custom Embedding class.

I suppose it is ok to run register_module_override after the model was fully loaded. Do we need to import bnb.optim.Adam8bit, first by any chance?

If we add support for StableEmbedding in transformers archs, then for new trainings bnb.optim.Adam8bit could be used directly.

Once we merge #14744 we can add --optim adam_bnb_8bit to HF Trainer and give it a try.

@stas00
Copy link
Contributor Author

stas00 commented Dec 31, 2021

The difficult question would be how would HF Trainer know when to push in fp32-embed-optim and when not to. The model will need to have a way to tell the user that info.

@stas00
Copy link
Contributor Author

stas00 commented Jan 2, 2022

@TimDettmers, I'm curious whether you have done experiments with using AdaNorm as part of StableEmbedding for those cases where the model wasn't pretrained with StableEmbedding.

If I understand correctly AdaNorm doesn't have the LayerNorm's normal gain+bias trainable params and uses a fixed hparam instead and the paper shows very close and better at times performance in several studies done in the paper. https://arxiv.org/abs/1911.07013

If it worked, then instead of doing embeddings in fp32, perhaps using AdaNorm could be a simpler solution and further save memory. So the user will then just have to swap nn.Embedding -> bnb.nn.StableEmbeddingAdaNorm and supply an additional hparam (no idea how easy it might be to get it right though, so perhaps it's not that easy).

@sgugger
Copy link
Collaborator

sgugger commented Jan 10, 2022

Mmm, I don't think we will "fixup the model architectures". The test for keeping a current architecture and adding support for Embedding norm as a config argument does not pass, so we will need new architectures with the proper embedding layers IMO, and only those will support training/fine-tuning with --optim adam_bnb_8bit

@TimDettmers
Copy link
Contributor

Mmm, I don't think we will "fixup the model architectures". The test for keeping a current architecture and adding support for Embedding norm as a config argument does not pass, so we will need new architectures with the proper embedding layers IMO, and only those will support training/fine-tuning with --optim adam_bnb_8bit

How about only using the bnb.nn.Embedding that does not use embedding layer norm? If using a different class is problematic, the 32-bit optimizers can also be configured by passing the weight attribute of the respective class to the bitsandbytes library, like so:

 GlobalOptimManager.get_instance().register_module_override(emb_module, 'weight', {'optim_bits': 32})

No architecture reconfiguration should be needed with this option, and the embedding will run with 32-bit optimizers. This is indeed the configuration that I use to fine-tune models in the 8-bit optimizer paper -- no embedding norm required! Do you think this could make more sense?

@stas00
Copy link
Contributor Author

stas00 commented Jan 24, 2022

I think so, Tim. Thank you for your practical suggestion! We should try it out and see how it fares.

@manuelciosici, you were asking earlier if there is something else to work on. Would you like to try to work on adding --optim adamw_bnb to HF Trainer?

If yes please read this whole thread, to attempt to wrap your head around the requirements of needing an embedding with layernorm, which we don't have, but Tim has proposed a workaround above that should be a middle-ground memory-saving-wise.

So basically it'd have 3/4 memory saved for all params in optimizer except the embedding, where there will be no saving at all.

Additionally I hope that in the future we will have model archs with embed norm, and we will need to figure out how to activate the bnb optim for those archs. But we can discuss that in the PR.

If you're busy it's no problem I then hope to be able to try this some time soon.

Thank you!

@manuelciosici
Copy link
Contributor

@stas00 Thank you for the opportunity. I will read the thread today and add --optim adamw_bnb.

@manuelciosici
Copy link
Contributor

Thank you github-actions. I plan to work on this issue's PR this weekend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants