Skip to content

Commit

Permalink
added training script and precommit scripts. Updated readme
Browse files Browse the repository at this point in the history
  • Loading branch information
Butanium committed Nov 18, 2024
1 parent 12e8427 commit 0f3e5e8
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 0 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/black-format-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Format Check

on: [push, pull_request]

jobs:
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install dependencies
run: |
pip install black
- name: Check formatting
run: |
black --check .
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
repos:
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# Dictionary Learning and Crosscoders
This repo contains a few new features compared to the original repo:
- It is `pip` installable.
- A new `Crosscoder` class for training CrossCoders as described in [the anthropic paper](https://transformer-circuits.pub/drafts/crosscoders/index.html#model-diffing).
- A way to cache activations in order to load them later to train a SAE or Crosscoder in `cache.py`.
- A script for training a Crosscoder using pre-computed activations in `scripts/train_crosscoder.py`.


# Original README

This is a repository for doing dictionary learning via sparse autoencoders on neural network activations. It was developed by Samuel Marks and Aaron Mueller.

For accessing, saving, and intervening on NN activations, we use the [`nnsight`](http://nnsight.net/) package; as of March 2024, `nnsight` is under active development and may undergo breaking changes. That said, `nnsight` is easy to use and quick to learn; if you plan to modify this repo, then we recommend going through the main `nnsight` demo [here](https://nnsight.net/notebooks/tutorials/walkthrough/).
Expand Down
136 changes: 136 additions & 0 deletions dictionary_learning/scripts/train_crosscoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Train a Crosscoder using pre-computed activations.
Activations are assumed to be stored in the directory specified by `--activation-store-dir`, organized by model and dataset:
activations/<base-model>/<dataset>/<submodule-name>/
"""

import torch as th
import argparse
from pathlib import Path
from dictionary_learning.cache import PairedActivationCache


from dictionary_learning import CrossCoder
from dictionary_learning.trainers import CrossCoderTrainer
from dictionary_learning.training import trainSAE
import os

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--activation-store-dir", type=str, default="activations")
parser.add_argument("--base-model", type=str, default="gemma-2-2b")
parser.add_argument("--instruct-model", type=str, default="gemma-2-2b-it")
parser.add_argument("--layer", type=int, default=13)
parser.add_argument("--wandb-entity", type=str, default="")
parser.add_argument("--disable-wandb", action="store_true")
parser.add_argument("--expansion-factor", type=int, default=32)
parser.add_argument("--batch-size", type=int, default=2048)
parser.add_argument("--workers", type=int, default=32)
parser.add_argument("--mu", type=float, default=1e-1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--max-steps", type=int, default=None)
parser.add_argument("--validate-every-n-steps", type=int, default=10000)
parser.add_argument("--same-init-for-all-layers", action="store_true")
parser.add_argument("--norm-init-scale", type=float, default=0.005)
parser.add_argument("--init-with-transpose", action="store_true")
parser.add_argument("--run-name", type=str, default=None)
parser.add_argument("--resample-steps", type=int, default=None)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--pretrained", type=str, default=None)
parser.add_argument("--encoder-layers", type=int, default=None, nargs="+")
parser.add_argument(
"--dataset", type=str, nargs="+", default=["fineweb", "lmsys_chat"]
)
args = parser.parse_args()

print(f"Training args: {args}")
th.manual_seed(args.seed)
th.cuda.manual_seed_all(args.seed)

activation_store_dir = Path(args.activation_store_dir)

base_model_dir = activation_store_dir / args.base_model
instruct_model_dir = activation_store_dir / args.instruct_model
caches = []
submodule_name = f"layer_{args.layer}_out"

for dataset in args.dataset:
base_model_dataset = base_model_dir / dataset
instruct_model_dataset = instruct_model_dir / dataset
caches.append(
PairedActivationCache(
base_model_dataset / submodule_name,
instruct_model_dataset / submodule_name,
)
)

dataset = th.utils.data.ConcatDataset(caches)

activation_dim = dataset[0].shape[1]
dictionary_size = args.expansion_factor * activation_dim

device = "cuda" if th.cuda.is_available() else "cpu"
print(f"Training on device={device}.")
trainer_cfg = {
"trainer": CrossCoderTrainer,
"dict_class": CrossCoder,
"activation_dim": activation_dim,
"dict_size": dictionary_size,
"lr": args.lr,
"resample_steps": args.resample_steps,
"device": device,
"warmup_steps": 1000,
"layer": args.layer,
"lm_name": f"{args.instruct_model}-{args.base_model}",
"compile": True,
"wandb_name": f"L{args.layer}-mu{args.mu:.1e}-lr{args.lr:.0e}"
+ (f"-{args.run_name}" if args.run_name is not None else ""),
"l1_penalty": args.mu,
"dict_class_kwargs": {
"same_init_for_all_layers": args.same_init_for_all_layers,
"norm_init_scale": args.norm_init_scale,
"init_with_transpose": args.init_with_transpose,
"encoder_layers": args.encoder_layers,
},
"pretrained_ae": (
CrossCoder.from_pretrained(args.pretrained)
if args.pretrained is not None
else None
),
}

validation_size = 10**6
train_dataset, validation_dataset = th.utils.data.random_split(
dataset, [len(dataset) - validation_size, validation_size]
)
print(f"Training on {len(train_dataset)} token activations.")
dataloader = th.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True,
)
validation_dataloader = th.utils.data.DataLoader(
validation_dataset,
batch_size=8192,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
)

# train the sparse autoencoder (SAE)
ae = trainSAE(
data=dataloader,
trainer_config=trainer_cfg,
validate_every_n_steps=args.validate_every_n_steps,
validation_data=validation_dataloader,
use_wandb=not args.disable_wandb,
wandb_entity=args.wandb_entity,
wandb_project="crosscoder",
log_steps=50,
save_dir="checkpoints",
steps=args.max_steps,
save_steps=args.validate_every_n_steps,
)

0 comments on commit 0f3e5e8

Please sign in to comment.