Skip to content

Commit

Permalink
Update caching to work with FA2
Browse files Browse the repository at this point in the history
  • Loading branch information
taha-yassine committed Jan 10, 2025
1 parent 5b41632 commit b45c4c3
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion sae_auto_interp/features/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def __init__(
batch_size (int): Size of batches for processing.
filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features.
"""

# Model must use FA2 to allow for efficient packing
if not hasattr(model.config, "_attn_implementation") or model.config._attn_implementation != "flash_attention_2":
raise ValueError("Model must use FlashAttention-2. Please enable it before initializing FeatureCache.")

self.model = model
self.submodule_dict = submodule_dict

Expand Down Expand Up @@ -224,7 +229,8 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]):

with torch.no_grad():
buffer = {}
with self.model.trace(batch):
# position_ids is required for FA2
with self.model.trace({"input_ids": batch["input_ids"]}, position_ids=batch["position_ids"]):
for module_path, submodule in self.submodule_dict.items():
buffer[module_path] = submodule.ae.output.save()
for module_path, latents in buffer.items():
Expand Down

0 comments on commit b45c4c3

Please sign in to comment.