-
Notifications
You must be signed in to change notification settings - Fork 308
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introducing HookedEncoder, a BERT-style encoder inheriting from HookedRootModule. Supports ActivationCache, with hooks on all interesting activations. Weights can be loaded from the huggingface bert-base-cased pretrained model. Unlike HookedTransformer, it does not (yet) do any pre-processing of the weights (e.g. folding LayerNorm). Another difference is that the model can currently only be run with tokens, and not strings or list of strings. Currently, the supported task/architecture is masked language modelling. Next sentence prediction, causal language modelling, and other tasks are not supported. The HookedEncoder does not contain dropouts, which may lead to inconsistent results when pretraining. Changes: - Add new class HookedEncoder - Add new components - TokenTypeEmbed - BertEmbed - BertMLMHead - BertBlock - Add `additive_attention_mask` parameter to forward method of `Attention` component - Add BERT config and state dict to loading_from_pretrained - Extract methods from HookedTransformer for reuse: - devices.move_to_and_update_config - lodaing.fill_missing_keys - Add demo notebook `demos/BERT.ipynb` - Update Available Models list in Main Demo - Testing - Unit and acceptance tests for HookedEncoder and sub-components - New demo in `demos/BERT.ipynb` also acts as a test - I also added some tests for existing components e.g. HookedTransformerConfig Future work: #277
- Loading branch information
Showing
15 changed files
with
1,203 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,6 @@ _modidx.py | |
env | ||
dist/ | ||
docs/build | ||
.coverage | ||
.coverage* | ||
.Ds_Store | ||
.pylintrc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# BERT in TransformerLens\n", | ||
"This demo shows how to use BERT in TransformerLens for the Masked Language Modelling task." | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Setup\n", | ||
"(No need to read)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Running as a Jupyter notebook - intended for development only!\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Janky code to do different setup when run in a Colab notebook vs VSCode\n", | ||
"DEVELOPMENT_MODE = False\n", | ||
"try:\n", | ||
" import google.colab\n", | ||
" IN_COLAB = True\n", | ||
" print(\"Running as a Colab notebook\")\n", | ||
" %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n", | ||
" %pip install circuitsvis\n", | ||
" \n", | ||
" # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n", | ||
" # # Install another version of node that makes PySvelte work way faster\n", | ||
" # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", | ||
" # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", | ||
"except:\n", | ||
" IN_COLAB = False\n", | ||
" print(\"Running as a Jupyter notebook - intended for development only!\")\n", | ||
" from IPython import get_ipython\n", | ||
"\n", | ||
" ipython = get_ipython()\n", | ||
" # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", | ||
" ipython.magic(\"load_ext autoreload\")\n", | ||
" ipython.magic(\"autoreload 2\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using renderer: colab\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", | ||
"import plotly.io as pio\n", | ||
"if IN_COLAB or not DEVELOPMENT_MODE:\n", | ||
" pio.renderers.default = \"colab\"\n", | ||
"else:\n", | ||
" pio.renderers.default = \"notebook_connected\"\n", | ||
"print(f\"Using renderer: {pio.renderers.default}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/html": [ | ||
"<div id=\"circuits-vis-3dab7238-6dd6\" style=\"margin: 15px 0;\"/>\n", | ||
" <script crossorigin type=\"module\">\n", | ||
" import { render, Hello } from \"https://unpkg.com/[email protected]/dist/cdn/esm.js\";\n", | ||
" render(\n", | ||
" \"circuits-vis-3dab7238-6dd6\",\n", | ||
" Hello,\n", | ||
" {\"name\": \"Neel\"}\n", | ||
" )\n", | ||
" </script>" | ||
], | ||
"text/plain": [ | ||
"<circuitsvis.utils.render.RenderedHTML at 0x1090aa4d0>" | ||
] | ||
}, | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"import circuitsvis as cv\n", | ||
"# Testing that the library works\n", | ||
"cv.examples.hello(\"Neel\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Import stuff\n", | ||
"import torch\n", | ||
"\n", | ||
"from transformers import AutoTokenizer\n", | ||
"\n", | ||
"from transformer_lens import HookedEncoder" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<torch.autograd.grad_mode.set_grad_enabled at 0x104e56b60>" | ||
] | ||
}, | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"torch.set_grad_enabled(False)" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# BERT\n", | ||
"\n", | ||
"In this section, we will load a pretrained BERT model and use it for the Masked Language Modelling task" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Moving model to device: cpu\n", | ||
"Loaded pretrained model bert-base-cased into HookedTransformer\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n", | ||
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Use the \"[MASK]\" token to mask any tokens which you would like the model to predict." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt = \"BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\"\n", | ||
"\n", | ||
"input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n", | ||
"mask_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Prompt: BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\n", | ||
"Prediction: \"Systems\"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"logprobs = bert(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n", | ||
"prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n", | ||
"\n", | ||
"print(f\"Prompt: {prompt}\")\n", | ||
"print(f\"Prediction: \\\"{prediction}\\\"\")" | ||
] | ||
}, | ||
{ | ||
"attachments": {}, | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Better luck next time, BERT." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.10" | ||
}, | ||
"orig_nbformat": 4 | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.