Skip to content

Commit

Permalink
Fix rebase issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Cyrilvallez committed May 14, 2024
1 parent d6a6189 commit 8095a4a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 26 deletions.
50 changes: 33 additions & 17 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,6 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
if max_length is not None and previous_seq_length + new_seq_length > max_length:
return max_length - new_seq_length
return previous_seq_length

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
Expand Down Expand Up @@ -192,7 +184,7 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens."""
Expand Down Expand Up @@ -298,15 +290,22 @@ def update(
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if isinstance(key_states, list):
self._seen_tokens += sum(x.shape[-2] for x in key_states)
else:
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
if isinstance(key_states, list):
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache.append([key_states])
self.value_cache.append([value_states])
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
self.key_cache[layer_idx].append(key_states)
self.value_cache[layer_idx].append(value_states)

# Whenever we have more than `self.restack_limit` new K-V value, cat() them. That way, we keep a relatively low number
# of tensors in self.key_cache[layer_idx], which is more efficient to later cat() them all, and we only
Expand All @@ -329,10 +328,9 @@ def update(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
return sum(x.shape[-2] for x in self.key_cache[layer_idx])

def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
Expand Down Expand Up @@ -432,11 +430,29 @@ def from_legacy_cache(
restack_limit: Optional[int] = None,
) -> "EfficientDynamicCache":
"""Converts a cache in the legacy cache format Tuple[Tuple[torch.Tensor]] or Tuple[Tuple[List[torch.Tensor]]] into an equivalent `EfficientDynamicCache`."""
# Small check to ensure that model implementation will not use `from_legacy_cache()` with an already existing EfficientDynamicCache instance.
# That would result in a copy that would annihilate the purpose of this class, and given that the old implementation
# would trigger this case, it could arise again in the future when adding Cache classes to more models
if isinstance(past_key_values, Cache):
raise ValueError("Cannot use `from_legacy_cache()` with a Cache instance.")

cache = cls(restack_limit=restack_limit)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
if layer_idx == 0:
if isinstance(key_states, list):
cache._seen_tokens += sum(x.shape[-2] for x in key_states)
else:
cache._seen_tokens += key_states.shape[-2]

# Update the cache
if isinstance(key_states, list):
cache.key_cache.append(key_states)
cache.value_cache.append(value_states)
else:
cache.key_cache.append([key_states])
cache.value_cache.append([value_states])
return cache


Expand Down
1 change: 1 addition & 0 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch

from ..cache_utils import DynamicCache
from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor


if TYPE_CHECKING:
Expand Down
21 changes: 12 additions & 9 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import Cache, DynamicCache, EfficientDynamicCache, StaticCache
from ..cache_utils import Cache, DynamicCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -1620,7 +1620,7 @@ def generate(
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_cache_class:
past = model_kwargs.get('past_key_values', None)
past = model_kwargs.get("past_key_values", None)
if past is None:
model_kwargs["past_key_values"] = DynamicCache()
elif isinstance(past, tuple):
Expand Down Expand Up @@ -2029,8 +2029,7 @@ def _contrastive_search(
)
elif (
not isinstance(past_key_values[0], (tuple, torch.Tensor))
or (isinstance(past_key_values[0][0], torch.Tensor) and past_key_values[0][0].shape[0] != batch_size)
or (isinstance(past_key_values[0][0], list) and past_key_values[0][0][0].shape[0] != batch_size)
or past_key_values[0][0].shape[0] != batch_size
):
raise ValueError(
f"{self.__class__.__name__} does not have a standard cache format and therefore **can't** be "
Expand Down Expand Up @@ -2085,7 +2084,7 @@ def _contrastive_search(
for item in layer:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(tuple(items))

past = tuple(new_key_values)

model_kwargs["past_key_values"] = past
Expand Down Expand Up @@ -2145,7 +2144,7 @@ def _contrastive_search(
selected_idx = selected_idx.to("cpu")

# This will be used instead of the previous inneficient torch.stack(torch.split())
augmented_idx = torch.tensor([x + i*top_k for i, x in enumerate(selected_idx)])
augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)])

# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
Expand Down Expand Up @@ -2179,8 +2178,12 @@ def _contrastive_search(
# Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache):
for layer_idx in range(len(next_past_key_values)):
next_past_key_values.key_cache[layer_idx] = next_past_key_values.key_cache[layer_idx][augmented_idx, ...]
next_past_key_values.value_cache[layer_idx] = next_past_key_values.value_cache[layer_idx][augmented_idx, ...]
next_past_key_values.key_cache[layer_idx] = next_past_key_values.key_cache[layer_idx][
augmented_idx, ...
]
next_past_key_values.value_cache[layer_idx] = next_past_key_values.value_cache[layer_idx][
augmented_idx, ...
]
else:
new_key_values = []
for layer in next_past_key_values:
Expand All @@ -2192,7 +2195,6 @@ def _contrastive_search(

next_past_key_values = tuple(new_key_values)


logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]

# Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
Expand Down Expand Up @@ -3648,6 +3650,7 @@ def _assisted_decoding(
else:
return input_ids


def _speculative_sampling(
candidate_input_ids,
candidate_logits,
Expand Down

0 comments on commit 8095a4a

Please sign in to comment.