Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable sequence packing with FlashAttention-2 #41

Draft
wants to merge 4 commits into
base: v0.2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Install this library as a local editable installation. Run the following command

```pip install -e .```

Then install FlashAttention-2.

```pip install flash-attn --no-build-isolation```

# Loading Autoencoders

This library uses NNsight to load and edit a model with autoencoders. We provide wrappers to load GPT-2 autoencoders trained by [OpenAI](https://github.com/openai/sparse_autoencoder), for the [GemmaScope SAEs](https://arxiv.org/abs/2408.05147) and for some SAEs train by EleutherAI using [SAE](https://github.com/EleutherAI/sae). See the [examples](examples/loading_saes.ipynb) directory for specific examples.
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ readme = "README.md"
requires-python = ">=3.10"
keywords = ["interpretability", "explainable-ai"]
dependencies = [
"transformers>=4.48.0",
"datasets",
"nnsight",
"orjson",
Expand All @@ -21,6 +22,7 @@ dependencies = [
"blobfile",
"transformer_lens",
"bitsandbytes",
# "flash-attn", # Install using: pip install flash-attn --no-build-isolation
]

[tool.pyright]
Expand Down
8 changes: 7 additions & 1 deletion sae_auto_interp/features/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def __init__(
batch_size (int): Size of batches for processing.
filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features.
"""

# Model must use FA2 to allow for efficient packing
if not hasattr(model.config, "_attn_implementation") or model.config._attn_implementation != "flash_attention_2":
raise ValueError("Model must use FlashAttention-2. Please enable it before initializing FeatureCache.")

self.model = model
self.submodule_dict = submodule_dict

Expand Down Expand Up @@ -224,7 +229,8 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]):

with torch.no_grad():
buffer = {}
with self.model.trace(batch):
# position_ids is required for FA2
with self.model.trace({"input_ids": batch["input_ids"]}, position_ids=batch["position_ids"]):
for module_path, submodule in self.submodule_dict.items():
buffer[module_path] = submodule.ae.output.save()
for module_path, latents in buffer.items():
Expand Down
53 changes: 52 additions & 1 deletion sae_auto_interp/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,57 @@
from transformers import AutoTokenizer
from transformers import AutoTokenizer, default_data_collator
from typing import List, Dict
import torch

def packed_data_collator(input_ids: List[Dict], return_tensors: str = "pt"):
"""
Collates input IDs into a batch.

Args:
input_ids (List[Dict]): The input IDs to collate.
return_tensors (str, optional): The type of tensors to return. Defaults to "pt".

Returns:
Dict: The collated batch.
"""
batch = {"input_ids": [], "position_ids": []}
for input_id in input_ids:
batch["input_ids"] += input_id
batch["position_ids"] += list(range(len(input_id)))

return default_data_collator([batch], return_tensors=return_tensors)

def load_dataset(
ctx_len: int,
tokenizer: AutoTokenizer,
dataset_repo: str,
dataset_split: str,
dataset_name: str = "",
column_name: str = "raw_content",
seed: int = 22,
) -> torch.Tensor:
"""
Load a Hugging Face dataset, tokenize it, shuffle it, and pack it.

Args:
ctx_len (int): Context length for tokenization.
tokenizer (AutoTokenizer): The tokenizer to use.
dataset_repo (str): The dataset repository name.
dataset_split (str): The dataset split to use.
dataset_name (str, optional): The dataset name. Defaults to "".
column_name (str, optional): The column name to use for tokenization. Defaults to "text".
seed (int, optional): Random seed for shuffling. Defaults to 22.

Returns:
torch.Tensor: The tokenized, shuffled, and packed dataset.
"""
from datasets import load_dataset
dataset = load_dataset(dataset_repo, name=dataset_name, split=dataset_split)
dataset = dataset.shuffle(seed)
dataset = dataset.map(lambda x: tokenizer(x[column_name], add_special_tokens=False, return_attention_mask=False), batched=True)
dataset = dataset.map(lambda x: packed_data_collator(x['input_ids']), batched=True, batch_size=10, remove_columns=dataset.column_names) # TODO: Change hardcoded batch size

return dataset

def load_tokenized_data(
ctx_len: int,
tokenizer: AutoTokenizer,
Expand Down