Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Support VLMs with fine-grained scheduling #9871

Merged
merged 3 commits into from
Nov 13, 2024
Merged

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Oct 31, 2024

This PR implements the basic vision language model support in V1.

Motivation

Multi-modal inputs are difficult to deal with because they often have complex (or non-trivial) dependencies. For example, the model can take a prompt with interleaved texts and images like

Screenshot 2024-11-07 at 10 05 10 PM

Here, different colors represent different types of dependencies:

  • Red: Can be computed independently of each other
  • Yellow: Depends on Image Embedding 0
  • Green: Depends on Image Embedding 1

In V0, we didn't consider those dependencies. V0 circumvented it by always processing the entire prompt (all images & text) at once. However, this is not desirable, since it doesn't fit with other optimizations such as chunked prefills and prefix caching.

Proposal

To address this limitation, this PR proposes to make the V1 scheduler consider & track these dependencies explicitly, and do flexible & fine-grained scheduling based on it. One example can be like following:
Screenshot 2024-11-07 at 10 06 17 PM

  1. The scheduler leverages chunked prefills for the decoder inputs, so that TPOT is under control.
  2. Furthermore, the scheduler ensures that not too many images are processed by the vision encoder in the same step, because this can cause a spike in TTFT/TPOT.
  3. This fine-grained scheduling will also allow using prefix caching for VLMs, although it's not implemented in this PR.

Implementation

  • The scheduler has “encoder budget” (e.g., number of input image tokens in ViT) and “decoder budget” (number of input tokens).
  • The scheduler explicitly schedules the encoder and decoder inputs, considering the input dependencies.
    • The vision encoder and LLM decoder will live in the same GPU.
    • In every step, the model runner will first (optionally) run the vision encoder, and then run the LLM decoder possibly with the output of the encoder.
  • The model runner caches the encoder outputs (e.g., image embeddings) in encoder cache on GPU until the entire tensor is consumed by the decoder.
    • We should limit the maximum size of the cache, since the encoder outputs can be large. This will work as a scheduling constraint in the scheduler.

Limitations

  • Currently, the design only consider Llava-style model architectures (e.g., Pixtral, Molmo). It didn't consider other model architectures like multi-modal Llama.
  • Currently, the implementation in the PR only supports Llava v1.5 and Phi3v because of the necessary changes in model's input processor. Support for other models will be implemented in a followup PR.
  • Currently, the encoder cache is just a pool of tensors. For more precise memory management, we need to store it in paged memory, just like the paged KV cache. I leave this as future work.
  • Currently, the scheduling logic for encoder inputs is a bit hacky because of some limitations on the V1 model runner. This needs to be further refined in the next PR.

Misc

To reduce the conflicts, I reverted back the changes in detokenizer. Plus, the MM input mapper will run on the same process as the engine (scheduler) for now. We will move it to a separate process later.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ywang96 ywang96 self-assigned this Oct 31, 2024
Copy link
Collaborator

@alexm-neuralmagic alexm-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new code looks great. Also the performance should be better. Some nit comments

continue
if not self.encoder_cache_manager.can_allocate(request, i):
# Cannot schedule because the encoder cache is full.
num_new_tokens = start_pos - num_computed_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the meaning of this num_new_tokens update here? For example, start_pos can be < num_computed_tokens, and then the result may be potentially negative?

Copy link
Member

@ywang96 ywang96 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's possible to have start_pos < num_computed_tokens here: This is because num_computed_tokens are tokens already processed, which means if there were an image with start_pos < num_computed_tokens, it should have been already processed in the previous iteration (either stored in KV cache, or cached in encoder cache).

If I understand correctly, the point of this update is that if we cannot run encoder here, then we want to stop at exactly before where the first encoder position is, to run decoder only processing for this current iteration. However, I think it is possible to have start_pos == num_computed_tokens for a running request? (e.g, the first image token in a placeholder is exactly the first scheduled token, but the cache cannot allocate).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible when prefix caching is enabled (whilst we currently don't support prefix caching for VLMs).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to stop at exactly before where the first encoder position is, to run decoder only processing for this current iteration.

Exactly.

vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
request.num_tokens)

# Encoder-related.
if encoder_inputs_to_schedule:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a code duplication with the running case. Maybe the duplication can be avoided somehow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code was simplified a bit. I found it difficult to further refactor, since it's only 5 lines of code, and it involves updating the local variables like scheduled_encoder_inputs and encoder_budget. The code looks ok to me. WDYT?

