From ad6feb5c4bf2f85895bb122b652574da2bcf8a87 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 5 Jun 2024 17:05:01 +0200 Subject: [PATCH] =?UTF-8?q?Reduce=20by=202=20the=20memory=20requirement=20?= =?UTF-8?q?in=20`generate()`=20=F0=9F=94=A5=F0=9F=94=A5=F0=9F=94=A5=20(#30?= =?UTF-8?q?536)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix contrastive_search for new cache structure, and improve performance by removing inneficient torch.stack(torch.split(x, top_k, dim=0)) * Fix _contrastive_search for non-standard cache using ellipsis slicing * Fix all outputs.logits memory leaks for all decoding strategies! * Fix small error in _contrastive_search() * Make all necessary change and revert for the new class * Apply coding style * Remove pipes in type hints for compatibility * correct type hint * apply style * Use DynamicCache by default and solve conflicts * Fix rebase issues * Add `_supports_dynamic_cache_class` in models for models that support DynamicCache but not other caches to make DynamicCache the default for more models * Create generation config to return legacy format by default, or to choose not to * style * Fix case when use_cache is False * Remove default DynamicCache in assiste_decoding if assistant_model does not support it + fix _seen_tokens when cropping cache * Update prepare_inputs_for_generation() for case with empty DynamicCache * Correct return of args in _assisted_decoding * Remove EfficientDynamicCache as it is no longer needed * Correct mistake in generation config * Move cache logic of assisted decoding to AssistedCandidateGenerator.__init__ * change DynamicCache function names from "split" to "batch_split" for readability + apply coding style * Remove `_supports_dynamic_cache_class` attribute after rebase * Correct missing line lost in conflict resolution during rebasing * Add special case for Jamba * Fix jamba test * Coding style * coding style * Correct missing import in rebasing * Simplify _validate_model_kwargs based on removal of _supports_dynamic_cache attribute * Simplify code paths in _contrastive_search * coding style * Update docstrings of cache methods * Update prepare_inputs_for_generation() -> past_key_values are always Cache objects --- src/transformers/cache_utils.py | 57 ++++- .../generation/candidate_generator.py | 18 +- .../generation/configuration_utils.py | 3 + src/transformers/generation/utils.py | 195 +++++++++++++----- .../models/cohere/modeling_cohere.py | 22 +- src/transformers/models/dbrx/modeling_dbrx.py | 22 +- .../models/gemma/modeling_gemma.py | 22 +- .../models/idefics2/modeling_idefics2.py | 14 +- .../models/llama/modeling_llama.py | 22 +- .../models/mistral/modeling_mistral.py | 24 +-- .../models/mixtral/modeling_mixtral.py | 14 +- src/transformers/models/olmo/modeling_olmo.py | 22 +- .../models/persimmon/modeling_persimmon.py | 14 +- src/transformers/models/phi/modeling_phi.py | 14 +- src/transformers/models/phi3/modeling_phi3.py | 14 +- .../models/qwen2/modeling_qwen2.py | 14 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 14 +- .../models/stablelm/modeling_stablelm.py | 14 +- .../models/starcoder2/modeling_starcoder2.py | 14 +- 19 files changed, 327 insertions(+), 206 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c7dd13ea59bcca..04ba337ef436b3 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -377,7 +377,8 @@ def get_max_length(self) -> Optional[int]: return None def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: - """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" legacy_cache = () for layer_idx in range(len(self)): legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) @@ -385,7 +386,8 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: @classmethod def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" cache = cls() if past_key_values is not None: for layer_idx in range(len(past_key_values)): @@ -393,6 +395,57 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens cache.update(key_states, value_states, layer_idx) return cache + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + + # In case it is negative + if maximum_length < 0: + maximum_length = self.get_seq_length() - abs(maximum_length) + + if self.get_seq_length() <= maximum_length: + return + + self._seen_tokens = maximum_length + for idx in range(len(self.key_cache)): + self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :] + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + class QuantizedCache(DynamicCache): """ diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 52371d94dc56d1..ccbbb412bd1cdc 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -116,6 +116,19 @@ def __init__( value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) ) + # Remove potential default DynamicCache if assistant does not support it + if "past_key_values" in assistant_kwargs.keys(): + if ( + isinstance(assistant_kwargs["past_key_values"], DynamicCache) + and not self.assistant_model._supports_cache_class + ): + # Cache is empty -> remove it from kwargs + if len(assistant_kwargs["past_key_values"]) == 0: + del assistant_kwargs["past_key_values"] + # Cache is not empty -> convert to legacy + else: + assistant_kwargs["past_key_values"] = assistant_kwargs["past_key_values"].to_legacy_cache() + if "assistant_encoder_outputs" in model_kwargs: assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] elif assistant_model.config.is_encoder_decoder: @@ -387,10 +400,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length): for idx in range(len(past_key_values)): past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] elif isinstance(past_key_values, DynamicCache): - for idx in range(len(past_key_values.key_cache)): - if past_key_values.value_cache[idx].shape[-1] != 0: - past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :] - past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :] + past_key_values.crop(maximum_length) elif past_key_values is not None: for idx in range(len(past_key_values)): diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index efcece82f579de..0e15014b34ceb7 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -313,6 +313,8 @@ class GenerationConfig(PushToHubMixin): Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and it will be converted to its repsective `CacheConfig` internally. Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. + return_legacy_cache (`bool`, *optional*, default to `True`): + Whether to return the legacy or new format of the cache when `DynamicCache` is used by default. > Wild card @@ -404,6 +406,7 @@ def __init__(self, **kwargs): self.cache_config = cache_config_class() elif isinstance(self.cache_config, dict): self.cache_config = cache_config_class.from_dict(self.cache_config) + self.return_legacy_cache = kwargs.pop("return_legacy_cache", True) # Prompt lookup decoding self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 47ca012f22e7d6..967980c714fe09 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1448,6 +1448,16 @@ def _get_decoder_start_token_id( else: return + def _supports_default_dynamic_cache(self) -> bool: + """ + Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`. + This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which + uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in + order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed + for `HybridMambaAttentionDynamicCache`). + """ + return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower() + def _prepare_special_tokens( self, generation_config: GenerationConfig, @@ -1709,6 +1719,7 @@ def generate( input_ids_length=input_ids_length, ) + use_dynamic_cache_by_default = False if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: raise ValueError( "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " @@ -1750,6 +1761,16 @@ def generate( ) model_kwargs["past_key_values"] = cache_class(cache_config) + # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that + # keeps copying the cache thus using much more memory + elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): + past = model_kwargs.get("past_key_values", None) + if past is None: + model_kwargs["past_key_values"] = DynamicCache() + use_dynamic_cache_by_default = True + elif isinstance(past, tuple): + model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past) + use_dynamic_cache_by_default = True self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -2018,6 +2039,11 @@ def typeerror(): **model_kwargs, ) + # Convert to legacy cache if needed + if use_dynamic_cache_by_default and generation_config.return_legacy_cache: + if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): + if isinstance(result.past_key_values, DynamicCache): + result.past_key_values = result.past_key_values.to_legacy_cache() return result def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: @@ -2185,7 +2211,10 @@ def _contrastive_search( while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step - if model_kwargs.get("past_key_values") is None: + if model_kwargs.get("past_key_values") is None or ( + isinstance(model_kwargs["past_key_values"], Cache) + and model_kwargs["past_key_values"].get_seq_length() == 0 + ): # prepare inputs model_kwargs["use_cache"] = True model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) @@ -2204,7 +2233,9 @@ def _contrastive_search( last_hidden_states = outputs.hidden_states[-1] # next logit for contrastive search to select top-k candidate tokens - logit_for_next_step = outputs.logits[:, -1, :] + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration + # (the clone itself is always small) + logit_for_next_step = outputs.logits[:, -1, :].clone() model_kwargs = self._update_model_kwargs_for_generation( outputs, @@ -2212,6 +2243,7 @@ def _contrastive_search( is_encoder_decoder=self.config.is_encoder_decoder, standardize_cache_format=True, ) + if not sequential: # Expands model inputs top_k times, for batched forward passes (akin to beam search). _, model_kwargs = self._expand_inputs_for_generation( @@ -2261,25 +2293,28 @@ def _contrastive_search( else (outputs.hidden_states,) ) - # Replicates the new past_key_values to match the `top_k` candidates - new_key_values = [] - past = model_kwargs["past_key_values"] - for layer in past: - items = [] - # item is either the key or the value matrix - for item in layer: - if sequential: - items.append(item.repeat_interleave(1, dim=0)) - else: - items.append(item.repeat_interleave(top_k, dim=0)) - new_key_values.append(tuple(items)) - if not isinstance(past, DynamicCache): - past = tuple(new_key_values) - else: - for layer_idx in range(len(new_key_values)): - past.key_cache[layer_idx] = new_key_values[layer_idx][0] - past.value_cache[layer_idx] = new_key_values[layer_idx][1] - model_kwargs["past_key_values"] = past + # This is needed to properly delete outputs.logits which may be very large for this first iteration + # Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward() + del outputs + + if not sequential: + # Replicates the new past_key_values to match the `top_k` candidates + past = model_kwargs["past_key_values"] + # If it is a static cache, modify it in-place layer after layer to save memory + if isinstance(past, DynamicCache): + past.batch_repeat_interleave(top_k) + else: + new_key_values = [] + for layer in past: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item.repeat_interleave(top_k, dim=0)) + new_key_values.append(tuple(items)) + + past = tuple(new_key_values) + + model_kwargs["past_key_values"] = past if sequential: all_outputs = [] @@ -2293,6 +2328,12 @@ def _contrastive_search( output_hidden_states=True, output_attentions=output_attentions, ) + if isinstance(outputs["past_key_values"], DynamicCache): + # Remove past K-V from output since we don't need to stack later + outputs["past_key_values"] = None + # Remove last token from past K-V since we don't want to append it at this point + model_kwargs["past_key_values"].crop(-1) + all_outputs.append(outputs) outputs = stack_model_outputs(all_outputs) @@ -2307,6 +2348,11 @@ def _contrastive_search( output_hidden_states=True, output_attentions=output_attentions, ) + + # This is essential to avoid having a last reference to the big past K-V and double the necesary memory + # in the next loop + del next_model_inputs + # name is different for encoder-decoder and decoder-only models if self.config.is_encoder_decoder: next_hidden = outputs.decoder_hidden_states[-1] @@ -2316,7 +2362,6 @@ def _contrastive_search( full_hidden_states = outputs.hidden_states logits = outputs.logits[:, -1, :] - context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the @@ -2325,6 +2370,9 @@ def _contrastive_search( selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) selected_idx = selected_idx.to("cpu") + # This will be used instead of the previous inneficient torch.stack(torch.split()) + augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)]) + # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores # (model confidence minus degeneration penalty); (6) decoder hidden_states @@ -2354,22 +2402,19 @@ def _contrastive_search( else: next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) - new_key_values = [] - for layer in next_past_key_values: - items = [] - # item is either the key or the value matrix - for item in layer: - item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] - item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] - items += [item] - new_key_values += [items] - - if not isinstance(next_past_key_values, DynamicCache): - next_past_key_values = tuple(new_key_values) + # Do it in-place layer per layer to save memory + if isinstance(next_past_key_values, DynamicCache): + next_past_key_values.batch_select_indices(augmented_idx) else: - for layer_idx in range(len(new_key_values)): - next_past_key_values.key_cache[layer_idx] = new_key_values[layer_idx][0] - next_past_key_values.value_cache[layer_idx] = new_key_values[layer_idx][1] + new_key_values = [] + for layer in next_past_key_values: + items = [] + # item is either the key or the value matrix + for item in layer: + items.append(item[augmented_idx, ...]) + new_key_values.append(tuple(items)) + + next_past_key_values = tuple(new_key_values) logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] @@ -2431,13 +2476,16 @@ def _contrastive_search( # Contrastive search works by forward looking at the next token, so we need to exclude it from # `past_key_values` to be consistent with the other decoding methods if model_kwargs.get("past_key_values") is not None: - past_key_values = [] - for layer in model_kwargs["past_key_values"]: - layer_past_key_values = [] - for item in layer: - layer_past_key_values.append(item[..., :-1, :]) - past_key_values.append(tuple(layer_past_key_values)) - model_kwargs["past_key_values"] = tuple(past_key_values) + if isinstance(model_kwargs["past_key_values"], DynamicCache): + model_kwargs["past_key_values"].crop(-1) + else: + past_key_values = [] + for layer in model_kwargs["past_key_values"]: + layer_past_key_values = [] + for item in layer: + layer_past_key_values.append(item[..., :-1, :]) + past_key_values.append(tuple(layer_past_key_values)) + model_kwargs["past_key_values"] = tuple(past_key_values) if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( @@ -2588,7 +2636,9 @@ def _sample( if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) @@ -2639,6 +2689,10 @@ def _sample( unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + if streamer is not None: streamer.end() @@ -2846,7 +2900,9 @@ def _beam_search( cur_len = cur_len + 1 continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) @@ -2922,6 +2978,13 @@ def _beam_search( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx @@ -3125,7 +3188,9 @@ def _group_beam_search( if output_scores: processed_score = torch.zeros_like(outputs.logits[:, -1, :]) if output_logits: - raw_logit_score = outputs.logits[:, -1, :] + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + raw_logit_score = outputs.logits[:, -1, :].clone() for beam_group_idx in range(num_beam_groups): group_start_idx = beam_group_idx * num_sub_beams @@ -3142,6 +3207,7 @@ def _group_beam_search( group_input_ids = input_ids[batch_group_indices] # select outputs of beams of current group only + # No need to clone() the logits here as they will not retain outputs.logits at the end of the loop next_token_logits = outputs.logits[batch_group_indices, -1, :] next_token_scores = nn.functional.log_softmax( @@ -3231,6 +3297,13 @@ def _group_beam_search( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], reordering_indices @@ -3393,7 +3466,9 @@ def _constrained_beam_search( cur_len = cur_len + 1 continue # don't waste resources running the code we don't need - next_token_logits = outputs.logits[:, -1, :] + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone() next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) @@ -3461,6 +3536,13 @@ def _constrained_beam_search( model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory + # (that way the memory peak does not include outputs.logits) + del outputs + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx @@ -3597,6 +3679,13 @@ def _assisted_decoding( unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + # This is needed if return_dict_in_generate is True + if isinstance(model_kwargs.get("past_key_values", None), DynamicCache): + if len(model_kwargs["past_key_values"]) == 0: + start_from_empty_dynamic_cache = True + else: + start_from_empty_dynamic_cache = False + this_peer_finished = False while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): cur_len = input_ids.shape[-1] @@ -3709,8 +3798,10 @@ def _assisted_decoding( if output_logits: raw_logits += (next_token_logits,) - if "past_key_values" not in model_kwargs: + if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache: added_len = new_cur_len + # set it to false for other iterations + start_from_empty_dynamic_cache = False else: added_len = n_matches + 1 @@ -3909,6 +4000,9 @@ def _split(data, full_batch_size: int, split_size: int = None): return [None] * (full_batch_size // split_size) if isinstance(data, torch.Tensor): return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] + # New cache format + elif isinstance(data, DynamicCache): + return data.batch_split(full_batch_size, split_size) elif isinstance(data, tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0], tuple): @@ -4012,6 +4106,9 @@ def _concat(data): return None if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) + # New cache format + elif isinstance(data[0], DynamicCache): + return DynamicCache.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 9f01450bb6ed0d..f3d62af5bae5ab 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -1167,18 +1167,14 @@ def prepare_inputs_for_generation( ): past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1208,7 +1204,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index d34ce400ccf9ae..cbb2231b5a0e5b 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1443,18 +1443,14 @@ def prepare_inputs_for_generation( ): past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1484,7 +1480,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 2c948dd74c896d..c0a8c193d4cc5c 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -1163,18 +1163,14 @@ def prepare_inputs_for_generation( ): past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1204,7 +1200,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 6acabad0635b3f..bd44483c7992cd 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1876,15 +1876,13 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 # Omit tokens covered by past_key_values if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1915,7 +1913,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 92c3249247e336..7cef177fef99f5 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1218,18 +1218,14 @@ def prepare_inputs_for_generation( ): past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1259,7 +1255,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 0e1241c301c3f0..d266c6b4f47216 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1240,21 +1240,17 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # Omit tokens covered by past_key_values past_length = 0 + # Omit tokens covered by past_key_values if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1294,7 +1290,7 @@ def prepare_inputs_for_generation( attention_mask = attention_mask[:, -past_key_values.max_cache_len :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids.contiguous()} diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 4c694de0c36a49..c64ae79a931df3 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1407,15 +1407,13 @@ def prepare_inputs_for_generation( output_router_logits=False, **kwargs, ): + past_length = 0 # Omit tokens covered by past_key_values if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1446,7 +1444,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 1630297cd82d19..50ec015e521cb2 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -1198,18 +1198,14 @@ def prepare_inputs_for_generation( ): past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() - max_cache_length = ( - torch.tensor(past_key_values.get_max_length(), device=input_ids.device) - if past_key_values.get_max_length() is not None - else None - ) - cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) - # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1239,7 +1235,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 803169ddd57517..9458c3361d2e81 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -832,14 +832,12 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -870,7 +868,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index d8c1f4a9b4a6f9..a2c3793c01194b 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1211,14 +1211,12 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1249,7 +1247,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 9ce7e44dcebda1..e14785bd1f8b18 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1299,14 +1299,12 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1337,7 +1335,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index a55bab6bc7d0cb..919c7442d69a0f 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1199,15 +1199,13 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 # Omit tokens covered by past_key_values if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1238,7 +1236,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 0e4b4b75e8120d..013368b1b82d3a 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1394,15 +1394,13 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 # Omit tokens covered by past_key_values if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1433,7 +1431,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 6c19e999692f30..264bc3e9739442 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1208,14 +1208,12 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1246,7 +1244,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 32e5998884a7d8..97ea7f9509ebea 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1182,15 +1182,13 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 # Omit tokens covered by past_key_values if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1221,7 +1219,7 @@ def prepare_inputs_for_generation( position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: + if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids}