From 9ef2acd25e2f860ea66d2c47540abadc16e1ed63 Mon Sep 17 00:00:00 2001 From: Clement Dumas Date: Fri, 22 Nov 2024 13:45:09 +0100 Subject: [PATCH] testing --- dictionary_learning/dictionary.py | 47 +++++++++++++++---------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/dictionary_learning/dictionary.py b/dictionary_learning/dictionary.py index 1e82d31..2b02909 100644 --- a/dictionary_learning/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -11,7 +11,6 @@ from torch.nn.functional import relu import einops from warnings import warn -import tempfile class Dictionary(ABC, nn.Module, PyTorchModelHubMixin): @@ -593,29 +592,29 @@ def forward(self, x: th.Tensor, output_features=False): else: return x_hat - @classmethod - def from_pretrained( - cls, - path: str, - dtype: th.dtype = th.float32, - device: th.device | None = None, - ): - """ - Load a pretrained cross-coder from a file. - """ - state_dict = th.load(path, map_location="cpu", weights_only=True) - if "encoder.weight" not in state_dict: - warn( - "Cross-coder state dict was saved while torch.compiled was enabled. Fixing..." - ) - state_dict = {k.split("_orig_mod.")[1]: v for k, v in state_dict.items()} - num_layers, activation_dim, dict_size = state_dict["encoder.weight"].shape - cross_coder = cls(activation_dim, dict_size, num_layers) - cross_coder.load_state_dict(state_dict) - - if device is not None: - cross_coder = cross_coder.to(device) - return cross_coder.to(dtype=dtype) + # @classmethod + # def from_pretrained( + # cls, + # path: str, + # dtype: th.dtype = th.float32, + # device: th.device | None = None, + # ): + # """ + # Load a pretrained cross-coder from a file. + # """ + # state_dict = th.load(path, map_location="cpu", weights_only=True) + # if "encoder.weight" not in state_dict: + # warn( + # "Cross-coder state dict was saved while torch.compiled was enabled. Fixing..." + # ) + # state_dict = {k.split("_orig_mod.")[1]: v for k, v in state_dict.items()} + # num_layers, activation_dim, dict_size = state_dict["encoder.weight"].shape + # cross_coder = cls(activation_dim, dict_size, num_layers) + # cross_coder.load_state_dict(state_dict) + + # if device is not None: + # cross_coder = cross_coder.to(device) + # return cross_coder.to(dtype=dtype) def resample_neurons(self, deads, activations): # https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling