Skip to content

Commit

Permalink
Attn: Add paged mode for forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed May 14, 2024
1 parent c93088f commit affc350
Showing 1 changed file with 177 additions and 26 deletions.
203 changes: 177 additions & 26 deletions exllamav2/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,33 @@
# Detect flash-attn

has_flash_attn = False
has_flash_attn_with_paged = False

try:
import flash_attn
flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()]
is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count()))

if flash_attn_ver >= [2, 2, 1] and is_ampere_or_newer_gpu:

if not is_ampere_or_newer_gpu:
print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.")

if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]:
from flash_attn import flash_attn_func
has_flash_attn = True

if [2, 5, 7] <= flash_attn_ver:
from flash_attn import flash_attn_func, flash_attn_with_kvcache
has_flash_attn = True
has_flash_attn_with_paged = True

except ModuleNotFoundError:
pass

def assert_paged_attn():
global has_flash_attn_with_paged
assert has_flash_attn_with_paged, \
"Paged attention required Flash Attention 2.5.7 or later"


class ExLlamaV2Attention(ExLlamaV2Module):

Expand Down Expand Up @@ -77,15 +93,23 @@ class Params:
attn_masks: torch.Tensor | None
position_offsets: torch.Tensor | None
past_lens_tensor: torch.Tensor | None

def __init__(self,
batch_size: int,
seq_len: int,
past_len: int | list[int],
input_mask: torch.Tensor,
position_offsets: torch.Tensor):
paged: bool

def __init__(
self,
batch_size: int,
seq_len: int | None = None,
past_len: int | list[int] | None = None,
input_mask: torch.Tensor | None = None,
position_offsets: torch.Tensor | None = None,
paged = False
):

self.batch_size = batch_size
self.paged = paged

if paged: return

self.seq_len = seq_len
if isinstance(past_len, list):
self.past_len = None
Expand All @@ -102,74 +126,58 @@ def __init__(self,

self.position_offsets = position_offsets
self.past_lens_tensor = None
self.paged = paged


def is_causal(self) -> bool:

return self.input_mask is None


def get_position_offsets(self, device) -> torch.Tensor | None:

assert self.position_offsets is not None
if self.position_offsets.device != device:
self.position_offsets = safe_move_tensor(self.position_offsets, device)
return self.position_offsets


def get_past_lens(self, device) -> torch.Tensor | None:

assert self.past_lens is not None
if self.past_lens_tensor is None:
self.past_lens_tensor = torch.tensor(self.past_lens, dtype = torch.int, device = device)
elif self.past_lens_tensor.device != device:
self.past_lens_tensor = safe_move_tensor(self.past_lens_tensor, device)
return self.past_lens_tensor


def get_attn_mask(self, device) -> torch.Tensor | None:

if self.attn_mask is None:
self.attn_mask = self.build_attn_mask(device)
elif self.attn_mask.device != device:
self.attn_mask = safe_move_tensor(self.attn_mask, device)
return self.attn_mask


def get_attn_masks(self, device) -> torch.Tensor | None:

if self.attn_masks is None:
self.attn_masks = self.build_attn_masks(device)
elif self.attn_masks[0] is not None and self.attn_masks[0].device != device:
self.attn_masks = [(safe_move_tensor(m, device) if m is not None else None) for m in self.attn_masks]
return self.attn_masks


def build_single_attn_mask(self, batch_size, seq_len, past_len, device, input_mask):

attn_mask = torch.zeros((batch_size, 1, seq_len, past_len + seq_len), dtype = torch.float16, device = device)
attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), -65504.0))
attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu

if input_mask is not None:
min_mask_width = min(input_mask.shape[-1], seq_len + past_len)
input_mask_part = safe_move_tensor(input_mask[:, :min_mask_width], attn_mask.device)
input_mask_part = input_mask_part.unsqueeze(1).unsqueeze(2)
attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part)

return attn_mask


def build_attn_mask(self, device) -> torch.Tensor | None:
assert not self.multi_cache, "Building single mask for multiple caches"

if self.input_mask is None and self.seq_len == 1: return None
return self.build_single_attn_mask(self.batch_size, self.seq_len, self.past_len, device, self.input_mask)


def build_attn_masks(self, device) -> torch.Tensor | None:
assert self.multi_cache, "Building multiple masks for single cache"

