diff --git a/.gitignore b/.gitignore
index 2aea0a4cb..a5d6d42b7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,6 +15,6 @@ _modidx.py
env
dist/
docs/build
-.coverage
+.coverage*
.Ds_Store
.pylintrc
diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb
new file mode 100644
index 000000000..6d227ebea
--- /dev/null
+++ b/demos/BERT.ipynb
@@ -0,0 +1,258 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# BERT in TransformerLens\n",
+ "This demo shows how to use BERT in TransformerLens for the Masked Language Modelling task."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Setup\n",
+ "(No need to read)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running as a Jupyter notebook - intended for development only!\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
+ "DEVELOPMENT_MODE = False\n",
+ "try:\n",
+ " import google.colab\n",
+ " IN_COLAB = True\n",
+ " print(\"Running as a Colab notebook\")\n",
+ " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n",
+ " %pip install circuitsvis\n",
+ " \n",
+ " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
+ " # # Install another version of node that makes PySvelte work way faster\n",
+ " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
+ " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
+ "except:\n",
+ " IN_COLAB = False\n",
+ " print(\"Running as a Jupyter notebook - intended for development only!\")\n",
+ " from IPython import get_ipython\n",
+ "\n",
+ " ipython = get_ipython()\n",
+ " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
+ " ipython.magic(\"load_ext autoreload\")\n",
+ " ipython.magic(\"autoreload 2\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using renderer: colab\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
+ "import plotly.io as pio\n",
+ "if IN_COLAB or not DEVELOPMENT_MODE:\n",
+ " pio.renderers.default = \"colab\"\n",
+ "else:\n",
+ " pio.renderers.default = \"notebook_connected\"\n",
+ "print(f\"Using renderer: {pio.renderers.default}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import circuitsvis as cv\n",
+ "# Testing that the library works\n",
+ "cv.examples.hello(\"Neel\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Import stuff\n",
+ "import torch\n",
+ "\n",
+ "from transformers import AutoTokenizer\n",
+ "\n",
+ "from transformer_lens import HookedEncoder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "torch.set_grad_enabled(False)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# BERT\n",
+ "\n",
+ "In this section, we will load a pretrained BERT model and use it for the Masked Language Modelling task"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Moving model to device: cpu\n",
+ "Loaded pretrained model bert-base-cased into HookedTransformer\n"
+ ]
+ }
+ ],
+ "source": [
+ "bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n",
+ "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Use the \"[MASK]\" token to mask any tokens which you would like the model to predict."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "prompt = \"BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\"\n",
+ "\n",
+ "input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n",
+ "mask_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Prompt: BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\n",
+ "Prediction: \"Systems\"\n"
+ ]
+ }
+ ],
+ "source": [
+ "logprobs = bert(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n",
+ "prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n",
+ "\n",
+ "print(f\"Prompt: {prompt}\")\n",
+ "print(f\"Prediction: \\\"{prediction}\\\"\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Better luck next time, BERT."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.10"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb
index aed5af6ba..6922d18ff 100644
--- a/demos/Main_Demo.ipynb
+++ b/demos/Main_Demo.ipynb
@@ -1091,6 +1091,7 @@
]
},
{
+ "attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@@ -1110,6 +1111,9 @@
"* **GPT-NeoX** - Eleuther's 20B parameter model, trained on the Pile\n",
"* **Stanford CRFM models** - a replication of GPT-2 Small and GPT-2 Medium, trained on 5 different random seeds.\n",
" * Notably, 600 checkpoints were taken during training per model, and these are available in the library with eg `HookedTransformer.from_pretrained(\"stanford-gpt2-small-a\", checkpoint_index=265)`.\n",
+ "- **BERT** - Google's bidirectional encoder-only transformer.\n",
+ " - Size Base (108M), trained on English Wikipedia and BooksCorpus.\n",
+ " \n",
""
]
},
diff --git a/pyproject.toml b/pyproject.toml
index 3e5e9b9aa..5a56131dc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,6 +52,7 @@ docs = ["Sphinx", "sphinx-autobuild", "sphinxcontrib-napoleon", "furo", "myst_pa
[tool.pytest.ini_options]
filterwarnings = [
+ "ignore:pkg_resources is deprecated as an API:DeprecationWarning",
# Ignore numpy.distutils deprecation warning caused by pandas
# More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
"ignore:distutils Version classes are deprecated:DeprecationWarning"
diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py
new file mode 100644
index 000000000..6f95d9b27
--- /dev/null
+++ b/tests/acceptance/test_hooked_encoder.py
@@ -0,0 +1,155 @@
+from typing import List
+
+import pytest
+import torch
+import torch.nn.functional as F
+from jaxtyping import Float
+from torch.testing import assert_close
+from transformers import AutoTokenizer, BertForMaskedLM
+
+from transformer_lens import HookedEncoder
+
+MODEL_NAME = "bert-base-cased"
+
+
+@pytest.fixture(scope="module")
+def our_bert():
+ return HookedEncoder.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture(scope="module")
+def huggingface_bert():
+ return BertForMaskedLM.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture(scope="module")
+def tokenizer():
+ return AutoTokenizer.from_pretrained(MODEL_NAME)
+
+
+@pytest.fixture
+def hello_world_tokens(tokenizer):
+ return tokenizer("Hello, world!", return_tensors="pt")["input_ids"]
+
+
+def test_full_model(our_bert, huggingface_bert, tokenizer):
+ sequences = [
+ "Hello, world!",
+ "this is another sequence of tokens",
+ ]
+ tokenized = tokenizer(sequences, return_tensors="pt", padding=True)
+ input_ids = tokenized["input_ids"]
+ attention_mask = tokenized["attention_mask"]
+
+ huggingface_bert_out = huggingface_bert(
+ input_ids, attention_mask=attention_mask
+ ).logits
+ our_bert_out = our_bert(input_ids, one_zero_attention_mask=attention_mask)
+ assert_close(huggingface_bert_out, our_bert_out, rtol=1.3e-6, atol=4e-5)
+
+
+def test_embed_one_sentence(our_bert, huggingface_bert, hello_world_tokens):
+ huggingface_embed = huggingface_bert.bert.embeddings
+ our_embed = our_bert.embed
+
+ huggingface_embed_out = huggingface_embed(hello_world_tokens)[0]
+ our_embed_out = our_embed(hello_world_tokens).squeeze(0)
+ assert_close(huggingface_embed_out, our_embed_out)
+
+
+def test_embed_two_sentences(our_bert, huggingface_bert, tokenizer):
+ encoding = tokenizer("First sentence.", "Second sentence.", return_tensors="pt")
+ input_ids = encoding["input_ids"]
+ token_type_ids = encoding["token_type_ids"]
+
+ huggingface_embed_out = huggingface_bert.bert.embeddings(
+ input_ids, token_type_ids=token_type_ids
+ )[0]
+ our_embed_out = our_bert.embed(input_ids, token_type_ids=token_type_ids).squeeze(0)
+ assert_close(huggingface_embed_out, our_embed_out)
+
+
+def test_attention(our_bert, huggingface_bert, hello_world_tokens):
+ huggingface_embed = huggingface_bert.bert.embeddings
+ huggingface_attn = huggingface_bert.bert.encoder.layer[0].attention
+
+ embed_out = huggingface_embed(hello_world_tokens)
+
+ our_attn = our_bert.blocks[0].attn
+
+ our_attn_out = our_attn(embed_out, embed_out, embed_out)
+ huggingface_self_attn_out = huggingface_attn.self(embed_out)[0]
+ huggingface_attn_out = huggingface_attn.output.dense(huggingface_self_attn_out)
+ assert_close(our_attn_out, huggingface_attn_out)
+
+
+def test_bert_block(our_bert, huggingface_bert, hello_world_tokens):
+ huggingface_embed = huggingface_bert.bert.embeddings
+ huggingface_block = huggingface_bert.bert.encoder.layer[0]
+
+ embed_out = huggingface_embed(hello_world_tokens)
+
+ our_block = our_bert.blocks[0]
+
+ our_block_out = our_block(embed_out)
+ huggingface_block_out = huggingface_block(embed_out)[0]
+ assert_close(our_block_out, huggingface_block_out)
+
+
+def test_mlm_head(our_bert, huggingface_bert, hello_world_tokens):
+ huggingface_bert_core_outputs = huggingface_bert.bert(
+ hello_world_tokens
+ ).last_hidden_state
+
+ our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs)
+ our_unembed_out = our_bert.unembed(our_mlm_head_out)
+ huggingface_predictions_out = huggingface_bert.cls.predictions(
+ huggingface_bert_core_outputs
+ )
+
+ assert_close(our_unembed_out, huggingface_predictions_out, rtol=1.3e-6, atol=4e-5)
+
+
+def test_unembed(our_bert, huggingface_bert, hello_world_tokens):
+ huggingface_bert_core_outputs = huggingface_bert.bert(
+ hello_world_tokens
+ ).last_hidden_state
+
+ our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs)
+ huggingface_predictions_out = huggingface_bert.cls.predictions.transform(
+ huggingface_bert_core_outputs
+ )
+
+ print((our_mlm_head_out - huggingface_predictions_out).abs().max())
+ assert_close(our_mlm_head_out, huggingface_predictions_out, rtol=1.3e-3, atol=1e-5)
+
+
+def test_run_with_cache(our_bert, huggingface_bert, hello_world_tokens):
+ model = HookedEncoder.from_pretrained("bert-base-cased")
+ logits, cache = model.run_with_cache(hello_world_tokens)
+
+ # check that an arbitrary subset of the keys exist
+ assert "embed.hook_embed" in cache
+ assert "blocks.0.attn.hook_q" in cache
+ assert "blocks.3.attn.hook_attn_scores" in cache
+ assert "blocks.7.hook_resid_post" in cache
+ assert "mlm_head.ln.hook_normalized" in cache
+
+
+def test_predictions(our_bert, huggingface_bert, tokenizer):
+ input_ids = tokenizer("The [MASK] sat on the mat", return_tensors="pt")["input_ids"]
+
+ def get_predictions(
+ logits: Float[torch.Tensor, "batch pos d_vocab"], positions: List[int]
+ ):
+ logits_at_position = logits.squeeze(0)[positions]
+ predicted_tokens = F.softmax(logits_at_position, dim=-1).argmax(dim=-1)
+ return tokenizer.batch_decode(predicted_tokens)
+
+ our_bert_out = our_bert(input_ids)
+ our_prediction = get_predictions(our_bert_out, [2])
+
+ huggingface_bert_out = huggingface_bert(input_ids).logits
+ huggingface_prediction = get_predictions(huggingface_bert_out, [2])
+
+ assert our_prediction == huggingface_prediction
diff --git a/tests/acceptance/test_transformer_lens.py b/tests/acceptance/test_hooked_transformer.py
similarity index 100%
rename from tests/acceptance/test_transformer_lens.py
rename to tests/acceptance/test_hooked_transformer.py
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
new file mode 100644
index 000000000..7e3de330e
--- /dev/null
+++ b/tests/unit/test_config.py
@@ -0,0 +1,25 @@
+"""
+Tests that verify than an arbitrary component (e.g. Embed) can be initialized using dict and object versions of HookedTransformerConfig and HookedEncoderConfig.
+"""
+
+from transformer_lens.components import Embed
+from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
+
+
+def test_hooked_transformer_config_object():
+ hooked_transformer_config = HookedTransformerConfig(
+ n_layers=2, d_vocab=100, d_model=6, n_ctx=5, d_head=2, attn_only=True
+ )
+ Embed(hooked_transformer_config)
+
+
+def test_hooked_transformer_config_dict():
+ hooked_transformer_config_dict = {
+ "n_layers": 2,
+ "d_vocab": 100,
+ "d_model": 6,
+ "n_ctx": 5,
+ "d_head": 2,
+ "attn_only": True,
+ }
+ Embed(hooked_transformer_config_dict)
diff --git a/tests/unit/test_create_hooked_encoder.py b/tests/unit/test_create_hooked_encoder.py
new file mode 100644
index 000000000..a1adc7ef3
--- /dev/null
+++ b/tests/unit/test_create_hooked_encoder.py
@@ -0,0 +1,35 @@
+import pytest
+from transformers import AutoTokenizer, BertTokenizerFast
+
+from transformer_lens import HookedEncoder, HookedTransformerConfig
+
+
+@pytest.fixture
+def cfg():
+ return HookedTransformerConfig(
+ d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu"
+ )
+
+
+def test_pass_tokenizer(cfg):
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
+ model = HookedEncoder(cfg, tokenizer=tokenizer)
+ assert model.tokenizer == tokenizer
+
+
+def test_load_tokenizer_from_config(cfg):
+ cfg.tokenizer_name = "bert-base-cased"
+ model = HookedEncoder(cfg)
+ assert isinstance(model.tokenizer, BertTokenizerFast)
+
+
+def test_load_without_tokenizer(cfg):
+ cfg.d_vocab = 22
+ model = HookedEncoder(cfg)
+ assert model.tokenizer is None
+
+
+def test_cannot_load_without_tokenizer_or_d_vocab(cfg):
+ with pytest.raises(AssertionError) as e:
+ HookedEncoder(cfg)
+ assert "Must provide a tokenizer if d_vocab is not provided" in str(e.value)
diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py
new file mode 100644
index 000000000..3ee5e2ad7
--- /dev/null
+++ b/transformer_lens/HookedEncoder.py
@@ -0,0 +1,397 @@
+from __future__ import annotations
+
+import logging
+from functools import lru_cache
+from typing import Dict, Optional, Tuple, Union, cast, overload
+
+import torch
+from einops import repeat
+from jaxtyping import Float, Int
+from torch import nn
+from transformers import AutoTokenizer
+from typeguard import typeguard_ignore
+from typing_extensions import Literal
+
+import transformer_lens.loading_from_pretrained as loading
+from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformerConfig
+from transformer_lens.components import BertBlock, BertEmbed, BertMLMHead, Unembed
+from transformer_lens.hook_points import HookedRootModule, HookPoint
+from transformer_lens.utilities import devices
+
+
+class HookedEncoder(HookedRootModule):
+ """
+ This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule.
+
+ Limitations:
+ - The current MVP implementation supports only the masked language modelling (MLM) task. Next sentence prediction (NSP), causal language modelling, and other tasks are not yet supported.
+ - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning.
+
+ Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported:
+ - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
+ - The model only accepts tokens as inputs, and not strings, or lists of strings
+ """
+
+ def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig(**cfg)
+ elif isinstance(cfg, str):
+ raise ValueError(
+ "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead."
+ )
+ self.cfg = cfg
+
+ assert (
+ self.cfg.n_devices == 1
+ ), "Multiple devices not supported for HookedEncoder"
+ if move_to_device:
+ self.to(self.cfg.device)
+
+ if tokenizer is not None:
+ self.tokenizer = tokenizer
+ elif self.cfg.tokenizer_name is not None:
+ self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name)
+ else:
+ self.tokenizer = None
+
+ if self.cfg.d_vocab == -1:
+ # If we have a tokenizer, vocab size can be inferred from it.
+ assert (
+ self.tokenizer is not None
+ ), "Must provide a tokenizer if d_vocab is not provided"
+ self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1
+ if self.cfg.d_vocab_out == -1:
+ self.cfg.d_vocab_out = self.cfg.d_vocab
+
+ self.embed = BertEmbed(self.cfg)
+ self.blocks = nn.ModuleList(
+ [BertBlock(self.cfg) for _ in range(self.cfg.n_layers)]
+ )
+ self.mlm_head = BertMLMHead(cfg)
+ self.unembed = Unembed(self.cfg)
+
+ self.hook_full_embed = HookPoint()
+
+ self.setup()
+
+ @overload
+ def forward(
+ self,
+ input: Int[torch.Tensor, "batch pos"],
+ return_type: Literal["logits"],
+ token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
+ one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
+ ) -> Float[torch.Tensor, "batch pos d_vocab"]:
+ ...
+
+ @overload
+ def forward(
+ self,
+ input: Int[torch.Tensor, "batch pos"],
+ return_type: Literal[None],
+ token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
+ one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
+ ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]:
+ ...
+
+ def forward(
+ self,
+ input: Int[torch.Tensor, "batch pos"],
+ return_type: Optional[str] = "logits",
+ token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
+ one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None,
+ ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]:
+ """Input must be a batch of tokens. Strings and lists of strings are not yet supported.
+
+ return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits).
+
+ token_type_ids Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
+
+ one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to.
+ """
+
+ tokens = input
+
+ if tokens.device.type != self.cfg.device:
+ tokens = tokens.to(self.cfg.device)
+ if one_zero_attention_mask is not None:
+ one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device)
+
+ resid = self.hook_full_embed(self.embed(tokens, token_type_ids))
+
+ large_negative_number = -1e5
+ additive_attention_mask = (
+ large_negative_number
+ * repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos")
+ if one_zero_attention_mask is not None
+ else None
+ )
+
+ for block in self.blocks:
+ resid = block(resid, additive_attention_mask)
+ resid = self.mlm_head(resid)
+
+ if return_type is None:
+ return
+
+ logits = self.unembed(resid)
+ return logits
+
+ @overload
+ def run_with_cache(
+ self, *model_args, return_cache_object: Literal[True] = True, **kwargs
+ ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]:
+ ...
+
+ @overload
+ def run_with_cache(
+ self, *model_args, return_cache_object: Literal[False] = False, **kwargs
+ ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]:
+ ...
+
+ def run_with_cache(
+ self,
+ *model_args,
+ return_cache_object: bool = True,
+ remove_batch_dim: bool = False,
+ **kwargs,
+ ) -> Tuple[
+ Float[torch.Tensor, "batch pos d_vocab"],
+ Union[ActivationCache, Dict[str, torch.Tensor]],
+ ]:
+ """
+ Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer.
+ """
+ out, cache_dict = super().run_with_cache(
+ *model_args, remove_batch_dim=remove_batch_dim, **kwargs
+ )
+ if return_cache_object:
+ cache = ActivationCache(
+ cache_dict, self, has_batch_dim=not remove_batch_dim
+ )
+ return out, cache
+ else:
+ return out, cache_dict
+
+ def to(
+ self,
+ device_or_dtype: Union[torch.device, str, torch.dtype],
+ print_details: bool = True,
+ ):
+ return devices.move_to_and_update_config(self, device_or_dtype, print_details)
+
+ def cuda(self):
+ # Wrapper around cuda that also changes self.cfg.device
+ return self.to("cuda")
+
+ def cpu(self):
+ # Wrapper around cuda that also changes self.cfg.device
+ return self.to("cpu")
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ model_name: str,
+ checkpoint_index: Optional[int] = None,
+ checkpoint_value: Optional[int] = None,
+ hf_model=None,
+ device: Optional[str] = None,
+ **model_kwargs,
+ ) -> HookedEncoder:
+ """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model."""
+ logging.warning(
+ "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature "
+ "parity with HookedTransformer and has been tested on real research tasks. Until then, backward "
+ "compatibility is not guaranteed. Please see the docs for information on the limitations of the current "
+ "implementation."
+ "\n"
+ "If using BERT for interpretability research, keep in mind that BERT has some significant architectural "
+ "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning "
+ "that the last LayerNorm in a block cannot be folded."
+ )
+
+ official_model_name = loading.get_official_model_name(model_name)
+
+ cfg = loading.get_pretrained_model_config(
+ official_model_name,
+ checkpoint_index=checkpoint_index,
+ checkpoint_value=checkpoint_value,
+ fold_ln=False,
+ device=device,
+ n_devices=1,
+ )
+
+ state_dict = loading.get_pretrained_state_dict(
+ official_model_name, cfg, hf_model
+ )
+
+ model = cls(cfg, **model_kwargs)
+
+ model.load_state_dict(state_dict, strict=False)
+
+ print(f"Loaded pretrained model {model_name} into HookedTransformer")
+
+ return model
+
+ @property
+ @typeguard_ignore
+ def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]:
+ """
+ Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits)
+ """
+ return self.unembed.W_U
+
+ @property
+ @typeguard_ignore
+ def b_U(self) -> Float[torch.Tensor, "d_vocab"]:
+ return self.unembed.b_U
+
+ @property
+ @typeguard_ignore
+ def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]:
+ """
+ Convenience to get the embedding matrix
+ """
+ return self.embed.embed.W_E
+
+ @property
+ @typeguard_ignore
+ def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]:
+ """
+ Convenience function to get the positional embedding. Only works on models with absolute positional embeddings!
+ """
+ return self.embed.pos_embed.W_pos
+
+ @property
+ @typeguard_ignore
+ def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]:
+ """
+ Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits.
+ """
+ return torch.cat([self.W_E, self.W_pos], dim=0)
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
+ """Stacks the key weights across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
+ """Stacks the query weights across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]:
+ """Stacks the value weights across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]:
+ """Stacks the attn output weights across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]:
+ """Stacks the MLP input weights across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]:
+ """Stacks the MLP output weights across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
+ """Stacks the key biases across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
+ """Stacks the query biases across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]:
+ """Stacks the value biases across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]:
+ """Stacks the attn output biases across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]:
+ """Stacks the MLP input biases across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ @lru_cache(maxsize=None)
+ def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]:
+ """Stacks the MLP output biases across all layers"""
+ return torch.stack(
+ [cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0
+ )
+
+ @property
+ @typeguard_ignore
+ def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
+ return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
+
+ @property
+ @typeguard_ignore
+ def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model]
+ return FactoredMatrix(self.W_V, self.W_O)
+
+ def all_head_labels(self) -> list[str]:
+ return [
+ f"L{l}H{h}"
+ for l in range(self.cfg.n_layers)
+ for h in range(self.cfg.n_heads)
+ ]
diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py
index 6d2489044..6533c85a7 100644
--- a/transformer_lens/HookedTransformer.py
+++ b/transformer_lens/HookedTransformer.py
@@ -721,26 +721,12 @@ def tokens_to_residual_directions(
residual_direction = self.W_U[:, token]
return residual_direction
- def to(self, device_or_dtype, print_details=True):
- """
- Wrapper around to that also changes self.cfg.device if it's a torch.device or string.
- If torch.dtype, just passes through
- """
- if isinstance(device_or_dtype, torch.device):
- self.cfg.device = device_or_dtype.type
- if print_details:
- print("Moving model to device: ", self.cfg.device)
- elif isinstance(device_or_dtype, str):
- self.cfg.device = device_or_dtype
- if print_details:
- print("Moving model to device: ", self.cfg.device)
- elif isinstance(device_or_dtype, torch.dtype):
- if print_details:
- print("Changing model dtype to", device_or_dtype)
- # change state_dict dtypes
- for k, v in self.state_dict().items():
- self.state_dict()[k] = v.to(device_or_dtype)
- return nn.Module.to(self, device_or_dtype)
+ def to(
+ self,
+ device_or_dtype: Union[torch.device, str, torch.dtype],
+ print_details: bool = True,
+ ):
+ return devices.move_to_and_update_config(self, device_or_dtype, print_details)
def cuda(self):
# Wrapper around cuda that also changes self.cfg.device
@@ -1033,33 +1019,7 @@ def load_and_process_state_dict(
self.load_state_dict(state_dict)
def fill_missing_keys(self, state_dict):
- """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
-
- This function is assumed to be run before weights are initialized.
-
- Args:
- state_dict (dict): State dict from a pretrained model
-
- Returns:
- dict: State dict with missing keys filled in
- """
- # Get the default state dict
- default_state_dict = self.state_dict()
- # Get the keys that are missing from the pretrained model
- missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
- # Fill in the missing keys with the default initialization
- for key in missing_keys:
- if "hf_model" in key:
- # Skip keys that are from the HuggingFace model, if loading from HF.
- continue
- if "W_" in key:
- logging.warning(
- "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
- key
- )
- )
- state_dict[key] = default_state_dict[key]
- return state_dict
+ return loading.fill_missing_keys(self, state_dict)
def fold_layer_norm(self, state_dict: Dict[str, torch.Tensor]):
"""Takes in a state dict from a pretrained model, formatted to be consistent with HookedTransformer but with
diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py
index b33ffcefa..becf62821 100644
--- a/transformer_lens/HookedTransformerConfig.py
+++ b/transformer_lens/HookedTransformerConfig.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import logging
import pprint
import random
@@ -226,7 +228,7 @@ def __post_init__(self):
), f"Not enough CUDA devices to support n_devices {self.n_devices}"
@classmethod
- def from_dict(cls, config_dict: Dict[str, Any]):
+ def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig:
"""
Instantiates a `HookedTransformerConfig` from a Python dictionary of
parameters.
diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py
index 2994e9768..fce9f93fa 100644
--- a/transformer_lens/__init__.py
+++ b/transformer_lens/__init__.py
@@ -10,6 +10,7 @@
from .FactoredMatrix import FactoredMatrix
from .ActivationCache import ActivationCache
from .HookedTransformer import HookedTransformer
+from .HookedEncoder import HookedEncoder
from . import head_detector
from . import loading_from_pretrained as loading
from . import patching
diff --git a/transformer_lens/components.py b/transformer_lens/components.py
index 16fa7053b..09b1b069b 100644
--- a/transformer_lens/components.py
+++ b/transformer_lens/components.py
@@ -90,6 +90,97 @@ def forward(
return broadcast_pos_embed.clone()
+class TokenTypeEmbed(nn.Module):
+ """
+ The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length).
+
+ See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf
+ """
+
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model))
+
+ def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]):
+ return self.W_token_type[token_type_ids, :]
+
+
+class BertEmbed(nn.Module):
+ """
+ Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result.
+ """
+
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.embed = Embed(cfg)
+ self.pos_embed = PosEmbed(cfg)
+ self.token_type_embed = TokenTypeEmbed(cfg)
+ self.ln = LayerNorm(cfg)
+
+ self.hook_embed = HookPoint()
+ self.hook_pos_embed = HookPoint()
+ self.hook_token_type_embed = HookPoint()
+
+ def forward(
+ self,
+ input_ids: Int[torch.Tensor, "batch pos"],
+ token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None,
+ ):
+ base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device)
+ index_ids = einops.repeat(
+ base_index_id, "pos -> batch pos", batch=input_ids.shape[0]
+ )
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ word_embeddings_out = self.hook_embed(self.embed(input_ids))
+ position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids))
+ token_type_embeddings_out = self.hook_token_type_embed(
+ self.token_type_embed(token_type_ids)
+ )
+
+ embeddings_out = (
+ word_embeddings_out + position_embeddings_out + token_type_embeddings_out
+ )
+ layer_norm_out = self.ln(embeddings_out)
+ return layer_norm_out
+
+
+class BertMLMHead(nn.Module):
+ """
+ Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence.
+ """
+
+ def __init__(self, cfg: Union[Dict, HookedTransformerConfig]):
+ super().__init__()
+ if isinstance(cfg, Dict):
+ cfg = HookedTransformerConfig.from_dict(cfg)
+ self.cfg = cfg
+ self.W = nn.Parameter(torch.empty(cfg.d_model, cfg.d_model))
+ self.b = nn.Parameter(torch.zeros(cfg.d_model))
+ self.act_fn = nn.GELU()
+ self.ln = LayerNorm(cfg)
+
+ def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor:
+ resid = (
+ einsum(
+ "batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out",
+ resid,
+ self.W,
+ )
+ + self.b
+ )
+ resid = self.act_fn(resid)
+ resid = self.ln(resid)
+ return resid
+
+
# LayerNormPre
# I fold the LayerNorm weights and biases into later weights and biases.
# This is just the 'center and normalise' part of LayerNorm
@@ -368,10 +459,12 @@ def forward(
Float[torch.Tensor, "batch pos head_index d_model"],
],
past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None,
+ additive_attention_mask: Float[torch.Tensor, "batch 1 1 pos"] = None,
) -> Float[torch.Tensor, "batch pos d_model"]:
"""
shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details
past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None
+ additive_attention_mask is an optional mask to add to the attention weights. Defaults to None.
"""
if self.cfg.use_split_qkv_input:
@@ -421,8 +514,8 @@ def forward(
attn_scores = (
einsum(
"batch query_pos head_index d_head, \
- batch key_pos head_index d_head \
- -> batch head_index query_pos key_pos",
+ batch key_pos head_index d_head \
+ -> batch head_index query_pos key_pos",
q,
k,
)
@@ -433,6 +526,9 @@ def forward(
attn_scores = self.apply_causal_mask(
attn_scores, kv_cache_pos_offset
) # [batch, head_index, query_pos, key_pos]
+ if additive_attention_mask is not None:
+ attn_scores += additive_attention_mask
+
attn_scores = self.hook_attn_scores(attn_scores)
pattern = self.hook_pattern(
F.softmax(attn_scores, dim=-1)
@@ -451,8 +547,8 @@ def forward(
(
einsum(
"batch pos head_index d_head, \
- head_index d_head d_model -> \
- batch pos d_model",
+ head_index d_head d_model -> \
+ batch pos d_model",
z,
self.W_O,
)
@@ -853,3 +949,70 @@ def add_head_dimension(tensor):
resid_pre + attn_out
) # [batch, pos, d_model]
return resid_post
+
+
+class BertBlock(nn.Module):
+ """
+ BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before.
+ """
+
+ def __init__(self, cfg: HookedTransformerConfig):
+ super().__init__()
+ self.cfg = cfg
+
+ self.attn = Attention(cfg)
+ self.ln1 = LayerNorm(cfg)
+ self.mlp = MLP(cfg)
+ self.ln2 = LayerNorm(cfg)
+
+ self.hook_q_input = HookPoint() # [batch, pos, d_model]
+ self.hook_k_input = HookPoint() # [batch, pos, d_model]
+ self.hook_v_input = HookPoint() # [batch, pos, d_model]
+
+ self.hook_attn_out = HookPoint() # [batch, pos, d_model]
+ self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
+ self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
+ self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
+ self.hook_resid_post = HookPoint() # [batch, pos, d_model]
+ self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model]
+
+ def forward(
+ self,
+ resid_pre: Float[torch.Tensor, "batch pos d_model"],
+ additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None,
+ ):
+ resid_pre = self.hook_resid_pre(resid_pre)
+
+ query_input = resid_pre
+ key_input = resid_pre
+ value_input = resid_pre
+
+ if self.cfg.use_split_qkv_input:
+
+ def add_head_dimension(tensor):
+ return einops.repeat(
+ tensor,
+ "batch pos d_model -> batch pos n_heads d_model",
+ n_heads=self.cfg.n_heads,
+ ).clone()
+
+ query_input = self.hook_q_input(add_head_dimension(query_input))
+ key_input = self.hook_k_input(add_head_dimension(key_input))
+ value_input = self.hook_v_input(add_head_dimension(value_input))
+
+ attn_out = self.hook_attn_out(
+ self.attn(
+ query_input,
+ key_input,
+ value_input,
+ additive_attention_mask=additive_attention_mask,
+ )
+ )
+ resid_mid = self.hook_resid_mid(resid_pre + attn_out)
+ normalized_resid_mid = self.ln1(resid_mid)
+
+ mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid))
+ resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out)
+ normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post))
+
+ return normalized_resid_post
diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py
index 8c51bdea1..b638cc97c 100644
--- a/transformer_lens/loading_from_pretrained.py
+++ b/transformer_lens/loading_from_pretrained.py
@@ -6,7 +6,7 @@
import einops
import torch
from huggingface_hub import HfApi
-from transformers import AutoConfig, AutoModelForCausalLM
+from transformers import AutoConfig, AutoModelForCausalLM, BertForPreTraining
import transformer_lens.utils as utils
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
@@ -104,6 +104,7 @@
"llama-30b-hf",
"llama-65b-hf",
"Baidicoot/Othello-GPT-Transformer-Lens",
+ "bert-base-cased",
]
# Model Aliases:
@@ -442,7 +443,7 @@ def get_official_model_name(model_name: str):
return official_model_name
-def convert_hf_model_config(official_model_name: str):
+def convert_hf_model_config(model_name: str):
"""
Returns the model config for a HuggingFace model, converted to a dictionary
in the HookedTransformerConfig format.
@@ -450,7 +451,7 @@ def convert_hf_model_config(official_model_name: str):
Takes the official_model_name as an input.
"""
# In case the user passed in an alias
- official_model_name = get_official_model_name(official_model_name)
+ official_model_name = get_official_model_name(model_name)
# Load HuggingFace model config
if "llama" not in official_model_name:
hf_config = AutoConfig.from_pretrained(official_model_name)
@@ -614,6 +615,19 @@ def convert_hf_model_config(official_model_name: str):
}
rotary_pct = hf_config.rotary_pct
cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"])
+ elif architecture == "BertForMaskedLM":
+ cfg_dict = {
+ "d_model": hf_config.hidden_size,
+ "d_head": hf_config.hidden_size // hf_config.num_attention_heads,
+ "n_heads": hf_config.num_attention_heads,
+ "d_mlp": hf_config.intermediate_size,
+ "n_layers": hf_config.num_hidden_layers,
+ "n_ctx": hf_config.max_position_embeddings,
+ "eps": hf_config.layer_norm_eps,
+ "d_vocab": hf_config.vocab_size,
+ "act_fn": "gelu",
+ "attention_dir": "bidirectional",
+ }
else:
raise NotImplementedError(f"{architecture} is not currently supported.")
# All of these models use LayerNorm
@@ -726,7 +740,6 @@ def get_pretrained_model_config(
cfg_dict["normalization_type"] = "LNPre"
else:
logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
- pass
if checkpoint_index is not None or checkpoint_value is not None:
checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
@@ -863,6 +876,8 @@ def get_pretrained_state_dict(
elif hf_model is None:
if "llama" in official_model_name:
raise NotImplementedError("Must pass in hf_model for LLaMA models")
+ elif "bert" in official_model_name:
+ hf_model = BertForPreTraining.from_pretrained(official_model_name)
else:
hf_model = AutoModelForCausalLM.from_pretrained(official_model_name)
@@ -879,6 +894,8 @@ def get_pretrained_state_dict(
state_dict = convert_neox_weights(hf_model, cfg)
elif cfg.original_architecture == "LLaMAForCausalLM":
state_dict = convert_llama_weights(hf_model, cfg)
+ elif cfg.original_architecture == "BertForMaskedLM":
+ state_dict = convert_bert_weights(hf_model, cfg)
else:
raise ValueError(
f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature."
@@ -887,6 +904,36 @@ def get_pretrained_state_dict(
return state_dict
+def fill_missing_keys(model, state_dict):
+ """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.
+
+ This function is assumed to be run before weights are initialized.
+
+ Args:
+ state_dict (dict): State dict from a pretrained model
+
+ Returns:
+ dict: State dict with missing keys filled in
+ """
+ # Get the default state dict
+ default_state_dict = model.state_dict()
+ # Get the keys that are missing from the pretrained model
+ missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
+ # Fill in the missing keys with the default initialization
+ for key in missing_keys:
+ if "hf_model" in key:
+ # Skip keys that are from the HuggingFace model, if loading from HF.
+ continue
+ if "W_" in key:
+ logging.warning(
+ "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
+ key
+ )
+ )
+ state_dict[key] = default_state_dict[key]
+ return state_dict
+
+
# %%
def convert_state_dict(
state_dict: dict,
@@ -964,7 +1011,7 @@ def convert_gpt2_weights(gpt2, cfg: HookedTransformerConfig):
W_out = gpt2.transformer.h[l].mlp.c_proj.weight
state_dict[f"blocks.{l}.mlp.W_out"] = W_out
state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias
- state_dict[f"unembed.W_U"] = gpt2.lm_head.weight.T
+ state_dict["unembed.W_U"] = gpt2.lm_head.weight.T
state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight
state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias
@@ -1270,8 +1317,8 @@ def convert_opt_weights(opt, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.mlp.b_in"] = opt.model.decoder.layers[l].fc1.bias
state_dict[f"blocks.{l}.mlp.b_out"] = opt.model.decoder.layers[l].fc2.bias
- state_dict[f"ln_final.w"] = opt.model.decoder.final_layer_norm.weight
- state_dict[f"ln_final.b"] = opt.model.decoder.final_layer_norm.bias
+ state_dict["ln_final.w"] = opt.model.decoder.final_layer_norm.weight
+ state_dict["ln_final.b"] = opt.model.decoder.final_layer_norm.bias
state_dict["unembed.W_U"] = opt.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab)
return state_dict
@@ -1368,7 +1415,7 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T
state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"]
- state_dict[f"unembed.W_U"] = old_state_dict["head.weight"].T
+ state_dict["unembed.W_U"] = old_state_dict["head.weight"].T
state_dict["ln_final.w"] = old_state_dict["ln_f.weight"]
state_dict["ln_final.b"] = old_state_dict["ln_f.bias"]
@@ -1376,4 +1423,63 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig):
return state_dict
-# %%
+def convert_bert_weights(bert, cfg: HookedTransformerConfig):
+ embeddings = bert.bert.embeddings
+ state_dict = {
+ "embed.embed.W_E": embeddings.word_embeddings.weight,
+ "embed.pos_embed.W_pos": embeddings.position_embeddings.weight,
+ "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight,
+ "embed.ln.w": embeddings.LayerNorm.weight,
+ "embed.ln.b": embeddings.LayerNorm.bias,
+ }
+
+ for l in range(cfg.n_layers):
+ block = bert.bert.encoder.layer[l]
+ state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange(
+ block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads
+ )
+ state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange(
+ block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads
+ )
+ state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange(
+ block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads
+ )
+ state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange(
+ block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads
+ )
+ state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange(
+ block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads
+ )
+ state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange(
+ block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads
+ )
+ state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange(
+ block.attention.output.dense.weight,
+ "m (i h) -> i h m",
+ i=cfg.n_heads,
+ )
+ state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias
+ state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight
+ state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias
+ state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange(
+ block.intermediate.dense.weight, "mlp model -> model mlp"
+ )
+ state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias
+ state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange(
+ block.output.dense.weight, "model mlp -> mlp model"
+ )
+ state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias
+ state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight
+ state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias
+
+ mlm_head = bert.cls.predictions
+ state_dict["mlm_head.W"] = mlm_head.transform.dense.weight
+ state_dict["mlm_head.b"] = mlm_head.transform.dense.bias
+ state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight
+ state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias
+ # Note: BERT uses tied embeddings
+ state_dict["unembed.W_U"] = embeddings.word_embeddings.weight.T
+ # "unembed.W_U": mlm_head.decoder.weight.T,
+ state_dict["unembed.b_U"] = mlm_head.bias
+
+ return state_dict
diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py
index bf260ba5f..881a63ecb 100644
--- a/transformer_lens/utilities/devices.py
+++ b/transformer_lens/utilities/devices.py
@@ -1,6 +1,13 @@
-from typing import Optional, Union
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Optional, Union
import torch
+from torch import nn
+
+if TYPE_CHECKING:
+ from transformer_lens.HookedEncoder import HookedEncoder
+ from transformer_lens.HookedTransformer import HookedTransformer
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig
@@ -32,3 +39,29 @@ def get_device_for_block_index(
device = torch.device(device)
device_index = (device.index or 0) + (index // layers_per_device)
return torch.device(device.type, device_index)
+
+
+def move_to_and_update_config(
+ model: Union[HookedTransformer, HookedEncoder],
+ device_or_dtype: Union[torch.device, str, torch.dtype],
+ print_details=True,
+):
+ """
+ Wrapper around to that also changes model.cfg.device if it's a torch.device or string.
+ If torch.dtype, just passes through
+ """
+ if isinstance(device_or_dtype, torch.device):
+ model.cfg.device = device_or_dtype.type
+ if print_details:
+ print("Moving model to device: ", model.cfg.device)
+ elif isinstance(device_or_dtype, str):
+ model.cfg.device = device_or_dtype
+ if print_details:
+ print("Moving model to device: ", model.cfg.device)
+ elif isinstance(device_or_dtype, torch.dtype):
+ if print_details:
+ print("Changing model dtype to", device_or_dtype)
+ # change state_dict dtypes
+ for k, v in model.state_dict().items():
+ model.state_dict()[k] = v.to(device_or_dtype)
+ return nn.Module.to(model, device_or_dtype)