Skip to content

Commit

Permalink
Model: Expose chunk forward fn and allow attn_params as input to forw…
Browse files Browse the repository at this point in the history
…ard pass
  • Loading branch information
turboderp committed May 14, 2024
1 parent c771d63 commit c93088f
Showing 1 changed file with 35 additions and 32 deletions.
67 changes: 35 additions & 32 deletions exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,16 +683,16 @@ def forward(self,

assert q_len <= effective_max_input_len, "Maximum input length exceeded in model.forward"

result, last_state = self._forward(input_ids = input_ids,
cache = cache,
input_mask = input_mask,
preprocess_only = preprocess_only,
last_id_only = last_id_only,
loras = loras,
return_last_state = return_last_state,
position_offsets = position_offsets,
abort_event = abort_event,
**kwargs)
result, last_state = self.forward_chunk(input_ids = input_ids,
cache = cache,
input_mask = input_mask,
preprocess_only = preprocess_only,
last_id_only = last_id_only,
loras = loras,
return_last_state = return_last_state,
position_offsets = position_offsets,
abort_event = abort_event,
**kwargs)

if abort_event and abort_event.is_set(): return

Expand Down Expand Up @@ -744,16 +744,16 @@ def forward(self,
_last_id_only = last_id_only
_preprocess_only = preprocess_only or (chunk_end < q_len and last_id_only)

r, ls = self._forward(input_ids = input_ids[:, chunk_begin : chunk_end],
cache = cache,
input_mask = input_mask,
preprocess_only = _preprocess_only,
last_id_only = _last_id_only,
loras = loras,
return_last_state = return_last_state and remaining_q_len <= chunk_size,
position_offsets = position_offsets,
abort_event = abort_event,
**kwargs)
r, ls = self.forward_chunk(input_ids = input_ids[:, chunk_begin : chunk_end],
cache = cache,
input_mask = input_mask,
preprocess_only = _preprocess_only,
last_id_only = _last_id_only,
loras = loras,
return_last_state = return_last_state and remaining_q_len <= chunk_size,
position_offsets = position_offsets,
abort_event = abort_event,
**kwargs)

if abort_event and abort_event.is_set(): return

Expand All @@ -772,17 +772,18 @@ def forward(self,


@torch.inference_mode()
def _forward(self,
input_ids: torch.Tensor,
cache: ExLlamaV2CacheBase | list[ExLlamaV2CacheBase] | None = None,
input_mask: torch.Tensor | None = None,
preprocess_only: bool = False,
last_id_only: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
return_last_state: bool = False,
position_offsets: torch.Tensor | None = None,
abort_event: threading.Event | None = None,
**kwargs) \
def forward_chunk(self,
input_ids: torch.Tensor,
cache: ExLlamaV2CacheBase | list[ExLlamaV2CacheBase] | None = None,
input_mask: torch.Tensor | None = None,
preprocess_only: bool = False,
last_id_only: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
return_last_state: bool = False,
position_offsets: torch.Tensor | None = None,
abort_event: threading.Event | None = None,
attn_params: ExLlamaV2Attention.Params | None = None,
**kwargs) \
-> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:

batch_size, seq_len = input_ids.shape
Expand All @@ -802,7 +803,9 @@ def _forward(self,
# assert cache is None or isinstance(cache, list) or batch_size <= cache.batch_size

x = input_ids
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets)

if not attn_params:
attn_params = ExLlamaV2Attention.Params(batch_size, seq_len, past_len, input_mask, position_offsets)
last_state = None
last_module = None

Expand Down

0 comments on commit c93088f

Please sign in to comment.