Skip to content

Commit

Permalink
Merge pull request #3 from jkminder/hf_hub
Browse files Browse the repository at this point in the history
Added push and load from hub
  • Loading branch information
Butanium authored Nov 22, 2024
2 parents daddac6 + ee9c812 commit de73513
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 20 deletions.
35 changes: 34 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,42 @@
This repo contains a few new features compared to the original repo:
- It is `pip` installable.
- A new `Crosscoder` class for training CrossCoders as described in [the anthropic paper](https://transformer-circuits.pub/drafts/crosscoders/index.html#model-diffing).
```py
!pip install git+https://github.com/jkminder/dictionary_learning
from dictionary_learning import CrossCoder
from nnsight import LanguageModel
import torch as th

crosscoder = CrossCoder.from_pretrained("Butanium/gemma-2-2b-crosscoder-l13-mu4.1e-02-lr1e-04", from_hub=True)
gemma_2 = LanguageModel("google/gemma-2-2b", device_map="cuda:0")
gemma_2_it = LanguageModel("google/gemma-2-2b-it", device_map="cuda:1")
prompt = "quick fox brown"

with gemma_2.trace(prompt):
l13_act_base = gemma_2.model.layers[13].output[0][:, -1].save() # (1, 2304)
gemma_2.model.layers[13].output.stop()

with gemma_2_it.trace(prompt):
l13_act_it = gemma_2_it.model.layers[13].output[0][:, -1].save() # (1, 2304)
gemma_2_it.model.layers[13].output.stop()


crosscoder_input = th.cat([l13_act_base, l13_act_it], dim=0).unsqueeze(0).cpu() # (batch, 2, 2304)
print(crosscoder_input.shape)
reconstruction, features = crosscoder(crosscoder_input, output_features=True)

# print metrics
print(f"MSE loss: {th.nn.functional.mse_loss(reconstruction, crosscoder_input).item():.2f}")
print(f"L1 sparsity: {features.abs().sum():.1f}")
print(f"L0 sparsity: {(features > 1e-4).sum()}")
```
- A way to cache activations in order to load them later to train a SAE or Crosscoder in `cache.py`.
- A script for training a Crosscoder using pre-computed activations in `scripts/train_crosscoder.py`.

- You can now load and push dictionaries to the Huggingface model hub.
```py
my_super_cool_dictionary.push_to_hub("username/my-super-cool-dictionary")
loaded_dictionary = MyDictionary.from_pretrained("username/my-super-cool-dictionary", from_hub=True)
```

# Original README

Expand Down
69 changes: 50 additions & 19 deletions dictionary_learning/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

from abc import ABC, abstractclassmethod, abstractmethod
from huggingface_hub import PyTorchModelHubMixin

import torch as th
import torch.nn as nn
import torch.nn.init as init
Expand All @@ -11,7 +13,7 @@
from warnings import warn


class Dictionary(ABC, nn.Module):
class Dictionary(ABC, nn.Module, PyTorchModelHubMixin):
"""
A dictionary consists of a collection of vectors, an encoder, and a decoder.
"""
Expand All @@ -35,11 +37,24 @@ def decode(self, f):

@classmethod
@abstractmethod
def from_pretrained(cls, path, device=None, **kwargs) -> "Dictionary":
def from_pretrained(
cls, path, from_hub=False, device=None, dtype=None, **kwargs
) -> "Dictionary":
"""
Load a pretrained dictionary from a file.
Load a pretrained dictionary from a file or hub.
Args:
path: Path to local file or hub model id
from_hub: If True, load from HuggingFace hub using PyTorchModelHubMixin
device: Device to load the model to
**kwargs: Additional arguments passed to loading function
"""
pass
model = super(Dictionary, cls).from_pretrained(path, **kwargs)
if device is not None:
model.to(device)
if dtype is not None:
model.to(dtype=dtype)
return model


class AutoEncoder(Dictionary, nn.Module):
Expand Down Expand Up @@ -96,10 +111,13 @@ def forward(self, x, output_features=False, ghost_mask=None):
return x_hat, x_ghost

@classmethod
def from_pretrained(cls, path, dtype=th.float, device=None):
"""
Load a pretrained autoencoder from a file.
"""
def from_pretrained(
cls, path, dtype=th.float, from_hub=False, device=None, **kwargs
):
if from_hub:
return super().from_pretrained(path, dtype=dtype, device=device, **kwargs)

# Existing custom loading logic
state_dict = th.load(path)
dict_size, activation_dim = state_dict["encoder.weight"].shape
autoencoder = cls(activation_dim, dict_size)
Expand Down Expand Up @@ -218,13 +236,15 @@ def forward(self, x, output_features=False):
else:
return x_hat

def from_pretrained(path, device=None):
"""
Load a pretrained autoencoder from a file.
"""
@classmethod
def from_pretrained(cls, path, from_hub=False, device=None, dtype=None, **kwargs):
if from_hub:
return super().from_pretrained(path, device=device, dtype=dtype, **kwargs)

# Existing custom loading logic
state_dict = th.load(path)
dict_size, activation_dim = state_dict["encoder.weight"].shape
autoencoder = GatedAutoEncoder(activation_dim, dict_size)
autoencoder = cls(activation_dim, dict_size)
autoencoder.load_state_dict(state_dict)
if device is not None:
autoencoder.to(device)
Expand Down Expand Up @@ -288,6 +308,7 @@ def from_pretrained(
cls,
path: str | None = None,
load_from_sae_lens: bool = False,
from_hub: bool = False,
dtype: th.dtype = th.float32,
device: th.device | None = None,
**kwargs,
Expand All @@ -298,9 +319,13 @@ def from_pretrained(
loading function.
"""
if not load_from_sae_lens:
if from_hub:
return super().from_pretrained(
path, device=device, dtype=dtype, **kwargs
)
state_dict = th.load(path)
dict_size, activation_dim = state_dict["W_enc"].shape
autoencoder = JumpReluAutoEncoder(activation_dim, dict_size)
autoencoder = cls(activation_dim, dict_size)
autoencoder.load_state_dict(state_dict)
else:
from sae_lens import SAE
Expand Down Expand Up @@ -364,13 +389,14 @@ def forward(self, x, output_features=False):
f = f * self.decoder.weight.norm(dim=0, keepdim=True)
return x_hat, f

def from_pretrained(path, device=None):
"""
Load a pretrained autoencoder from a file.
"""
@classmethod
def from_pretrained(cls, path, device=None, from_hub=False, dtype=None, **kwargs):
if from_hub:
return super().from_pretrained(path, device=device, dtype=dtype, **kwargs)

state_dict = th.load(path)
dict_size, activation_dim = state_dict["encoder.weight"].shape
autoencoder = AutoEncoderNew(activation_dim, dict_size)
autoencoder = cls(activation_dim, dict_size)
autoencoder.load_state_dict(state_dict)
if device is not None:
autoencoder.to(device)
Expand Down Expand Up @@ -596,10 +622,15 @@ def from_pretrained(
path: str,
dtype: th.dtype = th.float32,
device: th.device | None = None,
from_hub: bool = False,
**kwargs,
):
"""
Load a pretrained cross-coder from a file.
"""
if from_hub:
return super().from_pretrained(path, device=device, dtype=dtype, **kwargs)

state_dict = th.load(path, map_location="cpu", weights_only=True)
if "encoder.weight" not in state_dict:
warn(
Expand Down

0 comments on commit de73513

Please sign in to comment.