Skip to content

Commit

Permalink
Sperate model runner
Browse files Browse the repository at this point in the history
  • Loading branch information
bigPYJ1151 committed Apr 3, 2024
1 parent 77a6572 commit 24a5a18
Show file tree
Hide file tree
Showing 4 changed files with 409 additions and 16 deletions.
6 changes: 3 additions & 3 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def forward(
attn_metadata.kv_cache_dtype)

if attn_metadata.is_prompt:
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if (kv_cache is None or attn_metadata.block_tables is None):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
value = value.repeat_interleave(self.num_queries_per_kv,
Expand Down Expand Up @@ -221,8 +221,8 @@ def _make_alibi_bias(
bias = bias[None, :] - bias[:, None]

num_heads = alibi_slopes.shape[0]
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)
bias.mul_(alibi_slopes[:, None, None])
bias = bias[None, :].expand(num_heads, prompt_len, prompt_len)\
.mul(alibi_slopes[:, None, None])
inf_mask = torch.empty(
(1, prompt_len, prompt_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
Expand Down
1 change: 0 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def is_pin_memory_available() -> bool:
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif is_cpu():
print_warning_once("Pin memory is not supported on CPU.")
return False
return True

Expand Down
Loading

0 comments on commit 24a5a18

Please sign in to comment.