Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Reimplement and separate beam search on top of vLLM core #8306

Closed
1 task done
youkaichao opened this issue Sep 9, 2024 · 21 comments · Fixed by #9105
Closed
1 task done

[RFC]: Reimplement and separate beam search on top of vLLM core #8306

youkaichao opened this issue Sep 9, 2024 · 21 comments · Fixed by #9105
Labels

Comments

@youkaichao
Copy link
Member

Motivation.

A rework of #6226

After discussing further with the community, we find that the common use case for beam search is:

  1. throughput oriented
  2. mainly offline batch inference
  3. use one beam search parameter for all the prompts in the batch

After discussing with many contributors, we find:

because beam search is a search algorithm, it conflicts with all the rest sampling algorithm. As a result, many features in vllm already directly assert beam search is not used, e.g.

assert len(input_seq_group_metadata.seq_data) == 1, (
"Beam search "
"not supported in speculative decoding")

assert len(seqs) == 1, (
"Beam search not supported in multi-step decoding.")
seq = seqs[0]

keeping beam-search as-is in the codebase, will not benefit current beam search user, as no optimization will target at better beam search performance. What's worse, very few developers understand beam search. Keeping beam-search as-is will not only increase the bugs for beam search as the codebase evolves, but also increase the maintenance cost of all contributors.

in search of a win-win solution, on behalf of the vllm team, I propose to separate and reimplement beam search on top of the vllm core code.

to be specific, we can:

  1. remove beam search logic from the scheduler
  2. add an LLM.beam_search interface, that calls the engine to generate 1 tokens with logprobs every step, and maintain beam-search logic only in the LLM.beam_search function.
  3. add a beam search emulator over commonly used openai api server, which internally calls the generation endpoint to generate one step with logprobs, and maintain beam-search logic only in the emulator.

From the initial discussion, one concern is the efficiency of such implementation, as the request will come and go again and again from the vllm core's perspective. It should be solvable in two-folds:

  1. turning on prefix caching can reuse computation from the last step so that we don't need to recompute the kv cache of prompt again and again.
  2. after separating beam search and the vllm core, they can be optimized individually. The simplified code will be much easier to optimize.

vLLM is a community project, and we'd like to not only seek opinions from beam-search users, but also seek contributions from beam-search users. Your help is truly needed to shape the future of beam-search support in vLLM.

Proposed Change.

summary of the change: implement beam-search on top of vllm core and add wrappers for users. remove beam-search from the vllm core (scheduler).

Feedback Period.

1 week, from 9/9 to 9/15 (both inclusive)

CC List.

@hrsmanian @zhouyuan @lanking520 @nightflight-dk @HeegonJin @SemMulder @darabos @DhruvaBansal00 @tmostak @physicsrob @YooSungHyun @denadai2 @sjmielke @Reichenbachian @AaronFriel @hinnefe2 @mflaxman10
@WoosukKwon @zhuohan123 @simon-mo

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@youkaichao youkaichao added the RFC label Sep 9, 2024
@simon-mo simon-mo changed the title [RFC]: Reimplement and separate beam search on top of vllm core [RFC]: Reimplement and separate beam search on top of vLLM core Sep 9, 2024
@AaronFriel
Copy link

AaronFriel commented Sep 9, 2024

after separating beam search and the vllm core, they can be optimized individually. The simplified code will be much easier to optimize.

This is a good goal to work toward, as ensuring that API interfaces (OpenAI, beam search, or otherwise) can efficiently and reliably schedule new sequences benefits all consumers.

turning on prefix caching can reuse computation from the last step so that we don't need to recompute the kv cache of prompt again and again.

The flood of vLLM notifications is hard to keep up with, so I may be out of date. My understanding was that prefix caching was not precise and was block based, resulting in some amount of excess computation. Is there an issue to allow APIs to specify the "prefix length" that should be cached?

This new approach could see performance degrade when the sequence length approaches a multiple of the KV block length, if each arm of the beam search schedules a new sequence and must prefill O(kv_block_size) tokens plus decoding O(1) tokens. Ideally both would be O(1) with a hint to allow beam search to cache the entire prefix.

@youkaichao
Copy link
Member Author

youkaichao commented Sep 9, 2024

My understanding was that prefix caching was not precise and was block based, resulting in some amount of excess computation

we can set block size to 1 for the vLLM instance when we use beam search, then we don't have to waste any computation.

@simon-mo
Copy link
Collaborator

simon-mo commented Sep 9, 2024

There are also some alternative implementation of this by moving this functionality to a special class of Worker or Executor, which can be configured when beam search is turned on for any engine that needs it.

@AaronFriel
Copy link

@youkaichao How well does the KV cache handle a block size of 1, in terms of compute or memory overhead?

@youkaichao
Copy link
Member Author

@AaronFriel I don't think setting block size of 1 will affect performance a lot. But we need to test and measure the impact.

@youkaichao
Copy link
Member Author

@simon-mo can you explain more? What special functions / interfaces would these new Worker or Executor need?

@youkaichao youkaichao pinned this issue Sep 12, 2024
@nFunctor
Copy link
Contributor

Hello, thanks for your work, I am not sure if I want to create a new issue for this, technically those are still comments!

