From b45c4c3a1f6ce57ad3078493f78d6f8bd4668d92 Mon Sep 17 00:00:00 2001 From: Taha YASSINE Date: Tue, 17 Dec 2024 22:10:30 +0100 Subject: [PATCH] Update caching to work with FA2 --- sae_auto_interp/features/cache.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sae_auto_interp/features/cache.py b/sae_auto_interp/features/cache.py index 400f45e..3dafc27 100644 --- a/sae_auto_interp/features/cache.py +++ b/sae_auto_interp/features/cache.py @@ -154,6 +154,11 @@ def __init__( batch_size (int): Size of batches for processing. filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features. """ + + # Model must use FA2 to allow for efficient packing + if not hasattr(model.config, "_attn_implementation") or model.config._attn_implementation != "flash_attention_2": + raise ValueError("Model must use FlashAttention-2. Please enable it before initializing FeatureCache.") + self.model = model self.submodule_dict = submodule_dict @@ -224,7 +229,8 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): with torch.no_grad(): buffer = {} - with self.model.trace(batch): + # position_ids is required for FA2 + with self.model.trace({"input_ids": batch["input_ids"]}, position_ids=batch["position_ids"]): for module_path, submodule in self.submodule_dict.items(): buffer[module_path] = submodule.ae.output.save() for module_path, latents in buffer.items():