Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to replace attention with flash attention #25

Closed
ryanchesler opened this issue Apr 8, 2023 · 7 comments
Closed

Add option to replace attention with flash attention #25

ryanchesler opened this issue Apr 8, 2023 · 7 comments

Comments

@ryanchesler
Copy link
Contributor

Flash attention has already been integrated into gpt-neox models here: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py#L215

Can add the swapped model definition as an option to the training and generation scripts and benchmark the speed difference.

Converting Llama and others might be more work. it uses a pretty standard looking attention, but not sure how it differs from the pytorch default. Might just need to remap some layer names https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L160

@pseudotensor
Copy link
Collaborator

Some old notes from slack last week:

Right now, nothing in torch/huggingface directly can be used to do flash attention. One would need to swap-out the layer, which is possible as this is what gpt-neox repo does. I'll have to look more carefully with this approach to see how to do it, similar to how the other vicuna repo does for llama.
And alternative is to use gpt-neox repo directly with their training code, which is probably fine. I installed all their dependencies and nothing had issues.

source ~/.bashrc.mamba
mamba create -n gptneox
conda activate gptneox
mamba install python=3.8 -y
mamba install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia -y
cd gpt-neox/
pip install -r requirements/requirements.txt
mamba install cudatoolkit-dev=11.7 cudatoolkit=11.7 -c conda-forge -c nvidia -y
unset CUDA_HOME
python ./megatron/fused_kernels/setup.py install
pip install -r ./requirements/requirements-flashattention.txt
cd ..
git clone https://github.com/EleutherAI/DeeperSpeed.git
cd DeeperSpeed
./install.sh

cuda 11.7 required.

@pseudotensor
Copy link
Collaborator

WIP for neox using flash in huggingface transformers, but no work for last 3 months, so probably dead: https://github.com/conceptofmind/flash-gpt

@pseudotensor
Copy link
Collaborator

pseudotensor commented Apr 11, 2023

Amazon thing: https://aws.amazon.com/blogs/machine-learning/new-performance-improvements-in-amazon-sagemaker-model-parallel-library/

To help our customers further minimize training costs and accelerate time-to-market, we are thrilled to introduce two new performance improvements in SageMaker model parallel — SMDDP Collectives and FlashAttention. SMDDP Collectives is the most performant collective library on AWS infrastructure for large model training offered by SageMaker distributed data parallel library. FlashAttention is introduced in Dao et al., which re-implements the attention mechanism in an IO-aware manner, reducing the memory bandwidth requirement and saving on attention speed and memory footprint. These two components collectively push our sharded data parallel technique to be 30.58% faster when training a 100B parameter GPT-NeoX model on 32 p4d.24xlarge instances. For customers who are already using sharded data parallel on supported models, no code changes are necessary to benefit from the performance boost offered by these latest features.

So maybe we should use sagemaker. I noticed this before somewhere else I think.
But unsure how compatible with other weights e.g. huggingface

100B parameter GPT-NeoX model on 32 p4d.24xlarge instances

@pseudotensor
Copy link
Collaborator

You can use the same install above to then make llama use flash attention using the wrappers/patches from vicunda model:
https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py
https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train_mem.py#L5
So we can already do that for llama case if we are interested.

@pseudotensor
Copy link
Collaborator

EleutherAI/gpt-neox#725

@arnocandel
Copy link
Member

arnocandel commented Apr 11, 2023

@arnocandel
Copy link
Member

#128

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants