-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: implement the min_tokens sampling parameter #3124
Conversation
Signed-off-by: Travis Johnson <[email protected]>
This can be used to prevent the EOS token (and other stop tokens) from being generated by the model when using min_tokens. Signed-off-by: Travis Johnson <[email protected]>
@simon-mo I am unable to add you as a reviewer but tagging you RE: #1945 (comment). I made the change to use a Logits Processor to penalize the tokens, but the changes are more than just the logits processor. The additional changes are to skip the stop sequences check if min_tokens have not yet been generated. Let me know if you had something else in mind for the Processor. I'm also looking for a good way to inject the MinNewTokensProcessor automatically when |
@tjohnson31415 IMHO since an explicit parameter is being introduced for this, it would be best for the eos token suppression to be tied to that without having to additionally pass a LogitsProcessor. There is already an An additional advantage is that it would then be possible to have a vectorized implementation (not necessarily in this first PR). One question in this case is whether other provided stop tokens (if any) should be suppressed in addition to eos, I'd lean towards yes but would be good to get input from others on that too. |
Signed-off-by: Travis Johnson <[email protected]>
@@ -139,6 +142,35 @@ def _get_bin_counts_and_mask( | |||
return bin_counts, mask | |||
|
|||
|
|||
def _apply_min_tokens_penalty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great @tjohnson31415. But I think technically the token_ids_to_penalize should be determined per seq_group (i.e. also within the loop) since they may be different per seq group. The indexing gets a bit tricker but I think it might be possible with scatter_
with src=-torch.inf
. Or else could group the sequences that share the same list of tokens to pernalize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Heh, yup. Thanks for pointing that out. I still need to write some tests for this 😅. I pushed a fix to build a list of coordinates to penalize within the loop so the stop ids are per seq_group.
I was trying to use scatter
initially, but couldn't figure out how to get it to work. In particular, scatter uses a rectangular tensor and doesn't seem to have a way to "skip" rows where we don't want to scatter into. So I think a gather-modify-scatter (where we gather across all sequences and stop token ids) would work, but we'd still need to index into the gather'd tensor to set the -inf
values.
Signed-off-by: Travis Johnson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tjohnson31415, LGTM!
We should add a test for this too, though perhaps lets wait for confirmation from the maintainers that this would be accepted.
Co-authored-by: Nick Hill <[email protected]>
Thank you for the contribution, and thank Nick for the review. My original intention is that There are some readability issue with |
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
Signed-off-by: Travis Johnson <[email protected]>
ab85f6d
to
6393a50
Compare
Signed-off-by: Travis Johnson <[email protected]>
* upstream/main: [Misc] Bump up transformers to v4.39.0 & Remove StarCoder2Config (vllm-project#3551) [Misc][Log] Add log for tokenizer length not equal to vocabulary size (vllm-project#3500) [🚀 Ready to be merged] Added support for Jais models (vllm-project#3183) Fix 1D query issue from `_prune_hidden_states` (vllm-project#3539) [PREFIX CACHING FOLLOW UP] OrderedDict-based evictor (vllm-project#3431) [BugFix] Hot fix in setup.py for neuron build (vllm-project#3537) Migrate `logits` computation and gather to `model_runner` (vllm-project#3233) [1/n][Chunked Prefill] Refactor input query shapes (vllm-project#3236) [1/n] Triton sampling kernel (vllm-project#3186) [Bugfix] Fix ROCm support in CMakeLists.txt (vllm-project#3534)
ce72382
to
90ac00c
Compare
@simon-mo this should be ready now! |
Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: Nick Hill <[email protected]>
Signed-off-by: Travis Johnson <[email protected]> Co-authored-by: Nick Hill <[email protected]>
Adds the
min_tokens
sampling parameter to ensure a minimum number of generated tokens.The implementation here is meant to align with https://github.com/IBM/text-generation-inference (IBM's fork of HF TGI). In particular, we want to ignore stop sequences and penalize the EOS token until min_tokens have been generated. Stop sequence can be generated within min_tokens tokens but generation will not terminate.
stop_token_ids
are treated like the EOS token and penalized so that they are not generated until min_tokens tokens have been generated.Related PR that stalled: #1945