Skip to content

Commit

Permalink
Update caching to work with FA2
Browse files Browse the repository at this point in the history
  • Loading branch information
taha-yassine committed Dec 17, 2024
1 parent 15df167 commit 33225ff
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 1 deletion.
46 changes: 46 additions & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
{
"name": "SAE dev",
"build": {
// Sets the run context to one level up instead of the .devcontainer folder.
"context": "../..",
// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
"dockerfile": "../../Dockerfile",
"target": "gpu"
},


"runArgs": [
"--gpus=all",
"--shm-size=8g"
],

// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},

// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],

// Uncomment the next line to run commands after the container is created.
// "postCreateCommand": "cat /etc/os-release",

"mounts": [
"source=/home/tyassine/.cache/huggingface,target=/root/.cache/huggingface,type=bind"
],

// Configure tool-specific properties.
"customizations": {
"vscode": {
"extensions": [
"ms-python.python",
"ms-python.vscode-pylance",
"ms-toolsai.jupyter",
"mhutchie.git-graph"
]
}
}

// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "devcontainer"
}
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
torch
datasets
# flash-attn --no-build-isolation
nnsight
setuptools
ipykernel
git+https://github.com/taha-yassine/transformers.git@patch_fa2
8 changes: 7 additions & 1 deletion sae_auto_interp/features/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ def __init__(
batch_size (int): Size of batches for processing.
filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features.
"""

# Model must use FA2 to allow for efficient packing
if not hasattr(model.config, "_attn_implementation") or model.config._attn_implementation != "flash_attention_2":
raise ValueError("Model must use FlashAttention-2. Please enable it before initializing FeatureCache.")

self.model = model
self.submodule_dict = submodule_dict

Expand Down Expand Up @@ -224,7 +229,8 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]):

with torch.no_grad():
buffer = {}
with self.model.trace(batch):
# position_ids is required for FA2
with self.model.trace({"input_ids": batch["input_ids"]}, position_ids=batch["position_ids"]):
for module_path, submodule in self.submodule_dict.items():
buffer[module_path] = submodule.ae.output.save()
for module_path, latents in buffer.items():
Expand Down
122 changes: 122 additions & 0 deletions sae_auto_interp/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# %% Imports
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, default_data_collator
from typing import Dict, List, Tuple
import time

# %% Functions
def data_collator(features: List[Dict], return_tensors: str = "pt"):
batch = {"input_ids": [], "position_ids": []}
for x in features["input_ids"]:
batch["input_ids"] += x
batch["position_ids"] += list(range(len(x)))

return default_data_collator([batch], return_tensors=return_tensors)

def prepare_packed_dataset(
texts: List[str],
tokenizer: AutoTokenizer,
) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]:
"""
Prepare a packed dataset using continuous batching without padding.
"""
# Tokenize all texts
tokenized = tokenizer(
texts,
add_special_tokens=False,
return_attention_mask=False,
)

# Use collator to flatten and pack sequences
packed = data_collator(tokenized, return_tensors="pt")

return packed

def prepare_padded_dataset(
texts: List[str],
tokenizer: AutoTokenizer,
max_seq_length: int = 2048
) -> Dict[str, torch.Tensor]:
"""
Prepare a dataset using traditional padding.
"""
return tokenizer(
texts,
padding=True,
add_special_tokens=False,
max_length=max_seq_length,
return_tensors="pt"
)

# %% Load model and tokenizer
model_name = "EleutherAI/pythia-70m" # Small model for testing
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
device_map="auto",
)

# %% Load and prepare datasets
dataset = load_dataset("Open-Orca/FLAN", split="train", streaming=True).take(100)

dataset_padded = dataset.map(lambda x: prepare_padded_dataset(x["inputs"], tokenizer), batched=True, batch_size=10)

dataset_packed = dataset.map(lambda x: prepare_packed_dataset(x["inputs"], tokenizer), batched=True, batch_size=10, remove_columns=dataset.column_names)

# %% Test traditional padding approach
start_time = time.time()

with torch.no_grad():
for batch in dataset_padded:
input_ids = batch["input_ids"].to(model.device)
attention_mask = batch["attention_mask"].to(model.device)
padded_output = model(
input_ids=input_ids,
attention_mask=attention_mask,
)
padding_time = time.time() - start_time

# %% Test sequence packing approach
start_time = time.time()

# Process in chunks that fit the model's context window

with torch.no_grad():
for batch in dataset_packed:
input_ids = batch["input_ids"].to(model.device)
position_ids = batch["position_ids"].to(model.device)
packed_output = model(
input_ids=input_ids.unsqueeze(0),
position_ids=position_ids.unsqueeze(0),
)
packing_time = time.time() - start_time

# %%
print(f"Traditional padding processing time: {padding_time:.2f} seconds")
print(f"Sequence packing processing time: {packing_time:.2f} seconds")
print(f"Speedup: {padding_time/packing_time:.2f}x")

# %%
from nnsight import LanguageModel
nnsight_model = LanguageModel(
model_name,
torch_dtype=torch.float16,
device_map='cuda:0',
attn_implementation='flash_attention_2'
)
# print(nnsight_model)

nnsight_model.tokenizer = tokenizer

batch = next(iter(dataset_packed))

with nnsight_model.trace({"input_ids": batch["input_ids"]}, position_ids=batch["position_ids"].to(model.device).unsqueeze(0)):
logits = nnsight_model.embed_out.output.save()

print(logits.value)

# %%

0 comments on commit 33225ff

Please sign in to comment.