Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate logits computation and gather to model_runner #3233

Merged
merged 22 commits into from
Mar 20, 2024

Conversation

esmeetu
Copy link
Collaborator

@esmeetu esmeetu commented Mar 6, 2024

  1. Finish TODO: Use NCCL instead of ray for control-plane communication to remove serialization overhead #2221 (comment)
  2. This might replace Fix vocab_size inconsistency for sampler #2398 . (refer to [Misc][Log] Add log for tokenizer length not equal to vocabulary size #3500)

Description

This PR migrate logits computation and gather to model_runner. This change make Sampler simple and clean.
Furthermore, i want to remove sample method from model file, like llama.py. Because model and sample are at different stage, and should decouple with each other. But in this comment of PR #3183 : #3183 (comment) , that model will scale logits in sampler, so i will keep sample method as it is. This PR will better support #3183 model integration as well.

Pipeline:

  1. Prepare inputs
  2. Model
  3. LogitsProcessor (support logits scale and Custom LogitsProcessor Functions)
  4. Sampler
  5. Output

TODO

@grandiose-pizza
Copy link
Contributor

grandiose-pizza commented Mar 6, 2024

Hi @esmeetu , Important update.

Regarding this comment, #3183 (comment)

I found the bug. It will be useful for this PR.

Please add this change while making the PR for scaling logits: e04e56d

This is done as during multiGPU setting logits are None for get_tensor_model_parallel_rank() that are > 0 (it will be 0 & 1 in a 2GPU setting) as per this:

@zhuohan123 zhuohan123 self-assigned this Mar 7, 2024
@Yard1
Copy link
Collaborator

Yard1 commented Mar 7, 2024

While I agree with this change in principle, I think it's important to ensure we have an API that can support different usecases. For example, I would suggest making the logit generation a layer (or some other sort of abstraction on the model level). The fact that it has now been moved out of the model into the model runner makes the code harder to understand, especially considering the sampler remains a layer.

In other words, I would just suggest adding a new logit generator layer to the models (or the sampler, though models would be better - the output of a model should be logits IMO) and not putting that logic inside of model runner.

@esmeetu esmeetu marked this pull request as draft March 10, 2024 08:36
@esmeetu esmeetu marked this pull request as ready for review March 13, 2024 13:42
@esmeetu
Copy link
Collaborator Author

esmeetu commented Mar 13, 2024

@Yard1 @zhuohan123 I redesign this PR, which make logits processor an individual layer. PTAL!
After this design is supported, i will fix other model files.

@esmeetu esmeetu requested review from Yard1 and zhuohan123 March 13, 2024 13:46
@Yard1
Copy link
Collaborator

Yard1 commented Mar 13, 2024

This need changes to work with the LoRA path

@GennVa GennVa mentioned this pull request Mar 14, 2024
3 tasks
@esmeetu
Copy link
Collaborator Author

esmeetu commented Mar 15, 2024

@Yard1 All CI passed, please review this again. cc @zhuohan123
And should we resolve #2398 into this PR? Is there any better solution?

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! This looks very good. Much better than what we have before :) Can you merge the PR after fixing the merge conflicts?

vllm/worker/model_runner.py Show resolved Hide resolved
@zhuohan123 zhuohan123 enabled auto-merge (squash) March 20, 2024 22:59
@zhuohan123 zhuohan123 merged commit f1c0fc3 into vllm-project:main Mar 20, 2024
32 checks passed
grandiose-pizza pushed a commit to grandiose-pizza/vllm-jais that referenced this pull request Mar 21, 2024
tjohnson31415 added a commit to tjohnson31415/vllm that referenced this pull request Mar 21, 2024
* upstream/main:
  [Misc] Bump up transformers to v4.39.0 & Remove StarCoder2Config (vllm-project#3551)
  [Misc][Log] Add log for tokenizer length not equal to vocabulary size (vllm-project#3500)
  [🚀 Ready to be merged] Added support for Jais models (vllm-project#3183)
  Fix 1D query issue from `_prune_hidden_states` (vllm-project#3539)
  [PREFIX CACHING FOLLOW UP] OrderedDict-based evictor (vllm-project#3431)
  [BugFix] Hot fix in setup.py for neuron build (vllm-project#3537)
  Migrate `logits` computation and gather to `model_runner` (vllm-project#3233)
  [1/n][Chunked Prefill] Refactor input query shapes (vllm-project#3236)
  [1/n] Triton sampling kernel (vllm-project#3186)
  [Bugfix] Fix ROCm support in CMakeLists.txt (vllm-project#3534)
@esmeetu esmeetu deleted the perf-sampler branch March 23, 2024 11:10
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants