From a4562ab71bcf94623dfd4e4559df9332e2eb87a2 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Fri, 17 Nov 2023 05:09:33 +0000 Subject: [PATCH 01/31] FP6 quantization end-to-end. * Initialize the fp6-quant-kernel integration. * Add necessary parameters of kernel interfaces and the linear layer selection logic. * upload kernel code * The simple script for debugging. * fix typo * update * fix split k * Fix some errors and add test case. * Workspace for Inference Kernels (#1) * Add transform_param functions and update format. * kernel debug * fix include * Update core_ops.cpp * Add split k support * fix * Fix kernel error * update * update * Fix rebase errors. * Add missed include. * Fix the bug that the attribute uses the weight information for mem alloc. * Avoid GPU preallocation during weight loading. * Add support of larger shapes for gated activation kernel. * update * model update * fix all weight preprocessing * Add split-k heuristic. * Avoid reading scale attribute on non-quantized tensors. * Change the scales from attributes to new tensors. Provide the end-to-end script given HuggingFace model id. * Hard-coded commented out the scales in the kernel to workaround the bug. * Support the user config for quantization. Fix kernel bug. * Per operator test functions. * Multiply scales by 1e12 according to the kernel design. * Revert "Workspace for Inference Kernels (#1)". This reverts commit 1528732bd2ca54bae248846c6dac34729ac97cdf. * Remove the format-only changes. * Put the quantization into the transform_param function. --------- Co-authored-by: Shiyang Chen Co-authored-by: Haojun Xia --- debug/README.md | 1 + debug/clean.sh | 2 + debug/run_pipeline.py | 18 ++ deepspeed/inference/v2/config_v2.py | 12 + .../inference/v2/kernels/core_ops/__init__.py | 1 + .../v2/kernels/core_ops/core_ops.cpp | 7 + .../cuda_linear/GenMatrix_QuantLLM.cpp | 300 ++++++++++++++++++ .../core_ops/cuda_linear/GenMatrix_QuantLLM.h | 41 +++ .../core_ops/cuda_linear/Kernel_QuantGEMM.cuh | 211 ++++++++++++ .../core_ops/cuda_linear/Kernel_Reduction.cuh | 47 +++ .../kernels/core_ops/cuda_linear/Launcher.cu | 287 +++++++++++++++++ .../kernels/core_ops/cuda_linear/__init__.py | 6 + .../core_ops/cuda_linear/cuda_linear.py | 68 ++++ .../cuda_linear/cuda_linear_kernels.cpp | 94 ++++++ .../cuda_linear/cuda_linear_kernels.cu | 58 ++++ .../cuda_linear/cuda_linear_kernels.h | 25 ++ .../core_ops/cuda_linear/quant_gemm_api.cuh | 25 ++ .../core_ops/cuda_linear/utils/Configs.h | 89 ++++++ .../cuda_linear/utils/PTX_cp.async.cuh | 59 ++++ .../core_ops/cuda_linear/utils/PTX_mma.cuh | 113 +++++++ .../core_ops/cuda_linear/utils/Utils_Core.cuh | 212 +++++++++++++ .../core_ops/cuda_linear/utils/Utils_GMem.cuh | 97 ++++++ .../utils/Utils_ParallelDequant.cuh | 121 +++++++ .../gated_activation_kernels_cuda.cu | 3 + .../flat_model_helpers.py | 30 +- deepspeed/inference/v2/modules/heuristics.py | 16 +- .../implementations/linear/__init__.py | 2 + .../linear/quantized_linear.py | 236 ++++++++++++++ op_builder/inference_core_ops.py | 5 + 29 files changed, 2173 insertions(+), 13 deletions(-) create mode 100644 debug/README.md create mode 100644 debug/clean.sh create mode 100644 debug/run_pipeline.py create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_Reduction.cuh create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/quant_gemm_api.cuh create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Configs.h create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_cp.async.cuh create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_mma.cuh create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_Core.cuh create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh create mode 100644 deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py diff --git a/debug/README.md b/debug/README.md new file mode 100644 index 000000000000..71e9e64c05cb --- /dev/null +++ b/debug/README.md @@ -0,0 +1 @@ +The files in this directory is only for debugging of FP6 quantization kernel integration. Will not merge. diff --git a/debug/clean.sh b/debug/clean.sh new file mode 100644 index 000000000000..7120c76321da --- /dev/null +++ b/debug/clean.sh @@ -0,0 +1,2 @@ +rm ~/.cache/torch_extensions/py38_cu118/inference_core_ops/*.o +rm ~/.cache/torch_extensions/py38_cu118/inference_core_ops/*.so diff --git a/debug/run_pipeline.py b/debug/run_pipeline.py new file mode 100644 index 000000000000..330c836abed8 --- /dev/null +++ b/debug/run_pipeline.py @@ -0,0 +1,18 @@ +import mii + + +def fake_request_texts(batch_size: int): + request_texts = ["Ha ha ha"] * batch_size + return request_texts + + +if __name__ == '__main__': + model_id = "meta-llama/Llama-2-7b-hf" + + batch_size = 32 + prompts = fake_request_texts(batch_size) + + pipe = mii.pipeline(model_name_or_path=model_id, + quantization_mode='wf6af16') + response = pipe(prompts, max_new_tokens=2) + print(f"{len(response)} responses.") diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 64e7e29b1844..60803ee39ccd 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -16,6 +16,16 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel): """ Number of devices to split the model across using tensor parallelism. """ +class QuantizationConfig(DeepSpeedConfigModel): + """ Configure tensor parallelism settings """ + + quantization_mode: str = None + """ The quantization mode in string format. The supported modes are as follows: + - 'wf6af16', weight-only quantization with FP6 weight and FP16 activation. + """ + # TODO: may reuse the constants in deepspeed/compression/constants.py + + class RaggedInferenceEngineConfig(DeepSpeedConfigModel): """ Sets parameters for DeepSpeed Inference Engine. """ @@ -29,3 +39,5 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel): """ Configuration for managing persistent state """ + + quantization: QuantizationConfig = Field({}, alias="quantization") diff --git a/deepspeed/inference/v2/kernels/core_ops/__init__.py b/deepspeed/inference/v2/kernels/core_ops/__init__.py index bbb53e5b58a2..1d16b484a560 100644 --- a/deepspeed/inference/v2/kernels/core_ops/__init__.py +++ b/deepspeed/inference/v2/kernels/core_ops/__init__.py @@ -8,3 +8,4 @@ from .cuda_layer_norm import * from .cuda_rms_norm import * from .gated_activations import * +from .cuda_linear import * diff --git a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp index 58df88e56136..fccb248816fc 100644 --- a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp @@ -8,6 +8,7 @@ #include "bias_activation.h" #include "blas.h" +#include "cuda_linear_kernels.h" #include "gated_activation_kernels.h" #include "layer_norm.h" #include "rms_norm.h" @@ -33,4 +34,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // rms_norm.h m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA"); m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA"); + + // cuda_linear_kernels.h + m.def("cuda_wf6af16_linear", &cuda_wf6af16_linear, "DeepSpeed Wf6Af16 linear in CUDA"); + m.def( + "preprocess_weight", &preprocess_weight, "preprocess the FP16 weight to be 2bit and 4 bit"); + m.def("preprocess_scales", &preprocess_scales, "preprocess the FP16 scales"); } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp new file mode 100644 index 000000000000..ba0592de08e4 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp @@ -0,0 +1,300 @@ +#include "GenMatrix_QuantLLM.h" +#include +#include +#include +#include + +using namespace std; + +void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], + unsigned char* FP6_Array) // padding 0 to the lowerest bit location +{ + Padded_FP6[0] = FP6_Array[0] & 0xfc; + Padded_FP6[1] = (FP6_Array[0] << 6) | ((FP6_Array[1] >> 2) & 0xfc); + Padded_FP6[2] = (FP6_Array[1] << 4) | ((FP6_Array[2] >> 4) & 0xfc); + Padded_FP6[3] = FP6_Array[2] << 2; + Padded_FP6[4] = FP6_Array[3] & 0xfc; + Padded_FP6[5] = (FP6_Array[3] << 6) | ((FP6_Array[4] >> 2) & 0xfc); + Padded_FP6[6] = (FP6_Array[4] << 4) | ((FP6_Array[5] >> 4) & 0xfc); + Padded_FP6[7] = FP6_Array[5] << 2; +} + +unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, + unsigned char B2, + unsigned char B3, + unsigned char B4) +{ + unsigned char out; + out = (B1 & 0xc0) | ((B2 & 0xc0) >> 2) | ((B3 & 0xc0) >> 4) | ((B4 & 0xc0) >> 6); + return out; +} + +unsigned char Extract_4_Bits_From_2_PaddedFP6( + unsigned char B1, + unsigned char + B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); +{ + unsigned char out; + out = ((B1 << 2) & 0xf0) | ((B2 >> 2) & 0x0f); + return out; +} + +// dealing with 4 1*8 blocks of FP6 +void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], + vector Seg_4bit[], + unsigned char* PTR_1, + unsigned char* PTR_2, + unsigned char* PTR_3, + unsigned char* PTR_4) +{ + unsigned char Padded_8_FP8[4][8]; + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); + // + unsigned char Seg1_Byte1_T[4]; + unsigned char Seg1_Byte2_T[4]; + unsigned char Seg2_Byte1_T[4]; + unsigned char Seg2_Byte2_T[4]; + unsigned char Seg2_Byte3_T[4]; + unsigned char Seg2_Byte4_T[4]; + for (int t = 0; t < 4; t++) { + Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0 + t * 2], + Padded_8_FP8[0][1 + t * 2], + Padded_8_FP8[1][0 + t * 2], + Padded_8_FP8[1][1 + t * 2]); + Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0 + t * 2], + Padded_8_FP8[2][1 + t * 2], + Padded_8_FP8[3][0 + t * 2], + Padded_8_FP8[3][1 + t * 2]); + Seg2_Byte1_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0 + t * 2], Padded_8_FP8[0][1 + t * 2]); + Seg2_Byte2_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0 + t * 2], Padded_8_FP8[1][1 + t * 2]); + Seg2_Byte3_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0 + t * 2], Padded_8_FP8[2][1 + t * 2]); + Seg2_Byte4_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0 + t * 2], Padded_8_FP8[3][1 + t * 2]); + } + // + for (int t = 0; t < 4; t++) { + Seg_2bit[t].push_back(Seg1_Byte1_T[t]); + Seg_2bit[t].push_back(Seg1_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte1_T[t]); + Seg_4bit[t].push_back(Seg2_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte3_T[t]); + Seg_4bit[t].push_back(Seg2_Byte4_T[t]); + } + return; +} + +void BitInterleaving_2bit(unsigned char* PTR_4Bytes) +{ + unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + // int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for + // bit-interleaving in QuantLLM + int order_2bit[16] = { + 2, 6, 10, 14, 4, 8, 12, 16, 1, 5, 9, 13, 3, 7, 11, 15}; // pre-defined order for + // bit-interleaving in QuantLLM + unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. + for (int i = 0; i < 16; i++) Frags_2bit[i] = (input << 2 * (order_2bit[i] - 1)) & 0xc0000000; + // + unsigned int output = 0x00000000; + for (int i = 0; i < 16; i++) output |= (Frags_2bit[i] >> (i * 2)); + // + *PTR_UINT = output; +} + +void BitInterleaving_4bit(unsigned char* PTR_4Bytes) +{ + unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + // int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in QuantLLM + int order_4bit[8] = { + 2, 6, 4, 8, 1, 5, 3, 7}; // pre-defined order for bit-interleaving in QuantLLM + unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. + for (int i = 0; i < 8; i++) Frags_4bit[i] = (input << 4 * (order_4bit[i] - 1)) & 0xf0000000; + // + unsigned int output = 0x00000000; + for (int i = 0; i < 8; i++) output |= (Frags_4bit[i] >> (i * 4)); + // + *PTR_UINT = output; +} + +short GetShort(unsigned char* Scale_In, size_t row, size_t col, size_t BytesPerRow) +{ + unsigned char* PTR_8bit = Scale_In; + PTR_8bit += row * BytesPerRow + col * 2; + short* PTR_16bit = reinterpret_cast(PTR_8bit); + return (*PTR_16bit); +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ +void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, + unsigned char* Weight_2bit, + unsigned char* Weight_4bit, + size_t M, + size_t K) +{ + assert(M % 64 == 0); + assert(K % 64 == 0); + // + vector A_Segment_2bit[32]; + vector A_Segment_4bit[32]; + // + size_t BytesPerRow = K * 6 / 8; +// Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. +#pragma omp parallel for + for (size_t i = 0; i < M / 64; i++) // + { + for (size_t j = 0; j < K / 16; j++) { + for (size_t k = 0; k < 64 / 16; k++) { + size_t row = i * 64 + k * 16; + size_t col = j * 16; + unsigned char* StartPTR_1 = Weight_6bit + row * BytesPerRow + col * 6 / 8; + unsigned char* StartPTR_2 = StartPTR_1 + 8 * BytesPerRow; + unsigned char* StartPTR_3 = StartPTR_1 + 8 * 6 / 8; + unsigned char* StartPTR_4 = StartPTR_2 + 8 * 6 / 8; + // Dealing with each 16*16 blocks then... + for (int l = 0; l < 8; l++) + Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l * 4], + &A_Segment_4bit[l * 4], + StartPTR_1 + l * BytesPerRow, + StartPTR_2 + l * BytesPerRow, + StartPTR_3 + l * BytesPerRow, + StartPTR_4 + l * BytesPerRow); + } + } + } + // Verifying the length of 2_bit segments and 4_bit segments + size_t BytesPerThread_2bit = M * K * 2 / 8 / 32; + size_t BytesPerThread_4bit = M * K * 4 / 8 / 32; + for (int i = 0; i < 32; i++) { + assert(A_Segment_2bit[i].size() == BytesPerThread_2bit); + assert(A_Segment_4bit[i].size() == BytesPerThread_4bit); + } +// Pass-2: Optimizing coleasced global memory access +#pragma omp parallel for collapse(2) + for (int t = 0; t < 32; t++) { + for (int b = 0; b < 4; b++) { + for (size_t i = 0; i < BytesPerThread_2bit / 4; i++) + Weight_2bit[i * 128 + t * 4 + (3 - b)] = + A_Segment_2bit[t] + [i * 4 + b]; // why (3-b): special byte order within a register + for (size_t i = 0; i < BytesPerThread_4bit / 4; i++) + Weight_4bit[i * 128 + t * 4 + (3 - b)] = + A_Segment_4bit[t] + [i * 4 + b]; // why (3-b): special byte order within a register + } + } + + // Pass-3: Bit-level interleaving + for (size_t i = 0; i < BytesPerThread_2bit * 32 / 4; i++) + BitInterleaving_2bit(Weight_2bit + 4 * i); + for (size_t i = 0; i < BytesPerThread_4bit * 32 / 4; i++) + BitInterleaving_4bit(Weight_4bit + 4 * i); +} + +/* + * Inputs: + * (1) unsigned char Scale_In[M*K/GroupSize*16/8] + * Outputs: + * (1) unsigned char Scale_Out[M*K/GroupSize*16/8] + */ +void GenMatrix_Scale_FP16(unsigned char* Scale_Out, + unsigned char* Scale_In, + size_t M, + size_t K, + int GroupSize) +{ + short* Out_PTR = reinterpret_cast(Scale_Out); + // + assert(K % GroupSize == 0); + size_t BytesPerRow = K / GroupSize * 2; + // + for (size_t i = 0; i < M / 64; i++) + for (size_t j = 0; j < K / GroupSize; j++) + for (int l = 0; l < 8; l++) { + *Out_PTR = GetShort(Scale_In, 0 + 64 * i + l, j, BytesPerRow); + Out_PTR += 1; + *Out_PTR = GetShort(Scale_In, 8 + 64 * i + l, j, BytesPerRow); + Out_PTR += 1; + *Out_PTR = GetShort(Scale_In, 16 + 64 * i + l, j, BytesPerRow); + Out_PTR += 1; + *Out_PTR = GetShort(Scale_In, 24 + 64 * i + l, j, BytesPerRow); + Out_PTR += 1; + *Out_PTR = GetShort(Scale_In, 32 + 64 * i + l, j, BytesPerRow); + Out_PTR += 1; + *Out_PTR = GetShort(Scale_In, 40 + 64 * i + l, j, BytesPerRow); + Out_PTR += 1; + *Out_PTR = GetShort(Scale_In, 48 + 64 * i + l, j, BytesPerRow); + Out_PTR += 1; + *Out_PTR = GetShort(Scale_In, 56 + 64 * i + l, j, BytesPerRow); + Out_PTR += 1; + } + return; +} + +void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) +{ + constexpr int exponent_bits_fp6 = 3; + constexpr int mantissa_bits_fp6 = 2; + // Constants for FP16 + constexpr int exponent_bits_fp16 = 5; + constexpr int mantissa_bits_fp16 = 10; + constexpr int exp_bias_fp16 = (1 << (exponent_bits_fp16 - 1)) - 1; + + uint8_t fp6_temp[4]; + + for (int i = 0; i < 4; ++i) { + int sign = (FP16x4[i] >> 15); + int exp = (FP16x4[i] >> mantissa_bits_fp16) & + ((1 << exponent_bits_fp16) - 1); // Extracting exponent + int mant = FP16x4[i] & ((1 << mantissa_bits_fp16) - 1); // Extracting mantissa + + int new_exp = exp - exp_bias_fp16; + int new_mant = mant >> (mantissa_bits_fp16 - mantissa_bits_fp6); + + fp6_temp[i] = (sign << (exponent_bits_fp6 + mantissa_bits_fp6)) | + (new_exp << mantissa_bits_fp6) | new_mant; + } + // Pack the values + FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); + FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); + FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; +} + +/* + * Inputs: + * (1) uint16_t Weight_16bit[M*K] + * Outputs: + * (1) unsigned char Weight_6bit[M*K*6/8] + */ +void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t M, size_t K) +{ +#pragma omp parallel for + for (auto m = 0; m < M; m++) { + uint8_t* ptr_6bit = Weight_6bit + m * K * 6 / 8; + uint16_t* ptr_16bit = Weight_16bit + m * K; + for (auto k = 0; k < K; k += 4) { + Cast_FP16_FP6(ptr_16bit, ptr_6bit); + ptr_16bit += 4; + ptr_6bit += 3; + } + } +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h new file mode 100644 index 000000000000..77810eb39ba5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h @@ -0,0 +1,41 @@ +#pragma once +#include +#include + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ +void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, + unsigned char* Weight_2bit, + unsigned char* Weight_4bit, + size_t M, + size_t K); + +/* + * Inputs: + * (1) unsigned char Scale_In[M*K/GroupSize*16/8] + * Outputs: + * (1) unsigned char Scale_Out[M*K/GroupSize*16/8] + */ +void GenMatrix_Scale_FP16(unsigned char* Scale_Out, + unsigned char* Scale_In, + size_t M, + size_t K, + int GroupSize); + +/* + * Inputs: + * (1) uint16_t Weight_16bit[M*K] + * Outputs: + * (1) unsigned char Weight_6bit[M*K*6/8] + */ +void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, unsigned char* Weight_6bit, size_t M, size_t K); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh new file mode 100644 index 000000000000..46cc0bc5a002 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh @@ -0,0 +1,211 @@ +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +#include "utils/Configs.h" +#include "utils/Utils_GMem.cuh" +#include "utils/Utils_Core.cuh" + +__device__ __forceinline__ void ExchangePTRs(void** PTR1, void** PTR2) { + void* tmp_PTR = *PTR1; + *PTR1 = *PTR2; + *PTR2 = tmp_PTR; +} + +/* + * C = A*B + * A: row major with ahead-of-time layout transformation, FP6 + * B: col major, FP16 + * C: col major, FP16 + */ + template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, const half* Scales, const int QUANT_GROUP_SIZE_DIVIDED_BY_64, + const half *B, + OutputDataType* C, + const size_t M_Global, const size_t N_Global, const size_t K_Global, + int Split_K) +{ + #ifdef DEBUG_MODE + assert(K_Global%TilingConfig::TILE_K==0); + assert(M_Global%TilingConfig::TILE_M==0); + assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M)); + #endif + extern __shared__ __align__(128) half smem[]; // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned + half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + (SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE)/2 ); // Dynamic shared memory for FP16 B tiles + __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS*2]; // static shared memory for quantization scales, 64 row per warp * 4 warps * 2 double buffer = 1 KB + // Thread Block Mapping, considering SplitK + const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); + const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t Tile_Start_M = y * TilingConfig::TILE_M; + const size_t Tile_Start_N = x * TilingConfig::TILE_N; + const size_t NumBlock_K = K_Global/TilingConfig::TILE_K; + const size_t AverageNumBlock_K = NumBlock_K/Split_K; + const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; + size_t NumIter = AverageNumBlock_K; + if(BatchID(smem); + uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR+SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers + // StartSPTR for each WARP + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4; + // Pre-fetch of A tile + for(int i=0; i(AFrag_2BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4, WARP_StartGPTR_A1); + CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4, WARP_StartGPTR_A2); + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; + } + // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// + #ifdef DEBUG_MODE + assert(NumBlock_K%QUANT_GROUP_SIZE_DIVIDED_BY_64==0); + #endif + const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS)* (NumBlock_K/QUANT_GROUP_SIZE_DIVIDED_BY_64) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * (NumBlock_K/QUANT_GROUP_SIZE_DIVIDED_BY_64) * 64; + size_t UnitID_K = WARP_Start_UnitID_K; + size_t QuantGroup_K = UnitID_K / QUANT_GROUP_SIZE_DIVIDED_BY_64; + CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales + QuantGroup_K*64); + // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// + const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global); + BTile_GPTR += TilingConfig::TILE_K; + } + // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// + constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block +#ifdef PIPELINE_LEVEL_SMEM + uint32_t a [NumRegSets_a * PIPELINE_LEVEL_SMEM][4]; // double/Trible buffer is used // Registers to store decompressed FP6 + uint32_t b [NumRegSets_b * PIPELINE_LEVEL_SMEM][4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) +#endif + float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; + for(int i=0; i(a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); +#endif + // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + #pragma unroll(1) + for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) + { + // Trible-Buffer for A Tile + uint32_t* __restrict__ read_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 +#ifdef PIPELINE_LEVEL_SMEM + uint32_t* __restrict__ read2_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; + uint32_t* __restrict__ read2_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; +#endif + uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + // Trible-Buffer for B Tile + half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#ifdef PIPELINE_LEVEL_SMEM + half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#endif + half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + // + bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; + /* + // Optionally Updating QuantScale for A Tile + UnitID_K = WARP_Start_UnitID_K + tile_id_k; + QuantGroup_K = UnitID_K / QUANT_GROUP_SIZE_DIVIDED_BY_64; + //bool SwitchQuantGroup = (UnitID_K % QUANT_GROUP_SIZE_DIVIDED_BY_64 == (QUANT_GROUP_SIZE_DIVIDED_BY_64-1)); + bool SwitchQuantGroup = false; + if(SwitchQuantGroup) CopyFromGlobalToShared_Scales(write_WARP_SPTR_Scales, WARP_StartGPTR_A_Scales + (QuantGroup_K+1)*64, GlobalCopy); // If the next loop need the new scales, load the scales from global to shared. + //bool IsNewQuantGroup = (UnitID_K % QUANT_GROUP_SIZE_DIVIDED_BY_64 == 0); + bool IsNewQuantGroup = false; + if(IsNewQuantGroup) ExtractFromSharedToReg_Scales(Scales_RPTR, read_WARP_SPTR_Scales); // If the curent loop need the new scales, load the scales from shared to registers. + // Exchanging the PTRs for double buffers + if(SwitchQuantGroup) ExchangePTRs((void**)&read_WARP_SPTR_Scales, (void**)&write_WARP_SPTR_Scales); // SPTRs for Scales + */ + + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + CopyFromGlobalToShared_A(write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); + CopyFromGlobalToShared_A(write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); + // copying B tile from GlobalMemory to SharedMemory + CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, GlobalCopy); + cp_async_group_commit(); + #ifdef PIPELINE_LEVEL_SMEM + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); + core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + core_mma_slice(c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); + // Updating global PTRs + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + #else + PipelinedCoreLoop(c, read_SPTR, read_SPTR_Frag1, read_SPTR_Frag2, Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs + // Updating global PTRs + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + #endif + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store the C fragments to shared memory. + float (*smem_CFrag) [TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4] = + reinterpret_cast (smem); + StoreToSharedMemoryFromRegister(smem_CFrag, c); + __syncthreads(); + // Now that shared memory contains all the D tiles, stream them to global memory. + OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global; + + size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N; + for(size_t i=warpId; i::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); + else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; + } +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_Reduction.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_Reduction.cuh new file mode 100644 index 000000000000..442de103b8d1 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_Reduction.cuh @@ -0,0 +1,47 @@ +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +// Used for the reduction of result matrix if Split-K is used +// Reduction_Workspace: (Split_K, M_Global, N_Global), column major +// C: (M_Global, N_Global), column major +// Each thread deals with 8 output elements, each elements is the sum of Split_K elements +// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp +// Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> 256 half per warp +// GridSize = (M_Global*N_Global) / 256 + +#include +#include +#include + +#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 +#define HALF_PER_128BIT 8 + +__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) +{ + half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + // Initializing Thread-Local Results + float Results[HALF_PER_128BIT]; + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; + // Reduction + for (int i = 0; i < Split_K; i++) { + #pragma unroll + for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; + THREAD_GPTR_R += M_Global * N_Global; + } + // Writing to global memory + #pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu new file mode 100644 index 000000000000..6f9a8daa499b --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu @@ -0,0 +1,287 @@ +#include "GenMatrix_QuantLLM.h" +#include "Kernel_QuantGEMM.cuh" +#include "Kernel_Reduction.cuh" + +#include +#include + +template +static void Kernel_QuantGEMM_Ex(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const int QUANT_GROUP_SIZE_DIVIDED_BY_64, + const half* B, + OutputDataType* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) +{ +#ifdef DEBUG_MODE + printf("\n"); + printf("Launcher.cu->Kernel_QuantGEMM_Ex():\n"); + printf("M: %d, N: %d, K: %d, SplitK: %d, QUANT_GROUP_SIZE_DIVIDED_BY_64: %d\n", + M_Global, + N_Global, + K_Global, + Split_K, + QUANT_GROUP_SIZE_DIVIDED_BY_64); + printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", + TilingConfig::TILE_M, + TilingConfig::TILE_K, + TilingConfig::TILE_N); + // assert(N_Global % TilingConfig::TILE_N == 0); + // assert(M_Global*Split_K % TilingConfig::TILE_M == 0); + // assert(K_Global % TilingConfig::TILE_K == 0); +#endif + static size_t SHMEM_SZ = + max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE, + TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + SHMEM_SZ); + size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; + size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; + dim3 GridDim(dimN, dimM, 1); + dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); +// +#ifdef DEBUG_MODE + printf( + "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: " + "%d SHMEM_SZ: %d\n", + GridDim.x, + GridDim.y, + GridDim.z, + BlockDim.x, + BlockDim.y, + BlockDim.z, + SHMEM_SZ); + printf("\n"); +#endif + QUANT_GEMM_Kernel + <<>>(Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + C, + M_Global, + N_Global, + K_Global, + Split_K); +} + +/* + *half* Reduction_Workspace: 1. Requiring an extra memory space in device memory for un-reducted + *intermediate output tensors + * 2. Reduction_Workspace_Size = Split_K * M_Global * N_Global * + *sizeof(fp32) + */ +cudaError_t QuantGEMM_API( + cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const int QUANT_GROUP_SIZE_DIVIDED_BY_64, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches + int Split_K) +{ + if (N_Global <= 0) { + printf("QuantLLM_API Error: Unsupported N dimension %ld!\n", N_Global); + return cudaErrorUnknown; + } + + // Work around to support more N shapes: Pretending that the input is 2^n + size_t N_PowerOf2; + if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8; + if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16; + if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32; + if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64; + if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128; + if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128; + // printf("N_Global:%d N_PowerOf2:%d\n", N_Global, N_PowerOf2); + + if (Split_K == 1) { + switch (N_PowerOf2) { + case 8: + Kernel_QuantGEMM_Ex, half>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + C, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 16: + Kernel_QuantGEMM_Ex, half>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + C, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 32: + Kernel_QuantGEMM_Ex, half>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + C, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 64: + Kernel_QuantGEMM_Ex, half>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + C, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 128: + Kernel_QuantGEMM_Ex, half>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + C, + M_Global, + N_Global, + K_Global, + Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("QuantLLM_API Error: Unsupported N dimension %ld!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_QuantGEMM_Ex, half>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + C, + M_Global, + N_Global, + K_Global, + Split_K); + break; + } + } else { + switch (N_PowerOf2) { + case 8: + Kernel_QuantGEMM_Ex, float>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 16: + Kernel_QuantGEMM_Ex, float>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 32: + Kernel_QuantGEMM_Ex, float>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 64: + Kernel_QuantGEMM_Ex, float>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 128: + Kernel_QuantGEMM_Ex, float>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("QuantLLM_API Error: Unsupported N dimension %ld!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_QuantGEMM_Ex, float>(stream, + Weight1, + Weight2, + Scales, + QUANT_GROUP_SIZE_DIVIDED_BY_64, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + } + // Reduction for SplitK + dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); + dim3 BlockDim(WARP_SIZE, 1, 1); + SplitK_Reduction<<>>( + C, Reduction_Workspace, M_Global, N_Global, Split_K); + } + return cudaGetLastError(); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py new file mode 100644 index 000000000000..cd08409c0a7a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_linear import * diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py new file mode 100644 index 000000000000..9431ddc96f41 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from deepspeed.ops.op_builder import InferenceCoreBuilder +from typing import Tuple +from ... import DSKernelBase + + +class CUDAWf6Af16Linear(DSKernelBase): + """ + Wrapper around the CUDA kernel of Wf6Af16 quantized linear. + + Performs z = x @ y + """ + supported_dtypes = [DtypeEnum.fp16] + + def __init__(self): + self.inf_module = InferenceCoreBuilder().load() + self.inf_module.create_handle() + self.kernel = self.inf_module.cuda_wf6af16_linear + + def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_4bit: torch.Tensor, weights_2bit: torch.Tensor, scale: torch.Tensor, M, N, K) -> torch.Tensor: + """ + Matmul kernel as implemented via CUDA directly. The input must be 2D or larger. If + n-dimensional, the leading dimensions are folded into each other: + 2D: m = x.size(0) + 3D: m = x.size(0) * x.size(1) + 4D: m = x.size(0) * x.size(1) * x.size(2) (etc...) + All inputs should be contiguous. + + Parameters: + output (torch.Tensor): Output tensor. Shape is of [*, out_features] + hidden_states (torch.Tensor): Input tensor. Shape is of [*, in_features] + weights (torch.Tensor): Input tensor. Shape is of [out_features, in_features] + scale (torch.Tensor): Input tensor. Shape is of [1] or [out_features], since the scale is per output channel + + Returns: + z (torch.Tensor): Output tensor. Shape is of [m, n] + """ + + # TODO: deal with batched-matmul. As the current implementation only supports 2D input, we need to split the + # batched-matmul into multiple 2D matmul. + + # TODO: optimize the heuristic of split k selection. + split_k_dict = {15360: 3, 27648: 2, 5120: 10, 10240: 5, + 57344: 7, 8192: 6, 21504: 5, 7168: 7, 28672: 7} + split_k = 1 + if not N > 128 and M in split_k_dict: + split_k = split_k_dict[M] + workspace = self.get_workspace( + M, N, K, split_k, torch.float, hidden_states.device) + self.kernel(output, hidden_states, weights_4bit, + weights_2bit, scale, workspace, M, N, K, split_k) + + def get_workspace(self, M: int, N: int, K: int, split_k: int, dtype, device) -> torch.Tensor: + """ + Allocate workspace for the kernel. The workspace is used to store the intermediate results of the matmul before + split-K. The split-K size is determined by the size of the matmul. + """ + workspace = torch.empty((split_k, M, N), dtype=dtype, device=device) + # TODO: allocate workspace in advance to avoid memory allocation overhead + + return workspace diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp new file mode 100644 index 000000000000..eb4ad0b9318b --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include "cuda_linear_kernels.h" +#include "GenMatrix_QuantLLM.h" + +void Launch_QuantGEMM(torch::Tensor output, + torch::Tensor Weight1, // 2bit + torch::Tensor Weight2, // 4bit + torch::Tensor B, + torch::Tensor Scales, + const int M_Global, + const int N_Global, + const int K_Global, + const int Split_K, + torch::Tensor workspace); + +void cuda_wf6af16_linear(torch::Tensor& output, + torch::Tensor& hidden_states, + torch::Tensor& weights_4bit, + torch::Tensor& weights_2bit, + torch::Tensor& scale, + torch::Tensor& workspace, + int M, + int N, + int K, + int split_k) +{ + TORCH_CHECK(weights_2bit.device().type() == torch::kCUDA, "weight_2bit must be on CUDA"); + TORCH_CHECK(weights_4bit.device().type() == torch::kCUDA, "weight_4bit must be on CUDA"); + TORCH_CHECK(hidden_states.device().type() == torch::kCUDA, "X must be on CUDA"); + TORCH_CHECK(scale.device().type() == torch::kCUDA, "scale must be on CUDA"); + Launch_QuantGEMM( + output, weights_2bit, weights_4bit, hidden_states, scale, M, N, K, split_k, workspace); +} + +/* + * Inputs: + * (1) torch::Tensor weight[M, K] in FP16 + * Outputs: + * (1) torch::Tensor weight_2bit and weight_4bit + */ +std::vector preprocess_weight(torch::Tensor& Weight) +{ + TORCH_CHECK(Weight.dim() == 2, "weight must be 2-dimensional"); + TORCH_CHECK(Weight.scalar_type() == torch::kFloat16, "weight must be FP16"); + TORCH_CHECK(Weight.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(Weight.device().type() == torch::kCPU, "weight must be on CPU"); + auto M = Weight.size(0); + auto K = Weight.size(1); + TORCH_CHECK(K % 4 == 0, "K must be multiple of 4"); + + // Pack Weight + auto Weight_ptr = Weight.data_ptr(); + std::vector Weight_6bit_Packed(M * K * 6 / 8); + PackMatrix_Weight_FP6((uint16_t*)Weight_ptr, Weight_6bit_Packed.data(), M, K); + + // Split Weight + auto Weight_2bit = torch::empty({M * K * 2 / 8}, torch::kUInt8); + auto Weight_4bit = torch::empty({M * K * 4 / 8}, torch::kUInt8); + GenMatrix_Weight_FP6(Weight_6bit_Packed.data(), + Weight_2bit.data_ptr(), + Weight_4bit.data_ptr(), + M, + K); + + return {Weight_2bit, Weight_4bit}; +} + +/* + * Inputs: + * (1) torch::Tensor Scale_In[M, K/GroupSize] in FP16 + * Outputs: + * (1) torch::Tensor Scale_Out[M, K/GroupSize] in FP16 + */ + +torch::Tensor preprocess_scales(torch::Tensor& Scale, int M, int K) +{ + // Preprocess scales + TORCH_CHECK(Scale.dim() == 2, "scale must be 2-dimensional"); + TORCH_CHECK(Scale.size(0) == M, "scale must have same M as weight"); + TORCH_CHECK(Scale.is_contiguous(), "scale must be contiguous"); + TORCH_CHECK(Scale.device().type() == torch::kCPU, "scale must be on CPU"); + TORCH_CHECK(Scale.scalar_type() == torch::kFloat16, "scale must be FP16"); + auto GroupSize = K / Scale.size(1); + TORCH_CHECK(GroupSize % 64 == 0, "GroupSize must be multiple of 64"); + auto New_Scale = torch::empty_like(Scale); + auto Scale_out = New_Scale.data_ptr(); + auto Scale_in = New_Scale.data_ptr(); + GenMatrix_Scale_FP16((uint8_t*)Scale_out, (uint8_t*)Scale_in, M, K, GroupSize); + return New_Scale; +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu new file mode 100644 index 000000000000..194b9120a9eb --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include "cuda_linear_kernels.h" + +cudaError_t QuantGEMM_API( + cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const int QUANT_GROUP_SIZE_DIVIDED_BY_64, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches + int Split_K); + +void Launch_QuantGEMM(torch::Tensor C, + torch::Tensor Weight1, // 2bit + torch::Tensor Weight2, // 4bit + torch::Tensor B, + torch::Tensor Scales, + const int M_Global, + const int N_Global, + const int K_Global, + const int Split_K, + torch::Tensor workspace) +{ + auto C_ptr = C.data_ptr(); + auto B_ptr = B.data_ptr(); + auto W1_ptr = Weight1.data_ptr(); + auto W2_ptr = Weight2.data_ptr(); + auto Group_Size = K_Global / Scales.size(1); + + // auto workspace_size = M_Global * N_Global * Split_K; + // auto workspace = torch::empty({workspace_size}, torch::kFloat16); + + auto status = QuantGEMM_API(at::cuda::getCurrentCUDAStream(), + (uint4*)W1_ptr, + (uint4*)W2_ptr, + (half*)Scales.data_ptr(), + Group_Size / 64, + (half*)B_ptr, + (half*)C_ptr, + M_Global, + N_Global, + K_Global, + workspace.data_ptr(), + Split_K); + if (status != cudaSuccess) { + AT_ERROR("QuantGEMM_API failed with error: ", cudaGetErrorString(status)); + } +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h new file mode 100644 index 000000000000..0a2e2bd4197e --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ds_kernel_utils.h" + +void cuda_wf6af16_linear(torch::Tensor& output, + torch::Tensor& hidden_states, + torch::Tensor& weights_4bit, + torch::Tensor& weights_2bit, + torch::Tensor& scale, + torch::Tensor& workspace, + int M, + int N, + int K, + int split_k); + +std::vector preprocess_weight(torch::Tensor& Weight); + +torch::Tensor preprocess_scales(torch::Tensor& Scale, int M, int K); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/quant_gemm_api.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/quant_gemm_api.cuh new file mode 100644 index 000000000000..35d899aa642c --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/quant_gemm_api.cuh @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include +#include +#include + +/* + *half* Reduction_Workspace: 1. Requiring an extra memory space in device memory for un-reducted intermediate output tensors + * 2. Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp16) + */ +cudaError_t QuantGEMM_API(cudaStream_t stream, + const uint4 *Weight1, + const uint4 *Weight2, + const half *Scales, + const int QUANT_GROUP_SIZE_DIVIDED_BY_64, + const half *B, + half *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float *Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches + int Split_K); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Configs.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Configs.h new file mode 100644 index 000000000000..d910454fb819 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Configs.h @@ -0,0 +1,89 @@ +#ifndef CONFIGS_H +#define CONFIGS_H + +//#define DEBUG_MODE +#define PIPELINE_LEVEL_GMEM 2 +#define PIPELINE_LEVEL_SMEM 2 // only support 2 + +/************************ Hardware Parameters ************************/ +#define WARP_SIZE 32 +#define REG_BIT_WIDTH 32 +// mma: M=16 K=16 N=8 +#define MMA_8 8 +#define MMA_16 16 +// for memory access +#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... +#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 + +/******************** Register Allocation For GEMM ********************/ +#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation +/********************** Memory Padding Parameters **********************/ +// Eliminating bank-conflict +#define PADDING_BYTES_16 16 // Padding 16 bytes each column +#define PADDING_SHARED_MEM_FOR_B_8 \ + 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B +#define PADDING_SHARED_MEM_FOR_C_4 \ + 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C +/************************* WARP Tiling part-1 *************************/ +#define WARP_ROW_MMA_TENSORS 4 +#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 +#define WARP_K_MMA_TENSORS 4 +#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 +template +struct TilingConfig { + // Depending on "n" dimension of the GEMM + static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; + static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; + static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; + /************************* WARP Tiling part-2 *************************/ + static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; + /*************************Thread Block Tiling *************************/ + static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; + static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; + static constexpr int TILE_K = WARP_K; + /********************** #Thread per Thread Block **********************/ + static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; + static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; + /******************************* Others *******************************/ + static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * + PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 + static constexpr int SMEM_SIZE_C_TILE = + TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 +}; + +/************************ General Config for Quant-LLM **********************/ +#define WEIGHT_FRAG1_BIT_WIDTH 2 +#define WEIGHT_FRAG2_BIT_WIDTH 4 +#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH + WEIGHT_FRAG2_BIT_WIDTH) // 6 +//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // +// QuantGroupSize: 4*64 = 256 +/*************************** 64*64 Weghts of A WARP *************************/ +#define WEIGHT_PER_UNIT (WARP_M * WARP_K) // 64*64 +#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 \ + (WEIGHT_PER_UNIT * WEIGHT_FRAG1_BIT_WIDTH / \ + 8) // 1024 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 \ + (WEIGHT_PER_UNIT * WEIGHT_FRAG2_BIT_WIDTH / \ + 8) // 2048 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_A1_TILE \ + (SMEM_SIZE_IN_BYTES_PER_WARP_A1 * 4 * \ + PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double + // buffer for 2-level pipeline A= 8 KB. +#define SMEM_SIZE_A2_TILE \ + (SMEM_SIZE_IN_BYTES_PER_WARP_A2 * 4 * \ + PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double + // buffer for 2-level pipeline A= 16 KB. +/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/ +#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT * WEIGHT_FRAG1_BIT_WIDTH / 128) // 64 +#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT * WEIGHT_FRAG2_BIT_WIDTH / 128) // 128 +/******************** Register Allocation For QUANTIZED DATA ******************/ +#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT / WARP_SIZE) // 128 +#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD / REG_BIT_WIDTH * 2) // 8 +#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD / REG_BIT_WIDTH * 4) // 16 +/******************** Register Allocation For QUANT Scales ******************/ +#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers +#define WARP_REG_QUANT_SCALE_DISTRIBUTED \ + 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for + // each thread + +#endif // CONFIGS_H \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_cp.async.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_cp.async.cuh new file mode 100644 index 000000000000..90b00c386f2f --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_cp.async.cuh @@ -0,0 +1,59 @@ +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +// Extended from CUTLASS's source code + +#ifndef PTX_CP_ASYNC_CUH +#define PTX_CP_ASYNC_CUH + +#include +#include +#include + +template +__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true) +{ + static_assert(SizeInBytes == 16, "Size is not supported"); + unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); + asm volatile("{ \n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), + "l"(global_ptr), + "n"(SizeInBytes)); +} + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +__device__ __forceinline__ void cp_async_group_commit() +{ + asm volatile("cp.async.commit_group;\n" ::); +} + +/// Blocks until all but previous cp.async.commit_group operations have committed. +template +__device__ __forceinline__ void cp_async_wait_group() +{ + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +} + +/// Blocks until all previous cp.async.commit_group operations have committed. +// cp.async.wait_all is equivalent to : +// cp.async.commit_group; +// cp.async.wait_group 0; +__device__ __forceinline__ void cp_async_wait_all() +{ + asm volatile("cp.async.wait_all;\n" ::); +} + +#endif \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_mma.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_mma.cuh new file mode 100644 index 000000000000..94b8e0c94d92 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_mma.cuh @@ -0,0 +1,113 @@ +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +#ifndef PTX_MMA_CUH +#define PTX_MMA_CUH + +#include +#include +#include + +#include +#include "Configs.h" + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + int slice_id) { + #ifdef DEBUG_MODE + static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); + #endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory + #ifdef DEBUG_MODE + assert( warp_start_col==0 ); + #endif + + int col = (lane_id%8) + (lane_id/16)*8; + int row = (lane_id%16) / 8 * 8; + uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row])); + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } + else { + #pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) + { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +} +#else +// Debug: Whether ldmatrix.trans is required??? +// B is in column-major +template +__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + int k_offset) { + #ifdef DEBUG_MODE + static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); + #endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory + #ifdef DEBUG_MODE + assert( warp_start_col==0 ); + #endif + + int col = (lane_id%8) + (lane_id/16)*8; + int row = (lane_id%16) / 8 * 8; + uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][k_offset + row])); + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } + else { + #pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) + { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +} +#endif + +__device__ __forceinline__ void +MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b) +{ + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + +#endif \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_Core.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_Core.cuh new file mode 100644 index 000000000000..a82d097a2c8a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_Core.cuh @@ -0,0 +1,212 @@ +/*************************************************************************** + * Copyright 2023 The FLash-LLM Authors. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + ***************************************************************************/ +#ifndef UTILS_CORE_CUH +#define UTILS_CORE_CUH + +#include + +#include "Configs.h" +#include "PTX_mma.cuh" +#include "Utils_ParallelDequant.cuh" + + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { + SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE); + int lane_id = threadIdx.x % WARP_SIZE; + #pragma unroll + for(int i=0; i +__device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales) +{ + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, 0); + CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, 0); + Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers +} + +template +__device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales, + int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching +{ + #ifdef DEBUG_MODE + assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block + #endif + const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block + uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); // Reigsters for accumulated FP32 results + + // Setting RPTRs for double buffers + uint32_t (*a_read )[4] = a; + uint32_t (*a_write)[4] = a; + uint32_t (*b_read )[4] = b; + uint32_t (*b_write)[4] = b; + if(slice_id%2==1) { b_write += NumRegSets_b; a_write += NumRegSets_a;} + else { b_read += NumRegSets_b; a_read += NumRegSets_a;} + + // Reading registers and issuing core tensor core computations (a slice of A and B tile in shared memory) + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); + } + else { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 + } + } + } + + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, slice_id); + CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, slice_id); + Dequant_32FP6_4Way(a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers +} + +#else +// Old version with naive pipeline design +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) { + int lane_id = threadIdx.x % WARP_SIZE; + #pragma unroll + for(int i=0; i +__device__ __forceinline__ void PipelinedCoreLoop(float c[][REG_PER_THREAD_C_TENSOR_16_16], + half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + uint32_t* __restrict__ read_SPTR_Frag1, + uint32_t* __restrict__ read_SPTR_Frag2, + uint32_t* RPTR_Scales) +{ + #ifdef DEBUG_MODE + assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block + #endif + const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block + + // Reigsters to store FP32 results + uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; + uint32_t a_1[2*2]; // double buffer is used + uint32_t a_2[4*2]; // double buffer is used + // Registers to store decompressed FP6 + uint32_t a [NumRegSets_a * 1][4]; // No double buffer + // Register to store FP16 B matrix (a slice) + uint32_t b [NumRegSets_b * 2][4]; // double buffer is used + + // Overlapped Smem and TC pipeline: pre-loading from shared to registers + CopyFromSharedToRegister_AFrag<2> (a_1, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4> (a_2, read_SPTR_Frag2); + B_FromSharedToReg (b, read_SPTR, 0); + + #pragma unroll + for (int k = 0; k < WARP_K_MMA_TENSORS; k++) { + uint32_t (*b_read)[4] = b; + uint32_t (*b_write)[4] = b; + uint32_t *a_1_read = a_1; + uint32_t *a_1_write = a_1; + uint32_t *a_2_read = a_2; + uint32_t *a_2_write = a_2; + if(k%2==0) { + b_write += NumRegSets_b; + a_1_write += 2; + a_2_write += 4; + } + else { + b_read += NumRegSets_b; + a_1_read += 2; + a_2_read += 4; + } + // data loading + if (k + 1 < WARP_K_MMA_TENSORS) { + // updating SPTR for fragment1 and fragment2 + read_SPTR_Frag1 += 2*WARP_SIZE; + read_SPTR_Frag2 += 4*WARP_SIZE; + CopyFromSharedToRegister_AFrag<2>(a_1_write, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4>(a_2_write, read_SPTR_Frag2); + B_FromSharedToReg(b_write, read_SPTR, (k+1)*MMA_16); + } + // SIMT Dequant + Tensor Core computations + Dequant_32FP6_4Way(a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, dequantizing a slice each time + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if(TilingConfig::WARP_COL_MMA_TENSORS==1) + MMA_FP16_M16N8K16( c_uint_ptr[i], a[i], b_read[0] ); + else { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j] ); + MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a[i], b_read[j] + 2 ); // c+4; b+2 + } + } + } + } +} +#endif // #ifdef PIPELINE_LEVEL_SMEM + +template +__device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], + float c[][REG_PER_THREAD_C_TENSOR_16_16]) +{ + const int lane_id = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); + #pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + #pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) { // Dealing with one 16*8 Tensor + int RegSetID = i + (j/2)*WARP_ROW_MMA_TENSORS; + int RegOffset = (j%2)*(REG_PER_THREAD_C_TENSOR_16_16/2); + int Tensor_row_offset = warp_row_offset + i * MMA_16; + int Tensor_col_offset = j * MMA_8; + #pragma unroll + for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16/2; r++) { + int row_offset = lane_id / 4; + if (r >= 2) row_offset += 8; + int col_offset = (lane_id % 4) * 2; + if (r%2==1) col_offset += 1; + smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset]; + } + } + } +} + +#endif \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh new file mode 100644 index 000000000000..c507b43d8111 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh @@ -0,0 +1,97 @@ +#ifndef UTILS_GMEM_CUH +#define UTILS_GMEM_CUH + +#include +#include "Configs.h" +#include "PTX_cp.async.cuh" + +template +__device__ __forceinline__ void CopyFromGlobalToReg_ByPassL1 + (uint32_t* RegArray, + const uint4* SourceGlobalAddr, + bool pred_guard = true) { + // Setting GMem pointer + int lane_id = threadIdx.x % WARP_SIZE; + SourceGlobalAddr += lane_id; + #pragma unroll + for(int i=0; i +__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, + const uint4* GPTR, + bool pred_guard = true) { + #ifdef DEBUG_MODE + static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); + #endif + int lane_id = threadIdx.x % WARP_SIZE; + half* SPTR_HALF = reinterpret_cast(SPTR); + const half* GPTR_HALF = reinterpret_cast(GPTR); + SPTR_HALF += lane_id*8; + GPTR_HALF += lane_id*8; + #pragma unroll + for(int i=0; i( SPTR_HALF, GPTR_HALF, pred_guard); + SPTR_HALF += 256; // Forward 512 Bytes + GPTR_HALF += 256; // Forward 512 Bytes + } + +} + + +/* + * Copying 64 Quant Scales (FP16) from global memory to shared memory. + */ +__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, + const half* GPTR_A_Scales, + bool pred_guard = true) { + int lane_id = threadIdx.x % WARP_SIZE; + GPTR_A_Scales += lane_id * 8; // 8 FP16 (16 Bytes) per thread + SPTR_QuantScales += lane_id * 8; // 8 FP16 (16 Bytes) per thread + cp_async<16>( SPTR_QuantScales, GPTR_A_Scales, pred_guard && (lane_id<8)); +} + +/* + * (1) Copying X rows * 64 columns of FP16 values, originally in row major + * (2) Copying 64 rows * X columns of FP16 values, originally in column major + * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, + const int GlobalStride, + bool Pred = true) { + // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time + const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; + const int NumOfGroups = NumOfThreads / 8; + const int MaxIteration = (NumOfLinesToCopy-1) / NumOfGroups + 1; + // runtime variables + const int line_id = threadIdx.x / 8; + const int line_offset = (threadIdx.x%8) * 8; + // PTR for source global memory and target shared memory + GlobalPTR += line_id * GlobalStride + line_offset; + SharedPTR += line_id; + #pragma unroll + for (int i = 0; i < MaxIteration; i++) { + bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesToCopy && Pred; + cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + // + GlobalPTR += NumOfGroups * GlobalStride; + SharedPTR += NumOfGroups; + } +} + +#endif \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh new file mode 100644 index 000000000000..d8694b400699 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh @@ -0,0 +1,121 @@ +#ifndef UTILS_PARALLELDEQUANT_CUH +#define UTILS_PARALLELDEQUANT_CUH + +#include +#include +#include +#include + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2) +{ + *R2 = *R1 & 0x80808080; + *R1 = *R1 >> 2; + *R1 = *R1 & 0x1f1f1f1f; + *R2 = *R2 | *R1; + *R1 = *R2 & 0x9f009f00; + *R2 = *R2 & 0x009f009f; + *R2 = *R2 << 8; +} + +// Will be removed in the future. +#define DEBUG 1 + +__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) +{ + half* FP16_1 = reinterpret_cast(&PackedFP16Pair); + half* FP16_2 = FP16_1 + 1; + uint32_t output; + half* output_half_ptr = reinterpret_cast(&output); + +#if DEBUG + // Only for testing. Note that the scales should be multiplied by e12 accoring to the Quant-LLM + // optimization. + bool apply_e12 = true; + if (apply_e12) { + // TODO: this will still lead to NaN. Need to fix it. + // 2^12 == 4096 + output_half_ptr[0] = __hmul(*FP16_1, half(4096)); + output_half_ptr[1] = __hmul(*FP16_2, half(4096)); + } else { + output_half_ptr[0] = *FP16_1; + output_half_ptr[1] = *FP16_2; + } +#else + output_half_ptr[0] = *FP16_1 * Scale; + output_half_ptr[1] = *FP16_2 * Scale; +#endif +#if 0 + auto res1 = output_half_ptr[0]; + auto res2 = output_half_ptr[1]; + // printf("FP16_1: %f\n", __half2float(*FP16_1)); + // The following two lines do not show NaN. + if (res1 != res1) { printf("res1 is NaN: %f\n", __half2float(res1)); } + if (res2 != res2) { printf("res2 is NaN: %f\n", __half2float(res2)); } +#endif + + return output; +} + +__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], + u_int32_t __restrict__* read_RPTR_Frag1, + u_int32_t __restrict__* read_RPTR_Frag2, + u_int32_t* Scales) +{ + u_int32_t* OutputRegs = reinterpret_cast(Reg); + u_int32_t* Frag1_PTR = read_RPTR_Frag1; + u_int32_t* Frag2_PTR = read_RPTR_Frag2; + half* Scale_RPTR = reinterpret_cast(Scales); + u_int32_t Packed_FP6 = 0; + u_int32_t tmp = 0; +// Dequantizing 32 FP6, each Loop dequantizing 4 FP6 +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + // Frag1 + Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; + if (i % 4 == 3) + Frag1_PTR++; + else + (*Frag1_PTR) = (*Frag1_PTR) << 2; + // Frag2 + tmp = (*Frag2_PTR) & 0xf0f0f0f0; + tmp = tmp >> 2; + if (i % 2 == 1) + Frag2_PTR++; + else + (*Frag2_PTR) = (*Frag2_PTR) << 4; + // Packed_FP6 + Packed_FP6 = Packed_FP6 | tmp; + // + FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); + // + *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0]); // Muliply FP16 scales + OutputRegs += 1; + *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + OutputRegs += 1; + // Updating offset for FP16 scales for every two iterations + if (i % 2 == 1) Scale_RPTR += 2; + } +} + +/* + * + */ +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, + half* WARP_SPTR_Scales) +{ + int lane_id = threadIdx.x % WARP_SIZE; + uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); + uint32_t tmpReg = SPTR_uint[lane_id]; +#pragma unroll + for (int i = 0; i < 4; i++) { + // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); + } +} + +#endif \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu index 84a9906cf037..cfa62f94596a 100644 --- a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu @@ -119,7 +119,10 @@ void launch_gated_activation_impl(T* output, DISPATCH_UNROLL(5); } else if (unroll == 6) { DISPATCH_UNROLL(6); + } else if (unroll == 7) { + DISPATCH_UNROLL(7); } else { + // TODO: provide a kernel with an outer loop to handle larger columns. throw std::runtime_error( "Called with more columns than supported, please report this bug and this limit will " "be increased."); diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py index f9da7ac5d23e..5e8618f22c22 100644 --- a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py +++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py @@ -138,7 +138,8 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) try: _ = layer_container.is_populated except ValueError as e: - raise ValueError(f"Layer container {l_name} is not populated.") from e + raise ValueError( + f"Layer container {l_name} is not populated.") from e layer_metadata = LayerMetadata() @@ -156,7 +157,8 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=param.stride(), offset=cur_offset) - cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) + cur_offset += pad_to_aligned_offset( + elem_size(param.dtype) * param.numel()) for t_name, tensor in param.aux_attrs.items(): param_metadata.aux_params[t_name] = TensorMetadata(dtype=str(tensor.dtype), @@ -164,7 +166,8 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=tensor.stride(), offset=cur_offset) - cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) + cur_offset += pad_to_aligned_offset( + elem_size(tensor.dtype) * tensor.numel()) layer_metadata.params[p_name] = param_metadata @@ -178,7 +181,8 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) l_name = "non_transformer" total_size = process_layer(non_transformer_container, l_name, total_size) - buffer = torch.empty(total_size, dtype=torch.uint8, device=get_accelerator().current_device()) + buffer = torch.empty(total_size, dtype=torch.uint8, + device=get_accelerator().current_device()) def copy_layer(layer_container: LayerContainer, l_name: str) -> None: """ @@ -206,11 +210,13 @@ def copy_layer(layer_container: LayerContainer, l_name: str) -> None: aux_params = {} for t_name, tensor in param.aux_attrs.items(): - t_view = alloc_fn(tensor, buffer, p_metadata.aux_params[t_name].offset) + t_view = alloc_fn( + tensor, buffer, p_metadata.aux_params[t_name].offset) aux_params[t_name] = t_view t_view.copy_(tensor) - setattr(layer_container, p_name, InferenceParameter.initialize(core_param, **aux_params)) + setattr(layer_container, p_name, + InferenceParameter.initialize(core_param, **aux_params)) for i, layer in enumerate(transformer_containers): l_name = f"transformer_layer_{i}" @@ -259,19 +265,23 @@ def restore_layer(layer_container: LayerContainer, l_name: str) -> None: layer_container.direct_injection(p_name, None) continue - dummy_tensor = torch.empty([], dtype=STR_TO_DTYPE[p_metadata.core_param.dtype]) + dummy_tensor = torch.empty( + [], dtype=STR_TO_DTYPE[p_metadata.core_param.dtype]) core_param = alloc_fn(p_metadata.core_param.shape, p_metadata.core_param.strides, dummy_tensor, buffer, p_metadata.core_param.offset) aux_params = {} for t_name, t_metadata in p_metadata.aux_params.items(): - dummy_tensor = torch.empty([], dtype=STR_TO_DTYPE[t_metadata.dtype]) - t_view = alloc_fn(t_metadata.shape, t_metadata.strides, dummy_tensor, buffer, t_metadata.offset) + dummy_tensor = torch.empty( + [], dtype=STR_TO_DTYPE[t_metadata.dtype]) + t_view = alloc_fn(t_metadata.shape, t_metadata.strides, + dummy_tensor, buffer, t_metadata.offset) aux_params[t_name] = t_view - restored_param = InferenceParameter.initialize(core_param, **aux_params) + restored_param = InferenceParameter.initialize( + core_param, **aux_params) layer_container.direct_injection(p_name, restored_param) for i, layer in enumerate(transformer_containers): diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index b89e95c0d834..1ddf34d3920a 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -49,7 +49,8 @@ def instantiate_attention(attention_config: DSSelfAttentionConfig, """ # Currently, we only have one implementation, so we just return it. - config = ConfigBundle(name="dense_blocked_attention", config=attention_config) + config = ConfigBundle(name="dense_blocked_attention", + config=attention_config) return DSSelfAttentionRegistry.instantiate_config(config) @@ -86,8 +87,17 @@ def instantiate_linear(linear_config: DSLinearConfig, engine_config: RaggedInfer A linear module implementing the given configuration. """ - # Currently, we only have one implementation, so we just return it. - config = ConfigBundle(name="blas_fp_linear", config=linear_config) + quantization_mode = engine_config.quantization.quantization_mode + if quantization_mode is None: + config = ConfigBundle(name="blas_fp_linear", config=linear_config) + else: + # Currently, we only support ``quantized_wf6af16_linear``. + if quantization_mode == "wf6af16": + config = ConfigBundle( + name="quantized_wf6af16_linear", config=linear_config) + else: + raise ValueError( + f"Unsupported quantization mode: {quantization_mode}") return DSLinearRegistry.instantiate_config(config) diff --git a/deepspeed/inference/v2/modules/implementations/linear/__init__.py b/deepspeed/inference/v2/modules/implementations/linear/__init__.py index e76aab71c4cf..11a545846522 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/__init__.py +++ b/deepspeed/inference/v2/modules/implementations/linear/__init__.py @@ -4,3 +4,5 @@ # DeepSpeed Team from .blas_fp_linear import BlasFPLinear +from .quantized_linear import QuantizedWf6Af16Linear + diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py new file mode 100644 index 000000000000..927eaf90b6ef --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -0,0 +1,236 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from qtorch.quant import float_quantize +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ....allocator import empty_from +from ....inference_utils import is_gated +from ....kernels.core_ops import ( + CUDAWf6Af16Linear, + CUDABiasActivation, + CUDAGatedActivation, +) + +from ...interfaces import DSLinearBase, DSLinearRegistry +from ...configs import DSLinearConfig +from ....inference_parameter import InferenceParameter + + +def fp_quantize( + input: torch.FloatTensor, num_bits: int = 6, exp_bits: int = 3, + min_value: torch.FloatTensor = None, max_value: torch.FloatTensor = None, + group_size: int = -1): + """ + Args: + inputs (`torch.FloatTensor`) + The input which needs to be quantized + num_bits (int, >=4) + Number of bits to use for quantization + exp_bits: + fp exp_bits + min_value/max_vlue (torch.FloatTensor) + Used for static activation quantization + group_size (int) N + The quantization block size, each N numbers has its own scaling factor and off-site + -1 means use the last dim as the group_size + Returns: + quantized_fake_fp6 + The quantized weights, in fp16 format and contains fp6 value. + scales + Quantization scales + """ + assert (min_value is None and max_value is None) or ( + min_value is not None and max_value is not None) + + assert input.dtype == torch.float16 + + input = input.to(torch.float32) + if num_bits == 6: + if exp_bits == 3: # this is defulat + q_range = 28 + else: + raise NotImplementedError + + man_bits = num_bits - exp_bits - 1 + input_shape = input.shape + + if group_size == -1: + group_size = input_shape[-1] + else: + # Only support per-channel quantization + raise NotImplementedError + num_groups = input.numel() // group_size + input = input.reshape(num_groups, -1) + + if min_value is None: + max_input = torch.amax( + torch.abs(input), dim=-1).view(num_groups, -1) + else: + max_input = torch.max(min_value.abs(), max_value) # .view(-1) + scales = max_input / q_range # q_range + 1 + scales[scales == 0] = 1 # avoid zero scales + scaled_input = input / scales + # torch.cuda.synchronize() # for some reason this is needed to avoid the output being 0 + + quantized_fake_fp6 = float_quantize( + scaled_input, exp_bits, man_bits, rounding="nearest") + quantized_fake_fp6 = quantized_fake_fp6.reshape( + input_shape).contiguous().to(torch.float16) + scales = scales.to(torch.float16) + # Now the dequantized value is quantized_fake_fp6 * scales + + return quantized_fake_fp6, scales + + +@DSLinearRegistry.register_module +class QuantizedWf6Af16Linear(DSLinearBase): + """ + Linear DSModule for FP6 weight-only quantization kernel, where weight is FP6 and activation is FP16. + """ + + @staticmethod + def name(): + return 'quantized_wf6af16_linear' + + @staticmethod + def supports_config(config: DSLinearConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + # As for fp6 data items, they are packed and stored in a set of fp16 tensors. E.g., 8 fp6 data items are stored + # in 3 fp16 tensor. + if config.input_dtype != torch.float16: + return False + + if is_gated(config.activation): + try: + _ = CUDAGatedActivation( + config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + else: + try: + _ = CUDABiasActivation( + config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + + return True + + def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self._linear_impl = CUDAWf6Af16Linear() + + if is_gated(config.activation): + # In the FP6 kernel implementation, the MatMul is W * A, where W is the weight and A is activation. + # M is the output channel size. + self.M = self._config.out_channels * 2 + self.K = self._config.in_channels + self._is_gated = True + self._act_fn = CUDAGatedActivation( + config.out_channels, config.output_dtype, config.activation) + self._double_buffer = torch.empty((config.max_tokens, config.out_channels * 2), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + else: + self.M = self._config.out_channels + self.K = self._config.in_channels + self._is_gated = False + self._act_fn = CUDABiasActivation( + config.out_channels, config.output_dtype, config.activation) + + self._output = torch.empty((config.max_tokens, config.out_channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + self.inf_module = InferenceCoreBuilder().load() + self.inf_module.create_handle() + self.preprocess_weight = self.inf_module.preprocess_weight + self.preprocess_scales = self.inf_module.preprocess_scales + + self.quantizer = fp_quantize + + self.DEBUG = True + + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + # It expects that the quantization scales are store in the attribute `scales`. + + if param.ndim == 1: # bias, do nothing + return InferenceParameter.initialize(param) + + if self.DEBUG: + self.weight = param.clone().cpu() + + quantized_fake_fp6, scales = self.quantizer( + param, num_bits=6, exp_bits=3) + + # This is for debugging, will delete after release. + assert (quantized_fake_fp6.dtype == torch.float16) + assert quantized_fake_fp6.shape[0] == self.M + assert scales.numel() == self.M + + weights_2bit, weights_4bit = self.preprocess_weight(quantized_fake_fp6) + + # According to the optimization in Quant-LLM, the scales need to be multiplied by 2^12. + scales = scales * (2 ** 12) + scales = self.preprocess_scales(scales, self.M, self.K) + + return InferenceParameter.initialize(weights_4bit, weights_2bit=weights_2bit, scales=scales) + + def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: + weights_4bit = w + weights_2bit = w.weights_2bit + scales = w.scales + output = empty_from( + self._output, (hidden_states.shape[0], self._config.out_channels)) + # N = hidden_states.shape[0] + if self._is_gated: + staging_output = empty_from( + self._double_buffer, (hidden_states.shape[0], self.M)) + self._linear_impl(staging_output, hidden_states, weights_4bit, + weights_2bit, scales, self.M, hidden_states.shape[0], self.K) + self._act_fn(output, staging_output, b) + else: + self._linear_impl(output, hidden_states, weights_4bit, + weights_2bit, scales, self.M, hidden_states.shape[0], self.K) + self._act_fn(output, b) + + if self.DEBUG: + orig_device = self.weight.device + self.weight = self.weight.to(output.device) + ground_truth = torch.nn.functional.linear( + hidden_states, self.weight, b) + self.weight = self.weight.to(orig_device) + shape = (hidden_states.shape[0], self.M, self.K) + if self._is_gated: + ismatch = torch.allclose( + ground_truth, staging_output, rtol=1e-3) + print(f"Linear shape: {shape}:\n\tIs correct: {ismatch}. " + f"Max diff: {torch.max(torch.abs(ground_truth - staging_output))}") + else: + ismatch = torch.allclose(ground_truth, output, rtol=1e-3) + print(f"Linear shape: {shape}:\n\tIs correct: {ismatch}. " + f"Max diff: {torch.max(torch.abs(ground_truth - output))}") + + return output + + @property + def output(self) -> torch.Tensor: + """ + Return the padded, pre-allocated output Tensor. + """ + return self._output diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index 8073b63ad16b..c1ef43f8170f 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -67,6 +67,10 @@ def sources(self): "inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu", "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp", "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu", + "inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu", + "inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp", + "inference/v2/kernels/core_ops/cuda_linear/Launcher.cu", + "inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp", ] prefix = self.get_prefix() @@ -83,6 +87,7 @@ def include_paths(self): 'inference/v2/kernels/core_ops/cuda_layer_norm', 'inference/v2/kernels/core_ops/cuda_rms_norm', 'inference/v2/kernels/core_ops/gated_activations', + 'inference/v2/kernels/core_ops/cuda_linear', 'inference/v2/kernels/includes', ] From 91bb4d794adc356a323dd8e746846d797fedb59a Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Fri, 26 Jan 2024 06:19:53 +0000 Subject: [PATCH 02/31] Update CUDA kernels and clean codes. --- .../v2/kernels/core_ops/core_ops.cpp | 1 - .../cuda_linear/GenMatrix_QuantLLM.cpp | 300 ------------------ .../core_ops/cuda_linear/GenMatrix_QuantLLM.h | 204 ++++++++++-- .../core_ops/cuda_linear/Kernel_QuantGEMM.cuh | 53 +--- .../kernels/core_ops/cuda_linear/Launcher.cu | 122 ++----- .../core_ops/cuda_linear/cuda_linear.py | 6 +- .../cuda_linear/cuda_linear_kernels.cpp | 138 +++++--- .../cuda_linear/cuda_linear_kernels.cu | 58 ---- .../cuda_linear/cuda_linear_kernels.h | 4 +- .../core_ops/cuda_linear/quant_gemm_api.cuh | 25 -- .../core_ops/cuda_linear/utils/Utils_GMem.cuh | 38 +-- .../utils/Utils_ParallelDequant.cuh | 122 ++++--- .../linear/quantized_linear.py | 18 +- op_builder/inference_core_ops.py | 2 - 14 files changed, 385 insertions(+), 706 deletions(-) delete mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp delete mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu delete mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/quant_gemm_api.cuh diff --git a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp index fccb248816fc..2397b0694696 100644 --- a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp @@ -39,5 +39,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("cuda_wf6af16_linear", &cuda_wf6af16_linear, "DeepSpeed Wf6Af16 linear in CUDA"); m.def( "preprocess_weight", &preprocess_weight, "preprocess the FP16 weight to be 2bit and 4 bit"); - m.def("preprocess_scales", &preprocess_scales, "preprocess the FP16 scales"); } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp deleted file mode 100644 index ba0592de08e4..000000000000 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp +++ /dev/null @@ -1,300 +0,0 @@ -#include "GenMatrix_QuantLLM.h" -#include -#include -#include -#include - -using namespace std; - -void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], - unsigned char* FP6_Array) // padding 0 to the lowerest bit location -{ - Padded_FP6[0] = FP6_Array[0] & 0xfc; - Padded_FP6[1] = (FP6_Array[0] << 6) | ((FP6_Array[1] >> 2) & 0xfc); - Padded_FP6[2] = (FP6_Array[1] << 4) | ((FP6_Array[2] >> 4) & 0xfc); - Padded_FP6[3] = FP6_Array[2] << 2; - Padded_FP6[4] = FP6_Array[3] & 0xfc; - Padded_FP6[5] = (FP6_Array[3] << 6) | ((FP6_Array[4] >> 2) & 0xfc); - Padded_FP6[6] = (FP6_Array[4] << 4) | ((FP6_Array[5] >> 4) & 0xfc); - Padded_FP6[7] = FP6_Array[5] << 2; -} - -unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, - unsigned char B2, - unsigned char B3, - unsigned char B4) -{ - unsigned char out; - out = (B1 & 0xc0) | ((B2 & 0xc0) >> 2) | ((B3 & 0xc0) >> 4) | ((B4 & 0xc0) >> 6); - return out; -} - -unsigned char Extract_4_Bits_From_2_PaddedFP6( - unsigned char B1, - unsigned char - B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); -{ - unsigned char out; - out = ((B1 << 2) & 0xf0) | ((B2 >> 2) & 0x0f); - return out; -} - -// dealing with 4 1*8 blocks of FP6 -void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], - vector Seg_4bit[], - unsigned char* PTR_1, - unsigned char* PTR_2, - unsigned char* PTR_3, - unsigned char* PTR_4) -{ - unsigned char Padded_8_FP8[4][8]; - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); - Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); - // - unsigned char Seg1_Byte1_T[4]; - unsigned char Seg1_Byte2_T[4]; - unsigned char Seg2_Byte1_T[4]; - unsigned char Seg2_Byte2_T[4]; - unsigned char Seg2_Byte3_T[4]; - unsigned char Seg2_Byte4_T[4]; - for (int t = 0; t < 4; t++) { - Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0 + t * 2], - Padded_8_FP8[0][1 + t * 2], - Padded_8_FP8[1][0 + t * 2], - Padded_8_FP8[1][1 + t * 2]); - Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0 + t * 2], - Padded_8_FP8[2][1 + t * 2], - Padded_8_FP8[3][0 + t * 2], - Padded_8_FP8[3][1 + t * 2]); - Seg2_Byte1_T[t] = - Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0 + t * 2], Padded_8_FP8[0][1 + t * 2]); - Seg2_Byte2_T[t] = - Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0 + t * 2], Padded_8_FP8[1][1 + t * 2]); - Seg2_Byte3_T[t] = - Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0 + t * 2], Padded_8_FP8[2][1 + t * 2]); - Seg2_Byte4_T[t] = - Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0 + t * 2], Padded_8_FP8[3][1 + t * 2]); - } - // - for (int t = 0; t < 4; t++) { - Seg_2bit[t].push_back(Seg1_Byte1_T[t]); - Seg_2bit[t].push_back(Seg1_Byte2_T[t]); - Seg_4bit[t].push_back(Seg2_Byte1_T[t]); - Seg_4bit[t].push_back(Seg2_Byte2_T[t]); - Seg_4bit[t].push_back(Seg2_Byte3_T[t]); - Seg_4bit[t].push_back(Seg2_Byte4_T[t]); - } - return; -} - -void BitInterleaving_2bit(unsigned char* PTR_4Bytes) -{ - unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); - unsigned int input = *PTR_UINT; - // - // int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for - // bit-interleaving in QuantLLM - int order_2bit[16] = { - 2, 6, 10, 14, 4, 8, 12, 16, 1, 5, 9, 13, 3, 7, 11, 15}; // pre-defined order for - // bit-interleaving in QuantLLM - unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. - for (int i = 0; i < 16; i++) Frags_2bit[i] = (input << 2 * (order_2bit[i] - 1)) & 0xc0000000; - // - unsigned int output = 0x00000000; - for (int i = 0; i < 16; i++) output |= (Frags_2bit[i] >> (i * 2)); - // - *PTR_UINT = output; -} - -void BitInterleaving_4bit(unsigned char* PTR_4Bytes) -{ - unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); - unsigned int input = *PTR_UINT; - // - // int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in QuantLLM - int order_4bit[8] = { - 2, 6, 4, 8, 1, 5, 3, 7}; // pre-defined order for bit-interleaving in QuantLLM - unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. - for (int i = 0; i < 8; i++) Frags_4bit[i] = (input << 4 * (order_4bit[i] - 1)) & 0xf0000000; - // - unsigned int output = 0x00000000; - for (int i = 0; i < 8; i++) output |= (Frags_4bit[i] >> (i * 4)); - // - *PTR_UINT = output; -} - -short GetShort(unsigned char* Scale_In, size_t row, size_t col, size_t BytesPerRow) -{ - unsigned char* PTR_8bit = Scale_In; - PTR_8bit += row * BytesPerRow + col * 2; - short* PTR_16bit = reinterpret_cast(PTR_8bit); - return (*PTR_16bit); -} - -/* - * Inputs: - * (1) unsigned char Weight_6bit [M*K*6/8] - * Outputs: - * (1) unsigned char Weight_2bit [M*K*2/8] - * (2) unsigned char Weight_4bit [M*K*4/8] - * - * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. - * 8 FP6 = 6 Bytes - * 8 FP4 = 4 Bytes - * 8 FP2 = 2 Bytes - */ -void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, - unsigned char* Weight_2bit, - unsigned char* Weight_4bit, - size_t M, - size_t K) -{ - assert(M % 64 == 0); - assert(K % 64 == 0); - // - vector A_Segment_2bit[32]; - vector A_Segment_4bit[32]; - // - size_t BytesPerRow = K * 6 / 8; -// Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. -#pragma omp parallel for - for (size_t i = 0; i < M / 64; i++) // - { - for (size_t j = 0; j < K / 16; j++) { - for (size_t k = 0; k < 64 / 16; k++) { - size_t row = i * 64 + k * 16; - size_t col = j * 16; - unsigned char* StartPTR_1 = Weight_6bit + row * BytesPerRow + col * 6 / 8; - unsigned char* StartPTR_2 = StartPTR_1 + 8 * BytesPerRow; - unsigned char* StartPTR_3 = StartPTR_1 + 8 * 6 / 8; - unsigned char* StartPTR_4 = StartPTR_2 + 8 * 6 / 8; - // Dealing with each 16*16 blocks then... - for (int l = 0; l < 8; l++) - Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l * 4], - &A_Segment_4bit[l * 4], - StartPTR_1 + l * BytesPerRow, - StartPTR_2 + l * BytesPerRow, - StartPTR_3 + l * BytesPerRow, - StartPTR_4 + l * BytesPerRow); - } - } - } - // Verifying the length of 2_bit segments and 4_bit segments - size_t BytesPerThread_2bit = M * K * 2 / 8 / 32; - size_t BytesPerThread_4bit = M * K * 4 / 8 / 32; - for (int i = 0; i < 32; i++) { - assert(A_Segment_2bit[i].size() == BytesPerThread_2bit); - assert(A_Segment_4bit[i].size() == BytesPerThread_4bit); - } -// Pass-2: Optimizing coleasced global memory access -#pragma omp parallel for collapse(2) - for (int t = 0; t < 32; t++) { - for (int b = 0; b < 4; b++) { - for (size_t i = 0; i < BytesPerThread_2bit / 4; i++) - Weight_2bit[i * 128 + t * 4 + (3 - b)] = - A_Segment_2bit[t] - [i * 4 + b]; // why (3-b): special byte order within a register - for (size_t i = 0; i < BytesPerThread_4bit / 4; i++) - Weight_4bit[i * 128 + t * 4 + (3 - b)] = - A_Segment_4bit[t] - [i * 4 + b]; // why (3-b): special byte order within a register - } - } - - // Pass-3: Bit-level interleaving - for (size_t i = 0; i < BytesPerThread_2bit * 32 / 4; i++) - BitInterleaving_2bit(Weight_2bit + 4 * i); - for (size_t i = 0; i < BytesPerThread_4bit * 32 / 4; i++) - BitInterleaving_4bit(Weight_4bit + 4 * i); -} - -/* - * Inputs: - * (1) unsigned char Scale_In[M*K/GroupSize*16/8] - * Outputs: - * (1) unsigned char Scale_Out[M*K/GroupSize*16/8] - */ -void GenMatrix_Scale_FP16(unsigned char* Scale_Out, - unsigned char* Scale_In, - size_t M, - size_t K, - int GroupSize) -{ - short* Out_PTR = reinterpret_cast(Scale_Out); - // - assert(K % GroupSize == 0); - size_t BytesPerRow = K / GroupSize * 2; - // - for (size_t i = 0; i < M / 64; i++) - for (size_t j = 0; j < K / GroupSize; j++) - for (int l = 0; l < 8; l++) { - *Out_PTR = GetShort(Scale_In, 0 + 64 * i + l, j, BytesPerRow); - Out_PTR += 1; - *Out_PTR = GetShort(Scale_In, 8 + 64 * i + l, j, BytesPerRow); - Out_PTR += 1; - *Out_PTR = GetShort(Scale_In, 16 + 64 * i + l, j, BytesPerRow); - Out_PTR += 1; - *Out_PTR = GetShort(Scale_In, 24 + 64 * i + l, j, BytesPerRow); - Out_PTR += 1; - *Out_PTR = GetShort(Scale_In, 32 + 64 * i + l, j, BytesPerRow); - Out_PTR += 1; - *Out_PTR = GetShort(Scale_In, 40 + 64 * i + l, j, BytesPerRow); - Out_PTR += 1; - *Out_PTR = GetShort(Scale_In, 48 + 64 * i + l, j, BytesPerRow); - Out_PTR += 1; - *Out_PTR = GetShort(Scale_In, 56 + 64 * i + l, j, BytesPerRow); - Out_PTR += 1; - } - return; -} - -void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) -{ - constexpr int exponent_bits_fp6 = 3; - constexpr int mantissa_bits_fp6 = 2; - // Constants for FP16 - constexpr int exponent_bits_fp16 = 5; - constexpr int mantissa_bits_fp16 = 10; - constexpr int exp_bias_fp16 = (1 << (exponent_bits_fp16 - 1)) - 1; - - uint8_t fp6_temp[4]; - - for (int i = 0; i < 4; ++i) { - int sign = (FP16x4[i] >> 15); - int exp = (FP16x4[i] >> mantissa_bits_fp16) & - ((1 << exponent_bits_fp16) - 1); // Extracting exponent - int mant = FP16x4[i] & ((1 << mantissa_bits_fp16) - 1); // Extracting mantissa - - int new_exp = exp - exp_bias_fp16; - int new_mant = mant >> (mantissa_bits_fp16 - mantissa_bits_fp6); - - fp6_temp[i] = (sign << (exponent_bits_fp6 + mantissa_bits_fp6)) | - (new_exp << mantissa_bits_fp6) | new_mant; - } - // Pack the values - FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); - FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); - FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; -} - -/* - * Inputs: - * (1) uint16_t Weight_16bit[M*K] - * Outputs: - * (1) unsigned char Weight_6bit[M*K*6/8] - */ -void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t M, size_t K) -{ -#pragma omp parallel for - for (auto m = 0; m < M; m++) { - uint8_t* ptr_6bit = Weight_6bit + m * K * 6 / 8; - uint16_t* ptr_16bit = Weight_16bit + m * K; - for (auto k = 0; k < K; k += 4) { - Cast_FP16_FP6(ptr_16bit, ptr_6bit); - ptr_16bit += 4; - ptr_6bit += 3; - } - } -} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h index 77810eb39ba5..7ab71f957627 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h @@ -1,6 +1,127 @@ -#pragma once -#include -#include +#include +#include +#include + +using namespace std; + +void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], + unsigned char* FP6_Array) // padding 0 to the lowerest bit location +{ + Padded_FP6[0] = FP6_Array[0] & 0xfc; + Padded_FP6[1] = (FP6_Array[0] << 6) | ((FP6_Array[1] >> 2) & 0xfc); + Padded_FP6[2] = (FP6_Array[1] << 4) | ((FP6_Array[2] >> 4) & 0xfc); + Padded_FP6[3] = FP6_Array[2] << 2; + Padded_FP6[4] = FP6_Array[3] & 0xfc; + Padded_FP6[5] = (FP6_Array[3] << 6) | ((FP6_Array[4] >> 2) & 0xfc); + Padded_FP6[6] = (FP6_Array[4] << 4) | ((FP6_Array[5] >> 4) & 0xfc); + Padded_FP6[7] = FP6_Array[5] << 2; +} + +unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, + unsigned char B2, + unsigned char B3, + unsigned char B4) +{ + unsigned char out; + out = (B1 & 0xc0) | ((B2 & 0xc0) >> 2) | ((B3 & 0xc0) >> 4) | ((B4 & 0xc0) >> 6); + return out; +} + +unsigned char Extract_4_Bits_From_2_PaddedFP6( + unsigned char B1, + unsigned char + B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); +{ + unsigned char out; + out = ((B1 << 2) & 0xf0) | ((B2 >> 2) & 0x0f); + return out; +} + +// dealing with 4 1*8 blocks of FP6 +void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], + vector Seg_4bit[], + unsigned char* PTR_1, + unsigned char* PTR_2, + unsigned char* PTR_3, + unsigned char* PTR_4) +{ + unsigned char Padded_8_FP8[4][8]; + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); + // + unsigned char Seg1_Byte1_T[4]; + unsigned char Seg1_Byte2_T[4]; + unsigned char Seg2_Byte1_T[4]; + unsigned char Seg2_Byte2_T[4]; + unsigned char Seg2_Byte3_T[4]; + unsigned char Seg2_Byte4_T[4]; + for (int t = 0; t < 4; t++) { + Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0 + t * 2], + Padded_8_FP8[0][1 + t * 2], + Padded_8_FP8[1][0 + t * 2], + Padded_8_FP8[1][1 + t * 2]); + Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0 + t * 2], + Padded_8_FP8[2][1 + t * 2], + Padded_8_FP8[3][0 + t * 2], + Padded_8_FP8[3][1 + t * 2]); + Seg2_Byte1_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0 + t * 2], Padded_8_FP8[0][1 + t * 2]); + Seg2_Byte2_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0 + t * 2], Padded_8_FP8[1][1 + t * 2]); + Seg2_Byte3_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0 + t * 2], Padded_8_FP8[2][1 + t * 2]); + Seg2_Byte4_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0 + t * 2], Padded_8_FP8[3][1 + t * 2]); + } + // + for (int t = 0; t < 4; t++) { + Seg_2bit[t].push_back(Seg1_Byte1_T[t]); + Seg_2bit[t].push_back(Seg1_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte1_T[t]); + Seg_4bit[t].push_back(Seg2_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte3_T[t]); + Seg_4bit[t].push_back(Seg2_Byte4_T[t]); + } + return; +} + +void BitInterleaving_2bit(unsigned char* PTR_4Bytes) +{ + unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + // int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for + // bit-interleaving in QuantLLM + int order_2bit[16] = { + 2, 6, 10, 14, 4, 8, 12, 16, 1, 5, 9, 13, 3, 7, 11, 15}; // pre-defined order for + // bit-interleaving in QuantLLM + unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. + for (int i = 0; i < 16; i++) Frags_2bit[i] = (input << 2 * (order_2bit[i] - 1)) & 0xc0000000; + // + unsigned int output = 0x00000000; + for (int i = 0; i < 16; i++) output |= (Frags_2bit[i] >> (i * 2)); + // + *PTR_UINT = output; +} + +void BitInterleaving_4bit(unsigned char* PTR_4Bytes) +{ + unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + // int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in QuantLLM + int order_4bit[8] = { + 2, 6, 4, 8, 1, 5, 3, 7}; // pre-defined order for bit-interleaving in QuantLLM + unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. + for (int i = 0; i < 8; i++) Frags_4bit[i] = (input << 4 * (order_4bit[i] - 1)) & 0xf0000000; + // + unsigned int output = 0x00000000; + for (int i = 0; i < 8; i++) output |= (Frags_4bit[i] >> (i * 4)); + // + *PTR_UINT = output; +} /* * Inputs: @@ -18,24 +139,59 @@ void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, unsigned char* Weight_2bit, unsigned char* Weight_4bit, size_t M, - size_t K); - -/* - * Inputs: - * (1) unsigned char Scale_In[M*K/GroupSize*16/8] - * Outputs: - * (1) unsigned char Scale_Out[M*K/GroupSize*16/8] - */ -void GenMatrix_Scale_FP16(unsigned char* Scale_Out, - unsigned char* Scale_In, - size_t M, - size_t K, - int GroupSize); - -/* - * Inputs: - * (1) uint16_t Weight_16bit[M*K] - * Outputs: - * (1) unsigned char Weight_6bit[M*K*6/8] - */ -void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, unsigned char* Weight_6bit, size_t M, size_t K); + size_t K) +{ + assert(M % 64 == 0); + assert(K % 64 == 0); + // + vector A_Segment_2bit[32]; + vector A_Segment_4bit[32]; + // + size_t BytesPerRow = K * 6 / 8; + // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. + for (size_t i = 0; i < M / 64; i++) // + { + for (size_t j = 0; j < K / 16; j++) { + for (size_t k = 0; k < 64 / 16; k++) { + size_t row = i * 64 + k * 16; + size_t col = j * 16; + unsigned char* StartPTR_1 = Weight_6bit + row * BytesPerRow + col * 6 / 8; + unsigned char* StartPTR_2 = StartPTR_1 + 8 * BytesPerRow; + unsigned char* StartPTR_3 = StartPTR_1 + 8 * 6 / 8; + unsigned char* StartPTR_4 = StartPTR_2 + 8 * 6 / 8; + // Dealing with each 16*16 blocks then... + for (int l = 0; l < 8; l++) + Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l * 4], + &A_Segment_4bit[l * 4], + StartPTR_1 + l * BytesPerRow, + StartPTR_2 + l * BytesPerRow, + StartPTR_3 + l * BytesPerRow, + StartPTR_4 + l * BytesPerRow); + } + } + } + // Verifying the length of 2_bit segments and 4_bit segments + size_t BytesPerThread_2bit = M * K * 2 / 8 / 32; + size_t BytesPerThread_4bit = M * K * 4 / 8 / 32; + for (int i = 0; i < 32; i++) { + assert(A_Segment_2bit[i].size() == BytesPerThread_2bit); + assert(A_Segment_4bit[i].size() == BytesPerThread_4bit); + } + // Pass-2: Optimizing coleasced global memory access + for (size_t i = 0; i < BytesPerThread_2bit / 4; i++) + for (int t = 0; t < 32; t++) + for (int b = 0; b < 4; b++) + Weight_2bit[i * 128 + t * 4 + (3 - b)] = + A_Segment_2bit[t] + [i * 4 + b]; // why (3-b): special byte order within a register + for (size_t i = 0; i < BytesPerThread_4bit / 4; i++) + for (int t = 0; t < 32; t++) + for (int b = 0; b < 4; b++) + Weight_4bit[i * 128 + t * 4 + (3 - b)] = + A_Segment_4bit[t][i * 4 + b]; // why (3-b):special byte order within a register + // Pass-3: Bit-level interleaving + for (size_t i = 0; i < BytesPerThread_2bit * 32 / 4; i++) + BitInterleaving_2bit(Weight_2bit + 4 * i); + for (size_t i = 0; i < BytesPerThread_4bit * 32 / 4; i++) + BitInterleaving_4bit(Weight_4bit + 4 * i); +} \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh index 46cc0bc5a002..03d27433bc2b 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh @@ -14,12 +14,6 @@ #include "utils/Utils_GMem.cuh" #include "utils/Utils_Core.cuh" -__device__ __forceinline__ void ExchangePTRs(void** PTR1, void** PTR2) { - void* tmp_PTR = *PTR1; - *PTR1 = *PTR2; - *PTR2 = tmp_PTR; -} - /* * C = A*B * A: row major with ahead-of-time layout transformation, FP6 @@ -27,7 +21,7 @@ __device__ __forceinline__ void ExchangePTRs(void** PTR1, void** PTR2) { * C: col major, FP16 */ template -__global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, const half* Scales, const int QUANT_GROUP_SIZE_DIVIDED_BY_64, +__global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, const half* Scales, const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, @@ -40,13 +34,14 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, co #endif extern __shared__ __align__(128) half smem[]; // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + (SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE)/2 ); // Dynamic shared memory for FP16 B tiles - __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS*2]; // static shared memory for quantization scales, 64 row per warp * 4 warps * 2 double buffer = 1 KB + __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes // Thread Block Mapping, considering SplitK const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) const size_t Tile_Start_M = y * TilingConfig::TILE_M; const size_t Tile_Start_N = x * TilingConfig::TILE_N; + const size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N; const size_t NumBlock_K = K_Global/TilingConfig::TILE_K; const size_t AverageNumBlock_K = NumBlock_K/Split_K; const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; @@ -85,18 +80,13 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, co WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; } // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// - #ifdef DEBUG_MODE - assert(NumBlock_K%QUANT_GROUP_SIZE_DIVIDED_BY_64==0); - #endif - const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS)* (NumBlock_K/QUANT_GROUP_SIZE_DIVIDED_BY_64) * 64; - const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * (NumBlock_K/QUANT_GROUP_SIZE_DIVIDED_BY_64) * 64; - size_t UnitID_K = WARP_Start_UnitID_K; - size_t QuantGroup_K = UnitID_K / QUANT_GROUP_SIZE_DIVIDED_BY_64; - CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales + QuantGroup_K*64); + const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global); + CopyFromGlobalToShared (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); BTile_GPTR += TilingConfig::TILE_K; } // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// @@ -115,13 +105,8 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, co __syncthreads(); ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // For Quantization Scales - half *read_WARP_SPTR_Scales = QuantScales + WARP_i*64; - half *write_WARP_SPTR_Scales = read_WARP_SPTR_Scales + 64*TilingConfig::BLOCK_WARPS; - // 4 Registers per thread for Quantization Scales, Preparing Scales for the first loop even it is not the start of a new quant group (SplitK) /////////// - uint32_t Scales_RPTR[4]; - ExtractFromSharedToReg_Scales(Scales_RPTR, read_WARP_SPTR_Scales); - + uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales + ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i*64); #ifdef PIPELINE_LEVEL_SMEM // Initializing the Software Pipeline: writing registers. //////////////////////////////////////////////////////////////////////////////////////////////// initialize_mma_slice(a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); @@ -147,25 +132,11 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, co half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; // bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; - /* - // Optionally Updating QuantScale for A Tile - UnitID_K = WARP_Start_UnitID_K + tile_id_k; - QuantGroup_K = UnitID_K / QUANT_GROUP_SIZE_DIVIDED_BY_64; - //bool SwitchQuantGroup = (UnitID_K % QUANT_GROUP_SIZE_DIVIDED_BY_64 == (QUANT_GROUP_SIZE_DIVIDED_BY_64-1)); - bool SwitchQuantGroup = false; - if(SwitchQuantGroup) CopyFromGlobalToShared_Scales(write_WARP_SPTR_Scales, WARP_StartGPTR_A_Scales + (QuantGroup_K+1)*64, GlobalCopy); // If the next loop need the new scales, load the scales from global to shared. - //bool IsNewQuantGroup = (UnitID_K % QUANT_GROUP_SIZE_DIVIDED_BY_64 == 0); - bool IsNewQuantGroup = false; - if(IsNewQuantGroup) ExtractFromSharedToReg_Scales(Scales_RPTR, read_WARP_SPTR_Scales); // If the curent loop need the new scales, load the scales from shared to registers. - // Exchanging the PTRs for double buffers - if(SwitchQuantGroup) ExchangePTRs((void**)&read_WARP_SPTR_Scales, (void**)&write_WARP_SPTR_Scales); // SPTRs for Scales - */ - // Copying A tile from Global to Register, Bypassing L1, using double-buffer CopyFromGlobalToShared_A(write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); CopyFromGlobalToShared_A(write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); // copying B tile from GlobalMemory to SharedMemory - CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, GlobalCopy); + CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); cp_async_group_commit(); #ifdef PIPELINE_LEVEL_SMEM core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs @@ -199,8 +170,6 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, co __syncthreads(); // Now that shared memory contains all the D tiles, stream them to global memory. OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global; - - size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N; for(size_t i=warpId; i::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; } -} +} \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu index 6f9a8daa499b..0e8636a81452 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu @@ -10,7 +10,6 @@ static void Kernel_QuantGEMM_Ex(cudaStream_t stream, const uint4* Weight1, const uint4* Weight2, const half* Scales, - const int QUANT_GROUP_SIZE_DIVIDED_BY_64, const half* B, OutputDataType* C, const size_t M_Global, @@ -21,19 +20,11 @@ static void Kernel_QuantGEMM_Ex(cudaStream_t stream, #ifdef DEBUG_MODE printf("\n"); printf("Launcher.cu->Kernel_QuantGEMM_Ex():\n"); - printf("M: %d, N: %d, K: %d, SplitK: %d, QUANT_GROUP_SIZE_DIVIDED_BY_64: %d\n", - M_Global, - N_Global, - K_Global, - Split_K, - QUANT_GROUP_SIZE_DIVIDED_BY_64); + printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); - // assert(N_Global % TilingConfig::TILE_N == 0); - // assert(M_Global*Split_K % TilingConfig::TILE_M == 0); - // assert(K_Global % TilingConfig::TILE_K == 0); #endif static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE, @@ -59,17 +50,8 @@ static void Kernel_QuantGEMM_Ex(cudaStream_t stream, SHMEM_SZ); printf("\n"); #endif - QUANT_GEMM_Kernel - <<>>(Weight1, - Weight2, - Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, - B, - C, - M_Global, - N_Global, - K_Global, - Split_K); + QUANT_GEMM_Kernel<<>>( + Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } /* @@ -83,7 +65,6 @@ cudaError_t QuantGEMM_API( const uint4* Weight1, const uint4* Weight2, const half* Scales, - const int QUANT_GROUP_SIZE_DIVIDED_BY_64, const half* B, half* C, const size_t M_Global, @@ -92,10 +73,9 @@ cudaError_t QuantGEMM_API( float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches int Split_K) { - if (N_Global <= 0) { - printf("QuantLLM_API Error: Unsupported N dimension %ld!\n", N_Global); - return cudaErrorUnknown; - } + // assert(M_Global % TilingConfig::TILE_M == 0); + // assert(K_Global % TilingConfig::TILE_K == 0); + assert(N_Global > 0); // Work around to support more N shapes: Pretending that the input is 2^n size_t N_PowerOf2; @@ -110,86 +90,32 @@ cudaError_t QuantGEMM_API( if (Split_K == 1) { switch (N_PowerOf2) { case 8: - Kernel_QuantGEMM_Ex, half>(stream, - Weight1, - Weight2, - Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, - B, - C, - M_Global, - N_Global, - K_Global, - Split_K); + Kernel_QuantGEMM_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 16: - Kernel_QuantGEMM_Ex, half>(stream, - Weight1, - Weight2, - Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, - B, - C, - M_Global, - N_Global, - K_Global, - Split_K); + Kernel_QuantGEMM_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 32: - Kernel_QuantGEMM_Ex, half>(stream, - Weight1, - Weight2, - Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, - B, - C, - M_Global, - N_Global, - K_Global, - Split_K); + Kernel_QuantGEMM_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 64: - Kernel_QuantGEMM_Ex, half>(stream, - Weight1, - Weight2, - Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, - B, - C, - M_Global, - N_Global, - K_Global, - Split_K); + Kernel_QuantGEMM_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; case 128: - Kernel_QuantGEMM_Ex, half>(stream, - Weight1, - Weight2, - Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, - B, - C, - M_Global, - N_Global, - K_Global, - Split_K); + Kernel_QuantGEMM_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { - printf("QuantLLM_API Error: Unsupported N dimension %ld!\n", N_PowerOf2); + printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } - Kernel_QuantGEMM_Ex, half>(stream, - Weight1, - Weight2, - Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, - B, - C, - M_Global, - N_Global, - K_Global, - Split_K); + Kernel_QuantGEMM_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } } else { @@ -199,7 +125,6 @@ cudaError_t QuantGEMM_API( Weight1, Weight2, Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, B, Reduction_Workspace, M_Global, @@ -212,7 +137,6 @@ cudaError_t QuantGEMM_API( Weight1, Weight2, Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, B, Reduction_Workspace, M_Global, @@ -225,7 +149,6 @@ cudaError_t QuantGEMM_API( Weight1, Weight2, Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, B, Reduction_Workspace, M_Global, @@ -238,7 +161,6 @@ cudaError_t QuantGEMM_API( Weight1, Weight2, Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, B, Reduction_Workspace, M_Global, @@ -251,7 +173,6 @@ cudaError_t QuantGEMM_API( Weight1, Weight2, Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, B, Reduction_Workspace, M_Global, @@ -261,14 +182,13 @@ cudaError_t QuantGEMM_API( break; default: if (N_PowerOf2 % 128 != 0) { - printf("QuantLLM_API Error: Unsupported N dimension %ld!\n", N_PowerOf2); + printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } Kernel_QuantGEMM_Ex, float>(stream, Weight1, Weight2, Scales, - QUANT_GROUP_SIZE_DIVIDED_BY_64, B, Reduction_Workspace, M_Global, @@ -284,4 +204,4 @@ cudaError_t QuantGEMM_API( C, Reduction_Workspace, M_Global, N_Global, Split_K); } return cudaGetLastError(); -} +} \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py index 9431ddc96f41..aa973daf81cb 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -24,7 +24,7 @@ def __init__(self): self.inf_module.create_handle() self.kernel = self.inf_module.cuda_wf6af16_linear - def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_4bit: torch.Tensor, weights_2bit: torch.Tensor, scale: torch.Tensor, M, N, K) -> torch.Tensor: + def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2bit: torch.Tensor, weights_4bit: torch.Tensor, scale: torch.Tensor, M, N, K) -> torch.Tensor: """ Matmul kernel as implemented via CUDA directly. The input must be 2D or larger. If n-dimensional, the leading dimensions are folded into each other: @@ -54,8 +54,8 @@ def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_4b split_k = split_k_dict[M] workspace = self.get_workspace( M, N, K, split_k, torch.float, hidden_states.device) - self.kernel(output, hidden_states, weights_4bit, - weights_2bit, scale, workspace, M, N, K, split_k) + self.kernel(output, hidden_states, weights_2bit, + weights_4bit, scale, workspace, M, N, K, split_k) def get_workspace(self, M: int, N: int, K: int, split_k: int, dtype, device) -> torch.Tensor: """ diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index eb4ad0b9318b..42ae7342d844 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -3,25 +3,84 @@ // DeepSpeed Team +#include + #include "cuda_linear_kernels.h" -#include "GenMatrix_QuantLLM.h" - -void Launch_QuantGEMM(torch::Tensor output, - torch::Tensor Weight1, // 2bit - torch::Tensor Weight2, // 4bit - torch::Tensor B, - torch::Tensor Scales, - const int M_Global, - const int N_Global, - const int K_Global, - const int Split_K, - torch::Tensor workspace); + +namespace { + +// Utils to prepack FP16 weights into continuous FP6 values. +// TODO: debug it. + +void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) +{ + constexpr int exponent_bits_fp6 = 3; + constexpr int mantissa_bits_fp6 = 2; + // Constants for FP16 + constexpr int exponent_bits_fp16 = 5; + constexpr int mantissa_bits_fp16 = 10; + constexpr int exp_bias_fp16 = (1 << (exponent_bits_fp16 - 1)) - 1; + + uint8_t fp6_temp[4]; + + for (int i = 0; i < 4; ++i) { + int sign = (FP16x4[i] >> 15); + int exp = (FP16x4[i] >> mantissa_bits_fp16) & + ((1 << exponent_bits_fp16) - 1); // Extracting exponent + int mant = FP16x4[i] & ((1 << mantissa_bits_fp16) - 1); // Extracting mantissa + + int new_exp = exp - exp_bias_fp16; + int new_mant = mant >> (mantissa_bits_fp16 - mantissa_bits_fp6); + + fp6_temp[i] = (sign << (exponent_bits_fp6 + mantissa_bits_fp6)) | + (new_exp << mantissa_bits_fp6) | new_mant; + } + // Pack the values + FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); + FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); + FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; +} + +/* + * Inputs: + * (1) uint16_t Weight_16bit[M*K] + * Outputs: + * (1) unsigned char Weight_6bit[M*K*6/8] + */ +void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t M, size_t K) +{ +#pragma omp parallel for + for (auto m = 0; m < M; m++) { + uint8_t* ptr_6bit = Weight_6bit + m * K * 6 / 8; + uint16_t* ptr_16bit = Weight_16bit + m * K; + for (auto k = 0; k < K; k += 4) { + Cast_FP16_FP6(ptr_16bit, ptr_6bit); + ptr_16bit += 4; + ptr_6bit += 3; + } + } +} + +} // namespace + +cudaError_t QuantGEMM_API( + cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches + int Split_K); void cuda_wf6af16_linear(torch::Tensor& output, torch::Tensor& hidden_states, - torch::Tensor& weights_4bit, torch::Tensor& weights_2bit, - torch::Tensor& scale, + torch::Tensor& weights_4bit, + torch::Tensor& scales, torch::Tensor& workspace, int M, int N, @@ -31,11 +90,30 @@ void cuda_wf6af16_linear(torch::Tensor& output, TORCH_CHECK(weights_2bit.device().type() == torch::kCUDA, "weight_2bit must be on CUDA"); TORCH_CHECK(weights_4bit.device().type() == torch::kCUDA, "weight_4bit must be on CUDA"); TORCH_CHECK(hidden_states.device().type() == torch::kCUDA, "X must be on CUDA"); - TORCH_CHECK(scale.device().type() == torch::kCUDA, "scale must be on CUDA"); - Launch_QuantGEMM( - output, weights_2bit, weights_4bit, hidden_states, scale, M, N, K, split_k, workspace); + TORCH_CHECK(scales.device().type() == torch::kCUDA, "scales must be on CUDA"); + + auto status = QuantGEMM_API(at::cuda::getCurrentCUDAStream(), + (uint4*)(weights_2bit.data_ptr()), + (uint4*)(weights_4bit.data_ptr()), + (half*)(scales.data_ptr()), + (half*)(hidden_states.data_ptr()), + (half*)(output.data_ptr()), + M, + N, + K, + workspace.data_ptr(), + split_k); + if (status != cudaSuccess) { + AT_ERROR("QuantGEMM_API failed with error: ", cudaGetErrorString(status)); + } } +void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, + unsigned char* Weight_2bit, + unsigned char* Weight_4bit, + size_t M, + size_t K); + /* * Inputs: * (1) torch::Tensor weight[M, K] in FP16 @@ -67,28 +145,4 @@ std::vector preprocess_weight(torch::Tensor& Weight) K); return {Weight_2bit, Weight_4bit}; -} - -/* - * Inputs: - * (1) torch::Tensor Scale_In[M, K/GroupSize] in FP16 - * Outputs: - * (1) torch::Tensor Scale_Out[M, K/GroupSize] in FP16 - */ - -torch::Tensor preprocess_scales(torch::Tensor& Scale, int M, int K) -{ - // Preprocess scales - TORCH_CHECK(Scale.dim() == 2, "scale must be 2-dimensional"); - TORCH_CHECK(Scale.size(0) == M, "scale must have same M as weight"); - TORCH_CHECK(Scale.is_contiguous(), "scale must be contiguous"); - TORCH_CHECK(Scale.device().type() == torch::kCPU, "scale must be on CPU"); - TORCH_CHECK(Scale.scalar_type() == torch::kFloat16, "scale must be FP16"); - auto GroupSize = K / Scale.size(1); - TORCH_CHECK(GroupSize % 64 == 0, "GroupSize must be multiple of 64"); - auto New_Scale = torch::empty_like(Scale); - auto Scale_out = New_Scale.data_ptr(); - auto Scale_in = New_Scale.data_ptr(); - GenMatrix_Scale_FP16((uint8_t*)Scale_out, (uint8_t*)Scale_in, M, K, GroupSize); - return New_Scale; -} +} \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu deleted file mode 100644 index 194b9120a9eb..000000000000 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -#include -#include "cuda_linear_kernels.h" - -cudaError_t QuantGEMM_API( - cudaStream_t stream, - const uint4* Weight1, - const uint4* Weight2, - const half* Scales, - const int QUANT_GROUP_SIZE_DIVIDED_BY_64, - const half* B, - half* C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches - int Split_K); - -void Launch_QuantGEMM(torch::Tensor C, - torch::Tensor Weight1, // 2bit - torch::Tensor Weight2, // 4bit - torch::Tensor B, - torch::Tensor Scales, - const int M_Global, - const int N_Global, - const int K_Global, - const int Split_K, - torch::Tensor workspace) -{ - auto C_ptr = C.data_ptr(); - auto B_ptr = B.data_ptr(); - auto W1_ptr = Weight1.data_ptr(); - auto W2_ptr = Weight2.data_ptr(); - auto Group_Size = K_Global / Scales.size(1); - - // auto workspace_size = M_Global * N_Global * Split_K; - // auto workspace = torch::empty({workspace_size}, torch::kFloat16); - - auto status = QuantGEMM_API(at::cuda::getCurrentCUDAStream(), - (uint4*)W1_ptr, - (uint4*)W2_ptr, - (half*)Scales.data_ptr(), - Group_Size / 64, - (half*)B_ptr, - (half*)C_ptr, - M_Global, - N_Global, - K_Global, - workspace.data_ptr(), - Split_K); - if (status != cudaSuccess) { - AT_ERROR("QuantGEMM_API failed with error: ", cudaGetErrorString(status)); - } -} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h index 0a2e2bd4197e..bbdbcb487235 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h @@ -11,8 +11,8 @@ void cuda_wf6af16_linear(torch::Tensor& output, torch::Tensor& hidden_states, - torch::Tensor& weights_4bit, torch::Tensor& weights_2bit, + torch::Tensor& weights_4bit, torch::Tensor& scale, torch::Tensor& workspace, int M, @@ -21,5 +21,3 @@ void cuda_wf6af16_linear(torch::Tensor& output, int split_k); std::vector preprocess_weight(torch::Tensor& Weight); - -torch::Tensor preprocess_scales(torch::Tensor& Scale, int M, int K); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/quant_gemm_api.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/quant_gemm_api.cuh deleted file mode 100644 index 35d899aa642c..000000000000 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/quant_gemm_api.cuh +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// SPDX-License-Identifier: Apache-2.0 - -// DeepSpeed Team - -#include -#include -#include - -/* - *half* Reduction_Workspace: 1. Requiring an extra memory space in device memory for un-reducted intermediate output tensors - * 2. Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp16) - */ -cudaError_t QuantGEMM_API(cudaStream_t stream, - const uint4 *Weight1, - const uint4 *Weight2, - const half *Scales, - const int QUANT_GROUP_SIZE_DIVIDED_BY_64, - const half *B, - half *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float *Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches - int Split_K); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh index c507b43d8111..fce0146b3fa3 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh @@ -5,27 +5,6 @@ #include "Configs.h" #include "PTX_cp.async.cuh" -template -__device__ __forceinline__ void CopyFromGlobalToReg_ByPassL1 - (uint32_t* RegArray, - const uint4* SourceGlobalAddr, - bool pred_guard = true) { - // Setting GMem pointer - int lane_id = threadIdx.x % WARP_SIZE; - SourceGlobalAddr += lane_id; - #pragma unroll - for(int i=0; i( SPTR_QuantScales, GPTR_A_Scales, pred_guard && (lane_id<8)); + int Offset_Shared = lane_id*2; + int Offset_Global = lane_id/4 + (lane_id%4)*16; + for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8]; } /* @@ -69,15 +46,16 @@ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantSc * (2) Copying 64 rows * X columns of FP16 values, originally in column major * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads */ -template +template __device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], const half* GlobalPTR, const int GlobalStride, + const int NumOfLinesLeft, // To support arbitrary N dimensions. bool Pred = true) { // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; const int NumOfGroups = NumOfThreads / 8; - const int MaxIteration = (NumOfLinesToCopy-1) / NumOfGroups + 1; + const int MaxIteration = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1; // runtime variables const int line_id = threadIdx.x / 8; const int line_offset = (threadIdx.x%8) * 8; @@ -86,7 +64,7 @@ __device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*Share SharedPTR += line_id; #pragma unroll for (int i = 0; i < MaxIteration; i++) { - bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesToCopy && Pred; + bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred; cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); // GlobalPTR += NumOfGroups * GlobalStride; diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh index d8694b400699..827667abc320 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh @@ -4,15 +4,13 @@ #include #include #include -#include /* * Input: R1 * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2) -{ +__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) { *R2 = *R1 & 0x80808080; *R1 = *R1 >> 2; *R1 = *R1 & 0x1f1f1f1f; @@ -22,99 +20,91 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2) *R2 = *R2 << 8; } -// Will be removed in the future. -#define DEBUG 1 +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is NOT applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) { + //*R2 = *R1 & 0x80808080; + *R2 = *R1 & 0xc0c0c0c0; + *R1 = *R1 >> 2; + //*R1 = *R1 & 0x1f1f1f1f; + *R1 = *R1 & 0x0f0f0f0f; + *R2 = *R2 | *R1; + // + //*R1 = *R2 & 0x9f009f00; + //*R2 = *R2 & 0x009f009f; + *R1 = *R2 & 0xcf00cf00; + if( !(*R1 & 0x40000000) && (*R1 & 0x0c000000) ) *R1 = *R1 | 0x30000000; + if( !(*R1 & 0x00004000) && (*R1 & 0x00000c00) ) *R1 = *R1 | 0x00003000; + *R2 = *R2 & 0x00cf00cf; + if( !(*R2 & 0x00400000) && (*R2 & 0x000c0000) ) *R2 = *R2 | 0x00300000; + if( !(*R2 & 0x00000040) && (*R2 & 0x0000000c) ) *R2 = *R2 | 0x00000030; + // + *R2 = *R2 << 8; + //*R1 = 0x3c003c00; + //*R2 = 0x3c003c00; +} -__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) -{ +__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) { half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; uint32_t output; half* output_half_ptr = reinterpret_cast(&output); - -#if DEBUG - // Only for testing. Note that the scales should be multiplied by e12 accoring to the Quant-LLM - // optimization. - bool apply_e12 = true; - if (apply_e12) { - // TODO: this will still lead to NaN. Need to fix it. - // 2^12 == 4096 - output_half_ptr[0] = __hmul(*FP16_1, half(4096)); - output_half_ptr[1] = __hmul(*FP16_2, half(4096)); - } else { - output_half_ptr[0] = *FP16_1; - output_half_ptr[1] = *FP16_2; - } -#else output_half_ptr[0] = *FP16_1 * Scale; - output_half_ptr[1] = *FP16_2 * Scale; -#endif -#if 0 - auto res1 = output_half_ptr[0]; - auto res2 = output_half_ptr[1]; - // printf("FP16_1: %f\n", __half2float(*FP16_1)); - // The following two lines do not show NaN. - if (res1 != res1) { printf("res1 is NaN: %f\n", __half2float(res1)); } - if (res2 != res2) { printf("res2 is NaN: %f\n", __half2float(res2)); } -#endif - + output_half_ptr[1] = *FP16_2 * Scale; return output; } -__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], - u_int32_t __restrict__* read_RPTR_Frag1, - u_int32_t __restrict__* read_RPTR_Frag2, - u_int32_t* Scales) -{ - u_int32_t* OutputRegs = reinterpret_cast(Reg); - u_int32_t* Frag1_PTR = read_RPTR_Frag1; - u_int32_t* Frag2_PTR = read_RPTR_Frag2; - half* Scale_RPTR = reinterpret_cast(Scales); - u_int32_t Packed_FP6 = 0; - u_int32_t tmp = 0; -// Dequantizing 32 FP6, each Loop dequantizing 4 FP6 -#pragma unroll(8) - for (int i = 0; i < 8; i++) { +__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], + u_int32_t __restrict__ *read_RPTR_Frag1, + u_int32_t __restrict__ *read_RPTR_Frag2, + u_int32_t *Scales) { + u_int32_t *OutputRegs = reinterpret_cast (Reg); + u_int32_t *Frag1_PTR = read_RPTR_Frag1; + u_int32_t *Frag2_PTR = read_RPTR_Frag2; + half *Scale_RPTR = reinterpret_cast(Scales); + u_int32_t Packed_FP6 = 0; + u_int32_t tmp = 0; + // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 + #pragma unroll(8) + for(int i=0; i<8; i++) { // Frag1 Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; - if (i % 4 == 3) - Frag1_PTR++; - else - (*Frag1_PTR) = (*Frag1_PTR) << 2; + if(i%4==3) Frag1_PTR++; + else (*Frag1_PTR) = (*Frag1_PTR) << 2; // Frag2 tmp = (*Frag2_PTR) & 0xf0f0f0f0; tmp = tmp >> 2; - if (i % 2 == 1) - Frag2_PTR++; - else - (*Frag2_PTR) = (*Frag2_PTR) << 4; + if(i%2==1) Frag2_PTR++; + else (*Frag2_PTR) = (*Frag2_PTR) << 4; // Packed_FP6 Packed_FP6 = Packed_FP6 | tmp; // FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); // - *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0]); // Muliply FP16 scales + *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0] ); // Muliply FP16 scales OutputRegs += 1; - *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales OutputRegs += 1; // Updating offset for FP16 scales for every two iterations - if (i % 2 == 1) Scale_RPTR += 2; + if(i%2==1) Scale_RPTR += 2; } + } /* - * + * */ -__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, - half* WARP_SPTR_Scales) -{ +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) { int lane_id = threadIdx.x % WARP_SIZE; uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); uint32_t tmpReg = SPTR_uint[lane_id]; -#pragma unroll - for (int i = 0; i < 4; i++) { - // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); - Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); + #pragma unroll + for(int i=0; i<4; i++) { + // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); } } diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index 927eaf90b6ef..b0da842936b6 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -50,6 +50,8 @@ def fp_quantize( min_value is not None and max_value is not None) assert input.dtype == torch.float16 + + print(f"device of input: {input.device}") input = input.to(torch.float32) if num_bits == 6: @@ -154,7 +156,6 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self.inf_module = InferenceCoreBuilder().load() self.inf_module.create_handle() self.preprocess_weight = self.inf_module.preprocess_weight - self.preprocess_scales = self.inf_module.preprocess_scales self.quantizer = fp_quantize @@ -187,13 +188,12 @@ def transform_param(self, param: torch.Tensor) -> InferenceParameter: # According to the optimization in Quant-LLM, the scales need to be multiplied by 2^12. scales = scales * (2 ** 12) - scales = self.preprocess_scales(scales, self.M, self.K) - return InferenceParameter.initialize(weights_4bit, weights_2bit=weights_2bit, scales=scales) + return InferenceParameter.initialize(weights_2bit, weights_4bit=weights_4bit, scales=scales) def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: - weights_4bit = w - weights_2bit = w.weights_2bit + weights_2bit = w + weights_4bit = w.weights_4bit scales = w.scales output = empty_from( self._output, (hidden_states.shape[0], self._config.out_channels)) @@ -201,12 +201,12 @@ def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torc if self._is_gated: staging_output = empty_from( self._double_buffer, (hidden_states.shape[0], self.M)) - self._linear_impl(staging_output, hidden_states, weights_4bit, - weights_2bit, scales, self.M, hidden_states.shape[0], self.K) + self._linear_impl(staging_output, hidden_states, weights_2bit, + weights_4bit, scales, self.M, hidden_states.shape[0], self.K) self._act_fn(output, staging_output, b) else: - self._linear_impl(output, hidden_states, weights_4bit, - weights_2bit, scales, self.M, hidden_states.shape[0], self.K) + self._linear_impl(output, hidden_states, weights_2bit, + weights_4bit, scales, self.M, hidden_states.shape[0], self.K) self._act_fn(output, b) if self.DEBUG: diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index c1ef43f8170f..7ebd2a8c323f 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -67,10 +67,8 @@ def sources(self): "inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu", "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp", "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu", - "inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cu", "inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp", "inference/v2/kernels/core_ops/cuda_linear/Launcher.cu", - "inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.cpp", ] prefix = self.get_prefix() From 1c2131d20eb58072a038ef2a744e9c11b846af1f Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Fri, 26 Jan 2024 08:50:43 +0000 Subject: [PATCH 03/31] Make the quantizer on GPU. --- .../modules/implementations/linear/quantized_linear.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index b0da842936b6..db184674ed4f 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -50,10 +50,9 @@ def fp_quantize( min_value is not None and max_value is not None) assert input.dtype == torch.float16 - - print(f"device of input: {input.device}") - input = input.to(torch.float32) + orig_device = input.device + input = input.to(torch.float32).cuda() if num_bits == 6: if exp_bits == 3: # this is defulat q_range = 28 @@ -84,8 +83,8 @@ def fp_quantize( quantized_fake_fp6 = float_quantize( scaled_input, exp_bits, man_bits, rounding="nearest") quantized_fake_fp6 = quantized_fake_fp6.reshape( - input_shape).contiguous().to(torch.float16) - scales = scales.to(torch.float16) + input_shape).contiguous().to(torch.float16).to(orig_device) + scales = scales.to(torch.float16).to(orig_device) # Now the dequantized value is quantized_fake_fp6 * scales return quantized_fake_fp6, scales @@ -159,6 +158,7 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self.quantizer = fp_quantize + # This is for debugging, will delete after release. self.DEBUG = True def transform_param(self, param: torch.Tensor) -> InferenceParameter: From 1ba45fdd53ff2ebd86e8733c821eabcd3c0c4e30 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Fri, 26 Jan 2024 10:02:13 +0000 Subject: [PATCH 04/31] [WIP] Fix the bug of FP16-to-FP6 data packing. --- .../cuda_linear/cuda_linear_kernels.cpp | 15 +++++++---- .../linear/quantized_linear.py | 26 ++++++++++++------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index 42ae7342d844..d59742a94093 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -10,12 +10,15 @@ namespace { // Utils to prepack FP16 weights into continuous FP6 values. -// TODO: debug it. +// TODO: debug according to the qtorch float_quantize funcion: +// https://github.com/Tiiiger/QPyTorch/blob/f58bba72113e696099ef3e15e06cf421a06ff289/qtorch/quant/quant_cuda/float_kernel.cu#L41 void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) { + // Constants for FP6 constexpr int exponent_bits_fp6 = 3; constexpr int mantissa_bits_fp6 = 2; + constexpr int exp_bias_fp6 = (1 << (exponent_bits_fp6 - 1)) - 1; // Constants for FP16 constexpr int exponent_bits_fp16 = 5; constexpr int mantissa_bits_fp16 = 10; @@ -25,11 +28,13 @@ void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) for (int i = 0; i < 4; ++i) { int sign = (FP16x4[i] >> 15); - int exp = (FP16x4[i] >> mantissa_bits_fp16) & - ((1 << exponent_bits_fp16) - 1); // Extracting exponent - int mant = FP16x4[i] & ((1 << mantissa_bits_fp16) - 1); // Extracting mantissa + // Extracting exponent represented in FP16 + int exp = (FP16x4[i] << 1 >> (mantissa_bits_fp16 + 1)) & ((1 << exponent_bits_fp16) - 1); + // Extracting mantissa represented in FP16 + int mant = FP16x4[i] & ((1 << mantissa_bits_fp16) - 1); - int new_exp = exp - exp_bias_fp16; + int new_exp = exp - exp_bias_fp16 + exp_bias_fp6; + new_exp &= ((1 << exponent_bits_fp6) - 1); // To double check. int new_mant = mant >> (mantissa_bits_fp16 - mantissa_bits_fp6); fp6_temp[i] = (sign << (exponent_bits_fp6 + mantissa_bits_fp6)) | diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index db184674ed4f..7e5c58f173e8 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -87,6 +87,8 @@ def fp_quantize( scales = scales.to(torch.float16).to(orig_device) # Now the dequantized value is quantized_fake_fp6 * scales + # TODO: the conversion between float and half may make the fp6 value not accurate. To test and debug. + return quantized_fake_fp6, scales @@ -173,12 +175,12 @@ def transform_param(self, param: torch.Tensor) -> InferenceParameter: if param.ndim == 1: # bias, do nothing return InferenceParameter.initialize(param) - if self.DEBUG: - self.weight = param.clone().cpu() - quantized_fake_fp6, scales = self.quantizer( param, num_bits=6, exp_bits=3) + if self.DEBUG: + self.weight_dequantized = quantized_fake_fp6 * scales + # This is for debugging, will delete after release. assert (quantized_fake_fp6.dtype == torch.float16) assert quantized_fake_fp6.shape[0] == self.M @@ -210,21 +212,27 @@ def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torc self._act_fn(output, b) if self.DEBUG: - orig_device = self.weight.device - self.weight = self.weight.to(output.device) + orig_device = self.weight_dequantized.device + self.weight_dequantized = self.weight_dequantized.to(output.device) ground_truth = torch.nn.functional.linear( - hidden_states, self.weight, b) - self.weight = self.weight.to(orig_device) + hidden_states, self.weight_dequantized, b) + self.weight_dequantized = self.weight_dequantized.to(orig_device) shape = (hidden_states.shape[0], self.M, self.K) if self._is_gated: ismatch = torch.allclose( ground_truth, staging_output, rtol=1e-3) + abs_diff = torch.max(torch.abs(ground_truth - staging_output)) + rel_diff = torch.max( + torch.abs((ground_truth - staging_output) / ground_truth)) print(f"Linear shape: {shape}:\n\tIs correct: {ismatch}. " - f"Max diff: {torch.max(torch.abs(ground_truth - staging_output))}") + f"Max diff (abs, rel): ({abs_diff}, {rel_diff})") else: ismatch = torch.allclose(ground_truth, output, rtol=1e-3) + abs_diff = torch.max(torch.abs(ground_truth - output)) + rel_diff = torch.max( + torch.abs((ground_truth - output) / ground_truth)) print(f"Linear shape: {shape}:\n\tIs correct: {ismatch}. " - f"Max diff: {torch.max(torch.abs(ground_truth - output))}") + f"Max diff (abs, rel): ({abs_diff}, {rel_diff})") return output From ff6c3c3a72e8b2d209b75fa9cbf4098d4f7e092d Mon Sep 17 00:00:00 2001 From: Arash Bakhtiari Date: Sat, 27 Jan 2024 02:03:20 +0000 Subject: [PATCH 05/31] Add FP6 end-to-end unit tests --- .../modules/test_quantizied_linear_module.py | 114 ++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 tests/unit/inference/v2/modules/test_quantizied_linear_module.py diff --git a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py new file mode 100644 index 000000000000..d5d7fa4beff8 --- /dev/null +++ b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum, is_gated +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSLinearConfig +from deepspeed.inference.v2.modules.interfaces import DSLinearRegistry +from ...v2.inference_test_utils import allclose + + +def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + out_states = torch.nn.functional.linear(hidden_states, weight, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _fp6_quantized_linear_helper(tokens: int, + in_channels: int, + out_channels: int, + dtype: DtypeEnum, + act_fn: ActivationType, + use_bias: bool = True) -> None: + + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + + weight_out_channels = 2 * out_channels if is_gated(act_fn) else out_channels + weight = torch.randn( + (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + if use_bias: + bias = torch.randn( + (weight_out_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + else: + bias = None + + linear_config = DSLinearConfig(max_tokens=2048, + in_channels=in_channels, + out_channels=out_channels, + activation=act_fn, + input_dtype=dtype, + output_dtype=dtype) + bundle = ConfigBundle(name='quantized_wf6af16_linear', config=linear_config) + fp6_linear_module = DSLinearRegistry.instantiate_config(bundle) + weight_fp6 = fp6_linear_module.transform_param(weight.clone().cpu()).to(get_accelerator().current_device_name()) + + # Reference output + ref_output = reference_implementation(hidden_states, weight, bias, act_fn) + + # New output + ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) + + # Check + assert allclose(ds_output, ref_output) + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens, in_channels, out_channels", [(1, 4608, 1728), (37, 8192, 4096), (1280, 3072, 6144)]) +def test_fp6_quantized_linear_shapes(tokens: int, in_channels: int, out_channels: int) -> None: + _fp6_quantized_linear_helper(tokens, + in_channels, + out_channels, + DtypeEnum.fp16, + ActivationType.IDENTITY, + use_bias=True) + + +all_acts = [ + ActivationType.RELU, + ActivationType.GELU, + ActivationType.SILU, + ActivationType.GEGLU, + ActivationType.ReGLU, + ActivationType.SiGLU, +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_fp6_quantized_linear_act_fn(act_fn: ActivationType, use_bias: bool) -> None: + _fp6_quantized_linear_helper(283, 512, 4096, DtypeEnum.fp16, act_fn, use_bias=use_bias) From 368a76309ec64603a48aee562988de06bf5d31f5 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Mon, 29 Jan 2024 12:29:40 +0000 Subject: [PATCH 06/31] Refine the FP16-to-FP6 cast logic. --- .../cuda_linear/cuda_linear_kernels.cpp | 112 +++++++++++++++--- .../linear/quantized_linear.py | 5 +- 2 files changed, 96 insertions(+), 21 deletions(-) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index d59742a94093..7fbb529c21e9 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -9,36 +9,112 @@ namespace { +// For bit-level debugging. +template +void printBits(T num) +{ + char bits[sizeof(T) * 8 + 1] = {'\0'}; + for (int bit = 0; bit < (sizeof(T) * 8); bit++) { + bits[sizeof(T) * 8 - 1 - bit] = '0' + (num & 0x01); + num = num >> 1; + } + printf("%s\n", bits); +} + +void printBits(half num) +{ + char bits[sizeof(half) * 8 + 1] = {'\0'}; + auto int_num = *reinterpret_cast(&num); + for (int bit = 0; bit < (sizeof(half) * 8); bit++) { + bits[sizeof(half) * 8 - 1 - bit] = '0' + (int_num & 0x01); + int_num = int_num >> 1; + } + printf("%s\n", bits); +} + // Utils to prepack FP16 weights into continuous FP6 values. -// TODO: debug according to the qtorch float_quantize funcion: +// TODO: the following cast seems not the same with that in qtorch: // https://github.com/Tiiiger/QPyTorch/blob/f58bba72113e696099ef3e15e06cf421a06ff289/qtorch/quant/quant_cuda/float_kernel.cu#L41 void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) { // Constants for FP6 - constexpr int exponent_bits_fp6 = 3; - constexpr int mantissa_bits_fp6 = 2; - constexpr int exp_bias_fp6 = (1 << (exponent_bits_fp6 - 1)) - 1; + constexpr int exponent_nbits_fp6 = 3; + constexpr int mantissa_nbits_fp6 = 2; + constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; + // constexpr int max_exponent_fp6 = (1 << exponent_nbits_fp6) - 2; // Constants for FP16 - constexpr int exponent_bits_fp16 = 5; - constexpr int mantissa_bits_fp16 = 10; - constexpr int exp_bias_fp16 = (1 << (exponent_bits_fp16 - 1)) - 1; + constexpr int exponent_nbits_fp16 = 5; + constexpr int mantissa_nbits_fp16 = 10; + constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1; - uint8_t fp6_temp[4]; + int fp6_temp[4]; for (int i = 0; i < 4; ++i) { - int sign = (FP16x4[i] >> 15); - // Extracting exponent represented in FP16 - int exp = (FP16x4[i] << 1 >> (mantissa_bits_fp16 + 1)) & ((1 << exponent_bits_fp16) - 1); - // Extracting mantissa represented in FP16 - int mant = FP16x4[i] & ((1 << mantissa_bits_fp16) - 1); + uint16_t source = FP16x4[i]; + // It is not safe to do shift operation on uint16_t. So we promote it to int. + int source_promote = int(source); - int new_exp = exp - exp_bias_fp16 + exp_bias_fp6; - new_exp &= ((1 << exponent_bits_fp6) - 1); // To double check. - int new_mant = mant >> (mantissa_bits_fp16 - mantissa_bits_fp6); + int sign = (source_promote >> 15); + // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111' + int exp = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; + // Extracting mantissa represented in FP16 + int mant = source_promote & ((1 << mantissa_nbits_fp16) - 1); + + int new_exp; + int new_mant; + if (exp == 0 || exp == ((1 << exponent_nbits_fp16) - 1)) { + // When all bits of exponent are zero, the calculation of the float value will not + // include the spacific value of the exponent. Thus we can just copy the value to the + // new variable after bit cutting. + // TODO: a problem here is that the mantissa actually affects the results when all bits + // of exponent are one. We need to consider this case in the next version if it matters + // in practice. + new_exp = exp >> (exponent_nbits_fp16 - exponent_nbits_fp6); + new_mant = mant >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); + } else { + int target_exp = int(exp) - int(exp_bias_fp16); + constexpr int min_exp_fp6 = -((1 << (exponent_nbits_fp6 - 1)) - 2); + constexpr int max_exp_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; + if (target_exp < min_exp_fp6) { + // The exponent is too small to be represented in FP6. We need to round it to zero. + // Keep the sign. + // TODO: do we round it to zero or the smallest FP6 value? + new_exp = 0; + new_mant = 0; + } else if (target_exp > max_exp_fp6) { + // The exponent is too large to be represented in FP6. We need to round it to the + // largest value. Keep the sign. + new_exp = max_exp_fp6 + exp_bias_fp6; + new_mant = (1 << mantissa_nbits_fp6) - 1; + } else { + new_exp = target_exp + exp_bias_fp6; + new_mant = mant >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); + } + +#if 0 + if (target_exp < min_exp_fp6) { + uint16_t casted = (sign << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | + (new_exp << mantissa_nbits_fp6) | new_mant; + printf("%f exp too small, new value is: %f\n", + __half2float(*((half*)(&source))), + __half2float(*((half*)(&casted)))); + printBits(source); + printBits(casted); + } else if (target_exp > max_exp_fp6) { + uint16_t casted = (sign << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | + (new_exp << mantissa_nbits_fp6) | new_mant; + printf("%f exp too large, new value is: %f\n", + __half2float(*((half*)(&source))), + __half2float(*((half*)(&casted)))); + printBits(source); + printBits(casted); + } +#endif + } - fp6_temp[i] = (sign << (exponent_bits_fp6 + mantissa_bits_fp6)) | - (new_exp << mantissa_bits_fp6) | new_mant; + fp6_temp[i] = (sign << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | + (new_exp << mantissa_nbits_fp6) | new_mant; } // Pack the values FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index 7e5c58f173e8..3d3e0cd32a66 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -82,13 +82,13 @@ def fp_quantize( quantized_fake_fp6 = float_quantize( scaled_input, exp_bits, man_bits, rounding="nearest") + # TODO: it seems the `float_quantize` will not clamp the value into the range of FP6 correctly. + # To double check it. If it is true, we need to clamp it manually. quantized_fake_fp6 = quantized_fake_fp6.reshape( input_shape).contiguous().to(torch.float16).to(orig_device) scales = scales.to(torch.float16).to(orig_device) # Now the dequantized value is quantized_fake_fp6 * scales - # TODO: the conversion between float and half may make the fp6 value not accurate. To test and debug. - return quantized_fake_fp6, scales @@ -199,7 +199,6 @@ def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torc scales = w.scales output = empty_from( self._output, (hidden_states.shape[0], self._config.out_channels)) - # N = hidden_states.shape[0] if self._is_gated: staging_output = empty_from( self._double_buffer, (hidden_states.shape[0], self.M)) From 6c45a84bb1261db3ce78d32ea4588e2d0dec5a66 Mon Sep 17 00:00:00 2001 From: Arash Bakhtiari Date: Tue, 30 Jan 2024 01:45:43 +0000 Subject: [PATCH 07/31] Add unit tests for FP6 quantizer --- .../implementations/linear/__init__.py | 2 +- .../modules/test_quantizied_linear_module.py | 96 +++++++++++++++---- 2 files changed, 78 insertions(+), 20 deletions(-) diff --git a/deepspeed/inference/v2/modules/implementations/linear/__init__.py b/deepspeed/inference/v2/modules/implementations/linear/__init__.py index 11a545846522..2843f8bf187a 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/__init__.py +++ b/deepspeed/inference/v2/modules/implementations/linear/__init__.py @@ -4,5 +4,5 @@ # DeepSpeed Team from .blas_fp_linear import BlasFPLinear -from .quantized_linear import QuantizedWf6Af16Linear +from .quantized_linear import QuantizedWf6Af16Linear, fp_quantize diff --git a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py index d5d7fa4beff8..314d5cc5ce90 100644 --- a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py +++ b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py @@ -46,13 +46,64 @@ def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, return out_states.to(dtype) +def _fp6_quant_dequant_weights(weight: torch.Tensor) -> torch.Tensor: + from deepspeed.inference.v2.modules.implementations.linear.quantized_linear import fp_quantize + weight_quantized_fake_fp6, scales = fp_quantize(weight, num_bits=6, exp_bits=3) + return weight_quantized_fake_fp6 * scales + + +def quant_dequant_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + weight_dequantized = _fp6_quant_dequant_weights(weight) + out_states = torch.nn.functional.linear(hidden_states, weight_dequantized, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _fp6_quantized_weights_helper( + in_channels: int, + out_channels: int, + dtype: DtypeEnum, + act_fn: ActivationType, +) -> torch.Tensor: + weight_out_channels = 2 * out_channels if is_gated(act_fn) else out_channels + weight = torch.randn( + (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + weight_dequantized = _fp6_quant_dequant_weights(weight) + tolerances = (4.8e-1, 3.2e-2) # tolerances for bf16 + # tolerances = (3e-2, 2e-3) # tolerances for fp16 + assert allclose(weight_dequantized, weight, tolerances=tolerances) + + def _fp6_quantized_linear_helper(tokens: int, in_channels: int, out_channels: int, dtype: DtypeEnum, act_fn: ActivationType, use_bias: bool = True) -> None: - # Input vals hidden_states = torch.randn( (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 @@ -66,6 +117,9 @@ def _fp6_quantized_linear_helper(tokens: int, else: bias = None + # quantize and dequantize output + ref_quant_dequant_output = quant_dequant_implementation(hidden_states, weight, bias, act_fn) + linear_config = DSLinearConfig(max_tokens=2048, in_channels=in_channels, out_channels=out_channels, @@ -75,26 +129,17 @@ def _fp6_quantized_linear_helper(tokens: int, bundle = ConfigBundle(name='quantized_wf6af16_linear', config=linear_config) fp6_linear_module = DSLinearRegistry.instantiate_config(bundle) weight_fp6 = fp6_linear_module.transform_param(weight.clone().cpu()).to(get_accelerator().current_device_name()) - - # Reference output - ref_output = reference_implementation(hidden_states, weight, bias, act_fn) - - # New output ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) - # Check - assert allclose(ds_output, ref_output) + tolerances = (4.8e-1, 3.2e-2) # tolerances for bf16 + # tolerances = (3e-2, 2e-3) # tolerances for fp16 + # Check DeepSpeed implementation + assert allclose(ds_output, ref_quant_dequant_output, tolerances=tolerances) -@pytest.mark.inference_v2_ops -@pytest.mark.parametrize("tokens, in_channels, out_channels", [(1, 4608, 1728), (37, 8192, 4096), (1280, 3072, 6144)]) -def test_fp6_quantized_linear_shapes(tokens: int, in_channels: int, out_channels: int) -> None: - _fp6_quantized_linear_helper(tokens, - in_channels, - out_channels, - DtypeEnum.fp16, - ActivationType.IDENTITY, - use_bias=True) + # # Check reference implementation + # ref_output = reference_implementation(hidden_states, weight, bias, act_fn) + # assert allclose(ds_output, ref_output, tolerances=tolerances) all_acts = [ @@ -105,10 +150,23 @@ def test_fp6_quantized_linear_shapes(tokens: int, in_channels: int, out_channels ActivationType.ReGLU, ActivationType.SiGLU, ] +all_tokens = [1, 37, 1280] +all_in_out_channels = [(4608, 1728), (8192, 4096), (3072, 6144)] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_fp6_quantized_weights(in_channels: int, out_channels: int, act_fn: ActivationType, use_bias: bool) -> None: + _fp6_quantized_weights_helper(in_channels, out_channels, DtypeEnum.fp16, act_fn) @pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens", all_tokens) +@pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) @pytest.mark.parametrize("act_fn", all_acts) @pytest.mark.parametrize("use_bias", [True, False]) -def test_fp6_quantized_linear_act_fn(act_fn: ActivationType, use_bias: bool) -> None: - _fp6_quantized_linear_helper(283, 512, 4096, DtypeEnum.fp16, act_fn, use_bias=use_bias) +def test_fp6_quantized_linear_act_fn(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, + use_bias: bool) -> None: + _fp6_quantized_linear_helper(tokens, in_channels, out_channels, DtypeEnum.fp16, act_fn, use_bias=use_bias) From 90b710d29c2ae7b2af641d23a1b3170abffd2a9c Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Tue, 30 Jan 2024 08:48:12 +0000 Subject: [PATCH 08/31] Fix FP16-FP6 cast problems. --- .../cuda_linear/cuda_linear_kernels.cpp | 76 +++++-------------- 1 file changed, 19 insertions(+), 57 deletions(-) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index 7fbb529c21e9..e1cb2e64c689 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -11,7 +11,7 @@ namespace { // For bit-level debugging. template -void printBits(T num) +void print_bits(T num) { char bits[sizeof(T) * 8 + 1] = {'\0'}; for (int bit = 0; bit < (sizeof(T) * 8); bit++) { @@ -21,7 +21,7 @@ void printBits(T num) printf("%s\n", bits); } -void printBits(half num) +void print_bits(half num) { char bits[sizeof(half) * 8 + 1] = {'\0'}; auto int_num = *reinterpret_cast(&num); @@ -34,15 +34,12 @@ void printBits(half num) // Utils to prepack FP16 weights into continuous FP6 values. -// TODO: the following cast seems not the same with that in qtorch: -// https://github.com/Tiiiger/QPyTorch/blob/f58bba72113e696099ef3e15e06cf421a06ff289/qtorch/quant/quant_cuda/float_kernel.cu#L41 void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) { // Constants for FP6 constexpr int exponent_nbits_fp6 = 3; constexpr int mantissa_nbits_fp6 = 2; constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; - // constexpr int max_exponent_fp6 = (1 << exponent_nbits_fp6) - 2; // Constants for FP16 constexpr int exponent_nbits_fp16 = 5; constexpr int mantissa_nbits_fp16 = 10; @@ -50,8 +47,19 @@ void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) int fp6_temp[4]; + float absmin_nonzero_fp6 = 0.0625; + // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is + // the same with that in qtorch. + float absmax_fp6 = 28; + for (int i = 0; i < 4; ++i) { uint16_t source = FP16x4[i]; + float fp6_value_abs = std::abs(__half2float(*((half*)(&source)))); + if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) || + fp6_value_abs > absmax_fp6) { + throw std::invalid_argument("Input value out of range for FP6."); + } + // It is not safe to do shift operation on uint16_t. So we promote it to int. int source_promote = int(source); @@ -61,57 +69,8 @@ void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) // Extracting mantissa represented in FP16 int mant = source_promote & ((1 << mantissa_nbits_fp16) - 1); - int new_exp; - int new_mant; - if (exp == 0 || exp == ((1 << exponent_nbits_fp16) - 1)) { - // When all bits of exponent are zero, the calculation of the float value will not - // include the spacific value of the exponent. Thus we can just copy the value to the - // new variable after bit cutting. - // TODO: a problem here is that the mantissa actually affects the results when all bits - // of exponent are one. We need to consider this case in the next version if it matters - // in practice. - new_exp = exp >> (exponent_nbits_fp16 - exponent_nbits_fp6); - new_mant = mant >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); - } else { - int target_exp = int(exp) - int(exp_bias_fp16); - constexpr int min_exp_fp6 = -((1 << (exponent_nbits_fp6 - 1)) - 2); - constexpr int max_exp_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; - if (target_exp < min_exp_fp6) { - // The exponent is too small to be represented in FP6. We need to round it to zero. - // Keep the sign. - // TODO: do we round it to zero or the smallest FP6 value? - new_exp = 0; - new_mant = 0; - } else if (target_exp > max_exp_fp6) { - // The exponent is too large to be represented in FP6. We need to round it to the - // largest value. Keep the sign. - new_exp = max_exp_fp6 + exp_bias_fp6; - new_mant = (1 << mantissa_nbits_fp6) - 1; - } else { - new_exp = target_exp + exp_bias_fp6; - new_mant = mant >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); - } - -#if 0 - if (target_exp < min_exp_fp6) { - uint16_t casted = (sign << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | - (new_exp << mantissa_nbits_fp6) | new_mant; - printf("%f exp too small, new value is: %f\n", - __half2float(*((half*)(&source))), - __half2float(*((half*)(&casted)))); - printBits(source); - printBits(casted); - } else if (target_exp > max_exp_fp6) { - uint16_t casted = (sign << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | - (new_exp << mantissa_nbits_fp6) | new_mant; - printf("%f exp too large, new value is: %f\n", - __half2float(*((half*)(&source))), - __half2float(*((half*)(&casted)))); - printBits(source); - printBits(casted); - } -#endif - } + int new_exp = exp == 0 ? 0 : exp - exp_bias_fp16 + exp_bias_fp6; + int new_mant = mant >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); fp6_temp[i] = (sign << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | (new_exp << mantissa_nbits_fp6) | new_mant; @@ -130,9 +89,12 @@ void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) */ void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t M, size_t K) { + // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). + if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } + size_t K_fp6_packed = K * 6 / 8; #pragma omp parallel for for (auto m = 0; m < M; m++) { - uint8_t* ptr_6bit = Weight_6bit + m * K * 6 / 8; + uint8_t* ptr_6bit = Weight_6bit + m * K_fp6_packed; uint16_t* ptr_16bit = Weight_16bit + m * K; for (auto k = 0; k < K; k += 4) { Cast_FP16_FP6(ptr_16bit, ptr_6bit); From f8e3acfb63acf8ed96959ae7f84fa1d48b943c57 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Thu, 1 Feb 2024 04:47:56 +0000 Subject: [PATCH 09/31] Update FP6 kernels. --- .../cuda_linear/cuda_linear_kernels.cpp | 107 +++---- .../cuda_linear/cuda_linear_kernels.h | 2 + .../core_ops/cuda_linear/fp6_linear.cu | 303 ++++++++++++++++++ .../core_ops/cuda_linear/fp6_linear.cuh | 38 +++ .../cuda_linear/{utils => include}/Configs.h | 0 .../{utils => include}/PTX_cp.async.cuh | 0 .../{utils => include}/PTX_mma.cuh | 0 .../{utils => include}/Utils_Core.cuh | 0 .../{utils => include}/Utils_GMem.cuh | 0 .../Utils_ParallelDequant.cuh | 4 +- .../kernel_matmul.cuh} | 8 +- .../kernel_reduction.cuh} | 0 .../weight_prepacking.h} | 10 +- .../linear/quantized_linear.py | 35 +- op_builder/inference_core_ops.py | 16 +- 15 files changed, 450 insertions(+), 73 deletions(-) create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu create mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{utils => include}/Configs.h (100%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{utils => include}/PTX_cp.async.cuh (100%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{utils => include}/PTX_mma.cuh (100%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{utils => include}/Utils_Core.cuh (100%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{utils => include}/Utils_GMem.cuh (100%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{utils => include}/Utils_ParallelDequant.cuh (95%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{Kernel_QuantGEMM.cuh => include/kernel_matmul.cuh} (98%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{Kernel_Reduction.cuh => include/kernel_reduction.cuh} (100%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/{GenMatrix_QuantLLM.h => include/weight_prepacking.h} (96%) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index e1cb2e64c689..665360e7efda 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -34,7 +34,7 @@ void print_bits(half num) // Utils to prepack FP16 weights into continuous FP6 values. -void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) +void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) { // Constants for FP6 constexpr int exponent_nbits_fp6 = 3; @@ -87,7 +87,7 @@ void Cast_FP16_FP6(uint16_t* FP16x4, uint8_t* FP6x4) * Outputs: * (1) unsigned char Weight_6bit[M*K*6/8] */ -void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t M, size_t K) +void weight_prepacing_fp16_to_fp6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t M, size_t K) { // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } @@ -97,7 +97,7 @@ void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t uint8_t* ptr_6bit = Weight_6bit + m * K_fp6_packed; uint16_t* ptr_16bit = Weight_16bit + m * K; for (auto k = 0; k < K; k += 4) { - Cast_FP16_FP6(ptr_16bit, ptr_6bit); + cast_fp16_fp6(ptr_16bit, ptr_6bit); ptr_16bit += 4; ptr_6bit += 3; } @@ -106,18 +106,18 @@ void PackMatrix_Weight_FP6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t } // namespace -cudaError_t QuantGEMM_API( - cudaStream_t stream, - const uint4* Weight1, - const uint4* Weight2, - const half* Scales, - const half* B, - half* C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches - int Split_K); +// cudaError_t QuantGEMM_API( +// cudaStream_t stream, +// const uint4* Weight1, +// const uint4* Weight2, +// const half* Scales, +// const half* B, +// half* C, +// const size_t M_Global, +// const size_t N_Global, +// const size_t K_Global, +// float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches +// int Split_K); void cuda_wf6af16_linear(torch::Tensor& output, torch::Tensor& hidden_states, @@ -135,27 +135,27 @@ void cuda_wf6af16_linear(torch::Tensor& output, TORCH_CHECK(hidden_states.device().type() == torch::kCUDA, "X must be on CUDA"); TORCH_CHECK(scales.device().type() == torch::kCUDA, "scales must be on CUDA"); - auto status = QuantGEMM_API(at::cuda::getCurrentCUDAStream(), - (uint4*)(weights_2bit.data_ptr()), - (uint4*)(weights_4bit.data_ptr()), - (half*)(scales.data_ptr()), - (half*)(hidden_states.data_ptr()), - (half*)(output.data_ptr()), - M, - N, - K, - workspace.data_ptr(), - split_k); + auto status = fp6_linear_kernel(at::cuda::getCurrentCUDAStream(), + (uint4*)(weights_2bit.data_ptr()), + (uint4*)(weights_4bit.data_ptr()), + (half*)(scales.data_ptr()), + (half*)(hidden_states.data_ptr()), + (half*)(output.data_ptr()), + M, + N, + K, + workspace.data_ptr(), + split_k); if (status != cudaSuccess) { - AT_ERROR("QuantGEMM_API failed with error: ", cudaGetErrorString(status)); + AT_ERROR("fp6_linear_kernel failed with error: ", cudaGetErrorString(status)); } } -void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, - unsigned char* Weight_2bit, - unsigned char* Weight_4bit, - size_t M, - size_t K); +// void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, +// unsigned char* Weight_2bit, +// unsigned char* Weight_4bit, +// size_t M, +// size_t K); /* * Inputs: @@ -163,29 +163,28 @@ void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, * Outputs: * (1) torch::Tensor weight_2bit and weight_4bit */ -std::vector preprocess_weight(torch::Tensor& Weight) +std::vector preprocess_weight(torch::Tensor& weight) { - TORCH_CHECK(Weight.dim() == 2, "weight must be 2-dimensional"); - TORCH_CHECK(Weight.scalar_type() == torch::kFloat16, "weight must be FP16"); - TORCH_CHECK(Weight.is_contiguous(), "weight must be contiguous"); - TORCH_CHECK(Weight.device().type() == torch::kCPU, "weight must be on CPU"); - auto M = Weight.size(0); - auto K = Weight.size(1); + TORCH_CHECK(weight.dim() == 2, "weight must be 2-dimensional"); + TORCH_CHECK(weight.scalar_type() == torch::kFloat16, "weight must be FP16"); + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(weight.device().type() == torch::kCPU, "weight must be on CPU"); + auto M = weight.size(0); + auto K = weight.size(1); TORCH_CHECK(K % 4 == 0, "K must be multiple of 4"); - // Pack Weight - auto Weight_ptr = Weight.data_ptr(); - std::vector Weight_6bit_Packed(M * K * 6 / 8); - PackMatrix_Weight_FP6((uint16_t*)Weight_ptr, Weight_6bit_Packed.data(), M, K); - - // Split Weight - auto Weight_2bit = torch::empty({M * K * 2 / 8}, torch::kUInt8); - auto Weight_4bit = torch::empty({M * K * 4 / 8}, torch::kUInt8); - GenMatrix_Weight_FP6(Weight_6bit_Packed.data(), - Weight_2bit.data_ptr(), - Weight_4bit.data_ptr(), - M, - K); - - return {Weight_2bit, Weight_4bit}; + // Pack weight from FP16 to FP6. + uint16_t* weight_16bit_ptr = reinterpret_cast(weight.data_ptr()); + std::vector weight_6bit_packed(M * K * 6 / 8); + uint8_t* weight_6bit_ptr = weight_6bit_packed.data(); + weight_prepacing_fp16_to_fp6(weight_16bit_ptr, weight_6bit_ptr, M, K); + + // Split weight into 2bit and 4bit. + weight_matrix_prepacking(reinterpret_cast(weight_6bit_ptr), M, K); + uint8_t* weight_2bit_ptr = weight_6bit_ptr; + auto weight_2bit = torch::from_blob(weight_2bit_ptr, {M * K * 2 / 8}, torch::kUInt8); + uint8_t* weight_4bit_ptr = weight_2bit_ptr + M * K * 2 / 8; + auto weight_4bit = torch::from_blob(weight_4bit_ptr, {M * K * 4 / 8}, torch::kUInt8); + + return {weight_2bit, weight_4bit}; } \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h index bbdbcb487235..0f5882d519ca 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h @@ -9,6 +9,8 @@ #include #include "ds_kernel_utils.h" +#include "fp6_linear.cuh" + void cuda_wf6af16_linear(torch::Tensor& output, torch::Tensor& hidden_states, torch::Tensor& weights_2bit, diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu new file mode 100644 index 000000000000..d80d2764b050 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu @@ -0,0 +1,303 @@ +#include "include/kernel_matmul.cuh" +#include "include/kernel_reduction.cuh" +#include "include/weight_prepacking.h" + +#include +#include +#include +#include + +template +static void Kernel_Ex(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + OutputDataType* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) +{ +#ifdef DEBUG_MODE + printf("\n"); + printf("Launcher.cu->Kernel_Ex():\n"); + printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); + printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", + TilingConfig::TILE_M, + TilingConfig::TILE_K, + TilingConfig::TILE_N); +#endif + static size_t SHMEM_SZ = + max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE, + TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + SHMEM_SZ); + size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; + size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; + dim3 GridDim(dimN, dimM, 1); + dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); +// +#ifdef DEBUG_MODE + printf( + "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: " + "%d SHMEM_SZ: %d\n", + GridDim.x, + GridDim.y, + GridDim.z, + BlockDim.x, + BlockDim.y, + BlockDim.z, + SHMEM_SZ); + printf("\n"); +#endif + QUANT_GEMM_Kernel<<>>( + Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); +} + +/* + * + */ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K * + // M_Global * N_Global * sizeof(fp32) + int Split_K) +{ + assert(M_Global % 256 == 0); + assert(K_Global % 64 == 0); + assert(N_Global > 0); + + // Work around to support more N shapes: + size_t N_PowerOf2; + if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8; + if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16; + if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32; + if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64; + if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128; + if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128; + + if (Split_K == 1) { + switch (N_PowerOf2) { + case 8: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 16: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 32: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 64: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 128: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + } + } else { + switch (N_PowerOf2) { + case 8: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 16: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 32: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 64: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 128: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + } + // Reduction for SplitK + dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); + dim3 BlockDim(WARP_SIZE, 1, 1); + SplitK_Reduction<<>>( + C, Reduction_Workspace, M_Global, N_Global, Split_K); + } + return cudaGetLastError(); +} + +/* +Computes FP6-FP16 GEMM (PyTorch interface). + +[Mathmatical Formula] +Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in +row-major. After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not +perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when +calling our CUDA kernel. + +[Inputs] + _in_feats: tensor of shape [B, IC]; // half + _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + _scales: tensor of shape [OC]; // half + splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. +[Outputs] + _out_feats: tensor of shape [B, OC]; // half +*/ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int splitK = 1) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_channels = _weights.size(0); + assert(num_in_channels % 64 == 0); + assert((num_in_channels / 16 * 3) == + _weights.size(1)); // Making sure the K dimension is matched. + // + int M = num_out_channels; + int K = num_in_channels; + int N = num_in_feats; + // Input Tensors + auto weight1 = reinterpret_cast( + _weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto weight2 = weight1 + num_in_channels * num_out_channels * 2 / 128; + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + // Output Tensors + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + float* Reduction_Workspace = nullptr; + if (splitK != 1) { + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); + at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); + auto Reduction_Workspace = reinterpret_cast( + _out_feats.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * + // N_Global * sizeof(fp32) + } + + fp6_linear_kernel(0, // Using default stream here. + weight1, + weight2, + scales, + in_feats, + out_feats, + M, + N, + K, + Reduction_Workspace, + splitK); + + return _out_feats; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ + +/* + * Weight prepacking (Pytorch interface). + * [Input & Output] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * [Output] + * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; + */ +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t OC, size_t IC) +{ + assert((OC % 256 == 0) && (IC % 64 == 0)); + assert((fp6_tensor.size(0) == OC) && (fp6_tensor.size(1) == IC / 16 * 3)); + // auto packed_tensor = torch::empty_like(fp6_tensor); + // auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + weight_matrix_prepacking(fp6_tensor_ptr, OC, IC); + return fp6_tensor; +} \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh new file mode 100644 index 000000000000..41389d1a91d8 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh @@ -0,0 +1,38 @@ +#include +#include +#include + +#include + +/* +* Computes FP6-FP16 GEMM (C++ interface). +*/ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4 *Weight1, + const uint4 *Weight2, + const half *Scales, + const half *B, + half *C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) + int Split_K); + +/* +* Computes FP6-FP16 GEMM (PyTorch interface). +*/ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int splitK=1); + +/* + * In-place weight prepacking (C++ interface). + */ +void weight_matrix_prepacking(int *FP6Weights, size_t M, size_t K); + +/* + * Weight prepacking (Pytorch interface). + */ +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t M, size_t K); \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Configs.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h similarity index 100% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Configs.h rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_cp.async.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh similarity index 100% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_cp.async.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_mma.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh similarity index 100% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/PTX_mma.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_Core.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh similarity index 100% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_Core.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh similarity index 100% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_GMem.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh similarity index 95% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh index 827667abc320..9627741ce317 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/utils/Utils_ParallelDequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh @@ -52,8 +52,8 @@ __device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Sc half* FP16_2 = FP16_1 + 1; uint32_t output; half* output_half_ptr = reinterpret_cast(&output); - output_half_ptr[0] = *FP16_1 * Scale; - output_half_ptr[1] = *FP16_2 * Scale; + output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(4096.0f)), Scale); + output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(4096.0f)), Scale); return output; } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh similarity index 98% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh index 03d27433bc2b..7f120fdcb303 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_QuantGEMM.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh @@ -10,9 +10,9 @@ * See the License for the specific language governing permissions and * limitations under the License. ***************************************************************************/ -#include "utils/Configs.h" -#include "utils/Utils_GMem.cuh" -#include "utils/Utils_Core.cuh" +#include "Configs.h" +#include "Utils_GMem.cuh" +#include "Utils_Core.cuh" /* * C = A*B @@ -174,7 +174,7 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, co #pragma unroll for(size_t j=threadIdx.x%WARP_SIZE; j::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); + if constexpr (std::is_same::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; } } \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_Reduction.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh similarity index 100% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/Kernel_Reduction.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h similarity index 96% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h index 7ab71f957627..76ff5bbb6b8b 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/GenMatrix_QuantLLM.h +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h @@ -135,15 +135,15 @@ void BitInterleaving_4bit(unsigned char* PTR_4Bytes) * 8 FP4 = 4 Bytes * 8 FP2 = 2 Bytes */ -void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, - unsigned char* Weight_2bit, - unsigned char* Weight_4bit, - size_t M, - size_t K) +void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K) { assert(M % 64 == 0); assert(K % 64 == 0); // + unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); + unsigned char* Weight_2bit = Weight_6bit; + unsigned char* Weight_4bit = Weight_6bit + M * K * 2 / 8; + // vector A_Segment_2bit[32]; vector A_Segment_4bit[32]; // diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index 3d3e0cd32a66..f60b2ff72a59 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -84,11 +84,40 @@ def fp_quantize( scaled_input, exp_bits, man_bits, rounding="nearest") # TODO: it seems the `float_quantize` will not clamp the value into the range of FP6 correctly. # To double check it. If it is true, we need to clamp it manually. + + if False: + abs_quantized = torch.abs(quantized_fake_fp6) + max_fp6 = 14 + min_fp6 = 0.0625 + non_zero = abs_quantized[abs_quantized != 0] + larger_than_max = non_zero[non_zero > max_fp6] + smaller_than_min = non_zero[non_zero < min_fp6] + print( + f"FP32. value too large: {larger_than_max}, largest: {larger_than_max.max()}") + print( + f"FP32. value too small: {smaller_than_min}") + print(f"FP32. min: {non_zero.min()}") + # exit(0) + quantized_fake_fp6 = quantized_fake_fp6.reshape( input_shape).contiguous().to(torch.float16).to(orig_device) scales = scales.to(torch.float16).to(orig_device) # Now the dequantized value is quantized_fake_fp6 * scales + if False: + abs_quantized = torch.abs(quantized_fake_fp6) + max_fp6 = 14 + min_fp6 = 0.0625 + non_zero = abs_quantized[abs_quantized != 0] + larger_than_max = non_zero[non_zero > max_fp6] + smaller_than_min = non_zero[non_zero < min_fp6] + print( + f"FP16. value too large: {larger_than_max}, largest: {larger_than_max.max()}") + print( + f"FP16. value too small: {smaller_than_min}") + print(f"FP16. min: {non_zero.min()}") + # exit(0) + return quantized_fake_fp6, scales @@ -161,7 +190,7 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self.quantizer = fp_quantize # This is for debugging, will delete after release. - self.DEBUG = True + self.DEBUG = False def transform_param(self, param: torch.Tensor) -> InferenceParameter: """ @@ -186,10 +215,12 @@ def transform_param(self, param: torch.Tensor) -> InferenceParameter: assert quantized_fake_fp6.shape[0] == self.M assert scales.numel() == self.M + # Do not delete `quantized_fake_fp6` as the `preprocess_weight` is in-place operation. weights_2bit, weights_4bit = self.preprocess_weight(quantized_fake_fp6) # According to the optimization in Quant-LLM, the scales need to be multiplied by 2^12. - scales = scales * (2 ** 12) + # scales = scales * (2 ** 12) + # scales = torch.full_like(scales, 1) return InferenceParameter.initialize(weights_2bit, weights_4bit=weights_4bit, scales=scales) diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index 7ebd2a8c323f..64266cdf5e06 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -23,20 +23,24 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") + self.warning( + "Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True - if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + if not self.is_rocm_pytorch() and torch.cuda.is_available(): # ignore-cuda sys_cuda_major, _ = installed_cuda_version() torch_cuda_major = int(torch.version.cuda.split('.')[0]) - cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + cuda_capability = torch.cuda.get_device_properties( + 0).major # ignore-cuda if cuda_capability < 6: - self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") + self.warning( + "NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning("On Ampere and higher architectures please use CUDA 11+") + self.warning( + "On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay @@ -67,8 +71,8 @@ def sources(self): "inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu", "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp", "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu", + "inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu", "inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp", - "inference/v2/kernels/core_ops/cuda_linear/Launcher.cu", ] prefix = self.get_prefix() From b025c5adb98c99f258f381cce05c7b28e0d02c61 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Thu, 1 Feb 2024 12:09:57 +0000 Subject: [PATCH 10/31] Fix the bug of subnormal FP6 casting and the 2bit/4bit tensor allocation. --- .../kernels/core_ops/cuda_linear/Launcher.cu | 207 ------------------ .../cuda_linear/cuda_linear_kernels.cpp | 88 +++++--- .../linear/quantized_linear.py | 15 +- .../modules/test_quantizied_linear_module.py | 102 ++++++--- 4 files changed, 139 insertions(+), 273 deletions(-) delete mode 100644 deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu deleted file mode 100644 index 0e8636a81452..000000000000 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/Launcher.cu +++ /dev/null @@ -1,207 +0,0 @@ -#include "GenMatrix_QuantLLM.h" -#include "Kernel_QuantGEMM.cuh" -#include "Kernel_Reduction.cuh" - -#include -#include - -template -static void Kernel_QuantGEMM_Ex(cudaStream_t stream, - const uint4* Weight1, - const uint4* Weight2, - const half* Scales, - const half* B, - OutputDataType* C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - int Split_K) -{ -#ifdef DEBUG_MODE - printf("\n"); - printf("Launcher.cu->Kernel_QuantGEMM_Ex():\n"); - printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); - printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", - TilingConfig::TILE_M, - TilingConfig::TILE_K, - TilingConfig::TILE_N); -#endif - static size_t SHMEM_SZ = - max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE, - TilingConfig::SMEM_SIZE_C_TILE); - cudaFuncSetAttribute(QUANT_GEMM_Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - SHMEM_SZ); - size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; - size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; - dim3 GridDim(dimN, dimM, 1); - dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); -// -#ifdef DEBUG_MODE - printf( - "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: " - "%d SHMEM_SZ: %d\n", - GridDim.x, - GridDim.y, - GridDim.z, - BlockDim.x, - BlockDim.y, - BlockDim.z, - SHMEM_SZ); - printf("\n"); -#endif - QUANT_GEMM_Kernel<<>>( - Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); -} - -/* - *half* Reduction_Workspace: 1. Requiring an extra memory space in device memory for un-reducted - *intermediate output tensors - * 2. Reduction_Workspace_Size = Split_K * M_Global * N_Global * - *sizeof(fp32) - */ -cudaError_t QuantGEMM_API( - cudaStream_t stream, - const uint4* Weight1, - const uint4* Weight2, - const half* Scales, - const half* B, - half* C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches - int Split_K) -{ - // assert(M_Global % TilingConfig::TILE_M == 0); - // assert(K_Global % TilingConfig::TILE_K == 0); - assert(N_Global > 0); - - // Work around to support more N shapes: Pretending that the input is 2^n - size_t N_PowerOf2; - if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8; - if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16; - if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32; - if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64; - if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128; - if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128; - // printf("N_Global:%d N_PowerOf2:%d\n", N_Global, N_PowerOf2); - - if (Split_K == 1) { - switch (N_PowerOf2) { - case 8: - Kernel_QuantGEMM_Ex, half>( - stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - break; - case 16: - Kernel_QuantGEMM_Ex, half>( - stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - break; - case 32: - Kernel_QuantGEMM_Ex, half>( - stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - break; - case 64: - Kernel_QuantGEMM_Ex, half>( - stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - break; - case 128: - Kernel_QuantGEMM_Ex, half>( - stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - break; - default: - if (N_PowerOf2 % 128 != 0) { - printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); - return cudaErrorUnknown; - } - Kernel_QuantGEMM_Ex, half>( - stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - break; - } - } else { - switch (N_PowerOf2) { - case 8: - Kernel_QuantGEMM_Ex, float>(stream, - Weight1, - Weight2, - Scales, - B, - Reduction_Workspace, - M_Global, - N_Global, - K_Global, - Split_K); - break; - case 16: - Kernel_QuantGEMM_Ex, float>(stream, - Weight1, - Weight2, - Scales, - B, - Reduction_Workspace, - M_Global, - N_Global, - K_Global, - Split_K); - break; - case 32: - Kernel_QuantGEMM_Ex, float>(stream, - Weight1, - Weight2, - Scales, - B, - Reduction_Workspace, - M_Global, - N_Global, - K_Global, - Split_K); - break; - case 64: - Kernel_QuantGEMM_Ex, float>(stream, - Weight1, - Weight2, - Scales, - B, - Reduction_Workspace, - M_Global, - N_Global, - K_Global, - Split_K); - break; - case 128: - Kernel_QuantGEMM_Ex, float>(stream, - Weight1, - Weight2, - Scales, - B, - Reduction_Workspace, - M_Global, - N_Global, - K_Global, - Split_K); - break; - default: - if (N_PowerOf2 % 128 != 0) { - printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); - return cudaErrorUnknown; - } - Kernel_QuantGEMM_Ex, float>(stream, - Weight1, - Weight2, - Scales, - B, - Reduction_Workspace, - M_Global, - N_Global, - K_Global, - Split_K); - break; - } - // Reduction for SplitK - dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); - dim3 BlockDim(WARP_SIZE, 1, 1); - SplitK_Reduction<<>>( - C, Reduction_Workspace, M_Global, N_Global, Split_K); - } - return cudaGetLastError(); -} \ No newline at end of file diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index 665360e7efda..6b592b458d71 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -32,8 +32,9 @@ void print_bits(half num) printf("%s\n", bits); } -// Utils to prepack FP16 weights into continuous FP6 values. - +/* + * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. + */ void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) { // Constants for FP6 @@ -69,7 +70,12 @@ void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) // Extracting mantissa represented in FP16 int mant = source_promote & ((1 << mantissa_nbits_fp16) - 1); - int new_exp = exp == 0 ? 0 : exp - exp_bias_fp16 + exp_bias_fp6; + int new_exp = exp - exp_bias_fp16 + exp_bias_fp6; + if (exp == 0) { + // SUbnormal FP6 number. But the value is a normal FP16 number. Thus it needs a special + // treatment. + new_exp += 1; + } int new_mant = mant >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); fp6_temp[i] = (sign << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | @@ -82,20 +88,25 @@ void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) } /* - * Inputs: - * (1) uint16_t Weight_16bit[M*K] - * Outputs: - * (1) unsigned char Weight_6bit[M*K*6/8] + * Function to prepack FP16 weights into continuous FP6 values. + * + * Parameters: + * weight_16bit: input weight in FP16, size M*K + * weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8 + * M, K: the shape of the weight */ -void weight_prepacing_fp16_to_fp6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, size_t M, size_t K) +void weight_prepacing_fp16_to_fp6(uint16_t* weight_16bit, + uint8_t* weight_6bit_packed, + size_t M, + size_t K) { // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } size_t K_fp6_packed = K * 6 / 8; -#pragma omp parallel for + // #pragma omp parallel for for (auto m = 0; m < M; m++) { - uint8_t* ptr_6bit = Weight_6bit + m * K_fp6_packed; - uint16_t* ptr_16bit = Weight_16bit + m * K; + uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed; + uint16_t* ptr_16bit = weight_16bit + m * K; for (auto k = 0; k < K; k += 4) { cast_fp16_fp6(ptr_16bit, ptr_6bit); ptr_16bit += 4; @@ -106,19 +117,21 @@ void weight_prepacing_fp16_to_fp6(uint16_t* Weight_16bit, uint8_t* Weight_6bit, } // namespace -// cudaError_t QuantGEMM_API( -// cudaStream_t stream, -// const uint4* Weight1, -// const uint4* Weight2, -// const half* Scales, -// const half* B, -// half* C, -// const size_t M_Global, -// const size_t N_Global, -// const size_t K_Global, -// float* Reduction_Workspace, // Identical workspace for all QuantGEMM kernel launches -// int Split_K); - +/* + * Function to execute the FP6 linear kernel. + * + * Parameters: + * output: output tensor, size M*N + * hidden_states: input activation tensor, size N*K + * weights_2bit: packed 2bit weights, size M*K*2/8 + * weights_4bit: packed 4bit weights, size M*K*4/8 + * scales: scale tensor, size M + * workspace: workspace tensor, size M*N*split_k + * M: the output channel size of the weight + * N: the token number of the activation + * K: the input channel size of the weight + * split_k: the split size of the GEMM calculation + */ void cuda_wf6af16_linear(torch::Tensor& output, torch::Tensor& hidden_states, torch::Tensor& weights_2bit, @@ -151,17 +164,14 @@ void cuda_wf6af16_linear(torch::Tensor& output, } } -// void GenMatrix_Weight_FP6(unsigned char* Weight_6bit, -// unsigned char* Weight_2bit, -// unsigned char* Weight_4bit, -// size_t M, -// size_t K); - /* - * Inputs: - * (1) torch::Tensor weight[M, K] in FP16 - * Outputs: - * (1) torch::Tensor weight_2bit and weight_4bit + * Function to prepack the fake 6-bit-quantized FP16 weights into 2bit and 4bit. + * + * Parameters: + * weight: input weight in FP16 (containing the quantized FP6-ranged value), size M*K + * Returns: + * weight_2bit: output weight in 2bit, size M*K*2/8 + * weight_4bit: output weight in 4bit, size M*K*4/8 */ std::vector preprocess_weight(torch::Tensor& weight) { @@ -182,9 +192,15 @@ std::vector preprocess_weight(torch::Tensor& weight) // Split weight into 2bit and 4bit. weight_matrix_prepacking(reinterpret_cast(weight_6bit_ptr), M, K); uint8_t* weight_2bit_ptr = weight_6bit_ptr; - auto weight_2bit = torch::from_blob(weight_2bit_ptr, {M * K * 2 / 8}, torch::kUInt8); + + // Make sure that the new split tensor does not share the underlying memory with the original + // one. Otherwise it will incure some problems when the original tensor is deleted. It also + // makes the memory flattern risky. + auto weight_2bit = + torch::from_blob(weight_2bit_ptr, {M * K * 2 / 8}, torch::kUInt8).clone().detach(); uint8_t* weight_4bit_ptr = weight_2bit_ptr + M * K * 2 / 8; - auto weight_4bit = torch::from_blob(weight_4bit_ptr, {M * K * 4 / 8}, torch::kUInt8); + auto weight_4bit = + torch::from_blob(weight_4bit_ptr, {M * K * 4 / 8}, torch::kUInt8).clone().detach(); return {weight_2bit, weight_4bit}; } \ No newline at end of file diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index f60b2ff72a59..8deb8e3db21a 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -190,7 +190,7 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self.quantizer = fp_quantize # This is for debugging, will delete after release. - self.DEBUG = False + self.DEBUG = True def transform_param(self, param: torch.Tensor) -> InferenceParameter: """ @@ -218,15 +218,22 @@ def transform_param(self, param: torch.Tensor) -> InferenceParameter: # Do not delete `quantized_fake_fp6` as the `preprocess_weight` is in-place operation. weights_2bit, weights_4bit = self.preprocess_weight(quantized_fake_fp6) - # According to the optimization in Quant-LLM, the scales need to be multiplied by 2^12. - # scales = scales * (2 ** 12) - # scales = torch.full_like(scales, 1) + # print(f"weights_2bit: {weights_2bit}") + # print(f"weights_4bit: {weights_4bit}") return InferenceParameter.initialize(weights_2bit, weights_4bit=weights_4bit, scales=scales) def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: weights_2bit = w weights_4bit = w.weights_4bit + # print(f"shape of weights_2bit: {weights_2bit.shape}") + # print(f"shape of weights_4bit: {weights_4bit.shape}") + if False: + b2 = weights_2bit.cpu().numpy() + b4 = weights_4bit.cpu().numpy() + import numpy as np + np.savetxt("e2e-2b.txt", b2, fmt='%s', delimiter=',', newline=',') + np.savetxt("e2e-4b.txt", b4, fmt='%s', delimiter=',', newline=',') scales = w.scales output = empty_from( self._output, (hidden_states.shape[0], self._config.out_channels)) diff --git a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py index 314d5cc5ce90..2076e2862ba5 100644 --- a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py +++ b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py @@ -16,6 +16,18 @@ from ...v2.inference_test_utils import allclose +def save_tensor_to_file(tensor: torch.Tensor, file_name: str) -> None: + import numpy as np + np.savetxt(file_name, tensor.cpu().numpy(), delimiter=',', + newline='},\n{', header='half xxx={\n{', footer='}\n};', comments='', encoding=None) + + +def load_tensor_from_file(file_name: str) -> torch.Tensor: + import numpy as np + data = np.loadtxt(file_name, delimiter=',') + return torch.as_tensor(data, dtype=torch.float16, device=get_accelerator().current_device_name()) + + def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], act_type: ActivationType) -> torch.Tensor: dtype = hidden_states.dtype @@ -48,15 +60,24 @@ def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, def _fp6_quant_dequant_weights(weight: torch.Tensor) -> torch.Tensor: from deepspeed.inference.v2.modules.implementations.linear.quantized_linear import fp_quantize - weight_quantized_fake_fp6, scales = fp_quantize(weight, num_bits=6, exp_bits=3) + weight_quantized_fake_fp6, scales = fp_quantize( + weight, num_bits=6, exp_bits=3) + if False: + import numpy as np + save_tensor_to_file(weight_quantized_fake_fp6, 'quantized_weight.txt') + save_tensor_to_file(scales, 'scales.txt') + if False: + scales = torch.full_like(scales, 1) return weight_quantized_fake_fp6 * scales + # return weight_quantized_fake_fp6 def quant_dequant_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], act_type: ActivationType) -> torch.Tensor: dtype = hidden_states.dtype weight_dequantized = _fp6_quant_dequant_weights(weight) - out_states = torch.nn.functional.linear(hidden_states, weight_dequantized, bias) + out_states = torch.nn.functional.linear( + hidden_states, weight_dequantized, bias) out_states.float() if is_gated(act_type): @@ -89,7 +110,8 @@ def _fp6_quantized_weights_helper( dtype: DtypeEnum, act_fn: ActivationType, ) -> torch.Tensor: - weight_out_channels = 2 * out_channels if is_gated(act_fn) else out_channels + weight_out_channels = 2 * \ + out_channels if is_gated(act_fn) else out_channels weight = torch.randn( (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 weight_dequantized = _fp6_quant_dequant_weights(weight) @@ -108,9 +130,26 @@ def _fp6_quantized_linear_helper(tokens: int, hidden_states = torch.randn( (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 - weight_out_channels = 2 * out_channels if is_gated(act_fn) else out_channels + weight_out_channels = 2 * \ + out_channels if is_gated(act_fn) else out_channels weight = torch.randn( (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + if False: + # hidden_states = torch.full_like( + # hidden_states, 1.0, dtype=dtype.value, device=get_accelerator().current_device_name()) + # hidden_val = torch.as_tensor(range( + # 0, hidden_states.shape[0]), dtype=dtype.value, device=get_accelerator().current_device_name()) + # hidden_states = hidden_states * hidden_val.reshape(-1, 1) + # print(f"hidden is: {hidden_states}") + + # save_tensor_to_file(hidden_states, 'hidden_states.txt') + hidden_states = load_tensor_from_file('hidden_states.txt') + + weight = torch.full_like( + weight, 1, dtype=dtype.value, device=get_accelerator().current_device_name()) + weight_val = torch.as_tensor(range( + 0, weight.shape[1]), dtype=dtype.value, device=get_accelerator().current_device_name()) + weight = weight * weight_val.reshape(1, -1)*0.01 if use_bias: bias = torch.randn( (weight_out_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 @@ -118,7 +157,8 @@ def _fp6_quantized_linear_helper(tokens: int, bias = None # quantize and dequantize output - ref_quant_dequant_output = quant_dequant_implementation(hidden_states, weight, bias, act_fn) + ref_quant_dequant_output = quant_dequant_implementation( + hidden_states, weight, bias, act_fn) linear_config = DSLinearConfig(max_tokens=2048, in_channels=in_channels, @@ -126,13 +166,19 @@ def _fp6_quantized_linear_helper(tokens: int, activation=act_fn, input_dtype=dtype, output_dtype=dtype) - bundle = ConfigBundle(name='quantized_wf6af16_linear', config=linear_config) + bundle = ConfigBundle(name='quantized_wf6af16_linear', + config=linear_config) fp6_linear_module = DSLinearRegistry.instantiate_config(bundle) - weight_fp6 = fp6_linear_module.transform_param(weight.clone().cpu()).to(get_accelerator().current_device_name()) + weight_fp6 = fp6_linear_module.transform_param( + weight.clone().cpu()).to(get_accelerator().current_device_name()) ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) - tolerances = (4.8e-1, 3.2e-2) # tolerances for bf16 - # tolerances = (3e-2, 2e-3) # tolerances for fp16 + # tolerances = (4.8e-1, 3.2e-2) # tolerances for bf16 + tolerances = (3e-2, 2e-3) # tolerances for fp16 + + if True: + print(ds_output) + print(ref_quant_dequant_output) # Check DeepSpeed implementation assert allclose(ds_output, ref_quant_dequant_output, tolerances=tolerances) @@ -143,30 +189,34 @@ def _fp6_quantized_linear_helper(tokens: int, all_acts = [ - ActivationType.RELU, - ActivationType.GELU, - ActivationType.SILU, - ActivationType.GEGLU, - ActivationType.ReGLU, - ActivationType.SiGLU, + ActivationType.IDENTITY, + # ActivationType.RELU, + # ActivationType.GELU, + # ActivationType.SILU, + # ActivationType.GEGLU, + # ActivationType.ReGLU, + # ActivationType.SiGLU, ] -all_tokens = [1, 37, 1280] -all_in_out_channels = [(4608, 1728), (8192, 4096), (3072, 6144)] - +# all_tokens = [1, 37, 1280] +# all_in_out_channels = [(4608, 1728), (8192, 4096), (3072, 6144)] +all_tokens = [32] +all_in_out_channels = [(4096, 4096), (8192, 8192), (3072, 3072), (1024, 1024), (512, 512), (256, 256)] -@pytest.mark.inference_v2_ops -@pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) -@pytest.mark.parametrize("act_fn", all_acts) -@pytest.mark.parametrize("use_bias", [True, False]) -def test_fp6_quantized_weights(in_channels: int, out_channels: int, act_fn: ActivationType, use_bias: bool) -> None: - _fp6_quantized_weights_helper(in_channels, out_channels, DtypeEnum.fp16, act_fn) +# @pytest.mark.inference_v2_ops +# @pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) +# @pytest.mark.parametrize("act_fn", all_acts) +# @pytest.mark.parametrize("use_bias", [True, False]) +# def test_fp6_quantized_weights(in_channels: int, out_channels: int, act_fn: ActivationType, use_bias: bool) -> None: +# _fp6_quantized_weights_helper(in_channels, out_channels, DtypeEnum.fp16, act_fn) @pytest.mark.inference_v2_ops @pytest.mark.parametrize("tokens", all_tokens) @pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) @pytest.mark.parametrize("act_fn", all_acts) -@pytest.mark.parametrize("use_bias", [True, False]) +# @pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.parametrize("use_bias", [False]) def test_fp6_quantized_linear_act_fn(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, use_bias: bool) -> None: - _fp6_quantized_linear_helper(tokens, in_channels, out_channels, DtypeEnum.fp16, act_fn, use_bias=use_bias) + _fp6_quantized_linear_helper( + tokens, in_channels, out_channels, DtypeEnum.fp16, act_fn, use_bias=use_bias) From 6ed67f774a76b3ba272e6c2b0a490e49b1a0254e Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Thu, 1 Feb 2024 12:29:50 +0000 Subject: [PATCH 11/31] Clean code. --- .../core_ops/cuda_linear/cuda_linear.py | 32 ++++--- .../linear/quantized_linear.py | 41 +-------- .../modules/test_quantizied_linear_module.py | 87 +++---------------- 3 files changed, 28 insertions(+), 132 deletions(-) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py index aa973daf81cb..1bd8e9767207 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -24,27 +24,26 @@ def __init__(self): self.inf_module.create_handle() self.kernel = self.inf_module.cuda_wf6af16_linear - def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2bit: torch.Tensor, weights_4bit: torch.Tensor, scale: torch.Tensor, M, N, K) -> torch.Tensor: + def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2bit: torch.Tensor, + weights_4bit: torch.Tensor, scale: torch.Tensor, M, N, K) -> torch.Tensor: """ - Matmul kernel as implemented via CUDA directly. The input must be 2D or larger. If - n-dimensional, the leading dimensions are folded into each other: - 2D: m = x.size(0) - 3D: m = x.size(0) * x.size(1) - 4D: m = x.size(0) * x.size(1) * x.size(2) (etc...) - All inputs should be contiguous. + Matmul kernel of FP6 weight-only quantized linear. All inputs should be contiguous. + It does not support batched-matmul. Parameters: - output (torch.Tensor): Output tensor. Shape is of [*, out_features] - hidden_states (torch.Tensor): Input tensor. Shape is of [*, in_features] - weights (torch.Tensor): Input tensor. Shape is of [out_features, in_features] - scale (torch.Tensor): Input tensor. Shape is of [1] or [out_features], since the scale is per output channel - - Returns: - z (torch.Tensor): Output tensor. Shape is of [m, n] + output (torch.Tensor): Output tensor. Shape is of [token_number, out_features] + hidden_states (torch.Tensor): Input tensor. Shape is of [token_number, in_features] + weights_2bit (torch.Tensor): Input tensor of the 2-bit slice. Shape is of [out_features*2/8, in_features] + weights_4bit (torch.Tensor): Input tensor of the 4-bit slice. Shape is of [out_features*4/8, in_features] + scale (torch.Tensor): Input tensor. Shape is of [out_features], since the scale is per output channel + M (int): The number of output channels + N (int): The number of tokens + K (int): The number of input channels """ - # TODO: deal with batched-matmul. As the current implementation only supports 2D input, we need to split the - # batched-matmul into multiple 2D matmul. + if M % 256 != 0 or K % 64 != 0: + raise ValueError( + "The out and in channel of the FP6 weight-only quantized linear should be multiple of 256 and 64 respectively.") # TODO: optimize the heuristic of split k selection. split_k_dict = {15360: 3, 27648: 2, 5120: 10, 10240: 5, @@ -63,6 +62,5 @@ def get_workspace(self, M: int, N: int, K: int, split_k: int, dtype, device) -> split-K. The split-K size is determined by the size of the matmul. """ workspace = torch.empty((split_k, M, N), dtype=dtype, device=device) - # TODO: allocate workspace in advance to avoid memory allocation overhead return workspace diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index 8deb8e3db21a..b6d43dba43bd 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -85,39 +85,11 @@ def fp_quantize( # TODO: it seems the `float_quantize` will not clamp the value into the range of FP6 correctly. # To double check it. If it is true, we need to clamp it manually. - if False: - abs_quantized = torch.abs(quantized_fake_fp6) - max_fp6 = 14 - min_fp6 = 0.0625 - non_zero = abs_quantized[abs_quantized != 0] - larger_than_max = non_zero[non_zero > max_fp6] - smaller_than_min = non_zero[non_zero < min_fp6] - print( - f"FP32. value too large: {larger_than_max}, largest: {larger_than_max.max()}") - print( - f"FP32. value too small: {smaller_than_min}") - print(f"FP32. min: {non_zero.min()}") - # exit(0) - quantized_fake_fp6 = quantized_fake_fp6.reshape( input_shape).contiguous().to(torch.float16).to(orig_device) scales = scales.to(torch.float16).to(orig_device) # Now the dequantized value is quantized_fake_fp6 * scales - if False: - abs_quantized = torch.abs(quantized_fake_fp6) - max_fp6 = 14 - min_fp6 = 0.0625 - non_zero = abs_quantized[abs_quantized != 0] - larger_than_max = non_zero[non_zero > max_fp6] - smaller_than_min = non_zero[non_zero < min_fp6] - print( - f"FP16. value too large: {larger_than_max}, largest: {larger_than_max.max()}") - print( - f"FP16. value too small: {smaller_than_min}") - print(f"FP16. min: {non_zero.min()}") - # exit(0) - return quantized_fake_fp6, scales @@ -190,7 +162,7 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self.quantizer = fp_quantize # This is for debugging, will delete after release. - self.DEBUG = True + self.DEBUG = False def transform_param(self, param: torch.Tensor) -> InferenceParameter: """ @@ -218,22 +190,11 @@ def transform_param(self, param: torch.Tensor) -> InferenceParameter: # Do not delete `quantized_fake_fp6` as the `preprocess_weight` is in-place operation. weights_2bit, weights_4bit = self.preprocess_weight(quantized_fake_fp6) - # print(f"weights_2bit: {weights_2bit}") - # print(f"weights_4bit: {weights_4bit}") - return InferenceParameter.initialize(weights_2bit, weights_4bit=weights_4bit, scales=scales) def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: weights_2bit = w weights_4bit = w.weights_4bit - # print(f"shape of weights_2bit: {weights_2bit.shape}") - # print(f"shape of weights_4bit: {weights_4bit.shape}") - if False: - b2 = weights_2bit.cpu().numpy() - b4 = weights_4bit.cpu().numpy() - import numpy as np - np.savetxt("e2e-2b.txt", b2, fmt='%s', delimiter=',', newline=',') - np.savetxt("e2e-4b.txt", b4, fmt='%s', delimiter=',', newline=',') scales = w.scales output = empty_from( self._output, (hidden_states.shape[0], self._config.out_channels)) diff --git a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py index 2076e2862ba5..f2f37b2a2728 100644 --- a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py +++ b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py @@ -16,18 +16,6 @@ from ...v2.inference_test_utils import allclose -def save_tensor_to_file(tensor: torch.Tensor, file_name: str) -> None: - import numpy as np - np.savetxt(file_name, tensor.cpu().numpy(), delimiter=',', - newline='},\n{', header='half xxx={\n{', footer='}\n};', comments='', encoding=None) - - -def load_tensor_from_file(file_name: str) -> torch.Tensor: - import numpy as np - data = np.loadtxt(file_name, delimiter=',') - return torch.as_tensor(data, dtype=torch.float16, device=get_accelerator().current_device_name()) - - def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], act_type: ActivationType) -> torch.Tensor: dtype = hidden_states.dtype @@ -62,14 +50,7 @@ def _fp6_quant_dequant_weights(weight: torch.Tensor) -> torch.Tensor: from deepspeed.inference.v2.modules.implementations.linear.quantized_linear import fp_quantize weight_quantized_fake_fp6, scales = fp_quantize( weight, num_bits=6, exp_bits=3) - if False: - import numpy as np - save_tensor_to_file(weight_quantized_fake_fp6, 'quantized_weight.txt') - save_tensor_to_file(scales, 'scales.txt') - if False: - scales = torch.full_like(scales, 1) return weight_quantized_fake_fp6 * scales - # return weight_quantized_fake_fp6 def quant_dequant_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], @@ -104,22 +85,6 @@ def quant_dequant_implementation(hidden_states: torch.Tensor, weight: torch.Tens return out_states.to(dtype) -def _fp6_quantized_weights_helper( - in_channels: int, - out_channels: int, - dtype: DtypeEnum, - act_fn: ActivationType, -) -> torch.Tensor: - weight_out_channels = 2 * \ - out_channels if is_gated(act_fn) else out_channels - weight = torch.randn( - (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 - weight_dequantized = _fp6_quant_dequant_weights(weight) - tolerances = (4.8e-1, 3.2e-2) # tolerances for bf16 - # tolerances = (3e-2, 2e-3) # tolerances for fp16 - assert allclose(weight_dequantized, weight, tolerances=tolerances) - - def _fp6_quantized_linear_helper(tokens: int, in_channels: int, out_channels: int, @@ -134,22 +99,6 @@ def _fp6_quantized_linear_helper(tokens: int, out_channels if is_gated(act_fn) else out_channels weight = torch.randn( (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 - if False: - # hidden_states = torch.full_like( - # hidden_states, 1.0, dtype=dtype.value, device=get_accelerator().current_device_name()) - # hidden_val = torch.as_tensor(range( - # 0, hidden_states.shape[0]), dtype=dtype.value, device=get_accelerator().current_device_name()) - # hidden_states = hidden_states * hidden_val.reshape(-1, 1) - # print(f"hidden is: {hidden_states}") - - # save_tensor_to_file(hidden_states, 'hidden_states.txt') - hidden_states = load_tensor_from_file('hidden_states.txt') - - weight = torch.full_like( - weight, 1, dtype=dtype.value, device=get_accelerator().current_device_name()) - weight_val = torch.as_tensor(range( - 0, weight.shape[1]), dtype=dtype.value, device=get_accelerator().current_device_name()) - weight = weight * weight_val.reshape(1, -1)*0.01 if use_bias: bias = torch.randn( (weight_out_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 @@ -174,12 +123,9 @@ def _fp6_quantized_linear_helper(tokens: int, ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) # tolerances = (4.8e-1, 3.2e-2) # tolerances for bf16 + # The current FP6 kernel uses FP16 Tensor Core. tolerances = (3e-2, 2e-3) # tolerances for fp16 - if True: - print(ds_output) - print(ref_quant_dequant_output) - # Check DeepSpeed implementation assert allclose(ds_output, ref_quant_dequant_output, tolerances=tolerances) @@ -189,33 +135,24 @@ def _fp6_quantized_linear_helper(tokens: int, all_acts = [ - ActivationType.IDENTITY, - # ActivationType.RELU, - # ActivationType.GELU, - # ActivationType.SILU, - # ActivationType.GEGLU, - # ActivationType.ReGLU, - # ActivationType.SiGLU, + ActivationType.RELU, + ActivationType.GELU, + ActivationType.SILU, + ActivationType.GEGLU, + ActivationType.ReGLU, + ActivationType.SiGLU, ] -# all_tokens = [1, 37, 1280] -# all_in_out_channels = [(4608, 1728), (8192, 4096), (3072, 6144)] -all_tokens = [32] -all_in_out_channels = [(4096, 4096), (8192, 8192), (3072, 3072), (1024, 1024), (512, 512), (256, 256)] - -# @pytest.mark.inference_v2_ops -# @pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) -# @pytest.mark.parametrize("act_fn", all_acts) -# @pytest.mark.parametrize("use_bias", [True, False]) -# def test_fp6_quantized_weights(in_channels: int, out_channels: int, act_fn: ActivationType, use_bias: bool) -> None: -# _fp6_quantized_weights_helper(in_channels, out_channels, DtypeEnum.fp16, act_fn) +all_tokens = [1, 37, 1280] +# TODO: some of the shapes are not supported. The output channels should be a multiple of 256. +# The input channel should be a multiple of 64. +all_in_out_channels = [(4608, 1728), (8192, 4096), (3072, 6144)] @pytest.mark.inference_v2_ops @pytest.mark.parametrize("tokens", all_tokens) @pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) @pytest.mark.parametrize("act_fn", all_acts) -# @pytest.mark.parametrize("use_bias", [True, False]) -@pytest.mark.parametrize("use_bias", [False]) +@pytest.mark.parametrize("use_bias", [True, False]) def test_fp6_quantized_linear_act_fn(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, use_bias: bool) -> None: _fp6_quantized_linear_helper( From 20b543caec506370b23a2000f1a500d123139da3 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Thu, 1 Feb 2024 12:41:57 +0000 Subject: [PATCH 12/31] pre-commit --- debug/run_pipeline.py | 3 +- .../core_ops/cuda_linear/cuda_linear.py | 12 +- .../cuda_linear/cuda_linear_kernels.cpp | 2 +- .../core_ops/cuda_linear/fp6_linear.cu | 2 +- .../core_ops/cuda_linear/fp6_linear.cuh | 37 +- .../core_ops/cuda_linear/include/Configs.h | 2 +- .../cuda_linear/include/PTX_cp.async.cuh | 32 +- .../core_ops/cuda_linear/include/PTX_mma.cuh | 133 +++--- .../cuda_linear/include/Utils_Core.cuh | 301 ++++++++------ .../cuda_linear/include/Utils_GMem.cuh | 84 ++-- .../include/Utils_ParallelDequant.cuh | 83 ++-- .../cuda_linear/include/kernel_matmul.cuh | 382 +++++++++++------- .../cuda_linear/include/kernel_reduction.cuh | 37 +- .../cuda_linear/include/weight_prepacking.h | 2 +- .../flat_model_helpers.py | 30 +- deepspeed/inference/v2/modules/heuristics.py | 9 +- .../implementations/linear/__init__.py | 1 - .../linear/quantized_linear.py | 67 ++- op_builder/inference_core_ops.py | 12 +- .../modules/test_quantizied_linear_module.py | 18 +- 20 files changed, 685 insertions(+), 564 deletions(-) diff --git a/debug/run_pipeline.py b/debug/run_pipeline.py index 330c836abed8..bd0d8d8939aa 100644 --- a/debug/run_pipeline.py +++ b/debug/run_pipeline.py @@ -12,7 +12,6 @@ def fake_request_texts(batch_size: int): batch_size = 32 prompts = fake_request_texts(batch_size) - pipe = mii.pipeline(model_name_or_path=model_id, - quantization_mode='wf6af16') + pipe = mii.pipeline(model_name_or_path=model_id, quantization_mode='wf6af16') response = pipe(prompts, max_new_tokens=2) print(f"{len(response)} responses.") diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py index 1bd8e9767207..6bda28719de9 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -43,18 +43,16 @@ def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2b if M % 256 != 0 or K % 64 != 0: raise ValueError( - "The out and in channel of the FP6 weight-only quantized linear should be multiple of 256 and 64 respectively.") + "The out and in channel of the FP6 weight-only quantized linear should be multiple of 256 and 64 respectively." + ) # TODO: optimize the heuristic of split k selection. - split_k_dict = {15360: 3, 27648: 2, 5120: 10, 10240: 5, - 57344: 7, 8192: 6, 21504: 5, 7168: 7, 28672: 7} + split_k_dict = {15360: 3, 27648: 2, 5120: 10, 10240: 5, 57344: 7, 8192: 6, 21504: 5, 7168: 7, 28672: 7} split_k = 1 if not N > 128 and M in split_k_dict: split_k = split_k_dict[M] - workspace = self.get_workspace( - M, N, K, split_k, torch.float, hidden_states.device) - self.kernel(output, hidden_states, weights_2bit, - weights_4bit, scale, workspace, M, N, K, split_k) + workspace = self.get_workspace(M, N, K, split_k, torch.float, hidden_states.device) + self.kernel(output, hidden_states, weights_2bit, weights_4bit, scale, workspace, M, N, K, split_k) def get_workspace(self, M: int, N: int, K: int, split_k: int, dtype, device) -> torch.Tensor: """ diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index 6b592b458d71..1bdb94787eef 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -203,4 +203,4 @@ std::vector preprocess_weight(torch::Tensor& weight) torch::from_blob(weight_4bit_ptr, {M * K * 4 / 8}, torch::kUInt8).clone().detach(); return {weight_2bit, weight_4bit}; -} \ No newline at end of file +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu index d80d2764b050..27c0cd0579a7 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu @@ -300,4 +300,4 @@ torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t OC, auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); weight_matrix_prepacking(fp6_tensor_ptr, OC, IC); return fp6_tensor; -} \ No newline at end of file +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh index 41389d1a91d8..4cdbfed2fe42 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh @@ -5,34 +5,35 @@ #include /* -* Computes FP6-FP16 GEMM (C++ interface). -*/ -cudaError_t fp6_linear_kernel(cudaStream_t stream, - const uint4 *Weight1, - const uint4 *Weight2, - const half *Scales, - const half *B, - half *C, - const size_t M_Global, - const size_t N_Global, - const size_t K_Global, - float *Reduction_Workspace, // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - int Split_K); + * Computes FP6-FP16 GEMM (C++ interface). + */ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K * + // M_Global * N_Global * sizeof(fp32) + int Split_K); /* -* Computes FP6-FP16 GEMM (PyTorch interface). -*/ + * Computes FP6-FP16 GEMM (PyTorch interface). + */ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, torch::Tensor _weights, torch::Tensor _scales, - int splitK=1); + int splitK = 1); /* * In-place weight prepacking (C++ interface). */ -void weight_matrix_prepacking(int *FP6Weights, size_t M, size_t K); +void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K); /* * Weight prepacking (Pytorch interface). */ -torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t M, size_t K); \ No newline at end of file +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t M, size_t K); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h index d910454fb819..aa5f9e527249 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h @@ -86,4 +86,4 @@ struct TilingConfig { 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for // each thread -#endif // CONFIGS_H \ No newline at end of file +#endif // CONFIGS_H diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh index 90b00c386f2f..9ab917da8183 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh @@ -19,19 +19,22 @@ #include #include -template -__device__ __forceinline__ void cp_async(half* smem_ptr, const half* global_ptr, bool pred_guard = true) +template +__device__ __forceinline__ void cp_async(half* smem_ptr, + const half* global_ptr, + bool pred_guard = true) { static_assert(SizeInBytes == 16, "Size is not supported"); unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); - asm volatile("{ \n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred_guard), - "r"(smem_int_ptr), - "l"(global_ptr), - "n"(SizeInBytes)); + asm volatile( + "{ \n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), + "l"(global_ptr), + "n"(SizeInBytes)); } /// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. @@ -41,7 +44,7 @@ __device__ __forceinline__ void cp_async_group_commit() } /// Blocks until all but previous cp.async.commit_group operations have committed. -template +template __device__ __forceinline__ void cp_async_wait_group() { asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); @@ -51,9 +54,6 @@ __device__ __forceinline__ void cp_async_wait_group() // cp.async.wait_all is equivalent to : // cp.async.commit_group; // cp.async.wait_group 0; -__device__ __forceinline__ void cp_async_wait_all() -{ - asm volatile("cp.async.wait_all;\n" ::); -} +__device__ __forceinline__ void cp_async_wait_all() { asm volatile("cp.async.wait_all;\n" ::); } -#endif \ No newline at end of file +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh index 94b8e0c94d92..632be75959ed 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh @@ -22,37 +22,41 @@ #ifdef PIPELINE_LEVEL_SMEM template -__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], - half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - int slice_id) { - #ifdef DEBUG_MODE - static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); - #endif - - const int warpId = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; +__device__ __forceinline__ void B_FromSharedToReg( + uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + int slice_id) +{ +#ifdef DEBUG_MODE + static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0)); +#endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; - int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory - #ifdef DEBUG_MODE - assert( warp_start_col==0 ); - #endif + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * + WARP_j; // each warp may start from reading warp_start_col'th column of + // the B tile in shared memory +#ifdef DEBUG_MODE + assert(warp_start_col == 0); +#endif - int col = (lane_id%8) + (lane_id/16)*8; - int row = (lane_id%16) / 8 * 8; - uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][slice_id*MMA_16 + row])); - if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + int col = (lane_id % 8) + (lane_id / 16) * 8; + int row = (lane_id % 16) / 8 * 8; + uint32_t smem_local_ptr = static_cast( + __cvta_generic_to_shared(&read_SPTR[warp_start_col + col][slice_id * MMA_16 + row])); + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(Reg[0][0]), "=r"(Reg[0][1]) : "r"(smem_local_ptr)); - } - else { - #pragma unroll - for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) - { + } else { +#pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) { asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) : "r"(smem_local_ptr)); - smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + smem_local_ptr += 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); } } } @@ -60,54 +64,67 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[ // Debug: Whether ldmatrix.trans is required??? // B is in column-major template -__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], - half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - int k_offset) { - #ifdef DEBUG_MODE - static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); - #endif - - const int warpId = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; +__device__ __forceinline__ void B_FromSharedToReg( + uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + int k_offset) +{ +#ifdef DEBUG_MODE + static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0)); +#endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; - int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory - #ifdef DEBUG_MODE - assert( warp_start_col==0 ); - #endif + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * + WARP_j; // each warp may start from reading warp_start_col'th column of + // the B tile in shared memory +#ifdef DEBUG_MODE + assert(warp_start_col == 0); +#endif - int col = (lane_id%8) + (lane_id/16)*8; - int row = (lane_id%16) / 8 * 8; - uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][k_offset + row])); - if(TilingConfig::WARP_COL_MMA_TENSORS==1) { + int col = (lane_id % 8) + (lane_id / 16) * 8; + int row = (lane_id % 16) / 8 * 8; + uint32_t smem_local_ptr = static_cast( + __cvta_generic_to_shared(&read_SPTR[warp_start_col + col][k_offset + row])); + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(Reg[0][0]), "=r"(Reg[0][1]) : "r"(smem_local_ptr)); - } - else { - #pragma unroll - for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) - { + } else { +#pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) { asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) : "r"(smem_local_ptr)); - smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + smem_local_ptr += 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); } } } #endif -__device__ __forceinline__ void -MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b) +__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[], + uint32_t __restrict__* a, + uint32_t __restrict__* b) { - asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{ %0, %1, %2, %3}," - "{ %4, %5, %6, %7 }," - "{ %8, %9 }," - "{ %10, %11, %12, %13 };" - : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), - "r"(b[0]), "r"(b[1]), - "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), + "r"(a[1]), + "r"(a[2]), + "r"(a[3]), + "r"(b[0]), + "r"(b[1]), + "r"(c[0]), + "r"(c[1]), + "r"(c[2]), + "r"(c[3])); } -#endif \ No newline at end of file +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh index a82d097a2c8a..6635b897e64f 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh @@ -19,194 +19,233 @@ #include "PTX_mma.cuh" #include "Utils_ParallelDequant.cuh" - #ifdef PIPELINE_LEVEL_SMEM -template -__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { - SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE); - int lane_id = threadIdx.x % WARP_SIZE; - #pragma unroll - for(int i=0; i +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], + uint32_t* SPTR, + int slice_id) +{ + SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE); + int lane_id = threadIdx.x % WARP_SIZE; +#pragma unroll + for (int i = 0; i < NUM_INT_PER_THREAD; i++) { Reg[i] = SPTR[lane_id + i * WARP_SIZE]; } } template -__device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], - uint32_t (*b)[4], - uint32_t* __restrict__ A1_SPTR_read, - uint32_t* __restrict__ A2_SPTR_read, - half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - uint32_t* RPTR_Scales) +__device__ __forceinline__ void initialize_mma_slice( + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales) { // Writing registers - // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2]; // NO double buffer - uint32_t a_2[4]; // NO double buffer - CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, 0); - CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, 0); - Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time - B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2>(a_1, A1_SPTR_read, 0); + CopyFromSharedToRegister_AFrag<4>(a_2, A2_SPTR_read, 0); + Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at + // register level, dequantizing a slice each time + B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } template -__device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], - uint32_t (*a)[4], - uint32_t (*b)[4], - uint32_t* __restrict__ A1_SPTR_read, - uint32_t* __restrict__ A2_SPTR_read, - half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - uint32_t* RPTR_Scales, - int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching +__device__ __forceinline__ void core_mma_slice( + float c[][REG_PER_THREAD_C_TENSOR_16_16], + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales, + int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching { - #ifdef DEBUG_MODE - assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block - #endif - const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block - const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block - uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); // Reigsters for accumulated FP32 results +#ifdef DEBUG_MODE + assert( + (TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == + 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block +#endif + const int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block + uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = + reinterpret_cast( + c); // Reigsters for accumulated FP32 results // Setting RPTRs for double buffers - uint32_t (*a_read )[4] = a; - uint32_t (*a_write)[4] = a; - uint32_t (*b_read )[4] = b; - uint32_t (*b_write)[4] = b; - if(slice_id%2==1) { b_write += NumRegSets_b; a_write += NumRegSets_a;} - else { b_read += NumRegSets_b; a_read += NumRegSets_a;} + uint32_t(*a_read)[4] = a; + uint32_t(*a_write)[4] = a; + uint32_t(*b_read)[4] = b; + uint32_t(*b_write)[4] = b; + if (slice_id % 2 == 1) { + b_write += NumRegSets_b; + a_write += NumRegSets_a; + } else { + b_read += NumRegSets_b; + a_read += NumRegSets_a; + } - // Reading registers and issuing core tensor core computations (a slice of A and B tile in shared memory) - #pragma unroll +// Reading registers and issuing core tensor core computations (a slice of A and B tile in shared +// memory) +#pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { - if(TilingConfig::WARP_COL_MMA_TENSORS==1) { - MMA_FP16_M16N8K16( c_uint_ptr[i], a_read[i], b_read[0] ); - } - else { - #pragma unroll - for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j] ); - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a_read[i], b_read[j] + 2 ); // c+4; b+2 + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]); + } else { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) { + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j]); + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, + a_read[i], + b_read[j] + 2); // c+4; b+2 } } } // Writing registers - // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2]; // NO double buffer - uint32_t a_2[4]; // NO double buffer - CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, slice_id); - CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, slice_id); - Dequant_32FP6_4Way(a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time - B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2>(a_1, A1_SPTR_read, slice_id); + CopyFromSharedToRegister_AFrag<4>(a_2, A2_SPTR_read, slice_id); + Dequant_32FP6_4Way( + a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register + // level, dequantizing a slice each time + B_FromSharedToReg( + b_write, B_SPTR_read, slice_id); // Loading B from shared to registers } #else // Old version with naive pipeline design -template -__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) { - int lane_id = threadIdx.x % WARP_SIZE; - #pragma unroll - for(int i=0; i +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) +{ + int lane_id = threadIdx.x % WARP_SIZE; +#pragma unroll + for (int i = 0; i < NUM_INT_PER_THREAD; i++) { Reg[i] = SPTR[lane_id + i * WARP_SIZE]; } } template -__device__ __forceinline__ void PipelinedCoreLoop(float c[][REG_PER_THREAD_C_TENSOR_16_16], - half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - uint32_t* __restrict__ read_SPTR_Frag1, - uint32_t* __restrict__ read_SPTR_Frag2, - uint32_t* RPTR_Scales) +__device__ __forceinline__ void PipelinedCoreLoop( + float c[][REG_PER_THREAD_C_TENSOR_16_16], + half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* __restrict__ read_SPTR_Frag1, + uint32_t* __restrict__ read_SPTR_Frag2, + uint32_t* RPTR_Scales) { - #ifdef DEBUG_MODE - assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block - #endif - const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block - const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block +#ifdef DEBUG_MODE + assert( + (TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == + 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block +#endif + const int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block // Reigsters to store FP32 results - uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); - // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2*2]; // double buffer is used - uint32_t a_2[4*2]; // double buffer is used + uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = + reinterpret_cast(c); + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2 * 2]; // double buffer is used + uint32_t a_2[4 * 2]; // double buffer is used // Registers to store decompressed FP6 - uint32_t a [NumRegSets_a * 1][4]; // No double buffer + uint32_t a[NumRegSets_a * 1][4]; // No double buffer // Register to store FP16 B matrix (a slice) - uint32_t b [NumRegSets_b * 2][4]; // double buffer is used + uint32_t b[NumRegSets_b * 2][4]; // double buffer is used // Overlapped Smem and TC pipeline: pre-loading from shared to registers - CopyFromSharedToRegister_AFrag<2> (a_1, read_SPTR_Frag1); - CopyFromSharedToRegister_AFrag<4> (a_2, read_SPTR_Frag2); - B_FromSharedToReg (b, read_SPTR, 0); + CopyFromSharedToRegister_AFrag<2>(a_1, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4>(a_2, read_SPTR_Frag2); + B_FromSharedToReg(b, read_SPTR, 0); - #pragma unroll +#pragma unroll for (int k = 0; k < WARP_K_MMA_TENSORS; k++) { - uint32_t (*b_read)[4] = b; - uint32_t (*b_write)[4] = b; - uint32_t *a_1_read = a_1; - uint32_t *a_1_write = a_1; - uint32_t *a_2_read = a_2; - uint32_t *a_2_write = a_2; - if(k%2==0) { - b_write += NumRegSets_b; - a_1_write += 2; - a_2_write += 4; - } - else { - b_read += NumRegSets_b; - a_1_read += 2; - a_2_read += 4; + uint32_t(*b_read)[4] = b; + uint32_t(*b_write)[4] = b; + uint32_t* a_1_read = a_1; + uint32_t* a_1_write = a_1; + uint32_t* a_2_read = a_2; + uint32_t* a_2_write = a_2; + if (k % 2 == 0) { + b_write += NumRegSets_b; + a_1_write += 2; + a_2_write += 4; + } else { + b_read += NumRegSets_b; + a_1_read += 2; + a_2_read += 4; } // data loading if (k + 1 < WARP_K_MMA_TENSORS) { // updating SPTR for fragment1 and fragment2 - read_SPTR_Frag1 += 2*WARP_SIZE; - read_SPTR_Frag2 += 4*WARP_SIZE; + read_SPTR_Frag1 += 2 * WARP_SIZE; + read_SPTR_Frag2 += 4 * WARP_SIZE; CopyFromSharedToRegister_AFrag<2>(a_1_write, read_SPTR_Frag1); CopyFromSharedToRegister_AFrag<4>(a_2_write, read_SPTR_Frag2); - B_FromSharedToReg(b_write, read_SPTR, (k+1)*MMA_16); + B_FromSharedToReg(b_write, read_SPTR, (k + 1) * MMA_16); } // SIMT Dequant + Tensor Core computations - Dequant_32FP6_4Way(a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, dequantizing a slice each time - #pragma unroll + Dequant_32FP6_4Way( + a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, + // dequantizing a slice each time +#pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { - if(TilingConfig::WARP_COL_MMA_TENSORS==1) - MMA_FP16_M16N8K16( c_uint_ptr[i], a[i], b_read[0] ); - else { - #pragma unroll - for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j] ); - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a[i], b_read[j] + 2 ); // c+4; b+2 + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) + MMA_FP16_M16N8K16(c_uint_ptr[i], a[i], b_read[0]); + else { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) { + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j]); + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, + a[i], + b_read[j] + 2); // c+4; b+2 } } } } } -#endif // #ifdef PIPELINE_LEVEL_SMEM +#endif // #ifdef PIPELINE_LEVEL_SMEM template -__device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], - float c[][REG_PER_THREAD_C_TENSOR_16_16]) +__device__ __forceinline__ void StoreToSharedMemoryFromRegister( + float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], + float c[][REG_PER_THREAD_C_TENSOR_16_16]) { - const int lane_id = threadIdx.x % WARP_SIZE; - const int warpId = threadIdx.x / WARP_SIZE; - int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); - #pragma unroll + const int lane_id = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); +#pragma unroll for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { - #pragma unroll - for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; j++) { // Dealing with one 16*8 Tensor - int RegSetID = i + (j/2)*WARP_ROW_MMA_TENSORS; - int RegOffset = (j%2)*(REG_PER_THREAD_C_TENSOR_16_16/2); - int Tensor_row_offset = warp_row_offset + i * MMA_16; - int Tensor_col_offset = j * MMA_8; - #pragma unroll - for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16/2; r++) { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; + j++) { // Dealing with one 16*8 Tensor + int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS; + int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2); + int Tensor_row_offset = warp_row_offset + i * MMA_16; + int Tensor_col_offset = j * MMA_8; +#pragma unroll + for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) { int row_offset = lane_id / 4; if (r >= 2) row_offset += 8; int col_offset = (lane_id % 4) * 2; - if (r%2==1) col_offset += 1; - smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = c[RegSetID][r + RegOffset]; + if (r % 2 == 1) col_offset += 1; + smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = + c[RegSetID][r + RegOffset]; } } } } -#endif \ No newline at end of file +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh index fce0146b3fa3..b1c023dae24d 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh @@ -5,71 +5,75 @@ #include "Configs.h" #include "PTX_cp.async.cuh" -/* +/* * Copying A1/A2 from global memory to shared memory. * Usually 1024 or 2048 Bytes */ -template -__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, - const uint4* GPTR, - bool pred_guard = true) { - #ifdef DEBUG_MODE - static_assert(SMEM_SIZE_IN_BYTES_PER_WARP/WARP_SIZE % 16 == 0); - #endif - int lane_id = threadIdx.x % WARP_SIZE; +template +__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, + const uint4* GPTR, + bool pred_guard = true) +{ +#ifdef DEBUG_MODE + static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0); +#endif + int lane_id = threadIdx.x % WARP_SIZE; half* SPTR_HALF = reinterpret_cast(SPTR); const half* GPTR_HALF = reinterpret_cast(GPTR); - SPTR_HALF += lane_id*8; - GPTR_HALF += lane_id*8; - #pragma unroll - for(int i=0; i( SPTR_HALF, GPTR_HALF, pred_guard); - SPTR_HALF += 256; // Forward 512 Bytes - GPTR_HALF += 256; // Forward 512 Bytes + SPTR_HALF += lane_id * 8; + GPTR_HALF += lane_id * 8; +#pragma unroll + for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) { + cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard); + SPTR_HALF += 256; // Forward 512 Bytes + GPTR_HALF += 256; // Forward 512 Bytes } - } -/* +/* * Copying 64 Quant Scales (FP16) from global memory to shared memory. */ __device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, - const half* GPTR_A_Scales) { - int lane_id = threadIdx.x % WARP_SIZE; - int Offset_Shared = lane_id*2; - int Offset_Global = lane_id/4 + (lane_id%4)*16; - for(int i=0; i<2; i++) SPTR_QuantScales[Offset_Shared+i] = GPTR_A_Scales[Offset_Global+i*8]; + const half* GPTR_A_Scales) +{ + int lane_id = threadIdx.x % WARP_SIZE; + int Offset_Shared = lane_id * 2; + int Offset_Global = lane_id / 4 + (lane_id % 4) * 16; + for (int i = 0; i < 2; i++) + SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8]; } -/* +/* * (1) Copying X rows * 64 columns of FP16 values, originally in row major * (2) Copying 64 rows * X columns of FP16 values, originally in column major * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads */ -template -__device__ __forceinline__ void CopyFromGlobalToShared(half __restrict__ (*SharedPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - const half* GlobalPTR, - const int GlobalStride, - const int NumOfLinesLeft, // To support arbitrary N dimensions. - bool Pred = true) { +template +__device__ __forceinline__ void CopyFromGlobalToShared( + half __restrict__ (*SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, + const int GlobalStride, + const int NumOfLinesLeft, // To support arbitrary N dimensions. + bool Pred = true) +{ // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time - const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; - const int NumOfGroups = NumOfThreads / 8; - const int MaxIteration = (MaxNumOfLinesToCopy-1) / NumOfGroups + 1; - // runtime variables - const int line_id = threadIdx.x / 8; - const int line_offset = (threadIdx.x%8) * 8; + const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; + const int NumOfGroups = NumOfThreads / 8; + const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1; + // runtime variables + const int line_id = threadIdx.x / 8; + const int line_offset = (threadIdx.x % 8) * 8; // PTR for source global memory and target shared memory GlobalPTR += line_id * GlobalStride + line_offset; SharedPTR += line_id; - #pragma unroll +#pragma unroll for (int i = 0; i < MaxIteration; i++) { - bool AsyncCopyPred = (line_id+i*NumOfGroups) < NumOfLinesLeft && Pred; - cp_async<16>( &(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred; + cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); // GlobalPTR += NumOfGroups * GlobalStride; SharedPTR += NumOfGroups; } } -#endif \ No newline at end of file +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh index 9627741ce317..cfc25dea9180 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh @@ -10,7 +10,8 @@ * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) { +__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2) +{ *R2 = *R1 & 0x80808080; *R1 = *R1 >> 2; *R1 = *R1 & 0x1f1f1f1f; @@ -25,7 +26,8 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t *R1, u_int32_t *R2) * Outputs: R1, R2 * Note: Simplified Exponent calculation is NOT applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_t *R2) { +__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t* R1, u_int32_t* R2) +{ //*R2 = *R1 & 0x80808080; *R2 = *R1 & 0xc0c0c0c0; *R1 = *R1 >> 2; @@ -36,76 +38,83 @@ __device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t *R1, u_int32_ //*R1 = *R2 & 0x9f009f00; //*R2 = *R2 & 0x009f009f; *R1 = *R2 & 0xcf00cf00; - if( !(*R1 & 0x40000000) && (*R1 & 0x0c000000) ) *R1 = *R1 | 0x30000000; - if( !(*R1 & 0x00004000) && (*R1 & 0x00000c00) ) *R1 = *R1 | 0x00003000; + if (!(*R1 & 0x40000000) && (*R1 & 0x0c000000)) *R1 = *R1 | 0x30000000; + if (!(*R1 & 0x00004000) && (*R1 & 0x00000c00)) *R1 = *R1 | 0x00003000; *R2 = *R2 & 0x00cf00cf; - if( !(*R2 & 0x00400000) && (*R2 & 0x000c0000) ) *R2 = *R2 | 0x00300000; - if( !(*R2 & 0x00000040) && (*R2 & 0x0000000c) ) *R2 = *R2 | 0x00000030; + if (!(*R2 & 0x00400000) && (*R2 & 0x000c0000)) *R2 = *R2 | 0x00300000; + if (!(*R2 & 0x00000040) && (*R2 & 0x0000000c)) *R2 = *R2 | 0x00000030; // *R2 = *R2 << 8; //*R1 = 0x3c003c00; //*R2 = 0x3c003c00; } -__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) { +__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) +{ half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; uint32_t output; half* output_half_ptr = reinterpret_cast(&output); - output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(4096.0f)), Scale); - output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(4096.0f)), Scale); + output_half_ptr[0] = __hmul(__hmul(*FP16_1, __float2half(4096.0f)), Scale); + output_half_ptr[1] = __hmul(__hmul(*FP16_2, __float2half(4096.0f)), Scale); return output; } -__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], - u_int32_t __restrict__ *read_RPTR_Frag1, - u_int32_t __restrict__ *read_RPTR_Frag2, - u_int32_t *Scales) { - u_int32_t *OutputRegs = reinterpret_cast (Reg); - u_int32_t *Frag1_PTR = read_RPTR_Frag1; - u_int32_t *Frag2_PTR = read_RPTR_Frag2; - half *Scale_RPTR = reinterpret_cast(Scales); - u_int32_t Packed_FP6 = 0; - u_int32_t tmp = 0; - // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 - #pragma unroll(8) - for(int i=0; i<8; i++) { +__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], + u_int32_t __restrict__* read_RPTR_Frag1, + u_int32_t __restrict__* read_RPTR_Frag2, + u_int32_t* Scales) +{ + u_int32_t* OutputRegs = reinterpret_cast(Reg); + u_int32_t* Frag1_PTR = read_RPTR_Frag1; + u_int32_t* Frag2_PTR = read_RPTR_Frag2; + half* Scale_RPTR = reinterpret_cast(Scales); + u_int32_t Packed_FP6 = 0; + u_int32_t tmp = 0; +// Dequantizing 32 FP6, each Loop dequantizing 4 FP6 +#pragma unroll(8) + for (int i = 0; i < 8; i++) { // Frag1 Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; - if(i%4==3) Frag1_PTR++; - else (*Frag1_PTR) = (*Frag1_PTR) << 2; + if (i % 4 == 3) + Frag1_PTR++; + else + (*Frag1_PTR) = (*Frag1_PTR) << 2; // Frag2 tmp = (*Frag2_PTR) & 0xf0f0f0f0; tmp = tmp >> 2; - if(i%2==1) Frag2_PTR++; - else (*Frag2_PTR) = (*Frag2_PTR) << 4; + if (i % 2 == 1) + Frag2_PTR++; + else + (*Frag2_PTR) = (*Frag2_PTR) << 4; // Packed_FP6 Packed_FP6 = Packed_FP6 | tmp; // FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); // - *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0] ); // Muliply FP16 scales + *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0]); // Muliply FP16 scales OutputRegs += 1; - *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales OutputRegs += 1; // Updating offset for FP16 scales for every two iterations - if(i%2==1) Scale_RPTR += 2; + if (i % 2 == 1) Scale_RPTR += 2; } - } /* - * + * */ -__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, half* WARP_SPTR_Scales) { +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, + half* WARP_SPTR_Scales) +{ int lane_id = threadIdx.x % WARP_SIZE; uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); uint32_t tmpReg = SPTR_uint[lane_id]; - #pragma unroll - for(int i=0; i<4; i++) { - // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); - Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); +#pragma unroll + for (int i = 0; i < 4; i++) { + // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); } } -#endif \ No newline at end of file +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh index 7f120fdcb303..342beaf518e9 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh @@ -11,170 +11,256 @@ * limitations under the License. ***************************************************************************/ #include "Configs.h" -#include "Utils_GMem.cuh" #include "Utils_Core.cuh" +#include "Utils_GMem.cuh" /* * C = A*B * A: row major with ahead-of-time layout transformation, FP6 * B: col major, FP16 * C: col major, FP16 - */ - template -__global__ void QUANT_GEMM_Kernel(const uint4* Weight1, const uint4* Weight2, const half* Scales, - const half *B, + */ +template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, OutputDataType* C, - const size_t M_Global, const size_t N_Global, const size_t K_Global, - int Split_K) + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) { - #ifdef DEBUG_MODE - assert(K_Global%TilingConfig::TILE_K==0); - assert(M_Global%TilingConfig::TILE_M==0); - assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M)); - #endif - extern __shared__ __align__(128) half smem[]; // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned - half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + (SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE)/2 ); // Dynamic shared memory for FP16 B tiles - __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes - // Thread Block Mapping, considering SplitK - const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); - const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) - const size_t y = blockIdx.y % (M_Global/TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) - const size_t Tile_Start_M = y * TilingConfig::TILE_M; - const size_t Tile_Start_N = x * TilingConfig::TILE_N; - const size_t NumColumnToCopy = (N_Global-Tile_Start_N) < TilingConfig::TILE_N ? (N_Global-Tile_Start_N) : TilingConfig::TILE_N; - const size_t NumBlock_K = K_Global/TilingConfig::TILE_K; - const size_t AverageNumBlock_K = NumBlock_K/Split_K; - const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; - size_t NumIter = AverageNumBlock_K; - if(BatchID(smem); - uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR+SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers - // StartSPTR for each WARP - AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4; - AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4; - // Pre-fetch of A tile - for(int i=0; i(AFrag_2BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4, WARP_StartGPTR_A1); - CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4, WARP_StartGPTR_A2); - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; - } - // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// - const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; - const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; - CopyFromGlobalToShared_Scales(QuantScales+WARP_i*64, WARP_StartGPTR_A_Scales); - // Copying B tile from Global to Shared, considering SplitK ///////////////////////////////////////////////////////////// - const half *BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; - for(int i=0; i (smem_array+i*TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); - BTile_GPTR += TilingConfig::TILE_K; - } - // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// - constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block - constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block +#ifdef DEBUG_MODE + assert(K_Global % TilingConfig::TILE_K == 0); + assert(M_Global % TilingConfig::TILE_M == 0); + assert(gridDim.y == Split_K * (M_Global / TilingConfig::TILE_M)); +#endif + extern __shared__ __align__(128) + half smem[]; // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned + half(*smem_array)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + reinterpret_cast( + smem + + (SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE) / 2); // Dynamic shared memory for FP16 B tiles + __shared__ half QuantScales[64 * TilingConfig::BLOCK_WARPS]; // static shared memory for + // quantization scales, 64 row per + // warp * 4 warps = 512 Bytes + // Thread Block Mapping, considering SplitK + const size_t BatchID = blockIdx.y / (M_Global / TilingConfig::TILE_M); + const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t y = + blockIdx.y % + (M_Global / TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t Tile_Start_M = y * TilingConfig::TILE_M; + const size_t Tile_Start_N = x * TilingConfig::TILE_N; + const size_t NumColumnToCopy = (N_Global - Tile_Start_N) < TilingConfig::TILE_N + ? (N_Global - Tile_Start_N) + : TilingConfig::TILE_N; + const size_t NumBlock_K = K_Global / TilingConfig::TILE_K; + const size_t AverageNumBlock_K = NumBlock_K / Split_K; + const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; + size_t NumIter = AverageNumBlock_K; + if (BatchID < ExtraNumBlock_K) NumIter++; + size_t StartBlockID_K = AverageNumBlock_K * BatchID; + if (BatchID < ExtraNumBlock_K) + StartBlockID_K += BatchID; + else + StartBlockID_K += ExtraNumBlock_K; + // Warp ID. + const int warpId = threadIdx.x / WARP_SIZE; + int WARP_i = + warpId / TilingConfig::BLOCK_COL_WARPS; // WARP_i: row number; WARP_j: column number + // int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + // Global Memory Address for Matrix A (Weight) + // ///////////////////////////////////////////////////////////////////////// StartPTR for each + // ThreadBlock(TB) + const uint4* TB_StartGPTR_A1 = + Weight1 + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + const uint4* TB_StartGPTR_A2 = + Weight2 + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // StartPTR for each WARP. + const uint4* WARP_StartGPTR_A1 = + TB_StartGPTR_A1 + WARP_i * NumBlock_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + const uint4* WARP_StartGPTR_A2 = + TB_StartGPTR_A2 + WARP_i * NumBlock_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // StartPTR for each WARP, considering SplitK + const size_t WARP_Start_UnitID_K = StartBlockID_K; + WARP_StartGPTR_A1 += WARP_Start_UnitID_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + WARP_StartGPTR_A2 += WARP_Start_UnitID_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // Copying A tile from Global to Shared, using double-buffer + // ////////////////////////////////////////////////////////// StartSPTR for each ThreadBlock + uint32_t* AFrag_2BIT_SPTR = reinterpret_cast(smem); + uint32_t* AFrag_4BIT_SPTR = + AFrag_2BIT_SPTR + + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * TilingConfig::BLOCK_WARPS * + PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers + // StartSPTR for each WARP + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4; + // Pre-fetch of A tile + for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) { + CopyFromGlobalToShared_A( + AFrag_2BIT_SPTR + i * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * 4, WARP_StartGPTR_A1); + CopyFromGlobalToShared_A( + AFrag_4BIT_SPTR + i * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * 4, WARP_StartGPTR_A2); + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; + } + // Global Memory Address for Matrix A (QuantScale) + // ///////////////////////////////////////////////////////////////////// + const half* TB_StartGPTR_A_Scale = Scales + (y * TilingConfig::BLOCK_ROW_WARPS) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales + WARP_i * 64, WARP_StartGPTR_A_Scales); + // Copying B tile from Global to Shared, considering SplitK + // ///////////////////////////////////////////////////////////// + const half* BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) { + CopyFromGlobalToShared( + smem_array + i * TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); + BTile_GPTR += TilingConfig::TILE_K; + } + // Register Allocation for A,B, and C, Initilazed to Zeros + // ///////////////////////////////////////////////////////////////////// + constexpr int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block #ifdef PIPELINE_LEVEL_SMEM - uint32_t a [NumRegSets_a * PIPELINE_LEVEL_SMEM][4]; // double/Trible buffer is used // Registers to store decompressed FP6 - uint32_t b [NumRegSets_b * PIPELINE_LEVEL_SMEM][4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) + uint32_t a[NumRegSets_a * PIPELINE_LEVEL_SMEM] + [4]; // double/Trible buffer is used // Registers to store decompressed FP6 + uint32_t b[NumRegSets_b * PIPELINE_LEVEL_SMEM] + [4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) #endif - float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; - for(int i=0; i(a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); + // Initializing the Software Pipeline: writing registers. + // //////////////////////////////////////////////////////////////////////////////////////////////// + initialize_mma_slice( + a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); #endif - // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - #pragma unroll(1) - for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) - { - // Trible-Buffer for A Tile - uint32_t* __restrict__ read_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ read_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 +// The outer loop. +// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma unroll(1) + for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) { + // Trible-Buffer for A Tile + uint32_t* __restrict__ read_SPTR_Frag1 = + AFrag_2BIT_SPTR + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * + 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag2 = + AFrag_4BIT_SPTR + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * + 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 #ifdef PIPELINE_LEVEL_SMEM - uint32_t* __restrict__ read2_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; - uint32_t* __restrict__ read2_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; + uint32_t* __restrict__ read2_SPTR_Frag1 = + AFrag_2BIT_SPTR + + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * 4; + uint32_t* __restrict__ read2_SPTR_Frag2 = + AFrag_4BIT_SPTR + + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * 4; #endif - uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - // Trible-Buffer for B Tile - half __restrict__ (*read_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + uint32_t* __restrict__ write_SPTR_Frag1 = + AFrag_2BIT_SPTR + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * + 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag2 = + AFrag_4BIT_SPTR + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * + 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + // Trible-Buffer for B Tile + half __restrict__(*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; #ifdef PIPELINE_LEVEL_SMEM - half __restrict__ (*read2_SPTR )[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + half __restrict__(*read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#endif + half __restrict__(*write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + // + bool GlobalCopy = (tile_id_k + PIPELINE_LEVEL_GMEM - 1) < NumIter; + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + CopyFromGlobalToShared_A( + write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); + CopyFromGlobalToShared_A( + write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); + // copying B tile from GlobalMemory to SharedMemory + CopyFromGlobalToShared( + write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + cp_async_group_commit(); +#ifdef PIPELINE_LEVEL_SMEM + core_mma_slice(c, + a, + b, + read_SPTR_Frag1, + read_SPTR_Frag2, + read_SPTR, + Scales_RPTR, + 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each + // WARP; read_SPTR is shared among WARPs + core_mma_slice( + c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); + core_mma_slice( + c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + core_mma_slice( + c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); + // Updating global PTRs + WARP_StartGPTR_A1 += + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; +#else + PipelinedCoreLoop( + c, + read_SPTR, + read_SPTR_Frag1, + read_SPTR_Frag2, + Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; + // read_SPTR is shared among WARPs + // Updating global PTRs + WARP_StartGPTR_A1 += + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); #endif - half __restrict__ (*write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; - // - bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; - // Copying A tile from Global to Register, Bypassing L1, using double-buffer - CopyFromGlobalToShared_A(write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); - CopyFromGlobalToShared_A(write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); - // copying B tile from GlobalMemory to SharedMemory - CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); - cp_async_group_commit(); - #ifdef PIPELINE_LEVEL_SMEM - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); - // Barriers and Synchronizations - cp_async_wait_group(); - __syncthreads(); - core_mma_slice(c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); - // Updating global PTRs - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 - BTile_GPTR += TilingConfig::TILE_K; - #else - PipelinedCoreLoop(c, read_SPTR, read_SPTR_Frag1, read_SPTR_Frag2, Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs - // Updating global PTRs - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 - BTile_GPTR += TilingConfig::TILE_K; - // Barriers and Synchronizations - cp_async_wait_group(); - __syncthreads(); - #endif - } - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Store the C fragments to shared memory. - float (*smem_CFrag) [TilingConfig::TILE_M+PADDING_SHARED_MEM_FOR_C_4] = - reinterpret_cast (smem); - StoreToSharedMemoryFromRegister(smem_CFrag, c); - __syncthreads(); - // Now that shared memory contains all the D tiles, stream them to global memory. - OutputDataType* BlockGlobalPTR = C + BatchID*(M_Global*N_Global) + Tile_Start_M + Tile_Start_N*M_Global; - for(size_t i=warpId; i::value) BlockGlobalPTR[j+i*M_Global] = __float2half_rn(smem_CFrag[i][j]); - else BlockGlobalPTR[j+i*M_Global] = smem_CFrag[i][j]; } -} \ No newline at end of file + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store the C fragments to shared memory. + float(*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4] = + reinterpret_cast(smem); + StoreToSharedMemoryFromRegister(smem_CFrag, c); + __syncthreads(); + // Now that shared memory contains all the D tiles, stream them to global memory. + OutputDataType* BlockGlobalPTR = + C + BatchID * (M_Global * N_Global) + Tile_Start_M + Tile_Start_N * M_Global; + for (size_t i = warpId; i < NumColumnToCopy; i += TilingConfig::BLOCK_WARPS) // i-th column +#pragma unroll + for (size_t j = threadIdx.x % WARP_SIZE; j < TilingConfig::TILE_M; + j += WARP_SIZE) // j-th row + { + if constexpr (std::is_same::value) + BlockGlobalPTR[j + i * M_Global] = __float2half_rn(smem_CFrag[i][j]); + else + BlockGlobalPTR[j + i * M_Global] = smem_CFrag[i][j]; + } +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh index 442de103b8d1..c1d5915abecb 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh @@ -14,34 +14,39 @@ // Reduction_Workspace: (Split_K, M_Global, N_Global), column major // C: (M_Global, N_Global), column major // Each thread deals with 8 output elements, each elements is the sum of Split_K elements -// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> 256 float per warp -// Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 half_per_thread (128bit) -> 256 half per warp +// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> +// 256 float per warp Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 +// half_per_thread (128bit) -> 256 half per warp // GridSize = (M_Global*N_Global) / 256 #include #include #include -#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 -#define HALF_PER_128BIT 8 +#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 +#define HALF_PER_128BIT 8 -__global__ void SplitK_Reduction(half* C, float* Reduction_Workspace, size_t M_Global, size_t N_Global, int Split_K) +__global__ void SplitK_Reduction(half* C, + float* Reduction_Workspace, + size_t M_Global, + size_t N_Global, + int Split_K) { - half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; - half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; - float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; // Initializing Thread-Local Results float Results[HALF_PER_128BIT]; - #pragma unroll - for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; +#pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; // Reduction for (int i = 0; i < Split_K; i++) { - #pragma unroll - for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; +#pragma unroll + for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; THREAD_GPTR_R += M_Global * N_Global; } - // Writing to global memory - #pragma unroll - for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); +// Writing to global memory +#pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h index 76ff5bbb6b8b..3f84a486070a 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h @@ -194,4 +194,4 @@ void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K) BitInterleaving_2bit(Weight_2bit + 4 * i); for (size_t i = 0; i < BytesPerThread_4bit * 32 / 4; i++) BitInterleaving_4bit(Weight_4bit + 4 * i); -} \ No newline at end of file +} diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py index 5e8618f22c22..ebdb59bca920 100644 --- a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py +++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py @@ -138,8 +138,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) try: _ = layer_container.is_populated except ValueError as e: - raise ValueError( - f"Layer container {l_name} is not populated.") from e + raise ValueError(f"Layer container {l_name} is not populated.") from e layer_metadata = LayerMetadata() @@ -157,8 +156,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=param.stride(), offset=cur_offset) - cur_offset += pad_to_aligned_offset( - elem_size(param.dtype) * param.numel()) + cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) for t_name, tensor in param.aux_attrs.items(): param_metadata.aux_params[t_name] = TensorMetadata(dtype=str(tensor.dtype), @@ -166,8 +164,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=tensor.stride(), offset=cur_offset) - cur_offset += pad_to_aligned_offset( - elem_size(tensor.dtype) * tensor.numel()) + cur_offset += pad_to_aligned_offset(elem_size(tensor.dtype) * tensor.numel()) layer_metadata.params[p_name] = param_metadata @@ -181,8 +178,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) l_name = "non_transformer" total_size = process_layer(non_transformer_container, l_name, total_size) - buffer = torch.empty(total_size, dtype=torch.uint8, - device=get_accelerator().current_device()) + buffer = torch.empty(total_size, dtype=torch.uint8, device=get_accelerator().current_device()) def copy_layer(layer_container: LayerContainer, l_name: str) -> None: """ @@ -210,13 +206,11 @@ def copy_layer(layer_container: LayerContainer, l_name: str) -> None: aux_params = {} for t_name, tensor in param.aux_attrs.items(): - t_view = alloc_fn( - tensor, buffer, p_metadata.aux_params[t_name].offset) + t_view = alloc_fn(tensor, buffer, p_metadata.aux_params[t_name].offset) aux_params[t_name] = t_view t_view.copy_(tensor) - setattr(layer_container, p_name, - InferenceParameter.initialize(core_param, **aux_params)) + setattr(layer_container, p_name, InferenceParameter.initialize(core_param, **aux_params)) for i, layer in enumerate(transformer_containers): l_name = f"transformer_layer_{i}" @@ -265,23 +259,19 @@ def restore_layer(layer_container: LayerContainer, l_name: str) -> None: layer_container.direct_injection(p_name, None) continue - dummy_tensor = torch.empty( - [], dtype=STR_TO_DTYPE[p_metadata.core_param.dtype]) + dummy_tensor = torch.empty([], dtype=STR_TO_DTYPE[p_metadata.core_param.dtype]) core_param = alloc_fn(p_metadata.core_param.shape, p_metadata.core_param.strides, dummy_tensor, buffer, p_metadata.core_param.offset) aux_params = {} for t_name, t_metadata in p_metadata.aux_params.items(): - dummy_tensor = torch.empty( - [], dtype=STR_TO_DTYPE[t_metadata.dtype]) - t_view = alloc_fn(t_metadata.shape, t_metadata.strides, - dummy_tensor, buffer, t_metadata.offset) + dummy_tensor = torch.empty([], dtype=STR_TO_DTYPE[t_metadata.dtype]) + t_view = alloc_fn(t_metadata.shape, t_metadata.strides, dummy_tensor, buffer, t_metadata.offset) aux_params[t_name] = t_view - restored_param = InferenceParameter.initialize( - core_param, **aux_params) + restored_param = InferenceParameter.initialize(core_param, **aux_params) layer_container.direct_injection(p_name, restored_param) for i, layer in enumerate(transformer_containers): diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index 1ddf34d3920a..d176206f3c60 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -49,8 +49,7 @@ def instantiate_attention(attention_config: DSSelfAttentionConfig, """ # Currently, we only have one implementation, so we just return it. - config = ConfigBundle(name="dense_blocked_attention", - config=attention_config) + config = ConfigBundle(name="dense_blocked_attention", config=attention_config) return DSSelfAttentionRegistry.instantiate_config(config) @@ -93,11 +92,9 @@ def instantiate_linear(linear_config: DSLinearConfig, engine_config: RaggedInfer else: # Currently, we only support ``quantized_wf6af16_linear``. if quantization_mode == "wf6af16": - config = ConfigBundle( - name="quantized_wf6af16_linear", config=linear_config) + config = ConfigBundle(name="quantized_wf6af16_linear", config=linear_config) else: - raise ValueError( - f"Unsupported quantization mode: {quantization_mode}") + raise ValueError(f"Unsupported quantization mode: {quantization_mode}") return DSLinearRegistry.instantiate_config(config) diff --git a/deepspeed/inference/v2/modules/implementations/linear/__init__.py b/deepspeed/inference/v2/modules/implementations/linear/__init__.py index 2843f8bf187a..0501af54c4e6 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/__init__.py +++ b/deepspeed/inference/v2/modules/implementations/linear/__init__.py @@ -5,4 +5,3 @@ from .blas_fp_linear import BlasFPLinear from .quantized_linear import QuantizedWf6Af16Linear, fp_quantize - diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index b6d43dba43bd..3c3da5c3d981 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -23,10 +23,12 @@ from ....inference_parameter import InferenceParameter -def fp_quantize( - input: torch.FloatTensor, num_bits: int = 6, exp_bits: int = 3, - min_value: torch.FloatTensor = None, max_value: torch.FloatTensor = None, - group_size: int = -1): +def fp_quantize(input: torch.FloatTensor, + num_bits: int = 6, + exp_bits: int = 3, + min_value: torch.FloatTensor = None, + max_value: torch.FloatTensor = None, + group_size: int = -1): """ Args: inputs (`torch.FloatTensor`) @@ -39,20 +41,19 @@ def fp_quantize( Used for static activation quantization group_size (int) N The quantization block size, each N numbers has its own scaling factor and off-site - -1 means use the last dim as the group_size + -1 means use the last dim as the group_size Returns: quantized_fake_fp6 The quantized weights, in fp16 format and contains fp6 value. scales Quantization scales """ - assert (min_value is None and max_value is None) or ( - min_value is not None and max_value is not None) + assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None) assert input.dtype == torch.float16 orig_device = input.device - input = input.to(torch.float32).cuda() + input = input.to(torch.float32).to(get_accelerator().current_device()) if num_bits == 6: if exp_bits == 3: # this is defulat q_range = 28 @@ -71,8 +72,7 @@ def fp_quantize( input = input.reshape(num_groups, -1) if min_value is None: - max_input = torch.amax( - torch.abs(input), dim=-1).view(num_groups, -1) + max_input = torch.amax(torch.abs(input), dim=-1).view(num_groups, -1) else: max_input = torch.max(min_value.abs(), max_value) # .view(-1) scales = max_input / q_range # q_range + 1 @@ -80,13 +80,11 @@ def fp_quantize( scaled_input = input / scales # torch.cuda.synchronize() # for some reason this is needed to avoid the output being 0 - quantized_fake_fp6 = float_quantize( - scaled_input, exp_bits, man_bits, rounding="nearest") + quantized_fake_fp6 = float_quantize(scaled_input, exp_bits, man_bits, rounding="nearest") # TODO: it seems the `float_quantize` will not clamp the value into the range of FP6 correctly. # To double check it. If it is true, we need to clamp it manually. - quantized_fake_fp6 = quantized_fake_fp6.reshape( - input_shape).contiguous().to(torch.float16).to(orig_device) + quantized_fake_fp6 = quantized_fake_fp6.reshape(input_shape).contiguous().to(torch.float16).to(orig_device) scales = scales.to(torch.float16).to(orig_device) # Now the dequantized value is quantized_fake_fp6 * scales @@ -115,14 +113,12 @@ def supports_config(config: DSLinearConfig) -> bool: if is_gated(config.activation): try: - _ = CUDAGatedActivation( - config.out_channels, config.output_dtype, config.activation) + _ = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) except ValueError: return False else: try: - _ = CUDABiasActivation( - config.out_channels, config.output_dtype, config.activation) + _ = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) except ValueError: return False @@ -139,8 +135,7 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self.M = self._config.out_channels * 2 self.K = self._config.in_channels self._is_gated = True - self._act_fn = CUDAGatedActivation( - config.out_channels, config.output_dtype, config.activation) + self._act_fn = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) self._double_buffer = torch.empty((config.max_tokens, config.out_channels * 2), dtype=config.output_dtype, device=get_accelerator().current_device()) @@ -148,8 +143,7 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self.M = self._config.out_channels self.K = self._config.in_channels self._is_gated = False - self._act_fn = CUDABiasActivation( - config.out_channels, config.output_dtype, config.activation) + self._act_fn = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) self._output = torch.empty((config.max_tokens, config.out_channels), dtype=config.output_dtype, @@ -176,8 +170,7 @@ def transform_param(self, param: torch.Tensor) -> InferenceParameter: if param.ndim == 1: # bias, do nothing return InferenceParameter.initialize(param) - quantized_fake_fp6, scales = self.quantizer( - param, num_bits=6, exp_bits=3) + quantized_fake_fp6, scales = self.quantizer(param, num_bits=6, exp_bits=3) if self.DEBUG: self.weight_dequantized = quantized_fake_fp6 * scales @@ -196,39 +189,33 @@ def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torc weights_2bit = w weights_4bit = w.weights_4bit scales = w.scales - output = empty_from( - self._output, (hidden_states.shape[0], self._config.out_channels)) + output = empty_from(self._output, (hidden_states.shape[0], self._config.out_channels)) if self._is_gated: - staging_output = empty_from( - self._double_buffer, (hidden_states.shape[0], self.M)) - self._linear_impl(staging_output, hidden_states, weights_2bit, - weights_4bit, scales, self.M, hidden_states.shape[0], self.K) + staging_output = empty_from(self._double_buffer, (hidden_states.shape[0], self.M)) + self._linear_impl(staging_output, hidden_states, weights_2bit, weights_4bit, scales, self.M, + hidden_states.shape[0], self.K) self._act_fn(output, staging_output, b) else: - self._linear_impl(output, hidden_states, weights_2bit, - weights_4bit, scales, self.M, hidden_states.shape[0], self.K) + self._linear_impl(output, hidden_states, weights_2bit, weights_4bit, scales, self.M, + hidden_states.shape[0], self.K) self._act_fn(output, b) if self.DEBUG: orig_device = self.weight_dequantized.device self.weight_dequantized = self.weight_dequantized.to(output.device) - ground_truth = torch.nn.functional.linear( - hidden_states, self.weight_dequantized, b) + ground_truth = torch.nn.functional.linear(hidden_states, self.weight_dequantized, b) self.weight_dequantized = self.weight_dequantized.to(orig_device) shape = (hidden_states.shape[0], self.M, self.K) if self._is_gated: - ismatch = torch.allclose( - ground_truth, staging_output, rtol=1e-3) + ismatch = torch.allclose(ground_truth, staging_output, rtol=1e-3) abs_diff = torch.max(torch.abs(ground_truth - staging_output)) - rel_diff = torch.max( - torch.abs((ground_truth - staging_output) / ground_truth)) + rel_diff = torch.max(torch.abs((ground_truth - staging_output) / ground_truth)) print(f"Linear shape: {shape}:\n\tIs correct: {ismatch}. " f"Max diff (abs, rel): ({abs_diff}, {rel_diff})") else: ismatch = torch.allclose(ground_truth, output, rtol=1e-3) abs_diff = torch.max(torch.abs(ground_truth - output)) - rel_diff = torch.max( - torch.abs((ground_truth - output) / ground_truth)) + rel_diff = torch.max(torch.abs((ground_truth - output) / ground_truth)) print(f"Linear shape: {shape}:\n\tIs correct: {ismatch}. " f"Max diff (abs, rel): ({abs_diff}, {rel_diff})") diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index 64266cdf5e06..a9211d3f39c7 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -23,24 +23,20 @@ def is_compatible(self, verbose=True): try: import torch except ImportError: - self.warning( - "Please install torch if trying to pre-compile inference kernels") + self.warning("Please install torch if trying to pre-compile inference kernels") return False cuda_okay = True if not self.is_rocm_pytorch() and torch.cuda.is_available(): # ignore-cuda sys_cuda_major, _ = installed_cuda_version() torch_cuda_major = int(torch.version.cuda.split('.')[0]) - cuda_capability = torch.cuda.get_device_properties( - 0).major # ignore-cuda + cuda_capability = torch.cuda.get_device_properties(0).major # ignore-cuda if cuda_capability < 6: - self.warning( - "NVIDIA Inference is only supported on Pascal and newer architectures") + self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False if cuda_capability >= 8: if torch_cuda_major < 11 or sys_cuda_major < 11: - self.warning( - "On Ampere and higher architectures please use CUDA 11+") + self.warning("On Ampere and higher architectures please use CUDA 11+") cuda_okay = False return super().is_compatible(verbose) and cuda_okay diff --git a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py index f2f37b2a2728..d62daf171a0a 100644 --- a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py +++ b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py @@ -48,8 +48,7 @@ def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, def _fp6_quant_dequant_weights(weight: torch.Tensor) -> torch.Tensor: from deepspeed.inference.v2.modules.implementations.linear.quantized_linear import fp_quantize - weight_quantized_fake_fp6, scales = fp_quantize( - weight, num_bits=6, exp_bits=3) + weight_quantized_fake_fp6, scales = fp_quantize(weight, num_bits=6, exp_bits=3) return weight_quantized_fake_fp6 * scales @@ -57,8 +56,7 @@ def quant_dequant_implementation(hidden_states: torch.Tensor, weight: torch.Tens act_type: ActivationType) -> torch.Tensor: dtype = hidden_states.dtype weight_dequantized = _fp6_quant_dequant_weights(weight) - out_states = torch.nn.functional.linear( - hidden_states, weight_dequantized, bias) + out_states = torch.nn.functional.linear(hidden_states, weight_dequantized, bias) out_states.float() if is_gated(act_type): @@ -106,8 +104,7 @@ def _fp6_quantized_linear_helper(tokens: int, bias = None # quantize and dequantize output - ref_quant_dequant_output = quant_dequant_implementation( - hidden_states, weight, bias, act_fn) + ref_quant_dequant_output = quant_dequant_implementation(hidden_states, weight, bias, act_fn) linear_config = DSLinearConfig(max_tokens=2048, in_channels=in_channels, @@ -115,11 +112,9 @@ def _fp6_quantized_linear_helper(tokens: int, activation=act_fn, input_dtype=dtype, output_dtype=dtype) - bundle = ConfigBundle(name='quantized_wf6af16_linear', - config=linear_config) + bundle = ConfigBundle(name='quantized_wf6af16_linear', config=linear_config) fp6_linear_module = DSLinearRegistry.instantiate_config(bundle) - weight_fp6 = fp6_linear_module.transform_param( - weight.clone().cpu()).to(get_accelerator().current_device_name()) + weight_fp6 = fp6_linear_module.transform_param(weight.clone().cpu()).to(get_accelerator().current_device_name()) ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) # tolerances = (4.8e-1, 3.2e-2) # tolerances for bf16 @@ -155,5 +150,4 @@ def _fp6_quantized_linear_helper(tokens: int, @pytest.mark.parametrize("use_bias", [True, False]) def test_fp6_quantized_linear_act_fn(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, use_bias: bool) -> None: - _fp6_quantized_linear_helper( - tokens, in_channels, out_channels, DtypeEnum.fp16, act_fn, use_bias=use_bias) + _fp6_quantized_linear_helper(tokens, in_channels, out_channels, DtypeEnum.fp16, act_fn, use_bias=use_bias) From c43947a273e89615e84797ce6823092d1d26e1d1 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Fri, 2 Feb 2024 08:48:28 +0000 Subject: [PATCH 13/31] Deal with the subnormal FP6 and FP16 values and refine the UT. --- .../core_ops/cuda_linear/cuda_linear.py | 4 +- .../cuda_linear/cuda_linear_kernels.cpp | 42 +++++++++++++------ .../modules/test_quantizied_linear_module.py | 26 +++++++----- 3 files changed, 46 insertions(+), 26 deletions(-) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py index 6bda28719de9..48386720a91c 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -42,9 +42,7 @@ def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2b """ if M % 256 != 0 or K % 64 != 0: - raise ValueError( - "The out and in channel of the FP6 weight-only quantized linear should be multiple of 256 and 64 respectively." - ) + raise ValueError("The out and in channel should be multiple of 256 and 64 respectively.") # TODO: optimize the heuristic of split k selection. split_k_dict = {15360: 3, 27648: 2, 5120: 10, 10240: 5, 57344: 7, 8192: 6, 21504: 5, 7168: 7, 28672: 7} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index 1bdb94787eef..96505116f217 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -58,28 +58,46 @@ void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) float fp6_value_abs = std::abs(__half2float(*((half*)(&source)))); if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) || fp6_value_abs > absmax_fp6) { + // TODO(zhen): a better way may be rounding it to the nearest FP6 value. throw std::invalid_argument("Input value out of range for FP6."); } // It is not safe to do shift operation on uint16_t. So we promote it to int. int source_promote = int(source); - int sign = (source_promote >> 15); + int sign_bit = (source_promote >> 15); // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111' - int exp = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; + int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; // Extracting mantissa represented in FP16 - int mant = source_promote & ((1 << mantissa_nbits_fp16) - 1); - - int new_exp = exp - exp_bias_fp16 + exp_bias_fp6; - if (exp == 0) { - // SUbnormal FP6 number. But the value is a normal FP16 number. Thus it needs a special - // treatment. - new_exp += 1; + int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1); + + int new_exp_bit; + int new_mant_bit; + + if (exp_bit == 0) { + // Subnormal FP16 number. Too small for FP6. + new_exp_bit = 0; + new_mant_bit = 0; + } else { + new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); + new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6; + + // Deal with subnormal FP6 values. + int target_exp_val = exp_bit - exp_bias_fp16; + int min_fp6_exp_val = -exp_bias_fp6 + 1; + bool subnormal_fp6 = target_exp_val < min_fp6_exp_val; + if (subnormal_fp6) { + // TODO(zhen): add the rounding logic. + new_exp_bit = 0; + // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we + // need to add it + new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >> + (min_fp6_exp_val - target_exp_val); + } } - int new_mant = mant >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); - fp6_temp[i] = (sign << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | - (new_exp << mantissa_nbits_fp6) | new_mant; + fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | + (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit; } // Pack the values FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); diff --git a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py index d62daf171a0a..e0064b7fb2df 100644 --- a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py +++ b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py @@ -7,6 +7,7 @@ import pytest import torch +import warnings from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum, is_gated @@ -115,21 +116,24 @@ def _fp6_quantized_linear_helper(tokens: int, bundle = ConfigBundle(name='quantized_wf6af16_linear', config=linear_config) fp6_linear_module = DSLinearRegistry.instantiate_config(bundle) weight_fp6 = fp6_linear_module.transform_param(weight.clone().cpu()).to(get_accelerator().current_device_name()) - ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) - - # tolerances = (4.8e-1, 3.2e-2) # tolerances for bf16 - # The current FP6 kernel uses FP16 Tensor Core. - tolerances = (3e-2, 2e-3) # tolerances for fp16 - - # Check DeepSpeed implementation - assert allclose(ds_output, ref_quant_dequant_output, tolerances=tolerances) + try: + ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) + except ValueError as e: + if str(e) != "The out and in channel should be multiple of 256 and 64 respectively.": + raise + else: + warnings.warn("The out and in channel should be multiple of 256 and 64 respectively. Skipping the test. " + f"tokens: {tokens}, in_channels: {in_channels}, out_channels: {out_channels}") + else: + # The current FP6 kernel uses FP16 Tensor Core. + tolerances = (3e-2, 2e-3) # tolerances for fp16 - # # Check reference implementation - # ref_output = reference_implementation(hidden_states, weight, bias, act_fn) - # assert allclose(ds_output, ref_output, tolerances=tolerances) + # Check DeepSpeed implementation + assert allclose(ds_output, ref_quant_dequant_output, tolerances=tolerances) all_acts = [ + ActivationType.IDENTITY, ActivationType.RELU, ActivationType.GELU, ActivationType.SILU, From a6d2f2f0075a0b4d1d4354568c3564273bdbbda9 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Mon, 5 Feb 2024 08:49:17 +0000 Subject: [PATCH 14/31] Update according to review comments. --- debug/run_pipeline.py | 5 ++ .../core_ops/cuda_linear/cuda_linear.py | 27 +++++------ .../cuda_linear/cuda_linear_kernels.cpp | 6 +-- .../core_ops/cuda_linear/fp6_linear.cu | 11 ++++- .../core_ops/cuda_linear/fp6_linear.cuh | 7 +++ .../include/{Configs.h => configs.h} | 15 ++++-- .../cuda_linear/include/kernel_matmul.cuh | 25 ++++------ .../cuda_linear/include/kernel_reduction.cuh | 26 +++-------- .../{PTX_cp.async.cuh => ptx_cp.async.cuh} | 19 +++----- .../include/{PTX_mma.cuh => ptx_mma.cuh} | 21 ++++----- .../{Utils_Core.cuh => utils_core.cuh} | 29 +++++------- .../{Utils_GMem.cuh => utils_gmem.cuh} | 11 ++++- ...lDequant.cuh => utils_paralleldequant.cuh} | 7 +++ .../cuda_linear/include/weight_prepacking.h | 7 +++ .../linear/quantized_linear.py | 46 +++++++++---------- op_builder/inference_core_ops.py | 4 +- .../modules/test_quantizied_linear_module.py | 23 ++++++++-- 17 files changed, 158 insertions(+), 131 deletions(-) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/{Configs.h => configs.h} (92%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/{PTX_cp.async.cuh => ptx_cp.async.cuh} (64%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/{PTX_mma.cuh => ptx_mma.cuh} (84%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/{Utils_Core.cuh => utils_core.cuh} (91%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/{Utils_GMem.cuh => utils_gmem.cuh} (92%) rename deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/{Utils_ParallelDequant.cuh => utils_paralleldequant.cuh} (95%) diff --git a/debug/run_pipeline.py b/debug/run_pipeline.py index bd0d8d8939aa..dab39e13ae86 100644 --- a/debug/run_pipeline.py +++ b/debug/run_pipeline.py @@ -1,3 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + import mii diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py index 48386720a91c..49c3ab353517 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -7,7 +7,6 @@ from ....inference_utils import DtypeEnum from deepspeed.ops.op_builder import InferenceCoreBuilder -from typing import Tuple from ... import DSKernelBase @@ -25,7 +24,7 @@ def __init__(self): self.kernel = self.inf_module.cuda_wf6af16_linear def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2bit: torch.Tensor, - weights_4bit: torch.Tensor, scale: torch.Tensor, M, N, K) -> torch.Tensor: + weights_4bit: torch.Tensor, scale: torch.Tensor, out_channels, tokens, in_channels) -> torch.Tensor: """ Matmul kernel of FP6 weight-only quantized linear. All inputs should be contiguous. It does not support batched-matmul. @@ -36,27 +35,29 @@ def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2b weights_2bit (torch.Tensor): Input tensor of the 2-bit slice. Shape is of [out_features*2/8, in_features] weights_4bit (torch.Tensor): Input tensor of the 4-bit slice. Shape is of [out_features*4/8, in_features] scale (torch.Tensor): Input tensor. Shape is of [out_features], since the scale is per output channel - M (int): The number of output channels - N (int): The number of tokens - K (int): The number of input channels + out_channels (int): The number of output channels + tokens (int): The number of tokens + in_channels (int): The number of input channels """ - if M % 256 != 0 or K % 64 != 0: + if out_channels % 256 != 0 or in_channels % 64 != 0: raise ValueError("The out and in channel should be multiple of 256 and 64 respectively.") # TODO: optimize the heuristic of split k selection. split_k_dict = {15360: 3, 27648: 2, 5120: 10, 10240: 5, 57344: 7, 8192: 6, 21504: 5, 7168: 7, 28672: 7} split_k = 1 - if not N > 128 and M in split_k_dict: - split_k = split_k_dict[M] - workspace = self.get_workspace(M, N, K, split_k, torch.float, hidden_states.device) - self.kernel(output, hidden_states, weights_2bit, weights_4bit, scale, workspace, M, N, K, split_k) - - def get_workspace(self, M: int, N: int, K: int, split_k: int, dtype, device) -> torch.Tensor: + if not tokens > 128 and out_channels in split_k_dict: + split_k = split_k_dict[out_channels] + workspace = self.get_workspace(out_channels, tokens, in_channels, split_k, torch.float, hidden_states.device) + self.kernel(output, hidden_states, weights_2bit, weights_4bit, scale, workspace, out_channels, tokens, + in_channels, split_k) + + def get_workspace(self, out_channels: int, tokens: int, in_channels: int, split_k: int, dtype, + device) -> torch.Tensor: """ Allocate workspace for the kernel. The workspace is used to store the intermediate results of the matmul before split-K. The split-K size is determined by the size of the matmul. """ - workspace = torch.empty((split_k, M, N), dtype=dtype, device=device) + workspace = torch.empty((split_k, out_channels, tokens), dtype=dtype, device=device) return workspace diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index 96505116f217..e1678a02a1a3 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -145,9 +145,9 @@ void weight_prepacing_fp16_to_fp6(uint16_t* weight_16bit, * weights_4bit: packed 4bit weights, size M*K*4/8 * scales: scale tensor, size M * workspace: workspace tensor, size M*N*split_k - * M: the output channel size of the weight + * M: the output channel number of the weight * N: the token number of the activation - * K: the input channel size of the weight + * K: the input channel number of the weight * split_k: the split size of the GEMM calculation */ void cuda_wf6af16_linear(torch::Tensor& output, @@ -212,7 +212,7 @@ std::vector preprocess_weight(torch::Tensor& weight) uint8_t* weight_2bit_ptr = weight_6bit_ptr; // Make sure that the new split tensor does not share the underlying memory with the original - // one. Otherwise it will incure some problems when the original tensor is deleted. It also + // one. Otherwise it will incur some problems when the original tensor is deleted. It also // makes the memory flattern risky. auto weight_2bit = torch::from_blob(weight_2bit_ptr, {M * K * 2 / 8}, torch::kUInt8).clone().detach(); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu index 27c0cd0579a7..5a90f4344996 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu @@ -1,3 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + #include "include/kernel_matmul.cuh" #include "include/kernel_reduction.cuh" #include "include/weight_prepacking.h" @@ -207,7 +214,7 @@ cudaError_t fp6_linear_kernel(cudaStream_t stream, /* Computes FP6-FP16 GEMM (PyTorch interface). -[Mathmatical Formula] +[Mathematical Formula] Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when @@ -217,7 +224,7 @@ calling our CUDA kernel. _in_feats: tensor of shape [B, IC]; // half _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. _scales: tensor of shape [OC]; // half - splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. + splitK: splitting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half */ diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh index 4cdbfed2fe42..95f7f6050c15 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh @@ -1,3 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + #include #include #include diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h similarity index 92% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h index aa5f9e527249..76e8eda2d35e 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Configs.h +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h @@ -1,7 +1,14 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + #ifndef CONFIGS_H #define CONFIGS_H -//#define DEBUG_MODE +// #define DEBUG_MODE #define PIPELINE_LEVEL_GMEM 2 #define PIPELINE_LEVEL_SMEM 2 // only support 2 @@ -55,8 +62,8 @@ struct TilingConfig { #define WEIGHT_FRAG1_BIT_WIDTH 2 #define WEIGHT_FRAG2_BIT_WIDTH 4 #define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH + WEIGHT_FRAG2_BIT_WIDTH) // 6 -//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // -// QuantGroupSize: 4*64 = 256 +// #define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // +// QuantGroupSize: 4*64 = 256 /*************************** 64*64 Weghts of A WARP *************************/ #define WEIGHT_PER_UNIT (WARP_M * WARP_K) // 64*64 #define SMEM_SIZE_IN_BYTES_PER_WARP_A1 \ @@ -73,7 +80,7 @@ struct TilingConfig { (SMEM_SIZE_IN_BYTES_PER_WARP_A2 * 4 * \ PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double // buffer for 2-level pipeline A= 16 KB. -/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/ +/******************** Global Memory Layout For QUANTIZED DATA ******************/ #define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT * WEIGHT_FRAG1_BIT_WIDTH / 128) // 64 #define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT * WEIGHT_FRAG2_BIT_WIDTH / 128) // 128 /******************** Register Allocation For QUANTIZED DATA ******************/ diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh index 342beaf518e9..aa6ea6c4b1c2 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh @@ -1,18 +1,13 @@ -/*************************************************************************** - * Copyright 2023 The FLash-LLM Authors. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ***************************************************************************/ -#include "Configs.h" -#include "Utils_Core.cuh" -#include "Utils_GMem.cuh" +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#include "configs.h" +#include "utils_core.cuh" +#include "utils_gmem.cuh" /* * C = A*B diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh index c1d5915abecb..8c49f8b0b3a5 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh @@ -1,23 +1,9 @@ -/*************************************************************************** - * Copyright 2023 The FLash-LLM Authors. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ***************************************************************************/ -// Used for the reduction of result matrix if Split-K is used -// Reduction_Workspace: (Split_K, M_Global, N_Global), column major -// C: (M_Global, N_Global), column major -// Each thread deals with 8 output elements, each elements is the sum of Split_K elements -// Read Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 float_per_thread (256bit) -> -// 256 float per warp Write Global: Each Warp/ThreadBlock: 32 threads_per_warp * 8 -// half_per_thread (128bit) -> 256 half per warp -// GridSize = (M_Global*N_Global) / 256 +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 #include #include diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh similarity index 64% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh index 9ab917da8183..7f36cfd5d961 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_cp.async.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh @@ -1,16 +1,9 @@ -/*************************************************************************** - * Copyright 2023 The FLash-LLM Authors. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ***************************************************************************/ -// Extended from CUTLASS's source code +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 #ifndef PTX_CP_ASYNC_CUH #define PTX_CP_ASYNC_CUH diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh similarity index 84% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh index 632be75959ed..f13abe036279 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/PTX_mma.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh @@ -1,15 +1,10 @@ -/*************************************************************************** - * Copyright 2023 The FLash-LLM Authors. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ***************************************************************************/ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + #ifndef PTX_MMA_CUH #define PTX_MMA_CUH @@ -18,7 +13,7 @@ #include #include -#include "Configs.h" +#include "configs.h" #ifdef PIPELINE_LEVEL_SMEM template diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh similarity index 91% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh index 6635b897e64f..713cebc57e33 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_Core.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh @@ -1,23 +1,18 @@ -/*************************************************************************** - * Copyright 2023 The FLash-LLM Authors. All rights reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - ***************************************************************************/ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + #ifndef UTILS_CORE_CUH #define UTILS_CORE_CUH #include -#include "Configs.h" -#include "PTX_mma.cuh" -#include "Utils_ParallelDequant.cuh" +#include "configs.h" +#include "ptx_mma.cuh" +#include "utils_paralleldequant.cuh" #ifdef PIPELINE_LEVEL_SMEM template @@ -77,7 +72,7 @@ __device__ __forceinline__ void core_mma_slice( 2; // 1 set = 4 registers, containing a 16*16 MMA block uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast( - c); // Reigsters for accumulated FP32 results + c); // Registers for accumulated FP32 results // Setting RPTRs for double buffers uint32_t(*a_read)[4] = a; @@ -153,7 +148,7 @@ __device__ __forceinline__ void PipelinedCoreLoop( : TilingConfig::WARP_COL_MMA_TENSORS / 2; // 1 set = 4 registers, containing a 16*16 MMA block - // Reigsters to store FP32 results + // Registers to store FP32 results uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh similarity index 92% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh index b1c023dae24d..62b77edaa37a 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_GMem.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh @@ -1,9 +1,16 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + #ifndef UTILS_GMEM_CUH #define UTILS_GMEM_CUH #include -#include "Configs.h" -#include "PTX_cp.async.cuh" +#include "configs.h" +#include "ptx_cp.async.cuh" /* * Copying A1/A2 from global memory to shared memory. diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh similarity index 95% rename from deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh rename to deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh index cfc25dea9180..ff13868c1347 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/Utils_ParallelDequant.cuh +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh @@ -1,3 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + #ifndef UTILS_PARALLELDEQUANT_CUH #define UTILS_PARALLELDEQUANT_CUH diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h index 3f84a486070a..c8cc7243f341 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h @@ -1,3 +1,10 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + #include #include #include diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index 3c3da5c3d981..5e460ad9336f 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -40,8 +40,8 @@ def fp_quantize(input: torch.FloatTensor, min_value/max_vlue (torch.FloatTensor) Used for static activation quantization group_size (int) N - The quantization block size, each N numbers has its own scaling factor and off-site - -1 means use the last dim as the group_size + The quantization block size, each N numbers has its own scaling + factor and off-site. -1 means use the last dim as the group_size Returns: quantized_fake_fp6 The quantized weights, in fp16 format and contains fp6 value. @@ -78,11 +78,8 @@ def fp_quantize(input: torch.FloatTensor, scales = max_input / q_range # q_range + 1 scales[scales == 0] = 1 # avoid zero scales scaled_input = input / scales - # torch.cuda.synchronize() # for some reason this is needed to avoid the output being 0 quantized_fake_fp6 = float_quantize(scaled_input, exp_bits, man_bits, rounding="nearest") - # TODO: it seems the `float_quantize` will not clamp the value into the range of FP6 correctly. - # To double check it. If it is true, we need to clamp it manually. quantized_fake_fp6 = quantized_fake_fp6.reshape(input_shape).contiguous().to(torch.float16).to(orig_device) scales = scales.to(torch.float16).to(orig_device) @@ -94,7 +91,8 @@ def fp_quantize(input: torch.FloatTensor, @DSLinearRegistry.register_module class QuantizedWf6Af16Linear(DSLinearBase): """ - Linear DSModule for FP6 weight-only quantization kernel, where weight is FP6 and activation is FP16. + Linear DSModule for FP6 weight-only quantization kernel, where weight is FP6 + and activation is FP16. """ @staticmethod @@ -106,8 +104,8 @@ def supports_config(config: DSLinearConfig) -> bool: if config.input_dtype != config.output_dtype: return False - # As for fp6 data items, they are packed and stored in a set of fp16 tensors. E.g., 8 fp6 data items are stored - # in 3 fp16 tensor. + # As for fp6 data items, they are packed and stored in a set of fp16 + # tensors. E.g., 8 fp6 data items are stored in 3 fp16 tensor. if config.input_dtype != torch.float16: return False @@ -130,18 +128,18 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self._linear_impl = CUDAWf6Af16Linear() if is_gated(config.activation): - # In the FP6 kernel implementation, the MatMul is W * A, where W is the weight and A is activation. - # M is the output channel size. - self.M = self._config.out_channels * 2 - self.K = self._config.in_channels + # In the FP6 kernel implementation, the MatMul is W * A, where W is + # the weight and A is activation. M is the output channel size. + self.out_channels = self._config.out_channels * 2 + self.in_channels = self._config.in_channels self._is_gated = True self._act_fn = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) self._double_buffer = torch.empty((config.max_tokens, config.out_channels * 2), dtype=config.output_dtype, device=get_accelerator().current_device()) else: - self.M = self._config.out_channels - self.K = self._config.in_channels + self.out_channels = self._config.out_channels + self.in_channels = self._config.in_channels self._is_gated = False self._act_fn = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) @@ -172,15 +170,15 @@ def transform_param(self, param: torch.Tensor) -> InferenceParameter: quantized_fake_fp6, scales = self.quantizer(param, num_bits=6, exp_bits=3) + # This is for debugging, will delete before release. if self.DEBUG: self.weight_dequantized = quantized_fake_fp6 * scales - # This is for debugging, will delete after release. + # This is for debugging, will delete before release. assert (quantized_fake_fp6.dtype == torch.float16) - assert quantized_fake_fp6.shape[0] == self.M - assert scales.numel() == self.M + assert quantized_fake_fp6.shape[0] == self.out_channels + assert scales.numel() == self.out_channels - # Do not delete `quantized_fake_fp6` as the `preprocess_weight` is in-place operation. weights_2bit, weights_4bit = self.preprocess_weight(quantized_fake_fp6) return InferenceParameter.initialize(weights_2bit, weights_4bit=weights_4bit, scales=scales) @@ -191,13 +189,13 @@ def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torc scales = w.scales output = empty_from(self._output, (hidden_states.shape[0], self._config.out_channels)) if self._is_gated: - staging_output = empty_from(self._double_buffer, (hidden_states.shape[0], self.M)) - self._linear_impl(staging_output, hidden_states, weights_2bit, weights_4bit, scales, self.M, - hidden_states.shape[0], self.K) + staging_output = empty_from(self._double_buffer, (hidden_states.shape[0], self.out_channels)) + self._linear_impl(staging_output, hidden_states, weights_2bit, weights_4bit, scales, self.out_channels, + hidden_states.shape[0], self.in_channels) self._act_fn(output, staging_output, b) else: - self._linear_impl(output, hidden_states, weights_2bit, weights_4bit, scales, self.M, - hidden_states.shape[0], self.K) + self._linear_impl(output, hidden_states, weights_2bit, weights_4bit, scales, self.out_channels, + hidden_states.shape[0], self.in_channels) self._act_fn(output, b) if self.DEBUG: @@ -205,7 +203,7 @@ def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torc self.weight_dequantized = self.weight_dequantized.to(output.device) ground_truth = torch.nn.functional.linear(hidden_states, self.weight_dequantized, b) self.weight_dequantized = self.weight_dequantized.to(orig_device) - shape = (hidden_states.shape[0], self.M, self.K) + shape = (hidden_states.shape[0], self.out_channels, self.in_channels) if self._is_gated: ismatch = torch.allclose(ground_truth, staging_output, rtol=1e-3) abs_diff = torch.max(torch.abs(ground_truth - staging_output)) diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index a9211d3f39c7..b1635a741950 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -27,10 +27,10 @@ def is_compatible(self, verbose=True): return False cuda_okay = True - if not self.is_rocm_pytorch() and torch.cuda.is_available(): # ignore-cuda + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda sys_cuda_major, _ = installed_cuda_version() torch_cuda_major = int(torch.version.cuda.split('.')[0]) - cuda_capability = torch.cuda.get_device_properties(0).major # ignore-cuda + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda if cuda_capability < 6: self.warning("NVIDIA Inference is only supported on Pascal and newer architectures") cuda_okay = False diff --git a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py index e0064b7fb2df..518561e88b77 100644 --- a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py +++ b/tests/unit/inference/v2/modules/test_quantizied_linear_module.py @@ -90,6 +90,13 @@ def _fp6_quantized_linear_helper(tokens: int, dtype: DtypeEnum, act_fn: ActivationType, use_bias: bool = True) -> None: + # The current FP6 kernel only supports NVIDIA Ampere GPUs. + if not 'cuda' in get_accelerator().current_device_name(): + return + major, _ = torch.cuda.get_device_capability() #ignore-cuda + if major != 8: + return + # Input vals hidden_states = torch.randn( (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 @@ -142,9 +149,19 @@ def _fp6_quantized_linear_helper(tokens: int, ActivationType.SiGLU, ] all_tokens = [1, 37, 1280] -# TODO: some of the shapes are not supported. The output channels should be a multiple of 256. -# The input channel should be a multiple of 64. -all_in_out_channels = [(4608, 1728), (8192, 4096), (3072, 6144)] +all_in_out_channels = [ + # Llama 2 7B shapes + (4096, 4096), + (4096, 11008), + (11008, 4096), + # Llama 2 70B shapes + (8192, 8192), + (8192, 28672), + (28672, 8192), + # Other shapes, not supported by FP6 kernels. Will raise ValueError. + (4608, 1728), + (3072, 6144) +] @pytest.mark.inference_v2_ops From 62a2d495a6e3367c842ae3360efcfb7b508275bb Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Tue, 6 Feb 2024 08:56:24 +0000 Subject: [PATCH 15/31] Fix the CI workflow problem for FP6 end-to-end. --- .github/workflows/nv-a6000.yml | 2 +- .../core_ops/cuda_linear/cuda_linear_kernels.cpp | 10 +++++----- .../modules/implementations/linear/quantized_linear.py | 7 ++++++- requirements/requirements-inf.txt | 1 + 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index d7db447f5d26..b707d52a55ff 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -47,7 +47,7 @@ jobs: - name: Install deepspeed run: | python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja - python -m pip install .[dev,1bit,autotuning] + python -m pip install .[dev,1bit,autotuning,inf] ds_report - name: Python environment run: | diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp index e1678a02a1a3..677bec22ded8 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -113,10 +113,10 @@ void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) * weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8 * M, K: the shape of the weight */ -void weight_prepacing_fp16_to_fp6(uint16_t* weight_16bit, - uint8_t* weight_6bit_packed, - size_t M, - size_t K) +void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, + uint8_t* weight_6bit_packed, + size_t M, + size_t K) { // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } @@ -205,7 +205,7 @@ std::vector preprocess_weight(torch::Tensor& weight) uint16_t* weight_16bit_ptr = reinterpret_cast(weight.data_ptr()); std::vector weight_6bit_packed(M * K * 6 / 8); uint8_t* weight_6bit_ptr = weight_6bit_packed.data(); - weight_prepacing_fp16_to_fp6(weight_16bit_ptr, weight_6bit_ptr, M, K); + weight_prepacking_fp16_to_fp6(weight_16bit_ptr, weight_6bit_ptr, M, K); // Split weight into 2bit and 4bit. weight_matrix_prepacking(reinterpret_cast(weight_6bit_ptr), M, K); diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index 5e460ad9336f..bfb4b112a31f 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -3,7 +3,6 @@ # DeepSpeed Team -from qtorch.quant import float_quantize from typing import Any, Dict, Optional import torch @@ -48,6 +47,12 @@ def fp_quantize(input: torch.FloatTensor, scales Quantization scales """ + + try: + from qtorch.quant import float_quantize + except ImportError: + raise ImportError("Please install qtorch to use this function") + assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None) assert input.dtype == torch.float16 diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt index 7a40ae814cbe..b7fd13787e8b 100644 --- a/requirements/requirements-inf.txt +++ b/requirements/requirements-inf.txt @@ -1,6 +1,7 @@ google lm-eval==0.3.0 protobuf +qtorch safetensors sentencepiece transformers>=4.32.1 From 118af370cb52c9e81d99fcf27241621315633058 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Tue, 6 Feb 2024 10:08:22 +0000 Subject: [PATCH 16/31] Fix at::nullopt and at::optional conflicts. --- .../v2/kernels/core_ops/cuda_linear/fp6_linear.cu | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu index 5a90f4344996..64e06a5435c6 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu @@ -5,14 +5,19 @@ // This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 +// clang-format off +// Put the torch headers at the front to avoid conflict with other headers on +// `at::nullopt` and `at::optional`. +#include +#include +// clang-format on + #include "include/kernel_matmul.cuh" #include "include/kernel_reduction.cuh" #include "include/weight_prepacking.h" -#include #include #include -#include template static void Kernel_Ex(cudaStream_t stream, From 56eb8b90d2e27a33e05a04af577e6eda248927b9 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Fri, 23 Feb 2024 15:40:52 +0000 Subject: [PATCH 17/31] Refine split-k setting. --- .../core_ops/cuda_linear/cuda_linear.py | 154 +++++++++++++++++- .../linear/quantized_linear.py | 26 --- 2 files changed, 149 insertions(+), 31 deletions(-) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py index 49c3ab353517..69aa9e8920e2 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -6,6 +6,7 @@ import torch from ....inference_utils import DtypeEnum +from ....logging import inference_logger from deepspeed.ops.op_builder import InferenceCoreBuilder from ... import DSKernelBase @@ -22,6 +23,143 @@ def __init__(self): self.inf_module = InferenceCoreBuilder().load() self.inf_module.create_handle() self.kernel = self.inf_module.cuda_wf6af16_linear + # The split_k_map is profiled on A100-80G GPU for some common shapes. + # It is an array of dictionaries, where the array index is the tokens chunk id. + # The dictionary is the mapping from the output channel to the split-K size. + self.split_k_map = [ + { # tokens: [1, 64] + 3072: 18, + 4096: 13, + 5120: 10, + 6144: 9, + 8192: 6, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 7 + }, + { # tokens: [65:128] + 3072: 9, + 4096: 6, + 5120: 5, + 6144: 9, + 8192: 3, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 6 + }, + { # tokens: [129:192] + 3072: 6, + 4096: 4, + 5120: 7, + 6144: 3, + 8192: 2, + 10240: 5, + 14336: 5, + 28672: 5, + 57344: 4 + }, + { # tokens: [193:256] + 3072: 9, + 4096: 3, + 5120: 5, + 6144: 2, + 8192: 5, + 10240: 4, + 14336: 8, + 28672: 6, + 57344: 4 + }, + { # tokens: [257:320] + 3072: 7, + 4096: 5, + 5120: 2, + 6144: 5, + 8192: 4, + 10240: 1, + 14336: 3, + 28672: 3, + 57344: 4 + }, + { # tokens: [321:384] + 3072: 3, + 4096: 2, + 5120: 5, + 6144: 3, + 8192: 1, + 10240: 8, + 14336: 3, + 28672: 4, + 57344: 3 + }, + { # tokens: [385:448] + 3072: 5, + 4096: 7, + 5120: 3, + 6144: 5, + 8192: 7, + 10240: 3, + 14336: 1, + 28672: 1, + 57344: 3 + }, + { # tokens: [449:512] + 3072: 2, + 4096: 5, + 5120: 4, + 6144: 1, + 8192: 5, + 10240: 2, + 14336: 6, + 28672: 4, + 57344: 1 + }, + { # tokens: [513:576] + 3072: 2, + 4096: 3, + 5120: 1, + 6144: 1, + 8192: 3, + 10240: 3, + 14336: 3, + 28672: 1, + 57344: 1 + }, + { # tokens: [577:640] + 3072: 5, + 4096: 4, + 5120: 1, + 6144: 4, + 8192: 2, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [641:704] + 3072: 3, + 4096: 1, + 5120: 2, + 6144: 2, + 8192: 1, + 10240: 2, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [705:768] + 3072: 3, + 4096: 1, + 5120: 3, + 6144: 2, + 8192: 1, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + } + ] def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2bit: torch.Tensor, weights_4bit: torch.Tensor, scale: torch.Tensor, out_channels, tokens, in_channels) -> torch.Tensor: @@ -43,11 +181,17 @@ def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2b if out_channels % 256 != 0 or in_channels % 64 != 0: raise ValueError("The out and in channel should be multiple of 256 and 64 respectively.") - # TODO: optimize the heuristic of split k selection. - split_k_dict = {15360: 3, 27648: 2, 5120: 10, 10240: 5, 57344: 7, 8192: 6, 21504: 5, 7168: 7, 28672: 7} - split_k = 1 - if not tokens > 128 and out_channels in split_k_dict: - split_k = split_k_dict[out_channels] + # TODO: add a more general heuristic to determine the split-K. + split_k = -1 # not initialized + if tokens <= 768: + # Try to find the split-K from the pre-profiled map. + tokens_chunk_id = (tokens - 1) // 64 + split_k = self.split_k_map[tokens_chunk_id].get(out_channels, -1) + if split_k == -1: + split_k = 1 + inference_logger().warning( + f"The split-K setting may be suboptimal for shape {tokens}x{in_channels}x{out_channels}...") + workspace = self.get_workspace(out_channels, tokens, in_channels, split_k, torch.float, hidden_states.device) self.kernel(output, hidden_states, weights_2bit, weights_4bit, scale, workspace, out_channels, tokens, in_channels, split_k) diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index bfb4b112a31f..5b8b89f95cb3 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -158,9 +158,6 @@ def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any] self.quantizer = fp_quantize - # This is for debugging, will delete after release. - self.DEBUG = False - def transform_param(self, param: torch.Tensor) -> InferenceParameter: """ Converts param to same data type as input and output. @@ -175,10 +172,6 @@ def transform_param(self, param: torch.Tensor) -> InferenceParameter: quantized_fake_fp6, scales = self.quantizer(param, num_bits=6, exp_bits=3) - # This is for debugging, will delete before release. - if self.DEBUG: - self.weight_dequantized = quantized_fake_fp6 * scales - # This is for debugging, will delete before release. assert (quantized_fake_fp6.dtype == torch.float16) assert quantized_fake_fp6.shape[0] == self.out_channels @@ -203,25 +196,6 @@ def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torc hidden_states.shape[0], self.in_channels) self._act_fn(output, b) - if self.DEBUG: - orig_device = self.weight_dequantized.device - self.weight_dequantized = self.weight_dequantized.to(output.device) - ground_truth = torch.nn.functional.linear(hidden_states, self.weight_dequantized, b) - self.weight_dequantized = self.weight_dequantized.to(orig_device) - shape = (hidden_states.shape[0], self.out_channels, self.in_channels) - if self._is_gated: - ismatch = torch.allclose(ground_truth, staging_output, rtol=1e-3) - abs_diff = torch.max(torch.abs(ground_truth - staging_output)) - rel_diff = torch.max(torch.abs((ground_truth - staging_output) / ground_truth)) - print(f"Linear shape: {shape}:\n\tIs correct: {ismatch}. " - f"Max diff (abs, rel): ({abs_diff}, {rel_diff})") - else: - ismatch = torch.allclose(ground_truth, output, rtol=1e-3) - abs_diff = torch.max(torch.abs(ground_truth - output)) - rel_diff = torch.max(torch.abs((ground_truth - output) / ground_truth)) - print(f"Linear shape: {shape}:\n\tIs correct: {ismatch}. " - f"Max diff (abs, rel): ({abs_diff}, {rel_diff})") - return output @property From 0ddbfd116ef573aa66e6dd90a3c8893214eebeb0 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Mon, 4 Mar 2024 12:03:55 +0000 Subject: [PATCH 18/31] Remove debug files. --- debug/README.md | 1 - debug/clean.sh | 2 -- debug/run_pipeline.py | 22 ---------------------- 3 files changed, 25 deletions(-) delete mode 100644 debug/README.md delete mode 100644 debug/clean.sh delete mode 100644 debug/run_pipeline.py diff --git a/debug/README.md b/debug/README.md deleted file mode 100644 index 71e9e64c05cb..000000000000 --- a/debug/README.md +++ /dev/null @@ -1 +0,0 @@ -The files in this directory is only for debugging of FP6 quantization kernel integration. Will not merge. diff --git a/debug/clean.sh b/debug/clean.sh deleted file mode 100644 index 7120c76321da..000000000000 --- a/debug/clean.sh +++ /dev/null @@ -1,2 +0,0 @@ -rm ~/.cache/torch_extensions/py38_cu118/inference_core_ops/*.o -rm ~/.cache/torch_extensions/py38_cu118/inference_core_ops/*.so diff --git a/debug/run_pipeline.py b/debug/run_pipeline.py deleted file mode 100644 index dab39e13ae86..000000000000 --- a/debug/run_pipeline.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -import mii - - -def fake_request_texts(batch_size: int): - request_texts = ["Ha ha ha"] * batch_size - return request_texts - - -if __name__ == '__main__': - model_id = "meta-llama/Llama-2-7b-hf" - - batch_size = 32 - prompts = fake_request_texts(batch_size) - - pipe = mii.pipeline(model_name_or_path=model_id, quantization_mode='wf6af16') - response = pipe(prompts, max_new_tokens=2) - print(f"{len(response)} responses.") From 35c82f25bd7ad9e5ddc10c3444a2cf6c56915f8d Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Mon, 4 Mar 2024 12:53:45 +0000 Subject: [PATCH 19/31] Only compiler the kernel body for SM >= 8.0. --- .../inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu index 64e06a5435c6..dcbb8cec9a59 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu @@ -31,6 +31,9 @@ static void Kernel_Ex(cudaStream_t stream, const size_t K_Global, int Split_K) { +// The current implementation only supports SM >= 8.0 +#if __CUDA_ARCH__ >= 800 + #ifdef DEBUG_MODE printf("\n"); printf("Launcher.cu->Kernel_Ex():\n"); @@ -66,6 +69,8 @@ static void Kernel_Ex(cudaStream_t stream, #endif QUANT_GEMM_Kernel<<>>( Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + +#endif // __CUDA_ARCH__ >= 800 } /* From 63489d17592857544977fe1816b77335a83c9dde Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Tue, 5 Mar 2024 03:17:31 +0000 Subject: [PATCH 20/31] Fix the GPU architecture requirement of FP6 kernel. --- .../kernels/core_ops/cuda_linear/fp6_linear.cu | 5 ----- op_builder/inference_core_ops.py | 17 +++++++++++++++-- ...odule.py => test_quantized_linear_module.py} | 0 3 files changed, 15 insertions(+), 7 deletions(-) rename tests/unit/inference/v2/modules/{test_quantizied_linear_module.py => test_quantized_linear_module.py} (100%) diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu index dcbb8cec9a59..64e06a5435c6 100644 --- a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu @@ -31,9 +31,6 @@ static void Kernel_Ex(cudaStream_t stream, const size_t K_Global, int Split_K) { -// The current implementation only supports SM >= 8.0 -#if __CUDA_ARCH__ >= 800 - #ifdef DEBUG_MODE printf("\n"); printf("Launcher.cu->Kernel_Ex():\n"); @@ -69,8 +66,6 @@ static void Kernel_Ex(cudaStream_t stream, #endif QUANT_GEMM_Kernel<<>>( Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); - -#endif // __CUDA_ARCH__ >= 800 } /* diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index b1635a741950..08ae711fb3fb 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -57,6 +57,12 @@ def get_prefix(self): return "deepspeed" if os.path.isdir(ds_path) else ".." def sources(self): + try: + import torch + except ImportError: + self.warning("Please install torch if trying to pre-compile inference kernels") + return False + sources = [ "inference/v2/kernels/core_ops/core_ops.cpp", "inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp", @@ -67,10 +73,17 @@ def sources(self): "inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu", "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp", "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu", - "inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu", - "inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp", ] + # The source files with specific GPU architecture requirements. + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability != 8: + self.warning("FP6 quantization kernel is only supported on Ampere architectures") + else: + sources.append("inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu") + sources.append("inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp") + prefix = self.get_prefix() sources = [os.path.join(prefix, src) for src in sources] return sources diff --git a/tests/unit/inference/v2/modules/test_quantizied_linear_module.py b/tests/unit/inference/v2/modules/test_quantized_linear_module.py similarity index 100% rename from tests/unit/inference/v2/modules/test_quantizied_linear_module.py rename to tests/unit/inference/v2/modules/test_quantized_linear_module.py From ed00ac92f4bf3baa3ff0329263f78399e6171167 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 5 Mar 2024 13:03:58 -0800 Subject: [PATCH 21/31] Update deepspeed/inference/v2/config_v2.py --- deepspeed/inference/v2/config_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 60803ee39ccd..93ccf13ffb9d 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -19,7 +19,7 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel): class QuantizationConfig(DeepSpeedConfigModel): """ Configure tensor parallelism settings """ - quantization_mode: str = None + quantization_mode: Optional[str] = None """ The quantization mode in string format. The supported modes are as follows: - 'wf6af16', weight-only quantization with FP6 weight and FP16 activation. """ From b15a1a103e56ce544fb863e7a675309b2bb4708a Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 5 Mar 2024 13:04:08 -0800 Subject: [PATCH 22/31] Update deepspeed/inference/v2/config_v2.py --- deepspeed/inference/v2/config_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 93ccf13ffb9d..1ce8bf9470a6 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -40,4 +40,4 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel): Configuration for managing persistent state """ - quantization: QuantizationConfig = Field({}, alias="quantization") + quantization: QuantizationConfig = {} From c2e6ebb9c74cb6f15d31095331ef4c0b05db299f Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 5 Mar 2024 13:51:52 -0800 Subject: [PATCH 23/31] refactor fp6 tests, fix import error --- deepspeed/inference/v2/config_v2.py | 2 +- .../modules/test_quantized_linear_module.py | 45 +++++++++++++------ 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 1ce8bf9470a6..85e4b7a0e0a0 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -3,8 +3,8 @@ # DeepSpeed Team +from typing import Optional from deepspeed.pydantic_v1 import Field - from deepspeed.runtime.config_utils import DeepSpeedConfigModel from .ragged import DSStateManagerConfig diff --git a/tests/unit/inference/v2/modules/test_quantized_linear_module.py b/tests/unit/inference/v2/modules/test_quantized_linear_module.py index 518561e88b77..b03322b5aff1 100644 --- a/tests/unit/inference/v2/modules/test_quantized_linear_module.py +++ b/tests/unit/inference/v2/modules/test_quantized_linear_module.py @@ -7,7 +7,6 @@ import pytest import torch -import warnings from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum, is_gated @@ -89,7 +88,8 @@ def _fp6_quantized_linear_helper(tokens: int, out_channels: int, dtype: DtypeEnum, act_fn: ActivationType, - use_bias: bool = True) -> None: + use_bias: bool = True, + expect_failure: bool = False) -> None: # The current FP6 kernel only supports NVIDIA Ampere GPUs. if not 'cuda' in get_accelerator().current_device_name(): return @@ -123,15 +123,13 @@ def _fp6_quantized_linear_helper(tokens: int, bundle = ConfigBundle(name='quantized_wf6af16_linear', config=linear_config) fp6_linear_module = DSLinearRegistry.instantiate_config(bundle) weight_fp6 = fp6_linear_module.transform_param(weight.clone().cpu()).to(get_accelerator().current_device_name()) - try: - ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) - except ValueError as e: - if str(e) != "The out and in channel should be multiple of 256 and 64 respectively.": - raise - else: - warnings.warn("The out and in channel should be multiple of 256 and 64 respectively. Skipping the test. " - f"tokens: {tokens}, in_channels: {in_channels}, out_channels: {out_channels}") + + if expect_failure: + with pytest.raises(ValueError) as excinfo: + ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) + assert "The out and in channel should be multiple of 256 and 64 respectively." in str(excinfo.value) else: + ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) # The current FP6 kernel uses FP16 Tensor Core. tolerances = (3e-2, 2e-3) # tolerances for fp16 @@ -158,9 +156,6 @@ def _fp6_quantized_linear_helper(tokens: int, (8192, 8192), (8192, 28672), (28672, 8192), - # Other shapes, not supported by FP6 kernels. Will raise ValueError. - (4608, 1728), - (3072, 6144) ] @@ -171,4 +166,26 @@ def _fp6_quantized_linear_helper(tokens: int, @pytest.mark.parametrize("use_bias", [True, False]) def test_fp6_quantized_linear_act_fn(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, use_bias: bool) -> None: - _fp6_quantized_linear_helper(tokens, in_channels, out_channels, DtypeEnum.fp16, act_fn, use_bias=use_bias) + _fp6_quantized_linear_helper(tokens=tokens, + in_channels=in_channels, + out_channels=out_channels, + dtype=DtypeEnum.fp16, + act_fn=act_fn, + use_bias=use_bias) + + +# Other shapes, not supported by FP6 kernels. Will raise ValueError. +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens", all_tokens) +@pytest.mark.parametrize("in_channels, out_channels", [(4608, 1728)]) +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_fp6_quantized_linear_act_fn_fail(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, + use_bias: bool) -> None: + _fp6_quantized_linear_helper(tokens=tokens, + in_channels=in_channels, + out_channels=out_channels, + dtype=DtypeEnum.fp16, + act_fn=act_fn, + use_bias=use_bias, + expect_failure=True) From fb8887c981ef02b3479373963d7da30bc99af3d0 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 5 Mar 2024 13:55:25 -0800 Subject: [PATCH 24/31] Update deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py --- .../v2/modules/implementations/linear/quantized_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index 5b8b89f95cb3..e5bb55686b6b 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -60,7 +60,7 @@ def fp_quantize(input: torch.FloatTensor, orig_device = input.device input = input.to(torch.float32).to(get_accelerator().current_device()) if num_bits == 6: - if exp_bits == 3: # this is defulat + if exp_bits == 3: # this is default q_range = 28 else: raise NotImplementedError From 77f3883d658558658ca604177ac5bbccba01b973 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 5 Mar 2024 14:04:23 -0800 Subject: [PATCH 25/31] Update requirements.txt --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 80c9f9b3287a..e083a633960a 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,7 +4,7 @@ numpy packaging>=20.0 psutil py-cpuinfo -pydantic +pydantic>=2.0.0 pynvml torch tqdm From f6bcdee00c7e45a97622230ff210cf12710165f4 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 5 Mar 2024 14:27:01 -0800 Subject: [PATCH 26/31] revert testing to fix A6000 test --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index e083a633960a..80c9f9b3287a 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -4,7 +4,7 @@ numpy packaging>=20.0 psutil py-cpuinfo -pydantic>=2.0.0 +pydantic pynvml torch tqdm From e1a4ce0401d9223d78a42da9c0286b3c84cba269 Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Tue, 5 Mar 2024 17:07:08 -0800 Subject: [PATCH 27/31] Update pydantic version --- .github/workflows/nv-a6000.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index b707d52a55ff..9c63a8dfee90 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -47,6 +47,7 @@ jobs: - name: Install deepspeed run: | python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja + python -m pip install pydantic --upgrade python -m pip install .[dev,1bit,autotuning,inf] ds_report - name: Python environment From e86611fc5911cd6aec615bd6402f4df307203b54 Mon Sep 17 00:00:00 2001 From: Michael Wyatt Date: Tue, 5 Mar 2024 17:11:40 -0800 Subject: [PATCH 28/31] fix pydantic import --- tests/unit/inference/v2/ragged/test_manager_configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/inference/v2/ragged/test_manager_configs.py b/tests/unit/inference/v2/ragged/test_manager_configs.py index bdd513445ddb..a5f270cced8c 100644 --- a/tests/unit/inference/v2/ragged/test_manager_configs.py +++ b/tests/unit/inference/v2/ragged/test_manager_configs.py @@ -5,7 +5,7 @@ import pytest -from pydantic import ValidationError +from deepspeed.pydantic_v1 import ValidationError from deepspeed.inference.v2.ragged import DSStateManagerConfig From 7e28144d7f7d68c92ad66e0f18690634919b1e61 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Wed, 6 Mar 2024 01:11:52 +0000 Subject: [PATCH 29/31] Fix some review comments. --- .../modules/implementations/linear/quantized_linear.py | 9 ++++----- op_builder/inference_core_ops.py | 6 ------ .../inference/v2/modules/test_quantized_linear_module.py | 9 +-------- 3 files changed, 5 insertions(+), 19 deletions(-) diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py index e5bb55686b6b..933cf55b2391 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -59,11 +59,10 @@ def fp_quantize(input: torch.FloatTensor, orig_device = input.device input = input.to(torch.float32).to(get_accelerator().current_device()) - if num_bits == 6: - if exp_bits == 3: # this is default - q_range = 28 - else: - raise NotImplementedError + if num_bits == 6 and exp_bits == 3: # this is default + q_range = 28 + else: + raise NotImplementedError man_bits = num_bits - exp_bits - 1 input_shape = input.shape diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index 08ae711fb3fb..061b78a57f49 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -57,12 +57,6 @@ def get_prefix(self): return "deepspeed" if os.path.isdir(ds_path) else ".." def sources(self): - try: - import torch - except ImportError: - self.warning("Please install torch if trying to pre-compile inference kernels") - return False - sources = [ "inference/v2/kernels/core_ops/core_ops.cpp", "inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp", diff --git a/tests/unit/inference/v2/modules/test_quantized_linear_module.py b/tests/unit/inference/v2/modules/test_quantized_linear_module.py index b03322b5aff1..a7bd965072ac 100644 --- a/tests/unit/inference/v2/modules/test_quantized_linear_module.py +++ b/tests/unit/inference/v2/modules/test_quantized_linear_module.py @@ -138,7 +138,6 @@ def _fp6_quantized_linear_helper(tokens: int, all_acts = [ - ActivationType.IDENTITY, ActivationType.RELU, ActivationType.GELU, ActivationType.SILU, @@ -146,16 +145,10 @@ def _fp6_quantized_linear_helper(tokens: int, ActivationType.ReGLU, ActivationType.SiGLU, ] -all_tokens = [1, 37, 1280] +all_tokens = [1, 37] all_in_out_channels = [ - # Llama 2 7B shapes (4096, 4096), - (4096, 11008), - (11008, 4096), - # Llama 2 70B shapes - (8192, 8192), (8192, 28672), - (28672, 8192), ] From f8454a08b4b706ec442b4f4d5d02885a4c7e12bf Mon Sep 17 00:00:00 2001 From: Logan Adams Date: Tue, 5 Mar 2024 17:18:20 -0800 Subject: [PATCH 30/31] Pin pydantic to latest version --- .github/workflows/nv-a6000.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index 9c63a8dfee90..960e0203919e 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -47,7 +47,7 @@ jobs: - name: Install deepspeed run: | python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja - python -m pip install pydantic --upgrade + python -m pip install pydantic==1.10.11 python -m pip install .[dev,1bit,autotuning,inf] ds_report - name: Python environment From bed775e184e9276fafcbefc53249ec4cd67fe125 Mon Sep 17 00:00:00 2001 From: "ZHENG, Zhen" Date: Wed, 6 Mar 2024 01:24:05 +0000 Subject: [PATCH 31/31] Add the missed torch import. --- op_builder/inference_core_ops.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index 061b78a57f49..3c53774d0a50 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -57,6 +57,8 @@ def get_prefix(self): return "deepspeed" if os.path.isdir(ds_path) else ".." def sources(self): + import torch + sources = [ "inference/v2/kernels/core_ops/core_ops.cpp", "inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp",