Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added logits processor API to sampling params #1469

Merged
merged 6 commits into from
Nov 3, 2023

Conversation

noamgat
Copy link
Contributor

@noamgat noamgat commented Oct 25, 2023

This PR adds a new optional parameter logits_processors to SamplingParams.

The idea (which exists in huggingface transformers, llama.cpp and other inference engines) allows custom code to modify the logits scores after they are generated by the model, before they are sampled from.

This opens integration possibilities with a lot of solutions, such as LM Format Enforcer (my library), Guidance, JsonFormer and Outlines.

For example, this allows limiting vLLM to only generate outputs that conform to a specific JSON Schema or regular expression.

The design principles behind this PR were to have as little impact as possible on vLLM itself (so no extra dependencies, no runtime penalty if the option is not used) and to be as consistent as possible with other inference engines.

LM Format Enforcer already has an example notebook on integrating with vLLM, showing the benefits, however it uses monkey patching due to the lack of API, which makes it less robust for production use.

I added a test to check the integration, and format.sh did not add any new remarks.

@noamgat
Copy link
Contributor Author

noamgat commented Oct 31, 2023

This seems to be a similar PR:
#535
(But this one is ready to merge because it was written against the most up to date vLLM)

@c3-adam
Copy link

c3-adam commented Oct 31, 2023

My team and I really really want this!!

@flexorRegev
Copy link

Benchmarked the RegexParser - works great, over 1000 examples in offline inference it's 40% slower than without (this will of course vary with the type of regex you're forcing)
This is super super useful for a lot of things and a super simple integration!
Great job @noamgat

The next phase will be acceleration when logits are simply a selection like guidance did in the past

@simon-mo
Copy link
Collaborator

simon-mo commented Nov 1, 2023

Hi @noamgat, thank you for this amazing contribution. The team (+@zhuohan123 @WoosukKwon @LiuXiaoxuanPKU) discussed a bit about this PR, and we think it's very promising. Can you help address the following:

  • Zero cost abstraction: how can we ensure that the inference will not be slowed down when there's no logits processors. Can we disable the entire for loop when there's no processor present.
  • Potential of batching: is it possible make the logic processor accept a batch? In particular, we are interested in reducing the performance penalty when the common case of just one logits processor for all request.
  • Documentation: can you add documentation about this feature and examples? stressing the performance penalty will be helpful to guide our users on this as well.
  • Error handling: currently if there is no token available, the request fail with Assertion error. Is it possible to fail more properly with custom error. somehow allowing other request to continue?

Thank you again for this commit. We are really looping forward to bringing this to vLLM.

@noamgat
Copy link
Contributor Author

noamgat commented Nov 1, 2023

Hi @noamgat, thank you for this amazing contribution. The team (+@zhuohan123 @WoosukKwon @LiuXiaoxuanPKU) discussed a bit about this PR, and we think it's very promising. Can you help address the following:

  • Zero cost abstraction: how can we ensure that the inference will not be slowed down when there's no logits processors. Can we disable the entire for loop when there's no processor present.
  • Potential of batching: is it possible make the logic processor accept a batch? In particular, we are interested in reducing the performance penalty when the common case of just one logits processor for all request.
  • Documentation: can you add documentation about this feature and examples? stressing the performance penalty will be helpful to guide our users on this as well.
  • Error handling: currently if there is no token available, the request fail with Assertion error. Is it possible to fail more properly with custom error. somehow allowing other request to continue?

Thank you again for this commit. We are really looping forward to bringing this to vLLM.

Thanks for the feedback!

Replies here:

  • Zero cost abstraction: I believe this is already the case. _apply_logits_processors() will only loop over the sequence groups and check if a processor exists. No buffers will be copied if none exist, the modifications happens in place.
  • Batching - The contract I went with mimics the design decisions of Llama.cpp and huggingface transformers - the API allows the logits process to depend on the output logits that were chosen in previous steps. So each sample can get different processing. If the caller wants to do something simpler (for example, disable a logit), they can do it. The performance won't be that different.
  • Documentation - there is no inherent performance penalty. The 40% slowdown that flexorRegex was talking about is due to the processing time of LMFormatEnforcer in the logits processor, not the vLLM pipeline integration. I think that after this PR is approved, I will update my library (LMFormatEnforcer) to use the new API, and submit integration examples to the vLLM documentation, so users will be able to use this API to generate outputs that conform to a JSON Schema or regular expression.
  • Error handling - This is how other inference engines behave as well. Is there a way to fail only one (or some) of the requests in the minibatch?

vllm/model_executor/layers/sampler.py Outdated Show resolved Hide resolved
@simon-mo
Copy link
Collaborator

simon-mo commented Nov 3, 2023

Thank you for the response. We will accept the PR pending one question above.

@noamgat
Copy link
Contributor Author

noamgat commented Nov 3, 2023

Nice catch! It was a bug I added when making a small modification to reduce the footprint when no logits processors are present during the PR process. Confirmed and updated.

@simon-mo simon-mo merged commit 555bdcc into vllm-project:main Nov 3, 2023
2 checks passed
@Cppowboy
Copy link

Cppowboy commented Nov 9, 2023

When will the logits processor feature be release?

@veltz1
Copy link

veltz1 commented Dec 13, 2023

Is it possible to use this PR to implement more complex methods s.a. contrastive decoding, etc...?

@mmoskal
Copy link
Contributor

mmoskal commented May 11, 2024

Does anyone know how are the logit processor functions passed to other workers when using Ray? I understand that the "driver" worker where the sampling happens is in fact another thread within the main vLLM process, so there is probably no problem there. However, because SamplingParams are passed to all workers (as part of SequenceGroupMetadata), would Ray just copy lots of data if the processor references it and pass it around (and not use it later)?

(The case of local copying of logits processor was also addressed in #3099 but I don't think this applies to Ray)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants