Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] re-enable multi-modality input in the new beam search implementation #9427

Merged

Conversation

FerdinandZhong
Copy link
Contributor

@FerdinandZhong FerdinandZhong commented Oct 16, 2024

FIX #9577

Changes in this PR:

This PR introduces the following changes based on the updated beam search implementation:

  • Re-enable multi-modality input:
    Support for multi-modality input has been re-enabled for beam search with OpenAI-compatible endpoints.
  • Logprobs handling in ChatCompletionRequest:
    Added additional validation to disable logprobs when use_beam_search=True. Since the beam search selects results based on cumulative logprobs and determines step logprobs by beam_width, it ignores the top_logprobs and logprobs parameters passed in with the request.

Unit Test

Added two additional test cases in tests/entrypoints/openai/test_vision.py.

Manual Testing

The following command was used to launch the server for manual testing: vllm serve microsoft/Phi-3.5-vision-instruct --api-key token-abc123 --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2

Client script used to test the changes:

import openai
import asyncio


url = "http://localhost:"
client = openai.AsyncOpenAI(
    base_url = "http://localhost:8000/v1",
    api_key="token-abc123"
)


# Image URLs
img_urls = [
    "https://upload.wikimedia.org/wikipedia/commons/c/cb/Brachiosaurus_DB_flipped.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/3/3d/Allosaurus_Revised.jpg"
]

# Define the messages for the chat completion
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image_url",
                "image_url": {
                    "url": img_urls[0]
                }
            },
            {
                "type": "image_url",
                "image_url": {
                    "url": img_urls[1]
                }
            },
            {
                "type": "text",
                "text": "what are the animals in the images?"
            }
        ]
    }
]

async def make_request():
    try:
        response = await client.chat.completions.create(
            model="microsoft/Phi-3.5-vision-instruct",
            max_tokens=32,
            temperature=0,
            messages=messages,
            n=2,
            extra_body={"use_beam_search": True}
        )
        for choice in response.choices:
            print(choice.message.content)

    except openai.BadRequestError as e:
        print(f"Error: {e.code}")

asyncio.run(make_request())

Verified the functionality of multi-image input handling and correct response generation using beam search with the above manual tests.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@FerdinandZhong
Copy link
Contributor Author

Hi @simon-mo @khluu could you please add me to your Buildkite org to unblock the full CI run?

@simon-mo
Copy link
Collaborator

Added your email to our buildkite org.

