Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Flash attention #725

Merged
merged 5 commits into from
Dec 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ from the repository root.
</aside>


### TensorBoard

To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.


### Containerized Setup

We also provide a Dockerfile if you prefer to run NeoX in a container. To use this option, first build an image named `gpt-neox` from the repository root directory with `docker build -t gpt-neox -f Dockerfile .`. We also host pre-built images on Docker Hub at `leogao2/gpt-neox`.
Expand Down
4 changes: 3 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ Model Arguments
The first item in the list specifies the attention type(s), and should be a list of strings. The second item
specifies the number of times to repeat those attention types in the full list.

attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird]
attention type choices: [global, local, sparse_fixed, sparse_variable, bigbird, bslongformer, gmlp, amlp, flash]

So a 12 layer network with only global attention could be specified like:
[[[`global`], 12]]
Expand All @@ -344,6 +344,8 @@ Model Arguments

If none is specified, this defaults to
[[[`global`], n_layers]]

"flash" attention refers to optimized global attention for Ampere (and some other) generation GPUs described here [Flash-Attention](https://github.com/HazyResearch/flash-attention).



Expand Down
86 changes: 86 additions & 0 deletions megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Based on: https://github.com/HazyResearch/flash-attention/blob/4a6eaa9f27df6fff7ffb2c24e894938a687dd870/flash_attn/flash_attn_interface.py

import torch
import torch.nn as nn
import torch.nn.functional as F

import flash_attn_cuda


def _flash_attn_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, return_softmax, num_splits=0,
generator=None):
"""
num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means
it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking.
Don't change it unless you know what you're doing.
"""
softmax_lse, *rest = flash_attn_cuda.fwd(
q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p,
softmax_scale, False, causal, return_softmax, num_splits, generator
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask = rest[0] if return_softmax else None
return out, softmax_lse, S_dmask


def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, num_splits=0,
generator=None):
"""
num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or
not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic.
Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel
as num_splits=3), so effectively the choices are 0, 1, and 2.
This hyperparameter can be tuned for performance, but default value (heuristic) should work fine.
"""
_, _, _, softmax_d = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, num_splits, generator)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d


class FlashAttnQKVPackedFunc(torch.autograd.Function):

@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
# Save rng_state because the backward pass will regenerate the dropout mask
rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, softmax_lse, S_dmask = _flash_attn_forward(
qkv[:, 0], qkv[:, 1], qkv[:, 2], torch.empty_like(qkv[:, 0]), cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen, dropout_p, softmax_scale, causal=causal,
return_softmax=return_softmax
)
ctx.save_for_backward(qkv, out, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
qkv, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
if rng_state is not None:
cur_rng_state = torch.cuda.get_rng_state()
torch.cuda.set_rng_state(rng_state)
dqkv = torch.empty_like(qkv)
_flash_attn_backward(
dout, qkv[:, 0], qkv[:, 1], qkv[:, 2], out, softmax_lse,
dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], cu_seqlens, cu_seqlens,
ctx.max_seqlen, ctx.max_seqlen, ctx.dropout_p, ctx.softmax_scale, ctx.causal
)
if rng_state is not None:
torch.cuda.set_rng_state(cur_rng_state)
return dqkv, None, None, None, None, None, None


def flash_attn_unpadded_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None,
causal=False, return_attn_probs=False):
return FlashAttnQKVPackedFunc.apply(qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale,
causal, return_attn_probs)
63 changes: 53 additions & 10 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def __init__(
self.rotary_emb = None

self.attention_type = neox_args.attention_config[layer_number]
self.use_flash_attention = self.attention_type == "flash"
self.sparse = self.attention_type != "global"
if self.sparse:
Quentin-Anthony marked this conversation as resolved.
Show resolved Hide resolved
self.sparse_attn = configure_sparse_attention(
Expand All @@ -268,19 +269,26 @@ def __init__(
mpu=mpu,
)
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
input_in_bf16=self.bf16,
fusion_type=get_fusion_type(neox_args),
mask_func=self.attention_mask_func,
softmax_in_fp32=self.attention_softmax_in_fp32,
scale=coeff,
)
if self.use_flash_attention:
from megatron.model.flash_attention import flash_attn_unpadded_qkvpacked_func
self.flash_attention_function = flash_attn_unpadded_qkvpacked_func
if self.pos_emb == "alibi":
raise ValueError('Flash attention is currently not compatible with AliBi positional embeddings. Use sinuisoidal, learned, or rotary embeddings instead.')
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
input_in_bf16=self.bf16,
fusion_type=get_fusion_type(neox_args),
mask_func=self.attention_mask_func,
softmax_in_fp32=self.attention_softmax_in_fp32,
scale=coeff,
)

# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = nn.Dropout(neox_args.attention_dropout)
self.dropout_p = neox_args.attention_dropout
self.attention_dropout = nn.Dropout(self.dropout_p)

# Output.
self.dense = mpu.RowParallelLinear(
Expand Down Expand Up @@ -396,6 +404,37 @@ def attention(
context_layer = context_layer.view(*output_size)
return context_layer

def flash_attention(self, query_layer, key_layer, value_layer):
# [b, np, sq, sk]
output_size = (
query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0),
)
# [s, b, np, hn] -> [b, s, np, hn] -> [b * s, 1, np, hn]
query_layer = query_layer.transpose(0, 1).reshape(output_size[0] * output_size[2], 1, output_size[1], -1)
key_layer = key_layer.transpose(0, 1).reshape(output_size[0] * output_size[3], 1, output_size[1], -1)
value_layer = value_layer.transpose(0, 1).reshape(output_size[0] * output_size[3], 1, output_size[1], -1)

# Combined q/k/v into [b * s, 3, np, hn].
qkv = torch.concat([query_layer, key_layer, value_layer], dim=1)

batch_size = output_size[0]
seqlen = output_size[2]
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device)
output = self.flash_attention_function(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=None, causal=True
)
# [b * sq, np, hn] -> [b, sq, np, hn]
matmul_result = output.view(output_size[0], output_size[2], output.shape[1], output.shape[2])
# [b, sq, np, hn] -> [b, np, sq, hn]
matmul_result = matmul_result.transpose(1, 2)

return matmul_result

def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
# TODO: sparse attn dropout?
# TODO: pad to block size
Expand Down Expand Up @@ -483,7 +522,11 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
if self.use_cache:
present = torch.stack((key_layer, value_layer))

if not self.sparse:
if self.use_flash_attention:
context_layer = self.flash_attention(
query_layer, key_layer, value_layer
)
elif not self.sparse:
context_layer = self.attention(
query_layer, key_layer, value_layer, layer_past, attention_mask
)
Expand Down
1 change: 1 addition & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"bslongformer",
"gmlp",
"amlp",
"flash",
]


Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-flashattention.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
flash-attn==0.2.2