diff --git a/dictionary_learning/__init__.py b/dictionary_learning/__init__.py index 35d543d..2e4e46a 100644 --- a/dictionary_learning/__init__.py +++ b/dictionary_learning/__init__.py @@ -1,2 +1,2 @@ from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder, CrossCoder -from .buffer import ActivationBuffer \ No newline at end of file +from .buffer import ActivationBuffer diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index 86f24f9..e0ba458 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -6,9 +6,9 @@ from .config import DEBUG if DEBUG: - tracer_kwargs = {'scan' : True, 'validate' : True} + tracer_kwargs = {"scan": True, "validate": True} else: - tracer_kwargs = {'scan' : False, 'validate' : False} + tracer_kwargs = {"scan": False, "validate": False} class ActivationBuffer: @@ -16,30 +16,34 @@ class ActivationBuffer: Implements a buffer of activations. The buffer stores activations from a model, yields them in batches, and refreshes them when the buffer is less than half full. """ - def __init__(self, - data, # generator which yields text data - model : LanguageModel, # LanguageModel from which to extract activations - submodule, # submodule of the model from which to extract activations - d_submodule=None, # submodule dimension; if None, try to detect automatically - io='out', # can be 'in' or 'out'; whether to extract input or output activations - n_ctxs=3e4, # approximate number of contexts to store in the buffer - ctx_len=128, # length of each context - refresh_batch_size=512, # size of batches in which to process the data when adding to buffer - out_batch_size=8192, # size of batches in which to yield activations - device='cpu' # device on which to store the activations - ): - - if io not in ['in', 'out']: + + def __init__( + self, + data, # generator which yields text data + model: LanguageModel, # LanguageModel from which to extract activations + submodule, # submodule of the model from which to extract activations + d_submodule=None, # submodule dimension; if None, try to detect automatically + io="out", # can be 'in' or 'out'; whether to extract input or output activations + n_ctxs=3e4, # approximate number of contexts to store in the buffer + ctx_len=128, # length of each context + refresh_batch_size=512, # size of batches in which to process the data when adding to buffer + out_batch_size=8192, # size of batches in which to yield activations + device="cpu", # device on which to store the activations + ): + + if io not in ["in", "out"]: raise ValueError("io must be either 'in' or 'out'") if d_submodule is None: try: - if io == 'in': + if io == "in": d_submodule = submodule.in_features else: d_submodule = submodule.out_features except: - raise ValueError("d_submodule cannot be inferred and must be specified directly") + raise ValueError( + "d_submodule cannot be inferred and must be specified directly" + ) self.activations = t.empty(0, d_submodule, device=device) self.read = t.zeros(0).bool() @@ -54,7 +58,7 @@ def __init__(self, self.refresh_batch_size = refresh_batch_size self.out_batch_size = out_batch_size self.device = device - + def __iter__(self): return self @@ -69,10 +73,12 @@ def __next__(self): # return a batch unreads = (~self.read).nonzero().squeeze() - idxs = unreads[t.randperm(len(unreads), device=unreads.device)[:self.out_batch_size]] + idxs = unreads[ + t.randperm(len(unreads), device=unreads.device)[: self.out_batch_size] + ] self.read[idxs] = True return self.activations[idxs] - + def text_batch(self, batch_size=None): """ Return a list of text @@ -80,12 +86,10 @@ def text_batch(self, batch_size=None): if batch_size is None: batch_size = self.refresh_batch_size try: - return [ - next(self.data) for _ in range(batch_size) - ] + return [next(self.data) for _ in range(batch_size)] except StopIteration: raise StopIteration("End of data stream reached") - + def tokenized_batch(self, batch_size=None): """ Return a batch of tokenized inputs. @@ -93,10 +97,10 @@ def tokenized_batch(self, batch_size=None): texts = self.text_batch(batch_size=batch_size) return self.model.tokenizer( texts, - return_tensors='pt', + return_tensors="pt", max_length=self.ctx_len, padding=True, - truncation=True + truncation=True, ) def refresh(self): @@ -105,7 +109,9 @@ def refresh(self): self.activations = self.activations[~self.read] current_idx = len(self.activations) - new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device) + new_activations = t.empty( + self.activation_buffer_size, self.d_submodule, device=self.device + ) new_activations[: len(self.activations)] = self.activations self.activations = new_activations @@ -135,8 +141,8 @@ def refresh(self): assert remaining_space > 0 hidden_states = hidden_states[:remaining_space] - self.activations[current_idx : current_idx + len(hidden_states)] = hidden_states.to( - self.device + self.activations[current_idx : current_idx + len(hidden_states)] = ( + hidden_states.to(self.device) ) current_idx += len(hidden_states) @@ -148,13 +154,13 @@ def refresh(self): @property def config(self): return { - 'd_submodule' : self.d_submodule, - 'io' : self.io, - 'n_ctxs' : self.n_ctxs, - 'ctx_len' : self.ctx_len, - 'refresh_batch_size' : self.refresh_batch_size, - 'out_batch_size' : self.out_batch_size, - 'device' : self.device + "d_submodule": self.d_submodule, + "io": self.io, + "n_ctxs": self.n_ctxs, + "ctx_len": self.ctx_len, + "refresh_batch_size": self.refresh_batch_size, + "out_batch_size": self.out_batch_size, + "device": self.device, } def close(self): @@ -166,28 +172,30 @@ def close(self): class HeadActivationBuffer: """ - This is specifically designed for training SAEs for individual attn heads in Llama3. + This is specifically designed for training SAEs for individual attn heads in Llama3. Much redundant code; can eventually be merged to ActivationBuffer. Implements a buffer of activations. The buffer stores activations from a model, yields them in batches, and refreshes them when the buffer is less than half full. """ - def __init__(self, - data, # generator which yields text data - model : LanguageModel, # LanguageModel from which to extract activations - layer, # submodule of the model from which to extract activations - n_ctxs=3e4, # approximate number of contexts to store in the buffer - ctx_len=128, # length of each context - refresh_batch_size=512, # size of batches in which to process the data when adding to buffer - out_batch_size=8192, # size of batches in which to yield activations - device='cpu', # device on which to store the activations - apply_W_O = False, - remote = False, - ): - + + def __init__( + self, + data, # generator which yields text data + model: LanguageModel, # LanguageModel from which to extract activations + layer, # submodule of the model from which to extract activations + n_ctxs=3e4, # approximate number of contexts to store in the buffer + ctx_len=128, # length of each context + refresh_batch_size=512, # size of batches in which to process the data when adding to buffer + out_batch_size=8192, # size of batches in which to yield activations + device="cpu", # device on which to store the activations + apply_W_O=False, + remote=False, + ): + self.layer = layer self.n_heads = model.config.num_attention_heads - self.resid_dim = model.config.hidden_size - self.head_dim = self.resid_dim //self.n_heads + self.resid_dim = model.config.hidden_size + self.head_dim = self.resid_dim // self.n_heads self.data = data self.model = model self.n_ctxs = n_ctxs @@ -198,9 +206,11 @@ def __init__(self, self.apply_W_O = apply_W_O self.remote = remote - self.activations = t.empty(0, self.n_heads, self.head_dim, device=device) # [seq-pos, n_layers, n_head, head_dim] + self.activations = t.empty( + 0, self.n_heads, self.head_dim, device=device + ) # [seq-pos, n_layers, n_head, head_dim] self.read = t.zeros(0).bool() - + def __iter__(self): return self @@ -215,10 +225,12 @@ def __next__(self): # return a batch unreads = (~self.read).nonzero().squeeze() - idxs = unreads[t.randperm(len(unreads), device=unreads.device)[:self.out_batch_size]] + idxs = unreads[ + t.randperm(len(unreads), device=unreads.device)[: self.out_batch_size] + ] self.read[idxs] = True return self.activations[idxs] - + def text_batch(self, batch_size=None): """ Return a list of text @@ -226,12 +238,10 @@ def text_batch(self, batch_size=None): if batch_size is None: batch_size = self.refresh_batch_size try: - return [ - next(self.data) for _ in range(batch_size) - ] + return [next(self.data) for _ in range(batch_size)] except StopIteration: raise StopIteration("End of data stream reached") - + def tokenized_batch(self, batch_size=None): """ Return a batch of tokenized inputs. @@ -239,10 +249,10 @@ def tokenized_batch(self, batch_size=None): texts = self.text_batch(batch_size=batch_size) return self.model.tokenizer( texts, - return_tensors='pt', + return_tensors="pt", max_length=self.ctx_len, padding=True, - truncation=True + truncation=True, ) def refresh(self): @@ -250,43 +260,67 @@ def refresh(self): while len(self.activations) < self.n_ctxs * self.ctx_len: with t.no_grad(): - with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote): + with self.model.trace( + self.text_batch(), + **tracer_kwargs, + invoker_args={"truncation": True, "max_length": self.ctx_len}, + remote=self.remote, + ): input = self.model.input.save() - hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save() + hidden_states = self.model.model.layers[ + self.layer + ].self_attn.o_proj.input[0][ + 0 + ] # .save() if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] # Reshape by head - new_shape = hidden_states.size()[:-1] + (self.n_heads, self.head_dim) # (batch_size, seq_len, n_heads, head_dim) + new_shape = hidden_states.size()[:-1] + ( + self.n_heads, + self.head_dim, + ) # (batch_size, seq_len, n_heads, head_dim) hidden_states = hidden_states.view(*new_shape) # Optionally map from head dim to resid dim if self.apply_W_O: - hidden_states_W_O_shape = hidden_states.size()[:-1] + (self.model.config.hidden_size,) # (batch_size, seq_len, n_heads, resid_dim) - hidden_states_W_O = t.zeros(hidden_states_W_O_shape, device=hidden_states.device) - for h in range (self.n_heads): - start = h*self.head_dim - end = (h+1)*self.head_dim - hidden_states_W_O[..., h, start:end] = hidden_states[..., h, :] - hidden_states = self.model.model.layers[self.layer].self_attn.o_proj(hidden_states_W_O).save() + hidden_states_W_O_shape = hidden_states.size()[:-1] + ( + self.model.config.hidden_size, + ) # (batch_size, seq_len, n_heads, resid_dim) + hidden_states_W_O = t.zeros( + hidden_states_W_O_shape, device=hidden_states.device + ) + for h in range(self.n_heads): + start = h * self.head_dim + end = (h + 1) * self.head_dim + hidden_states_W_O[..., h, start:end] = hidden_states[ + ..., h, : + ] + hidden_states = ( + self.model.model.layers[self.layer] + .self_attn.o_proj(hidden_states_W_O) + .save() + ) # Apply attention mask - attn_mask = input.value[1]['attention_mask'] + attn_mask = input.value[1]["attention_mask"] hidden_states = hidden_states[attn_mask != 0] # Save results - self.activations = t.cat([self.activations, hidden_states.to(self.device)], dim=0) + self.activations = t.cat( + [self.activations, hidden_states.to(self.device)], dim=0 + ) self.read = t.zeros(len(self.activations), dtype=t.bool, device=self.device) @property def config(self): return { - 'layer': self.layer, - 'n_ctxs' : self.n_ctxs, - 'ctx_len' : self.ctx_len, - 'refresh_batch_size' : self.refresh_batch_size, - 'out_batch_size' : self.out_batch_size, - 'device' : self.device + "layer": self.layer, + "n_ctxs": self.n_ctxs, + "ctx_len": self.ctx_len, + "refresh_batch_size": self.refresh_batch_size, + "out_batch_size": self.out_batch_size, + "device": self.device, } def close(self): @@ -326,8 +360,10 @@ def __init__( else: d_submodule = submodule.out_features except: - raise ValueError("d_submodule cannot be inferred and must be specified directly") - + raise ValueError( + "d_submodule cannot be inferred and must be specified directly" + ) + if io in ["in", "out"]: self.activations = t.empty(0, d_submodule, device=device) elif io == "in_and_out": @@ -360,18 +396,23 @@ def __next__(self): # return a batch unreads = (~self.read).nonzero().squeeze() - idxs = unreads[t.randperm(len(unreads), device=unreads.device)[: self.out_batch_size]] + idxs = unreads[ + t.randperm(len(unreads), device=unreads.device)[: self.out_batch_size] + ] self.read[idxs] = True return self.activations[idxs] - def tokenized_batch(self, batch_size=None): """ Return a batch of tokenized inputs. """ texts = self.text_batch(batch_size=batch_size) return self.model.tokenizer( - texts, return_tensors="pt", max_length=self.ctx_len, padding=True, truncation=True + texts, + return_tensors="pt", + max_length=self.ctx_len, + padding=True, + truncation=True, ) def token_batch(self, batch_size=None): @@ -381,10 +422,12 @@ def token_batch(self, batch_size=None): if batch_size is None: batch_size = self.refresh_batch_size try: - return t.tensor([next(self.data) for _ in range(batch_size)], device=self.device) + return t.tensor( + [next(self.data) for _ in range(batch_size)], device=self.device + ) except StopIteration: raise StopIteration("End of data stream reached") - + def text_batch(self, batch_size=None): """ Return a list of text @@ -425,10 +468,16 @@ def refresh(self): elif self.io == "out": hidden_states = self._reshaped_activations(hidden_states_out) elif self.io == "in_and_out": - hidden_states_in = self._reshaped_activations(hidden_states_in).unsqueeze(1) - hidden_states_out = self._reshaped_activations(hidden_states_out).unsqueeze(1) + hidden_states_in = self._reshaped_activations( + hidden_states_in + ).unsqueeze(1) + hidden_states_out = self._reshaped_activations( + hidden_states_out + ).unsqueeze(1) hidden_states = t.cat([hidden_states_in, hidden_states_out], dim=1) - self.activations = t.cat([self.activations, hidden_states.to(self.device)], dim=0) + self.activations = t.cat( + [self.activations, hidden_states.to(self.device)], dim=0 + ) self.read = t.zeros(len(self.activations), dtype=t.bool, device=self.device) @property diff --git a/dictionary_learning/cache.py b/dictionary_learning/cache.py index a8e2455..18f2bbe 100644 --- a/dictionary_learning/cache.py +++ b/dictionary_learning/cache.py @@ -10,18 +10,21 @@ import json from .config import DEBUG + if DEBUG: - tracer_kwargs = {'scan' : True, 'validate' : True} + tracer_kwargs = {"scan": True, "validate": True} else: - tracer_kwargs = {'scan' : False, 'validate' : False} + tracer_kwargs = {"scan": False, "validate": False} class ActivationShard: - def __init__(self, store_dir : str, shard_idx : int): + def __init__(self, store_dir: str, shard_idx: int): self.shard_file = os.path.join(store_dir, f"shard_{shard_idx}.memmap") with open(self.shard_file.replace(".memmap", ".meta"), "r") as f: self.shape = tuple(json.load(f)["shape"]) - self.activations = np.memmap(self.shard_file, dtype=np.float32, mode='r', shape=self.shape) + self.activations = np.memmap( + self.shard_file, dtype=np.float32, mode="r", shape=self.shape + ) def __len__(self): return self.activations.shape[0] @@ -29,94 +32,138 @@ def __len__(self): def __getitem__(self, *indices): return th.tensor(self.activations[*indices], dtype=th.float32) + class ActivationCache: - def __init__(self, store_dir : str): + def __init__(self, store_dir: str): self.store_dir = store_dir self.config = json.load(open(os.path.join(store_dir, "config.json"), "r")) - self.shards = [ActivationShard(store_dir, i) for i in range(self.config["shard_count"])] + self.shards = [ + ActivationShard(store_dir, i) for i in range(self.config["shard_count"]) + ] self._range_to_shard_idx = np.cumsum([0] + [s.shape[0] for s in self.shards]) - + def __len__(self): return self.config["total_size"] - - def __getitem__(self, index : int): + + def __getitem__(self, index: int): shard_idx = np.searchsorted(self._range_to_shard_idx, index, side="right") - 1 offset = index - self._range_to_shard_idx[shard_idx] shard = self.shards[shard_idx] return shard[offset] - + @staticmethod - def get_activations(submodule : nn.Module, io : str): + def get_activations(submodule: nn.Module, io: str): if io == "in": return submodule.input[0] else: return submodule.output[0] @staticmethod - def collate_store_shards(store_dirs : Tuple[str], shard_count : int, activation_cache : List[th.Tensor], submodule_names : Tuple[str], shuffle_shards : bool = True, io : str = "out"): + def collate_store_shards( + store_dirs: Tuple[str], + shard_count: int, + activation_cache: List[th.Tensor], + submodule_names: Tuple[str], + shuffle_shards: bool = True, + io: str = "out", + ): for i, name in enumerate(submodule_names): - activations = th.cat(activation_cache[i], dim=0) # (N x B x T) x D (N = number of batches per shard) - print(f"Storing activation shard ({activations.shape}) for {name} {io}") - if shuffle_shards: - idx = np.random.permutation(activations.shape[0]) - activations = activations[idx] - # use memmap to store activations - memmap_file = os.path.join(store_dirs[i], f"shard_{shard_count}.memmap") - memmap_file_meta = memmap_file.replace(".memmap", ".meta") - memmap = np.memmap(memmap_file, dtype=np.float32, mode='w+', shape=(activations.shape[0], activations.shape[1])) - memmap[:] = activations.numpy() - memmap.flush() - with open(memmap_file_meta, "w") as f: - json.dump({"shape" : list(activations.shape)}, f) - del memmap + activations = th.cat( + activation_cache[i], dim=0 + ) # (N x B x T) x D (N = number of batches per shard) + print(f"Storing activation shard ({activations.shape}) for {name} {io}") + if shuffle_shards: + idx = np.random.permutation(activations.shape[0]) + activations = activations[idx] + # use memmap to store activations + memmap_file = os.path.join(store_dirs[i], f"shard_{shard_count}.memmap") + memmap_file_meta = memmap_file.replace(".memmap", ".meta") + memmap = np.memmap( + memmap_file, + dtype=np.float32, + mode="w+", + shape=(activations.shape[0], activations.shape[1]), + ) + memmap[:] = activations.numpy() + memmap.flush() + with open(memmap_file_meta, "w") as f: + json.dump({"shape": list(activations.shape)}, f) + del memmap @th.no_grad() @staticmethod - def collect(data : Dataset, - submodules : Tuple[nn.Module], - submodule_names : Tuple[str], - model : LanguageModel, - store_dir : str, - batch_size : int = 64, - context_len : int = 128, - shard_size : int = 10**6, - d_model : int = 1024, - shuffle_shards : bool = False, - io : str = "out", - num_workers : int = 8, - max_total_tokens : int = 10**8, - last_submodule : nn.Module = None): + def collect( + data: Dataset, + submodules: Tuple[nn.Module], + submodule_names: Tuple[str], + model: LanguageModel, + store_dir: str, + batch_size: int = 64, + context_len: int = 128, + shard_size: int = 10**6, + d_model: int = 1024, + shuffle_shards: bool = False, + io: str = "out", + num_workers: int = 8, + max_total_tokens: int = 10**8, + last_submodule: nn.Module = None, + ): dataloader = DataLoader(data, batch_size=batch_size, num_workers=num_workers) activation_cache = [[] for _ in submodules] - store_dirs = [os.path.join(store_dir, f"{submodule_names[i]}_{io}") for i in range(len(submodules))] + store_dirs = [ + os.path.join(store_dir, f"{submodule_names[i]}_{io}") + for i in range(len(submodules)) + ] for store_dir in store_dirs: os.makedirs(store_dir, exist_ok=True) total_size = 0 current_size = 0 shard_count = 0 for batch in tqdm(dataloader, desc="Collecting activations"): - tokens = model.tokenizer(batch, max_length=context_len, truncation=True, return_tensors="pt", padding=True).to(model.device) + tokens = model.tokenizer( + batch, + max_length=context_len, + truncation=True, + return_tensors="pt", + padding=True, + ).to(model.device) attention_mask = tokens["attention_mask"] with model.trace( tokens, **tracer_kwargs, ): for i, submodule in enumerate(submodules): - local_activations = ActivationCache.get_activations(submodule, io).reshape(-1, d_model).save() # (B x T) x D + local_activations = ( + ActivationCache.get_activations(submodule, io) + .reshape(-1, d_model) + .save() + ) # (B x T) x D activation_cache[i].append(local_activations) if last_submodule is not None: last_submodule.output.stop() for i in range(len(submodules)): - activation_cache[i][-1] = activation_cache[i][-1].value[attention_mask.reshape(-1).bool()].cpu().to(th.float32) # remove padding tokens + activation_cache[i][-1] = ( + activation_cache[i][-1] + .value[attention_mask.reshape(-1).bool()] + .cpu() + .to(th.float32) + ) # remove padding tokens current_size += activation_cache[0][-1].shape[0] if current_size > shard_size: - ActivationCache.collate_store_shards(store_dirs, shard_count, activation_cache, submodule_names, shuffle_shards, io) + ActivationCache.collate_store_shards( + store_dirs, + shard_count, + activation_cache, + submodule_names, + shuffle_shards, + io, + ) shard_count += 1 total_size += current_size @@ -128,37 +175,60 @@ def collect(data : Dataset, break if current_size > 0: - ActivationCache.collate_store_shards(store_dirs, shard_count, activation_cache, submodule_names, shuffle_shards, io) + ActivationCache.collate_store_shards( + store_dirs, + shard_count, + activation_cache, + submodule_names, + shuffle_shards, + io, + ) # store configs for i, store_dir in enumerate(store_dirs): with open(os.path.join(store_dir, "config.json"), "w") as f: - json.dump({"batch_size" : batch_size, "context_len" : context_len, "shard_size" : shard_size, "d_model" : d_model, "shuffle_shards" : shuffle_shards, "io" : io, "total_size" : total_size, "shard_count" : shard_count}, f) + json.dump( + { + "batch_size": batch_size, + "context_len": context_len, + "shard_size": shard_size, + "d_model": d_model, + "shuffle_shards": shuffle_shards, + "io": io, + "total_size": total_size, + "shard_count": shard_count, + }, + f, + ) print(f"Finished collecting activations. Total size: {total_size}") class PairedActivationCache: - def __init__(self, store_dir_1 : str, store_dir_2 : str): + def __init__(self, store_dir_1: str, store_dir_2: str): self.activation_cache_1 = ActivationCache(store_dir_1) self.activation_cache_2 = ActivationCache(store_dir_2) assert len(self.activation_cache_1) == len(self.activation_cache_2) def __len__(self): return len(self.activation_cache_1) - - def __getitem__(self, index : int): - return th.stack((self.activation_cache_1[index], self.activation_cache_2[index]), dim=0) + + def __getitem__(self, index: int): + return th.stack( + (self.activation_cache_1[index], self.activation_cache_2[index]), dim=0 + ) class ActivationCacheTuple: - def __init__(self, *store_dirs : str): - self.activation_caches = [ActivationCache(store_dir) for store_dir in store_dirs] + def __init__(self, *store_dirs: str): + self.activation_caches = [ + ActivationCache(store_dir) for store_dir in store_dirs + ] assert len(self.activation_caches) > 0 for i in range(1, len(self.activation_caches)): assert len(self.activation_caches[i]) == len(self.activation_caches[0]) def __len__(self): return len(self.activation_caches[0]) - - def __getitem__(self, index : int): + + def __getitem__(self, index: int): return th.stack([cache[index] for cache in self.activation_caches], dim=0) diff --git a/dictionary_learning/config.py b/dictionary_learning/config.py index 7edd09a..0ce00b0 100644 --- a/dictionary_learning/config.py +++ b/dictionary_learning/config.py @@ -1,2 +1,2 @@ # debugging flag for use in other scripts -DEBUG = False \ No newline at end of file +DEBUG = False diff --git a/dictionary_learning/evaluation.py b/dictionary_learning/evaluation.py index 6b3b0e5..ca01d20 100644 --- a/dictionary_learning/evaluation.py +++ b/dictionary_learning/evaluation.py @@ -16,7 +16,10 @@ def loss_recovered( max_len=None, # max context length for loss recovered normalize_batch=False, # normalize batch before passing through dictionary io="out", # can be 'in', 'out', or 'in_and_out' - tracer_args = {'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. + tracer_args={ + "use_cache": False, + "output_attentions": False, + }, # minimize cache during model trace. ): """ How much of the model's loss is recovered by replacing the component output @@ -26,33 +29,36 @@ def loss_recovered( if max_len is None: invoker_args = {} else: - invoker_args = {"truncation": True, "max_length": max_len } + invoker_args = {"truncation": True, "max_length": max_len} # unmodified logits with model.trace(text, invoker_args=invoker_args): logits_original = model.output.save() logits_original = logits_original.value - + # logits when replacing component activations with reconstruction by autoencoder with model.trace(text, **tracer_args, invoker_args=invoker_args): - if io == 'in': + if io == "in": x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] + if type(submodule.input.shape) == tuple: + x = x[0] if normalize_batch: - scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() + scale = (dictionary.activation_dim**0.5) / x.norm(dim=-1).mean() x = x * scale - elif io == 'out': + elif io == "out": x = submodule.output - if type(submodule.output.shape) == tuple: x = x[0] + if type(submodule.output.shape) == tuple: + x = x[0] if normalize_batch: - scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() + scale = (dictionary.activation_dim**0.5) / x.norm(dim=-1).mean() x = x * scale - elif io == 'in_and_out': + elif io == "in_and_out": x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] - print(f'x.shape: {x.shape}') + if type(submodule.input.shape) == tuple: + x = x[0] + print(f"x.shape: {x.shape}") if normalize_batch: - scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() + scale = (dictionary.activation_dim**0.5) / x.norm(dim=-1).mean() x = x * scale else: raise ValueError(f"Invalid value for io: {io}") @@ -63,28 +69,28 @@ def loss_recovered( # intervene with `x_hat` with model.trace(text, **tracer_args, invoker_args=invoker_args): - if io == 'in': + if io == "in": x = submodule.input[0] if normalize_batch: - scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() + scale = (dictionary.activation_dim**0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale if type(submodule.input.shape) == tuple: submodule.input[0][:] = x_hat else: submodule.input = x_hat - elif io == 'out': + elif io == "out": x = submodule.output if normalize_batch: - scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() + scale = (dictionary.activation_dim**0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale if type(submodule.output.shape) == tuple: submodule.output = (x_hat,) else: submodule.output = x_hat - elif io == 'in_and_out': + elif io == "in_and_out": x = submodule.input[0] if normalize_batch: - scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() + scale = (dictionary.activation_dim**0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale submodule.output = x_hat else: @@ -95,13 +101,13 @@ def loss_recovered( # logits when replacing component activations with zeros with model.trace(text, **tracer_args, invoker_args=invoker_args): - if io == 'in': + if io == "in": x = submodule.input[0] if type(submodule.input.shape) == tuple: submodule.input[0][:] = t.zeros_like(x[0]) else: submodule.input = t.zeros_like(x) - elif io in ['out', 'in_and_out']: + elif io in ["out", "in_and_out"]: x = submodule.output if type(submodule.output.shape) == tuple: submodule.output[0][:] = t.zeros_like(x[0]) @@ -109,7 +115,7 @@ def loss_recovered( submodule.output = t.zeros_like(x) else: raise ValueError(f"Invalid value for io: {io}") - + input = model.input.save() logits_zero = model.output.save() logits_zero = logits_zero.value @@ -126,14 +132,14 @@ def loss_recovered( tokens = text else: try: - tokens = input[1]['input_ids'] + tokens = input[1]["input_ids"] except: - tokens = input[1]['input'] + tokens = input[1]["input"] # compute losses losses = [] - if hasattr(model, 'tokenizer') and model.tokenizer is not None: - loss_kwargs = {'ignore_index': model.tokenizer.pad_token_id} + if hasattr(model, "tokenizer") and model.tokenizer is not None: + loss_kwargs = {"ignore_index": model.tokenizer.pad_token_id} else: loss_kwargs = {} for logits in [logits_original, logits_reconstructed, logits_zero]: @@ -147,12 +153,15 @@ def loss_recovered( def evaluate( dictionary, # a dictionary - activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered + activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered max_len=128, # max context length for loss recovered batch_size=128, # batch size for loss recovered io="out", # can be 'in', 'out', or 'in_and_out' - normalize_batch=False, # normalize batch before passing through dictionary - tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. + normalize_batch=False, # normalize batch before passing through dictionary + tracer_args={ + "use_cache": False, + "output_attentions": False, + }, # minimize cache during model trace. device="cpu", ): with t.no_grad(): @@ -162,7 +171,7 @@ def evaluate( try: x = next(activations).to(device) if normalize_batch: - x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5) + x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim**0.5) except StopIteration: raise StopIteration( @@ -173,7 +182,9 @@ def evaluate( l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() l1_loss = f.norm(p=1, dim=-1).mean() l0 = (f != 0).float().sum(dim=-1).mean() - frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size + frac_alive = ( + t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size + ) # cosine similarity between x and x_hat x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True) @@ -183,13 +194,13 @@ def evaluate( # l2 ratio l2_ratio = (t.linalg.norm(x_hat, dim=-1) / t.linalg.norm(x, dim=-1)).mean() - #compute variance explained + # compute variance explained total_variance = t.var(x, dim=0).sum() residual_variance = t.var(x - x_hat, dim=0).sum() - frac_variance_explained = (1 - residual_variance / total_variance) + frac_variance_explained = 1 - residual_variance / total_variance # Equation 10 from https://arxiv.org/abs/2404.16014 - x_hat_norm_squared = t.linalg.norm(x_hat, dim=-1, ord=2)**2 + x_hat_norm_squared = t.linalg.norm(x_hat, dim=-1, ord=2) ** 2 x_dot_x_hat = (x * x_hat).sum(dim=-1) relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean() @@ -200,7 +211,7 @@ def evaluate( out["frac_variance_explained"] = frac_variance_explained.item() out["cossim"] = cossim.item() out["l2_ratio"] = l2_ratio.item() - out['relative_reconstruction_bias'] = relative_reconstruction_bias.item() + out["relative_reconstruction_bias"] = relative_reconstruction_bias.item() if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): return out @@ -214,10 +225,10 @@ def evaluate( max_len=max_len, normalize_batch=normalize_batch, io=io, - tracer_args=tracer_args + tracer_args=tracer_args, ) frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero) - + out["loss_original"] = loss_original.item() out["loss_reconstructed"] = loss_reconstructed.item() out["loss_zero"] = loss_zero.item() diff --git a/dictionary_learning/grad_pursuit.py b/dictionary_learning/grad_pursuit.py index 0e1e6d7..fb2f181 100644 --- a/dictionary_learning/grad_pursuit.py +++ b/dictionary_learning/grad_pursuit.py @@ -6,27 +6,30 @@ import torch as t -def _grad_pursuit_update_step(signal, weights, dictionary, batch_arange, selected_features): +def _grad_pursuit_update_step( + signal, weights, dictionary, batch_arange, selected_features +): """ signal: b x d, weights: b x n, dictionary: d x n, batch_arange: b, selected_features: b x n """ - residual = signal - t.einsum('bn,dn -> bd', weights, dictionary) + residual = signal - t.einsum("bn,dn -> bd", weights, dictionary) # choose the element with largest inner product with residual, as in matched pursuit. - inner_products = t.einsum('dn,bd -> bn', dictionary, residual) + inner_products = t.einsum("dn,bd -> bn", dictionary, residual) idxs = t.argmax(inner_products, dim=1) - # add the new feature to the active set. + # add the new feature to the active set. selected_features[batch_arange, idxs] = 1 # the gradient for the weights is the inner product, restricted to the chosen features grad = selected_features * inner_products # the next two steps compute the optimal step size - c = t.einsum('bn,dn -> bd', grad, dictionary) - step_size = t.einsum('bd,bd -> b', c, residual) / t.einsum('bd,bd -> b ', c, c) - weights = weights + t.einsum('b,bn -> bn', step_size, grad) - weights = t.clip(weights, min=0) # clip the weights to be positive + c = t.einsum("bn,dn -> bd", grad, dictionary) + step_size = t.einsum("bd,bd -> b", c, residual) / t.einsum("bd,bd -> b ", c, c) + weights = weights + t.einsum("b,bn -> bn", step_size, grad) + weights = t.clip(weights, min=0) # clip the weights to be positive return weights, selected_features -def grad_pursuit(signal, dictionary, target_l0 : int = 20, device : str = 'cpu'): + +def grad_pursuit(signal, dictionary, target_l0: int = 20, device: str = "cpu"): """ Inputs: signal: b x d, dictionary: d x n, target_l0: int, device: str Outputs: weights: b x n @@ -38,5 +41,6 @@ def grad_pursuit(signal, dictionary, target_l0 : int = 20, device : str = 'cpu') selected_features = t.zeros((signal.shape[0], dictionary.shape[1])).to(device) for _ in range(target_l0): weights, selected_features = _grad_pursuit_update_step( - signal, weights, dictionary, batch_arange, selected_features) - return weights \ No newline at end of file + signal, weights, dictionary, batch_arange, selected_features + ) + return weights diff --git a/dictionary_learning/interp.py b/dictionary_learning/interp.py index 283965b..3a9de05 100644 --- a/dictionary_learning/interp.py +++ b/dictionary_learning/interp.py @@ -80,7 +80,14 @@ def feature_effect( def examine_dimension( - model, submodule, buffer, dictionary=None, max_length=128, n_inputs=512, dim_idx=None, k=30 + model, + submodule, + buffer, + dictionary=None, + max_length=128, + n_inputs=512, + dim_idx=None, + k=30, ): tracer_kwargs = { @@ -142,7 +149,9 @@ def _list_decode(x): top_affected = feature_effect( model, submodule, dictionary, dim_idx, tokens, max_length=max_length, k=k ) - top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)] + top_affected = [ + (model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected) + ] return namedtuple("featureProfile", ["top_contexts", "top_tokens", "top_affected"])( top_contexts, top_tokens, top_affected @@ -160,7 +169,8 @@ def feature_umap( feat_idxs=None, # if not none, indicate the feature with a red dot ): """ - Fit a UMAP embedding of the dictionary features and return a plotly plot of the result.""" + Fit a UMAP embedding of the dictionary features and return a plotly plot of the result. + """ if weight == "encoder": df = pd.DataFrame(dictionary.encoder.weight.cpu().detach().numpy()) else: @@ -177,9 +187,13 @@ def feature_umap( if isinstance(feat_idxs, int): feat_idxs = [feat_idxs] else: - colors = ["blue" if i not in feat_idxs else "red" for i in range(embedding.shape[0])] + colors = [ + "blue" if i not in feat_idxs else "red" for i in range(embedding.shape[0]) + ] if n_components == 2: - return px.scatter(x=embedding[:, 0], y=embedding[:, 1], hover_name=df.index, color=colors) + return px.scatter( + x=embedding[:, 0], y=embedding[:, 1], hover_name=df.index, color=colors + ) if n_components == 3: return px.scatter_3d( x=embedding[:, 0], diff --git a/dictionary_learning/trainers/__init__.py b/dictionary_learning/trainers/__init__.py index 7e2d946..1049b96 100644 --- a/dictionary_learning/trainers/__init__.py +++ b/dictionary_learning/trainers/__init__.py @@ -5,4 +5,4 @@ from .top_k import TrainerTopK from .jumprelu import TrainerJumpRelu from .batch_top_k import TrainerBatchTopK, BatchTopKSAE -from .crosscoder import CrossCoderTrainer \ No newline at end of file +from .crosscoder import CrossCoderTrainer diff --git a/dictionary_learning/trainers/batch_top_k.py b/dictionary_learning/trainers/batch_top_k.py index e684d9a..9c6d391 100644 --- a/dictionary_learning/trainers/batch_top_k.py +++ b/dictionary_learning/trainers/batch_top_k.py @@ -185,7 +185,11 @@ def loss(self, x, step=None, logging=False): x, x_hat, f, - {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()}, + { + "l2_loss": l2_loss.item(), + "auxk_loss": auxk_loss.item(), + "loss": loss.item(), + }, ) def update(self, step, x): diff --git a/dictionary_learning/trainers/crosscoder.py b/dictionary_learning/trainers/crosscoder.py index 24561df..bc08943 100644 --- a/dictionary_learning/trainers/crosscoder.py +++ b/dictionary_learning/trainers/crosscoder.py @@ -1,34 +1,38 @@ """ Implements the standard SAE training scheme. """ + import torch as th from ..trainers.trainer import SAETrainer from ..config import DEBUG from ..dictionary import CrossCoder from collections import namedtuple + class CrossCoderTrainer(SAETrainer): """ Standard SAE training scheme for cross-coding. """ - def __init__(self, - dict_class=CrossCoder, - num_layers=2, - activation_dim=512, - dict_size=64*512, - lr=1e-3, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name='CrossCoderTrainer', - submodule_name=None, - compile=False, - dict_class_kwargs={}, - pretrained_ae=None, + + def __init__( + self, + dict_class=CrossCoder, + num_layers=2, + activation_dim=512, + dict_size=64 * 512, + lr=1e-3, + l1_penalty=1e-1, + warmup_steps=1000, # lr warmup period at start of training and after each resample + resample_steps=None, # how often to resample neurons + seed=None, + device=None, + layer=None, + lm_name=None, + wandb_name="CrossCoderTrainer", + submodule_name=None, + compile=False, + dict_class_kwargs={}, + pretrained_ae=None, ): super().__init__(seed) @@ -43,19 +47,21 @@ def __init__(self, # initialize dictionary if pretrained_ae is None: - self.ae = dict_class(activation_dim, dict_size, num_layers=num_layers, **dict_class_kwargs) + self.ae = dict_class( + activation_dim, dict_size, num_layers=num_layers, **dict_class_kwargs + ) else: self.ae = pretrained_ae - + if compile: self.ae = th.compile(self.ae) self.lr = lr - self.l1_penalty=l1_penalty + self.l1_penalty = l1_penalty self.warmup_steps = warmup_steps self.wandb_name = wandb_name if device is None: - self.device = 'cuda' if th.cuda.is_available() else 'cpu' + self.device = "cuda" if th.cuda.is_available() else "cpu" else: self.device = device self.ae.to(self.device) @@ -64,35 +70,43 @@ def __init__(self, if self.resample_steps is not None: # how many steps since each neuron was last activated? - self.steps_since_active = th.zeros(self.ae.dict_size, dtype=int).to(self.device) + self.steps_since_active = th.zeros(self.ae.dict_size, dtype=int).to( + self.device + ) else: - self.steps_since_active = None + self.steps_since_active = None self.optimizer = th.optim.Adam(self.ae.parameters(), lr=lr) if resample_steps is None: + def warmup_fn(step): - return min(step / warmup_steps, 1.) + return min(step / warmup_steps, 1.0) + else: + def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = th.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + return min((step % resample_steps) / warmup_steps, 1.0) + + self.scheduler = th.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=warmup_fn + ) def resample_neurons(self, deads, activations): with th.no_grad(): - if deads.sum() == 0: return + if deads.sum() == 0: + return self.ae.resample_neurons(deads, activations) # reset Adam parameters for dead neurons - state_dict = self.optimizer.state_dict()['state'] + state_dict = self.optimizer.state_dict()["state"] ## encoder weight - state_dict[0]['exp_avg'][:, :, deads] = 0. - state_dict[0]['exp_avg_sq'][:, :, deads] = 0. + state_dict[0]["exp_avg"][:, :, deads] = 0.0 + state_dict[0]["exp_avg_sq"][:, :, deads] = 0.0 ## encoder bias - state_dict[1]['exp_avg'][deads] = 0. - state_dict[1]['exp_avg_sq'][deads] = 0. + state_dict[1]["exp_avg"][deads] = 0.0 + state_dict[1]["exp_avg_sq"][deads] = 0.0 ## decoder weight - state_dict[3]['exp_avg'][:, deads, :] = 0. - state_dict[3]['exp_avg_sq'][:, deads, :] = 0. - + state_dict[3]["exp_avg"][:, deads, :] = 0.0 + state_dict[3]["exp_avg_sq"][:, deads, :] = 0.0 def loss(self, x, logging=False, return_deads=False, **kwargs): x_hat, f = self.ae(x, output_features=True) @@ -103,24 +117,25 @@ def loss(self, x, logging=False, return_deads=False, **kwargs): # update steps_since_active self.steps_since_active[deads] += 1 self.steps_since_active[~deads] = 0 - + loss = l2_loss + self.l1_penalty * l1_loss if not logging: return loss else: - return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( - x, x_hat, f, + return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( + x, + x_hat, + f, { - 'l2_loss' : l2_loss.item(), - 'mse_loss' : (x - x_hat).pow(2).sum(dim=-1).mean().item(), - 'sparsity_loss' : l1_loss.item(), - 'loss' : loss.item(), - 'deads' : deads if return_deads else None - } + "l2_loss": l2_loss.item(), + "mse_loss": (x - x_hat).pow(2).sum(dim=-1).mean().item(), + "sparsity_loss": l1_loss.item(), + "loss": loss.item(), + "deads": deads if return_deads else None, + }, ) - def update(self, step, activations): activations = activations.to(self.device) @@ -131,23 +146,28 @@ def update(self, step, activations): self.scheduler.step() if self.resample_steps is not None and step % self.resample_steps == 0: - self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) + self.resample_neurons( + self.steps_since_active > self.resample_steps / 2, activations + ) @property def config(self): return { - 'dict_class': self.ae.__class__.__name__ if not self.compile else self.ae._orig_mod.__class__.__name__, - 'trainer_class' : self.__class__.__name__, - 'activation_dim': self.ae.activation_dim, - 'dict_size': self.ae.dict_size, - 'lr' : self.lr, - 'l1_penalty' : self.l1_penalty, - 'warmup_steps' : self.warmup_steps, - 'resample_steps' : self.resample_steps, - 'device' : self.device, - 'layer' : self.layer, - 'lm_name' : self.lm_name, - 'wandb_name': self.wandb_name, - 'submodule_name': self.submodule_name, + "dict_class": ( + self.ae.__class__.__name__ + if not self.compile + else self.ae._orig_mod.__class__.__name__ + ), + "trainer_class": self.__class__.__name__, + "activation_dim": self.ae.activation_dim, + "dict_size": self.ae.dict_size, + "lr": self.lr, + "l1_penalty": self.l1_penalty, + "warmup_steps": self.warmup_steps, + "resample_steps": self.resample_steps, + "device": self.device, + "layer": self.layer, + "lm_name": self.lm_name, + "wandb_name": self.wandb_name, + "submodule_name": self.submodule_name, } - diff --git a/dictionary_learning/trainers/gated_anneal.py b/dictionary_learning/trainers/gated_anneal.py index 664904b..4221c1f 100644 --- a/dictionary_learning/trainers/gated_anneal.py +++ b/dictionary_learning/trainers/gated_anneal.py @@ -8,14 +8,16 @@ from ..dictionary import GatedAutoEncoder from collections import namedtuple + class ConstrainedAdam(t.optim.Adam): """ A variant of Adam where some of the parameters are constrained to have unit norm. """ + def __init__(self, params, constrained_params, lr): super().__init__(params, lr=lr, betas=(0, 0.999)) self.constrained_params = list(constrained_params) - + def step(self, closure=None): with t.no_grad(): for p in self.constrained_params: @@ -28,31 +30,34 @@ def step(self, closure=None): # renormalize the constrained parameters p /= p.norm(dim=0, keepdim=True) + class GatedAnnealTrainer(SAETrainer): """ Gated SAE training scheme with p-annealing. """ - def __init__(self, - dict_class=GatedAutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=3e-4, - warmup_steps=1000, # lr warmup period at start of training and after each resample - sparsity_function='Lp^p', # Lp or Lp^p - initial_sparsity_penalty=1e-1, # equal to l1 penalty in standard trainer - anneal_start=15000, # step at which to start annealing p - anneal_end=None, # step at which to stop annealing, defaults to steps-1 - p_start=1, # starting value of p (constant throughout warmup) - p_end=0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded - n_sparsity_updates = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times - sparsity_queue_length = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty - resample_steps=None, # number of steps after which to resample dead neurons - steps=None, # total number of steps to train for - device=None, - seed=42, - layer=None, - lm_name=None, - wandb_name='GatedAnnealTrainer', + + def __init__( + self, + dict_class=GatedAutoEncoder, + activation_dim=512, + dict_size=64 * 512, + lr=3e-4, + warmup_steps=1000, # lr warmup period at start of training and after each resample + sparsity_function="Lp^p", # Lp or Lp^p + initial_sparsity_penalty=1e-1, # equal to l1 penalty in standard trainer + anneal_start=15000, # step at which to start annealing p + anneal_end=None, # step at which to stop annealing, defaults to steps-1 + p_start=1, # starting value of p (constant throughout warmup) + p_end=0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded + n_sparsity_updates=10, # number of times to update the sparsity penalty, at most steps-anneal_start times + sparsity_queue_length=10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty + resample_steps=None, # number of steps after which to resample dead neurons + steps=None, # total number of steps to train for + device=None, + seed=42, + layer=None, + lm_name=None, + wandb_name="GatedAnnealTrainer", ): super().__init__(seed) @@ -69,37 +74,45 @@ def __init__(self, self.activation_dim = activation_dim self.dict_size = dict_size self.ae = dict_class(activation_dim, dict_size) - + if device is None: - self.device = 'cuda' if t.cuda.is_available() else 'cpu' + self.device = "cuda" if t.cuda.is_available() else "cpu" else: self.device = device self.ae.to(self.device) - + self.lr = lr self.sparsity_function = sparsity_function self.anneal_start = anneal_start self.anneal_end = anneal_end if anneal_end is not None else steps self.p_start = p_start self.p_end = p_end - self.p = p_start # p is set in self.loss() - self.next_p = None # set in self.loss() - self.lp_loss = None # set in self.loss() - self.scaled_lp_loss = None # set in self.loss() + self.p = p_start # p is set in self.loss() + self.next_p = None # set in self.loss() + self.lp_loss = None # set in self.loss() + self.scaled_lp_loss = None # set in self.loss() if n_sparsity_updates == "continuous": - self.n_sparsity_updates = self.anneal_end - anneal_start +1 + self.n_sparsity_updates = self.anneal_end - anneal_start + 1 else: self.n_sparsity_updates = n_sparsity_updates - self.sparsity_update_steps = t.linspace(anneal_start, self.anneal_end, self.n_sparsity_updates, dtype=int) + self.sparsity_update_steps = t.linspace( + anneal_start, self.anneal_end, self.n_sparsity_updates, dtype=int + ) self.p_values = t.linspace(p_start, p_end, self.n_sparsity_updates) self.p_step_count = 0 - self.sparsity_coeff = initial_sparsity_penalty # alpha + self.sparsity_coeff = initial_sparsity_penalty # alpha self.sparsity_queue_length = sparsity_queue_length self.sparsity_queue = [] self.warmup_steps = warmup_steps self.steps = steps - self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff'] + self.logging_parameters = [ + "p", + "next_p", + "lp_loss", + "scaled_lp_loss", + "sparsity_coeff", + ] self.seed = seed self.wandb_name = wandb_name @@ -108,20 +121,29 @@ def __init__(self, # how many steps since each neuron was last activated? self.steps_since_active = t.zeros(self.dict_size, dtype=int).to(self.device) else: - self.steps_since_active = None + self.steps_since_active = None - self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + self.optimizer = ConstrainedAdam( + self.ae.parameters(), self.ae.decoder.parameters(), lr=lr + ) if resample_steps is None: + def warmup_fn(step): - return min(step / warmup_steps, 1.) + return min(step / warmup_steps, 1.0) + else: + def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) - + return min((step % resample_steps) / warmup_steps, 1.0) + + self.scheduler = t.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=warmup_fn + ) + def resample_neurons(self, deads, activations): with t.no_grad(): - if deads.sum() == 0: return + if deads.sum() == 0: + return print(f"resampling {deads.sum().item()} neurons") # compute loss for each activation @@ -135,40 +157,43 @@ def resample_neurons(self, deads, activations): # reset encoder/decoder weights for dead neurons alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2 - self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T - self.ae.encoder.bias[deads][:n_resample] = 0. - + self.ae.decoder.weight[:, deads][:, :n_resample] = ( + sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True) + ).T + self.ae.encoder.bias[deads][:n_resample] = 0.0 # reset Adam parameters for dead neurons - state_dict = self.optimizer.state_dict()['state'] + state_dict = self.optimizer.state_dict()["state"] ## encoder weight - state_dict[1]['exp_avg'][deads] = 0. - state_dict[1]['exp_avg_sq'][deads] = 0. + state_dict[1]["exp_avg"][deads] = 0.0 + state_dict[1]["exp_avg_sq"][deads] = 0.0 ## encoder bias - state_dict[2]['exp_avg'][deads] = 0. - state_dict[2]['exp_avg_sq'][deads] = 0. + state_dict[2]["exp_avg"][deads] = 0.0 + state_dict[2]["exp_avg_sq"][deads] = 0.0 ## decoder weight - state_dict[3]['exp_avg'][:,deads] = 0. - state_dict[3]['exp_avg_sq'][:,deads] = 0. - + state_dict[3]["exp_avg"][:, deads] = 0.0 + state_dict[3]["exp_avg_sq"][:, deads] = 0.0 + def lp_norm(self, f, p): norm_sq = f.pow(p).sum(dim=-1) - if self.sparsity_function == 'Lp^p': + if self.sparsity_function == "Lp^p": return norm_sq.mean() - elif self.sparsity_function == 'Lp': - return norm_sq.pow(1/p).mean() + elif self.sparsity_function == "Lp": + return norm_sq.pow(1 / p).mean() else: raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") - + def loss(self, x, step, logging=False, **kwargs): f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) - x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() + x_hat_gate = ( + f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() + ) L_recon = (x - x_hat).pow(2).sum(dim=-1).mean() L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean() - fs = f_gate # feature activation that we use for sparsity term + fs = f_gate # feature activation that we use for sparsity term lp_loss = self.lp_norm(fs, self.p) scaled_lp_loss = lp_loss * self.sparsity_coeff self.lp_loss = lp_loss @@ -177,20 +202,27 @@ def loss(self, x, step, logging=False, **kwargs): if self.next_p is not None: lp_loss_next = self.lp_norm(fs, self.next_p) self.sparsity_queue.append([self.lp_loss.item(), lp_loss_next.item()]) - self.sparsity_queue = self.sparsity_queue[-self.sparsity_queue_length:] - + self.sparsity_queue = self.sparsity_queue[-self.sparsity_queue_length :] + if step in self.sparsity_update_steps: # check to make sure we don't update on repeat step: if step >= self.sparsity_update_steps[self.p_step_count]: # Adapt sparsity penalty alpha if self.next_p is not None: - local_sparsity_new = t.tensor([i[0] for i in self.sparsity_queue]).mean() - local_sparsity_old = t.tensor([i[1] for i in self.sparsity_queue]).mean() - self.sparsity_coeff = self.sparsity_coeff * (local_sparsity_new / local_sparsity_old).item() + local_sparsity_new = t.tensor( + [i[0] for i in self.sparsity_queue] + ).mean() + local_sparsity_old = t.tensor( + [i[1] for i in self.sparsity_queue] + ).mean() + self.sparsity_coeff = ( + self.sparsity_coeff + * (local_sparsity_new / local_sparsity_old).item() + ) # Update p self.p = self.p_values[self.p_step_count].item() - if self.p_step_count < self.n_sparsity_updates-1: - self.next_p = self.p_values[self.p_step_count+1].item() + if self.p_step_count < self.n_sparsity_updates - 1: + self.next_p = self.p_values[self.p_step_count + 1].item() else: self.next_p = self.p_end self.p_step_count += 1 @@ -200,27 +232,29 @@ def loss(self, x, step, logging=False, **kwargs): # update steps_since_active deads = (f == 0).all(dim=0) self.steps_since_active[deads] += 1 - self.steps_since_active[~deads] = 0 - + self.steps_since_active[~deads] = 0 + loss = L_recon + scaled_lp_loss + L_aux - + if not logging: return loss else: - return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( - x, x_hat, f, + return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( + x, + x_hat, + f, { - 'mse_loss' : L_recon.item(), - 'aux_loss' : L_aux.item(), - 'loss' : loss.item(), - 'p' : self.p, - 'next_p' : self.next_p, - 'lp_loss' : lp_loss.item(), - 'sparsity_loss' : scaled_lp_loss.item(), - 'sparsity_coeff' : self.sparsity_coeff, - } + "mse_loss": L_recon.item(), + "aux_loss": L_aux.item(), + "loss": loss.item(), + "p": self.p, + "next_p": self.next_p, + "lp_loss": lp_loss.item(), + "sparsity_loss": scaled_lp_loss.item(), + "sparsity_coeff": self.sparsity_coeff, + }, ) - + def update(self, step, activations): activations = activations.to(self.device) @@ -230,8 +264,13 @@ def update(self, step, activations): self.optimizer.step() self.scheduler.step() - if self.resample_steps is not None and step % self.resample_steps == self.resample_steps - 1: - self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) + if ( + self.resample_steps is not None + and step % self.resample_steps == self.resample_steps - 1 + ): + self.resample_neurons( + self.steps_since_active > self.resample_steps / 2, activations + ) # @property # def config(self): @@ -245,27 +284,27 @@ def update(self, step, activations): # 'device' : self.device, # 'wandb_name': self.wandb_name, # } - + @property def config(self): return { - 'trainer_class' : "GatedAnnealTrainer", - 'dict_class' : "GatedAutoEncoder", - 'activation_dim' : self.activation_dim, - 'dict_size' : self.dict_size, - 'lr' : self.lr, - 'sparsity_function' : self.sparsity_function, - 'sparsity_penalty' : self.sparsity_coeff, - 'p_start' : self.p_start, - 'p_end' : self.p_end, - 'anneal_start' : self.anneal_start, - 'sparsity_queue_length' : self.sparsity_queue_length, - 'n_sparsity_updates' : self.n_sparsity_updates, - 'warmup_steps' : self.warmup_steps, - 'resample_steps' : self.resample_steps, - 'steps' : self.steps, - 'seed' : self.seed, - 'layer' : self.layer, - 'lm_name' : self.lm_name, - 'wandb_name' : self.wandb_name, + "trainer_class": "GatedAnnealTrainer", + "dict_class": "GatedAutoEncoder", + "activation_dim": self.activation_dim, + "dict_size": self.dict_size, + "lr": self.lr, + "sparsity_function": self.sparsity_function, + "sparsity_penalty": self.sparsity_coeff, + "p_start": self.p_start, + "p_end": self.p_end, + "anneal_start": self.anneal_start, + "sparsity_queue_length": self.sparsity_queue_length, + "n_sparsity_updates": self.n_sparsity_updates, + "warmup_steps": self.warmup_steps, + "resample_steps": self.resample_steps, + "steps": self.steps, + "seed": self.seed, + "layer": self.layer, + "lm_name": self.lm_name, + "wandb_name": self.wandb_name, } diff --git a/dictionary_learning/trainers/gdm.py b/dictionary_learning/trainers/gdm.py index 47ea772..89607af 100644 --- a/dictionary_learning/trainers/gdm.py +++ b/dictionary_learning/trainers/gdm.py @@ -8,14 +8,16 @@ from ..dictionary import GatedAutoEncoder from collections import namedtuple + class ConstrainedAdam(t.optim.Adam): """ A variant of Adam where some of the parameters are constrained to have unit norm. """ + def __init__(self, params, constrained_params, lr): super().__init__(params, lr=lr, betas=(0, 0.999)) self.constrained_params = list(constrained_params) - + def step(self, closure=None): with t.no_grad(): for p in self.constrained_params: @@ -28,24 +30,27 @@ def step(self, closure=None): # renormalize the constrained parameters p /= p.norm(dim=0, keepdim=True) + class GatedSAETrainer(SAETrainer): """ Gated SAE training scheme. """ - def __init__(self, - dict_class=GatedAutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=5e-5, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name='GatedSAETrainer', - submodule_name=None, + + def __init__( + self, + dict_class=GatedAutoEncoder, + activation_dim=512, + dict_size=64 * 512, + lr=5e-5, + l1_penalty=1e-1, + warmup_steps=1000, # lr warmup period at start of training and after each resample + resample_steps=None, # how often to resample neurons + seed=None, + device=None, + layer=None, + lm_name=None, + wandb_name="GatedSAETrainer", + submodule_name=None, ): super().__init__(seed) @@ -62,29 +67,31 @@ def __init__(self, self.ae = dict_class(activation_dim, dict_size) self.lr = lr - self.l1_penalty=l1_penalty + self.l1_penalty = l1_penalty self.warmup_steps = warmup_steps self.wandb_name = wandb_name if device is None: - self.device = 'cuda' if t.cuda.is_available() else 'cpu' + self.device = "cuda" if t.cuda.is_available() else "cpu" else: self.device = device self.ae.to(self.device) self.optimizer = ConstrainedAdam( - self.ae.parameters(), - self.ae.decoder.parameters(), - lr=lr + self.ae.parameters(), self.ae.decoder.parameters(), lr=lr ) + def warmup_fn(step): return min(1, step / warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn) def loss(self, x, logging=False, **kwargs): f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) - x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() + x_hat_gate = ( + f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() + ) L_recon = (x - x_hat).pow(2).sum(dim=-1).mean() L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean() @@ -95,16 +102,18 @@ def loss(self, x, logging=False, **kwargs): if not logging: return loss else: - return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( - x, x_hat, f, + return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( + x, + x_hat, + f, { - 'mse_loss' : L_recon.item(), - 'sparsity_loss' : L_sparse.item(), - 'aux_loss' : L_aux.item(), - 'loss' : loss.item() - } + "mse_loss": L_recon.item(), + "sparsity_loss": L_sparse.item(), + "aux_loss": L_aux.item(), + "loss": loss.item(), + }, ) - + def update(self, step, x): x = x.to(self.device) self.optimizer.zero_grad() @@ -116,16 +125,16 @@ def update(self, step, x): @property def config(self): return { - 'dict_class': 'GatedAutoEncoder', - 'trainer_class' : 'GatedSAETrainer', - 'activation_dim' : self.ae.activation_dim, - 'dict_size' : self.ae.dict_size, - 'lr' : self.lr, - 'l1_penalty' : self.l1_penalty, - 'warmup_steps' : self.warmup_steps, - 'device' : self.device, - 'layer' : self.layer, - 'lm_name' : self.lm_name, - 'wandb_name': self.wandb_name, - 'submodule_name': self.submodule_name, + "dict_class": "GatedAutoEncoder", + "trainer_class": "GatedSAETrainer", + "activation_dim": self.ae.activation_dim, + "dict_size": self.ae.dict_size, + "lr": self.lr, + "l1_penalty": self.l1_penalty, + "warmup_steps": self.warmup_steps, + "device": self.device, + "layer": self.layer, + "lm_name": self.lm_name, + "wandb_name": self.wandb_name, + "submodule_name": self.submodule_name, } diff --git a/dictionary_learning/trainers/jumprelu.py b/dictionary_learning/trainers/jumprelu.py index f87785a..bad49fd 100644 --- a/dictionary_learning/trainers/jumprelu.py +++ b/dictionary_learning/trainers/jumprelu.py @@ -66,6 +66,7 @@ class TrainerJumpRelu(nn.Module, SAETrainer): Note does not use learning rate or sparsity scheduling as in the paper. """ + def __init__( self, dict_class=JumpReluAutoEncoder, diff --git a/dictionary_learning/trainers/p_anneal.py b/dictionary_learning/trainers/p_anneal.py index 4a157b9..037d090 100644 --- a/dictionary_learning/trainers/p_anneal.py +++ b/dictionary_learning/trainers/p_anneal.py @@ -8,14 +8,16 @@ from ..trainers.trainer import SAETrainer from ..config import DEBUG + class ConstrainedAdam(t.optim.Adam): """ A variant of Adam where some of the parameters are constrained to have unit norm. """ + def __init__(self, params, constrained_params, lr): super().__init__(params, lr=lr) self.constrained_params = list(constrained_params) - + def step(self, closure=None): with t.no_grad(): for p in self.constrained_params: @@ -28,33 +30,36 @@ def step(self, closure=None): # renormalize the constrained parameters p /= p.norm(dim=0, keepdim=True) + class PAnnealTrainer(SAETrainer): """ SAE training scheme with the option to anneal the sparsity parameter p. You can further choose to use Lp or Lp^p sparsity. """ - def __init__(self, - dict_class=AutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=1e-3, - warmup_steps=1000, # lr warmup period at start of training and after each resample - sparsity_function='Lp', # Lp or Lp^p - initial_sparsity_penalty=1e-1, # equal to l1 penalty in standard trainer - anneal_start=15000, # step at which to start annealing p - anneal_end=None, # step at which to stop annealing, defaults to steps-1 - p_start=1, # starting value of p (constant throughout warmup) - p_end=0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded - n_sparsity_updates = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times - sparsity_queue_length = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty - resample_steps=None, # number of steps after which to resample dead neurons - steps=None, # total number of steps to train for - device=None, - seed=42, - layer=None, - lm_name=None, - wandb_name='PAnnealTrainer', - submodule_name: str = None, + + def __init__( + self, + dict_class=AutoEncoder, + activation_dim=512, + dict_size=64 * 512, + lr=1e-3, + warmup_steps=1000, # lr warmup period at start of training and after each resample + sparsity_function="Lp", # Lp or Lp^p + initial_sparsity_penalty=1e-1, # equal to l1 penalty in standard trainer + anneal_start=15000, # step at which to start annealing p + anneal_end=None, # step at which to stop annealing, defaults to steps-1 + p_start=1, # starting value of p (constant throughout warmup) + p_end=0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded + n_sparsity_updates=10, # number of times to update the sparsity penalty, at most steps-anneal_start times + sparsity_queue_length=10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty + resample_steps=None, # number of steps after which to resample dead neurons + steps=None, # total number of steps to train for + device=None, + seed=42, + layer=None, + lm_name=None, + wandb_name="PAnnealTrainer", + submodule_name: str = None, ): super().__init__(seed) @@ -68,7 +73,7 @@ def __init__(self, t.cuda.manual_seed_all(seed) if device is None: - self.device = t.device('cuda' if t.cuda.is_available() else 'cpu') + self.device = t.device("cuda" if t.cuda.is_available() else "cpu") else: self.device = device @@ -77,7 +82,7 @@ def __init__(self, self.dict_size = dict_size self.ae = dict_class(activation_dim, dict_size) self.ae.to(self.device) - + self.lr = lr self.sparsity_function = sparsity_function self.anneal_start = anneal_start @@ -87,19 +92,27 @@ def __init__(self, self.p = p_start self.next_p = None if n_sparsity_updates == "continuous": - self.n_sparsity_updates = self.anneal_end - anneal_start +1 + self.n_sparsity_updates = self.anneal_end - anneal_start + 1 else: self.n_sparsity_updates = n_sparsity_updates - self.sparsity_update_steps = t.linspace(anneal_start, self.anneal_end, self.n_sparsity_updates, dtype=int) + self.sparsity_update_steps = t.linspace( + anneal_start, self.anneal_end, self.n_sparsity_updates, dtype=int + ) self.p_values = t.linspace(p_start, p_end, self.n_sparsity_updates) self.p_step_count = 0 - self.sparsity_coeff = initial_sparsity_penalty # alpha + self.sparsity_coeff = initial_sparsity_penalty # alpha self.sparsity_queue_length = sparsity_queue_length self.sparsity_queue = [] self.warmup_steps = warmup_steps self.steps = steps - self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff'] + self.logging_parameters = [ + "p", + "next_p", + "lp_loss", + "scaled_lp_loss", + "sparsity_coeff", + ] self.seed = seed self.wandb_name = wandb_name @@ -108,23 +121,32 @@ def __init__(self, # how many steps since each neuron was last activated? self.steps_since_active = t.zeros(self.dict_size, dtype=int).to(self.device) else: - self.steps_since_active = None + self.steps_since_active = None - self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + self.optimizer = ConstrainedAdam( + self.ae.parameters(), self.ae.decoder.parameters(), lr=lr + ) if resample_steps is None: + def warmup_fn(step): - return min(step / warmup_steps, 1.) + return min(step / warmup_steps, 1.0) + else: + def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) - - if (self.sparsity_update_steps.unique(return_counts=True)[1] >1).any(): + return min((step % resample_steps) / warmup_steps, 1.0) + + self.scheduler = t.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=warmup_fn + ) + + if (self.sparsity_update_steps.unique(return_counts=True)[1] > 1).any(): print("Warning! Duplicates om self.sparsity_update_steps detected!") def resample_neurons(self, deads, activations): with t.no_grad(): - if deads.sum() == 0: return + if deads.sum() == 0: + return print(f"resampling {deads.sum().item()} neurons") # compute loss for each activation @@ -138,31 +160,32 @@ def resample_neurons(self, deads, activations): # reset encoder/decoder weights for dead neurons alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() self.ae.encoder.weight[deads][:n_resample] = sampled_vecs * alive_norm * 0.2 - self.ae.decoder.weight[:,deads][:,:n_resample] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T - self.ae.encoder.bias[deads][:n_resample] = 0. - + self.ae.decoder.weight[:, deads][:, :n_resample] = ( + sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True) + ).T + self.ae.encoder.bias[deads][:n_resample] = 0.0 # reset Adam parameters for dead neurons - state_dict = self.optimizer.state_dict()['state'] + state_dict = self.optimizer.state_dict()["state"] ## encoder weight - state_dict[1]['exp_avg'][deads] = 0. - state_dict[1]['exp_avg_sq'][deads] = 0. + state_dict[1]["exp_avg"][deads] = 0.0 + state_dict[1]["exp_avg_sq"][deads] = 0.0 ## encoder bias - state_dict[2]['exp_avg'][deads] = 0. - state_dict[2]['exp_avg_sq'][deads] = 0. + state_dict[2]["exp_avg"][deads] = 0.0 + state_dict[2]["exp_avg_sq"][deads] = 0.0 ## decoder weight - state_dict[3]['exp_avg'][:,deads] = 0. - state_dict[3]['exp_avg_sq'][:,deads] = 0. + state_dict[3]["exp_avg"][:, deads] = 0.0 + state_dict[3]["exp_avg_sq"][:, deads] = 0.0 def lp_norm(self, f, p): norm_sq = f.pow(p).sum(dim=-1) - if self.sparsity_function == 'Lp^p': + if self.sparsity_function == "Lp^p": return norm_sq.mean() - elif self.sparsity_function == 'Lp': - return norm_sq.pow(1/p).mean() + elif self.sparsity_function == "Lp": + return norm_sq.pow(1 / p).mean() else: raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") - + def loss(self, x, step, logging=False): # Compute loss terms x_hat, f = self.ae(x, output_features=True) @@ -175,20 +198,27 @@ def loss(self, x, step, logging=False): if self.next_p is not None: lp_loss_next = self.lp_norm(f, self.next_p) self.sparsity_queue.append([self.lp_loss.item(), lp_loss_next.item()]) - self.sparsity_queue = self.sparsity_queue[-self.sparsity_queue_length:] - + self.sparsity_queue = self.sparsity_queue[-self.sparsity_queue_length :] + if step in self.sparsity_update_steps: # check to make sure we don't update on repeat step: if step >= self.sparsity_update_steps[self.p_step_count]: # Adapt sparsity penalty alpha if self.next_p is not None: - local_sparsity_new = t.tensor([i[0] for i in self.sparsity_queue]).mean() - local_sparsity_old = t.tensor([i[1] for i in self.sparsity_queue]).mean() - self.sparsity_coeff = self.sparsity_coeff * (local_sparsity_new / local_sparsity_old).item() + local_sparsity_new = t.tensor( + [i[0] for i in self.sparsity_queue] + ).mean() + local_sparsity_old = t.tensor( + [i[1] for i in self.sparsity_queue] + ).mean() + self.sparsity_coeff = ( + self.sparsity_coeff + * (local_sparsity_new / local_sparsity_old).item() + ) # Update p self.p = self.p_values[self.p_step_count].item() - if self.p_step_count < self.n_sparsity_updates-1: - self.next_p = self.p_values[self.p_step_count+1].item() + if self.p_step_count < self.n_sparsity_updates - 1: + self.next_p = self.p_values[self.p_step_count + 1].item() else: self.next_p = self.p_end self.p_step_count += 1 @@ -198,21 +228,20 @@ def loss(self, x, step, logging=False): # update steps_since_active deads = (f == 0).all(dim=0) self.steps_since_active[deads] += 1 - self.steps_since_active[~deads] = 0 - + self.steps_since_active[~deads] = 0 + if logging is False: return l2_loss + scaled_lp_loss - else: + else: loss_log = { - 'p' : self.p, - 'next_p' : self.next_p, - 'lp_loss' : lp_loss.item(), - 'scaled_lp_loss' : scaled_lp_loss.item(), - 'sparsity_coeff' : self.sparsity_coeff, + "p": self.p, + "next_p": self.next_p, + "lp_loss": lp_loss.item(), + "scaled_lp_loss": scaled_lp_loss.item(), + "sparsity_coeff": self.sparsity_coeff, } return x, x_hat, f, loss_log - - + def update(self, step, activations): activations = activations.to(self.device) @@ -222,30 +251,35 @@ def update(self, step, activations): self.optimizer.step() self.scheduler.step() - if self.resample_steps is not None and step % self.resample_steps == self.resample_steps - 1: - self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) + if ( + self.resample_steps is not None + and step % self.resample_steps == self.resample_steps - 1 + ): + self.resample_neurons( + self.steps_since_active > self.resample_steps / 2, activations + ) @property def config(self): return { - 'trainer_class' : "PAnnealTrainer", - 'dict_class' : "AutoEncoder", - 'activation_dim' : self.activation_dim, - 'dict_size' : self.dict_size, - 'lr' : self.lr, - 'sparsity_function' : self.sparsity_function, - 'sparsity_penalty' : self.sparsity_coeff, - 'p_start' : self.p_start, - 'p_end' : self.p_end, - 'anneal_start' : self.anneal_start, - 'sparsity_queue_length' : self.sparsity_queue_length, - 'n_sparsity_updates' : self.n_sparsity_updates, - 'warmup_steps' : self.warmup_steps, - 'resample_steps' : self.resample_steps, - 'steps' : self.steps, - 'seed' : self.seed, - 'layer' : self.layer, - 'lm_name' : self.lm_name, - 'wandb_name' : self.wandb_name, - 'submodule_name' : self.submodule_name, + "trainer_class": "PAnnealTrainer", + "dict_class": "AutoEncoder", + "activation_dim": self.activation_dim, + "dict_size": self.dict_size, + "lr": self.lr, + "sparsity_function": self.sparsity_function, + "sparsity_penalty": self.sparsity_coeff, + "p_start": self.p_start, + "p_end": self.p_end, + "anneal_start": self.anneal_start, + "sparsity_queue_length": self.sparsity_queue_length, + "n_sparsity_updates": self.n_sparsity_updates, + "warmup_steps": self.warmup_steps, + "resample_steps": self.resample_steps, + "steps": self.steps, + "seed": self.seed, + "layer": self.layer, + "lm_name": self.lm_name, + "wandb_name": self.wandb_name, + "submodule_name": self.submodule_name, } diff --git a/dictionary_learning/trainers/standard.py b/dictionary_learning/trainers/standard.py index 1c82431..848b81f 100644 --- a/dictionary_learning/trainers/standard.py +++ b/dictionary_learning/trainers/standard.py @@ -1,20 +1,23 @@ """ Implements the standard SAE training scheme. """ + import torch as t from ..trainers.trainer import SAETrainer from ..config import DEBUG from ..dictionary import AutoEncoder from collections import namedtuple + class ConstrainedAdam(t.optim.Adam): """ A variant of Adam where some of the parameters are constrained to have unit norm. """ + def __init__(self, params, constrained_params, lr): super().__init__(params, lr=lr) self.constrained_params = list(constrained_params) - + def step(self, closure=None): with t.no_grad(): for p in self.constrained_params: @@ -27,25 +30,28 @@ def step(self, closure=None): # renormalize the constrained parameters p /= p.norm(dim=0, keepdim=True) + class StandardTrainer(SAETrainer): """ Standard SAE training scheme. """ - def __init__(self, - dict_class=AutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=1e-3, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name='StandardTrainer', - submodule_name=None, - compile=False, + + def __init__( + self, + dict_class=AutoEncoder, + activation_dim=512, + dict_size=64 * 512, + lr=1e-3, + l1_penalty=1e-1, + warmup_steps=1000, # lr warmup period at start of training and after each resample + resample_steps=None, # how often to resample neurons + seed=None, + device=None, + layer=None, + lm_name=None, + wandb_name="StandardTrainer", + submodule_name=None, + compile=False, ): super().__init__(seed) @@ -63,37 +69,47 @@ def __init__(self, if compile: self.ae = t.compile(self.ae) self.lr = lr - self.l1_penalty=l1_penalty + self.l1_penalty = l1_penalty self.warmup_steps = warmup_steps self.wandb_name = wandb_name if device is None: - self.device = 'cuda' if t.cuda.is_available() else 'cpu' + self.device = "cuda" if t.cuda.is_available() else "cpu" else: self.device = device self.ae.to(self.device) self.resample_steps = resample_steps - if self.resample_steps is not None: # how many steps since each neuron was last activated? - self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) + self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to( + self.device + ) else: - self.steps_since_active = None + self.steps_since_active = None - self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + self.optimizer = ConstrainedAdam( + self.ae.parameters(), self.ae.decoder.parameters(), lr=lr + ) if resample_steps is None: + def warmup_fn(step): - return min(step / warmup_steps, 1.) + return min(step / warmup_steps, 1.0) + else: + def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + return min((step % resample_steps) / warmup_steps, 1.0) + + self.scheduler = t.optim.lr_scheduler.LambdaLR( + self.optimizer, lr_lambda=warmup_fn + ) def resample_neurons(self, deads, activations): with t.no_grad(): - if deads.sum() == 0: return + if deads.sum() == 0: + return print(f"resampling {deads.sum().item()} neurons") # compute loss for each activation @@ -110,22 +126,23 @@ def resample_neurons(self, deads, activations): # resample first n_resample dead neurons deads[deads.nonzero()[n_resample:]] = False self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 - self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T - self.ae.encoder.bias[deads] = 0. - + self.ae.decoder.weight[:, deads] = ( + sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True) + ).T + self.ae.encoder.bias[deads] = 0.0 # reset Adam parameters for dead neurons - state_dict = self.optimizer.state_dict()['state'] + state_dict = self.optimizer.state_dict()["state"] ## encoder weight - state_dict[1]['exp_avg'][deads] = 0. - state_dict[1]['exp_avg_sq'][deads] = 0. + state_dict[1]["exp_avg"][deads] = 0.0 + state_dict[1]["exp_avg_sq"][deads] = 0.0 ## encoder bias - state_dict[2]['exp_avg'][deads] = 0. - state_dict[2]['exp_avg_sq'][deads] = 0. + state_dict[2]["exp_avg"][deads] = 0.0 + state_dict[2]["exp_avg_sq"][deads] = 0.0 ## decoder weight - state_dict[3]['exp_avg'][:,deads] = 0. - state_dict[3]['exp_avg_sq'][:,deads] = 0. - + state_dict[3]["exp_avg"][:, deads] = 0.0 + state_dict[3]["exp_avg_sq"][:, deads] = 0.0 + def loss(self, x, logging=False, **kwargs): x_hat, f = self.ae(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() @@ -136,23 +153,24 @@ def loss(self, x, logging=False, **kwargs): deads = (f == 0).all(dim=0) self.steps_since_active[deads] += 1 self.steps_since_active[~deads] = 0 - + loss = l2_loss + self.l1_penalty * l1_loss if not logging: return loss else: - return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( - x, x_hat, f, + return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( + x, + x_hat, + f, { - 'l2_loss' : l2_loss.item(), - 'mse_loss' : (x - x_hat).pow(2).sum(dim=-1).mean().item(), - 'sparsity_loss' : l1_loss.item(), - 'loss' : loss.item() - } + "l2_loss": l2_loss.item(), + "mse_loss": (x - x_hat).pow(2).sum(dim=-1).mean().item(), + "sparsity_loss": l1_loss.item(), + "loss": loss.item(), + }, ) - def update(self, step, activations): activations = activations.to(self.device) @@ -163,23 +181,24 @@ def update(self, step, activations): self.scheduler.step() if self.resample_steps is not None and step % self.resample_steps == 0: - self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) + self.resample_neurons( + self.steps_since_active > self.resample_steps / 2, activations + ) @property def config(self): return { - 'dict_class': 'AutoEncoder', - 'trainer_class' : 'StandardTrainer', - 'activation_dim': self.ae.activation_dim, - 'dict_size': self.ae.dict_size, - 'lr' : self.lr, - 'l1_penalty' : self.l1_penalty, - 'warmup_steps' : self.warmup_steps, - 'resample_steps' : self.resample_steps, - 'device' : self.device, - 'layer' : self.layer, - 'lm_name' : self.lm_name, - 'wandb_name': self.wandb_name, - 'submodule_name': self.submodule_name, + "dict_class": "AutoEncoder", + "trainer_class": "StandardTrainer", + "activation_dim": self.ae.activation_dim, + "dict_size": self.ae.dict_size, + "lr": self.lr, + "l1_penalty": self.l1_penalty, + "warmup_steps": self.warmup_steps, + "resample_steps": self.resample_steps, + "device": self.device, + "layer": self.layer, + "lm_name": self.lm_name, + "wandb_name": self.wandb_name, + "submodule_name": self.submodule_name, } - diff --git a/dictionary_learning/trainers/top_k.py b/dictionary_learning/trainers/top_k.py index 33046f5..ca244c5 100644 --- a/dictionary_learning/trainers/top_k.py +++ b/dictionary_learning/trainers/top_k.py @@ -78,7 +78,9 @@ def encode(self, x: t.Tensor, return_topk: bool = False): top_indices_BK = post_topk.indices buffer_BF = t.zeros_like(post_relu_feat_acts_BF) - encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK) + encoded_acts_BF = buffer_BF.scatter_( + dim=-1, index=top_indices_BK, src=tops_acts_BK + ) if return_topk: return encoded_acts_BF, tops_acts_BK, top_indices_BK @@ -180,7 +182,9 @@ def __init__( self.dead_feature_threshold = 10_000_000 # Optimizer and scheduler - self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + self.optimizer = t.optim.Adam( + self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999) + ) def lr_fn(step): if step < decay_start: @@ -242,7 +246,9 @@ def loss(self, x, step=None, logging=False): auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) auxk_buffer_BF = t.zeros_like(f) - auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + auxk_acts_BF = auxk_buffer_BF.scatter_( + dim=-1, index=auxk_indices, src=auxk_acts + ) # Encourage the top ~50% of dead latents to predict the residual of the # top k living latents @@ -263,7 +269,11 @@ def loss(self, x, step=None, logging=False): x, x_hat, f, - {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()}, + { + "l2_loss": l2_loss.item(), + "auxk_loss": auxk_loss.item(), + "loss": loss.item(), + }, ) def update(self, step, x): diff --git a/dictionary_learning/trainers/trainer.py b/dictionary_learning/trainers/trainer.py index 04170b9..e199f1e 100644 --- a/dictionary_learning/trainers/trainer.py +++ b/dictionary_learning/trainers/trainer.py @@ -2,15 +2,17 @@ class SAETrainer: """ Generic class for implementing SAE training algorithms """ + def __init__(self, seed=None): self.seed = seed self.logging_parameters = [] - def update(self, - step, # index of step in training - activations, # of shape [batch_size, d_submodule] - ): - pass # implemented by subclasses + def update( + self, + step, # index of step in training + activations, # of shape [batch_size, d_submodule] + ): + pass # implemented by subclasses def get_logging_parameters(self): stats = {} @@ -20,9 +22,9 @@ def get_logging_parameters(self): else: print(f"Warning: {param} not found in {self}") return stats - + @property def config(self): return { - 'wandb_name': 'trainer', + "wandb_name": "trainer", } diff --git a/dictionary_learning/training.py b/dictionary_learning/training.py index 24f18e6..ef64198 100644 --- a/dictionary_learning/training.py +++ b/dictionary_learning/training.py @@ -15,13 +15,16 @@ from .trainers.standard import StandardTrainer from .trainers.crosscoder import CrossCoderTrainer + def get_stats( trainer, act: t.Tensor, - deads_sum: bool=True, + deads_sum: bool = True, ): with t.no_grad(): - act, act_hat, f, losslog = trainer.loss(act, step=0, logging=True, return_deads=True) + act, act_hat, f, losslog = trainer.loss( + act, step=0, logging=True, return_deads=True + ) # L0 l0 = (f != 0).float().detach().cpu().sum(dim=-1).mean().item() @@ -32,12 +35,16 @@ def get_stats( } if losslog["deads"] is not None: total_feats = losslog["deads"].shape[0] - out["frac_deads"] = losslog["deads"].sum().item() / total_feats if deads_sum else losslog["deads"] + out["frac_deads"] = ( + losslog["deads"].sum().item() / total_feats + if deads_sum + else losslog["deads"] + ) # fraction of variance explained if act.dim() == 2: # act.shape: [batch, d_model] - # fraction of variance explained + # fraction of variance explained total_variance = t.var(act, dim=0).sum() residual_variance = t.var(act - act_hat, dim=0).sum() frac_variance_explained = 1 - residual_variance / total_variance @@ -45,11 +52,15 @@ def get_stats( # act.shape: [batch, layer, d_model] total_variance_per_layer = [] residual_variance_per_layer = [] - + for l in range(act.shape[1]): total_variance_per_layer.append(t.var(act[:, l, :], dim=0).cpu().sum()) - residual_variance_per_layer.append(t.var(act[:, l, :] - act_hat[:, l, :], dim=0).cpu().sum()) - out[f"cl{l}_frac_variance_explained"] = 1 - residual_variance_per_layer[l] / total_variance_per_layer[l] + residual_variance_per_layer.append( + t.var(act[:, l, :] - act_hat[:, l, :], dim=0).cpu().sum() + ) + out[f"cl{l}_frac_variance_explained"] = ( + 1 - residual_variance_per_layer[l] / total_variance_per_layer[l] + ) total_variance = sum(total_variance_per_layer) residual_variance = sum(residual_variance_per_layer) frac_variance_explained = 1 - residual_variance / total_variance @@ -58,13 +69,14 @@ def get_stats( return out + def log_stats( trainer, step: int, act: t.Tensor, activations_split_by_head: bool, transcoder: bool, - stage: str="train", + stage: str = "train", ): with t.no_grad(): log = {} @@ -87,11 +99,12 @@ def log_stats( wandb.log(log, step=step) + @t.no_grad() def run_validation( trainer, validation_data, - step: int=None, + step: int = None, ): l0 = [] frac_variance_explained = [] @@ -106,7 +119,9 @@ def run_validation( frac_variance_explained.append(stats["frac_variance_explained"]) if isinstance(trainer, CrossCoderTrainer): for l in range(act.shape[1]): - frac_variance_explained_per_layer[l].append(stats[f"cl{l}_frac_variance_explained"]) + frac_variance_explained_per_layer[l].append( + stats[f"cl{l}_frac_variance_explained"] + ) log = {} log["val/frac_deads"] = t.stack(deads).all(dim=0).float().mean().item() @@ -114,11 +129,14 @@ def run_validation( log["val/frac_variance_explained"] = t.tensor(frac_variance_explained).mean() if isinstance(trainer, CrossCoderTrainer): for l in range(act.shape[1]): - log[f"val/cl{l}_frac_variance_explained"] = t.tensor(frac_variance_explained_per_layer[l]).mean() + log[f"val/cl{l}_frac_variance_explained"] = t.tensor( + frac_variance_explained_per_layer[l] + ).mean() if step is not None: log["step"] = step wandb.log(log, step=step) + def trainSAE( data, trainer_config, @@ -138,14 +156,22 @@ def trainSAE( """ Train SAE using the given trainer """ - assert not(validation_data is None and validate_every_n_steps is not None), "Must provide validation data if validate_every_n_steps is not None" + assert not ( + validation_data is None and validate_every_n_steps is not None + ), "Must provide validation data if validate_every_n_steps is not None" trainer_class = trainer_config["trainer"] del trainer_config["trainer"] trainer = trainer_class(**trainer_config) wandb_config = trainer.config | run_cfg - wandb.init(entity=wandb_entity, project=wandb_project, config=wandb_config, name=wandb_config["wandb_name"], mode="disabled" if not use_wandb else "online") + wandb.init( + entity=wandb_entity, + project=wandb_project, + config=wandb_config, + name=wandb_config["wandb_name"], + mode="disabled" if not use_wandb else "online", + ) # make save dir, export config if save_dir is not None: @@ -166,23 +192,34 @@ def trainSAE( # logging if log_steps is not None and step % log_steps == 0 and step != 0: - log_stats( - trainer, step, act, activations_split_by_head, transcoder - ) + log_stats(trainer, step, act, activations_split_by_head, transcoder) # saving if save_steps is not None and step % save_steps == 0: if save_dir is not None: - os.makedirs(os.path.join(save_dir, trainer.config["wandb_name"].lower()), exist_ok=True) + os.makedirs( + os.path.join(save_dir, trainer.config["wandb_name"].lower()), + exist_ok=True, + ) t.save( - trainer.ae.state_dict() if not trainer_config["compile"] else trainer.ae._orig_mod.state_dict(), - os.path.join(save_dir, trainer.config["wandb_name"].lower(), f"ae_{step}.pt"), + ( + trainer.ae.state_dict() + if not trainer_config["compile"] + else trainer.ae._orig_mod.state_dict() + ), + os.path.join( + save_dir, trainer.config["wandb_name"].lower(), f"ae_{step}.pt" + ), ) # training trainer.update(step, act) - if validate_every_n_steps is not None and step % validate_every_n_steps == 0 and step != 0: + if ( + validate_every_n_steps is not None + and step % validate_every_n_steps == 0 + and step != 0 + ): print(f"Validating at step {step}") run_validation(trainer, validation_data, step=step) @@ -193,8 +230,19 @@ def trainSAE( # save final SAE if save_dir is not None: - os.makedirs(os.path.join(save_dir, trainer.config["wandb_name"].lower()), exist_ok=True) - t.save(trainer.ae.state_dict() if not trainer_config["compile"] else trainer.ae._orig_mod.state_dict(), os.path.join(save_dir, trainer.config["wandb_name"].lower(), f"ae_final.pt")) + os.makedirs( + os.path.join(save_dir, trainer.config["wandb_name"].lower()), exist_ok=True + ) + t.save( + ( + trainer.ae.state_dict() + if not trainer_config["compile"] + else trainer.ae._orig_mod.state_dict() + ), + os.path.join( + save_dir, trainer.config["wandb_name"].lower(), f"ae_final.pt" + ), + ) if use_wandb: wandb.finish() diff --git a/dictionary_learning/utils.py b/dictionary_learning/utils.py index 8641f05..bf629fa 100644 --- a/dictionary_learning/utils.py +++ b/dictionary_learning/utils.py @@ -3,25 +3,29 @@ import io import json -def hf_dataset_to_generator(dataset_name, split='train', streaming=True): + +def hf_dataset_to_generator(dataset_name, split="train", streaming=True): dataset = load_dataset(dataset_name, split=split, streaming=streaming) - + def gen(): for x in iter(dataset): - yield x['text'] - + yield x["text"] + return gen() + def zst_to_generator(data_path): """ Load a dataset from a .jsonl.zst file. The jsonl entries is assumed to have a 'text' field """ - compressed_file = open(data_path, 'rb') + compressed_file = open(data_path, "rb") dctx = zstd.ZstdDecompressor() reader = dctx.stream_reader(compressed_file) - text_stream = io.TextIOWrapper(reader, encoding='utf-8') + text_stream = io.TextIOWrapper(reader, encoding="utf-8") + def generator(): for line in text_stream: - yield json.loads(line)['text'] - return generator() \ No newline at end of file + yield json.loads(line)["text"] + + return generator()