-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
15df167
commit 33225ff
Showing
4 changed files
with
182 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the | ||
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile | ||
{ | ||
"name": "SAE dev", | ||
"build": { | ||
// Sets the run context to one level up instead of the .devcontainer folder. | ||
"context": "../..", | ||
// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename. | ||
"dockerfile": "../../Dockerfile", | ||
"target": "gpu" | ||
}, | ||
|
||
|
||
"runArgs": [ | ||
"--gpus=all", | ||
"--shm-size=8g" | ||
], | ||
|
||
// Features to add to the dev container. More info: https://containers.dev/features. | ||
// "features": {}, | ||
|
||
// Use 'forwardPorts' to make a list of ports inside the container available locally. | ||
// "forwardPorts": [], | ||
|
||
// Uncomment the next line to run commands after the container is created. | ||
// "postCreateCommand": "cat /etc/os-release", | ||
|
||
"mounts": [ | ||
"source=/home/tyassine/.cache/huggingface,target=/root/.cache/huggingface,type=bind" | ||
], | ||
|
||
// Configure tool-specific properties. | ||
"customizations": { | ||
"vscode": { | ||
"extensions": [ | ||
"ms-python.python", | ||
"ms-python.vscode-pylance", | ||
"ms-toolsai.jupyter", | ||
"mhutchie.git-graph" | ||
] | ||
} | ||
} | ||
|
||
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root. | ||
// "remoteUser": "devcontainer" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
torch | ||
datasets | ||
# flash-attn --no-build-isolation | ||
nnsight | ||
setuptools | ||
ipykernel | ||
git+https://github.com/taha-yassine/transformers.git@patch_fa2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# %% Imports | ||
import torch | ||
from datasets import load_dataset | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, default_data_collator | ||
from typing import Dict, List, Tuple | ||
import time | ||
|
||
# %% Functions | ||
def data_collator(features: List[Dict], return_tensors: str = "pt"): | ||
batch = {"input_ids": [], "position_ids": []} | ||
for x in features["input_ids"]: | ||
batch["input_ids"] += x | ||
batch["position_ids"] += list(range(len(x))) | ||
|
||
return default_data_collator([batch], return_tensors=return_tensors) | ||
|
||
def prepare_packed_dataset( | ||
texts: List[str], | ||
tokenizer: AutoTokenizer, | ||
) -> Tuple[torch.Tensor, torch.Tensor, List[Dict]]: | ||
""" | ||
Prepare a packed dataset using continuous batching without padding. | ||
""" | ||
# Tokenize all texts | ||
tokenized = tokenizer( | ||
texts, | ||
add_special_tokens=False, | ||
return_attention_mask=False, | ||
) | ||
|
||
# Use collator to flatten and pack sequences | ||
packed = data_collator(tokenized, return_tensors="pt") | ||
|
||
return packed | ||
|
||
def prepare_padded_dataset( | ||
texts: List[str], | ||
tokenizer: AutoTokenizer, | ||
max_seq_length: int = 2048 | ||
) -> Dict[str, torch.Tensor]: | ||
""" | ||
Prepare a dataset using traditional padding. | ||
""" | ||
return tokenizer( | ||
texts, | ||
padding=True, | ||
add_special_tokens=False, | ||
max_length=max_seq_length, | ||
return_tensors="pt" | ||
) | ||
|
||
# %% Load model and tokenizer | ||
model_name = "EleutherAI/pythia-70m" # Small model for testing | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
tokenizer.pad_token = tokenizer.eos_token | ||
model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
torch_dtype=torch.float16, | ||
attn_implementation="flash_attention_2", | ||
device_map="auto", | ||
) | ||
|
||
# %% Load and prepare datasets | ||
dataset = load_dataset("Open-Orca/FLAN", split="train", streaming=True).take(100) | ||
|
||
dataset_padded = dataset.map(lambda x: prepare_padded_dataset(x["inputs"], tokenizer), batched=True, batch_size=10) | ||
|
||
dataset_packed = dataset.map(lambda x: prepare_packed_dataset(x["inputs"], tokenizer), batched=True, batch_size=10, remove_columns=dataset.column_names) | ||
|
||
# %% Test traditional padding approach | ||
start_time = time.time() | ||
|
||
with torch.no_grad(): | ||
for batch in dataset_padded: | ||
input_ids = batch["input_ids"].to(model.device) | ||
attention_mask = batch["attention_mask"].to(model.device) | ||
padded_output = model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
) | ||
padding_time = time.time() - start_time | ||
|
||
# %% Test sequence packing approach | ||
start_time = time.time() | ||
|
||
# Process in chunks that fit the model's context window | ||
|
||
with torch.no_grad(): | ||
for batch in dataset_packed: | ||
input_ids = batch["input_ids"].to(model.device) | ||
position_ids = batch["position_ids"].to(model.device) | ||
packed_output = model( | ||
input_ids=input_ids.unsqueeze(0), | ||
position_ids=position_ids.unsqueeze(0), | ||
) | ||
packing_time = time.time() - start_time | ||
|
||
# %% | ||
print(f"Traditional padding processing time: {padding_time:.2f} seconds") | ||
print(f"Sequence packing processing time: {packing_time:.2f} seconds") | ||
print(f"Speedup: {padding_time/packing_time:.2f}x") | ||
|
||
# %% | ||
from nnsight import LanguageModel | ||
nnsight_model = LanguageModel( | ||
model_name, | ||
torch_dtype=torch.float16, | ||
device_map='cuda:0', | ||
attn_implementation='flash_attention_2' | ||
) | ||
# print(nnsight_model) | ||
|
||
nnsight_model.tokenizer = tokenizer | ||
|
||
batch = next(iter(dataset_packed)) | ||
|
||
with nnsight_model.trace({"input_ids": batch["input_ids"]}, position_ids=batch["position_ids"].to(model.device).unsqueeze(0)): | ||
logits = nnsight_model.embed_out.output.save() | ||
|
||
print(logits.value) | ||
|
||
# %% |