diff --git a/.gitignore b/.gitignore index 2aea0a4cb..a5d6d42b7 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,6 @@ _modidx.py env dist/ docs/build -.coverage +.coverage* .Ds_Store .pylintrc diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb new file mode 100644 index 000000000..6d227ebea --- /dev/null +++ b/demos/BERT.ipynb @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BERT in TransformerLens\n", + "This demo shows how to use BERT in TransformerLens for the Masked Language Modelling task." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup\n", + "(No need to read)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + } + ], + "source": [ + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "DEVELOPMENT_MODE = False\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + " %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", + " %pip install circuitsvis\n", + " \n", + " # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", + " # # Install another version of node that makes PySvelte work way faster\n", + " # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + " # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using renderer: colab\n" + ] + } + ], + "source": [ + "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", + "import plotly.io as pio\n", + "if IN_COLAB or not DEVELOPMENT_MODE:\n", + " pio.renderers.default = \"colab\"\n", + "else:\n", + " pio.renderers.default = \"notebook_connected\"\n", + "print(f\"Using renderer: {pio.renderers.default}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import circuitsvis as cv\n", + "# Testing that the library works\n", + "cv.examples.hello(\"Neel\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Import stuff\n", + "import torch\n", + "\n", + "from transformers import AutoTokenizer\n", + "\n", + "from transformer_lens import HookedEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.set_grad_enabled(False)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# BERT\n", + "\n", + "In this section, we will load a pretrained BERT model and use it for the Masked Language Modelling task" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Moving model to device: cpu\n", + "Loaded pretrained model bert-base-cased into HookedTransformer\n" + ] + } + ], + "source": [ + "bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n", + "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Use the \"[MASK]\" token to mask any tokens which you would like the model to predict." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "prompt = \"BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\"\n", + "\n", + "input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n", + "mask_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Prompt: BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\n", + "Prediction: \"Systems\"\n" + ] + } + ], + "source": [ + "logprobs = bert(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n", + "prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n", + "\n", + "print(f\"Prompt: {prompt}\")\n", + "print(f\"Prediction: \\\"{prediction}\\\"\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Better luck next time, BERT." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.10" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index aed5af6ba..6922d18ff 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -1091,6 +1091,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1110,6 +1111,9 @@ "* **GPT-NeoX** - Eleuther's 20B parameter model, trained on the Pile\n", "* **Stanford CRFM models** - a replication of GPT-2 Small and GPT-2 Medium, trained on 5 different random seeds.\n", " * Notably, 600 checkpoints were taken during training per model, and these are available in the library with eg `HookedTransformer.from_pretrained(\"stanford-gpt2-small-a\", checkpoint_index=265)`.\n", + "- **BERT** - Google's bidirectional encoder-only transformer.\n", + " - Size Base (108M), trained on English Wikipedia and BooksCorpus.\n", + " \n", "" ] }, diff --git a/pyproject.toml b/pyproject.toml index 3e5e9b9aa..5a56131dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ docs = ["Sphinx", "sphinx-autobuild", "sphinxcontrib-napoleon", "furo", "myst_pa [tool.pytest.ini_options] filterwarnings = [ + "ignore:pkg_resources is deprecated as an API:DeprecationWarning", # Ignore numpy.distutils deprecation warning caused by pandas # More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils "ignore:distutils Version classes are deprecated:DeprecationWarning" diff --git a/tests/acceptance/test_hooked_encoder.py b/tests/acceptance/test_hooked_encoder.py new file mode 100644 index 000000000..6f95d9b27 --- /dev/null +++ b/tests/acceptance/test_hooked_encoder.py @@ -0,0 +1,155 @@ +from typing import List + +import pytest +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch.testing import assert_close +from transformers import AutoTokenizer, BertForMaskedLM + +from transformer_lens import HookedEncoder + +MODEL_NAME = "bert-base-cased" + + +@pytest.fixture(scope="module") +def our_bert(): + return HookedEncoder.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def huggingface_bert(): + return BertForMaskedLM.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def tokenizer(): + return AutoTokenizer.from_pretrained(MODEL_NAME) + + +@pytest.fixture +def hello_world_tokens(tokenizer): + return tokenizer("Hello, world!", return_tensors="pt")["input_ids"] + + +def test_full_model(our_bert, huggingface_bert, tokenizer): + sequences = [ + "Hello, world!", + "this is another sequence of tokens", + ] + tokenized = tokenizer(sequences, return_tensors="pt", padding=True) + input_ids = tokenized["input_ids"] + attention_mask = tokenized["attention_mask"] + + huggingface_bert_out = huggingface_bert( + input_ids, attention_mask=attention_mask + ).logits + our_bert_out = our_bert(input_ids, one_zero_attention_mask=attention_mask) + assert_close(huggingface_bert_out, our_bert_out, rtol=1.3e-6, atol=4e-5) + + +def test_embed_one_sentence(our_bert, huggingface_bert, hello_world_tokens): + huggingface_embed = huggingface_bert.bert.embeddings + our_embed = our_bert.embed + + huggingface_embed_out = huggingface_embed(hello_world_tokens)[0] + our_embed_out = our_embed(hello_world_tokens).squeeze(0) + assert_close(huggingface_embed_out, our_embed_out) + + +def test_embed_two_sentences(our_bert, huggingface_bert, tokenizer): + encoding = tokenizer("First sentence.", "Second sentence.", return_tensors="pt") + input_ids = encoding["input_ids"] + token_type_ids = encoding["token_type_ids"] + + huggingface_embed_out = huggingface_bert.bert.embeddings( + input_ids, token_type_ids=token_type_ids + )[0] + our_embed_out = our_bert.embed(input_ids, token_type_ids=token_type_ids).squeeze(0) + assert_close(huggingface_embed_out, our_embed_out) + + +def test_attention(our_bert, huggingface_bert, hello_world_tokens): + huggingface_embed = huggingface_bert.bert.embeddings + huggingface_attn = huggingface_bert.bert.encoder.layer[0].attention + + embed_out = huggingface_embed(hello_world_tokens) + + our_attn = our_bert.blocks[0].attn + + our_attn_out = our_attn(embed_out, embed_out, embed_out) + huggingface_self_attn_out = huggingface_attn.self(embed_out)[0] + huggingface_attn_out = huggingface_attn.output.dense(huggingface_self_attn_out) + assert_close(our_attn_out, huggingface_attn_out) + + +def test_bert_block(our_bert, huggingface_bert, hello_world_tokens): + huggingface_embed = huggingface_bert.bert.embeddings + huggingface_block = huggingface_bert.bert.encoder.layer[0] + + embed_out = huggingface_embed(hello_world_tokens) + + our_block = our_bert.blocks[0] + + our_block_out = our_block(embed_out) + huggingface_block_out = huggingface_block(embed_out)[0] + assert_close(our_block_out, huggingface_block_out) + + +def test_mlm_head(our_bert, huggingface_bert, hello_world_tokens): + huggingface_bert_core_outputs = huggingface_bert.bert( + hello_world_tokens + ).last_hidden_state + + our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs) + our_unembed_out = our_bert.unembed(our_mlm_head_out) + huggingface_predictions_out = huggingface_bert.cls.predictions( + huggingface_bert_core_outputs + ) + + assert_close(our_unembed_out, huggingface_predictions_out, rtol=1.3e-6, atol=4e-5) + + +def test_unembed(our_bert, huggingface_bert, hello_world_tokens): + huggingface_bert_core_outputs = huggingface_bert.bert( + hello_world_tokens + ).last_hidden_state + + our_mlm_head_out = our_bert.mlm_head(huggingface_bert_core_outputs) + huggingface_predictions_out = huggingface_bert.cls.predictions.transform( + huggingface_bert_core_outputs + ) + + print((our_mlm_head_out - huggingface_predictions_out).abs().max()) + assert_close(our_mlm_head_out, huggingface_predictions_out, rtol=1.3e-3, atol=1e-5) + + +def test_run_with_cache(our_bert, huggingface_bert, hello_world_tokens): + model = HookedEncoder.from_pretrained("bert-base-cased") + logits, cache = model.run_with_cache(hello_world_tokens) + + # check that an arbitrary subset of the keys exist + assert "embed.hook_embed" in cache + assert "blocks.0.attn.hook_q" in cache + assert "blocks.3.attn.hook_attn_scores" in cache + assert "blocks.7.hook_resid_post" in cache + assert "mlm_head.ln.hook_normalized" in cache + + +def test_predictions(our_bert, huggingface_bert, tokenizer): + input_ids = tokenizer("The [MASK] sat on the mat", return_tensors="pt")["input_ids"] + + def get_predictions( + logits: Float[torch.Tensor, "batch pos d_vocab"], positions: List[int] + ): + logits_at_position = logits.squeeze(0)[positions] + predicted_tokens = F.softmax(logits_at_position, dim=-1).argmax(dim=-1) + return tokenizer.batch_decode(predicted_tokens) + + our_bert_out = our_bert(input_ids) + our_prediction = get_predictions(our_bert_out, [2]) + + huggingface_bert_out = huggingface_bert(input_ids).logits + huggingface_prediction = get_predictions(huggingface_bert_out, [2]) + + assert our_prediction == huggingface_prediction diff --git a/tests/acceptance/test_transformer_lens.py b/tests/acceptance/test_hooked_transformer.py similarity index 100% rename from tests/acceptance/test_transformer_lens.py rename to tests/acceptance/test_hooked_transformer.py diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py new file mode 100644 index 000000000..7e3de330e --- /dev/null +++ b/tests/unit/test_config.py @@ -0,0 +1,25 @@ +""" +Tests that verify than an arbitrary component (e.g. Embed) can be initialized using dict and object versions of HookedTransformerConfig and HookedEncoderConfig. +""" + +from transformer_lens.components import Embed +from transformer_lens.HookedTransformerConfig import HookedTransformerConfig + + +def test_hooked_transformer_config_object(): + hooked_transformer_config = HookedTransformerConfig( + n_layers=2, d_vocab=100, d_model=6, n_ctx=5, d_head=2, attn_only=True + ) + Embed(hooked_transformer_config) + + +def test_hooked_transformer_config_dict(): + hooked_transformer_config_dict = { + "n_layers": 2, + "d_vocab": 100, + "d_model": 6, + "n_ctx": 5, + "d_head": 2, + "attn_only": True, + } + Embed(hooked_transformer_config_dict) diff --git a/tests/unit/test_create_hooked_encoder.py b/tests/unit/test_create_hooked_encoder.py new file mode 100644 index 000000000..a1adc7ef3 --- /dev/null +++ b/tests/unit/test_create_hooked_encoder.py @@ -0,0 +1,35 @@ +import pytest +from transformers import AutoTokenizer, BertTokenizerFast + +from transformer_lens import HookedEncoder, HookedTransformerConfig + + +@pytest.fixture +def cfg(): + return HookedTransformerConfig( + d_head=4, d_model=12, n_ctx=5, n_layers=3, act_fn="gelu" + ) + + +def test_pass_tokenizer(cfg): + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + model = HookedEncoder(cfg, tokenizer=tokenizer) + assert model.tokenizer == tokenizer + + +def test_load_tokenizer_from_config(cfg): + cfg.tokenizer_name = "bert-base-cased" + model = HookedEncoder(cfg) + assert isinstance(model.tokenizer, BertTokenizerFast) + + +def test_load_without_tokenizer(cfg): + cfg.d_vocab = 22 + model = HookedEncoder(cfg) + assert model.tokenizer is None + + +def test_cannot_load_without_tokenizer_or_d_vocab(cfg): + with pytest.raises(AssertionError) as e: + HookedEncoder(cfg) + assert "Must provide a tokenizer if d_vocab is not provided" in str(e.value) diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py new file mode 100644 index 000000000..3ee5e2ad7 --- /dev/null +++ b/transformer_lens/HookedEncoder.py @@ -0,0 +1,397 @@ +from __future__ import annotations + +import logging +from functools import lru_cache +from typing import Dict, Optional, Tuple, Union, cast, overload + +import torch +from einops import repeat +from jaxtyping import Float, Int +from torch import nn +from transformers import AutoTokenizer +from typeguard import typeguard_ignore +from typing_extensions import Literal + +import transformer_lens.loading_from_pretrained as loading +from transformer_lens import ActivationCache, FactoredMatrix, HookedTransformerConfig +from transformer_lens.components import BertBlock, BertEmbed, BertMLMHead, Unembed +from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.utilities import devices + + +class HookedEncoder(HookedRootModule): + """ + This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. + + Limitations: + - The current MVP implementation supports only the masked language modelling (MLM) task. Next sentence prediction (NSP), causal language modelling, and other tasks are not yet supported. + - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. + + Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: + - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model + - The model only accepts tokens as inputs, and not strings, or lists of strings + """ + + def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig(**cfg) + elif isinstance(cfg, str): + raise ValueError( + "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." + ) + self.cfg = cfg + + assert ( + self.cfg.n_devices == 1 + ), "Multiple devices not supported for HookedEncoder" + if move_to_device: + self.to(self.cfg.device) + + if tokenizer is not None: + self.tokenizer = tokenizer + elif self.cfg.tokenizer_name is not None: + self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.tokenizer_name) + else: + self.tokenizer = None + + if self.cfg.d_vocab == -1: + # If we have a tokenizer, vocab size can be inferred from it. + assert ( + self.tokenizer is not None + ), "Must provide a tokenizer if d_vocab is not provided" + self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 + if self.cfg.d_vocab_out == -1: + self.cfg.d_vocab_out = self.cfg.d_vocab + + self.embed = BertEmbed(self.cfg) + self.blocks = nn.ModuleList( + [BertBlock(self.cfg) for _ in range(self.cfg.n_layers)] + ) + self.mlm_head = BertMLMHead(cfg) + self.unembed = Unembed(self.cfg) + + self.hook_full_embed = HookPoint() + + self.setup() + + @overload + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + return_type: Literal["logits"], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_vocab"]: + ... + + @overload + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + return_type: Literal[None], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: + ... + + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + return_type: Optional[str] = "logits", + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: + """Input must be a batch of tokens. Strings and lists of strings are not yet supported. + + return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits). + + token_type_ids Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). + + one_zero_attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to. + """ + + tokens = input + + if tokens.device.type != self.cfg.device: + tokens = tokens.to(self.cfg.device) + if one_zero_attention_mask is not None: + one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) + + resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) + + large_negative_number = -1e5 + additive_attention_mask = ( + large_negative_number + * repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") + if one_zero_attention_mask is not None + else None + ) + + for block in self.blocks: + resid = block(resid, additive_attention_mask) + resid = self.mlm_head(resid) + + if return_type is None: + return + + logits = self.unembed(resid) + return logits + + @overload + def run_with_cache( + self, *model_args, return_cache_object: Literal[True] = True, **kwargs + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], ActivationCache]: + ... + + @overload + def run_with_cache( + self, *model_args, return_cache_object: Literal[False] = False, **kwargs + ) -> Tuple[Float[torch.Tensor, "batch pos d_vocab"], Dict[str, torch.Tensor]]: + ... + + def run_with_cache( + self, + *model_args, + return_cache_object: bool = True, + remove_batch_dim: bool = False, + **kwargs, + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + Union[ActivationCache, Dict[str, torch.Tensor]], + ]: + """ + Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. + """ + out, cache_dict = super().run_with_cache( + *model_args, remove_batch_dim=remove_batch_dim, **kwargs + ) + if return_cache_object: + cache = ActivationCache( + cache_dict, self, has_batch_dim=not remove_batch_dim + ) + return out, cache + else: + return out, cache_dict + + def to( + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) + + def cuda(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("cuda") + + def cpu(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("cpu") + + @classmethod + def from_pretrained( + cls, + model_name: str, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, + hf_model=None, + device: Optional[str] = None, + **model_kwargs, + ) -> HookedEncoder: + """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" + logging.warning( + "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " + "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " + "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " + "implementation." + "\n" + "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " + "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " + "that the last LayerNorm in a block cannot be folded." + ) + + official_model_name = loading.get_official_model_name(model_name) + + cfg = loading.get_pretrained_model_config( + official_model_name, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, + fold_ln=False, + device=device, + n_devices=1, + ) + + state_dict = loading.get_pretrained_state_dict( + official_model_name, cfg, hf_model + ) + + model = cls(cfg, **model_kwargs) + + model.load_state_dict(state_dict, strict=False) + + print(f"Loaded pretrained model {model_name} into HookedTransformer") + + return model + + @property + @typeguard_ignore + def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: + """ + Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) + """ + return self.unembed.W_U + + @property + @typeguard_ignore + def b_U(self) -> Float[torch.Tensor, "d_vocab"]: + return self.unembed.b_U + + @property + @typeguard_ignore + def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: + """ + Convenience to get the embedding matrix + """ + return self.embed.embed.W_E + + @property + @typeguard_ignore + def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: + """ + Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! + """ + return self.embed.pos_embed.W_pos + + @property + @typeguard_ignore + def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: + """ + Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. + """ + return torch.cat([self.W_E, self.W_pos], dim=0) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the key weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the query weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the value weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: + """Stacks the attn output weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: + """Stacks the MLP input weights across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: + """Stacks the MLP output weights across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the key biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the query biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the value biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the attn output biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: + """Stacks the MLP input biases across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + @lru_cache(maxsize=None) + def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the MLP output biases across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0 + ) + + @property + @typeguard_ignore + def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) + + @property + @typeguard_ignore + def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + return FactoredMatrix(self.W_V, self.W_O) + + def all_head_labels(self) -> list[str]: + return [ + f"L{l}H{h}" + for l in range(self.cfg.n_layers) + for h in range(self.cfg.n_heads) + ] diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 6d2489044..6533c85a7 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -721,26 +721,12 @@ def tokens_to_residual_directions( residual_direction = self.W_U[:, token] return residual_direction - def to(self, device_or_dtype, print_details=True): - """ - Wrapper around to that also changes self.cfg.device if it's a torch.device or string. - If torch.dtype, just passes through - """ - if isinstance(device_or_dtype, torch.device): - self.cfg.device = device_or_dtype.type - if print_details: - print("Moving model to device: ", self.cfg.device) - elif isinstance(device_or_dtype, str): - self.cfg.device = device_or_dtype - if print_details: - print("Moving model to device: ", self.cfg.device) - elif isinstance(device_or_dtype, torch.dtype): - if print_details: - print("Changing model dtype to", device_or_dtype) - # change state_dict dtypes - for k, v in self.state_dict().items(): - self.state_dict()[k] = v.to(device_or_dtype) - return nn.Module.to(self, device_or_dtype) + def to( + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) def cuda(self): # Wrapper around cuda that also changes self.cfg.device @@ -1033,33 +1019,7 @@ def load_and_process_state_dict( self.load_state_dict(state_dict) def fill_missing_keys(self, state_dict): - """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. - - This function is assumed to be run before weights are initialized. - - Args: - state_dict (dict): State dict from a pretrained model - - Returns: - dict: State dict with missing keys filled in - """ - # Get the default state dict - default_state_dict = self.state_dict() - # Get the keys that are missing from the pretrained model - missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) - # Fill in the missing keys with the default initialization - for key in missing_keys: - if "hf_model" in key: - # Skip keys that are from the HuggingFace model, if loading from HF. - continue - if "W_" in key: - logging.warning( - "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( - key - ) - ) - state_dict[key] = default_state_dict[key] - return state_dict + return loading.fill_missing_keys(self, state_dict) def fold_layer_norm(self, state_dict: Dict[str, torch.Tensor]): """Takes in a state dict from a pretrained model, formatted to be consistent with HookedTransformer but with diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index b33ffcefa..becf62821 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import pprint import random @@ -226,7 +228,7 @@ def __post_init__(self): ), f"Not enough CUDA devices to support n_devices {self.n_devices}" @classmethod - def from_dict(cls, config_dict: Dict[str, Any]): + def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig: """ Instantiates a `HookedTransformerConfig` from a Python dictionary of parameters. diff --git a/transformer_lens/__init__.py b/transformer_lens/__init__.py index 2994e9768..fce9f93fa 100644 --- a/transformer_lens/__init__.py +++ b/transformer_lens/__init__.py @@ -10,6 +10,7 @@ from .FactoredMatrix import FactoredMatrix from .ActivationCache import ActivationCache from .HookedTransformer import HookedTransformer +from .HookedEncoder import HookedEncoder from . import head_detector from . import loading_from_pretrained as loading from . import patching diff --git a/transformer_lens/components.py b/transformer_lens/components.py index 16fa7053b..09b1b069b 100644 --- a/transformer_lens/components.py +++ b/transformer_lens/components.py @@ -90,6 +90,97 @@ def forward( return broadcast_pos_embed.clone() +class TokenTypeEmbed(nn.Module): + """ + The token-type embed is a binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). + + See the BERT paper for more information: https://arxiv.org/pdf/1810.04805.pdf + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W_token_type = nn.Parameter(torch.empty(2, self.cfg.d_model)) + + def forward(self, token_type_ids: Int[torch.Tensor, "batch pos"]): + return self.W_token_type[token_type_ids, :] + + +class BertEmbed(nn.Module): + """ + Custom embedding layer for a BERT-like model. This module computes the sum of the token, positional and token-type embeddings and takes the layer norm of the result. + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.embed = Embed(cfg) + self.pos_embed = PosEmbed(cfg) + self.token_type_embed = TokenTypeEmbed(cfg) + self.ln = LayerNorm(cfg) + + self.hook_embed = HookPoint() + self.hook_pos_embed = HookPoint() + self.hook_token_type_embed = HookPoint() + + def forward( + self, + input_ids: Int[torch.Tensor, "batch pos"], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + ): + base_index_id = torch.arange(input_ids.shape[1], device=input_ids.device) + index_ids = einops.repeat( + base_index_id, "pos -> batch pos", batch=input_ids.shape[0] + ) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + word_embeddings_out = self.hook_embed(self.embed(input_ids)) + position_embeddings_out = self.hook_pos_embed(self.pos_embed(index_ids)) + token_type_embeddings_out = self.hook_token_type_embed( + self.token_type_embed(token_type_ids) + ) + + embeddings_out = ( + word_embeddings_out + position_embeddings_out + token_type_embeddings_out + ) + layer_norm_out = self.ln(embeddings_out) + return layer_norm_out + + +class BertMLMHead(nn.Module): + """ + Transforms BERT embeddings into logits. The purpose of this module is to predict masked tokens in a sentence. + """ + + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig.from_dict(cfg) + self.cfg = cfg + self.W = nn.Parameter(torch.empty(cfg.d_model, cfg.d_model)) + self.b = nn.Parameter(torch.zeros(cfg.d_model)) + self.act_fn = nn.GELU() + self.ln = LayerNorm(cfg) + + def forward(self, resid: Float[torch.Tensor, "batch pos d_model"]) -> torch.Tensor: + resid = ( + einsum( + "batch pos d_model_in, d_model_out d_model_in -> batch pos d_model_out", + resid, + self.W, + ) + + self.b + ) + resid = self.act_fn(resid) + resid = self.ln(resid) + return resid + + # LayerNormPre # I fold the LayerNorm weights and biases into later weights and biases. # This is just the 'center and normalise' part of LayerNorm @@ -368,10 +459,12 @@ def forward( Float[torch.Tensor, "batch pos head_index d_model"], ], past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, + additive_attention_mask: Float[torch.Tensor, "batch 1 1 pos"] = None, ) -> Float[torch.Tensor, "batch pos d_model"]: """ shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None + additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. """ if self.cfg.use_split_qkv_input: @@ -421,8 +514,8 @@ def forward( attn_scores = ( einsum( "batch query_pos head_index d_head, \ - batch key_pos head_index d_head \ - -> batch head_index query_pos key_pos", + batch key_pos head_index d_head \ + -> batch head_index query_pos key_pos", q, k, ) @@ -433,6 +526,9 @@ def forward( attn_scores = self.apply_causal_mask( attn_scores, kv_cache_pos_offset ) # [batch, head_index, query_pos, key_pos] + if additive_attention_mask is not None: + attn_scores += additive_attention_mask + attn_scores = self.hook_attn_scores(attn_scores) pattern = self.hook_pattern( F.softmax(attn_scores, dim=-1) @@ -451,8 +547,8 @@ def forward( ( einsum( "batch pos head_index d_head, \ - head_index d_head d_model -> \ - batch pos d_model", + head_index d_head d_model -> \ + batch pos d_model", z, self.W_O, ) @@ -853,3 +949,70 @@ def add_head_dimension(tensor): resid_pre + attn_out ) # [batch, pos, d_model] return resid_post + + +class BertBlock(nn.Module): + """ + BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before. + """ + + def __init__(self, cfg: HookedTransformerConfig): + super().__init__() + self.cfg = cfg + + self.attn = Attention(cfg) + self.ln1 = LayerNorm(cfg) + self.mlp = MLP(cfg) + self.ln2 = LayerNorm(cfg) + + self.hook_q_input = HookPoint() # [batch, pos, d_model] + self.hook_k_input = HookPoint() # [batch, pos, d_model] + self.hook_v_input = HookPoint() # [batch, pos, d_model] + + self.hook_attn_out = HookPoint() # [batch, pos, d_model] + self.hook_mlp_out = HookPoint() # [batch, pos, d_model] + self.hook_resid_pre = HookPoint() # [batch, pos, d_model] + self.hook_resid_mid = HookPoint() # [batch, pos, d_model] + self.hook_resid_post = HookPoint() # [batch, pos, d_model] + self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model] + + def forward( + self, + resid_pre: Float[torch.Tensor, "batch pos d_model"], + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]] = None, + ): + resid_pre = self.hook_resid_pre(resid_pre) + + query_input = resid_pre + key_input = resid_pre + value_input = resid_pre + + if self.cfg.use_split_qkv_input: + + def add_head_dimension(tensor): + return einops.repeat( + tensor, + "batch pos d_model -> batch pos n_heads d_model", + n_heads=self.cfg.n_heads, + ).clone() + + query_input = self.hook_q_input(add_head_dimension(query_input)) + key_input = self.hook_k_input(add_head_dimension(key_input)) + value_input = self.hook_v_input(add_head_dimension(value_input)) + + attn_out = self.hook_attn_out( + self.attn( + query_input, + key_input, + value_input, + additive_attention_mask=additive_attention_mask, + ) + ) + resid_mid = self.hook_resid_mid(resid_pre + attn_out) + normalized_resid_mid = self.ln1(resid_mid) + + mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) + resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out) + normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post)) + + return normalized_resid_post diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 8c51bdea1..b638cc97c 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -6,7 +6,7 @@ import einops import torch from huggingface_hub import HfApi -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, BertForPreTraining import transformer_lens.utils as utils from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -104,6 +104,7 @@ "llama-30b-hf", "llama-65b-hf", "Baidicoot/Othello-GPT-Transformer-Lens", + "bert-base-cased", ] # Model Aliases: @@ -442,7 +443,7 @@ def get_official_model_name(model_name: str): return official_model_name -def convert_hf_model_config(official_model_name: str): +def convert_hf_model_config(model_name: str): """ Returns the model config for a HuggingFace model, converted to a dictionary in the HookedTransformerConfig format. @@ -450,7 +451,7 @@ def convert_hf_model_config(official_model_name: str): Takes the official_model_name as an input. """ # In case the user passed in an alias - official_model_name = get_official_model_name(official_model_name) + official_model_name = get_official_model_name(model_name) # Load HuggingFace model config if "llama" not in official_model_name: hf_config = AutoConfig.from_pretrained(official_model_name) @@ -614,6 +615,19 @@ def convert_hf_model_config(official_model_name: str): } rotary_pct = hf_config.rotary_pct cfg_dict["rotary_dim"] = round(rotary_pct * cfg_dict["d_head"]) + elif architecture == "BertForMaskedLM": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.layer_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": "gelu", + "attention_dir": "bidirectional", + } else: raise NotImplementedError(f"{architecture} is not currently supported.") # All of these models use LayerNorm @@ -726,7 +740,6 @@ def get_pretrained_model_config( cfg_dict["normalization_type"] = "LNPre" else: logging.warning("Cannot fold in layer norm, normalization_type is not LN.") - pass if checkpoint_index is not None or checkpoint_value is not None: checkpoint_labels, checkpoint_label_type = get_checkpoint_labels( @@ -863,6 +876,8 @@ def get_pretrained_state_dict( elif hf_model is None: if "llama" in official_model_name: raise NotImplementedError("Must pass in hf_model for LLaMA models") + elif "bert" in official_model_name: + hf_model = BertForPreTraining.from_pretrained(official_model_name) else: hf_model = AutoModelForCausalLM.from_pretrained(official_model_name) @@ -879,6 +894,8 @@ def get_pretrained_state_dict( state_dict = convert_neox_weights(hf_model, cfg) elif cfg.original_architecture == "LLaMAForCausalLM": state_dict = convert_llama_weights(hf_model, cfg) + elif cfg.original_architecture == "BertForMaskedLM": + state_dict = convert_bert_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." @@ -887,6 +904,36 @@ def get_pretrained_state_dict( return state_dict +def fill_missing_keys(model, state_dict): + """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. + + This function is assumed to be run before weights are initialized. + + Args: + state_dict (dict): State dict from a pretrained model + + Returns: + dict: State dict with missing keys filled in + """ + # Get the default state dict + default_state_dict = model.state_dict() + # Get the keys that are missing from the pretrained model + missing_keys = set(default_state_dict.keys()) - set(state_dict.keys()) + # Fill in the missing keys with the default initialization + for key in missing_keys: + if "hf_model" in key: + # Skip keys that are from the HuggingFace model, if loading from HF. + continue + if "W_" in key: + logging.warning( + "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format( + key + ) + ) + state_dict[key] = default_state_dict[key] + return state_dict + + # %% def convert_state_dict( state_dict: dict, @@ -964,7 +1011,7 @@ def convert_gpt2_weights(gpt2, cfg: HookedTransformerConfig): W_out = gpt2.transformer.h[l].mlp.c_proj.weight state_dict[f"blocks.{l}.mlp.W_out"] = W_out state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias - state_dict[f"unembed.W_U"] = gpt2.lm_head.weight.T + state_dict["unembed.W_U"] = gpt2.lm_head.weight.T state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias @@ -1270,8 +1317,8 @@ def convert_opt_weights(opt, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.mlp.b_in"] = opt.model.decoder.layers[l].fc1.bias state_dict[f"blocks.{l}.mlp.b_out"] = opt.model.decoder.layers[l].fc2.bias - state_dict[f"ln_final.w"] = opt.model.decoder.final_layer_norm.weight - state_dict[f"ln_final.b"] = opt.model.decoder.final_layer_norm.bias + state_dict["ln_final.w"] = opt.model.decoder.final_layer_norm.weight + state_dict["ln_final.b"] = opt.model.decoder.final_layer_norm.bias state_dict["unembed.W_U"] = opt.lm_head.weight.T state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab) return state_dict @@ -1368,7 +1415,7 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T state_dict[f"blocks.{l}.mlp.b_out"] = old_state_dict[f"blocks.{l}.mlp.2.bias"] - state_dict[f"unembed.W_U"] = old_state_dict["head.weight"].T + state_dict["unembed.W_U"] = old_state_dict["head.weight"].T state_dict["ln_final.w"] = old_state_dict["ln_f.weight"] state_dict["ln_final.b"] = old_state_dict["ln_f.bias"] @@ -1376,4 +1423,63 @@ def convert_mingpt_weights(old_state_dict, cfg: HookedTransformerConfig): return state_dict -# %% +def convert_bert_weights(bert, cfg: HookedTransformerConfig): + embeddings = bert.bert.embeddings + state_dict = { + "embed.embed.W_E": embeddings.word_embeddings.weight, + "embed.pos_embed.W_pos": embeddings.position_embeddings.weight, + "embed.token_type_embed.W_token_type": embeddings.token_type_embeddings.weight, + "embed.ln.w": embeddings.LayerNorm.weight, + "embed.ln.b": embeddings.LayerNorm.bias, + } + + for l in range(cfg.n_layers): + block = bert.bert.encoder.layer[l] + state_dict[f"blocks.{l}.attn.W_Q"] = einops.rearrange( + block.attention.self.query.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.b_Q"] = einops.rearrange( + block.attention.self.query.bias, "(i h) -> i h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.W_K"] = einops.rearrange( + block.attention.self.key.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.b_K"] = einops.rearrange( + block.attention.self.key.bias, "(i h) -> i h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.W_V"] = einops.rearrange( + block.attention.self.value.weight, "(i h) m -> i m h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.b_V"] = einops.rearrange( + block.attention.self.value.bias, "(i h) -> i h", i=cfg.n_heads + ) + state_dict[f"blocks.{l}.attn.W_O"] = einops.rearrange( + block.attention.output.dense.weight, + "m (i h) -> i h m", + i=cfg.n_heads, + ) + state_dict[f"blocks.{l}.attn.b_O"] = block.attention.output.dense.bias + state_dict[f"blocks.{l}.ln1.w"] = block.attention.output.LayerNorm.weight + state_dict[f"blocks.{l}.ln1.b"] = block.attention.output.LayerNorm.bias + state_dict[f"blocks.{l}.mlp.W_in"] = einops.rearrange( + block.intermediate.dense.weight, "mlp model -> model mlp" + ) + state_dict[f"blocks.{l}.mlp.b_in"] = block.intermediate.dense.bias + state_dict[f"blocks.{l}.mlp.W_out"] = einops.rearrange( + block.output.dense.weight, "model mlp -> mlp model" + ) + state_dict[f"blocks.{l}.mlp.b_out"] = block.output.dense.bias + state_dict[f"blocks.{l}.ln2.w"] = block.output.LayerNorm.weight + state_dict[f"blocks.{l}.ln2.b"] = block.output.LayerNorm.bias + + mlm_head = bert.cls.predictions + state_dict["mlm_head.W"] = mlm_head.transform.dense.weight + state_dict["mlm_head.b"] = mlm_head.transform.dense.bias + state_dict["mlm_head.ln.w"] = mlm_head.transform.LayerNorm.weight + state_dict["mlm_head.ln.b"] = mlm_head.transform.LayerNorm.bias + # Note: BERT uses tied embeddings + state_dict["unembed.W_U"] = embeddings.word_embeddings.weight.T + # "unembed.W_U": mlm_head.decoder.weight.T, + state_dict["unembed.b_U"] = mlm_head.bias + + return state_dict diff --git a/transformer_lens/utilities/devices.py b/transformer_lens/utilities/devices.py index bf260ba5f..881a63ecb 100644 --- a/transformer_lens/utilities/devices.py +++ b/transformer_lens/utilities/devices.py @@ -1,6 +1,13 @@ -from typing import Optional, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union import torch +from torch import nn + +if TYPE_CHECKING: + from transformer_lens.HookedEncoder import HookedEncoder + from transformer_lens.HookedTransformer import HookedTransformer from transformer_lens.HookedTransformerConfig import HookedTransformerConfig @@ -32,3 +39,29 @@ def get_device_for_block_index( device = torch.device(device) device_index = (device.index or 0) + (index // layers_per_device) return torch.device(device.type, device_index) + + +def move_to_and_update_config( + model: Union[HookedTransformer, HookedEncoder], + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details=True, +): + """ + Wrapper around to that also changes model.cfg.device if it's a torch.device or string. + If torch.dtype, just passes through + """ + if isinstance(device_or_dtype, torch.device): + model.cfg.device = device_or_dtype.type + if print_details: + print("Moving model to device: ", model.cfg.device) + elif isinstance(device_or_dtype, str): + model.cfg.device = device_or_dtype + if print_details: + print("Moving model to device: ", model.cfg.device) + elif isinstance(device_or_dtype, torch.dtype): + if print_details: + print("Changing model dtype to", device_or_dtype) + # change state_dict dtypes + for k, v in model.state_dict().items(): + model.state_dict()[k] = v.to(device_or_dtype) + return nn.Module.to(model, device_or_dtype)