Skip to content

Commit

Permalink
[Major Change][Undecided yet] Move to FlashDecoding instead of PagedA…
Browse files Browse the repository at this point in the history
…ttention kernel. (#1940)

* Using flash decoding

Conditional flashdecoding.

Fix max_q.

Working kvcache

Working version with flash decoding.

Make it work for mistral.

Fix after rebase..

Less intrusive.

REvert changes in modeling.

Speedup flashdecoding.

HHachweew
Hack to make other models work.

Fixing non flash decoding llama path.

Router logic knows about page size.

Missing 2 models.

Missing cohere.

Fixing cohere flash decoding.

Revamped all this architecture.

Fix cohere.

Fixing falcon.

Enabling custom block size schedule.

Update router/src/infer.rs

Not sending preallocated output.

* Making it work on non flash decoding.

* Fix Cohere.

* Fix non decoding paths.

* Rebased.

* No need for cache_manager anymore.

* Update?

* "ipex" -> "cpu"

* These do not belong.

* Factoring cu_seqlen_qk for better abstracting over every model.

* Fixing non flash tests/imports.

* Changing return everywhere.

* Update mistral past.

* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).

* Fixup mistral clamping (had issues with cuda graphs).

* No need to recreate anything actually.
  • Loading branch information
Narsil authored and glegendre01 committed Jul 2, 2024
1 parent b01c689 commit 1d542af
Show file tree
Hide file tree
Showing 24 changed files with 222 additions and 74 deletions.
9 changes: 8 additions & 1 deletion router/src/infer/v2/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@ impl SchedulerV2 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let queue = Queue::new(requires_padding, 16, window_size, speculate);
// Infer shared state
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(requires_padding, block_size, window_size, speculate);
let batching_task_notifier = Arc::new(Notify::new());

// Spawn batching background task that contains all the inference logic
Expand Down
8 changes: 7 additions & 1 deletion router/src/infer/v3/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ impl SchedulerV3 {
speculate: u32,
generation_health: Arc<AtomicBool>,
) -> Self {
let flashdecoding = if let Ok(flashdecoding) = std::env::var("FLASH_DECODING") {
matches!(flashdecoding.to_lowercase().as_str(), "1" | "true")
} else {
false
};
let block_size = if flashdecoding { 256 } else { 16 };
let queue = Queue::new(
requires_padding,
16,
block_size,
window_size,
speculate,
max_batch_total_tokens,
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from text_generation_server.utils.import_utils import SYSTEM
import os

from .common import Seqlen

if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
if SYSTEM == "cuda":
Expand Down
44 changes: 44 additions & 0 deletions server/text_generation_server/layers/attention/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from dataclasses import dataclass
from text_generation_server.models.globals import FLASH_DECODING
import torch
from typing import Optional


if FLASH_DECODING:

@dataclass
class Seqlen:
input_lengths: torch.Tensor
cu_seqlen_q: Optional[torch.Tensor]
cu_seqlen_k: Optional[torch.Tensor]

def __init__(self, input_lengths):
self.input_lengths = input_lengths
device = self.input_lengths.device
shape = self.input_lengths.shape
cu_seqlen_q = torch.arange(
shape[0] + 1,
device=device,
dtype=torch.int32,
)
cu_seqlen_k = torch.zeros(shape[-1] + 1, device=device, dtype=torch.int32)
# cuda graphs don't like this and this is necessary to clamp within mistral
# Although FA2 might not want the clamping
# cu_seqlen_k[0] = 0
torch.cumsum(self.input_lengths, -1, out=cu_seqlen_k[1:])

self.cu_seqlen_q = cu_seqlen_q
self.cu_seqlen_k = cu_seqlen_k

def clamp(self, max):
# Flash decoding doesn't need to clamp
return self

else:

@dataclass
class Seqlen:
input_lengths: torch.Tensor

def clamp(self, max):
return Seqlen(torch.clamp(self.input_lengths, max=max))
137 changes: 92 additions & 45 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING, BLOCK_SIZE
from text_generation_server.layers.attention import Seqlen

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
Expand All @@ -21,7 +23,14 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)


def paged_attention(
Expand All @@ -32,7 +41,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
Expand All @@ -53,7 +62,8 @@ def paged_attention(
#

# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
# block_size = value_cache.shape[3]
block_size = BLOCK_SIZE
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE

Expand All @@ -62,58 +72,95 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
from vllm._C import ops
if FLASH_DECODING:
max_q = 1
max_k = max_s
import flash_attn_2_cuda

use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
ops.paged_attention_v1(
out,
# TODO fixme when flash contains the fix.
# Number of splits is not correctly handled
# by the current path
# https://github.com/Dao-AILab/flash-attention/blob/320fb59487658f033f56711efd3d61b7c7a6f8f3/csrc/flash_attn/flash_api.cpp#L577
# This fails becuase we're using causal, therefore window_right is set to 0 and the split logic is never applied.
out2 = flash_attn_2_cuda.varlen_fwd(
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
None,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_k,
None,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
max_q,
max_k,
0.0, # dropout
softmax_scale,
False, # zero_tensors
True, # causal
-1, # Window_left
-1, # Window right
False, # return softmax
None, # generator
)
return out2[0]
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)
input_lengths = seqlen.input_lengths
from vllm._C import ops

ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
use_v1 = max_s <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512
)
if use_v1:
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=out.dtype,
device=out.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=out.device,
)
max_logits = torch.empty_like(exp_sums)

ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
return out


try:
Expand Down
5 changes: 3 additions & 2 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
Expand All @@ -66,7 +67,7 @@ def paged_attention(
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
cu_seqlen_q,
BLOCK_SIZE,
max_s,
None,
Expand Down
15 changes: 13 additions & 2 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING
from loguru import logger

major, minor = torch.cuda.get_device_capability()
Expand All @@ -26,7 +27,14 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)
if FLASH_DECODING:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)