attn_masks = []
for i, past_len in enumerate(self.past_lens):
if self.input_mask is None and self.seq_len == 1:
Expand All @@ -179,6 +187,38 @@ def build_attn_masks(self, device) -> torch.Tensor | None:
return attn_masks


class PagedParams(Params):

block_index: torch.Tensor

def __init__(
self,
batch_size: int,
block_index: torch.Tensor,
cache_seqlens: torch.Tensor
):
super().__init__(
batch_size = batch_size,
paged = True
)

self.block_index = block_index
self.cache_seqlens = cache_seqlens

def get_attn_mask(self, device):
raise NotImplementedError()

def get_block_index(self, device) -> torch.Tensor:
if self.block_index.device != device:
self.block_index = safe_move_tensor(self.block_index, device)
return self.block_index

def get_cache_seqlens(self, device) -> torch.Tensor:
if self.cache_seqlens.device != device:
self.cache_seqlens = safe_move_tensor(self.cache_seqlens, device)
return self.cache_seqlens


def __init__(self,
model: ExLlamaV2,
key: str,
Expand Down Expand Up @@ -447,6 +487,108 @@ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states


def forward_paged(self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,
attn_params: ExLlamaV2Attention.PagedParams | None = None,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs) -> torch.Tensor:

cfg = self.model.config
constants = self.model.get_device_tensors(self.device_idx)

batch_size, q_len, _ = hidden_states.shape
q = torch.empty((batch_size, q_len, cfg.num_attention_heads, cfg.head_dim), device = hidden_states.device, dtype = torch.half)
k = torch.empty((batch_size, q_len, cfg.num_key_value_heads, cfg.head_dim), device = hidden_states.device, dtype = torch.half)
v = torch.empty((batch_size, q_len, cfg.num_key_value_heads, cfg.head_dim), device = hidden_states.device, dtype = torch.half)

# TODO: Support paged Q4 cache and maybe FP8?
k_cache, v_cache = cache.get_kv_state(self.layer_idx, batch_size, 0, 0)

if loras is None or self.temp_lora_size == 0:
pass_loras, pass_lora_temp = [], none_tensor
else:
pass_loras, pass_lora_temp = [id(x) for x in loras], torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device)

ext_c.q_attn_forward_1(
self.q_handle,
hidden_states,
batch_size,
q_len,
0,
attn_params.get_cache_seqlens(self.device()),
q,
k,
v,
constants.sin,
constants.cos,
pass_loras,
pass_lora_temp
)

attn_output = flash_attn_with_kvcache(
q = q,
k = k,
v = v,
k_cache = k_cache,
v_cache = v_cache,
cache_seqlens = attn_params.get_cache_seqlens(self.device()),
block_table = attn_params.get_block_index(self.device()),
causal = True
)
attn_output = attn_output.view((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))

# Output projection

ext_c.q_attn_forward_2(
self.q_handle,
hidden_states,
attn_output,
batch_size,
q_len,
pass_loras,
pass_lora_temp
)

return hidden_states


def _attn_matmul(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):

q_states = q_states.transpose(1, 2)
k_states = k_states.transpose(1, 2)
v_states = v_states.transpose(1, 2)

k_states = self.repeat_kv(k_states, cfg.num_key_value_groups)
k_states = k_states.transpose(-1, -2)

attn_weights = torch.matmul(q_states, k_states)

attn_weights *= 1 / math.sqrt(cfg.head_dim)
attn_mask = attn_params.get_attn_mask(attn_weights.device)
if attn_mask is not None: attn_weights = attn_weights + attn_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16)

v_states = self.repeat_kv(v_states, cfg.num_key_value_groups)
attn_output = torch.matmul(attn_weights, v_states)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
return attn_output


def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg):

attn_output = flash_attn_func(
q_states,
k_states,
v_states,
causal = True
)
attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim))
return attn_output


def forward(self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,
Expand All @@ -458,6 +600,15 @@ def forward(self,

global has_flash_attn

if isinstance(attn_params, ExLlamaV2Attention.PagedParams):
return self.forward_paged(
hidden_states,
cache,
attn_params,
loras = loras,
**kwargs
)

if self.q_handle is None or intermediates:
return self.forward_torch(
hidden_states,
Expand Down

0 comments on commit affc350

Please sign in to comment.