Skip to content

Commit

Permalink
[Bugfix][VLM] Fix Fuyu batching inference with max_num_seqs>1 (vllm…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and liuyanyi committed Oct 6, 2024
1 parent c387943 commit 83c3bd1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 20 deletions.
6 changes: 2 additions & 4 deletions tests/models/decoder_only/vision_language/test_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def run_test(

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=2560,
max_num_seqs=1,
max_model_len=2048,
max_num_seqs=2,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
Expand All @@ -80,8 +80,6 @@ def run_test(
]

with hf_runner(model, dtype=dtype) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
Expand Down
51 changes: 35 additions & 16 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SequenceData)

from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
from .utils import flatten_bn, merge_multimodal_embeddings

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
Expand Down Expand Up @@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
model_config.model)

model_image_input = _fuyu_image_preprocess(image_processor, image_data)
image_patches = torch.stack([
image_patches = torch.cat([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
Expand Down Expand Up @@ -210,7 +210,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
])

# image has been processed with prompt in input processor
return MultiModalInputs({"image_patches": data})
return MultiModalInputs({"pixel_values": data})


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
Expand Down Expand Up @@ -242,23 +242,42 @@ def __init__(self,
cache_config=cache_config,
quant_config=quant_config)

def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w

def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)

if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")

for d in data:
_validate_shape(d)

return data.to(self.vision_embed_tokens.weight.dtype)

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None)
pixel_values = kwargs.pop("pixel_values", None)

if isinstance(image_patches, torch.Tensor):
# Remove the N dimension until multiple images are supported.
image_patches = image_patches.squeeze(1)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(pixel_values)}")

return FuyuImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)

expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size:
raise ValueError(
f"Expected image patches to have the last dimension of "
f"{expected_feature_size}, got {image_patches.size(-1)}")
image_patches = image_patches.to(
self.vision_embed_tokens.weight.dtype)
return FuyuImagePixelInputs(type="pixel_values",
data=image_patches)
return None

def _process_image_input(
Expand Down

0 comments on commit 83c3bd1

Please sign in to comment.