Skip to content

learning-at-home/lean_transformer

Repository files navigation

[Under Construction] A transformer that does not hog your GPU memory

LeanTransformer implements a specific version of transformer with two goals in mind:

  • using as little GPU memory as possible
  • stable training for very large models

This is code is under active development: if you want a stable and documented version, look at CALM or dalle-hivemind.

Basic usage: lean transformer works similarly to most models on Hugging Face Transformers. The model can be instantiated from a config, run forward and backward, compute loss. One can use vanilla general-purpose LeanTransformer or one of pre-implemented models:

from transformers import AutoTokenizer
from lean_transformer.models.gpt import LeanGPTConfig, LeanGPTModel

config = LeanGPTConfig(
    vocab_size=10 ** 4, hidden_size=768, num_hidden_layers=12, num_attention_heads=16,
    position_embedding_type="rotary", hidden_act_gated=True, tie_word_embeddings=True
)
model = LeanGPTModel(config)
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")

dummy_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")
outputs = model(**dummy_inputs, labels=dummy_inputs['input_ids'])
outputs.loss.backward()

All models are batch-first, i.e. they work on [batch, length, hid_size] or [batch, height, width, channels] tensors like the rest of HuggingFace stuff.

A day will come a day when we explain all these modifications and provide instructions on how to tune them. Until then, we'll happily answer any questions on our discord.

How it works?

The core philosophy of LeanTransformer is to replace torch.autograd with grad students. Automatic differentiation is great if you want to test ideas quickly, less so if a single training run can cost over $4 million (or >1000 years in grad school). So, we made a ton of tweaks that minimize memory usage.

Related work: GSO

Our implementation partially replaces automatic differentiation with Grad Student Optimization (GSO) - a biologically inspired black box optimization algorithm. In the past, GSO has seen widespread adoption thanks to its strong theoretical foundations and unparalleled cost efficiency (Chom et al). Previous successfully applied GSO for hyperparameter tuning and natural language generation. To the best of our knowledge we are the first work to successfully apply distributed fault-tolerant GSO for optimizing the memory footprint of transformers. We summarize our findings below:

Memory saving features:

Other features:

Acknowledgements:

About

Memory-efficient transformer. Work in progress.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages