diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1f5a164815aaed..c4f7578be4d043 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -23,6 +23,13 @@ logger = logging.get_logger(__name__) +class CacheInfo: + + def __init__(self, position, length): + self.position = position + self._length = length + + @dataclass class Cache: """ @@ -854,7 +861,7 @@ def update( Return: A tuple containing the updated key and value states. """ - cache_position = cache_kwargs.get("cache_position") + cache_info = cache_kwargs.get("cache_info") k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -862,8 +869,8 @@ def update( k_out.copy_(key_states) v_out.copy_(value_states) else: - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + k_out[:, :, cache_info.position] = key_states + v_out[:, :, cache_info.position] = value_states return k_out, v_out diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index c0da2530fe2c4e..35816922678451 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, StaticCache +from ...cache_utils import Cache, CacheInfo, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -266,7 +266,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -282,8 +282,8 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # sin and cos are specific to RoPE models; cache_info needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -340,7 +340,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if isinstance(past_key_value, StaticCache): raise ValueError( @@ -367,8 +367,8 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # sin and cos are specific to RoPE models; cache_info needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -531,7 +531,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -546,7 +546,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_info=cache_info, ) bsz, q_len, _ = hidden_states.size() @@ -563,8 +563,8 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # sin and cos are specific to RoPE models; cache_info needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -585,6 +585,11 @@ def forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False + if cache_info._length > 0: + key_states = key_states[:, :, :cache_info._length, :] + value_states = value_states[:, :, :cache_info._length, :] + causal_mask = causal_mask[:, :, :, :cache_info._length] if causal_mask is not None else causal_mask + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, @@ -628,7 +633,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -662,7 +667,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_info=cache_info, ) hidden_states = residual + hidden_states @@ -798,7 +803,7 @@ def _init_weights(self, module): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + cache_info (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, this tensor is not affected by padding. It is used to update the cache in the correct position and to infer the complete sequence length. @@ -850,7 +855,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -878,17 +883,18 @@ def forward( return_legacy_cache = True # noqa: F841 past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if cache_position is None: + if cache_info is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) if position_ids is None: - position_ids = cache_position.unsqueeze(0) + position_ids = cache_info.position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + attention_mask, inputs_embeds, cache_info, past_key_values, output_attentions ) # embed positions @@ -925,7 +931,7 @@ def forward( past_key_values, output_attentions, use_cache, - cache_position, + cache_info, ) else: layer_outputs = decoder_layer( @@ -935,7 +941,7 @@ def forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, - cache_position=cache_position, + cache_info=cache_info, ) hidden_states = layer_outputs[0] @@ -969,7 +975,7 @@ def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, - cache_position: torch.Tensor, + cache_info: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): @@ -1022,7 +1028,7 @@ def _update_causal_mask( ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask *= torch.arange(target_length, device=device) > cache_info.position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit @@ -1090,7 +1096,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_info: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1134,7 +1140,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - cache_position=cache_position, + cache_info=cache_info, ) hidden_states = outputs[0] @@ -1171,14 +1177,14 @@ def prepare_inputs_for_generation( past_key_values=None, attention_mask=None, inputs_embeds=None, - cache_position=None, + cache_info=None, use_cache=True, **kwargs, ): past_length = 0 if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_info.position[0] if cache_position is not None else past_key_values.get_seq_length() max_cache_length = ( torch.tensor(past_key_values.get_max_length(), device=input_ids.device) if past_key_values.get_max_length() is not None @@ -1223,15 +1229,17 @@ def prepare_inputs_for_generation( model_inputs = {"input_ids": input_ids.contiguous()} input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - if cache_position is None: + if cache_info is None: cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) elif use_cache: - cache_position = cache_position[-input_length:] + cache_position = cache_info.position[-input_length:] + cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1) model_inputs.update( { "position_ids": position_ids, - "cache_position": cache_position, + "cache_info": cache_info, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask,