From c268a7159a6f8d5c78236a3f958f2d704fbc940f Mon Sep 17 00:00:00 2001 From: Rusheb Shah Date: Fri, 19 May 2023 10:31:51 +0100 Subject: [PATCH] Introducing HookedEncoder (#276) Introducing HookedEncoder, a BERT-style encoder inheriting from HookedRootModule. Supports ActivationCache, with hooks on all interesting activations. Weights can be loaded from the huggingface bert-base-cased pretrained model. Unlike HookedTransformer, it does not (yet) do any pre-processing of the weights (e.g. folding LayerNorm). Another difference is that the model can currently only be run with tokens, and not strings or list of strings. Currently, the supported task/architecture is masked language modelling. Next sentence prediction, causal language modelling, and other tasks are not supported. The HookedEncoder does not contain dropouts, which may lead to inconsistent results when pretraining. Changes: - Add new class HookedEncoder - Add new components - TokenTypeEmbed - BertEmbed - BertMLMHead - BertBlock - Add `additive_attention_mask` parameter to forward method of `Attention` component - Add BERT config and state dict to loading_from_pretrained - Extract methods from HookedTransformer for reuse: - devices.move_to_and_update_config - lodaing.fill_missing_keys - Add demo notebook `demos/BERT.ipynb` - Update Available Models list in Main Demo - Testing - Unit and acceptance tests for HookedEncoder and sub-components - New demo in `demos/BERT.ipynb` also acts as a test - I also added some tests for existing components e.g. HookedTransformerConfig Future work: https://github.com/neelnanda-io/TransformerLens/issues/277 --- .gitignore | 2 +- demos/BERT.ipynb | 258 ++++++++++++ demos/Main_Demo.ipynb | 4 + pyproject.toml | 1 + tests/acceptance/test_hooked_encoder.py | 155 +++++++ ...mer_lens.py => test_hooked_transformer.py} | 0 tests/unit/test_config.py | 25 ++ tests/unit/test_create_hooked_encoder.py | 35 ++ transformer_lens/HookedEncoder.py | 397 ++++++++++++++++++ transformer_lens/HookedTransformer.py | 54 +-- transformer_lens/HookedTransformerConfig.py | 4 +- transformer_lens/__init__.py | 1 + transformer_lens/components.py | 171 +++++++- transformer_lens/loading_from_pretrained.py | 124 +++++- transformer_lens/utilities/devices.py | 35 +- 15 files changed, 1203 insertions(+), 63 deletions(-) create mode 100644 demos/BERT.ipynb create mode 100644 tests/acceptance/test_hooked_encoder.py rename tests/acceptance/{test_transformer_lens.py => test_hooked_transformer.py} (100%) create mode 100644 tests/unit/test_config.py create mode 100644 tests/unit/test_create_hooked_encoder.py create mode 100644 transformer_lens/HookedEncoder.py diff --git a/.gitignore b/.gitignore index 2aea0a4cb..a5d6d42b7 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,6 @@ _modidx.py env dist/ docs/build -.coverage +.coverage* .Ds_Store .pylintrc diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb new file mode 100644 index 000000000..6d227ebea --- /dev/null +++ b/demos/BERT.ipynb @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BERT in TransformerLens\n", + "This demo shows how to use BERT in TransformerLens for the Masked Language Modelling task." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup\n", + "(No need to read)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + } + ], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", + " %pip install circuitsvis\n", + " \n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using renderer: colab\n" + ] + } + ], + "source": [ + "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", + "import plotly.io as pio\n", + "if IN_COLAB or not DEVELOPMENT_MODE:\n", + " pio.renderers.default = \"colab\"\n", + "else:\n", + " pio.renderers.default = \"notebook_connected\"\n", + "print(f\"Using renderer: {pio.renderers.default}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import circuitsvis as cv\n", + "# Testing that the library works\n", + "cv.examples.hello(\"Neel\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Import stuff\n", + "import torch\n", + "\n", + "from transformers import AutoTokenizer\n", + "\n", + "from transformer_lens import HookedEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.set_grad_enabled(False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BERT\n", + "\n", + "In this section, we will load a pretrained BERT model and use it for the Masked Language Modelling task" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Moving model to device: cpu\n", + "Loaded pretrained model bert-base-cased into HookedTransformer\n" + ] + } + ], + "source": [ + "bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n", + "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the \"[MASK]\" token to mask any tokens which you would like the model to predict." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\"\n", + "\n", + "input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n", + "mask_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\n", + "Prediction: \"Systems\"\n" + ] + } + ], + "source": [ + "logprobs = bert(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n", + "prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"Prediction: \\\"{prediction}\\\"\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Better luck next time, BERT." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index aed5af6ba..6922d18ff 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -1091,6 +1091,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1110,6 +1111,9 @@ "* **GPT-NeoX** - Eleuther's 20B parameter model, trained on the Pile\n", "* **Stanford CRFM models** - a replication of GPT-2 Small and GPT-2 Medium, trained on 5 different random seeds.\n", " * Notably, 600 checkpoints were taken during training per model, and these are available in the library with eg `HookedTransformer.from_pretrained(\"stanford-gpt2-small-a\", checkpoint_index=265)`.\n", + "- **BERT** - Google's bidirectional encoder-only transformer.\n", + " - Size Base (108M), trained on English Wikipedia and BooksCorpus.\n", + " \n", "" ] }, diff --git a/pyproject.toml b/pyproject.toml index 3e5e9b9aa..5a56131dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ docs = ["Sphinx", "sphinx-autobuild", "sphinxcontrib-napoleon", "furo", "myst_pa [tool.pytest.ini_options] filterwarnings = [ + "ignore:pkg_resources is deprecated as an API:DeprecationWarning", # Ignore numpy.distutils deprecation warning caused by pandas # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils "ignore:distutils Version classes are deprecated:DeprecationWarning" diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py new file mode 100644 index 000000000..6f95d9b27 --- /dev/null +++ b/tests/acceptance/test_hooked_encoder.py @@ -0,0 +1,155 @@ +from typing import List + +import pytest +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch.testing import assert_close +from transformers import AutoTokenizer, BertForMaskedLM + +from transformer_lens import HookedEncoder + +MODEL_NAME = "bert-base-cased" + + +@pytest.fixture(scope="module") +def our_bert(): + return HookedEncoder.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def huggingface_bert(): + return BertForMaskedLM.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture +def hello_world_tokens(tokenizer): + return tokenizer("Hello, world!", return_tensors="pt")["input_ids"] + + +def test_full_model(our_bert, huggingface_bert, tokenizer): + sequences = [ + "Hello, world!", + "this is another sequence of tokens", + ] + tokenized = tokenizer(sequences, return_tensors="pt", padding=True) + input_ids = tokenized["input_ids"] + attention_mask = tokenized["attention_mask"] + + huggingface_bert_out = huggingface_bert( + input_ids, attention_mask=attention_mask + ).logits + our_bert_out = our_bert(input_ids, one_zero_attention_mask=attention_mask) + assert_close(huggingface_bert_out, our_bert_out, rtol=1.3e-6, atol=4e-5) + + +def test_embed_one_sentence(our_bert, huggingface_bert, hello_world_tokens): + huggingface_embed = huggingface_bert.bert.embeddings + our_embed = our_bert.embed + + huggingface_embed_out = huggingface_embed(hello_world_tokens)[0] + our_embed_out = our_embed(hello_world_tokens).squeeze(0) + assert_close(huggingface_embed_out, our_embed_out) + + +def test_embed_two_sentences(our_bert, huggingface_bert, tokenizer): + encoding = tokenizer("First sentence.", "Second sentence.", return_tensors="pt") + input_ids = encoding["input_ids"] + token_type_ids = encoding["token_type_ids"] + + huggingface_embed_out = huggingface_bert.bert.embeddings( + input_ids, token_type_ids=token_type_ids + )[0] + our_embed_out = our_bert.embed(input_ids, token_type_ids=token_type_ids).squeeze(0) + assert_close(huggingface_embed_out, our_embed_out) + + +def test_attention(our_bert, huggingface_bert, hello_world_tokens): + huggingface_embed = huggingface_bert.bert.embeddings + huggingface_attn = huggingface_bert.bert.encoder.layer[0].attention + + embed_out = huggingface_embed(hello_world_tokens) + + our_attn = our_bert.blocks[0].attn + + our_attn_out = our_attn(embed_out, embed_out, embed_out) + huggingface_self_attn_out = huggingface_attn.self(embed_out)[0] + huggingface_attn_out = huggingface_attn.output.dense(huggingface_self_attn_out) + assert_close(our_attn_out, huggingface_attn_out) + + +def test_bert_block(our_bert, huggingface_bert, hello_world_tokens): + huggingface_embed = huggingface_bert.bert.embeddings + huggingface_block = huggingface_bert.bert.encoder.layer[0] + + embed_out = huggingface_embed(hello_world_tokens) + + our_block = our_bert.blocks[0] + + our_block_out = our_block(embed_out) + huggingface_block_out = huggingface_block(embed_out)[0] + assert_close(our_block_out, huggingface_block_out) + + +def test_mlm_head(our_bert, huggingface_bert, hello_world_tokens): + huggingface_bert_core_outputs = huggingface_bert.bert( + hello_world_tokens + ).last_hidden_state + + our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs) + our_unembed_out = our_bert.unembed(our_mlm_head_out) + huggingface_predictions_out = huggingface_bert.cls.predictions( + huggingface_bert_core_outputs + ) + + assert_close(our_unembed_out, huggingface_predictions_out, rtol=1.3e-6, atol=4e-5) + + +def test_unembed(our_bert, huggingface_bert, hello_world_tokens): + huggingface_bert_core_outputs = huggingface_bert.bert( + hello_world_tokens + ).last_hidden_state + + our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs) + huggingface_predictions_out = huggingface_bert.cls.predictions.transform( + huggingface_bert_core_outputs + ) + + print((our_mlm_head_out - huggingface_predictions_out).abs().max()) + assert_close(our_mlm_head_out, huggingface_predictions_out, rtol=1.3e-3, atol=1e-5) + + +def test_run_with_cache(our_bert, huggingface_bert, hello_world_tokens): + model = HookedEncoder.from_pretrained("bert-base-cased") + logits, cache = model.run_with_cache(hello_world_tokens) + + # check that an arbitrary subset of the keys exist + assert "embed.hook_embed" in cache + assert "blocks.0.attn.hook_q" in cache + assert "blocks.3.attn.hook_attn_scores" in cache + assert "blocks.7.hook_resid_post" in cache + assert "mlm_head.ln.hook_normalized" in cache + + +def test_predictions(our_bert, huggingface_bert, tokenizer): + input_ids = tokenizer("The [MASK] sat on the mat", return_tensors="pt")["input_ids"] + + def get_predictions( + logits: Float[torch.Tensor, "batch pos d_vocab"], positions: List[int] + ): + logits_at_position = logits.squeeze(0)[positions] + predicted_tokens = F.softmax(logits_at_position, dim=-1).argmax(dim=-1) + return tokenizer.batch_decode(predicted_tokens) + + our_bert_out = our_bert(input_ids) + our_prediction = get_predictions(our_bert_out, [2]) + + huggingface_bert_out = huggingface_bert(input_ids).logits + huggingface_prediction = get_predictions(huggingface_bert_out, [2]) + + assert our_prediction == huggingface_prediction diff --git a/tests/acceptance/test_transformer_lens.py b/tests/acceptance/test_hooked_transformer.py similarity index 100% rename from tests/acceptance/test_transformer_lens.py rename to tests/acceptance/test_hooked_transformer.py diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 000000000..7e3de330e --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,25 @@ +""" +Tests that verify than an arbitrary component (e.g. Embed) can be initialized using dict and object versions of HookedTransformerConfig and HookedEncoderConfig. +""" + +from transformer_lens.components import Embed +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def test_hooked_transformer_config_object(): + hooked_transformer_config = HookedTransformerConfig( + n_layers=2, d_vocab=100, d_model=6, n_ctx=5, d_head=2, attn_only=True + ) + Embed(hooked_transformer_config) + + +def test_hooked_transformer_config_dict(): + hooked_transformer_config_dict = { + "n_layers": 2, + "d_vocab": 100, + "d_model": 6, + "n_ctx": 5, + "d_head": 2, + "attn_only": True, + } + Embed(hooked_transformer_config_dict) diff --git a/tests/unit/test_create_hooked_encoder.py b/tests/unit/test_create_hooked_encoder.py new file mode 100644 index 000000000..a1adc7ef3 --- /dev/null +++ b/tests/unit/test_create_hooked_encoder.py @@ -0,0 +1,35 @@ +import pytest +from transformers import AutoTokenizer, BertTokenizerFast + +from transformer_lens import HookedEncoder, HookedTransformerConfig + + +@pytest.fixture +def cfg(): + return HookedTransformerConfig( + d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu" + ) + + +def test_pass_tokenizer(cfg): + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + model = HookedEncoder(cfg, tokenizer=tokenizer) + assert model.tokenizer == tokenizer + + +def test_load_tokenizer_from_config(cfg): + cfg.tokenizer_name = "bert-base-cased" + model = HookedEncoder(cfg) + assert isinstance(model.tokenizer, BertTokenizerFast) + + +def test_load_without_tokenizer(cfg): + cfg.d_vocab = 22 + model = HookedEncoder(cfg) + assert model.tokenizer is None + + +def test_cannot_load_without_tokenizer_or_d_vocab(cfg): + with pytest.raises(AssertionError) as e: + HookedEncoder(cfg) + assert "Must provide a tokenizer if d_vocab is not provided" in str(e.value) diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py new file mode 100644 index 000000000..3ee5e2ad7 --- /dev/null +++ b/transformer_lens/HookedEncoder.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import Dict, Optional, Tuple, Union, cast, overload + +import torch +from einops import repeat +from jaxtyping import Float, Int +from torch import nn +from transformers import AutoTokenizer +from typeguard import typeguard_ignore +from typing_extensions import Literal + +import transformer_lens.loading_from_pretrained as loading +from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformerConfig +from transformer_lens.components import BertBlock, BertEmbed, BertMLMHead, Unembed +from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.utilities import devices + + +class HookedEncoder(HookedRootModule): + """ + This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. + + Limitations: + - The current MVP implementation supports only the masked language modelling (MLM) task. Next sentence prediction (NSP), causal language modelling, and other tasks are not yet supported. + - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. + + Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: + - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model + - The model only accepts tokens as inputs, and not strings, or lists of strings + """ + + def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig(**cfg) + elif isinstance(cfg, str): + raise ValueError( + "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." + ) + self.cfg = cfg + + assert ( + self.cfg.n_devices == 1 + ), "Multiple devices not supported for HookedEncoder" + if move_to_device: + self.to(self.cfg.device) + + if tokenizer is not None: + self.tokenizer = tokenizer + elif self.cfg.tokenizer_name is not None: + self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name) + else: + self.tokenizer = None + + if self.cfg.d_vocab == -1: + # If we have a tokenizer, vocab size can be inferred from it. + assert ( + self.tokenizer is not None + ), "Must provide a tokenizer if d_vocab is not provided" + self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 + if self.cfg.d_vocab_out == -1: + self.cfg.d_vocab_out = self.cfg.d_vocab + + self.embed = BertEmbed(self.cfg) + self.blocks = nn.ModuleList( + [BertBlock(self.cfg) for _ in range(self.cfg.n_layers)] + ) + self.mlm_head = BertMLMHead(cfg) + self.unembed = Unembed(self.cfg) + + self.hook_full_embed = HookPoint() + + self.setup() + + @overload + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + return_type: Literal["logits"], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_vocab"]: + ... + + @overload + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + return_type: Literal[None], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: + ... + + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + return_type: Optional[str] = "logits", + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: + """Input must be a batch of tokens. Strings and lists of strings are not yet supported. + + return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits). + + token_type_ids Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). + + one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to. + """ + + tokens = input + + if tokens.device.type != self.cfg.device: + tokens = tokens.to(self.cfg.device) + if one_zero_attention_mask is not None: + one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) + + resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) + + large_negative_number = -1e5 + additive_attention_mask = ( + large_negative_number + * repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") + if one_zero_attention_mask is not None + else None + ) + + for block in self.blocks: + resid = block(resid, additive_attention_mask) + resid = self.mlm_head(resid) + + if return_type is None: + return + + logits = self.unembed(resid) + return logits + + @overload + def run_with_cache( + self, *model_args, return_cache_object: Literal[True] = True, **kwargs + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: + ... + + @overload + def run_with_cache( + self, *model_args, return_cache_object: Literal[False] = False, **kwargs + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: + ... + + def run_with_cache( + self, + *model_args, + return_cache_object: bool = True, + remove_batch_dim: bool = False, + **kwargs, + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + Union[ActivationCache, Dict[str, torch.Tensor]], + ]: + """ + Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. + """ + out, cache_dict = super().run_with_cache( + *model_args, remove_batch_dim=remove_batch_dim, **kwargs + ) + if return_cache_object: + cache = ActivationCache( + cache_dict, self, has_batch_dim=not remove_batch_dim + ) + return out, cache + else: + return out, cache_dict + + def to( + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) + + def cuda(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("cuda") + + def cpu(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("cpu") + + @classmethod + def from_pretrained( + cls, + model_name: str, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, + hf_model=None, + device: Optional[str] = None, + **model_kwargs, + ) -> HookedEncoder: + """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" + logging.warning( + "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " + "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " + "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " + "implementation." + "\n" + "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " + "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " + "that the last LayerNorm in a block cannot be folded." + ) + + official_model_name = loading.get_official_model_name(model_name) + + cfg = loading.get_pretrained_model_config( + official_model_name, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, + fold_ln=False, + device=device, + n_devices=1, + ) + + state_dict = loading.get_pretrained_state_dict( + official_model_name, cfg, hf_model + ) + + model = cls(cfg, **model_kwargs) + + model.load_state_dict(state_dict, strict=False) + + print(f"Loaded pretrained model {model_name} into HookedTransformer") + + return model + + @property + @typeguard_ignore + def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: + """ + Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) + """ + return self.unembed.W_U + + @property + @typeguard_ignore + def b_U(self) -> Float[torch.Tensor, "d_vocab"]: + return self.unembed.b_U + + @property + @typeguard_ignore + def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: + """ + Convenience to get the embedding matrix + """ + return self.embed.embed.W_E + + @property + @typeguard_ignore + def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: + """ + Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! + """ + return self.embed.pos_embed.W_pos + + @property + @typeguard_ignore + def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: + """ + Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. + """ + return torch.cat([self.W_E, self.W_pos], dim=0) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the key weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the query weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the value weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: + """Stacks the attn output weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: + """Stacks the MLP input weights across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: + """Stacks the MLP output weights across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the key biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the query biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the value biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the attn output biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: + """Stacks the MLP input biases across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the MLP output biases across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) + + @property + @typeguard_ignore + def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + return FactoredMatrix(self.W_V, self.W_O) + + def all_head_labels(self) -> list[str]: + return [ + f"L{l}H{h}" + for l in range(self.cfg.n_layers) + for h in range(self.cfg.n_heads) + ] diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 6d2489044..6533c85a7 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -721,26 +721,12 @@ def tokens_to_residual_directions( residual_direction = self.W_U[:, token] return residual_direction - def to(self, device_or_dtype, print_details=True): - """ - Wrapper around to that also changes self.cfg.device if it's a torch.device or string. - If torch.dtype, just passes through - """ - if isinstance(device_or_dtype, torch.device): - self.cfg.device = device_or_dtype.type - if print_details: - print("Moving model to device: ", self.cfg.device) - elif isinstance(device_or_dtype, str): - self.cfg.device = device_or_dtype - if print_details: - print("Moving model to device: ", self.cfg.device) - elif isinstance(device_or_dtype, torch.dtype): - if print_details: - print("Changing model dtype to", device_or_dtype) - # change state_dict dtypes - for k, v in self.state_dict().items(): - self.state_dict()[k] = v.to(device_or_dtype) - return nn.Module.to(self, device_or_dtype) + def to( + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) def cuda(self): # Wrapper around cuda that also changes self.cfg.device @@ -1033,33 +1019,7 @@ def load_and_process_state_dict( self.load_state_dict(state_dict) def fill_missing_keys(self, state_dict): - """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. - - This function is assumed to be run before weights are initialized. - - Args: - state_dict (dict): State dict from a pretrained model - - Returns: - dict: State dict with missing keys filled in - """ - # Get the default state dict - default_state_dict = self.state_dict() - # Get the keys that are missing from the pretrained model - missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) - # Fill in the missing keys with the default initialization - for key in missing_keys: - if "hf_model" in key: - # Skip keys that are from the HuggingFace model, if loading from HF. - continue - if "W_" in key: - logging.warning( - "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( - key - ) - ) - state_dict[key] = default_state_dict[key] - return state_dict + return loading.fill_missing_keys(self, state_dict) def fold_layer_norm(self, state_dict: Dict[str, torch.Tensor]): """Takes in a state dict from a pretrained model, formatted to be consistent with HookedTransformer but with diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index b33ffcefa..becf62821 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import pprint import random @@ -226,7 +228,7 @@ def __post_init__(self): ), f"Not enough CUDA devices to support n_devices {self.n_devices}" @classmethod - def from_dict(cls, config_dict: Dict[str, Any]): + def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig: """ Instantiates a `HookedTransformerConfig` from a Python dictionary of parameters. diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 2994e9768..fce9f93fa 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -10,6 +10,7 @@ from .FactoredMatrix import FactoredMatrix from .ActivationCache import ActivationCache from .HookedTransformer import HookedTransformer +from .HookedEncoder import HookedEncoder from . import head_detector from . import loading_from_pretrained as loading from . import patching diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 16fa7053b..09b1b069b 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -90,6 +90,97 @@ def forward( return broadcast_pos_embed.clone() +class TokenTypeEmbed(nn.Module): + """ + The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). + + See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model)) + + def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]): + return self.W_token_type[token_type_ids, :] + + +class BertEmbed(nn.Module): + """ + Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.embed = Embed(cfg) + self.pos_embed = PosEmbed(cfg) + self.token_type_embed = TokenTypeEmbed(cfg) + self.ln = LayerNorm(cfg) + + self.hook_embed = HookPoint() + self.hook_pos_embed = HookPoint() + self.hook_token_type_embed = HookPoint() + + def forward( + self, + input_ids: Int[torch.Tensor, "batch pos"], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + ): + base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) + index_ids = einops.repeat( + base_index_id, "pos -> batch pos", batch=input_ids.shape[0] + ) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + word_embeddings_out = self.hook_embed(self.embed(input_ids)) + position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) + token_type_embeddings_out = self.hook_token_type_embed( + self.token_type_embed(token_type_ids) + ) + + embeddings_out = ( + word_embeddings_out + position_embeddings_out + token_type_embeddings_out + ) + layer_norm_out = self.ln(embeddings_out) + return layer_norm_out + + +class BertMLMHead(nn.Module): + """ + Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence. + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W = nn.Parameter(torch.empty(cfg.d_model, cfg.d_model)) + self.b = nn.Parameter(torch.zeros(cfg.d_model)) + self.act_fn = nn.GELU() + self.ln = LayerNorm(cfg) + + def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor: + resid = ( + einsum( + "batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out", + resid, + self.W, + ) + + self.b + ) + resid = self.act_fn(resid) + resid = self.ln(resid) + return resid + + # LayerNormPre # I fold the LayerNorm weights and biases into later weights and biases. # This is just the 'center and normalise' part of LayerNorm @@ -368,10 +459,12 @@ def forward( Float[torch.Tensor, "batch pos head_index d_model"], ], past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, + additive_attention_mask: Float[torch.Tensor, "batch 1 1 pos"] = None, ) -> Float[torch.Tensor, "batch pos d_model"]: """ shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None + additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. """ if self.cfg.use_split_qkv_input: @@ -421,8 +514,8 @@ def forward( attn_scores = ( einsum( "batch query_pos head_index d_head, \ - batch key_pos head_index d_head \ - -> batch head_index query_pos key_pos", + batch key_pos head_index d_head \ + -> batch head_index query_pos key_pos", q, k, ) @@ -433,6 +526,9 @@ def forward( attn_scores = self.apply_causal_mask( attn_scores, kv_cache_pos_offset ) # [batch, head_index, query_pos, key_pos] + if additive_attention_mask is not None: + attn_scores += additive_attention_mask + attn_scores = self.hook_attn_scores(attn_scores) pattern = self.hook_pattern( F.softmax(attn_scores, dim=-1) @@ -451,8 +547,8 @@ def forward( ( einsum( "batch pos head_index d_head, \ - head_index d_head d_model -> \ - batch pos d_model", + head_index d_head d_model -> \ + batch pos d_model", z, self.W_O, ) @@ -853,3 +949,70 @@ def add_head_dimension(tensor): resid_pre + attn_out ) # [batch, pos, d_model] return resid_post + + +class BertBlock(nn.Module): + """ + BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before. + """ + + def __init__(self, cfg: HookedTransformerConfig): + super().__init__() + self.cfg = cfg + + self.attn = Attention(cfg) + self.ln1 = LayerNorm(cfg) + self.mlp = MLP(cfg) + self.ln2 = LayerNorm(cfg) + + self.hook_q_input = HookPoint() # [batch, pos, d_model] + self.hook_k_input = HookPoint() # [batch, pos, d_model] + self.hook_v_input = HookPoint() # [batch, pos, d_model] + + self.hook_attn_out = HookPoint() # [batch, pos, d_model] + self.hook_mlp_out = HookPoint() # [batch, pos, d_model] + self.hook_resid_pre = HookPoint() # [batch, pos, d_model] + self.hook_resid_mid = HookPoint() # [batch, pos, d_model] + self.hook_resid_post = HookPoint() # [batch, pos, d_model] + self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model] + + def forward( + self, + resid_pre: Float[torch.Tensor, "batch pos d_model"], + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, + ): + resid_pre = self.hook_resid_pre(resid_pre) + + query_input = resid_pre + key_input = resid_pre + value_input = resid_pre + + if self.cfg.use_split_qkv_input: + + def add_head_dimension(tensor): + return einops.repeat( + tensor, + "batch pos d_model -> batch pos n_heads d_model", + n_heads=self.cfg.n_heads, + ).clone() + + query_input = self.hook_q_input(add_head_dimension(query_input)) + key_input = self.hook_k_input(add_head_dimension(key_input)) + value_input = self.hook_v_input(add_head_dimension(value_input)) + + attn_out = self.hook_attn_out( + self.attn( + query_input, + key_input, + value_input, + additive_attention_mask=additive_attention_mask, + ) + ) + resid_mid = self.hook_resid_mid(resid_pre + attn_out) + normalized_resid_mid = self.ln1(resid_mid) + + mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) + resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out) + normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post)) + + return normalized_resid_post diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8c51bdea1..b638cc97c 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -6,7 +6,7 @@ import einops import torch from huggingface_hub import HfApi -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, BertForPreTraining import transformer_lens.utils as utils from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -104,6 +104,7 @@ "llama-30b-hf", "llama-65b-hf", "Baidicoot/Othello-GPT-Transformer-Lens", + "bert-base-cased", ] # Model Aliases: @@ -442,7 +443,7 @@ def get_official_model_name(model_name: str): return official_model_name -def convert_hf_model_config(official_model_name: str): +def convert_hf_model_config(model_name: str): """ Returns the model config for a HuggingFace model, converted to a dictionary in the HookedTransformerConfig format. @@ -450,7 +451,7 @@ def convert_hf_model_config(official_model_name: str): Takes the official_model_name as an input. """ # In case the user passed in an alias - official_model_name = get_official_model_name(official_model_name) + official_model_name = get_official_model_name(model_name) # Load HuggingFace model config if "llama" not in official_model_name: hf_config = AutoConfig.from_pretrained(official_model_name) @@ -614,6 +615,19 @@ def convert_hf_model_config(official_model_name: str): } rotary_pct = hf_config.rotary_pct cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) + elif architecture == "BertForMaskedLM": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.layer_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": "gelu", + "attention_dir": "bidirectional", + } else: raise NotImplementedError(f"{architecture} is not currently supported.") # All of these models use LayerNorm @@ -726,7 +740,6 @@ def get_pretrained_model_config( cfg_dict["normalization_type"] = "LNPre" else: logging.warning("Cannot fold in layer norm, normalization_type is not LN.") - pass if checkpoint_index is not None or checkpoint_value is not None: checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( @@ -863,6 +876,8 @@ def get_pretrained_state_dict( elif hf_model is None: if "llama" in official_model_name: raise NotImplementedError("Must pass in hf_model for LLaMA models") + elif "bert" in official_model_name: + hf_model = BertForPreTraining.from_pretrained(official_model_name) else: hf_model = AutoModelForCausalLM.from_pretrained(official_model_name) @@ -879,6 +894,8 @@ def get_pretrained_state_dict( state_dict = convert_neox_weights(hf_model, cfg) elif cfg.original_architecture == "LLaMAForCausalLM": state_dict = convert_llama_weights(hf_model, cfg) + elif cfg.original_architecture == "BertForMaskedLM": + state_dict = convert_bert_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." @@ -887,6 +904,36 @@ def get_pretrained_state_dict( return state_dict +def fill_missing_keys(model, state_dict): + """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. + + This function is assumed to be run before weights are initialized. + + Args: + state_dict (dict): State dict from a pretrained model + + Returns: + dict: State dict with missing keys filled in + """ + # Get the default state dict + default_state_dict = model.state_dict() + # Get the keys that are missing from the pretrained model + missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) + # Fill in the missing keys with the default initialization + for key in missing_keys: + if "hf_model" in key: + # Skip keys that are from the HuggingFace model, if loading from HF. + continue + if "W_" in key: + logging.warning( + "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( + key + ) + ) + state_dict[key] = default_state_dict[key] + return state_dict + + # %% def convert_state_dict( state_dict: dict, @@ -964,7 +1011,7 @@ def convert_gpt2_weights(gpt2, cfg: HookedTransformerConfig): W_out = gpt2.transformer.h[l].mlp.c_proj.weight state_dict[f"blocks.{l}.mlp.W_out"] = W_out state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias - state_dict[f"unembed.W_U"] = gpt2.lm_head.weight.T + state_dict["unembed.W_U"] = gpt2.lm_head.weight.T state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias @@ -1270,8 +1317,8 @@ def convert_opt_weights(opt, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.mlp.b_in"] = opt.model.decoder.layers[l].fc1.bias state_dict[f"blocks.{l}.mlp.b_out"] = opt.model.decoder.layers[l].fc2.bias - state_dict[f"ln_final.w"] = opt.model.decoder.final_layer_norm.weight - state_dict[f"ln_final.b"] = opt.model.decoder.final_layer_norm.bias + state_dict["ln_final.w"] = opt.model.decoder.final_layer_norm.weight + state_dict["ln_final.b"] = opt.model.decoder.final_layer_norm.bias state_dict["unembed.W_U"] = opt.lm_head.weight.T state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab) return state_dict @@ -1368,7 +1415,7 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"] - state_dict[f"unembed.W_U"] = old_state_dict["head.weight"].T + state_dict["unembed.W_U"] = old_state_dict["head.weight"].T state_dict["ln_final.w"] = old_state_dict["ln_f.weight"] state_dict["ln_final.b"] = old_state_dict["ln_f.bias"] @@ -1376,4 +1423,63 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): return state_dict -# %% +def convert_bert_weights(bert, cfg: HookedTransformerConfig): + embeddings = bert.bert.embeddings + state_dict = { + "embed.embed.W_E": embeddings.word_embeddings.weight, + "embed.pos_embed.W_pos": embeddings.position_embeddings.weight, + "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight, + "embed.ln.w": embeddings.LayerNorm.weight, + "embed.ln.b": embeddings.LayerNorm.bias, + } + + for l in range(cfg.n_layers): + block = bert.bert.encoder.layer[l] + state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( + block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( + block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( + block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange( + block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( + block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange( + block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( + block.attention.output.dense.weight, + "m (i h) -> i h m", + i=cfg.n_heads, + ) + state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias + state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight + state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias + state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange( + block.intermediate.dense.weight, "mlp model -> model mlp" + ) + state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias + state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange( + block.output.dense.weight, "model mlp -> mlp model" + ) + state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias + state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight + state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias + + mlm_head = bert.cls.predictions + state_dict["mlm_head.W"] = mlm_head.transform.dense.weight + state_dict["mlm_head.b"] = mlm_head.transform.dense.bias + state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight + state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias + # Note: BERT uses tied embeddings + state_dict["unembed.W_U"] = embeddings.word_embeddings.weight.T + # "unembed.W_U": mlm_head.decoder.weight.T, + state_dict["unembed.b_U"] = mlm_head.bias + + return state_dict diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index bf260ba5f..881a63ecb 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -1,6 +1,13 @@ -from typing import Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union import torch +from torch import nn + +if TYPE_CHECKING: + from transformer_lens.HookedEncoder import HookedEncoder + from transformer_lens.HookedTransformer import HookedTransformer from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -32,3 +39,29 @@ def get_device_for_block_index( device = torch.device(device) device_index = (device.index or 0) + (index // layers_per_device) return torch.device(device.type, device_index) + + +def move_to_and_update_config( + model: Union[HookedTransformer, HookedEncoder], + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details=True, +): + """ + Wrapper around to that also changes model.cfg.device if it's a torch.device or string. + If torch.dtype, just passes through + """ + if isinstance(device_or_dtype, torch.device): + model.cfg.device = device_or_dtype.type + if print_details: + print("Moving model to device: ", model.cfg.device) + elif isinstance(device_or_dtype, str): + model.cfg.device = device_or_dtype + if print_details: + print("Moving model to device: ", model.cfg.device) + elif isinstance(device_or_dtype, torch.dtype): + if print_details: + print("Changing model dtype to", device_or_dtype) + # change state_dict dtypes + for k, v in model.state_dict().items(): + model.state_dict()[k] = v.to(device_or_dtype) + return nn.Module.to(model, device_or_dtype)