Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Add PaliGemma #5189

Merged
merged 41 commits into from
Jul 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e6352e5
initial
ywang96 Jun 2, 2024
7dfbe44
remove lm head
ywang96 Jun 2, 2024
3fd77fe
Merge branch 'main' into paligemma
ywang96 Jun 7, 2024
ccb0f25
Merge branch 'main' into paligemma
ywang96 Jun 8, 2024
9b5269d
update tests
ywang96 Jun 9, 2024
af11afa
fix test
ywang96 Jun 9, 2024
a465e85
format
ywang96 Jun 9, 2024
3e9a12b
fix model loading
ywang96 Jun 9, 2024
c734a17
fix input args
ywang96 Jun 9, 2024
2d7de4d
fix model loading
ywang96 Jun 9, 2024
2f65bf7
add embedding method to gemma
ywang96 Jun 9, 2024
04e4ace
fix linear output
ywang96 Jun 9, 2024
4a9551d
update gemma forward
ywang96 Jun 9, 2024
6fd10f1
update
ywang96 Jun 9, 2024
d08db94
fix test
ywang96 Jun 9, 2024
e325630
remove extra bos
ywang96 Jun 10, 2024
cbb7c49
format
ywang96 Jun 10, 2024
7ea7265
add gemma to model test
ywang96 Jun 10, 2024
9a8cd85
try normal caption
ywang96 Jun 11, 2024
9069831
Merge branch 'main' into paligemma
ywang96 Jun 12, 2024
7cb1cbb
Merge branch 'main' into paligemma
ywang96 Jun 25, 2024
7db6122
[Model] Add Gemma 2
WoosukKwon Jun 27, 2024
df2c007
Remove supports_lora=True
WoosukKwon Jun 27, 2024
9ba7aac
[Bugfix] Fix precision issues in Gemma 1
WoosukKwon Jun 27, 2024
6b32a1e
Minor
WoosukKwon Jun 27, 2024
6bfba0a
Merge branch 'main' into woosuk-gemma1
WoosukKwon Jun 27, 2024
bdf9334
Merge branch 'woosuk-gemma1' of https://github.com/vllm-project/vllm …
WoosukKwon Jun 27, 2024
524db49
Merge branch 'main' into woosuk-gemma1
WoosukKwon Jun 28, 2024
7e6f0fd
Merge branch 'main' into paligemma
ywang96 Jun 28, 2024
50ae420
Merge remote-tracking branch 'upstream/woosuk-gemma1' into paligemma
ywang96 Jun 28, 2024
e0828b0
Merge branch 'main' into paligemma
ywang96 Jul 5, 2024
c4fa37f
update paligemma
ywang96 Jul 6, 2024
b09066e
update test
ywang96 Jul 6, 2024
bf4bb58
update
ywang96 Jul 6, 2024
c1b9ebf
update
ywang96 Jul 6, 2024
5c0d2ec
add model to doc
ywang96 Jul 6, 2024
0b76ac1
address comments
ywang96 Jul 6, 2024
4823852
fix eos
ywang96 Jul 6, 2024
02b7c21
Merge branch 'main' into paligemma
ywang96 Jul 6, 2024
1651b15
move doc
ywang96 Jul 6, 2024
2f94007
Update docs/source/models/supported_models.rst
ywang96 Jul 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ Vision Language Models
- LLaVA-NeXT
- :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc.
-
* - :code:`PaliGemmaForConditionalGeneration`
- PaliGemma
- :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc.
-
* - :code:`Phi3VForCausalLM`
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
Expand Down
52 changes: 52 additions & 0 deletions examples/paligemma_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import subprocess

from PIL import Image

from vllm import LLM

# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them


def run_paligemma():
llm = LLM(model="google/paligemma-3b-mix-224")

prompt = "caption es"

image = Image.open("images/stop_sign.jpg")

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
},
})

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def main():
run_paligemma()


if __name__ == "__main__":
# Download from s3
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"

# Make sure the local directory exists or create it
os.makedirs(local_directory, exist_ok=True)

# Use AWS CLI to sync the directory, assume anonymous access
subprocess.check_call([
"aws",
"s3",
"sync",
s3_bucket_path,
local_directory,
"--no-sign-request",
])
main()
147 changes: 147 additions & 0 deletions tests/models/test_paligemma.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test is also a bit redundant with test_llava.py. Can we refactor test_llava.py to cover both models?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment above.

Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import List, Optional, Tuple, Type

import pytest
from transformers import AutoTokenizer

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign": "caption es",
"cherry_blossom": "What is in the picture?",
"boardwalk": "What is in the picture?",
})

IMAGE_TOKEN_ID = 257152

models = ["google/paligemma-3b-mix-224"]


def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
model: str):
"""Sanitize vllm output to be comparable with hf output."""
output_ids, output_str, out_logprobs = vllm_output

tokenizer = AutoTokenizer.from_pretrained(model)
eos_token_id = tokenizer.eos_token_id

hf_output_ids = [
token_id for idx, token_id in enumerate(output_ids)
if token_id != IMAGE_TOKEN_ID or output_ids[idx - 1] != IMAGE_TOKEN_ID
]

hf_output_str = output_str

if hf_output_ids[-1] == eos_token_id:
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)

return hf_output_ids, hf_output_str, out_logprobs


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.

All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]

with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=[
vllm_to_hf_output(vllm_output, model)
for vllm_output in vllm_outputs
],
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PaliGemmaForConditionalGeneration":
("paligemma", "PaliGemmaForConditionalGeneration"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
Expand Down
10 changes: 8 additions & 2 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,16 +268,22 @@ def __init__(
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer

residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
Expand Down
Loading
Loading