Skip to content

Commit

Permalink
RoPE scaling, document how to convert HuggingFace checkpoints (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Nov 21, 2024
1 parent 7655a3b commit 0bcc840
Show file tree
Hide file tree
Showing 14 changed files with 291 additions and 35 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
for loading checkpoints with different key names.
- Added `load_key_mapping` field to the trainer, same idea as the new `key_mapping` argument above.
- Added an implementation of nGPT called `NormalizedTransformer`.
- Added an example showing how to convert a HuggingFace Llama 3.2 checkpoint into the right format for OLMo-core.
- Added an API for scaling RoPE embeddings.

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"sphinx.ext.viewcode",
"sphinx_copybutton",
"sphinx_autodoc_typehints",
"sphinx_inline_tabs",
]

# Tell myst-parser to assign header anchors for h1-h3.
Expand Down
19 changes: 19 additions & 0 deletions docs/source/examples/huggingface.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
HuggingFace models
==================

The OLMo-core :class:`~olmo_core.train.Trainer` can be used to fine-tune language models from HuggingFace's ``transformers`` library.

One way to do this would be to manually apply a data parallel wrapper (like DDP or FSDP) to your ``AutoModelForCausalLM`` and then pass that model directly to the trainer. The downside with this approach is that you won't be able to take advantage of all of the optimizations in this library.

Instead we recommend converting your HuggingFace checkpoint into a format that can be loaded into an equivalent OLMo-core :class:`~olmo_core.nn.transformer.Transformer` model, when possible, using the functions provided by :mod:`olmo_core.distributed.checkpoint`.

Below is an example that shows how to convert a Llama-3.2 checkpoint on HuggingFace into the right format for OLMo-core.
It would be straight forward to adapt this script to convert in the other direction as well.

.. seealso::
See the `train a Llama model <llama.html>`_ example to learn how to use OLMo-core's training API to pretrain or fine-tune any Llama-like language model.

.. tab:: ``src/examples/huggingface/convert_checkpoint.py``

.. literalinclude:: ../../../src/examples/huggingface/convert_checkpoint.py
:language: py
24 changes: 11 additions & 13 deletions docs/source/examples/llama.rst
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
``Train a Llama model``
=======================
Train a Llama model
===================

The following snippet is the code from ``src/examples/llama/train.py``.
It's a script meant to be launched via ``torchrun``.
The following snippets can be found in `src/examples/llama/ <https://github.com/allenai/OLMo-core/tree/main/src/examples/llama>`_.
The ``train.py`` script is meant to be launched via ``torchrun``.
You can also use the :mod:`olmo_core.launch` API to quickly launch this script on Beaker.
See below for an example of that.
See the ``train_launch.py`` snippet for an example of that.

``src/examples/llama/train.py``
-------------------------------
.. tab:: ``train.py``

.. literalinclude:: ../../../src/examples/llama/train.py
:language: py
.. literalinclude:: ../../../src/examples/llama/train.py
:language: py

``src/examples/llama/train_launch.py``
--------------------------------------
.. tab:: ``train_launch.py``

.. literalinclude:: ../../../src/examples/llama/train_launch.py
:language: py
.. literalinclude:: ../../../src/examples/llama/train_launch.py
:language: py
13 changes: 6 additions & 7 deletions docs/source/examples/ngpt.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
``Train an nGPT model``
=======================
Train an nGPT model
===================

The following snippet is the code from ``src/examples/ngpt/train.py``.
The following snippet can be found in `src/examples/ngpt/ <https://github.com/allenai/OLMo-core/tree/main/src/examples/ngpt>`_.
It's a script meant to be launched via ``torchrun``.

``src/examples/ngpt/train.py``
------------------------------
.. tab:: ``train.py``

.. literalinclude:: ../../../src/examples/ngpt/train.py
:language: py
.. literalinclude:: ../../../src/examples/ngpt/train.py
:language: py
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ specific to your environment. Then you can install OLMo-core from PyPI with:
:maxdepth: 2
:caption: Examples

examples/huggingface.rst
examples/llama.rst
examples/ngpt.rst

Expand Down
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ dev = [
"build",
"boto3",
"google-cloud-storage",
"Sphinx>=6.0,<7.0.2",
"Sphinx>=6.0,<9.0",
"furo==2024.8.6",
"myst-parser>=1.0,<2.1",
"sphinx-copybutton==0.5.2",
"sphinx-autobuild==2021.3.14",
"myst-parser>=1.0",
"sphinx-copybutton",
"sphinx-autobuild",
"sphinx-autodoc-typehints==1.23.3",
"sphinx-inline-tabs",
]
beaker = [
"beaker-py>=1.32.0",
Expand Down
Empty file.
123 changes: 123 additions & 0 deletions src/examples/huggingface/convert_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Example script showing how you could convert model weights on HuggingFace for a Llama-3.2 model
into a format that can be loaded by OLMo-core for fine-tuning.
Note that this script is architecture-dependent, meaning it may only work for Llama-3.2 models on
HuggingFace.
"""

import logging

import torch
from transformers import AutoModelForCausalLM

from olmo_core.data.tokenizer import TokenizerConfig
from olmo_core.distributed.checkpoint import load_model_and_optim_state, save_state_dict
from olmo_core.io import clear_directory, dir_is_empty
from olmo_core.nn.rope import RoPEScalingConfig
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.utils import get_default_device, prepare_cli_environment

log = logging.getLogger(__name__)

HF_MODEL = "meta-llama/Llama-3.2-1B"
SAVE_PATH = f"/tmp/checkpoints/{HF_MODEL}"
SAVE_OVERWRITE = False

TOKENIZER_CONFIG = TokenizerConfig.from_hf(HF_MODEL)
MODEL_CONFIG = TransformerConfig.llama3_1B(
TOKENIZER_CONFIG.vocab_size, fused_ops=False, use_flash=False, rope_scaling=RoPEScalingConfig()
)


def convert_checkpoint() -> AutoModelForCausalLM:
log.info(f"Loading HF checkpoint '{HF_MODEL}'")
hf_model = AutoModelForCausalLM.from_pretrained(HF_MODEL)
print(hf_model)

if not dir_is_empty(SAVE_PATH):
if SAVE_OVERWRITE:
log.warning(f"Clearing existing checkpoint at '{SAVE_PATH}'")
clear_directory(SAVE_PATH)
else:
log.warning(f"Using existing checkpoint at '{SAVE_PATH}'")
return hf_model

n_layers = len(hf_model.model.layers)
state_dict = hf_model.state_dict()

# Map old keys to OLMo-core keys.
new_state_dict = {
"embeddings.weight": state_dict.pop("model.embed_tokens.weight"),
"lm_head.norm.weight": state_dict.pop("model.norm.weight"),
"lm_head.w_out.weight": state_dict.pop("lm_head.weight"),
}
for block in range(n_layers):
# Attention.
new_state_dict[f"blocks.{block}.attention.w_q.weight"] = state_dict.pop(
f"model.layers.{block}.self_attn.q_proj.weight"
)
new_state_dict[f"blocks.{block}.attention.w_k.weight"] = state_dict.pop(
f"model.layers.{block}.self_attn.k_proj.weight"
)
new_state_dict[f"blocks.{block}.attention.w_v.weight"] = state_dict.pop(
f"model.layers.{block}.self_attn.v_proj.weight"
)
new_state_dict[f"blocks.{block}.attention.w_out.weight"] = state_dict.pop(
f"model.layers.{block}.self_attn.o_proj.weight"
)

# MLP.
new_state_dict[f"blocks.{block}.feed_forward.w1.weight"] = state_dict.pop(
f"model.layers.{block}.mlp.gate_proj.weight"
)
new_state_dict[f"blocks.{block}.feed_forward.w2.weight"] = state_dict.pop(
f"model.layers.{block}.mlp.down_proj.weight"
)
new_state_dict[f"blocks.{block}.feed_forward.w3.weight"] = state_dict.pop(
f"model.layers.{block}.mlp.up_proj.weight"
)

# Attention layer norm.
new_state_dict[f"blocks.{block}.attention_norm.weight"] = state_dict.pop(
f"model.layers.{block}.input_layernorm.weight"
)

# MLP layer norm.
new_state_dict[f"blocks.{block}.feed_forward_norm.weight"] = state_dict.pop(
f"model.layers.{block}.post_attention_layernorm.weight"
)

assert len(state_dict) == 0

log.info(f"Saving converted model checkpoint '{SAVE_PATH}'...")
save_state_dict(SAVE_PATH, {"model": new_state_dict})

return hf_model


def validate_conversion(hf_model):
log.info("Loading converted checkpoint for validation...")

device = get_default_device()

model = MODEL_CONFIG.build(device=device, max_seq_len=131072).eval()
load_model_and_optim_state(SAVE_PATH, model)

hf_model = hf_model.to(device).eval()

B, T = 1, 120
input_ids = torch.randint(0, TOKENIZER_CONFIG.vocab_size, (B, T)).to(device)

with torch.no_grad():
logits = model(input_ids=input_ids)
hf_logits, *_ = hf_model(input_ids=input_ids, return_dict=False)
torch.testing.assert_close(hf_logits, logits)

log.info("Conversion successful")


if __name__ == "__main__":
prepare_cli_environment()
hf_model = convert_checkpoint()
validate_conversion(hf_model)
25 changes: 24 additions & 1 deletion src/olmo_core/data/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,31 @@ def gpt2(cls) -> "TokenizerConfig":
Get a :data:`~TokenizerName.gpt2` tokenizer config.
"""
return cls(
vocab_size=50280,
vocab_size=50257,
eos_token_id=50256,
bos_token_id=50256,
pad_token_id=50256,
identifier=TokenizerName.gpt2,
)

@classmethod
def from_hf(cls, identifier: str) -> "TokenizerConfig":
"""
Initialize a tokenizer config from a model on HuggingFace.
:param identifier: The HF model identifier, e.g. "meta-llama/Llama-3.2-1B".
"""
import json

from cached_path import cached_path

with cached_path(f"hf://{identifier}/config.json").open() as f:
config = json.load(f)

return cls(
vocab_size=config["vocab_size"],
eos_token_id=config["eos_token_id"],
pad_token_id=config.get("pad_token_id", config["eos_token_id"]),
bos_token_id=config.get("bos_token_id"),
identifier=identifier,
)
8 changes: 8 additions & 0 deletions src/olmo_core/nn/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def __init__(
self.register_parameter("bias", None)
self.register_parameter("weight", None)

def extra_repr(self):
if self.weight is not None and self.bias is not None:
return f"{tuple(self.weight.shape)}, bias=True, eps={self.eps}"
elif self.weight is not None:
return f"{tuple(self.weight.shape)}, eps={self.eps}"
else:
return f"eps={self.eps}"

def reset_parameters(self):
if self.weight is not None:
torch.nn.init.ones_(self.weight)
Expand Down
Loading

0 comments on commit 0bcc840

Please sign in to comment.