vllm/v1/request.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, (start_pos, num_encoder_tokens) in enumerate(mm_positions):
start_idx = max(num_computed_tokens - start_pos, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: A quick doc for this start/end indices computation would be helpful here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some comments above to help understand the logic.

vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
Copy link

mergify bot commented Nov 1, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @WoosukKwon please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the very delayed review - left some comments!

FWIW - I did some mini benchmark on this branch vs V0 on 1 x A100-80G.

Command: python vllm/examples/offline_inference_vision_language.py --num-prompts 1000

V0:

1000/1000 [01:33<00:00, 10.69it/s, est. speed input: 6369.87 toks/s, output: 682.44 toks/s]

V1 with this PR and default budget & cache:

1000/1000 [01:01<00:00, 16.13it/s, est. speed input: 9614.21 toks/s, output: 1029.49 toks/s]

V1 with encoder budget and cache size = 576 (This should be more or less equivalent to V1 with previous design of VLM)

1000/1000 [01:15<00:00, 13.18it/s, est. speed input: 7856.67 toks/s, output: 841.03 toks/s]

continue
if not self.encoder_cache_manager.can_allocate(request, i):
# Cannot schedule because the encoder cache is full.
num_new_tokens = start_pos - num_computed_tokens
Copy link
Member

@ywang96 ywang96 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's possible to have start_pos < num_computed_tokens here: This is because num_computed_tokens are tokens already processed, which means if there were an image with start_pos < num_computed_tokens, it should have been already processed in the previous iteration (either stored in KV cache, or cached in encoder cache).

If I understand correctly, the point of this update is that if we cannot run encoder here, then we want to stop at exactly before where the first encoder position is, to run decoder only processing for this current iteration. However, I think it is possible to have start_pos == num_computed_tokens for a running request? (e.g, the first image token in a placeholder is exactly the first scheduled token, but the cache cannot allocate).

self._schedule_encoder_inputs(request,
request.num_computed_tokens,
num_new_tokens, encoder_budget))
assert num_new_tokens > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my other comment on when num_new_tokens can be 0 for a running sequence.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also fixed this (for the above decoder tokens) in the prefix caching PR. Also to clarify the semantic of num_new_tokens:

  • Before calling _schedule_encoder_inputs, num_new_tokens would be the text tokens as well as image tokens (placeholder).
  • After calling _schedule_encoder_inputs, num_new_tokens may be the same as before if encoder budget allows; otherwise it would be reduced to only include text tokens.

Is this understanding correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@comaniac Yes, correct. When the encoder cache or budget is insufficient, num_new_tokens can decrease up to the point just before the encoder input (e.g., image placeholder).

vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
request.num_tokens)

# Encoder-related.
if encoder_inputs_to_schedule:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
vllm/v1/core/scheduler.py Show resolved Hide resolved
@WoosukKwon
Copy link
Collaborator Author

@ywang96 Thanks for the review!

QQ: How did you measure the perf of V1 without this PR?

@ywang96
Copy link
Member

ywang96 commented Nov 5, 2024

@ywang96 Thanks for the review!

QQ: How did you measure the perf of V1 without this PR?

I have updated my original review comment - PTAL!

Copy link

mergify bot commented Nov 6, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @WoosukKwon please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 6, 2024
@alexm-neuralmagic
Copy link
Collaborator

FYI,

I did a quick performance benchmark for microsoft/Phi-3.5-vision-instruct when you have a separate process for mm_mapper (on the old version of this PR) and when you don't have a separate process. Results below show that separate process has large TTFT overhead, even when the RPS goes up (which is a bit surprising) - I think it is related to pickle/socket overheads most likely. I did some manual timings specifically on the roundtrip times to the separate process and I saw that mm_mapper is 5X slower with separate process than simply running directly.

RPS V0 - TTFT V1 - TTFT (separate process mm_mapper)
1 67.05 127.99
5 73.23 143.1
10 84.28 190.66
     
RPS V0 - TPOT V1 - TPOT (separate process mm_mapper)
1 14.27 14.44
5 17.59 18.89
10 25.47 27.8

When there is no separate process, the performance looks much better:

RPS V0 - TTFT V1 - TTFT (direct mm_mapper)
1 67.05 69.91
5 73.23 78.47
10 84.28 89.23
     
RPS V0 - TPOT V1 - TPOT (direct mm_mapper)
1 14.27 13.10
5 17.59 14.19
10 25.47 16.17

The commands are used are:

server: vllm serve microsoft/Phi-3.5-vision-instruct --trust-remote-code --max-model-len 4096 --enforce-eager --disable-async-output-proc

client: python benchmarks/benchmark_serving.py --backend openai-chat --base-url http://0.0.0.0:8000/v1 --endpoint /chat/completions --model microsoft/Phi-3.5-vision-instruct --dataset-path lmms-lab/LLaVA-OneVision-Data --dataset-name hf --hf-subset "chart2text(cauldron)" --hf-split train --num_prompts=100 --request-rate 5

@comaniac
Copy link
Collaborator