vllm/engine/protocol.py Outdated Show resolved Hide resolved
vllm/beam_search.py Show resolved Hide resolved
Comment on lines 76 to 78
tokenizer = await self.get_tokenizer()
self.input_preprocessor = InputPreprocessor(model_config,
self.tokenizer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw why is this function defined inside a protocol? Perhaps we should move this to LLMEngine? Then we can make use of the existing input_preprocessor defined there. @youkaichao

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I made a mistake here, I just pushed another commit to change it from self.input_preprocessor and self.tokenizer to input_preprocessor and tokenizer.

By the way, in the current release v0.6.3, this function is inside AsyncLLMEngine and it has been moved here in a recent pr: 9296

@FerdinandZhong
Copy link
Contributor Author

Hi @DarkLight1337, thank you for the review!

I noticed the changes in PR #9473 and will merge the latest code once that PR is merged. Regarding the conflicts between the two PRs:

  • Logprobs: I'll align with the fix in #9473. In my PR, I directly prevent the return of logprobs when using beam search, as the number of logprobs in each step is determined by beam_width.

  • PromptType: I've kept it as the input type for handling multi-modality data. To correctly process the prompt passed to the function, I use inputpreprocess to parse the content, as suggested by @DarkLight1337. Additionaly, I also have prompt_text (parsed from the prompt) set as RequestOutput.prompt, which resolves the error related to PromptType being the input.

@njhill, could you please review these changes?

@njhill
Copy link
Member

njhill commented Oct 18, 2024

Thanks @FerdinandZhong, I'll review tomorrow (Friday US time)

@FerdinandZhong
Copy link
Contributor Author

FerdinandZhong commented Oct 20, 2024

Thanks @FerdinandZhong, I'll review tomorrow (Friday US time)

Thanks @njhill. May I know if changes look okay to you?

@njhill
Copy link
Member

njhill commented Oct 23, 2024

@FerdinandZhong sorry for the delay, bit behind with things.

The changes look good to me thanks. The InputProcessor change makes sense.

Re the logprobs, it's a good point that the number returned will be based on the beam width rather than how many are actually requested. I think we can improve this to request the max of these two and truncate as needed. But no need to change that for this PR.

I think we can improve the impl quite a bit more overall in some follow-on updates including:

  • Support most/all params.. I don't see any reason we can't, this should actually be easier with the "external" impl
  • Remove the separate beam search API, we can retain the function of the existing beam_search parameters, just have this layer intercept those. I'm not sure that we actually need separate BeamSearchParams.
  • Move the impl out of protocol.py .. at a minimum we can have it in beam_search.py and just call it from the abstract base class

@FerdinandZhong
Copy link
Contributor Author

@FerdinandZhong sorry for the delay, bit behind with things.

The changes look good to me thanks. The InputProcessor change makes sense.

Re the logprobs, it's a good point that the number returned will be based on the beam width rather than how many are actually requested. I think we can improve this to request the max of these two and truncate as needed. But no need to change that for this PR.

I think we can improve the impl quite a bit more overall in some follow-on updates including:

  • Support most/all params.. I don't see any reason we can't, this should actually be easier with the "external" impl
  • Remove the separate beam search API, we can retain the function of the existing beam_search parameters, just have this layer intercept those. I'm not sure that we actually need separate BeamSearchParams.
  • Move the impl out of protocol.py .. at a minimum we can have it in beam_search.py and just call it from the abstract base class

Hi @njhill, thank you for your comment!

I agree that taking the maximum of the beam_width and top_logprobs can be a good idea, and we can implement that change in the following PR. I'm also aligned with the action points you mentioned for improving beam search, and I'd be happy to collaborate on these enhancements moving forward.

@youkaichao
Copy link
Member

I'm surprised there are so many efforts for adding various features for beam-search ...

We will work for implementing beam search in another way so that all features for normal generation just works.

@FerdinandZhong
Copy link
Contributor Author

Hi @youkaichao, thank you for the feedback. I understand your concerns, and I agree that, in the long run, beam search can be properly designed and implemented. In the short term, I’m happy to continue providing feedback and contributing commits from a user’s perspective.

In the meantime, @DarkLight1337, could we consider merging this PR first, as it addresses the fix for #9577?

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 24, 2024

@njhill do we need to rebuild this? Seem that you cancelled some of the tests. Nevermind, looks like some containers died, I'll just rerun those tests.

@mergify mergify bot added the frontend label Oct 29, 2024
@FerdinandZhong
Copy link
Contributor Author

Hi @DarkLight1337 , I’ve merged the latest main branch and rerun the tests from my end. May I ask for your advice on the next steps to do with this PR for merging? Additionally, can I check with you if rebase is needed to add "sign-off" for each commit? Thank you.

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) October 29, 2024 10:04
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 29, 2024
@DarkLight1337
Copy link
Member

Sorry for missing this! You don't have to worry about signing-off for this PR as we can manually set that to pass.

@DarkLight1337
Copy link
Member

I have enabled auto-merge which should run all the tests and merge if they pass.

@FerdinandZhong
Copy link
Contributor Author

@DarkLight1337 got it, thanks!

@DarkLight1337 DarkLight1337 merged commit ef7865b into vllm-project:main Oct 29, 2024
75 checks passed
rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Usage]: beam_search not work with multimodal input
5 participants