Skip to content

Commit

Permalink
Introducing HookedEncoder (#276)
Browse files Browse the repository at this point in the history
Introducing HookedEncoder, a BERT-style encoder inheriting from HookedRootModule. Supports ActivationCache, with hooks on all interesting activations. Weights can be loaded from the huggingface bert-base-cased pretrained model. 

Unlike HookedTransformer, it does not (yet) do any pre-processing of the weights (e.g. folding LayerNorm). Another difference is that the model can currently only be run with tokens, and not strings or list of strings. Currently, the supported task/architecture is masked language modelling. Next sentence prediction, causal language modelling, and other tasks are not supported. The HookedEncoder does not contain dropouts, which may lead to inconsistent results when pretraining.

Changes:
- Add new class HookedEncoder
- Add new components
	- TokenTypeEmbed
	- BertEmbed
	- BertMLMHead
	- BertBlock
 - Add `additive_attention_mask` parameter to forward method of `Attention` component
 - Add BERT config and state dict to loading_from_pretrained
- Extract methods from HookedTransformer for reuse: 
	- devices.move_to_and_update_config
	- lodaing.fill_missing_keys
- Add demo notebook `demos/BERT.ipynb`
- Update Available Models list in Main Demo
- Testing
	- Unit and acceptance tests for HookedEncoder and sub-components
	- New demo in `demos/BERT.ipynb` also acts as a test
	- I also added some tests for existing components e.g. HookedTransformerConfig

Future work: #277
  • Loading branch information
rusheb authored May 19, 2023
1 parent 966ce1c commit c268a71
Show file tree
Hide file tree
Showing 15 changed files with 1,203 additions and 63 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ _modidx.py
env
dist/
docs/build
.coverage
.coverage*
.Ds_Store
.pylintrc
258 changes: 258 additions & 0 deletions demos/BERT.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# BERT in TransformerLens\n",
"This demo shows how to use BERT in TransformerLens for the Masked Language Modelling task."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup\n",
"(No need to read)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Running as a Jupyter notebook - intended for development only!\n"
]
}
],
"source": [
"# Janky code to do different setup when run in a Colab notebook vs VSCode\n",
"DEVELOPMENT_MODE = False\n",
"try:\n",
" import google.colab\n",
" IN_COLAB = True\n",
" print(\"Running as a Colab notebook\")\n",
" %pip install git+https://github.com/neelnanda-io/TransformerLens.git\n",
" %pip install circuitsvis\n",
" \n",
" # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working\n",
" # # Install another version of node that makes PySvelte work way faster\n",
" # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n",
" # %pip install git+https://github.com/neelnanda-io/PySvelte.git\n",
"except:\n",
" IN_COLAB = False\n",
" print(\"Running as a Jupyter notebook - intended for development only!\")\n",
" from IPython import get_ipython\n",
"\n",
" ipython = get_ipython()\n",
" # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
" ipython.magic(\"load_ext autoreload\")\n",
" ipython.magic(\"autoreload 2\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using renderer: colab\n"
]
}
],
"source": [
"# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
"import plotly.io as pio\n",
"if IN_COLAB or not DEVELOPMENT_MODE:\n",
" pio.renderers.default = \"colab\"\n",
"else:\n",
" pio.renderers.default = \"notebook_connected\"\n",
"print(f\"Using renderer: {pio.renderers.default}\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div id=\"circuits-vis-3dab7238-6dd6\" style=\"margin: 15px 0;\"/>\n",
" <script crossorigin type=\"module\">\n",
" import { render, Hello } from \"https://unpkg.com/[email protected]/dist/cdn/esm.js\";\n",
" render(\n",
" \"circuits-vis-3dab7238-6dd6\",\n",
" Hello,\n",
" {\"name\": \"Neel\"}\n",
" )\n",
" </script>"
],
"text/plain": [
"<circuitsvis.utils.render.RenderedHTML at 0x1090aa4d0>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import circuitsvis as cv\n",
"# Testing that the library works\n",
"cv.examples.hello(\"Neel\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Import stuff\n",
"import torch\n",
"\n",
"from transformers import AutoTokenizer\n",
"\n",
"from transformer_lens import HookedEncoder"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0x104e56b60>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.set_grad_enabled(False)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# BERT\n",
"\n",
"In this section, we will load a pretrained BERT model and use it for the Masked Language Modelling task"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:HookedEncoder is still in beta. Please be aware that model preprocessing (e.g. LayerNorm folding) is not yet supported and backward compatibility is not guaranteed.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Moving model to device: cpu\n",
"Loaded pretrained model bert-base-cased into HookedTransformer\n"
]
}
],
"source": [
"bert = HookedEncoder.from_pretrained(\"bert-base-cased\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the \"[MASK]\" token to mask any tokens which you would like the model to predict."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\"\n",
"\n",
"input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n",
"mask_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prompt: BERT: Pre-training of Deep Bidirectional [MASK] for Language Understanding\n",
"Prediction: \"Systems\"\n"
]
}
],
"source": [
"logprobs = bert(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n",
"prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n",
"\n",
"print(f\"Prompt: {prompt}\")\n",
"print(f\"Prediction: \\\"{prediction}\\\"\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Better luck next time, BERT."
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 4 additions & 0 deletions demos/Main_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -1110,6 +1111,9 @@
"* **GPT-NeoX** - Eleuther's 20B parameter model, trained on the Pile\n",
"* **Stanford CRFM models** - a replication of GPT-2 Small and GPT-2 Medium, trained on 5 different random seeds.\n",
" * Notably, 600 checkpoints were taken during training per model, and these are available in the library with eg `HookedTransformer.from_pretrained(\"stanford-gpt2-small-a\", checkpoint_index=265)`.\n",
"- **BERT** - Google's bidirectional encoder-only transformer.\n",
" - Size Base (108M), trained on English Wikipedia and BooksCorpus.\n",
" \n",
"</details>"
]
},
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ docs = ["Sphinx", "sphinx-autobuild", "sphinxcontrib-napoleon", "furo", "myst_pa

[tool.pytest.ini_options]
filterwarnings = [
"ignore:pkg_resources is deprecated as an API:DeprecationWarning",
# Ignore numpy.distutils deprecation warning caused by pandas
# More info: https://numpy.org/doc/stable/reference/distutils.html#module-numpy.distutils
"ignore:distutils Version classes are deprecated:DeprecationWarning"
Expand Down
Loading

0 comments on commit c268a71

Please sign in to comment.