diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a99985d9abc4..e9262b57d0867 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,23 +212,11 @@ define_gpu_extension_target( set(VLLM_PUNICA_EXT_SRC "csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu" "csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu" "csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu" "csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu" "csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu" "csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu" - "csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu" "csrc/punica/punica_ops.cc") # diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu deleted file mode 100644 index e8202dff561d9..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu deleted file mode 100644 index 3e7cf31dead0f..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu deleted file mode 100644 index 68277fa6b7d56..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu deleted file mode 100644 index 3b7531b8fbcfc..0000000000000 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu deleted file mode 100644 index b3b74aa3ec904..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu deleted file mode 100644 index 3cc87f5df76a1..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu deleted file mode 100644 index 9eda98bd8ddcf..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu deleted file mode 100644 index 060f9ebb8c2b1..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu deleted file mode 100644 index b37e44570bf40..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu deleted file mode 100644 index 06718cbb0a3e9..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu deleted file mode 100644 index 41fb0e45ef4e6..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu deleted file mode 100644 index 50b7ead9fcefd..0000000000000 --- a/csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu +++ /dev/null @@ -1,4 +0,0 @@ -#include "bgmv_config.h" -#include "bgmv_impl.cuh" - -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py index c347d4f2ab9f4..9bf7f6358880f 100644 --- a/csrc/punica/bgmv/generator.py +++ b/csrc/punica/bgmv/generator.py @@ -18,6 +18,26 @@ if weight_dtype == "fp32": # FP32 weights are not supported. continue + if output_dtype == "fp32": + # LoRA A matrix. + if input_dtype != weight_dtype: + # NOTE(woosuk): While Punica supports the case where the + # input and weight dtypes are different, we only generate + # the kernels the same dtypes to reduce the binary size. + continue + elif input_dtype == "fp32": + # LoRA B matrix. + if output_dtype != weight_dtype: + # NOTE(woosuk): While Punica supports the case where the + # output and weight dtypes are different, we only generate + # the kernels the same dtypes to reduce the binary size. + continue + elif not (input_dtype == output_dtype == weight_dtype): + # NOTE(woosuk): While Punica supports mixed data types for + # input, output, and weight, we only generate the kernels with + # the same data types to reduce the binary size. + continue + kernel_definition = TEMPLATE.format( input_dtype=DTYPE_MAP[input_dtype], output_dtype=DTYPE_MAP[output_dtype], diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index 7ebfd851c4feb..a1eaa90e85f27 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -50,6 +50,23 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t num_layers, int64_t layer_idx, float scale) { + // NOTE(woosuk): While Punica supports various combinations of input/output + // data types, we limit the supported data types to reduce the binary size. + constexpr bool is_input_float = std::is_same::value; + constexpr bool is_output_float = std::is_same::value; + if (is_input_float) { + if (!std::is_same::value) { + return false; + } + } else if (is_output_float) { + if (!std::is_same::value) { + return false; + } + } else if (!(std::is_same::value && + std::is_same::value)) { + return false; + } + switch (pack_u32(in_features, out_features)) { #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ case pack_u32(feat_in, feat_out): \ diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index e9e0c8554c1ef..1616fdfd4cff9 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -413,7 +413,9 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, def _pretest(): linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, - 1024, vocab_size) + 1024, + vocab_size, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( @@ -445,7 +447,7 @@ def _pretest(): num_inputs=8 * num_loras, # * 3, input_size=(1, 1024), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -494,7 +496,7 @@ def _pretest(): num_inputs=8 * num_loras * 3, input_size=(1, 1024), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -533,11 +535,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: def create_random_linear_parallel_layer(): if orientation == "row": - linear = RowParallelLinear(4096, 4096, bias=False) + linear = RowParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = RowParallelLinearWithLoRA(linear) else: - linear = ColumnParallelLinear(4096, 4096, bias=False) + linear = ColumnParallelLinear(4096, + 4096, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = ColumnParallelLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) @@ -561,7 +569,7 @@ def create_random_linear_parallel_layer(): num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -600,7 +608,7 @@ def create_random_linear_parallel_layer(): num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -633,15 +641,24 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: def create_column_parallel_packed_layer(): if repeats == 2: linear = MergedColumnParallelLinear(4096, [4096] * repeats, - bias=False) + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = MergedColumnParallelLinearWithLoRA(linear) elif repeats == 3: - linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = MergedQKVParallelLinearWithLora(linear) else: - linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear = QKVParallelLinear(4096, + 64, + 32, + bias=False, + params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) lora_linear = QKVParallelLinearWithLora(linear) @@ -676,7 +693,7 @@ class FakeConfig: num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -716,7 +733,7 @@ class FakeConfig: num_inputs=32 * num_loras, input_size=(1, 4096), input_range=(0, 1), - input_type=torch.float32, + input_type=torch.float16, ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping)