diff --git a/README.md b/README.md index 3ec89c5..89b3bce 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,42 @@ This repo contains a few new features compared to the original repo: - It is `pip` installable. - A new `Crosscoder` class for training CrossCoders as described in [the anthropic paper](https://transformer-circuits.pub/drafts/crosscoders/index.html#model-diffing). +```py +!pip install git+https://github.com/jkminder/dictionary_learning +from dictionary_learning import CrossCoder +from nnsight import LanguageModel +import torch as th + +crosscoder = CrossCoder.from_pretrained("Butanium/gemma-2-2b-crosscoder-l13-mu4.1e-02-lr1e-04", from_hub=True) +gemma_2 = LanguageModel("google/gemma-2-2b", device_map="cuda:0") +gemma_2_it = LanguageModel("google/gemma-2-2b-it", device_map="cuda:1") +prompt = "quick fox brown" + +with gemma_2.trace(prompt): + l13_act_base = gemma_2.model.layers[13].output[0][:, -1].save() # (1, 2304) + gemma_2.model.layers[13].output.stop() + +with gemma_2_it.trace(prompt): + l13_act_it = gemma_2_it.model.layers[13].output[0][:, -1].save() # (1, 2304) + gemma_2_it.model.layers[13].output.stop() + + +crosscoder_input = th.cat([l13_act_base, l13_act_it], dim=0).unsqueeze(0).cpu() # (batch, 2, 2304) +print(crosscoder_input.shape) +reconstruction, features = crosscoder(crosscoder_input, output_features=True) + +# print metrics +print(f"MSE loss: {th.nn.functional.mse_loss(reconstruction, crosscoder_input).item():.2f}") +print(f"L1 sparsity: {features.abs().sum():.1f}") +print(f"L0 sparsity: {(features > 1e-4).sum()}") +``` - A way to cache activations in order to load them later to train a SAE or Crosscoder in `cache.py`. - A script for training a Crosscoder using pre-computed activations in `scripts/train_crosscoder.py`. - +- You can now load and push dictionaries to the Huggingface model hub. +```py +my_super_cool_dictionary.push_to_hub("username/my-super-cool-dictionary") +loaded_dictionary = MyDictionary.from_pretrained("username/my-super-cool-dictionary", from_hub=True) +``` # Original README diff --git a/dictionary_learning/dictionary.py b/dictionary_learning/dictionary.py index bd16916..b817917 100644 --- a/dictionary_learning/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -3,6 +3,8 @@ """ from abc import ABC, abstractclassmethod, abstractmethod +from huggingface_hub import PyTorchModelHubMixin + import torch as th import torch.nn as nn import torch.nn.init as init @@ -11,7 +13,7 @@ from warnings import warn -class Dictionary(ABC, nn.Module): +class Dictionary(ABC, nn.Module, PyTorchModelHubMixin): """ A dictionary consists of a collection of vectors, an encoder, and a decoder. """ @@ -35,11 +37,24 @@ def decode(self, f): @classmethod @abstractmethod - def from_pretrained(cls, path, device=None, **kwargs) -> "Dictionary": + def from_pretrained( + cls, path, from_hub=False, device=None, dtype=None, **kwargs + ) -> "Dictionary": """ - Load a pretrained dictionary from a file. + Load a pretrained dictionary from a file or hub. + + Args: + path: Path to local file or hub model id + from_hub: If True, load from HuggingFace hub using PyTorchModelHubMixin + device: Device to load the model to + **kwargs: Additional arguments passed to loading function """ - pass + model = super(Dictionary, cls).from_pretrained(path, **kwargs) + if device is not None: + model.to(device) + if dtype is not None: + model.to(dtype=dtype) + return model class AutoEncoder(Dictionary, nn.Module): @@ -96,10 +111,13 @@ def forward(self, x, output_features=False, ghost_mask=None): return x_hat, x_ghost @classmethod - def from_pretrained(cls, path, dtype=th.float, device=None): - """ - Load a pretrained autoencoder from a file. - """ + def from_pretrained( + cls, path, dtype=th.float, from_hub=False, device=None, **kwargs + ): + if from_hub: + return super().from_pretrained(path, dtype=dtype, device=device, **kwargs) + + # Existing custom loading logic state_dict = th.load(path) dict_size, activation_dim = state_dict["encoder.weight"].shape autoencoder = cls(activation_dim, dict_size) @@ -218,13 +236,15 @@ def forward(self, x, output_features=False): else: return x_hat - def from_pretrained(path, device=None): - """ - Load a pretrained autoencoder from a file. - """ + @classmethod + def from_pretrained(cls, path, from_hub=False, device=None, dtype=None, **kwargs): + if from_hub: + return super().from_pretrained(path, device=device, dtype=dtype, **kwargs) + + # Existing custom loading logic state_dict = th.load(path) dict_size, activation_dim = state_dict["encoder.weight"].shape - autoencoder = GatedAutoEncoder(activation_dim, dict_size) + autoencoder = cls(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) if device is not None: autoencoder.to(device) @@ -288,6 +308,7 @@ def from_pretrained( cls, path: str | None = None, load_from_sae_lens: bool = False, + from_hub: bool = False, dtype: th.dtype = th.float32, device: th.device | None = None, **kwargs, @@ -298,9 +319,13 @@ def from_pretrained( loading function. """ if not load_from_sae_lens: + if from_hub: + return super().from_pretrained( + path, device=device, dtype=dtype, **kwargs + ) state_dict = th.load(path) dict_size, activation_dim = state_dict["W_enc"].shape - autoencoder = JumpReluAutoEncoder(activation_dim, dict_size) + autoencoder = cls(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) else: from sae_lens import SAE @@ -364,13 +389,14 @@ def forward(self, x, output_features=False): f = f * self.decoder.weight.norm(dim=0, keepdim=True) return x_hat, f - def from_pretrained(path, device=None): - """ - Load a pretrained autoencoder from a file. - """ + @classmethod + def from_pretrained(cls, path, device=None, from_hub=False, dtype=None, **kwargs): + if from_hub: + return super().from_pretrained(path, device=device, dtype=dtype, **kwargs) + state_dict = th.load(path) dict_size, activation_dim = state_dict["encoder.weight"].shape - autoencoder = AutoEncoderNew(activation_dim, dict_size) + autoencoder = cls(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) if device is not None: autoencoder.to(device) @@ -596,10 +622,15 @@ def from_pretrained( path: str, dtype: th.dtype = th.float32, device: th.device | None = None, + from_hub: bool = False, + **kwargs, ): """ Load a pretrained cross-coder from a file. """ + if from_hub: + return super().from_pretrained(path, device=device, dtype=dtype, **kwargs) + state_dict = th.load(path, map_location="cpu", weights_only=True) if "encoder.weight" not in state_dict: warn(