Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][CPU] Enable mrope and support Qwen2-VL on CPU backend #8770

Merged
merged 9 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu

from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)
Expand Down Expand Up @@ -281,6 +282,21 @@ def forward(
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif is_cpu():
seq_length = q.size(1)
q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
attention_mask = torch.zeros([1, seq_length, seq_length],
device=q.device,
dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = True
output = F.scaled_dot_product_attention(q,
k,
v,
attention_mask,
dropout_p=0.0)
context_layer = rearrange(output, "b h s d -> b s h d ")
else:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
Expand Down
92 changes: 83 additions & 9 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.sequence import (IntermediateTensors, SequenceData,
SequenceGroupMetadata)
from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad
from vllm.worker.model_runner_base import (
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
Expand Down Expand Up @@ -145,6 +147,38 @@ def build(self) -> ModelInputForCPU:
query_lens=seq_lens,
)

def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data,
computed_len: int):
mm_kwargs = self.multi_modal_input_mapper(mm_data)

# special processing for mrope position deltas.
mrope_positions = None
if self.runner.model_is_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, (
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")

hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()

mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
context_len=computed_len,
)
seq_data.mrope_position_delta = mrope_position_delta
return mm_kwargs, mrope_positions

def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand All @@ -153,6 +187,8 @@ def _prepare_prompt(
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]

slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
Expand All @@ -171,14 +207,20 @@ def _prepare_prompt(
seq_lens.append(seq_len) # Prompt token num
input_tokens.extend(prompt_tokens) # Token ids

mrope_positions = None
if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs, mrope_positions = self._compute_multi_modal_input(
seq_data, mm_data, computed_len)
multi_modal_inputs_list.append(mm_kwargs)

# Token position ids
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))

if (mm_data := seq_group_metadata.multi_modal_data):
mm_kwargs = self.multi_modal_input_mapper(mm_data)
multi_modal_inputs_list.append(mm_kwargs)
if mrope_positions:
for idx in range(3):
input_mrope_positions[idx].extend(mrope_positions[idx])
else:
input_positions.extend(list(range(computed_len, seq_len)))

# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
Expand All @@ -202,12 +244,18 @@ def _prepare_prompt(
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

if any(input_mrope_positions):
input_positions = None # type: ignore
else:
input_mrope_positions = None # type: ignore

num_prompt_tokens = len(input_tokens)

input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device) # type: ignore
input_positions = torch.tensor(input_positions,
input_positions = torch.tensor(input_positions
or input_mrope_positions,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
Expand Down Expand Up @@ -238,6 +286,7 @@ def _prepare_decode(
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
input_mrope_positions: List[List[int]] = [[] for _ in range(3)]
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
Expand All @@ -255,7 +304,17 @@ def _prepare_decode(

seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
if seq_data.mrope_position_delta is not None:
context_len = seq_data.get_num_computed_tokens()
next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta,
context_len,
seq_len,
)
for idx in range(3):
input_mrope_positions[idx].extend(next_pos[idx])
else:
input_positions.append(position)

seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
Expand All @@ -273,12 +332,18 @@ def _prepare_decode(
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)

if any(input_mrope_positions):
input_positions = None # type: ignore
else:
input_mrope_positions = None # type: ignore

max_decode_seq_len = max(seq_lens)

input_tokens = torch.tensor(input_tokens,
dtype=torch.long,
device=self.device)
input_positions = torch.tensor(input_positions,
input_positions = torch.tensor(input_positions
or input_mrope_positions,
dtype=torch.long,
device=self.device)
slot_mapping = torch.tensor(slot_mapping,
Expand Down Expand Up @@ -373,6 +438,15 @@ def __init__(
raise NotImplementedError(
STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU'])

@property
def model_is_mrope(self) -> bool:
"""Detect if the model has "mrope" rope_scaling type.
mrope requires keep "rope_deltas" between prompt and decoding phases."""
rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
if rope_scaling is None:
return False
return rope_scaling.get("type", None) == "mrope"

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
load_config=self.load_config,
Expand Down
Loading