Skip to content

Commit

Permalink
[Hardware][Intel] Isolate CPUModelRunner and ModelRunner for better m…
Browse files Browse the repository at this point in the history
…aintenance (vllm-project#3824)
  • Loading branch information
bigPYJ1151 authored Apr 11, 2024
1 parent 1bf454a commit d9b8060
Show file tree
Hide file tree
Showing 5 changed files with 443 additions and 61 deletions.
72 changes: 24 additions & 48 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,15 @@ def copy_blocks(


@dataclass
class TorchSDPAMetadata(AttentionMetadataPerStage, PagedAttentionMetadata):
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata,
AttentionMetadataPerStage):
"""Metadata for TorchSDPABackend.
"""
# Currently, input sequences can only contain all prompts
# or all decoding. True if all sequences are prompts.
is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[List[int]]
prompt_lens_tensor: Optional[torch.Tensor]

max_subquery_len: Optional[int] = None
max_prompt_len: Optional[int] = None
subquery_start_loc: Optional[torch.Tensor] = None
seq_start_loc: Optional[torch.Tensor] = None
use_cuda_graph: bool = False

def __post_init__(self):
# Set during the execution of the first attention op.
Expand Down Expand Up @@ -111,7 +106,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata[TorchSDPAMetadata],
attn_metadata: TorchSDPAMetadata,
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down Expand Up @@ -140,51 +135,36 @@ def forward(
attn_metadata.kv_cache_dtype,
kv_scale)

num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens

output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens

if prefill_meta := attn_metadata.prefill_metadata:
if (kv_cache is None or prefill_meta.block_tables.numel() == 0):
if attn_metadata.is_prompt:
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=1)

if prefill_meta.attn_bias is None:
if attn_metadata.attn_bias is None:
if self.alibi_slopes is not None:
att_masks = _make_alibi_bias(
self.alibi_slopes, query.dtype,
prefill_meta.prompt_lens) # type: ignore
attn_metadata.prompt_lens) # type: ignore
elif self.sliding_window is not None:
att_masks = _make_sliding_window_bias(
prefill_meta.prompt_lens, self.sliding_window,
attn_metadata.prompt_lens, self.sliding_window,
query.dtype) # type: ignore
else:
att_masks = [None] * len(prefill_meta.prompt_lens)
prefill_meta.attn_bias = att_masks
att_masks = [None] * len(attn_metadata.prompt_lens)
attn_metadata.attn_bias = att_masks

query = query.movedim(0, query.dim() - 2)
key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2)

start = 0
out = torch.empty((num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for prompt_len, mask in zip(prefill_meta.prompt_lens,
prefill_meta.attn_bias):
output = torch.empty(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype)
for prompt_len, mask in zip(attn_metadata.prompt_lens,
attn_metadata.attn_bias):
end = start + prompt_len
sub_out = scaled_dot_product_attention(
query[:, start:end, :],
Expand All @@ -194,32 +174,28 @@ def forward(
dropout_p=0.0,
is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0)
out[start:end, :, :] = sub_out
output[start:end, :, :] = sub_out
start = end
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else:
# prefix-enabled attention
raise RuntimeError(
"Torch SDPA backend doesn't support prefix decoding.")

if decode_meta := attn_metadata.decode_metadata:
else:
# Decoding run.
out = PagedAttention.forward_decode(
decode_query,
output = PagedAttention.forward_decode(
query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.context_lens,
decode_meta.max_context_len,
attn_metadata.block_tables,
attn_metadata.context_lens,
attn_metadata.max_context_len,
attn_metadata.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale,
)
assert out.shape == output[num_prefill_tokens:].shape
output[num_prefill_tokens:]

# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
Expand All @@ -241,7 +217,7 @@ def _make_alibi_bias(
bias = bias[None, :] - bias[:, None]

num_heads = alibi_slopes.shape[0]
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, prompt_len, prompt_len),
Expand Down
10 changes: 10 additions & 0 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self, model_config: ModelConfig, cache_config: CacheConfig,
assert lora_config is None, "cpu backend doesn't support LoRA"
model_config = _verify_and_get_model_config(model_config)
cache_config = _verify_and_get_cache_config(cache_config)
scheduler_config = _verify_and_get_scheduler_config(scheduler_config)

self.model_config = model_config
self.cache_config = cache_config
Expand Down Expand Up @@ -116,6 +117,15 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
return config


def _verify_and_get_scheduler_config(
config: SchedulerConfig) -> SchedulerConfig:
if config.chunked_prefill_enabled:
logger.warning("Chunked prefill is not supported on CPU, disable it.")
config.chunked_prefill_enabled = False

return config


def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
_GB = 1 << 30
if config.enable_prefix_caching:
Expand Down
1 change: 0 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,6 @@ def is_pin_memory_available() -> bool:
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif is_cpu():
print_warning_once("Pin memory is not supported on CPU.")
return False
return True

Expand Down
Loading

0 comments on commit d9b8060

Please sign in to comment.