Skip to content

Commit

Permalink
update prepare_inputs_for_generation
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jun 7, 2024
1 parent 169ed37 commit c0300c3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1210,11 +1210,12 @@ def prepare_inputs_for_generation(

input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_info is None:
cache_info = torch.arange(past_length, past_length + input_length, device=input_ids.device)
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1)
elif use_cache:
cache_info = cache_info[-input_length:]
cache_position = cache_info.position[-input_length:]
cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1)

cache_info = CacheInfo(position=cache_info, length=int(cache_info[-1]) + 1)
model_inputs.update(
{
"position_ids": position_ids,
Expand Down

0 comments on commit c0300c3

Please sign in to comment.