def paged_attention(
Expand All @@ -37,7 +45,8 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: torch.Tensor,
cu_seqlen_q: torch.Tensor,
cu_seqlen_k: torch.Tensor,
max_s: int,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
Expand All @@ -61,6 +70,7 @@ def paged_attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k

# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
Expand Down Expand Up @@ -119,6 +129,7 @@ def paged_attention(
"auto",
1.0,
)
return out


if ENGINE != "triton":
Expand Down
3 changes: 2 additions & 1 deletion server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
from text_generation_server.models.seq2seq_lm import Seq2SeqLM
Expand Down Expand Up @@ -53,6 +52,7 @@
FLASH_ATTENTION = True

try:
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.flash_rw import FlashRWSharded
from text_generation_server.models.flash_gpt2 import FlashGPT2
from text_generation_server.models.flash_neox import FlashNeoXSharded
Expand Down Expand Up @@ -92,6 +92,7 @@
FLASH_ATTENTION = False

if FLASH_ATTENTION:
__all__.append(FlashCausalLM)
__all__.append(FlashGPT2)
__all__.append(FlashNeoXSharded)
__all__.append(FlashRWSharded)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
attention,
reshape_and_cache,
)
from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.layers import (
TensorParallelRowLinear,
Expand Down Expand Up @@ -259,8 +260,8 @@ def forward(
cu_seqlen_prefill,
kv_cache,
block_tables,
slots,
input_lengths,
slots,
max_s,
):
qkv = self.query_key_value(hidden_states)
Expand Down Expand Up @@ -304,7 +305,7 @@ def forward(
)
# Decode
else:
paged_attention(
attn_output = paged_attention(
attn_output,
query,
kv_cache[0],
Expand Down Expand Up @@ -464,6 +465,7 @@ def forward(
)

residual = None

for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states,
Expand Down
Loading

0 comments on commit 1d542af

Please sign in to comment.