Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement the min_tokens sampling parameter #3124

Merged
merged 14 commits into from
Mar 25, 2024

Conversation

tjohnson31415
Copy link
Contributor

@tjohnson31415 tjohnson31415 commented Feb 29, 2024

Adds the min_tokens sampling parameter to ensure a minimum number of generated tokens.

The implementation here is meant to align with https://github.com/IBM/text-generation-inference (IBM's fork of HF TGI). In particular, we want to ignore stop sequences and penalize the EOS token until min_tokens have been generated. Stop sequence can be generated within min_tokens tokens but generation will not terminate. stop_token_ids are treated like the EOS token and penalized so that they are not generated until min_tokens tokens have been generated.

Related PR that stalled: #1945

This can be used to prevent the EOS token (and other stop tokens) from being generated by the model when using min_tokens.

Signed-off-by: Travis Johnson <[email protected]>
@tjohnson31415
Copy link
Contributor Author

tjohnson31415 commented Feb 29, 2024

@simon-mo I am unable to add you as a reviewer but tagging you RE: #1945 (comment).

I made the change to use a Logits Processor to penalize the tokens, but the changes are more than just the logits processor. The additional changes are to skip the stop sequences check if min_tokens have not yet been generated. Let me know if you had something else in mind for the Processor. I'm also looking for a good way to inject the MinNewTokensProcessor automatically when min_tokens is specified in the sampling params.

@simon-mo simon-mo self-assigned this Feb 29, 2024
@njhill
Copy link
Member

njhill commented Feb 29, 2024

@tjohnson31415 IMHO since an explicit parameter is being introduced for this, it would be best for the eos token suppression to be tied to that without having to additionally pass a LogitsProcessor. There is already an ignore_eos parameter which is similar but I think is practically only useful for performance testing; I think min_tokens wouldn't have much utility unless used in conjunction with the MinTokensProcessor anyhow.

An additional advantage is that it would then be possible to have a vectorized implementation (not necessarily in this first PR).

One question in this case is whether other provided stop tokens (if any) should be suppressed in addition to eos, I'd lean towards yes but would be good to get input from others on that too.

@@ -139,6 +142,35 @@ def _get_bin_counts_and_mask(
return bin_counts, mask


def _apply_min_tokens_penalty(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great @tjohnson31415. But I think technically the token_ids_to_penalize should be determined per seq_group (i.e. also within the loop) since they may be different per seq group. The indexing gets a bit tricker but I think it might be possible with scatter_ with src=-torch.inf. Or else could group the sequences that share the same list of tokens to pernalize.

Copy link
Contributor Author

@tjohnson31415 tjohnson31415 Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, yup. Thanks for pointing that out. I still need to write some tests for this 😅. I pushed a fix to build a list of coordinates to penalize within the loop so the stop ids are per seq_group.

I was trying to use scatter initially, but couldn't figure out how to get it to work. In particular, scatter uses a rectangular tensor and doesn't seem to have a way to "skip" rows where we don't want to scatter into. So I think a gather-modify-scatter (where we gather across all sequences and stop token ids) would work, but we'd still need to index into the gather'd tensor to set the -inf values.

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tjohnson31415, LGTM!

We should add a test for this too, though perhaps lets wait for confirmation from the maintainers that this would be accepted.

vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
vllm/model_executor/logits_processors.py Outdated Show resolved Hide resolved
@simon-mo
Copy link
Collaborator

Thank you for the contribution, and thank Nick for the review.

My original intention is that min_tokens can be implemented using a built-in logits processor so the interface is cleaner. But current approach is fine as well.

There are some readability issue with _apply_min_tokens_penalty, please redo the list comprehension to make future devs easier to understand. And please add this to OpenAI compatible server as well (see protocols.py in the entrypoints/openai` directory).

@tjohnson31415 tjohnson31415 marked this pull request as ready for review March 20, 2024 21:22
* upstream/main:
  [Misc] Bump up transformers to v4.39.0 & Remove StarCoder2Config (vllm-project#3551)
  [Misc][Log] Add log for tokenizer length not equal to vocabulary size (vllm-project#3500)
  [🚀 Ready to be merged] Added support for Jais models (vllm-project#3183)
  Fix 1D query issue from `_prune_hidden_states` (vllm-project#3539)
  [PREFIX CACHING FOLLOW UP] OrderedDict-based evictor (vllm-project#3431)
  [BugFix] Hot fix in setup.py for neuron build (vllm-project#3537)
  Migrate `logits` computation and gather to `model_runner` (vllm-project#3233)
  [1/n][Chunked Prefill] Refactor input query shapes (vllm-project#3236)
  [1/n] Triton sampling kernel (vllm-project#3186)
  [Bugfix] Fix ROCm support in CMakeLists.txt (vllm-project#3534)
@njhill
Copy link
Member

njhill commented Mar 21, 2024

@simon-mo this should be ready now!

vllm/engine/llm_engine.py Outdated Show resolved Hide resolved
@simon-mo simon-mo merged commit c13ad1b into vllm-project:main Mar 25, 2024
32 checks passed
@tjohnson31415 tjohnson31415 deleted the min-new-tokens branch March 25, 2024 17:16
@njhill njhill mentioned this pull request Mar 25, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Mar 31, 2024
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants