Skip to content

Commit

Permalink
Adjust kQuantizeBlockwise to work with WARP size 64
Browse files Browse the repository at this point in the history
  • Loading branch information
arlo-phoenix committed Jan 22, 2024
1 parent 32cd5e0 commit e03a8bd
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 20 deletions.
37 changes: 22 additions & 15 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -740,21 +740,28 @@ template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TY
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
#ifdef BNB_USE_HIP
const int CUB_NUM_PER_TH=(BLOCK_SIZE/NUM_PER_TH % __AMDGCN_WAVEFRONT_SIZE == 0) ? NUM_PER_TH : NUM_PER_TH/2;
#else
const int CUB_NUM_PER_TH=NUM_PER_TH;
#endif
const int DATA_NUM_PER_TH=(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH;

const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);

T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
T vals[CUB_NUM_PER_TH];
float rand_vals[CUB_NUM_PER_TH];
unsigned char qvals[DATA_NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;

typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef cub::BlockLoad<T, BLOCK_SIZE/CUB_NUM_PER_TH, CUB_NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/CUB_NUM_PER_TH, DATA_NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef cub::BlockReduce<float, BLOCK_SIZE/CUB_NUM_PER_TH> BlockReduce;
typedef cub::BlockLoad<float, BLOCK_SIZE/CUB_NUM_PER_TH, CUB_NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;

__shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf;
Expand All @@ -779,8 +786,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
// 2. broadcast local max
// 3. normalize inputs and quantize

#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
#pragma unroll CUB_NUM_PER_TH
for(int j = 0; j < CUB_NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));

local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
Expand Down Expand Up @@ -809,8 +816,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
switch(DATA_TYPE)
{
case General8bit:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
#pragma unroll CUB_NUM_PER_TH
for(int j = 0; j < CUB_NUM_PER_TH; j++)
{
if(!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
Expand All @@ -819,17 +826,17 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
}
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
#pragma unroll CUB_NUM_PER_TH
for(int j = 0; j < DATA_NUM_PER_TH; j++)
{
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
#pragma unroll CUB_NUM_PER_TH
for(int j = 0; j < DATA_NUM_PER_TH; j++)
{
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
Expand Down
6 changes: 1 addition & 5 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@


#ifdef BNB_USE_HIP
// check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's
// dirty hack to force wavefront_size 32 so this compiles
// RDNA 2 defaults to 64 which conflicts with kQuantizeBlockwise
#define __AMDGCN_WAVEFRONT_SIZE 32

#include <hip/hip_runtime_api.h>
#include <hip/hip_fp16.h>
Expand Down Expand Up @@ -58,7 +54,7 @@
#define cublasLtHandle_t hipblasLtHandle_t
#define cublasLtCreate hipblasLtCreate
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT

#else
#include <cuda_runtime_api.h>
Expand Down

0 comments on commit e03a8bd

Please sign in to comment.