From c93088f0015728ce677e8bf39cdbc16c6f19c946 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Tue, 14 May 2024 22:46:35 +0200 Subject: [PATCH] Model: Expose chunk forward fn and allow attn_params as input to forward pass --- exllamav2/model.py | 67 ++++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 32 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index ffbcc93d..e2800ab1 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -683,16 +683,16 @@ def forward(self, assert q_len <= effective_max_input_len, "Maximum input length exceeded in model.forward" - result, last_state = self._forward(input_ids = input_ids, - cache = cache, - input_mask = input_mask, - preprocess_only = preprocess_only, - last_id_only = last_id_only, - loras = loras, - return_last_state = return_last_state, - position_offsets = position_offsets, - abort_event = abort_event, - **kwargs) + result, last_state = self.forward_chunk(input_ids = input_ids, + cache = cache, + input_mask = input_mask, + preprocess_only = preprocess_only, + last_id_only = last_id_only, + loras = loras, + return_last_state = return_last_state, + position_offsets = position_offsets, + abort_event = abort_event, + **kwargs) if abort_event and abort_event.is_set(): return @@ -744,16 +744,16 @@ def forward(self, _last_id_only = last_id_only _preprocess_only = preprocess_only or (chunk_end < q_len and last_id_only) - r, ls = self._forward(input_ids = input_ids[:, chunk_begin : chunk_end], - cache = cache, - input_mask = input_mask, - preprocess_only = _preprocess_only, - last_id_only = _last_id_only, - loras = loras, - return_last_state = return_last_state and remaining_q_len <= chunk_size, - position_offsets = position_offsets, - abort_event = abort_event, - **kwargs) + r, ls = self.forward_chunk(input_ids = input_ids[:, chunk_begin : chunk_end], + cache = cache, + input_mask = input_mask, + preprocess_only = _preprocess_only, + last_id_only = _last_id_only, + loras = loras, + return_last_state = return_last_state and remaining_q_len <= chunk_size, + position_offsets = position_offsets, + abort_event = abort_event, + **kwargs) if abort_event and abort_event.is_set(): return @@ -772,17 +772,18 @@ def forward(self, @torch.inference_mode() - def _forward(self, - input_ids: torch.Tensor, - cache: ExLlamaV2CacheBase | list[ExLlamaV2CacheBase] | None = None, - input_mask: torch.Tensor | None = None, - preprocess_only: bool = False, - last_id_only: bool = False, - loras: list[ExLlamaV2Lora] | None = None, - return_last_state: bool = False, - position_offsets: torch.Tensor | None = None, - abort_event: threading.Event | None = None, - **kwargs) \ + def forward_chunk(self, + input_ids: torch.Tensor, + cache: ExLlamaV2CacheBase | list[ExLlamaV2CacheBase] | None = None, + input_mask: torch.Tensor | None = None, + preprocess_only: bool = False, + last_id_only: bool = False, + loras: list[ExLlamaV2Lora] | None = None, + return_last_state: bool = False, + position_offsets: torch.Tensor | None = None, + abort_event: threading.Event | None = None, + attn_params: ExLlamaV2Attention.Params | None = None, + **kwargs) \ -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: batch_size, seq_len = input_ids.shape @@ -802,7 +803,9 @@ def _forward(self, # assert cache is None or isinstance(cache, list) or batch_size <= cache.batch_size x = input_ids - attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets) + + if not attn_params: + attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets) last_state = None last_module = None