I have done some manual testing of your new beam search implementation and here are some observations, together with a late ROC response:

  • I'd believe that the new offline method would be made available in server completions just like before? The current implementation in the OpenAI server is very basic anyway and personally I would not need much more.

  • The speed win is more apparent the longer the model generates. The current (soon to be legacy?) implementation is suffering from lots of GPU idleness when generating around 500 tokens or more.

  • I have not seen v2_block_manager do anything in my tests, and same for the scheduler. This is probably normal, considering that I've hit the GPU usage to near max.

  • I believe the new (=HF) implementation is not the same as the old one (I have not read its code, admittedly)? The results for the same prompts and num_beams are different between old and new.

  • We now have the possibility to change temperature in beam_search_params ([Feature]: Beam Search with Temperature > 0 #8067 opened a while ago).

The tests were run on Llama 3.1 8B AWQ / RTX 3090 / three basic completion prompts like "Today is a good day". Can provide more details if it is of any use.

@youkaichao youkaichao unpinned this issue Sep 27, 2024
@youkaichao
Copy link
Member Author

@nFunctor it's great to hear that you find the new implementation is faster! We do have plan to add beam search back in the openai server, with implementation similar to the LLM.beam_search . Please stay tuned.

Regarding the exact equivalence with the old implementation, we cannot guarantee generating 500 tokens is exactly the same as huggingface one (and the old one). As long as the algorithm still follows beam search, it should be fine. And we have checked the first 64 tokens are the same. It should be enough for practical usage.

@nFunctor
Copy link
Contributor

Thanks for your response @youkaichao . What do you think about the temperature implementation?

The new method can still be slower, if generation is done with less beams and less tokens. As you say the block_size parameter is what is holding the method back. I tried activating Flashinfer with block_size 8 and it indeed gave some speedup (12s->10s in one experiment).

What I found, as a byproduct, is that the mentioned Llama becomes completely insane on Flashinfer with the old beam method, repeating the same phrase (I am running an instruct-tuned outside of its chat template format, but still). So, in regards to what you said about the results being similar, maybe the new results will be actually more "numerically stable" in some cases.

@youkaichao
Copy link
Member Author

What do you think about the temperature implementation?

beam search is a search algorithm. I don't see how it is related with temperature.

the new results will be actually more "numerically stable" in some cases

glad to hear that.

@nFunctor
Copy link
Contributor

@youkaichao If I understand correctly, the .generate returns log(softmax(logits/T)) as logprobs so there is an impact on the sequences' weights that can lead to significant deviations, cf this explanation. In the new implem, setting a non-zero temperature in beam_search_params in vllm.py does change the generated sequence.

We don't change the logic of the algorithm but since our next-token distribution changes, so might the results.

@youkaichao
Copy link
Member Author

sorry, I don't get it. what is your ask?

@yunyipower
Copy link

yunyipower commented Oct 10, 2024

@youkaichao hi, it's great to see your design, so does it support multi-batch beam-search or not? I mean, in terms of op, not a loop above prompts list

@youkaichao
Copy link
Member Author

multi-batch beam-search

what is multi-batch beam-search?

@varuniyer
Copy link

varuniyer commented Oct 13, 2024

@youkaichao The beam search docs for vllm.LLM still list these TODO items:

TODO: how does beam search work together with length penalty, frequency penalty, and stopping criteria, etc.?

I see you mentioned that beam search conflicts with the sampling algorithm. However, logit processors (currently an argument of the constructor of the SamplingParams object passed into generate) can be used to add penalties like these. They can affect the top k beams selected at each iteration of search even without sampling. Is there progress on supporting logit processors in the new beam search implementation? The closest issue I found is #9253 regarding stop conditions but not logit processors.

@yunyipower
Copy link

multi-batch beam-search

what is multi-batch beam-search?

inference in batch,say 16?

@liho00
Copy link

liho00 commented Oct 14, 2024

how to enable vllm openai server with beam search? seems like there is no engine args available there?

@HeegonJin
Copy link

HeegonJin commented Oct 16, 2024

As mentioned in #9253, the current implementation does not stop generating when the EOS token is encountered, and continues until it reaches the maximum token limit. This appears to be the major issue.

@nFunctor
Copy link
Contributor

@HeegonJin yes, and I tried my workaround in #9264 (we will see if the team approves).

As an external contributor I lack greater understanding of what's going on but it seems to me that a beam gets completed but never gets pushed to completed beams due to the insufficient execution of the eos check. The proposed stop conditions seem to do that but they are not exactly elegant.

@youkaichao
Copy link
Member Author

@youkaichao The beam search docs for vllm.LLM still list these TODO items:

TODO: how does beam search work together with length penalty, frequency penalty, and stopping criteria, etc.?

I see you mentioned that beam search conflicts with the sampling algorithm. However, logit processors (currently an argument of the constructor of the SamplingParams object passed into generate) can be used to add penalties like these. They can affect the top k beams selected at each iteration of search even without sampling. Is there progress on supporting logit processors in the new beam search implementation? The closest issue I found is #9253 regarding stop conditions but not logit processors.

we plan to improve beam search so that all the sampling parameters should work.

@denadai2
Copy link

to all: I added a feature request for a more powerful beam search (as it was in the old vllm) here #10754

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants