Skip to content

Commit

Permalink
Reduce by 2 the memory requirement in generate() 🔥🔥🔥 (#30536)
Browse files Browse the repository at this point in the history
* Fix contrastive_search for new cache structure, and improve performance by removing inneficient torch.stack(torch.split(x, top_k, dim=0))

* Fix _contrastive_search for non-standard cache using ellipsis slicing

* Fix all outputs.logits memory leaks for all decoding strategies!

* Fix small error in _contrastive_search()

* Make all necessary change and revert for the new class

* Apply coding style

* Remove pipes in type hints for compatibility

* correct type hint

* apply style

* Use DynamicCache by default and solve conflicts

* Fix rebase issues

* Add `_supports_dynamic_cache_class` in models for models that support DynamicCache but not other caches to make DynamicCache the default for more models

* Create generation config to return legacy format by default, or to choose not to

* style

* Fix case when use_cache is False

* Remove default DynamicCache in assiste_decoding if assistant_model does not support it + fix _seen_tokens when cropping cache

* Update prepare_inputs_for_generation() for case with empty DynamicCache

* Correct return of args in _assisted_decoding

* Remove EfficientDynamicCache as it is no longer needed

* Correct mistake in generation config

* Move cache logic of assisted decoding to AssistedCandidateGenerator.__init__

* change DynamicCache function names from "split" to "batch_split" for readability + apply coding style

* Remove `_supports_dynamic_cache_class` attribute after rebase

* Correct missing line lost in conflict resolution during rebasing

* Add special case for Jamba

* Fix jamba test

* Coding style

* coding style

* Correct missing import in rebasing

* Simplify _validate_model_kwargs based on removal of _supports_dynamic_cache attribute

* Simplify code paths in _contrastive_search

* coding style

* Update docstrings of cache methods

* Update prepare_inputs_for_generation() -> past_key_values are always Cache objects
  • Loading branch information
Cyrilvallez authored Jun 5, 2024
1 parent d6276f0 commit bd5091d
Show file tree
Hide file tree
Showing 19 changed files with 327 additions and 206 deletions.
57 changes: 55 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,22 +377,75 @@ def get_max_length(self) -> Optional[int]:
return None

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""

# In case it is negative
if maximum_length < 0:
maximum_length = self.get_seq_length() - abs(maximum_length)

if self.get_seq_length() <= maximum_length:
return

self._seen_tokens = maximum_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :]

def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, idx)
return cache

def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)

def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]


class QuantizedCache(DynamicCache):
"""
Expand Down
18 changes: 14 additions & 4 deletions src/transformers/generation/candidate_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ def __init__(
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)

# Remove potential default DynamicCache if assistant does not support it
if "past_key_values" in assistant_kwargs.keys():
if (
isinstance(assistant_kwargs["past_key_values"], DynamicCache)
and not self.assistant_model._supports_cache_class
):
# Cache is empty -> remove it from kwargs
if len(assistant_kwargs["past_key_values"]) == 0:
del assistant_kwargs["past_key_values"]
# Cache is not empty -> convert to legacy
else:
assistant_kwargs["past_key_values"] = assistant_kwargs["past_key_values"].to_legacy_cache()

if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:
Expand Down Expand Up @@ -387,10 +400,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
elif isinstance(past_key_values, DynamicCache):
for idx in range(len(past_key_values.key_cache)):
if past_key_values.value_cache[idx].shape[-1] != 0:
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
past_key_values.crop(maximum_length)

elif past_key_values is not None:
for idx in range(len(past_key_values)):
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ class GenerationConfig(PushToHubMixin):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
> Wild card
Expand Down Expand Up @@ -404,6 +406,7 @@ def __init__(self, **kwargs):
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)

# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
Expand Down
Loading

0 comments on commit bd5091d

Please sign in to comment.