diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 5990fb1f6b7d73..e270818426752f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1210,11 +1210,12 @@ def prepare_inputs_for_generation( input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] if cache_info is None: - cache_info = torch.arange(past_length, past_length + input_length, device=input_ids.device) + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) elif use_cache: - cache_info = cache_info[-input_length:] + cache_position = cache_info.position[-input_length:] + cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) - cache_info = CacheInfo(position=cache_info, length=int(cache_info[-1]) + 1) model_inputs.update( { "position_ids": position_ids,