diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 9b76b98ab3322..d2906914f927e 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -47,6 +47,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 2178266d2e0c8..5ab863eea94b3 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -62,7 +62,7 @@ def test_baichuan_lora(baichuan_lora_files): @pytest.mark.skip("Requires multiple GPUs") -def test_llama_tensor_parallel_equality(baichuan_lora_files): +def test_baichuan_tensor_parallel_equality(baichuan_lora_files): # Cannot use as it will initialize torch.cuda too early... # if torch.cuda.device_count() < 4: # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index cab8b44ccd2df..8b174f01d87d4 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -72,6 +72,7 @@ def _lora_ref_impl( 11008, 13824, 14336, + 15360, 22016, 24576, 27392,