Skip to content

Commit

Permalink
testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Butanium committed Nov 22, 2024
1 parent 1f026b9 commit 9ef2acd
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions dictionary_learning/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from torch.nn.functional import relu
import einops
from warnings import warn
import tempfile


class Dictionary(ABC, nn.Module, PyTorchModelHubMixin):
Expand Down Expand Up @@ -593,29 +592,29 @@ def forward(self, x: th.Tensor, output_features=False):
else:
return x_hat

@classmethod
def from_pretrained(
cls,
path: str,
dtype: th.dtype = th.float32,
device: th.device | None = None,
):
"""
Load a pretrained cross-coder from a file.
"""
state_dict = th.load(path, map_location="cpu", weights_only=True)
if "encoder.weight" not in state_dict:
warn(
"Cross-coder state dict was saved while torch.compiled was enabled. Fixing..."
)
state_dict = {k.split("_orig_mod.")[1]: v for k, v in state_dict.items()}
num_layers, activation_dim, dict_size = state_dict["encoder.weight"].shape
cross_coder = cls(activation_dim, dict_size, num_layers)
cross_coder.load_state_dict(state_dict)

if device is not None:
cross_coder = cross_coder.to(device)
return cross_coder.to(dtype=dtype)
# @classmethod
# def from_pretrained(
# cls,
# path: str,
# dtype: th.dtype = th.float32,
# device: th.device | None = None,
# ):
# """
# Load a pretrained cross-coder from a file.
# """
# state_dict = th.load(path, map_location="cpu", weights_only=True)
# if "encoder.weight" not in state_dict:
# warn(
# "Cross-coder state dict was saved while torch.compiled was enabled. Fixing..."
# )
# state_dict = {k.split("_orig_mod.")[1]: v for k, v in state_dict.items()}
# num_layers, activation_dim, dict_size = state_dict["encoder.weight"].shape
# cross_coder = cls(activation_dim, dict_size, num_layers)
# cross_coder.load_state_dict(state_dict)

# if device is not None:
# cross_coder = cross_coder.to(device)
# return cross_coder.to(dtype=dtype)

def resample_neurons(self, deads, activations):
# https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling
Expand Down

0 comments on commit 9ef2acd

Please sign in to comment.