diff --git a/.github/workflows/nv-a6000.yml b/.github/workflows/nv-a6000.yml index d7db447f5d26..960e0203919e 100644 --- a/.github/workflows/nv-a6000.yml +++ b/.github/workflows/nv-a6000.yml @@ -47,7 +47,8 @@ jobs: - name: Install deepspeed run: | python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja - python -m pip install .[dev,1bit,autotuning] + python -m pip install pydantic==1.10.11 + python -m pip install .[dev,1bit,autotuning,inf] ds_report - name: Python environment run: | diff --git a/deepspeed/inference/v2/config_v2.py b/deepspeed/inference/v2/config_v2.py index 64e7e29b1844..85e4b7a0e0a0 100644 --- a/deepspeed/inference/v2/config_v2.py +++ b/deepspeed/inference/v2/config_v2.py @@ -3,8 +3,8 @@ # DeepSpeed Team +from typing import Optional from deepspeed.pydantic_v1 import Field - from deepspeed.runtime.config_utils import DeepSpeedConfigModel from .ragged import DSStateManagerConfig @@ -16,6 +16,16 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel): """ Number of devices to split the model across using tensor parallelism. """ +class QuantizationConfig(DeepSpeedConfigModel): + """ Configure tensor parallelism settings """ + + quantization_mode: Optional[str] = None + """ The quantization mode in string format. The supported modes are as follows: + - 'wf6af16', weight-only quantization with FP6 weight and FP16 activation. + """ + # TODO: may reuse the constants in deepspeed/compression/constants.py + + class RaggedInferenceEngineConfig(DeepSpeedConfigModel): """ Sets parameters for DeepSpeed Inference Engine. """ @@ -29,3 +39,5 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel): """ Configuration for managing persistent state """ + + quantization: QuantizationConfig = {} diff --git a/deepspeed/inference/v2/kernels/core_ops/__init__.py b/deepspeed/inference/v2/kernels/core_ops/__init__.py index bbb53e5b58a2..1d16b484a560 100644 --- a/deepspeed/inference/v2/kernels/core_ops/__init__.py +++ b/deepspeed/inference/v2/kernels/core_ops/__init__.py @@ -8,3 +8,4 @@ from .cuda_layer_norm import * from .cuda_rms_norm import * from .gated_activations import * +from .cuda_linear import * diff --git a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp index 58df88e56136..2397b0694696 100644 --- a/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp +++ b/deepspeed/inference/v2/kernels/core_ops/core_ops.cpp @@ -8,6 +8,7 @@ #include "bias_activation.h" #include "blas.h" +#include "cuda_linear_kernels.h" #include "gated_activation_kernels.h" #include "layer_norm.h" #include "rms_norm.h" @@ -33,4 +34,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // rms_norm.h m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA"); m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA"); + + // cuda_linear_kernels.h + m.def("cuda_wf6af16_linear", &cuda_wf6af16_linear, "DeepSpeed Wf6Af16 linear in CUDA"); + m.def( + "preprocess_weight", &preprocess_weight, "preprocess the FP16 weight to be 2bit and 4 bit"); } diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py new file mode 100644 index 000000000000..cd08409c0a7a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .cuda_linear import * diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py new file mode 100644 index 000000000000..69aa9e8920e2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py @@ -0,0 +1,207 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch + +from ....inference_utils import DtypeEnum +from ....logging import inference_logger +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ... import DSKernelBase + + +class CUDAWf6Af16Linear(DSKernelBase): + """ + Wrapper around the CUDA kernel of Wf6Af16 quantized linear. + + Performs z = x @ y + """ + supported_dtypes = [DtypeEnum.fp16] + + def __init__(self): + self.inf_module = InferenceCoreBuilder().load() + self.inf_module.create_handle() + self.kernel = self.inf_module.cuda_wf6af16_linear + # The split_k_map is profiled on A100-80G GPU for some common shapes. + # It is an array of dictionaries, where the array index is the tokens chunk id. + # The dictionary is the mapping from the output channel to the split-K size. + self.split_k_map = [ + { # tokens: [1, 64] + 3072: 18, + 4096: 13, + 5120: 10, + 6144: 9, + 8192: 6, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 7 + }, + { # tokens: [65:128] + 3072: 9, + 4096: 6, + 5120: 5, + 6144: 9, + 8192: 3, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 6 + }, + { # tokens: [129:192] + 3072: 6, + 4096: 4, + 5120: 7, + 6144: 3, + 8192: 2, + 10240: 5, + 14336: 5, + 28672: 5, + 57344: 4 + }, + { # tokens: [193:256] + 3072: 9, + 4096: 3, + 5120: 5, + 6144: 2, + 8192: 5, + 10240: 4, + 14336: 8, + 28672: 6, + 57344: 4 + }, + { # tokens: [257:320] + 3072: 7, + 4096: 5, + 5120: 2, + 6144: 5, + 8192: 4, + 10240: 1, + 14336: 3, + 28672: 3, + 57344: 4 + }, + { # tokens: [321:384] + 3072: 3, + 4096: 2, + 5120: 5, + 6144: 3, + 8192: 1, + 10240: 8, + 14336: 3, + 28672: 4, + 57344: 3 + }, + { # tokens: [385:448] + 3072: 5, + 4096: 7, + 5120: 3, + 6144: 5, + 8192: 7, + 10240: 3, + 14336: 1, + 28672: 1, + 57344: 3 + }, + { # tokens: [449:512] + 3072: 2, + 4096: 5, + 5120: 4, + 6144: 1, + 8192: 5, + 10240: 2, + 14336: 6, + 28672: 4, + 57344: 1 + }, + { # tokens: [513:576] + 3072: 2, + 4096: 3, + 5120: 1, + 6144: 1, + 8192: 3, + 10240: 3, + 14336: 3, + 28672: 1, + 57344: 1 + }, + { # tokens: [577:640] + 3072: 5, + 4096: 4, + 5120: 1, + 6144: 4, + 8192: 2, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [641:704] + 3072: 3, + 4096: 1, + 5120: 2, + 6144: 2, + 8192: 1, + 10240: 2, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [705:768] + 3072: 3, + 4096: 1, + 5120: 3, + 6144: 2, + 8192: 1, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + } + ] + + def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2bit: torch.Tensor, + weights_4bit: torch.Tensor, scale: torch.Tensor, out_channels, tokens, in_channels) -> torch.Tensor: + """ + Matmul kernel of FP6 weight-only quantized linear. All inputs should be contiguous. + It does not support batched-matmul. + + Parameters: + output (torch.Tensor): Output tensor. Shape is of [token_number, out_features] + hidden_states (torch.Tensor): Input tensor. Shape is of [token_number, in_features] + weights_2bit (torch.Tensor): Input tensor of the 2-bit slice. Shape is of [out_features*2/8, in_features] + weights_4bit (torch.Tensor): Input tensor of the 4-bit slice. Shape is of [out_features*4/8, in_features] + scale (torch.Tensor): Input tensor. Shape is of [out_features], since the scale is per output channel + out_channels (int): The number of output channels + tokens (int): The number of tokens + in_channels (int): The number of input channels + """ + + if out_channels % 256 != 0 or in_channels % 64 != 0: + raise ValueError("The out and in channel should be multiple of 256 and 64 respectively.") + + # TODO: add a more general heuristic to determine the split-K. + split_k = -1 # not initialized + if tokens <= 768: + # Try to find the split-K from the pre-profiled map. + tokens_chunk_id = (tokens - 1) // 64 + split_k = self.split_k_map[tokens_chunk_id].get(out_channels, -1) + if split_k == -1: + split_k = 1 + inference_logger().warning( + f"The split-K setting may be suboptimal for shape {tokens}x{in_channels}x{out_channels}...") + + workspace = self.get_workspace(out_channels, tokens, in_channels, split_k, torch.float, hidden_states.device) + self.kernel(output, hidden_states, weights_2bit, weights_4bit, scale, workspace, out_channels, tokens, + in_channels, split_k) + + def get_workspace(self, out_channels: int, tokens: int, in_channels: int, split_k: int, dtype, + device) -> torch.Tensor: + """ + Allocate workspace for the kernel. The workspace is used to store the intermediate results of the matmul before + split-K. The split-K size is determined by the size of the matmul. + """ + workspace = torch.empty((split_k, out_channels, tokens), dtype=dtype, device=device) + + return workspace diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp new file mode 100644 index 000000000000..677bec22ded8 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp @@ -0,0 +1,224 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#include + +#include "cuda_linear_kernels.h" + +namespace { + +// For bit-level debugging. +template +void print_bits(T num) +{ + char bits[sizeof(T) * 8 + 1] = {'\0'}; + for (int bit = 0; bit < (sizeof(T) * 8); bit++) { + bits[sizeof(T) * 8 - 1 - bit] = '0' + (num & 0x01); + num = num >> 1; + } + printf("%s\n", bits); +} + +void print_bits(half num) +{ + char bits[sizeof(half) * 8 + 1] = {'\0'}; + auto int_num = *reinterpret_cast(&num); + for (int bit = 0; bit < (sizeof(half) * 8); bit++) { + bits[sizeof(half) * 8 - 1 - bit] = '0' + (int_num & 0x01); + int_num = int_num >> 1; + } + printf("%s\n", bits); +} + +/* + * Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values. + */ +void cast_fp16_fp6(uint16_t* FP16x4, uint8_t* FP6x4) +{ + // Constants for FP6 + constexpr int exponent_nbits_fp6 = 3; + constexpr int mantissa_nbits_fp6 = 2; + constexpr int exp_bias_fp6 = (1 << (exponent_nbits_fp6 - 1)) - 1; + // Constants for FP16 + constexpr int exponent_nbits_fp16 = 5; + constexpr int mantissa_nbits_fp16 = 10; + constexpr int exp_bias_fp16 = (1 << (exponent_nbits_fp16 - 1)) - 1; + + int fp6_temp[4]; + + float absmin_nonzero_fp6 = 0.0625; + // Note that we regard the exponent of '111' as a regular value rather than NaN or inf. This is + // the same with that in qtorch. + float absmax_fp6 = 28; + + for (int i = 0; i < 4; ++i) { + uint16_t source = FP16x4[i]; + float fp6_value_abs = std::abs(__half2float(*((half*)(&source)))); + if ((fp6_value_abs != 0 && fp6_value_abs < absmin_nonzero_fp6) || + fp6_value_abs > absmax_fp6) { + // TODO(zhen): a better way may be rounding it to the nearest FP6 value. + throw std::invalid_argument("Input value out of range for FP6."); + } + + // It is not safe to do shift operation on uint16_t. So we promote it to int. + int source_promote = int(source); + + int sign_bit = (source_promote >> 15); + // Extracting exponent represented in FP16. The sign mask 0x7FFF is '0111 1111 1111 1111' + int exp_bit = (source_promote & 0x7FFF) >> mantissa_nbits_fp16; + // Extracting mantissa represented in FP16 + int mant_bit = source_promote & ((1 << mantissa_nbits_fp16) - 1); + + int new_exp_bit; + int new_mant_bit; + + if (exp_bit == 0) { + // Subnormal FP16 number. Too small for FP6. + new_exp_bit = 0; + new_mant_bit = 0; + } else { + new_mant_bit = mant_bit >> (mantissa_nbits_fp16 - mantissa_nbits_fp6); + new_exp_bit = exp_bit - exp_bias_fp16 + exp_bias_fp6; + + // Deal with subnormal FP6 values. + int target_exp_val = exp_bit - exp_bias_fp16; + int min_fp6_exp_val = -exp_bias_fp6 + 1; + bool subnormal_fp6 = target_exp_val < min_fp6_exp_val; + if (subnormal_fp6) { + // TODO(zhen): add the rounding logic. + new_exp_bit = 0; + // The implicit 1 in the mantissa of FP16 is not present in subnormal FP6. Thus we + // need to add it + new_mant_bit = (new_mant_bit | (1 << mantissa_nbits_fp6)) >> + (min_fp6_exp_val - target_exp_val); + } + } + + fp6_temp[i] = (sign_bit << (exponent_nbits_fp6 + mantissa_nbits_fp6)) | + (new_exp_bit << mantissa_nbits_fp6) | new_mant_bit; + } + // Pack the values + FP6x4[0] = fp6_temp[0] << 2 | (fp6_temp[1] >> 4); + FP6x4[1] = (fp6_temp[1] & 0x0F) << 4 | (fp6_temp[2] >> 2); + FP6x4[2] = (fp6_temp[2] & 0x03) << 6 | fp6_temp[3]; +} + +/* + * Function to prepack FP16 weights into continuous FP6 values. + * + * Parameters: + * weight_16bit: input weight in FP16, size M*K + * weight_6bit: output weight in packed FP6, continuously stored, size M*K*6/8 + * M, K: the shape of the weight + */ +void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, + uint8_t* weight_6bit_packed, + size_t M, + size_t K) +{ + // Every four 16-bit elements are packed into three 6-bit values (4*6bit == 3*8bit). + if (K * 6 % 8 != 0) { throw std::invalid_argument("(K * 6 % 8) should be 0"); } + size_t K_fp6_packed = K * 6 / 8; + // #pragma omp parallel for + for (auto m = 0; m < M; m++) { + uint8_t* ptr_6bit = weight_6bit_packed + m * K_fp6_packed; + uint16_t* ptr_16bit = weight_16bit + m * K; + for (auto k = 0; k < K; k += 4) { + cast_fp16_fp6(ptr_16bit, ptr_6bit); + ptr_16bit += 4; + ptr_6bit += 3; + } + } +} + +} // namespace + +/* + * Function to execute the FP6 linear kernel. + * + * Parameters: + * output: output tensor, size M*N + * hidden_states: input activation tensor, size N*K + * weights_2bit: packed 2bit weights, size M*K*2/8 + * weights_4bit: packed 4bit weights, size M*K*4/8 + * scales: scale tensor, size M + * workspace: workspace tensor, size M*N*split_k + * M: the output channel number of the weight + * N: the token number of the activation + * K: the input channel number of the weight + * split_k: the split size of the GEMM calculation + */ +void cuda_wf6af16_linear(torch::Tensor& output, + torch::Tensor& hidden_states, + torch::Tensor& weights_2bit, + torch::Tensor& weights_4bit, + torch::Tensor& scales, + torch::Tensor& workspace, + int M, + int N, + int K, + int split_k) +{ + TORCH_CHECK(weights_2bit.device().type() == torch::kCUDA, "weight_2bit must be on CUDA"); + TORCH_CHECK(weights_4bit.device().type() == torch::kCUDA, "weight_4bit must be on CUDA"); + TORCH_CHECK(hidden_states.device().type() == torch::kCUDA, "X must be on CUDA"); + TORCH_CHECK(scales.device().type() == torch::kCUDA, "scales must be on CUDA"); + + auto status = fp6_linear_kernel(at::cuda::getCurrentCUDAStream(), + (uint4*)(weights_2bit.data_ptr()), + (uint4*)(weights_4bit.data_ptr()), + (half*)(scales.data_ptr()), + (half*)(hidden_states.data_ptr()), + (half*)(output.data_ptr()), + M, + N, + K, + workspace.data_ptr(), + split_k); + if (status != cudaSuccess) { + AT_ERROR("fp6_linear_kernel failed with error: ", cudaGetErrorString(status)); + } +} + +/* + * Function to prepack the fake 6-bit-quantized FP16 weights into 2bit and 4bit. + * + * Parameters: + * weight: input weight in FP16 (containing the quantized FP6-ranged value), size M*K + * Returns: + * weight_2bit: output weight in 2bit, size M*K*2/8 + * weight_4bit: output weight in 4bit, size M*K*4/8 + */ +std::vector preprocess_weight(torch::Tensor& weight) +{ + TORCH_CHECK(weight.dim() == 2, "weight must be 2-dimensional"); + TORCH_CHECK(weight.scalar_type() == torch::kFloat16, "weight must be FP16"); + TORCH_CHECK(weight.is_contiguous(), "weight must be contiguous"); + TORCH_CHECK(weight.device().type() == torch::kCPU, "weight must be on CPU"); + auto M = weight.size(0); + auto K = weight.size(1); + TORCH_CHECK(K % 4 == 0, "K must be multiple of 4"); + + // Pack weight from FP16 to FP6. + uint16_t* weight_16bit_ptr = reinterpret_cast(weight.data_ptr()); + std::vector weight_6bit_packed(M * K * 6 / 8); + uint8_t* weight_6bit_ptr = weight_6bit_packed.data(); + weight_prepacking_fp16_to_fp6(weight_16bit_ptr, weight_6bit_ptr, M, K); + + // Split weight into 2bit and 4bit. + weight_matrix_prepacking(reinterpret_cast(weight_6bit_ptr), M, K); + uint8_t* weight_2bit_ptr = weight_6bit_ptr; + + // Make sure that the new split tensor does not share the underlying memory with the original + // one. Otherwise it will incur some problems when the original tensor is deleted. It also + // makes the memory flattern risky. + auto weight_2bit = + torch::from_blob(weight_2bit_ptr, {M * K * 2 / 8}, torch::kUInt8).clone().detach(); + uint8_t* weight_4bit_ptr = weight_2bit_ptr + M * K * 2 / 8; + auto weight_4bit = + torch::from_blob(weight_4bit_ptr, {M * K * 4 / 8}, torch::kUInt8).clone().detach(); + + return {weight_2bit, weight_4bit}; +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h new file mode 100644 index 000000000000..0f5882d519ca --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#pragma once + +#include +#include +#include "ds_kernel_utils.h" + +#include "fp6_linear.cuh" + +void cuda_wf6af16_linear(torch::Tensor& output, + torch::Tensor& hidden_states, + torch::Tensor& weights_2bit, + torch::Tensor& weights_4bit, + torch::Tensor& scale, + torch::Tensor& workspace, + int M, + int N, + int K, + int split_k); + +std::vector preprocess_weight(torch::Tensor& Weight); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu new file mode 100644 index 000000000000..64e06a5435c6 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu @@ -0,0 +1,315 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +// clang-format off +// Put the torch headers at the front to avoid conflict with other headers on +// `at::nullopt` and `at::optional`. +#include +#include +// clang-format on + +#include "include/kernel_matmul.cuh" +#include "include/kernel_reduction.cuh" +#include "include/weight_prepacking.h" + +#include +#include + +template +static void Kernel_Ex(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + OutputDataType* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) +{ +#ifdef DEBUG_MODE + printf("\n"); + printf("Launcher.cu->Kernel_Ex():\n"); + printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); + printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", + TilingConfig::TILE_M, + TilingConfig::TILE_K, + TilingConfig::TILE_N); +#endif + static size_t SHMEM_SZ = + max(TilingConfig::SMEM_SIZE_B_TILE + SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE, + TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + SHMEM_SZ); + size_t dimN = (N_Global - 1) / TilingConfig::TILE_N + 1; + size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; + dim3 GridDim(dimN, dimM, 1); + dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1); +// +#ifdef DEBUG_MODE + printf( + "GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: " + "%d SHMEM_SZ: %d\n", + GridDim.x, + GridDim.y, + GridDim.z, + BlockDim.x, + BlockDim.y, + BlockDim.z, + SHMEM_SZ); + printf("\n"); +#endif + QUANT_GEMM_Kernel<<>>( + Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); +} + +/* + * + */ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K * + // M_Global * N_Global * sizeof(fp32) + int Split_K) +{ + assert(M_Global % 256 == 0); + assert(K_Global % 64 == 0); + assert(N_Global > 0); + + // Work around to support more N shapes: + size_t N_PowerOf2; + if (N_Global > 0 && N_Global <= 8) N_PowerOf2 = 8; + if (N_Global > 8 && N_Global <= 16) N_PowerOf2 = 16; + if (N_Global > 16 && N_Global <= 32) N_PowerOf2 = 32; + if (N_Global > 32 && N_Global <= 64) N_PowerOf2 = 64; + if (N_Global > 64 && N_Global <= 128) N_PowerOf2 = 128; + if (N_Global > 128) N_PowerOf2 = ((N_Global - 1) / 128 + 1) * 128; + + if (Split_K == 1) { + switch (N_PowerOf2) { + case 8: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 16: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 32: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 64: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + case 128: + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, half>( + stream, Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K); + break; + } + } else { + switch (N_PowerOf2) { + case 8: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 16: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 32: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 64: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + case 128: + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + default: + if (N_PowerOf2 % 128 != 0) { + printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); + return cudaErrorUnknown; + } + Kernel_Ex, float>(stream, + Weight1, + Weight2, + Scales, + B, + Reduction_Workspace, + M_Global, + N_Global, + K_Global, + Split_K); + break; + } + // Reduction for SplitK + dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); + dim3 BlockDim(WARP_SIZE, 1, 1); + SplitK_Reduction<<>>( + C, Reduction_Workspace, M_Global, N_Global, Split_K); + } + return cudaGetLastError(); +} + +/* +Computes FP6-FP16 GEMM (PyTorch interface). + +[Mathematical Formula] +Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in +row-major. After Equivalent transformation : trans(Out) = W * trans(In). Note that we do not +perform "transpose" during runtime, we instead interpret the In/Out as column-major matrices when +calling our CUDA kernel. + +[Inputs] + _in_feats: tensor of shape [B, IC]; // half + _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + _scales: tensor of shape [OC]; // half + splitK: splitting the MatMul problem along K dimension for higher GPU utilization, default 1. +[Outputs] + _out_feats: tensor of shape [B, OC]; // half +*/ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int splitK = 1) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + int num_out_channels = _weights.size(0); + assert(num_in_channels % 64 == 0); + assert((num_in_channels / 16 * 3) == + _weights.size(1)); // Making sure the K dimension is matched. + // + int M = num_out_channels; + int K = num_in_channels; + int N = num_in_feats; + // Input Tensors + auto weight1 = reinterpret_cast( + _weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto weight2 = weight1 + num_in_channels * num_out_channels * 2 / 128; + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto scales = reinterpret_cast(_scales.data_ptr()); + // Output Tensors + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({num_in_feats, num_out_channels}, options); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + + float* Reduction_Workspace = nullptr; + if (splitK != 1) { + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); + at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); + auto Reduction_Workspace = reinterpret_cast( + _out_feats.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * + // N_Global * sizeof(fp32) + } + + fp6_linear_kernel(0, // Using default stream here. + weight1, + weight2, + scales, + in_feats, + out_feats, + M, + N, + K, + Reduction_Workspace, + splitK); + + return _out_feats; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ + +/* + * Weight prepacking (Pytorch interface). + * [Input & Output] + * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + * [Output] + * packed_tensor: int tensor of shape [OC, IC // 16 * 3]; + */ +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t OC, size_t IC) +{ + assert((OC % 256 == 0) && (IC % 64 == 0)); + assert((fp6_tensor.size(0) == OC) && (fp6_tensor.size(1) == IC / 16 * 3)); + // auto packed_tensor = torch::empty_like(fp6_tensor); + // auto packed_tensor_ptr = reinterpret_cast(packed_tensor.data_ptr()); + auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); + weight_matrix_prepacking(fp6_tensor_ptr, OC, IC); + return fp6_tensor; +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh new file mode 100644 index 000000000000..95f7f6050c15 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cuh @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#include +#include +#include + +#include + +/* + * Computes FP6-FP16 GEMM (C++ interface). + */ +cudaError_t fp6_linear_kernel(cudaStream_t stream, + const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + half* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + float* Reduction_Workspace, // Reduction_Workspace_Size = Split_K * + // M_Global * N_Global * sizeof(fp32) + int Split_K); + +/* + * Computes FP6-FP16 GEMM (PyTorch interface). + */ +torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int splitK = 1); + +/* + * In-place weight prepacking (C++ interface). + */ +void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K); + +/* + * Weight prepacking (Pytorch interface). + */ +torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t M, size_t K); diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h new file mode 100644 index 000000000000..76e8eda2d35e --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/configs.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef CONFIGS_H +#define CONFIGS_H + +// #define DEBUG_MODE +#define PIPELINE_LEVEL_GMEM 2 +#define PIPELINE_LEVEL_SMEM 2 // only support 2 + +/************************ Hardware Parameters ************************/ +#define WARP_SIZE 32 +#define REG_BIT_WIDTH 32 +// mma: M=16 K=16 N=8 +#define MMA_8 8 +#define MMA_16 16 +// for memory access +#define THREAD_OPT_ACCESS_BIT_WIDTH_128 128 // LDS.128, cp_async.128, ... +#define BIT_WIDTH_PER_HALF 16 // Half precision: FP16 + +/******************** Register Allocation For GEMM ********************/ +#define REG_PER_THREAD_C_TENSOR_16_16 8 // 8 for FP32 Accumulation +/********************** Memory Padding Parameters **********************/ +// Eliminating bank-conflict +#define PADDING_BYTES_16 16 // Padding 16 bytes each column +#define PADDING_SHARED_MEM_FOR_B_8 \ + 8 // Padding 8 half each column, during CopyFromGlobalToShared() for B +#define PADDING_SHARED_MEM_FOR_C_4 \ + 4 // Padding 4 float each column, during StoreToSharedMemoryFromRegister() for C +/************************* WARP Tiling part-1 *************************/ +#define WARP_ROW_MMA_TENSORS 4 +#define WARP_M (WARP_ROW_MMA_TENSORS * MMA_16) // 64 +#define WARP_K_MMA_TENSORS 4 +#define WARP_K (WARP_K_MMA_TENSORS * MMA_16) // 64 +template +struct TilingConfig { + // Depending on "n" dimension of the GEMM + static constexpr int BLOCK_ROW_WARPS = BLOCK_ROW_WARPS_; + static constexpr int BLOCK_COL_WARPS = BLOCK_COL_WARPS_; + static constexpr int WARP_COL_MMA_TENSORS = WARP_COL_MMA_TENSORS_; + /************************* WARP Tiling part-2 *************************/ + static constexpr int WARP_N = WARP_COL_MMA_TENSORS * MMA_8; + /*************************Thread Block Tiling *************************/ + static constexpr int TILE_M = WARP_M * BLOCK_ROW_WARPS; + static constexpr int TILE_N = MMA_8 * WARP_COL_MMA_TENSORS * BLOCK_COL_WARPS; + static constexpr int TILE_K = WARP_K; + /********************** #Thread per Thread Block **********************/ + static constexpr int BLOCK_WARPS = BLOCK_ROW_WARPS * BLOCK_COL_WARPS; + static constexpr int BLOCK_THREADS = BLOCK_WARPS * WARP_SIZE; + /******************************* Others *******************************/ + static constexpr int SMEM_SIZE_B_TILE = TILE_N * (TILE_K + PADDING_BYTES_16) * 2 * + PIPELINE_LEVEL_GMEM; // sizeof(half)=2, doubleBuffer=2 + static constexpr int SMEM_SIZE_C_TILE = + TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 +}; + +/************************ General Config for Quant-LLM **********************/ +#define WEIGHT_FRAG1_BIT_WIDTH 2 +#define WEIGHT_FRAG2_BIT_WIDTH 4 +#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH + WEIGHT_FRAG2_BIT_WIDTH) // 6 +// #define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // +// QuantGroupSize: 4*64 = 256 +/*************************** 64*64 Weghts of A WARP *************************/ +#define WEIGHT_PER_UNIT (WARP_M * WARP_K) // 64*64 +#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 \ + (WEIGHT_PER_UNIT * WEIGHT_FRAG1_BIT_WIDTH / \ + 8) // 1024 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 \ + (WEIGHT_PER_UNIT * WEIGHT_FRAG2_BIT_WIDTH / \ + 8) // 2048 Bytes #doubleBuffer not takedn into consideration +#define SMEM_SIZE_A1_TILE \ + (SMEM_SIZE_IN_BYTES_PER_WARP_A1 * 4 * \ + PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double + // buffer for 2-level pipeline A= 8 KB. +#define SMEM_SIZE_A2_TILE \ + (SMEM_SIZE_IN_BYTES_PER_WARP_A2 * 4 * \ + PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double + // buffer for 2-level pipeline A= 16 KB. +/******************** Global Memory Layout For QUANTIZED DATA ******************/ +#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT * WEIGHT_FRAG1_BIT_WIDTH / 128) // 64 +#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT * WEIGHT_FRAG2_BIT_WIDTH / 128) // 128 +/******************** Register Allocation For QUANTIZED DATA ******************/ +#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT / WARP_SIZE) // 128 +#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD / REG_BIT_WIDTH * 2) // 8 +#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD / REG_BIT_WIDTH * 4) // 16 +/******************** Register Allocation For QUANT Scales ******************/ +#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers +#define WARP_REG_QUANT_SCALE_DISTRIBUTED \ + 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for + // each thread + +#endif // CONFIGS_H diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh new file mode 100644 index 000000000000..aa6ea6c4b1c2 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_matmul.cuh @@ -0,0 +1,261 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#include "configs.h" +#include "utils_core.cuh" +#include "utils_gmem.cuh" + +/* + * C = A*B + * A: row major with ahead-of-time layout transformation, FP6 + * B: col major, FP16 + * C: col major, FP16 + */ +template +__global__ void QUANT_GEMM_Kernel(const uint4* Weight1, + const uint4* Weight2, + const half* Scales, + const half* B, + OutputDataType* C, + const size_t M_Global, + const size_t N_Global, + const size_t K_Global, + int Split_K) +{ +#ifdef DEBUG_MODE + assert(K_Global % TilingConfig::TILE_K == 0); + assert(M_Global % TilingConfig::TILE_M == 0); + assert(gridDim.y == Split_K * (M_Global / TilingConfig::TILE_M)); +#endif + extern __shared__ __align__(128) + half smem[]; // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned + half(*smem_array)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + reinterpret_cast( + smem + + (SMEM_SIZE_A1_TILE + SMEM_SIZE_A2_TILE) / 2); // Dynamic shared memory for FP16 B tiles + __shared__ half QuantScales[64 * TilingConfig::BLOCK_WARPS]; // static shared memory for + // quantization scales, 64 row per + // warp * 4 warps = 512 Bytes + // Thread Block Mapping, considering SplitK + const size_t BatchID = blockIdx.y / (M_Global / TilingConfig::TILE_M); + const size_t x = blockIdx.x; // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t y = + blockIdx.y % + (M_Global / TilingConfig::TILE_M); // Output Block ID: (BlockID_Row = y; BlockID_Col = x ) + const size_t Tile_Start_M = y * TilingConfig::TILE_M; + const size_t Tile_Start_N = x * TilingConfig::TILE_N; + const size_t NumColumnToCopy = (N_Global - Tile_Start_N) < TilingConfig::TILE_N + ? (N_Global - Tile_Start_N) + : TilingConfig::TILE_N; + const size_t NumBlock_K = K_Global / TilingConfig::TILE_K; + const size_t AverageNumBlock_K = NumBlock_K / Split_K; + const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; + size_t NumIter = AverageNumBlock_K; + if (BatchID < ExtraNumBlock_K) NumIter++; + size_t StartBlockID_K = AverageNumBlock_K * BatchID; + if (BatchID < ExtraNumBlock_K) + StartBlockID_K += BatchID; + else + StartBlockID_K += ExtraNumBlock_K; + // Warp ID. + const int warpId = threadIdx.x / WARP_SIZE; + int WARP_i = + warpId / TilingConfig::BLOCK_COL_WARPS; // WARP_i: row number; WARP_j: column number + // int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + // Global Memory Address for Matrix A (Weight) + // ///////////////////////////////////////////////////////////////////////// StartPTR for each + // ThreadBlock(TB) + const uint4* TB_StartGPTR_A1 = + Weight1 + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + const uint4* TB_StartGPTR_A2 = + Weight2 + (y * TilingConfig::BLOCK_ROW_WARPS) * NumBlock_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // StartPTR for each WARP. + const uint4* WARP_StartGPTR_A1 = + TB_StartGPTR_A1 + WARP_i * NumBlock_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + const uint4* WARP_StartGPTR_A2 = + TB_StartGPTR_A2 + WARP_i * NumBlock_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // StartPTR for each WARP, considering SplitK + const size_t WARP_Start_UnitID_K = StartBlockID_K; + WARP_StartGPTR_A1 += WARP_Start_UnitID_K * NUM_INT4_PER_UNIT_2BIT_FRAG; + WARP_StartGPTR_A2 += WARP_Start_UnitID_K * NUM_INT4_PER_UNIT_4BIT_FRAG; + // Copying A tile from Global to Shared, using double-buffer + // ////////////////////////////////////////////////////////// StartSPTR for each ThreadBlock + uint32_t* AFrag_2BIT_SPTR = reinterpret_cast(smem); + uint32_t* AFrag_4BIT_SPTR = + AFrag_2BIT_SPTR + + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * TilingConfig::BLOCK_WARPS * + PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers + // StartSPTR for each WARP + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4; + // Pre-fetch of A tile + for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) { + CopyFromGlobalToShared_A( + AFrag_2BIT_SPTR + i * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * 4, WARP_StartGPTR_A1); + CopyFromGlobalToShared_A( + AFrag_4BIT_SPTR + i * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * 4, WARP_StartGPTR_A2); + WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; + WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; + } + // Global Memory Address for Matrix A (QuantScale) + // ///////////////////////////////////////////////////////////////////// + const half* TB_StartGPTR_A_Scale = Scales + (y * TilingConfig::BLOCK_ROW_WARPS) * 64; + const half* WARP_StartGPTR_A_Scales = TB_StartGPTR_A_Scale + WARP_i * 64; + CopyFromGlobalToShared_Scales(QuantScales + WARP_i * 64, WARP_StartGPTR_A_Scales); + // Copying B tile from Global to Shared, considering SplitK + // ///////////////////////////////////////////////////////////// + const half* BTile_GPTR = B + Tile_Start_N * K_Global + StartBlockID_K * TilingConfig::TILE_K; + for (int i = 0; i < PIPELINE_LEVEL_GMEM - 1; i++) { + CopyFromGlobalToShared( + smem_array + i * TilingConfig::TILE_N, BTile_GPTR, K_Global, NumColumnToCopy); + BTile_GPTR += TilingConfig::TILE_K; + } + // Register Allocation for A,B, and C, Initilazed to Zeros + // ///////////////////////////////////////////////////////////////////// + constexpr int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block +#ifdef PIPELINE_LEVEL_SMEM + uint32_t a[NumRegSets_a * PIPELINE_LEVEL_SMEM] + [4]; // double/Trible buffer is used // Registers to store decompressed FP6 + uint32_t b[NumRegSets_b * PIPELINE_LEVEL_SMEM] + [4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) +#endif + float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; + for (int i = 0; i < NumRegSets_a * NumRegSets_b; i++) + for (int j = 0; j < REG_PER_THREAD_C_TENSOR_16_16; j++) c[i][j] = 0.0f; + // + cp_async_wait_all(); + __syncthreads(); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + uint32_t Scales_RPTR[4]; // 4 Registers per thread for Quantization Scales + ExtractFromSharedToReg_Scales(Scales_RPTR, QuantScales + WARP_i * 64); +#ifdef PIPELINE_LEVEL_SMEM + // Initializing the Software Pipeline: writing registers. + // //////////////////////////////////////////////////////////////////////////////////////////////// + initialize_mma_slice( + a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); +#endif +// The outer loop. +// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +#pragma unroll(1) + for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) { + // Trible-Buffer for A Tile + uint32_t* __restrict__ read_SPTR_Frag1 = + AFrag_2BIT_SPTR + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * + 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag2 = + AFrag_4BIT_SPTR + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * + 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 +#ifdef PIPELINE_LEVEL_SMEM + uint32_t* __restrict__ read2_SPTR_Frag1 = + AFrag_2BIT_SPTR + + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * 4; + uint32_t* __restrict__ read2_SPTR_Frag2 = + AFrag_4BIT_SPTR + + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * 4; +#endif + uint32_t* __restrict__ write_SPTR_Frag1 = + AFrag_2BIT_SPTR + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 4 * + 4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag2 = + AFrag_4BIT_SPTR + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 4 * + 4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + // Trible-Buffer for B Tile + half __restrict__(*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + ((tile_id_k + 0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#ifdef PIPELINE_LEVEL_SMEM + half __restrict__(*read2_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + ((tile_id_k + 1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; +#endif + half __restrict__(*write_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8] = + smem_array + + ((tile_id_k + (PIPELINE_LEVEL_GMEM - 1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; + // + bool GlobalCopy = (tile_id_k + PIPELINE_LEVEL_GMEM - 1) < NumIter; + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + CopyFromGlobalToShared_A( + write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); + CopyFromGlobalToShared_A( + write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); + // copying B tile from GlobalMemory to SharedMemory + CopyFromGlobalToShared( + write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); + cp_async_group_commit(); +#ifdef PIPELINE_LEVEL_SMEM + core_mma_slice(c, + a, + b, + read_SPTR_Frag1, + read_SPTR_Frag2, + read_SPTR, + Scales_RPTR, + 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each + // WARP; read_SPTR is shared among WARPs + core_mma_slice( + c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); + core_mma_slice( + c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); + core_mma_slice( + c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); + // Updating global PTRs + WARP_StartGPTR_A1 += + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; +#else + PipelinedCoreLoop( + c, + read_SPTR, + read_SPTR_Frag1, + read_SPTR_Frag2, + Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; + // read_SPTR is shared among WARPs + // Updating global PTRs + WARP_StartGPTR_A1 += + SMEM_SIZE_IN_BYTES_PER_WARP_A1 / 16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A2 += + SMEM_SIZE_IN_BYTES_PER_WARP_A2 / 16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + BTile_GPTR += TilingConfig::TILE_K; + // Barriers and Synchronizations + cp_async_wait_group(); + __syncthreads(); +#endif + } + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Store the C fragments to shared memory. + float(*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4] = + reinterpret_cast(smem); + StoreToSharedMemoryFromRegister(smem_CFrag, c); + __syncthreads(); + // Now that shared memory contains all the D tiles, stream them to global memory. + OutputDataType* BlockGlobalPTR = + C + BatchID * (M_Global * N_Global) + Tile_Start_M + Tile_Start_N * M_Global; + for (size_t i = warpId; i < NumColumnToCopy; i += TilingConfig::BLOCK_WARPS) // i-th column +#pragma unroll + for (size_t j = threadIdx.x % WARP_SIZE; j < TilingConfig::TILE_M; + j += WARP_SIZE) // j-th row + { + if constexpr (std::is_same::value) + BlockGlobalPTR[j + i * M_Global] = __float2half_rn(smem_CFrag[i][j]); + else + BlockGlobalPTR[j + i * M_Global] = smem_CFrag[i][j]; + } +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh new file mode 100644 index 000000000000..8c49f8b0b3a5 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/kernel_reduction.cuh @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#include +#include +#include + +#define REDUCTION_ELEMENT_PER_THREADBLOCK 256 +#define HALF_PER_128BIT 8 + +__global__ void SplitK_Reduction(half* C, + float* Reduction_Workspace, + size_t M_Global, + size_t N_Global, + int Split_K) +{ + half* WARP_GPTR_C = C + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + float* WARP_GPTR_R = Reduction_Workspace + REDUCTION_ELEMENT_PER_THREADBLOCK * blockIdx.x; + half* THREAD_GPTR_C = WARP_GPTR_C + threadIdx.x * HALF_PER_128BIT; + float* THREAD_GPTR_R = WARP_GPTR_R + threadIdx.x * HALF_PER_128BIT; + // Initializing Thread-Local Results + float Results[HALF_PER_128BIT]; +#pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) Results[i] = 0.0f; + // Reduction + for (int i = 0; i < Split_K; i++) { +#pragma unroll + for (int j = 0; j < HALF_PER_128BIT; j++) Results[j] += THREAD_GPTR_R[j]; + THREAD_GPTR_R += M_Global * N_Global; + } +// Writing to global memory +#pragma unroll + for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh new file mode 100644 index 000000000000..7f36cfd5d961 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_cp.async.cuh @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef PTX_CP_ASYNC_CUH +#define PTX_CP_ASYNC_CUH + +#include +#include +#include + +template +__device__ __forceinline__ void cp_async(half* smem_ptr, + const half* global_ptr, + bool pred_guard = true) +{ + static_assert(SizeInBytes == 16, "Size is not supported"); + unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr); + asm volatile( + "{ \n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred_guard), + "r"(smem_int_ptr), + "l"(global_ptr), + "n"(SizeInBytes)); +} + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +__device__ __forceinline__ void cp_async_group_commit() +{ + asm volatile("cp.async.commit_group;\n" ::); +} + +/// Blocks until all but previous cp.async.commit_group operations have committed. +template +__device__ __forceinline__ void cp_async_wait_group() +{ + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +} + +/// Blocks until all previous cp.async.commit_group operations have committed. +// cp.async.wait_all is equivalent to : +// cp.async.commit_group; +// cp.async.wait_group 0; +__device__ __forceinline__ void cp_async_wait_all() { asm volatile("cp.async.wait_all;\n" ::); } + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh new file mode 100644 index 000000000000..f13abe036279 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/ptx_mma.cuh @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef PTX_MMA_CUH +#define PTX_MMA_CUH + +#include +#include +#include + +#include +#include "configs.h" + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void B_FromSharedToReg( + uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + int slice_id) +{ +#ifdef DEBUG_MODE + static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0)); +#endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * + WARP_j; // each warp may start from reading warp_start_col'th column of + // the B tile in shared memory +#ifdef DEBUG_MODE + assert(warp_start_col == 0); +#endif + + int col = (lane_id % 8) + (lane_id / 16) * 8; + int row = (lane_id % 16) / 8 * 8; + uint32_t smem_local_ptr = static_cast( + __cvta_generic_to_shared(&read_SPTR[warp_start_col + col][slice_id * MMA_16 + row])); + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } else { +#pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +} +#else +// Debug: Whether ldmatrix.trans is required??? +// B is in column-major +template +__device__ __forceinline__ void B_FromSharedToReg( + uint32_t __restrict__ Reg[][4], + half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + int k_offset) +{ +#ifdef DEBUG_MODE + static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0)); +#endif + + const int warpId = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; + int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * + WARP_j; // each warp may start from reading warp_start_col'th column of + // the B tile in shared memory +#ifdef DEBUG_MODE + assert(warp_start_col == 0); +#endif + + int col = (lane_id % 8) + (lane_id / 16) * 8; + int row = (lane_id % 16) / 8 * 8; + uint32_t smem_local_ptr = static_cast( + __cvta_generic_to_shared(&read_SPTR[warp_start_col + col][k_offset + row])); + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(Reg[0][0]), "=r"(Reg[0][1]) + : "r"(smem_local_ptr)); + } else { +#pragma unroll + for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS / 2; i++) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) + : "r"(smem_local_ptr)); + smem_local_ptr += 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); + } + } +} +#endif + +__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[], + uint32_t __restrict__* a, + uint32_t __restrict__* b) +{ + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{ %0, %1, %2, %3}," + "{ %4, %5, %6, %7 }," + "{ %8, %9 }," + "{ %10, %11, %12, %13 };" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), + "r"(a[1]), + "r"(a[2]), + "r"(a[3]), + "r"(b[0]), + "r"(b[1]), + "r"(c[0]), + "r"(c[1]), + "r"(c[2]), + "r"(c[3])); +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh new file mode 100644 index 000000000000..713cebc57e33 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_core.cuh @@ -0,0 +1,246 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef UTILS_CORE_CUH +#define UTILS_CORE_CUH + +#include + +#include "configs.h" +#include "ptx_mma.cuh" +#include "utils_paralleldequant.cuh" + +#ifdef PIPELINE_LEVEL_SMEM +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], + uint32_t* SPTR, + int slice_id) +{ + SPTR += slice_id * (NUM_INT_PER_THREAD * WARP_SIZE); + int lane_id = threadIdx.x % WARP_SIZE; +#pragma unroll + for (int i = 0; i < NUM_INT_PER_THREAD; i++) { Reg[i] = SPTR[lane_id + i * WARP_SIZE]; } +} + +template +__device__ __forceinline__ void initialize_mma_slice( + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales) +{ + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2>(a_1, A1_SPTR_read, 0); + CopyFromSharedToRegister_AFrag<4>(a_2, A2_SPTR_read, 0); + Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at + // register level, dequantizing a slice each time + B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers +} + +template +__device__ __forceinline__ void core_mma_slice( + float c[][REG_PER_THREAD_C_TENSOR_16_16], + uint32_t (*a)[4], + uint32_t (*b)[4], + uint32_t* __restrict__ A1_SPTR_read, + uint32_t* __restrict__ A2_SPTR_read, + half __restrict__ (*B_SPTR_read)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* RPTR_Scales, + int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching +{ +#ifdef DEBUG_MODE + assert( + (TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == + 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block +#endif + const int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block + uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = + reinterpret_cast( + c); // Registers for accumulated FP32 results + + // Setting RPTRs for double buffers + uint32_t(*a_read)[4] = a; + uint32_t(*a_write)[4] = a; + uint32_t(*b_read)[4] = b; + uint32_t(*b_write)[4] = b; + if (slice_id % 2 == 1) { + b_write += NumRegSets_b; + a_write += NumRegSets_a; + } else { + b_read += NumRegSets_b; + a_read += NumRegSets_a; + } + +// Reading registers and issuing core tensor core computations (a slice of A and B tile in shared +// memory) +#pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) { + MMA_FP16_M16N8K16(c_uint_ptr[i], a_read[i], b_read[0]); + } else { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) { + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a_read[i], b_read[j]); + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, + a_read[i], + b_read[j] + 2); // c+4; b+2 + } + } + } + + // Writing registers + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2]; // NO double buffer + uint32_t a_2[4]; // NO double buffer + CopyFromSharedToRegister_AFrag<2>(a_1, A1_SPTR_read, slice_id); + CopyFromSharedToRegister_AFrag<4>(a_2, A2_SPTR_read, slice_id); + Dequant_32FP6_4Way( + a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register + // level, dequantizing a slice each time + B_FromSharedToReg( + b_write, B_SPTR_read, slice_id); // Loading B from shared to registers +} + +#else +// Old version with naive pipeline design +template +__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) +{ + int lane_id = threadIdx.x % WARP_SIZE; +#pragma unroll + for (int i = 0; i < NUM_INT_PER_THREAD; i++) { Reg[i] = SPTR[lane_id + i * WARP_SIZE]; } +} +template +__device__ __forceinline__ void PipelinedCoreLoop( + float c[][REG_PER_THREAD_C_TENSOR_16_16], + half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + uint32_t* __restrict__ read_SPTR_Frag1, + uint32_t* __restrict__ read_SPTR_Frag2, + uint32_t* RPTR_Scales) +{ +#ifdef DEBUG_MODE + assert( + (TilingConfig::WARP_COL_MMA_TENSORS == 1) || + (TilingConfig::WARP_COL_MMA_TENSORS % 2 == + 0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block +#endif + const int NumRegSets_a = + WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block + const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS == 1) + ? 1 + : TilingConfig::WARP_COL_MMA_TENSORS / + 2; // 1 set = 4 registers, containing a 16*16 MMA block + + // Registers to store FP32 results + uint32_t(*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = + reinterpret_cast(c); + // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 + // register per thread; + uint32_t a_1[2 * 2]; // double buffer is used + uint32_t a_2[4 * 2]; // double buffer is used + // Registers to store decompressed FP6 + uint32_t a[NumRegSets_a * 1][4]; // No double buffer + // Register to store FP16 B matrix (a slice) + uint32_t b[NumRegSets_b * 2][4]; // double buffer is used + + // Overlapped Smem and TC pipeline: pre-loading from shared to registers + CopyFromSharedToRegister_AFrag<2>(a_1, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4>(a_2, read_SPTR_Frag2); + B_FromSharedToReg(b, read_SPTR, 0); + +#pragma unroll + for (int k = 0; k < WARP_K_MMA_TENSORS; k++) { + uint32_t(*b_read)[4] = b; + uint32_t(*b_write)[4] = b; + uint32_t* a_1_read = a_1; + uint32_t* a_1_write = a_1; + uint32_t* a_2_read = a_2; + uint32_t* a_2_write = a_2; + if (k % 2 == 0) { + b_write += NumRegSets_b; + a_1_write += 2; + a_2_write += 4; + } else { + b_read += NumRegSets_b; + a_1_read += 2; + a_2_read += 4; + } + // data loading + if (k + 1 < WARP_K_MMA_TENSORS) { + // updating SPTR for fragment1 and fragment2 + read_SPTR_Frag1 += 2 * WARP_SIZE; + read_SPTR_Frag2 += 4 * WARP_SIZE; + CopyFromSharedToRegister_AFrag<2>(a_1_write, read_SPTR_Frag1); + CopyFromSharedToRegister_AFrag<4>(a_2_write, read_SPTR_Frag2); + B_FromSharedToReg(b_write, read_SPTR, (k + 1) * MMA_16); + } + // SIMT Dequant + Tensor Core computations + Dequant_32FP6_4Way( + a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, + // dequantizing a slice each time +#pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { + if (TilingConfig::WARP_COL_MMA_TENSORS == 1) + MMA_FP16_M16N8K16(c_uint_ptr[i], a[i], b_read[0]); + else { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS / 2; j++) { + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j]); + MMA_FP16_M16N8K16(c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, + a[i], + b_read[j] + 2); // c+4; b+2 + } + } + } + } +} +#endif // #ifdef PIPELINE_LEVEL_SMEM + +template +__device__ __forceinline__ void StoreToSharedMemoryFromRegister( + float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], + float c[][REG_PER_THREAD_C_TENSOR_16_16]) +{ + const int lane_id = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + int warp_row_offset = warpId * (MMA_16 * WARP_ROW_MMA_TENSORS); +#pragma unroll + for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { +#pragma unroll + for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS; + j++) { // Dealing with one 16*8 Tensor + int RegSetID = i + (j / 2) * WARP_ROW_MMA_TENSORS; + int RegOffset = (j % 2) * (REG_PER_THREAD_C_TENSOR_16_16 / 2); + int Tensor_row_offset = warp_row_offset + i * MMA_16; + int Tensor_col_offset = j * MMA_8; +#pragma unroll + for (int r = 0; r < REG_PER_THREAD_C_TENSOR_16_16 / 2; r++) { + int row_offset = lane_id / 4; + if (r >= 2) row_offset += 8; + int col_offset = (lane_id % 4) * 2; + if (r % 2 == 1) col_offset += 1; + smem_CFrag[Tensor_col_offset + col_offset][Tensor_row_offset + row_offset] = + c[RegSetID][r + RegOffset]; + } + } + } +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh new file mode 100644 index 000000000000..62b77edaa37a --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_gmem.cuh @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef UTILS_GMEM_CUH +#define UTILS_GMEM_CUH + +#include +#include "configs.h" +#include "ptx_cp.async.cuh" + +/* + * Copying A1/A2 from global memory to shared memory. + * Usually 1024 or 2048 Bytes + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared_A(uint32_t* SPTR, + const uint4* GPTR, + bool pred_guard = true) +{ +#ifdef DEBUG_MODE + static_assert(SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE % 16 == 0); +#endif + int lane_id = threadIdx.x % WARP_SIZE; + half* SPTR_HALF = reinterpret_cast(SPTR); + const half* GPTR_HALF = reinterpret_cast(GPTR); + SPTR_HALF += lane_id * 8; + GPTR_HALF += lane_id * 8; +#pragma unroll + for (int i = 0; i < SMEM_SIZE_IN_BYTES_PER_WARP / WARP_SIZE / 16; i++) { + cp_async<16>(SPTR_HALF, GPTR_HALF, pred_guard); + SPTR_HALF += 256; // Forward 512 Bytes + GPTR_HALF += 256; // Forward 512 Bytes + } +} + +/* + * Copying 64 Quant Scales (FP16) from global memory to shared memory. + */ +__device__ __forceinline__ void CopyFromGlobalToShared_Scales(half* SPTR_QuantScales, + const half* GPTR_A_Scales) +{ + int lane_id = threadIdx.x % WARP_SIZE; + int Offset_Shared = lane_id * 2; + int Offset_Global = lane_id / 4 + (lane_id % 4) * 16; + for (int i = 0; i < 2; i++) + SPTR_QuantScales[Offset_Shared + i] = GPTR_A_Scales[Offset_Global + i * 8]; +} + +/* + * (1) Copying X rows * 64 columns of FP16 values, originally in row major + * (2) Copying 64 rows * X columns of FP16 values, originally in column major + * 16 Bytes per thread -> 512 Bytes per WARP = 4 line per WARP = 1 line per 8 Threads + */ +template +__device__ __forceinline__ void CopyFromGlobalToShared( + half __restrict__ (*SharedPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8], + const half* GlobalPTR, + const int GlobalStride, + const int NumOfLinesLeft, // To support arbitrary N dimensions. + bool Pred = true) +{ + // static parameters: 1 Group (8 Threads) can copy 1 line (64 FP16) each time + const int NumOfThreads = BLOCK_WARPS * WARP_SIZE; + const int NumOfGroups = NumOfThreads / 8; + const int MaxIteration = (MaxNumOfLinesToCopy - 1) / NumOfGroups + 1; + // runtime variables + const int line_id = threadIdx.x / 8; + const int line_offset = (threadIdx.x % 8) * 8; + // PTR for source global memory and target shared memory + GlobalPTR += line_id * GlobalStride + line_offset; + SharedPTR += line_id; +#pragma unroll + for (int i = 0; i < MaxIteration; i++) { + bool AsyncCopyPred = (line_id + i * NumOfGroups) < NumOfLinesLeft && Pred; + cp_async<16>(&(*SharedPTR)[line_offset], GlobalPTR, AsyncCopyPred); + // + GlobalPTR += NumOfGroups * GlobalStride; + SharedPTR += NumOfGroups; + } +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh new file mode 100644 index 000000000000..ff13868c1347 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/utils_paralleldequant.cuh @@ -0,0 +1,127 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#ifndef UTILS_PARALLELDEQUANT_CUH +#define UTILS_PARALLELDEQUANT_CUH + +#include +#include +#include + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way(u_int32_t* R1, u_int32_t* R2) +{ + *R2 = *R1 & 0x80808080; + *R1 = *R1 >> 2; + *R1 = *R1 & 0x1f1f1f1f; + *R2 = *R2 | *R1; + *R1 = *R2 & 0x9f009f00; + *R2 = *R2 & 0x009f009f; + *R2 = *R2 << 8; +} + +/* + * Input: R1 + * Outputs: R1, R2 + * Note: Simplified Exponent calculation is NOT applied. + */ +__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(u_int32_t* R1, u_int32_t* R2) +{ + //*R2 = *R1 & 0x80808080; + *R2 = *R1 & 0xc0c0c0c0; + *R1 = *R1 >> 2; + //*R1 = *R1 & 0x1f1f1f1f; + *R1 = *R1 & 0x0f0f0f0f; + *R2 = *R2 | *R1; + // + //*R1 = *R2 & 0x9f009f00; + //*R2 = *R2 & 0x009f009f; + *R1 = *R2 & 0xcf00cf00; + if (!(*R1 & 0x40000000) && (*R1 & 0x0c000000)) *R1 = *R1 | 0x30000000; + if (!(*R1 & 0x00004000) && (*R1 & 0x00000c00)) *R1 = *R1 | 0x00003000; + *R2 = *R2 & 0x00cf00cf; + if (!(*R2 & 0x00400000) && (*R2 & 0x000c0000)) *R2 = *R2 | 0x00300000; + if (!(*R2 & 0x00000040) && (*R2 & 0x0000000c)) *R2 = *R2 | 0x00000030; + // + *R2 = *R2 << 8; + //*R1 = 0x3c003c00; + //*R2 = 0x3c003c00; +} + +__device__ __forceinline__ u_int32_t MultScale(u_int32_t PackedFP16Pair, half Scale) +{ + half* FP16_1 = reinterpret_cast(&PackedFP16Pair); + half* FP16_2 = FP16_1 + 1; + uint32_t output; + half* output_half_ptr = reinterpret_cast(&output); + output_half_ptr[0] = __hmul(__hmul(*FP16_1, __float2half(4096.0f)), Scale); + output_half_ptr[1] = __hmul(__hmul(*FP16_2, __float2half(4096.0f)), Scale); + return output; +} + +__device__ __forceinline__ void Dequant_32FP6_4Way(u_int32_t __restrict__ Reg[][4], + u_int32_t __restrict__* read_RPTR_Frag1, + u_int32_t __restrict__* read_RPTR_Frag2, + u_int32_t* Scales) +{ + u_int32_t* OutputRegs = reinterpret_cast(Reg); + u_int32_t* Frag1_PTR = read_RPTR_Frag1; + u_int32_t* Frag2_PTR = read_RPTR_Frag2; + half* Scale_RPTR = reinterpret_cast(Scales); + u_int32_t Packed_FP6 = 0; + u_int32_t tmp = 0; +// Dequantizing 32 FP6, each Loop dequantizing 4 FP6 +#pragma unroll(8) + for (int i = 0; i < 8; i++) { + // Frag1 + Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; + if (i % 4 == 3) + Frag1_PTR++; + else + (*Frag1_PTR) = (*Frag1_PTR) << 2; + // Frag2 + tmp = (*Frag2_PTR) & 0xf0f0f0f0; + tmp = tmp >> 2; + if (i % 2 == 1) + Frag2_PTR++; + else + (*Frag2_PTR) = (*Frag2_PTR) << 4; + // Packed_FP6 + Packed_FP6 = Packed_FP6 | tmp; + // + FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); + // + *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0]); // Muliply FP16 scales + OutputRegs += 1; + *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + OutputRegs += 1; + // Updating offset for FP16 scales for every two iterations + if (i % 2 == 1) Scale_RPTR += 2; + } +} + +/* + * + */ +__device__ __forceinline__ void ExtractFromSharedToReg_Scales(uint32_t* Scales, + half* WARP_SPTR_Scales) +{ + int lane_id = threadIdx.x % WARP_SIZE; + uint32_t* SPTR_uint = reinterpret_cast(WARP_SPTR_Scales); + uint32_t tmpReg = SPTR_uint[lane_id]; +#pragma unroll + for (int i = 0; i < 4; i++) { + // T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize); + Scales[i] = __shfl_sync(0xffffffff, tmpReg, i, 4); + } +} + +#endif diff --git a/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h new file mode 100644 index 000000000000..c8cc7243f341 --- /dev/null +++ b/deepspeed/inference/v2/kernels/core_ops/cuda_linear/include/weight_prepacking.h @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112 + +#include +#include +#include + +using namespace std; + +void Padding_8_FP6_To_8_Bytes(unsigned char Padded_FP6[], + unsigned char* FP6_Array) // padding 0 to the lowerest bit location +{ + Padded_FP6[0] = FP6_Array[0] & 0xfc; + Padded_FP6[1] = (FP6_Array[0] << 6) | ((FP6_Array[1] >> 2) & 0xfc); + Padded_FP6[2] = (FP6_Array[1] << 4) | ((FP6_Array[2] >> 4) & 0xfc); + Padded_FP6[3] = FP6_Array[2] << 2; + Padded_FP6[4] = FP6_Array[3] & 0xfc; + Padded_FP6[5] = (FP6_Array[3] << 6) | ((FP6_Array[4] >> 2) & 0xfc); + Padded_FP6[6] = (FP6_Array[4] << 4) | ((FP6_Array[5] >> 4) & 0xfc); + Padded_FP6[7] = FP6_Array[5] << 2; +} + +unsigned char Extract_2_Bits_From_4_PaddedFP6(unsigned char B1, + unsigned char B2, + unsigned char B3, + unsigned char B4) +{ + unsigned char out; + out = (B1 & 0xc0) | ((B2 & 0xc0) >> 2) | ((B3 & 0xc0) >> 4) | ((B4 & 0xc0) >> 6); + return out; +} + +unsigned char Extract_4_Bits_From_2_PaddedFP6( + unsigned char B1, + unsigned char + B2) // The highest two bits are already extracted by Extract_2_Bits_From_4_PaddedFP6(); +{ + unsigned char out; + out = ((B1 << 2) & 0xf0) | ((B2 >> 2) & 0x0f); + return out; +} + +// dealing with 4 1*8 blocks of FP6 +void Assign_32_FP6_To_4_Thread(vector Seg_2bit[], + vector Seg_4bit[], + unsigned char* PTR_1, + unsigned char* PTR_2, + unsigned char* PTR_3, + unsigned char* PTR_4) +{ + unsigned char Padded_8_FP8[4][8]; + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[0], PTR_1); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[1], PTR_2); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[2], PTR_3); + Padding_8_FP6_To_8_Bytes(Padded_8_FP8[3], PTR_4); + // + unsigned char Seg1_Byte1_T[4]; + unsigned char Seg1_Byte2_T[4]; + unsigned char Seg2_Byte1_T[4]; + unsigned char Seg2_Byte2_T[4]; + unsigned char Seg2_Byte3_T[4]; + unsigned char Seg2_Byte4_T[4]; + for (int t = 0; t < 4; t++) { + Seg1_Byte1_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[0][0 + t * 2], + Padded_8_FP8[0][1 + t * 2], + Padded_8_FP8[1][0 + t * 2], + Padded_8_FP8[1][1 + t * 2]); + Seg1_Byte2_T[t] = Extract_2_Bits_From_4_PaddedFP6(Padded_8_FP8[2][0 + t * 2], + Padded_8_FP8[2][1 + t * 2], + Padded_8_FP8[3][0 + t * 2], + Padded_8_FP8[3][1 + t * 2]); + Seg2_Byte1_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[0][0 + t * 2], Padded_8_FP8[0][1 + t * 2]); + Seg2_Byte2_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[1][0 + t * 2], Padded_8_FP8[1][1 + t * 2]); + Seg2_Byte3_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[2][0 + t * 2], Padded_8_FP8[2][1 + t * 2]); + Seg2_Byte4_T[t] = + Extract_4_Bits_From_2_PaddedFP6(Padded_8_FP8[3][0 + t * 2], Padded_8_FP8[3][1 + t * 2]); + } + // + for (int t = 0; t < 4; t++) { + Seg_2bit[t].push_back(Seg1_Byte1_T[t]); + Seg_2bit[t].push_back(Seg1_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte1_T[t]); + Seg_4bit[t].push_back(Seg2_Byte2_T[t]); + Seg_4bit[t].push_back(Seg2_Byte3_T[t]); + Seg_4bit[t].push_back(Seg2_Byte4_T[t]); + } + return; +} + +void BitInterleaving_2bit(unsigned char* PTR_4Bytes) +{ + unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + // int order_2bit[16] = {1,5,9,13,3,7,11,15,2,6,10,14,4,8,12,16}; // pre-defined order for + // bit-interleaving in QuantLLM + int order_2bit[16] = { + 2, 6, 10, 14, 4, 8, 12, 16, 1, 5, 9, 13, 3, 7, 11, 15}; // pre-defined order for + // bit-interleaving in QuantLLM + unsigned int Frags_2bit[16]; // The highest 2 bits are used to store the extracted fragments. + for (int i = 0; i < 16; i++) Frags_2bit[i] = (input << 2 * (order_2bit[i] - 1)) & 0xc0000000; + // + unsigned int output = 0x00000000; + for (int i = 0; i < 16; i++) output |= (Frags_2bit[i] >> (i * 2)); + // + *PTR_UINT = output; +} + +void BitInterleaving_4bit(unsigned char* PTR_4Bytes) +{ + unsigned int* PTR_UINT = reinterpret_cast(PTR_4Bytes); + unsigned int input = *PTR_UINT; + // + // int order_4bit[8] = {1,5,3,7,2,6,4,8}; // pre-defined order for bit-interleaving in QuantLLM + int order_4bit[8] = { + 2, 6, 4, 8, 1, 5, 3, 7}; // pre-defined order for bit-interleaving in QuantLLM + unsigned int Frags_4bit[8]; // The highest4 bits are used to store the extracted fragments. + for (int i = 0; i < 8; i++) Frags_4bit[i] = (input << 4 * (order_4bit[i] - 1)) & 0xf0000000; + // + unsigned int output = 0x00000000; + for (int i = 0; i < 8; i++) output |= (Frags_4bit[i] >> (i * 4)); + // + *PTR_UINT = output; +} + +/* + * Inputs: + * (1) unsigned char Weight_6bit [M*K*6/8] + * Outputs: + * (1) unsigned char Weight_2bit [M*K*2/8] + * (2) unsigned char Weight_4bit [M*K*4/8] + * + * Assumption: Weight_6bit, Weight_2bit, Weight_4bit all stored continuously in row-major. + * 8 FP6 = 6 Bytes + * 8 FP4 = 4 Bytes + * 8 FP2 = 2 Bytes + */ +void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K) +{ + assert(M % 64 == 0); + assert(K % 64 == 0); + // + unsigned char* Weight_6bit = reinterpret_cast(FP6Weights); + unsigned char* Weight_2bit = Weight_6bit; + unsigned char* Weight_4bit = Weight_6bit + M * K * 2 / 8; + // + vector A_Segment_2bit[32]; + vector A_Segment_4bit[32]; + // + size_t BytesPerRow = K * 6 / 8; + // Pass-1: (1) 2+4 split; (2) assign weights to 32 threads. + for (size_t i = 0; i < M / 64; i++) // + { + for (size_t j = 0; j < K / 16; j++) { + for (size_t k = 0; k < 64 / 16; k++) { + size_t row = i * 64 + k * 16; + size_t col = j * 16; + unsigned char* StartPTR_1 = Weight_6bit + row * BytesPerRow + col * 6 / 8; + unsigned char* StartPTR_2 = StartPTR_1 + 8 * BytesPerRow; + unsigned char* StartPTR_3 = StartPTR_1 + 8 * 6 / 8; + unsigned char* StartPTR_4 = StartPTR_2 + 8 * 6 / 8; + // Dealing with each 16*16 blocks then... + for (int l = 0; l < 8; l++) + Assign_32_FP6_To_4_Thread(&A_Segment_2bit[l * 4], + &A_Segment_4bit[l * 4], + StartPTR_1 + l * BytesPerRow, + StartPTR_2 + l * BytesPerRow, + StartPTR_3 + l * BytesPerRow, + StartPTR_4 + l * BytesPerRow); + } + } + } + // Verifying the length of 2_bit segments and 4_bit segments + size_t BytesPerThread_2bit = M * K * 2 / 8 / 32; + size_t BytesPerThread_4bit = M * K * 4 / 8 / 32; + for (int i = 0; i < 32; i++) { + assert(A_Segment_2bit[i].size() == BytesPerThread_2bit); + assert(A_Segment_4bit[i].size() == BytesPerThread_4bit); + } + // Pass-2: Optimizing coleasced global memory access + for (size_t i = 0; i < BytesPerThread_2bit / 4; i++) + for (int t = 0; t < 32; t++) + for (int b = 0; b < 4; b++) + Weight_2bit[i * 128 + t * 4 + (3 - b)] = + A_Segment_2bit[t] + [i * 4 + b]; // why (3-b): special byte order within a register + for (size_t i = 0; i < BytesPerThread_4bit / 4; i++) + for (int t = 0; t < 32; t++) + for (int b = 0; b < 4; b++) + Weight_4bit[i * 128 + t * 4 + (3 - b)] = + A_Segment_4bit[t][i * 4 + b]; // why (3-b):special byte order within a register + // Pass-3: Bit-level interleaving + for (size_t i = 0; i < BytesPerThread_2bit * 32 / 4; i++) + BitInterleaving_2bit(Weight_2bit + 4 * i); + for (size_t i = 0; i < BytesPerThread_4bit * 32 / 4; i++) + BitInterleaving_4bit(Weight_4bit + 4 * i); +} diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu index 84a9906cf037..cfa62f94596a 100644 --- a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu @@ -119,7 +119,10 @@ void launch_gated_activation_impl(T* output, DISPATCH_UNROLL(5); } else if (unroll == 6) { DISPATCH_UNROLL(6); + } else if (unroll == 7) { + DISPATCH_UNROLL(7); } else { + // TODO: provide a kernel with an outer loop to handle larger columns. throw std::runtime_error( "Called with more columns than supported, please report this bug and this limit will " "be increased."); diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py index f9da7ac5d23e..ebdb59bca920 100644 --- a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py +++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py @@ -164,7 +164,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=tensor.stride(), offset=cur_offset) - cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) + cur_offset += pad_to_aligned_offset(elem_size(tensor.dtype) * tensor.numel()) layer_metadata.params[p_name] = param_metadata diff --git a/deepspeed/inference/v2/modules/heuristics.py b/deepspeed/inference/v2/modules/heuristics.py index b89e95c0d834..d176206f3c60 100644 --- a/deepspeed/inference/v2/modules/heuristics.py +++ b/deepspeed/inference/v2/modules/heuristics.py @@ -86,8 +86,15 @@ def instantiate_linear(linear_config: DSLinearConfig, engine_config: RaggedInfer A linear module implementing the given configuration. """ - # Currently, we only have one implementation, so we just return it. - config = ConfigBundle(name="blas_fp_linear", config=linear_config) + quantization_mode = engine_config.quantization.quantization_mode + if quantization_mode is None: + config = ConfigBundle(name="blas_fp_linear", config=linear_config) + else: + # Currently, we only support ``quantized_wf6af16_linear``. + if quantization_mode == "wf6af16": + config = ConfigBundle(name="quantized_wf6af16_linear", config=linear_config) + else: + raise ValueError(f"Unsupported quantization mode: {quantization_mode}") return DSLinearRegistry.instantiate_config(config) diff --git a/deepspeed/inference/v2/modules/implementations/linear/__init__.py b/deepspeed/inference/v2/modules/implementations/linear/__init__.py index e76aab71c4cf..0501af54c4e6 100644 --- a/deepspeed/inference/v2/modules/implementations/linear/__init__.py +++ b/deepspeed/inference/v2/modules/implementations/linear/__init__.py @@ -4,3 +4,4 @@ # DeepSpeed Team from .blas_fp_linear import BlasFPLinear +from .quantized_linear import QuantizedWf6Af16Linear, fp_quantize diff --git a/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py new file mode 100644 index 000000000000..933cf55b2391 --- /dev/null +++ b/deepspeed/inference/v2/modules/implementations/linear/quantized_linear.py @@ -0,0 +1,205 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any, Dict, Optional + +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import InferenceCoreBuilder +from ....allocator import empty_from +from ....inference_utils import is_gated +from ....kernels.core_ops import ( + CUDAWf6Af16Linear, + CUDABiasActivation, + CUDAGatedActivation, +) + +from ...interfaces import DSLinearBase, DSLinearRegistry +from ...configs import DSLinearConfig +from ....inference_parameter import InferenceParameter + + +def fp_quantize(input: torch.FloatTensor, + num_bits: int = 6, + exp_bits: int = 3, + min_value: torch.FloatTensor = None, + max_value: torch.FloatTensor = None, + group_size: int = -1): + """ + Args: + inputs (`torch.FloatTensor`) + The input which needs to be quantized + num_bits (int, >=4) + Number of bits to use for quantization + exp_bits: + fp exp_bits + min_value/max_vlue (torch.FloatTensor) + Used for static activation quantization + group_size (int) N + The quantization block size, each N numbers has its own scaling + factor and off-site. -1 means use the last dim as the group_size + Returns: + quantized_fake_fp6 + The quantized weights, in fp16 format and contains fp6 value. + scales + Quantization scales + """ + + try: + from qtorch.quant import float_quantize + except ImportError: + raise ImportError("Please install qtorch to use this function") + + assert (min_value is None and max_value is None) or (min_value is not None and max_value is not None) + + assert input.dtype == torch.float16 + + orig_device = input.device + input = input.to(torch.float32).to(get_accelerator().current_device()) + if num_bits == 6 and exp_bits == 3: # this is default + q_range = 28 + else: + raise NotImplementedError + + man_bits = num_bits - exp_bits - 1 + input_shape = input.shape + + if group_size == -1: + group_size = input_shape[-1] + else: + # Only support per-channel quantization + raise NotImplementedError + num_groups = input.numel() // group_size + input = input.reshape(num_groups, -1) + + if min_value is None: + max_input = torch.amax(torch.abs(input), dim=-1).view(num_groups, -1) + else: + max_input = torch.max(min_value.abs(), max_value) # .view(-1) + scales = max_input / q_range # q_range + 1 + scales[scales == 0] = 1 # avoid zero scales + scaled_input = input / scales + + quantized_fake_fp6 = float_quantize(scaled_input, exp_bits, man_bits, rounding="nearest") + + quantized_fake_fp6 = quantized_fake_fp6.reshape(input_shape).contiguous().to(torch.float16).to(orig_device) + scales = scales.to(torch.float16).to(orig_device) + # Now the dequantized value is quantized_fake_fp6 * scales + + return quantized_fake_fp6, scales + + +@DSLinearRegistry.register_module +class QuantizedWf6Af16Linear(DSLinearBase): + """ + Linear DSModule for FP6 weight-only quantization kernel, where weight is FP6 + and activation is FP16. + """ + + @staticmethod + def name(): + return 'quantized_wf6af16_linear' + + @staticmethod + def supports_config(config: DSLinearConfig) -> bool: + if config.input_dtype != config.output_dtype: + return False + + # As for fp6 data items, they are packed and stored in a set of fp16 + # tensors. E.g., 8 fp6 data items are stored in 3 fp16 tensor. + if config.input_dtype != torch.float16: + return False + + if is_gated(config.activation): + try: + _ = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + else: + try: + _ = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) + except ValueError: + return False + + return True + + def __init__(self, config: DSLinearConfig, implementation_config: Dict[str, Any]) -> None: + super().__init__(config, implementation_config) + + self._linear_impl = CUDAWf6Af16Linear() + + if is_gated(config.activation): + # In the FP6 kernel implementation, the MatMul is W * A, where W is + # the weight and A is activation. M is the output channel size. + self.out_channels = self._config.out_channels * 2 + self.in_channels = self._config.in_channels + self._is_gated = True + self._act_fn = CUDAGatedActivation(config.out_channels, config.output_dtype, config.activation) + self._double_buffer = torch.empty((config.max_tokens, config.out_channels * 2), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + else: + self.out_channels = self._config.out_channels + self.in_channels = self._config.in_channels + self._is_gated = False + self._act_fn = CUDABiasActivation(config.out_channels, config.output_dtype, config.activation) + + self._output = torch.empty((config.max_tokens, config.out_channels), + dtype=config.output_dtype, + device=get_accelerator().current_device()) + + self.inf_module = InferenceCoreBuilder().load() + self.inf_module.create_handle() + self.preprocess_weight = self.inf_module.preprocess_weight + + self.quantizer = fp_quantize + + def transform_param(self, param: torch.Tensor) -> InferenceParameter: + """ + Converts param to same data type as input and output. + + Parameters: + param (torch.Tensor): Weight or bias tensor. + """ + # It expects that the quantization scales are store in the attribute `scales`. + + if param.ndim == 1: # bias, do nothing + return InferenceParameter.initialize(param) + + quantized_fake_fp6, scales = self.quantizer(param, num_bits=6, exp_bits=3) + + # This is for debugging, will delete before release. + assert (quantized_fake_fp6.dtype == torch.float16) + assert quantized_fake_fp6.shape[0] == self.out_channels + assert scales.numel() == self.out_channels + + weights_2bit, weights_4bit = self.preprocess_weight(quantized_fake_fp6) + + return InferenceParameter.initialize(weights_2bit, weights_4bit=weights_4bit, scales=scales) + + def forward(self, hidden_states: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor] = None) -> torch.Tensor: + weights_2bit = w + weights_4bit = w.weights_4bit + scales = w.scales + output = empty_from(self._output, (hidden_states.shape[0], self._config.out_channels)) + if self._is_gated: + staging_output = empty_from(self._double_buffer, (hidden_states.shape[0], self.out_channels)) + self._linear_impl(staging_output, hidden_states, weights_2bit, weights_4bit, scales, self.out_channels, + hidden_states.shape[0], self.in_channels) + self._act_fn(output, staging_output, b) + else: + self._linear_impl(output, hidden_states, weights_2bit, weights_4bit, scales, self.out_channels, + hidden_states.shape[0], self.in_channels) + self._act_fn(output, b) + + return output + + @property + def output(self) -> torch.Tensor: + """ + Return the padded, pre-allocated output Tensor. + """ + return self._output diff --git a/op_builder/inference_core_ops.py b/op_builder/inference_core_ops.py index 8073b63ad16b..3c53774d0a50 100755 --- a/op_builder/inference_core_ops.py +++ b/op_builder/inference_core_ops.py @@ -57,6 +57,8 @@ def get_prefix(self): return "deepspeed" if os.path.isdir(ds_path) else ".." def sources(self): + import torch + sources = [ "inference/v2/kernels/core_ops/core_ops.cpp", "inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp", @@ -69,6 +71,15 @@ def sources(self): "inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu", ] + # The source files with specific GPU architecture requirements. + if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda + cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda + if cuda_capability != 8: + self.warning("FP6 quantization kernel is only supported on Ampere architectures") + else: + sources.append("inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu") + sources.append("inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp") + prefix = self.get_prefix() sources = [os.path.join(prefix, src) for src in sources] return sources @@ -83,6 +94,7 @@ def include_paths(self): 'inference/v2/kernels/core_ops/cuda_layer_norm', 'inference/v2/kernels/core_ops/cuda_rms_norm', 'inference/v2/kernels/core_ops/gated_activations', + 'inference/v2/kernels/core_ops/cuda_linear', 'inference/v2/kernels/includes', ] diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt index 7a40ae814cbe..b7fd13787e8b 100644 --- a/requirements/requirements-inf.txt +++ b/requirements/requirements-inf.txt @@ -1,6 +1,7 @@ google lm-eval==0.3.0 protobuf +qtorch safetensors sentencepiece transformers>=4.32.1 diff --git a/tests/unit/inference/v2/modules/test_quantized_linear_module.py b/tests/unit/inference/v2/modules/test_quantized_linear_module.py new file mode 100644 index 000000000000..a7bd965072ac --- /dev/null +++ b/tests/unit/inference/v2/modules/test_quantized_linear_module.py @@ -0,0 +1,184 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Optional + +import pytest +import torch + +from deepspeed.accelerator import get_accelerator +from deepspeed.inference.v2.inference_utils import ActivationType, DtypeEnum, is_gated +from deepspeed.inference.v2.modules import ConfigBundle +from deepspeed.inference.v2.modules.configs import DSLinearConfig +from deepspeed.inference.v2.modules.interfaces import DSLinearRegistry +from ...v2.inference_test_utils import allclose + + +def reference_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + out_states = torch.nn.functional.linear(hidden_states, weight, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _fp6_quant_dequant_weights(weight: torch.Tensor) -> torch.Tensor: + from deepspeed.inference.v2.modules.implementations.linear.quantized_linear import fp_quantize + weight_quantized_fake_fp6, scales = fp_quantize(weight, num_bits=6, exp_bits=3) + return weight_quantized_fake_fp6 * scales + + +def quant_dequant_implementation(hidden_states: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], + act_type: ActivationType) -> torch.Tensor: + dtype = hidden_states.dtype + weight_dequantized = _fp6_quant_dequant_weights(weight) + out_states = torch.nn.functional.linear(hidden_states, weight_dequantized, bias) + out_states.float() + + if is_gated(act_type): + act_func_map = { + ActivationType.ReGLU: torch.nn.functional.relu, + ActivationType.GEGLU: lambda x: torch.nn.functional.gelu(x, approximate="tanh"), + ActivationType.SiGLU: torch.nn.functional.silu, + } + + act_act = out_states[..., ::2] + act_linear = out_states[..., 1::2] + + act_act = act_func_map[act_type](act_act) + out_states = act_act * act_linear + else: + act_func_map = { + ActivationType.RELU: torch.nn.functional.relu, + ActivationType.GELU: torch.nn.functional.gelu, + ActivationType.SILU: torch.nn.functional.silu, + ActivationType.IDENTITY: lambda x: x, + } + + out_states = act_func_map[act_type](out_states) + return out_states.to(dtype) + + +def _fp6_quantized_linear_helper(tokens: int, + in_channels: int, + out_channels: int, + dtype: DtypeEnum, + act_fn: ActivationType, + use_bias: bool = True, + expect_failure: bool = False) -> None: + # The current FP6 kernel only supports NVIDIA Ampere GPUs. + if not 'cuda' in get_accelerator().current_device_name(): + return + major, _ = torch.cuda.get_device_capability() #ignore-cuda + if major != 8: + return + + # Input vals + hidden_states = torch.randn( + (tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + + weight_out_channels = 2 * \ + out_channels if is_gated(act_fn) else out_channels + weight = torch.randn( + (weight_out_channels, in_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + if use_bias: + bias = torch.randn( + (weight_out_channels), dtype=dtype.value, device=get_accelerator().current_device_name()) * .01 + else: + bias = None + + # quantize and dequantize output + ref_quant_dequant_output = quant_dequant_implementation(hidden_states, weight, bias, act_fn) + + linear_config = DSLinearConfig(max_tokens=2048, + in_channels=in_channels, + out_channels=out_channels, + activation=act_fn, + input_dtype=dtype, + output_dtype=dtype) + bundle = ConfigBundle(name='quantized_wf6af16_linear', config=linear_config) + fp6_linear_module = DSLinearRegistry.instantiate_config(bundle) + weight_fp6 = fp6_linear_module.transform_param(weight.clone().cpu()).to(get_accelerator().current_device_name()) + + if expect_failure: + with pytest.raises(ValueError) as excinfo: + ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) + assert "The out and in channel should be multiple of 256 and 64 respectively." in str(excinfo.value) + else: + ds_output = fp6_linear_module(hidden_states, weight_fp6, bias) + # The current FP6 kernel uses FP16 Tensor Core. + tolerances = (3e-2, 2e-3) # tolerances for fp16 + + # Check DeepSpeed implementation + assert allclose(ds_output, ref_quant_dequant_output, tolerances=tolerances) + + +all_acts = [ + ActivationType.RELU, + ActivationType.GELU, + ActivationType.SILU, + ActivationType.GEGLU, + ActivationType.ReGLU, + ActivationType.SiGLU, +] +all_tokens = [1, 37] +all_in_out_channels = [ + (4096, 4096), + (8192, 28672), +] + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens", all_tokens) +@pytest.mark.parametrize("in_channels, out_channels", all_in_out_channels) +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_fp6_quantized_linear_act_fn(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, + use_bias: bool) -> None: + _fp6_quantized_linear_helper(tokens=tokens, + in_channels=in_channels, + out_channels=out_channels, + dtype=DtypeEnum.fp16, + act_fn=act_fn, + use_bias=use_bias) + + +# Other shapes, not supported by FP6 kernels. Will raise ValueError. +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize("tokens", all_tokens) +@pytest.mark.parametrize("in_channels, out_channels", [(4608, 1728)]) +@pytest.mark.parametrize("act_fn", all_acts) +@pytest.mark.parametrize("use_bias", [True, False]) +def test_fp6_quantized_linear_act_fn_fail(tokens: int, in_channels: int, out_channels: int, act_fn: ActivationType, + use_bias: bool) -> None: + _fp6_quantized_linear_helper(tokens=tokens, + in_channels=in_channels, + out_channels=out_channels, + dtype=DtypeEnum.fp16, + act_fn=act_fn, + use_bias=use_bias, + expect_failure=True) diff --git a/tests/unit/inference/v2/ragged/test_manager_configs.py b/tests/unit/inference/v2/ragged/test_manager_configs.py index bdd513445ddb..a5f270cced8c 100644 --- a/tests/unit/inference/v2/ragged/test_manager_configs.py +++ b/tests/unit/inference/v2/ragged/test_manager_configs.py @@ -5,7 +5,7 @@ import pytest -from pydantic import ValidationError +from deepspeed.pydantic_v1 import ValidationError from deepspeed.inference.v2.ragged import DSStateManagerConfig