Skip to content

Commit

Permalink
cache_info
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Jul 4, 2024
1 parent dc72fd7 commit 9168904
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 32 deletions.
13 changes: 10 additions & 3 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
logger = logging.get_logger(__name__)


class CacheInfo:

def __init__(self, position, length):
self.position = position
self._length = length


@dataclass
class Cache:
"""
Expand Down Expand Up @@ -854,16 +861,16 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
cache_position = cache_kwargs.get("cache_position")
cache_info = cache_kwargs.get("cache_info")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
else:
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
k_out[:, :, cache_info.position] = key_states
v_out[:, :, cache_info.position] = value_states

return k_out, v_out

Expand Down
66 changes: 37 additions & 29 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...cache_utils import Cache, CacheInfo, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
Expand Down Expand Up @@ -266,7 +266,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_info: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

Expand All @@ -282,8 +282,8 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# sin and cos are specific to RoPE models; cache_info needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand Down Expand Up @@ -340,7 +340,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_info: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if isinstance(past_key_value, StaticCache):
raise ValueError(
Expand All @@ -367,8 +367,8 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# sin and cos are specific to RoPE models; cache_info needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
Expand Down Expand Up @@ -531,7 +531,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
cache_info: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
Expand All @@ -546,7 +546,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cache_info=cache_info,
)

bsz, q_len, _ = hidden_states.size()
Expand All @@ -563,8 +563,8 @@ def forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# sin and cos are specific to RoPE models; cache_info needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_info": cache_info}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

key_states = repeat_kv(key_states, self.num_key_value_groups)
Expand All @@ -585,6 +585,11 @@ def forward(
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if causal_mask is None and q_len > 1 else False

if cache_info._length > 0:
key_states = key_states[:, :, :cache_info._length, :]
value_states = value_states[:, :, :cache_info._length, :]
causal_mask = causal_mask[:, :, :, :cache_info._length] if causal_mask is not None else causal_mask

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
Expand Down Expand Up @@ -628,7 +633,7 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
cache_info: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Expand Down Expand Up @@ -662,7 +667,7 @@ def forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cache_info=cache_info,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -798,7 +803,7 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
cache_info (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
Expand Down Expand Up @@ -850,7 +855,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_info: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -878,17 +883,18 @@ def forward(
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
if cache_info is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_ids = cache_info.position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask, inputs_embeds, cache_info, past_key_values, output_attentions
)

# embed positions
Expand Down Expand Up @@ -925,7 +931,7 @@ def forward(
past_key_values,
output_attentions,
use_cache,
cache_position,
cache_info,
)
else:
layer_outputs = decoder_layer(
Expand All @@ -935,7 +941,7 @@ def forward(
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
cache_info=cache_info,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -969,7 +975,7 @@ def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
cache_info: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
Expand Down Expand Up @@ -1022,7 +1028,7 @@ def _update_causal_mask(
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask *= torch.arange(target_length, device=device) > cache_info.position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
Expand Down Expand Up @@ -1090,7 +1096,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_info: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -1134,7 +1140,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
cache_info=cache_info,
)

hidden_states = outputs[0]
Expand Down Expand Up @@ -1171,14 +1177,14 @@ def prepare_inputs_for_generation(
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
cache_info=None,
use_cache=True,
**kwargs,
):
past_length = 0
if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
past_length = cache_info.position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
Expand Down Expand Up @@ -1223,15 +1229,17 @@ def prepare_inputs_for_generation(
model_inputs = {"input_ids": input_ids.contiguous()}

input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
if cache_info is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1)
elif use_cache:
cache_position = cache_position[-input_length:]
cache_position = cache_info.position[-input_length:]
cache_info = CacheInfo(position=cache_position, length=int(cache_position[-1]) + 1)

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"cache_info": cache_info,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
Expand Down

0 comments on commit 9168904

Please sign in to comment.