diff --git a/tests/lora/test_lora_hpu.py b/tests/lora/test_lora_hpu.py index ddbab66e166b3..01b6472745e1c 100644 --- a/tests/lora/test_lora_hpu.py +++ b/tests/lora/test_lora_hpu.py @@ -1,6 +1,7 @@ import pytest import torch +from vllm.hpu.ops import LoraMask from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice from .utils import DummyLoRAManager @@ -19,7 +20,19 @@ torch.float16: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } -MAX_LORAS = 8 + + +def createLoraMask(indices, batch_size, seq_len, max_loras, max_lora_rank, + lora_dtype): + indices = indices.view(-1, 1) + mask = torch.arange(max_loras * max_lora_rank, device=indices.device) + mask = mask.view(1, -1) + mask = ((mask >= ((indices) * max_lora_rank)) * + (mask < ((indices + 1) * max_lora_rank))).to(dtype=lora_dtype) + mask = mask.view(batch_size, 1, + -1).expand(batch_size, seq_len, + -1).reshape(batch_size * seq_len, -1) + return mask @pytest.mark.parametrize("m", TENSOR_SIZES) @@ -39,32 +52,40 @@ def test_apply_lora(m, n, k, rank, dtype) -> None: input = torch.rand(k, n, device="hpu", dtype=dtype) expected = input @ lora.lora_a @ lora.lora_b * lora.scaling - lora_a_stack = torch.zeros(MAX_LORAS + 1, + lora_a_stack = torch.zeros(8, 1, lora.lora_a.shape[1], lora.lora_a.shape[0], device="hpu", dtype=dtype) - lora_b_stack = torch.zeros(MAX_LORAS + 1, + lora_b_stack = torch.zeros(8, 1, lora.lora_b.shape[1], lora.lora_b.shape[0], device="hpu", dtype=dtype) - for i in range(MAX_LORAS): + for i in range(lora_a_stack.shape[0]): lora_a_stack[i][0] = lora.lora_a.T lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T output = torch.zeros(k, m, device="hpu", dtype=dtype) - _apply_lora(input, lora_a_stack, lora_b_stack, - torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), - output) + indices = torch.randint(0, + lora_a_stack.shape[0], (len(input), ), + device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora(input, lora_a_stack, lora_b_stack, indices, output) + rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora(input, lora_a_stack, lora_b_stack, - torch.full((len(input), ), -1, device="hpu"), output) + indices = torch.full((len(input), ), -1, device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora(input, lora_a_stack, lora_b_stack, indices, output) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() @@ -99,7 +120,7 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: dim=1) lora_a_stacks = [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_1.lora_a.shape[1], lora_1.lora_a.shape[0], @@ -107,31 +128,38 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: dtype=dtype) for i in range(2) ] lora_b_stacks = [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_1.lora_b.shape[1], lora_1.lora_b.shape[0], device="hpu", dtype=dtype) for i in range(2) ] - for i in range(MAX_LORAS): + for i in range(lora_a_stacks[0].shape[0]): lora_a_stacks[0][i][0] = lora_1.lora_a.T lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T lora_a_stacks[1][i][0] = lora_2.lora_a.T lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T output = torch.zeros(k, m, device="hpu", dtype=dtype) - _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, - torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output, - (m // 2, m // 2)) + indices = torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices, + output, (m // 2, m // 2)) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, - torch.full((len(input), ), -1, device="hpu"), + indices = torch.full((len(input), ), -1, device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices, output, (m // 2, m // 2)) assert torch.allclose(torch.zeros_like(output), output) @@ -166,14 +194,14 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: dim=1) lora_a_stacks = [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_q.lora_a.shape[1], lora_q.lora_a.shape[0], device="hpu", dtype=dtype) ] + [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_k.lora_a.shape[1], lora_k.lora_a.shape[0], @@ -181,21 +209,21 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: dtype=dtype) for i in range(2) ] lora_b_stacks = [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_q.lora_b.shape[1], lora_q.lora_b.shape[0], device="hpu", dtype=dtype) ] + [ - torch.zeros(MAX_LORAS + 1, + torch.zeros(8, 1, lora_k.lora_b.shape[1], lora_k.lora_b.shape[0], device="hpu", dtype=dtype) for i in range(2) ] - for i in range(MAX_LORAS): + for i in range(lora_a_stacks[0].shape[0]): lora_a_stacks[0][i][0] = lora_q.lora_a.T lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T lora_a_stacks[1][i][0] = lora_k.lora_a.T @@ -204,17 +232,24 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T output = torch.zeros(k, sum(qkv), device="hpu", dtype=dtype) - _apply_lora_packed_nslice( - input, lora_a_stacks, lora_b_stacks, - torch.randint(0, MAX_LORAS, (len(input), ), device="hpu"), output, - (qkv[0], qkv[1], qkv[2])) + indices = torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices, + output, (qkv[0], qkv[1], qkv[2])) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, - torch.full((len(input), ), -1, device="hpu"), + indices = torch.full((len(input), ), -1, device="hpu") + mask = createLoraMask(indices, k, 1, 8, rank, dtype) + LoraMask.setLoraMask(mask) + + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, indices, output, (qkv[0], qkv[1], qkv[2])) assert torch.allclose(torch.zeros_like(output), output)