comaniac commented Nov 6, 2024

Thanks for the benchmarking. Could you also benchmark throughput? I suppose the benefit of separate processes should be more obvious in throughput instead of latency, as long as we pipeline mm_mapper well?

vllm/v1/core/encoder_cache_manager.py Show resolved Hide resolved
# in the "partial" state, where the request has some tokens computed
# but not all. The constraint is due to the persistent batch in the
# V1 model runner.
# TODO(woosuk): Remove this constraint after refactoring model runner.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what situation this limitation would hurt the performance?

self._schedule_encoder_inputs(request,
request.num_computed_tokens,
num_new_tokens, encoder_budget))
assert num_new_tokens > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also fixed this (for the above decoder tokens) in the prefix caching PR. Also to clarify the semantic of num_new_tokens:

  • Before calling _schedule_encoder_inputs, num_new_tokens would be the text tokens as well as image tokens (placeholder).
  • After calling _schedule_encoder_inputs, num_new_tokens may be the same as before if encoder budget allows; otherwise it would be reduced to only include text tokens.

Is this understanding correct?

Comment on lines 297 to 303
def _schedule_encoder_inputs(
self,
request: Request,
num_computed_tokens: int,
num_new_tokens: int,
encoder_budget: int,
) -> Tuple[List[int], int]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please docstring this function for readability.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added. Thanks for the suggestion.

Comment on lines 520 to 524
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
inputs_embeds = self.get_inputs_embeds(input_ids,
vision_embeddings)
input_ids = None
Copy link
Member

@ywang96 ywang96 Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're putting the encoder forward pass and embedding merge at model_runner level, then I don't think the code here is needed? (Is it possible for inputs_embeds to be None here when there's multimodal data in the request? If not, we just need to call embed_tokens here to get the text embeddings)

nvm - I see that it's needed here to be compatible with V0 - I will add a note accordingly in my PR to indicate that this needs to be cleaned up after we fully deprecate v0

@mergify mergify bot removed the needs-rebase label Nov 8, 2024
@WoosukKwon WoosukKwon marked this pull request as ready for review November 8, 2024 06:27
@WoosukKwon WoosukKwon changed the title [V1] Support VLMs [V1] Support VLMs with fine-grained scheduling Nov 8, 2024
Copy link

mergify bot commented Nov 11, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @WoosukKwon.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link

mergify bot commented Nov 11, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @WoosukKwon.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@WoosukKwon
Copy link
Collaborator Author

@ywang96 Addressed comments. PTAL.

@WoosukKwon
Copy link
Collaborator Author

@ywang96 This PR actually requires adding get_input_embeddings method to all models (while I only added it to llama, opt, llava, and phi3v in this PR), because it know executes the model's embedding layer and the other parts separately.

If we don't want to add this method to the text models, we can use self.model.model.get_input_embeddings instead, while it looks a bit hacky.

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon Overall looks good to me! I left a few more comments mainly around code clarifications so please take a look.

vllm/model_executor/models/llava.py Outdated Show resolved Hide resolved
vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Outdated Show resolved Hide resolved
vllm/v1/core/scheduler.py Outdated Show resolved Hide resolved
vllm/v1/worker/gpu_model_runner.py Show resolved Hide resolved
vllm/v1/core/scheduler.py Show resolved Hide resolved
@ywang96
Copy link
Member

ywang96 commented Nov 12, 2024

@WoosukKwon Everything looks good to me now - can you merge with main after #10272 is merged for the test fix? After that we can merge this.

Signed-off-by: Woosuk Kwon <[email protected]>
@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 13, 2024
@WoosukKwon WoosukKwon enabled auto-merge (squash) November 13, 2024 00:32
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Woosuk Kwon <[email protected]>
@WoosukKwon WoosukKwon merged commit bbd3e86 into main Nov 13, 2024
50 checks passed
@WoosukKwon WoosukKwon deleted the v1-vlm-sched branch November 13, 2024 04:53
rickyyx pushed a commit to rickyyx/vllm that referenced this pull request Nov 13, 2024
omer-dayan pushed a commit to omer-dayan/vllm that referenced this pull request Nov 14, 2024
Signed-off-by: Woosuk Kwon <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: OmerD <[email protected]>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Signed-off-by: Woosuk Kwon <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
Comment on lines +100 to +105
# FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may
# take 10-50 ms, which can cause a spike in the latency. We should
# consider moving this to a separate thread.
if req.mm_data:
req.mm_inputs = self.mm_input_mapper.process_inputs(
req.mm_data, req.mm_processor_kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One very nice property of V0 + #8348 is that the input mapper can be skipped entirely if the multimodal item is covered by the prefix cache (in our use case with Ultravox we can have many audio chunks in each inference). Not sure if that's practical to preserve in V1?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants