Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: Deprecate returning legacy cache by default; Handle use_cache=False #32863

Merged
merged 9 commits into from
Aug 22, 2024

Conversation

gante
Copy link
Member

@gante gante commented Aug 17, 2024

What does this PR do?

Another step towards using Cache everywhere 💪

This PR makes the following [Cache+generate]-related changes:

  1. Don't initialize a cache when use_cache=False (fixes Cache updating when use_cache = False #32843 )
  2. generate tests now explicitly pass use_cache, instead of setting it in model.config 🤢 We were relying on a LOT of side effects, and missing the incorrect case mentioned in Cache updating when use_cache = False #32843
  3. Add sanity-checks on cache-related parameters, in generation_config
  4. Add a deprecation cycle on the default cache return type, so we start returning a Cache instance by default on generate
  5. Isolate all cache initialization logic in generate into a single function, and reorganize the logic by blocks

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@gante gante changed the title Generate: Update cache initialization Generate: Deprecate returning legacy cache by default; Handle use_cache=False Aug 17, 2024
@@ -130,9 +130,29 @@ class GenerationConfig(PushToHubMixin):
[this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
dola_layers (`str` or `List[int]`, *optional*):
Copy link
Member Author

@gante gante Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved up to this documentation section (Parameters that control the generation strategy used), which makes more sense

`'high'` to improve short-answer tasks. Check the [documentation](https://github.com/huggingface/transformers/blob/main/docs/source/en/generation_strategies.md)
or [the paper](https://arxiv.org/abs/2309.03883) for more details.

> Parameters that control the cache
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new cache-related docs section in GenerationConfig, moved all cache-related flags here

@@ -544,8 +539,9 @@ def validate(self, is_init=False):
raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.")
if self.pad_token_id is not None and self.pad_token_id < 0:
warnings.warn(
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch generating, if there is padding. "
"Please set `pad_token_id` explicitly by `model.generation_config.pad_token_id=PAD_TOKEN_ID` to avoid errors in generation, and ensure your `input_ids` input does not have negative values."
f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch "
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(>120 chars/line)

@@ -675,6 +671,14 @@ def validate(self, is_init=False):
group_error_prefix
+ "`diversity_penalty` should be greater than `0.0`, otherwise your groups will be identical."
)
# DoLa generation
if self.dola_layers is not None and (self.repetition_penalty is None or self.repetition_penalty < 1.2):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(moved)

@@ -136,27 +136,23 @@ class GenerateDecoderOnlyOutput(ModelOutput):
sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
if all batches finished early due to the `eos_token_id`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our docs we often mention that there are two ways to parameterize generate (generation_config or pass arg to generate). I don't think we need to be verbose here.

Also, setting through config is deprecated 😉

Comment on lines +154 to +155
Returns the model cache, used to speed up decoding. Different models have a different cache format, check
the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
Copy link
Member Author

@gante gante Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrote this one.

The old description was outdated (legacy cache), and we now know that different models have different caches, so we shouldn't be precise here. The model class docs can be more precise, let's redirect users there.

@@ -328,6 +312,7 @@ class GenerateBeamEncoderDecoderOutput(ModelOutput):
past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None


# TODO (joao): remove the equivalent classes and typing shortcuts below in v5
Copy link
Member Author

@gante gante Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(These aliases made sense in the past, not anymore. They are, however, hard to deprecate!)

@@ -1497,6 +1482,127 @@ def _supports_default_dynamic_cache(self) -> bool:
"""
return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower()

def _prepare_cache_for_generation(
Copy link
Member Author

@gante gante Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New function, moving the cache logic from generate. I've organized the logic in blocks, putting the cases where we DON'T prepare a new cache at the top.

It is doing essentially the same, except for the Quick escape route 2, which is new. Added the warning in Quick escape route 3.

if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
result.past_key_values = result.past_key_values.to_legacy_cache()
# Convert to legacy cache format if requested
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is expanded to handle a deprecation cycle

@@ -194,6 +194,7 @@ def _greedy_generate(
output_attentions=False,
output_hidden_states=False,
return_dict_in_generate=False,
use_cache=True,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes in this file: pass use_cache to generate, instead of relying on model.config.use_cache=False and its side-effects

added a check to confirm that the cache is None when we pass use_cache=False

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Let's make sure slow tests all pass as well here!

src/transformers/generation/configuration_utils.py Outdated Show resolved Hide resolved
src/transformers/generation/utils.py Outdated Show resolved Hide resolved
Comment on lines +1545 to +1550
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's create an issue and leave it up to the community in the mean time!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante gante merged commit a26de15 into huggingface:main Aug 22, 2024
23 checks passed
@gante gante deleted the update_cache_kwargs branch August 22, 2024 19:01
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cache updating when use_cache = False
3 participants