Skip to content

Commit

Permalink
[Misc][XPU] Upgrade to Pytorch 2.5 for xpu backend (vllm-project#9823)
Browse files Browse the repository at this point in the history
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: yan ma <[email protected]>
Co-authored-by: Kunshang Ji <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
  • Loading branch information
2 people authored and JC1DA committed Nov 11, 2024
1 parent 3b59c00 commit 93968ea
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 46 deletions.
12 changes: 11 additions & 1 deletion Dockerfile.xpu
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,19 @@ COPY requirements-common.txt /workspace/vllm/requirements-common.txt

RUN --mount=type=cache,target=/root/.cache/pip \
pip install --no-cache-dir \
--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \
-r requirements-xpu.txt

RUN git clone https://github.com/intel/pti-gpu && \
cd pti-gpu/sdk && \
git checkout 6c491f07a777ed872c2654ca9942f1d0dde0a082 && \
mkdir build && \
cd build && \
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \
make -j && \
cmake --install . --config Release --prefix "/usr/local"

ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/"

COPY . .
ARG GIT_REPO_CHECK
RUN --mount=type=bind,source=.git,target=.git \
Expand Down
8 changes: 4 additions & 4 deletions requirements-xpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ packaging
setuptools-scm>=8
wheel
jinja2
# Following pkgs retrieved from https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
torch == 2.3.1+cxx11.abi
intel-extension-for-pytorch == 2.3.110+xpu
oneccl_bind_pt == 2.3.100+xpu

torch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/torch-2.5.0a0%2Bgite84e33f-cp310-cp310-linux_x86_64.whl
intel-extension-for-pytorch @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/intel_extension_for_pytorch-2.5.10%2Bgit9d489a8-cp310-cp310-linux_x86_64.whl
oneccl_bind_pt @ https://intel-extension-for-pytorch.s3.us-east-1.amazonaws.com/ipex_dev/xpu/oneccl_bind_pt-2.5.0%2Bxpu-cp310-cp310-linux_x86_64.whl

triton-xpu == 3.0.0b1
33 changes: 8 additions & 25 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,12 @@ def paged_attention_v1(
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
head_mapping = torch.arange(
0,
num_kv_heads,
device=query.device,
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v1( # type: ignore
ipex.llm.modules.PagedAttention.single_query_kv_attention(
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
num_queries_per_tokens,
scale,
block_tables,
context_lens,
Expand Down Expand Up @@ -124,26 +116,15 @@ def paged_attention_v2(
assert kv_cache_dtype == "auto"
num_heads = out.size(1)
num_queries_per_tokens = num_heads // num_kv_heads
head_mapping = torch.arange(
0,
num_kv_heads,
dtype=torch.int32,
device=query.device,
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v2( # type: ignore
ipex.llm.modules.PagedAttention.single_query_kv_attention(
out,
exp_sum,
max_logits,
tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
num_queries_per_tokens,
scale,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
Expand Down Expand Up @@ -202,6 +183,7 @@ def varlen_attention(
is_causal: bool,
return_softmax: bool,
gen_: torch.Generator,
logits_soft_cap: float,
) -> None:
ipex.llm.functional.varlen_attention(query.contiguous(),
key.contiguous(),
Expand All @@ -210,7 +192,8 @@ def varlen_attention(
max_seqlen_q, max_seqlen_k,
pdropout, softmax_scale,
zero_tensors, is_causal,
return_softmax, gen_)
return_softmax, gen_,
logits_soft_cap)

@staticmethod
def reshape_and_cache(
Expand Down
36 changes: 20 additions & 16 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def __init__(
if blocksparse_params is not None:
raise ValueError(
"IPEX backend does not support block-sparse attention.")
if logits_soft_cap is not None:
raise ValueError("IPEX backend does not support logits_soft_cap.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
Expand All @@ -135,6 +133,9 @@ def __init__(
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
if logits_soft_cap is None:
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap

supported_head_sizes = PagedAttention.get_supported_head_sizes()
if head_size not in supported_head_sizes:
Expand Down Expand Up @@ -239,20 +240,23 @@ def forward(
(num_tokens, self.num_heads, self.head_size),
dtype=query.dtype,
device=query.device)
ipex_ops.varlen_attention(query,
key,
value,
output,
attn_metadata.seqlen_q,
attn_metadata.seqlen_q,
attn_metadata.max_seqlen,
attn_metadata.max_seqlen,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None)
ipex_ops.varlen_attention(
query,
key,
value,
output,
attn_metadata.seqlen_q,
attn_metadata.seqlen_q,
attn_metadata.max_seqlen,
attn_metadata.max_seqlen,
pdropout=0.0,
softmax_scale=self.scale,
zero_tensors=False,
is_causal=True,
return_softmax=False,
gen_=None,
logits_soft_cap=self.logits_soft_cap,
)
else:
# prefix-enabled attention
raise RuntimeError(
Expand Down

0 comments on commit 93968ea

Please sign in to comment.