Skip to content

Commit

Permalink
1.18.0 fast-forward merge (#278)
Browse files Browse the repository at this point in the history
Merging missing features to 1.18.0
  • Loading branch information
kzawora-intel authored Sep 12, 2024
2 parents 73af823 + 6a734f4 commit 4173298
Show file tree
Hide file tree
Showing 11 changed files with 542 additions and 474 deletions.
22 changes: 11 additions & 11 deletions README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,33 +455,33 @@ Environment variables
- `VLLM_{phase}_{dim}_BUCKET_{param}` - collection of 12 environment
variables configuring ranges of bucketing mechanism
- `{phase}` is either `PROMPT` or `DECODE`
- `{dim}` is either `BS` or `SEQ`
- `{dim}` is either `BS`, `SEQ` or `BLOCK`
- `{param}` is either `MIN`, `STEP` or `MAX`
- Default values:
- Prompt:
- batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`): `1`
- batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `32`
- batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`): `min(max_num_seqs, 32)`
- batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`):
`min(max_num_seqs, 64)`
- sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`):
`block_size`
- sequence length step
(`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size`
- sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`):
`1024`
`max_model_len`

- Decode:
- batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1`
- batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `min(max_num_seqs, 32)`
- batch size step (`VLLM_DECODE_BS_BUCKET_STEP`):
`128`
`min(max_num_seqs, 32)`
- batch size max (`VLLM_DECODE_BS_BUCKET_MAX`):
`max_num_seqs`
- sequence length min (`VLLM_DECODE_SEQ_BUCKET_MIN`):
`block_size`
- sequence length step
(`VLLM_DECODE_SEQ_BUCKET_STEP`): `block_size`
- sequence length max (`VLLM_DECODE_SEQ_BUCKET_MAX`):
`2048`
- block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`):
`128`
- block size step
(`VLLM_DECODE_BLOCK_BUCKET_STEP`): `128`
- block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`):
`max(128, (max_num_seqs*max_model_len)/block_size)`

Additionally, there are HPU PyTorch Bridge environment variables
impacting vLLM execution:
Expand Down
14 changes: 7 additions & 7 deletions docs/source/getting_started/gaudi-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -335,19 +335,19 @@ Environment variables

- Prompt:
- batch size min (``VLLM_PROMPT_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``32``
- batch size step (``VLLM_PROMPT_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_PROMPT_BS_BUCKET_MAX``): ``min(max_num_seqs, 64)``
- sequence length min (``VLLM_PROMPT_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_PROMPT_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``1024``
- sequence length max (``VLLM_PROMPT_SEQ_BUCKET_MAX``): ``max_model_len``

- Decode:
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``1``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``128``
- batch size min (``VLLM_DECODE_BS_BUCKET_MIN``): ``min(max_num_seqs, 32)``
- batch size step (``VLLM_DECODE_BS_BUCKET_STEP``): ``min(max_num_seqs, 32)``
- batch size max (``VLLM_DECODE_BS_BUCKET_MAX``): ``max_num_seqs``
- sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``block_size``
- sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``block_size``
- sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``2048``
- sequence length min (``VLLM_DECODE_SEQ_BUCKET_MIN``): ``128``
- sequence length step (``VLLM_DECODE_SEQ_BUCKET_STEP``): ``128``
- sequence length max (``VLLM_DECODE_SEQ_BUCKET_MAX``): ``max(128, (max_num_seqs*max_model_len)/block_size)``


Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM execution:
Expand Down
93 changes: 64 additions & 29 deletions tests/lora/test_lora_hpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from vllm.hpu.ops import LoraMask
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice

from .utils import DummyLoRAManager
Expand All @@ -19,7 +20,19 @@
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
MAX_LORAS = 8


def createLoraMask(indices, batch_size, seq_len, max_loras, max_lora_rank,
lora_dtype):
indices = indices.view(-1, 1)
mask = torch.arange(max_loras * max_lora_rank, device=indices.device)
mask = mask.view(1, -1)
mask = ((mask >= ((indices) * max_lora_rank)) *
(mask < ((indices + 1) * max_lora_rank))).to(dtype=lora_dtype)
mask = mask.view(batch_size, 1,
-1).expand(batch_size, seq_len,
-1).reshape(batch_size * seq_len, -1)
return mask


@pytest.mark.parametrize("m", TENSOR_SIZES)
Expand All @@ -39,32 +52,40 @@ def test_apply_lora(m, n, k, rank, dtype) -> None:
input = torch.rand(k, n, device="hpu", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling

lora_a_stack = torch.zeros(MAX_LORAS + 1,
lora_a_stack = torch.zeros(8,
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="hpu",
dtype=dtype)
lora_b_stack = torch.zeros(MAX_LORAS + 1,
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="hpu",
dtype=dtype)
for i in range(MAX_LORAS):
for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"),
output)
indices = torch.randint(0,
lora_a_stack.shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora(input, lora_a_stack, lora_b_stack, indices, output)

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="hpu"), output)
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora(input, lora_a_stack, lora_b_stack, indices, output)
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()
Expand Down Expand Up @@ -99,39 +120,46 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
dim=1)

lora_a_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(MAX_LORAS):
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output,
(m // 2, m // 2))
indices = torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (m // 2, m // 2))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (m // 2, m // 2))
assert torch.allclose(torch.zeros_like(output), output)

Expand Down Expand Up @@ -166,36 +194,36 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
dim=1)

lora_a_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(MAX_LORAS):
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
Expand All @@ -204,17 +232,24 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T

output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output,
(qkv[0], qkv[1], qkv[2]))
indices = torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (qkv[0], qkv[1], qkv[2]))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (qkv[0], qkv[1], qkv[2]))
assert torch.allclose(torch.zeros_like(output), output)

Expand Down
Loading

0 comments on commit 4173298

Please sign in to comment.