Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LoRA test by handling mask creation inside the test #270

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 64 additions & 29 deletions tests/lora/test_lora_hpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import torch

from vllm.hpu.ops import LoraMask
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice

from .utils import DummyLoRAManager
Expand All @@ -19,7 +20,19 @@
torch.float16: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
MAX_LORAS = 8


def createLoraMask(indices, batch_size, seq_len, max_loras, max_lora_rank,
lora_dtype):
indices = indices.view(-1, 1)
mask = torch.arange(max_loras * max_lora_rank, device=indices.device)
mask = mask.view(1, -1)
mask = ((mask >= ((indices) * max_lora_rank)) *
(mask < ((indices + 1) * max_lora_rank))).to(dtype=lora_dtype)
mask = mask.view(batch_size, 1,
-1).expand(batch_size, seq_len,
-1).reshape(batch_size * seq_len, -1)
return mask


@pytest.mark.parametrize("m", TENSOR_SIZES)
Expand All @@ -39,32 +52,40 @@ def test_apply_lora(m, n, k, rank, dtype) -> None:
input = torch.rand(k, n, device="hpu", dtype=dtype)
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling

lora_a_stack = torch.zeros(MAX_LORAS + 1,
lora_a_stack = torch.zeros(8,
hlahkar marked this conversation as resolved.
Show resolved Hide resolved
1,
lora.lora_a.shape[1],
lora.lora_a.shape[0],
device="hpu",
dtype=dtype)
lora_b_stack = torch.zeros(MAX_LORAS + 1,
lora_b_stack = torch.zeros(8,
1,
lora.lora_b.shape[1],
lora.lora_b.shape[0],
device="hpu",
dtype=dtype)
for i in range(MAX_LORAS):
for i in range(lora_a_stack.shape[0]):
lora_a_stack[i][0] = lora.lora_a.T
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"),
output)
indices = torch.randint(0,
lora_a_stack.shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora(input, lora_a_stack, lora_b_stack, indices, output)

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora(input, lora_a_stack, lora_b_stack,
torch.full((len(input), ), -1, device="hpu"), output)
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora(input, lora_a_stack, lora_b_stack, indices, output)
assert torch.allclose(torch.zeros_like(output), output)

manager.reset_lora()
Expand Down Expand Up @@ -99,39 +120,46 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
dim=1)

lora_a_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_1.lora_a.shape[1],
lora_1.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_1.lora_b.shape[1],
lora_1.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(MAX_LORAS):
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_1.lora_a.T
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
lora_a_stacks[1][i][0] = lora_2.lora_a.T
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T

output = torch.zeros(k, m, device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output,
(m // 2, m // 2))
indices = torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (m // 2, m // 2))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (m // 2, m // 2))
assert torch.allclose(torch.zeros_like(output), output)

Expand Down Expand Up @@ -166,36 +194,36 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
dim=1)

lora_a_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_q.lora_a.shape[1],
lora_q.lora_a.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_k.lora_a.shape[1],
lora_k.lora_a.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
lora_b_stacks = [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_q.lora_b.shape[1],
lora_q.lora_b.shape[0],
device="hpu",
dtype=dtype)
] + [
torch.zeros(MAX_LORAS + 1,
torch.zeros(8,
1,
lora_k.lora_b.shape[1],
lora_k.lora_b.shape[0],
device="hpu",
dtype=dtype) for i in range(2)
]
for i in range(MAX_LORAS):
for i in range(lora_a_stacks[0].shape[0]):
lora_a_stacks[0][i][0] = lora_q.lora_a.T
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
lora_a_stacks[1][i][0] = lora_k.lora_a.T
Expand All @@ -204,17 +232,24 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T

output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype)
_apply_lora_packed_nslice(
input, lora_a_stacks, lora_b_stacks,
torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output,
(qkv[0], qkv[1], qkv[2]))
indices = torch.randint(0,
lora_a_stacks[0].shape[0], (len(input), ),
device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (qkv[0], qkv[1], qkv[2]))

rtol, atol = TOLERANCES[dtype]
assert torch.allclose(expected, output, rtol=rtol, atol=atol)

output[:] = 0
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
torch.full((len(input), ), -1, device="hpu"),
indices = torch.full((len(input), ), -1, device="hpu")
mask = createLoraMask(indices, k, 1, 8, rank, dtype)
LoraMask.setLoraMask(mask)

_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices,
output, (qkv[0], qkv[1], qkv[2]))
assert torch.allclose(torch.zeros_like(output), output)

Expand Down
Loading