diff --git a/exllamav2/attn.py b/exllamav2/attn.py index 5a558454..f2550c98 100644 --- a/exllamav2/attn.py +++ b/exllamav2/attn.py @@ -24,17 +24,33 @@ # Detect flash-attn has_flash_attn = False +has_flash_attn_with_paged = False + try: import flash_attn flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()] is_ampere_or_newer_gpu = any(torch.cuda.get_device_properties(i).major >= 8 for i in range(torch.cuda.device_count())) - - if flash_attn_ver >= [2, 2, 1] and is_ampere_or_newer_gpu: + + if not is_ampere_or_newer_gpu: + print(" ## Warning: Flash Attention is installed but unsupported GPUs were detected.") + + if [2, 2, 1] <= flash_attn_ver < [2, 5, 7]: from flash_attn import flash_attn_func has_flash_attn = True + + if [2, 5, 7] <= flash_attn_ver: + from flash_attn import flash_attn_func, flash_attn_with_kvcache + has_flash_attn = True + has_flash_attn_with_paged = True + except ModuleNotFoundError: pass +def assert_paged_attn(): + global has_flash_attn_with_paged + assert has_flash_attn_with_paged, \ + "Paged attention required Flash Attention 2.5.7 or later" + class ExLlamaV2Attention(ExLlamaV2Module): @@ -77,15 +93,23 @@ class Params: attn_masks: torch.Tensor | None position_offsets: torch.Tensor | None past_lens_tensor: torch.Tensor | None - - def __init__(self, - batch_size: int, - seq_len: int, - past_len: int | list[int], - input_mask: torch.Tensor, - position_offsets: torch.Tensor): + paged: bool + + def __init__( + self, + batch_size: int, + seq_len: int | None = None, + past_len: int | list[int] | None = None, + input_mask: torch.Tensor | None = None, + position_offsets: torch.Tensor | None = None, + paged = False + ): self.batch_size = batch_size + self.paged = paged + + if paged: return + self.seq_len = seq_len if isinstance(past_len, list): self.past_len = None @@ -102,23 +126,19 @@ def __init__(self, self.position_offsets = position_offsets self.past_lens_tensor = None + self.paged = paged def is_causal(self) -> bool: - return self.input_mask is None - def get_position_offsets(self, device) -> torch.Tensor | None: - assert self.position_offsets is not None if self.position_offsets.device != device: self.position_offsets = safe_move_tensor(self.position_offsets, device) return self.position_offsets - def get_past_lens(self, device) -> torch.Tensor | None: - assert self.past_lens is not None if self.past_lens_tensor is None: self.past_lens_tensor = torch.tensor(self.past_lens, dtype = torch.int, device = device) @@ -126,50 +146,38 @@ def get_past_lens(self, device) -> torch.Tensor | None: self.past_lens_tensor = safe_move_tensor(self.past_lens_tensor, device) return self.past_lens_tensor - def get_attn_mask(self, device) -> torch.Tensor | None: - if self.attn_mask is None: self.attn_mask = self.build_attn_mask(device) elif self.attn_mask.device != device: self.attn_mask = safe_move_tensor(self.attn_mask, device) return self.attn_mask - def get_attn_masks(self, device) -> torch.Tensor | None: - if self.attn_masks is None: self.attn_masks = self.build_attn_masks(device) elif self.attn_masks[0] is not None and self.attn_masks[0].device != device: self.attn_masks = [(safe_move_tensor(m, device) if m is not None else None) for m in self.attn_masks] return self.attn_masks - def build_single_attn_mask(self, batch_size, seq_len, past_len, device, input_mask): - attn_mask = torch.zeros((batch_size, 1, seq_len, past_len + seq_len), dtype = torch.float16, device = device) attn_mask_triu = torch.triu(torch.full((seq_len - 1, seq_len - 1), -65504.0)) attn_mask[:, :, : seq_len - 1, past_len + 1: past_len + seq_len] = attn_mask_triu - if input_mask is not None: min_mask_width = min(input_mask.shape[-1], seq_len + past_len) input_mask_part = safe_move_tensor(input_mask[:, :min_mask_width], attn_mask.device) input_mask_part = input_mask_part.unsqueeze(1).unsqueeze(2) attn_mask[:, :, :, :min_mask_width] = torch.minimum(attn_mask[:, :, :, :min_mask_width], input_mask_part) - return attn_mask - def build_attn_mask(self, device) -> torch.Tensor | None: assert not self.multi_cache, "Building single mask for multiple caches" - if self.input_mask is None and self.seq_len == 1: return None return self.build_single_attn_mask(self.batch_size, self.seq_len, self.past_len, device, self.input_mask) - def build_attn_masks(self, device) -> torch.Tensor | None: assert self.multi_cache, "Building multiple masks for single cache" - attn_masks = [] for i, past_len in enumerate(self.past_lens): if self.input_mask is None and self.seq_len == 1: @@ -179,6 +187,38 @@ def build_attn_masks(self, device) -> torch.Tensor | None: return attn_masks + class PagedParams(Params): + + block_index: torch.Tensor + + def __init__( + self, + batch_size: int, + block_index: torch.Tensor, + cache_seqlens: torch.Tensor + ): + super().__init__( + batch_size = batch_size, + paged = True + ) + + self.block_index = block_index + self.cache_seqlens = cache_seqlens + + def get_attn_mask(self, device): + raise NotImplementedError() + + def get_block_index(self, device) -> torch.Tensor: + if self.block_index.device != device: + self.block_index = safe_move_tensor(self.block_index, device) + return self.block_index + + def get_cache_seqlens(self, device) -> torch.Tensor: + if self.cache_seqlens.device != device: + self.cache_seqlens = safe_move_tensor(self.cache_seqlens, device) + return self.cache_seqlens + + def __init__(self, model: ExLlamaV2, key: str, @@ -447,6 +487,108 @@ def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states + def forward_paged(self, + hidden_states: torch.Tensor, + cache: ExLlamaV2CacheBase | None = None, + attn_params: ExLlamaV2Attention.PagedParams | None = None, + loras: list[ExLlamaV2Lora] | None = None, + **kwargs) -> torch.Tensor: + + cfg = self.model.config + constants = self.model.get_device_tensors(self.device_idx) + + batch_size, q_len, _ = hidden_states.shape + q = torch.empty((batch_size, q_len, cfg.num_attention_heads, cfg.head_dim), device = hidden_states.device, dtype = torch.half) + k = torch.empty((batch_size, q_len, cfg.num_key_value_heads, cfg.head_dim), device = hidden_states.device, dtype = torch.half) + v = torch.empty((batch_size, q_len, cfg.num_key_value_heads, cfg.head_dim), device = hidden_states.device, dtype = torch.half) + + # TODO: Support paged Q4 cache and maybe FP8? + k_cache, v_cache = cache.get_kv_state(self.layer_idx, batch_size, 0, 0) + + if loras is None or self.temp_lora_size == 0: + pass_loras, pass_lora_temp = [], none_tensor + else: + pass_loras, pass_lora_temp = [id(x) for x in loras], torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device) + + ext_c.q_attn_forward_1( + self.q_handle, + hidden_states, + batch_size, + q_len, + 0, + attn_params.get_cache_seqlens(self.device()), + q, + k, + v, + constants.sin, + constants.cos, + pass_loras, + pass_lora_temp + ) + + attn_output = flash_attn_with_kvcache( + q = q, + k = k, + v = v, + k_cache = k_cache, + v_cache = v_cache, + cache_seqlens = attn_params.get_cache_seqlens(self.device()), + block_table = attn_params.get_block_index(self.device()), + causal = True + ) + attn_output = attn_output.view((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) + + # Output projection + + ext_c.q_attn_forward_2( + self.q_handle, + hidden_states, + attn_output, + batch_size, + q_len, + pass_loras, + pass_lora_temp + ) + + return hidden_states + + + def _attn_matmul(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): + + q_states = q_states.transpose(1, 2) + k_states = k_states.transpose(1, 2) + v_states = v_states.transpose(1, 2) + + k_states = self.repeat_kv(k_states, cfg.num_key_value_groups) + k_states = k_states.transpose(-1, -2) + + attn_weights = torch.matmul(q_states, k_states) + + attn_weights *= 1 / math.sqrt(cfg.head_dim) + attn_mask = attn_params.get_attn_mask(attn_weights.device) + if attn_mask is not None: attn_weights = attn_weights + attn_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float16) + + v_states = self.repeat_kv(v_states, cfg.num_key_value_groups) + attn_output = torch.matmul(attn_weights, v_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) + return attn_output + + + def _attn_flash(self, batch_size, q_len, q_states, k_states, v_states, attn_params, cfg): + + attn_output = flash_attn_func( + q_states, + k_states, + v_states, + causal = True + ) + attn_output = attn_output.reshape((batch_size, q_len, cfg.num_attention_heads * cfg.head_dim)) + return attn_output + + def forward(self, hidden_states: torch.Tensor, cache: ExLlamaV2CacheBase | None = None, @@ -458,6 +600,15 @@ def forward(self, global has_flash_attn + if isinstance(attn_params, ExLlamaV2Attention.PagedParams): + return self.forward_paged( + hidden_states, + cache, + attn_params, + loras = loras, + **kwargs + ) + if self.q_handle is None or intermediates: return self.forward_torch( hidden_states,