-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: remove deprecated public decoding functions and streamline logic 🧼 #29956
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Tests on all models are passing 🙌 (the failing pipeline test seems unrelated, and passing locally on my end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yaaay, thanks for clean-up! Looks so much nicer
e3c994f
to
db8cc74
Compare
@@ -65,25 +65,16 @@ class GenerationConfig(PushToHubMixin): | |||
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods | |||
for text-decoder, text-to-text, speech-to-text, and vision-to-text models: | |||
|
|||
- *greedy decoding* by calling [`~generation.GenerationMixin._greedy_search`] if `num_beams=1` and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
functions not public -> no docs -> remove link to the docs
encoder_kwargs["output_attentions"] = generation_config.output_attentions | ||
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see comment on L1434)
model_kwargs["output_attentions"] = generation_config.output_attentions | ||
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of pulling information from generation_config
to pass around through model_kwargs
, let's use generation_config
directly.
A single object to hold all generation parameterization.
src/transformers/generation/utils.py
Outdated
output_attentions=generation_config.output_attentions, | ||
output_hidden_states=generation_config.output_hidden_states, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were being passed through model_kwargs
before
top_k: Optional[int] = 1, | ||
penalty_alpha: Optional[float] = 0, | ||
logits_processor: Optional[LogitsProcessorList] = None, | ||
logits_warper: Optional[LogitsProcessorList] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one was not needed -- contrastive search does not sample
src/transformers/generation/utils.py
Outdated
@@ -1674,6 +1680,9 @@ def generate( | |||
logits_processor=prepared_logits_processor, | |||
stopping_criteria=prepared_stopping_criteria, | |||
pad_token_id=generation_config.pad_token_id, | |||
eos_token_id=generation_config.eos_token_id, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On most decoding methods eos_token_ids
doesn't need to be passed -- it was used when the decoding method was called directly and stopping_criteria
was not passed.
However, beam methods still need it.
@@ -1945,69 +1940,9 @@ def _contrastive_search( | |||
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and | |||
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if | |||
`model.config.is_encoder_decoder=True`. | |||
|
|||
Examples: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
function not public -> let's remove the example (preparing its inputs will be more challenging now, as we no longer have API guarantees)
if eos_token_id is not None: | ||
if pad_token_id is None: | ||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the main generate
body, we set pad_token_id
to eos_token_id
in this situation -- this exception will never be reached
ping @ArthurZucker -- ready for review :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Late review but very nice cleanup sir! 🤗
src/transformers/generation/utils.py
Outdated
pad_token_id: Optional[int], | ||
output_attentions: bool, | ||
output_hidden_states: bool, | ||
output_scores: bool, | ||
output_logits: bool, | ||
return_dict_in_generate: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
theoretically some of these can be taken from the config / the generation config if it inherits them. But nit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! That's a good idea
db8cc74
to
a8b6e58
Compare
Reran slow tests locally, all seems good 👍 |
What does this PR do?
🧼 🧼 🧼
Calling the internal decoding functions as part of our public API was scheduled for removal in v4.41 (the next release). Its motivation was flexibility and conciseness: having multiple public interfaces for the same functionality forced us to add repeated logic in many places, increasing every time we added a new decoding method.
Due to this removal from the public API, a few things were changed/removed as a logical consequence:
generate
;generate
. As such, we can remove a lot of boilerplate (x = x if x is not None else self.generation_config.x
).Tests ran locally:
pytest --doctest-modules src/transformers/generation -vv
)RUN_SLOW=1 py.test tests/generation/test_utils.py -vv
)RUN_SLOW=1 py.test tests/test_cache_utils.py -vv
) -- same failures as inmain
RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py -vv
)RUN_SLOW=1 py.test tests/models/whisper/test_modeling_whisper.py -vv
) -- same failures as inmain