-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Drop beam search support #6226
Comments
We should disable beam search superficially in the next vLLM release (e.g. assert False, "beam search is deprecated, see #6226") and see the reaction. If there is a lot of noise then we should consider taking a path that maintains compatibility. |
Beam Search gives consistent results and is used in Production level systems where predictable results are important. So dropping beam search would be a bad idea IMHO. Setting temperature=0 provides some predictable results but not always. |
MLPerf inference benchmark requires the beam search feature on so I think this is still useful in the industry. Here's the link to the MLPerf inference rules: thanks, -yuan |
Regarding MLPerf Inference @zhouyuan , it is only needed for the GPT-J benchmark (which was the first LLM task they added) and is not used for Llama 2 70B or Mixtral 8x7B (which are more recent). I don't believe beam search will be used in future tasks since it is generally not practical for cost-effective deployment. |
To serve as an alternative to still enable customers would like similar features. I would like to propose a new param to introduce in our current vLLM system, let's call that num_q (number of queries or former num_beams). With this being set, let's say we set num_q=5, what it does will be similar to best_of or n. but instead of doing that, it will bring the top 5 token for the first token generation and generate them to Request:
Response:
Customer can gauranteed to get 5 different responses and its logprobs. Where in the meantime, they can still conduct beam search themselves through choosing the one with best log_probs. But with doing this, it is much lesser complication introduced to vLLM to ahieve something. It also brings the freedom for more customization for users to decide what sequence they want to use.
|
@cadedaniel Thanks for the suggestion! Here's what we've decided to do:
|
+1, our teams observe benefits for reliability and occasionally even latency from beam search, highly relevant in Prod |
Yes. The most commonly used now is top-p and top-k sampling. |
I kindly suggest maintaining beam search support, as it is the primary option for translation tasks, even with LLMs. |
@nightflight-dk Thanks for your input! Are you using vLLM in production? If so, we'd be happy to discuss our plan with you. |
A potential use-case we have is that sometimes using Adding the fact that the model should choose from |
I think the typical use case for taking multiple samples is when you have a method for "trying" a sample. Perhaps the first sample "fails", and then you want to try the second sample, etc. (Our specific use case is formal proof search.) Beam search is well suited for this application, because the beams provide diversity. With random sampling I could end up retrying the same "almost surely good" idea over and over, instead of continuing to the second idea. It's true that beams ranking lower are likely bad. But trying a bad idea still beats trying the same good idea twice. That said, I'm a fan of simpler code. If random sampling is much faster than beam search, we can just deduplicate the samples or something. I will run some experiments to measure how this will affect us. |
We have noticed that token level logprobs from beam search are quite informational compared to those from nucleus sampling. A lot of our workflows depend on these logprobs and I'd suggest keeping beam search support as well! |
We heavily depend on beam search at Heavy.ai in VLLM in production with customers to give optimal accuracy for text-to-SQL tasks (https://www.heavy.ai/heavyiq/overview), and would lose significant accuracy with it turned off. Perhaps we could implement it ourselves using the log probabilities (would be nervous about the performance though) or freeze our version to 0.5.2, but neither is ideal at all. We are also looking at various sampled approaches using a judge model to pick the best, and here again taking the top-n beam search generations provides better accuracy than setting a non-zero temperature and taking n samples. From the above I understand the motives but I'd request that this be reconsidered. It's not just us either, pretty much all the SOTA text-to-SQL approaches use beam search to get best accuracy. |
Beam search is a deal breaker for our use case. We use it extensively in prod. We have found that it increases the accuracy of our LLM's responses by roughly 1%, which is absolutely critical for our use case. Unfortunately if vLLM stops supporting beam search we'll have to switch to an unoptimized inference engine. |
We are considering using beam search as it actually improves performance and we are reviewing its use at the production level. This alone might make us reconsider using vLLM. The speed and complexity of implementation could be seen as a trade-off for better performance and the ability to infer model choice paths. Must we really delete it? We do not want that. |
We, at Spotify, use vLLM beam search to sample multiple answers from the same prompt in multiple textual tasks. This breaking change would hurt us significantly and we may have to reconsider vllm usage for some of our use cases, if there are no alternatives :( please, reconsider it Feel free to DM me |
We are very much relying on beam search for our biomedical industry applications to significantly boost performance in our setting. That benefit is large enough to consider alternative projects for serving, but we would hate to have to abandon vllm :( |
We are using beam search in production and would appreciate its continued support |
For production usacases, please also indicate why you choose beam search, and why not the rest sampling method. Many public API service does not provide beam search, and what would you do if you don't have beam search? (i.e. any workaround?) A possible workaround: LLMs are very smart at present, if you just want output diversity, how about adding a system prompt to instruct it for more diverse output? |
As a user of guidance/AICI/other methods of constraining LLM output, disabling beam search can reduce quality of outputs. For the reason users describe above. We've noticed that across a wide array of models, these two facts interact poorly:
For vLLM with open source models, beam search helps overcome this obstacle, in effect giving the model a weak form of backtracking. With LLM APIs, we maintain a list of tokens which we add a small negative weight to, however this list is not exhaustive and of course, we need to derive the token IDs for each unique tokenizer. In my experience, beam search works better than negative weighting these tokens, and is more straightforward and adaptable to multiple models. This is a sample of our "verboten tokens" file: [
"\",",
"],",
"[\"",
"[]",
",\"",
"\"]",
"][",
"},",
"\",\"",
"{{",
"\"\"",
"}}",
"{\"",
"]]", |
@AaronFriel do you try the guided decoding at https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api ? |
I'm familiar with Guidance, yes, I mentioned it in my reply. |
My understanding is that guided decoding in particular benefits from beam search, for reasons alluded to here and here, i.e. you can get into nasty situations with guided decoding where the probabilities of earlier tokens can be skewed by the probabilities of later tokens, even if some of those combinations are disallowed by the guided choice/regex/json. |
We also use beam search in a production deployment of vllm and would probably have to migrate off of vllm without it. We're optimizing for accuracy on a structured task, not diversity of output, and have found that beam search produces the best results. |
My HEAVY.AI colleagues have already commented, but to add a little detail ... we use beam search so that we can performantly constrain things to our specific SQL syntax. We've found it to be faster than alternatives and are using it in production across multiple accounts. |
@hrsmanian @zhouyuan @lanking520 @nightflight-dk @HeegonJin @SemMulder @darabos @DhruvaBansal00 @tmostak @physicsrob @YooSungHyun @denadai2 @sjmielke @Reichenbachian @AaronFriel @hinnefe2 @mflaxman10 Due to strong pushback from the community, we have decided to reconsider this proposal. vLLM will continue to support beam search until further notice. We will try to find other ways to optimize the system overheads and get back to this proposal after more exploration. Thanks everyone for the feedback! |
@WoosukKwon thank you. Super appreciated! To give more context, Spotify, among other use cases, needs to have long exact inference (e.g. for recommendations). Thus, beam search is great for this :) |
An update on this thread: For users who need beam search, I'd like to know how sensitive you are w.r.t. the latency and throughput of the inference. Per my understanding, beam search is quite slow in terms of both latency and throughput. If you use beam search, I assume you are not very sensitive to the speed, but just want the quality of generation from beam search. Why I ask this? Because I'd like to move the beam search logic one level higher above the current vLLM. Say we have an inference engine that supports openai api server, it seems we can emulate one api server with beam search, by asking the api server to produce one token at a time, with multiple logprobs: def beam_search_proxy(sequence, beam_width, max_tokens):
candidates = [sequence]
finished = []
while candidates:
new_candidates = []
for seq in candidates:
for token, logprob in generate(seq, max_tokens=1, logprobs=beam_width):
new_candidates.append(new_seq(seq, token, logprob))
finished += [x.is_finished() for x in new_candidates]
new_candidates = [x for x in new_candidates if not x.is_finished()]
new_candidates.sort(key=lambda x: x.cummulative_logprobs, reverse=True)
candidates = new_candidates[:beam_width]
finished.sort(key=lambda x: x.cummulative_logprobs, reverse=True)
return finished[:beam_width] the sharing of memory and computation among sequences, can be achieved via prefix caching. disclaimer: I'm not familiar with beam search, and the semantic of the above function can be wrong. please just read the idea, to emulate beam search with a normal openai api server. If we can go to this direction, the outcome would be:
|
We are very sensitive to throughput, but not latency. We need the highest possible throughput with beam search. If there's a substantial drop in overall compute efficiency, or drop of beam search support, we would migrate our inference elsewhere (or possibly fork, although TBH we don't want to be in the business of optimizing inference.) For what it's worth, I think it's unlikely that moving to a higher level abstraction would work without a substantial drop in throughput. My weak evidence for this: #1646 We currently monkeypatch our VLLM in production to make the fork operation performant. I honestly hate that we do this, but the cost implications of not doing it are unacceptable. |
@physicsrob can you elaborate on that? |
The point is, very few developers understand beam search, and many new features directly hard fail when beam search is used: vllm/vllm/spec_decode/batch_expansion.py Lines 303 to 305 in 6e36f4f
vllm/vllm/engine/output_processor/multi_step.py Lines 100 to 103 in 6e36f4f
I'm pretty sure this will happen more often in the future. If we keep beam search in vllm, even if the performance is untouched, you will find more and more bugs related to beam search. By separating beam search and vllm, both of them can be optimized separately. And it is even possible that finally beam search with new vllm gets better than the current beam search in vllm. |
Same here. But we find throughput very poor already. 2,500 t/s with n=1, 100 t/s with 10 beams. (I was always wondering if we're doing something wrong. 😅)
I'm not qualified to judge, but I like the idea of moving beam search to a higher layer. I can imagine it may make it easier to do batching for the beams. E.g. in your example:
Perhaps this could be replaced with a batch completion:
So we only make one inference call per token, which covers all beams at once. |
we should definitely add batching for the beams. please take the code snippet as just a demonstration for the ideas lol .
this is possible, because beam search is very complicated search algorithm. in normal decoding, you can stream back every token you generate. However, in beam search, the tokens you get might be discarded later. In fact many tokens will be decoded and then discarded.
thank you for your support! |
Another update on this thread: For people who use beam search, what are the common sampling parameters? besides basic beam search, do you need to compose beam search with the rest features? e.g. beam search + guided decoding? beam search with temperature? beam search with presence_penalty/frequency_penalty/repetition_penalty/length_penalty ? beam search with logprobs? it is hard for me to imagine, what does it mean for beam search + guided decoding , and what is the comparison criterion for beam search with presence_penalty/frequency_penalty/repetition_penalty/length_penalty (i.e. is the penalty included in telling the quality of candidates?). basically, because beam search is a search algorithm, it usually conflicts with all the rest sampling algorithm. And as I mentioned before, many features in vllm already directly assert beam search is not used. please provide your further feedback on the specific use case of beam search. If throughput is the only concern to move beam search one level above the vllm core, I'm pretty sure we should be able to optimize the speed as fast as the current vllm implementation. |
For me, it's just beam search with up to 100 beams. No other sampling features used. Temperature=0. Currently we rely on getting back logprobs for the generated samples. We use these to get an overall "sentence likelihood" which is then used for comparisons across different generations. (It's a best-first search.) There are some conceptual issues with this and we plan to switch to a better scoring function. |
@darabos thanks for the response! followup questions: do you use openai api server for beam search? or do you use the
when you create one vLLM instance to use beam search, do you need different beam search width for different prompts? or they all have the same beam search width? |
Oh my, this feature will be personalized for my needs! 😄
The LLM class. Here's an excerpt from our code that hopefully includes the bits you're looking for. The model is often a DeepSeek-Coder 1.3b fine-tune. def __init__(self):
self.model = vllm.LLM(model=model_name, gpu_memory_utilization=0.5, max_model_len=10240)
def candidates(self, ...):
sp = vllm.SamplingParams(temperature=0, n=num_beams, use_beam_search=True, max_tokens=100, early_stopping=True)
outputs = self.model.generate(prompts, sp, use_tqdm=False) I don't think there is a lot of thought behind how we set these parameters.
The same. The way our code works, and I think this may be a typical use of beam search, is that we want to try the best generation, then the second best, etc. Generating 16 samples is just a compromise. Often we won't use all 16, other times we would need more than 16. The ideal for us would be if we could pull N samples one by one, without guessing N ahead of time. I know this is not on the table with beam search. |
@darabos thanks! your explanation helps a lot |
Hi, sorry for the late response but I was in parental leave. I cannot go into details ATM. However, we use beam search to predict some catalog codes for recommendation purposes. We do this for offline inference but we are considering doing online inference as well. Why beam search? Because each catalog item corresponds to a sequence of codes and our model has to predict existing sequences. Early results with top-k sampling are significantly worse than beam search. We usually have ~ 30 or 50 beams and sequences that are between 3 and 15 long. I know, it's quite intense :(
Hi, sorry for the late response but I was in parental leave. I cannot go into details ATM but I want to give further details about our use case. However, we use beam search to predict some catalog codes for recommendation purposes. We do this for offline inference but we are considering doing online inference as well. Why beam search? Because each catalog item corresponds to a sequence of codes and our model has to predict existing sequences. Early results with top-k sampling are significantly worse than beam search. We usually have ~ 30 or 50 beams and sequences that are between 3 and 15 long. I know, speed is crucial but this task's quite intense :( |
to all: I added a feature request for a more powerful beam search (as it was in the old vllm) here #10754 |
TL;DR: To reduce system complexity and enable future optimizations, we propose discontinuing beam search support.Due to strong pushback from the community, we have decided to reconsider this proposal. vLLM will continue to support beam search until further notice. Thanks everyone for the feedback!
Motivation.
Currently, vLLM supports 3 types of sampling: greedy, random, and beam search. Beam search, which dynamically creates and removes top-k branches at each step, is the most complex of the three. Traditionally, beam search has been popular for NLP tasks like translation and summarization. However, in the LLM era, beam search has become less common. Major LLM APIs such as GPT, Gemini, and Claude do not support it.
In vLLM, beam search initially motivated the idea of PagedAttention. Actually, vLLM excels at beam search compared to other inference engines, since PagedAttention can efficiently handle the dynamic nature of beam search and minimize its KV cache usage. Despite this, implementing beam search introduces significant system complexity, hindering potential optimizations. It complicates the system while being used rarely.
To resolve this, we propose eliminating beam search support, which will provide the following benefits:
Reduced Complexity in Sampling and Output Processing
More Predictable Block Table
Potential Future Removal of SequenceGroup
Proposed Change.
We plan to execute this in 3 steps:
We are open to reintroducing beam search if there is strong demand from the community. Please share any concerns regarding this decision. We apologize for any inconvenience caused by this change.
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response
The text was updated successfully, but these errors were encountered: