Skip to content

Commit

Permalink
[Kernel] support non-zero cuda devices in punica kernels (#3636)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeejeelee authored Mar 27, 2024
1 parent 0dc7227 commit 566b57c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 62 deletions.
4 changes: 3 additions & 1 deletion csrc/punica/punica_ops.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>

#include <c10/cuda/CUDAGuard.h>
#include <cstdint>

#include "bgmv/bgmv_config.h"
Expand Down Expand Up @@ -91,6 +91,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch
Expand Down Expand Up @@ -322,6 +323,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
CHECK_EQ(w.size(2), h_out);
CHECK_EQ(indicies.size(0), x.size(0));
CHECK_EQ(y.size(0), x.size(0));
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
bool ok = false;
if (h_in < 65536 && h_out < 65536) {
// TODO: See if we can get rid of this massive nested switch
Expand Down
87 changes: 26 additions & 61 deletions tests/lora/test_punica.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,40 +49,34 @@ def _lora_ref_impl(
32768, 33024
]
SEED = [0xabcdabcd987]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]


@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_lora_correctness(dtype_str, h1, h2, seed):
def test_lora_correctness(dtype_str, h1, h2, seed, device):
torch.manual_seed(seed)
num_loras = 4
num_layers = 1
r = 8
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")

wa_T_all = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wb_T_all = torch.randn(num_loras,
num_layers,
h2,
r,
dtype=dtype,
device=device)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
torch.set_default_device(device)

wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype)
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)

for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, h2, dtype=dtype, device=device)
x = torch.randn(bs, h1, dtype=dtype)
y = torch.randn(bs, h2, dtype=dtype)

y_ref = y.clone()
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
Expand All @@ -98,8 +92,9 @@ def test_lora_correctness(dtype_str, h1, h2, seed):
@pytest.mark.parametrize("h1", H1)
@pytest.mark.parametrize("h2", H2)
@pytest.mark.parametrize("seed", SEED)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_lora_correctness_slice(dtype_str, h1, h2, seed):
def test_lora_correctness_slice(dtype_str, h1, h2, seed, device):
if h2 % 3 != 0 or h2 // 3 not in H1:
pytest.skip("h2 must be divisible by 3 and in supported shapes")
torch.manual_seed(seed)
Expand All @@ -109,50 +104,20 @@ def test_lora_correctness_slice(dtype_str, h1, h2, seed):
bs = 32
scale = 0.123
dtype = getattr(torch, dtype_str)
device = torch.device("cuda")

wa_T_all_0 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wa_T_all_1 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wa_T_all_2 = torch.randn(num_loras,
num_layers,
r,
h1,
dtype=dtype,
device=device)
wb_T_all_0 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
wb_T_all_1 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)
wb_T_all_2 = torch.randn(num_loras,
num_layers,
h2 // 3,
r,
dtype=dtype,
device=device)

indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
torch.set_default_device(device)

wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)

indices = torch.randint(num_loras, (bs, ), dtype=torch.long)

for layer_idx in range(num_layers):
x = torch.randn(bs, h1, dtype=dtype, device=device)
y = torch.randn(bs, h2, dtype=dtype, device=device)
x = torch.randn(bs, h1, dtype=dtype)
y = torch.randn(bs, h2, dtype=dtype)
s = h2 // 3

y_ref = y.clone()
Expand Down

0 comments on commit 566b57c

Please sign in to comment.