forked from HoagyC/sparse_coding
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathcomparisons.py
139 lines (104 loc) · 4.64 KB
/
comparisons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#import run
from utils import *
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from transformers import GPT2Tokenizer
from transformer_lens import HookedTransformer
import pickle
class BatchedPCA():
def __init__(self, n_dims, device):
super().__init__()
self.n_dims = n_dims
self.device = device
self.cov = torch.zeros((n_dims, n_dims), device=device)
self.mean = torch.zeros((n_dims,), device=device)
self.n_samples = 0
def train_iter(self, activations):
# activations: (batch_size, n_dims)
batch_size = activations.shape[0]
corrected = activations - self.mean.unsqueeze(0)
new_mean = self.mean + torch.mean(corrected, dim=0) * batch_size / (self.n_samples + batch_size)
cov_update = torch.einsum("bi,bj->bij", corrected, activations - new_mean.unsqueeze(0)).mean(dim=0)
self.cov = self.cov * (self.n_samples / (self.n_samples + batch_size)) + cov_update * batch_size / (self.n_samples + batch_size)
self.mean = new_mean
self.n_samples += batch_size
def get_dictionary(self):
eigvals, eigvecs = torch.linalg.eigh(self.cov)
return eigvals.detach().cpu().numpy(), eigvecs.detach().cpu().numpy()
def run_pca_on_activation_dataset(cfg: dotdict, outputs_folder):
# prelim dataset setup
if len(os.listdir(cfg.dataset_folder)) == 0:
print("activation dataset not found, throwing error")
raise Exception("activation dataset not found")
with open(os.path.join(cfg.dataset_folder, "0.pkl"), "rb") as f:
dataset = pickle.load(f)
cfg.mlp_width = dataset.tensors[0][0].shape[-1]
n_lines = cfg.max_lines
del dataset
# actual pca
pca_model = BatchedPCA(cfg.mlp_width, cfg.device)
n_chunks_in_folder = len(os.listdir(cfg.dataset_folder))
for chunk_id in range(n_chunks_in_folder):
chunk_loc = os.path.join(cfg.dataset_folder, str(chunk_id) + ".pkl")
# realistically can just load the whole thing into memory (only 2GB/chunk)
print(chunk_loc)
with open(chunk_loc, "rb") as f:
chunk = pickle.load(f)
dataset = DataLoader(chunk, batch_size=cfg.batch_size, shuffle=False)
for batch in dataset:
activations = batch[0].to(cfg.device)
pca_model.train_iter(activations)
pca_components, pca_directions = pca_model.get_dictionary()
pca_directions_loc = os.path.join(outputs_folder, "pca_results.pkl")
with open(pca_directions_loc, "wb") as f:
pickle.dump((pca_directions, pca_components), f)
return pca_directions, pca_components
def main():
from argparser import parse_args
cfg = parse_args()
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
from datetime import datetime
start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
outputs_folder = os.path.join(cfg.outputs_folder, start_time)
os.makedirs(outputs_folder, exist_ok=True)
from run import setup_data, run_real_data_model, AutoEncoder
cfg.model_name = "EleutherAI/pythia-70m-deduped"
cfg.use_wandb = False
cfg.dict_ratio_exp_high = 5
data_split = "train"
cfg.l1_exp_low = -16
cfg.l1_exp_high = -14
model = None
if cfg.model_name in ["gpt2", "EleutherAI/pythia-70m-deduped"]:
model = HookedTransformer.from_pretrained(cfg.model_name, device=cfg.device)
use_baukit = False
if hasattr(model, "tokenizer"):
tokenizer = model.tokenizer
else:
print("Using default tokenizer from gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
dataset_name = cfg.dataset_name.split("/")[-1] + "-" + cfg.model_name + "-" + str(cfg.layer)
cfg.dataset_folder = os.path.join(cfg.datasets_folder, dataset_name)
os.makedirs(cfg.dataset_folder, exist_ok=True)
if len(os.listdir(cfg.dataset_folder)) == 0:
print(f"Activations in {cfg.dataset_folder} do not exist, creating them")
n_lines = setup_data(cfg, tokenizer, model, use_baukit=use_baukit, split=data_split)
else:
print(f"Activations in {cfg.dataset_folder} already exist, loading them")
# get mlp_width from first file
with open(os.path.join(cfg.dataset_folder, "0.pkl"), "rb") as f:
dataset = pickle.load(f)
cfg.mlp_width = dataset.tensors[0][0].shape[-1]
n_lines = cfg.max_lines
del dataset
# do pca on activations
pca_directions, pca_components = run_pca_on_activation_dataset(cfg, outputs_folder=outputs_folder)
run_real_data_model(cfg)
if __name__ == "__main__":
main()