Skip to content

Commit

Permalink
[Model][OpenVINO] Fix regressions from vllm-project#8346 (vllm-projec…
Browse files Browse the repository at this point in the history
…t#10045)

Signed-off-by: Peter Salas <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
petersalas authored and sumitd2 committed Nov 14, 2024
1 parent 7002c64 commit 4edbd2f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .buildkite/run-openvino-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ trap remove_docker_container EXIT
remove_docker_container

# Run the image and launch offline inference
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/vllm/examples/offline_inference.py
docker run --network host --env VLLM_OPENVINO_KVCACHE_SPACE=1 --name openvino-test openvino-test python3 /workspace/examples/offline_inference.py
12 changes: 11 additions & 1 deletion vllm/attention/backends/openvino.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from dataclasses import dataclass
from typing import List, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type

import openvino as ov
import torch

from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.multimodal import MultiModalPlaceholderMap


def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor,
Expand Down Expand Up @@ -128,3 +129,12 @@ class OpenVINOAttentionMetadata:
# Shape: scalar
# Type: i32
max_context_len: torch.Tensor

# The index maps that relate multi-modal embeddings to the corresponding
# placeholders.
#
# N.B. These aren't really related to attention and don't belong on this
# type -- this is just a temporary solution to make them available to
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]
6 changes: 3 additions & 3 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -915,7 +915,7 @@ def dummy_data_for_molmo(ctx: InputContext, seq_len: int,
if "image_masks" in out:
dummy_imgdata["image_masks"] = out["image_masks"]
dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long)
return dummy_seqdata, {"image": dummy_imgdata}
return DummyData(dummy_seqdata, {"image": dummy_imgdata})


def pad_images(
Expand Down

0 comments on commit 4edbd2f

Please sign in to comment.