Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Jun 12, 2024
1 parent 72a8dbf commit aa58624
Showing 1 changed file with 112 additions and 108 deletions.
220 changes: 112 additions & 108 deletions csrc/custom/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,125 +23,129 @@ __device__ __forceinline__ float4 load_ntmprl(const float4* addr) {
return make_float4(dat0, dat1, dat2, dat3);
}

//TBlock fetches entire rows of A, and entire col of B (K dimension); assume N=1 for time being
//grid is M/A_NUM_ROWS blocks
// TBlock fetches entire rows of A, and entire col of B (K dimension); assume
// N=1 for time being grid is M/A_NUM_ROWS blocks
template <int NUM_A_ROWS_PER_BLOCK>
__global__ void LLGemm1_kernel(float4 *af4, __half2 *bf4, __half2 *c, const int K) {
__shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE];
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8;
const int threadid = threadIdx.x;
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid/16;
const int qthreadid = threadid%16;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
__half2 colB_elem4x,colB_elem4y,colB_elem4z,colB_elem4w;
float4 sum4; //[NUM_A_ROWS_PER_BLOCK];
float acc[NUM_A_ROWS_PER_BLOCK] = {0.0};
__half2 acch2;
__half2 oval;

// As we later use warp shuffle operations, we may have more threads in the block
// than the actual available data, hence the if guard here.
if(threadid * 8 < K) {
#pragma unroll
for (int i=0; i<NUM_A_ROWS_PER_BLOCK; i++) {
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
}
__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c,
const int K) {
__shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE];
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8;
const int threadid = threadIdx.x;
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid / 16;
const int qthreadid = threadid % 16;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
__half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
float4 sum4; //[NUM_A_ROWS_PER_BLOCK];
float acc[NUM_A_ROWS_PER_BLOCK] = {0.0};
__half2 acch2;
__half2 oval;

// As we later use warp shuffle operations, we may have more threads in the
// block than the actual available data, hence the if guard here.
if (threadid * 8 < K) {
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
}

colB_elem4x = bf4[threadid*4+0];
colB_elem4y = bf4[threadid*4+1];
colB_elem4z = bf4[threadid*4+2];
colB_elem4w = bf4[threadid*4+3];

__half2 Af2; __half2 Bf2; float2 S;

auto Ah2ptr = reinterpret_cast<__half2 *>(&rowA_elem4);
__half2 *ah2lptr;

#pragma unroll
for (int i=0; i<NUM_A_ROWS_PER_BLOCK; i++) {
// Multiply-add on 8 half.
ah2lptr = Ah2ptr+i*4;
Af2 = *(ah2lptr);
acch2 = __hmul2(Af2,colB_elem4x);
Af2 = *(ah2lptr+1);
acch2 = __hfma2(Af2,colB_elem4y,acch2);
Af2 = *(ah2lptr+2);
acch2 = __hfma2(Af2,colB_elem4z,acch2);
Af2 = *(ah2lptr+3);
acch2 = __hfma2(Af2,colB_elem4w,acch2);
S = __half22float2(acch2);

// See comment above concerning the if guard.
if(threadid * 8 < K) {
acc[i] = S.x + S.y; // accumulation on float
}
}
colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];

// all reduce accross warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#pragma unroll
for (int i=0; i<NUM_A_ROWS_PER_BLOCK; i++) {
acc[i] += __shfl_xor(acc[i], mask);
}
}
__half2 Af2;
__half2 Bf2;
float2 S;

// Warp leaders store the data to shared memory.
if (lane < NUM_A_ROWS_PER_BLOCK) {
red_smem[lane][warp] = acc[lane];
}
auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4);
__half2* ah2lptr;

// Make sure the data is in shared memory.
__syncthreads();
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// Multiply-add on 8 half.
ah2lptr = Ah2ptr + i * 4;
Af2 = *(ah2lptr);
acch2 = __hmul2(Af2, colB_elem4x);
Af2 = *(ah2lptr + 1);
acch2 = __hfma2(Af2, colB_elem4y, acch2);
Af2 = *(ah2lptr + 2);
acch2 = __hfma2(Af2, colB_elem4z, acch2);
Af2 = *(ah2lptr + 3);
acch2 = __hfma2(Af2, colB_elem4w, acch2);
S = __half22float2(acch2);

// See comment above concerning the if guard.
if (threadid * 8 < K) {
acc[i] = S.x + S.y; // accumulation on float
}
}

if (qwarpid<NUM_A_ROWS_PER_BLOCK) {
acc[qwarpid] = qthreadid<num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
#pragma unroll
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid],16);
// all reduce accross warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
acc[i] += __shfl_xor(acc[i], mask);
}
}

if (threadid%WARP_SIZE ==0 or threadid%WARP_SIZE==32) {
oval = __float22half2_rn(make_float2(acc[qwarpid],oval2));
c[blockIdx.x*NUM_A_ROWS_PER_BLOCK/2+qwarpid/2] = oval;
}
}
}
// Warp leaders store the data to shared memory.
if (lane < NUM_A_ROWS_PER_BLOCK) {
red_smem[lane][warp] = acc[lane];
}

// define the kernel calling code:
//template <typename T>
void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block=4) {
float4 *af4 = reinterpret_cast<float4*>(in_a);
auto *bf4 = reinterpret_cast<__half2*>(in_b);
auto *c = reinterpret_cast<__half2*>(out_c);
// Make sure the data is in shared memory.
__syncthreads();

// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle operations.
const int NUM_THREADS = K*2/16 % WARP_SIZE == 0 ? K*2/16 : K*2/16 + (WARP_SIZE - K*2/16 % WARP_SIZE);
if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
#pragma unroll
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid], 16);

int NUM_BLOCKS = M/rows_per_block;
if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) {
oval = __float22half2_rn(make_float2(acc[qwarpid], oval2));
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
}
}
}

if (rows_per_block==2) {
LLGemm1_kernel<2><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
}
else if (rows_per_block==4) {
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
}
else if (rows_per_block==8) {
LLGemm1_kernel<8><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
}
else if (rows_per_block==16) {
LLGemm1_kernel<16><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
}
else {
NUM_BLOCKS = M/4;
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
}
// define the kernel calling code:
// template <typename T>
void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int rows_per_block = 4) {
float4* af4 = reinterpret_cast<float4*>(in_a);
auto* bf4 = reinterpret_cast<__half2*>(in_b);
auto* c = reinterpret_cast<__half2*>(out_c);

// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const int NUM_THREADS =
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);

int NUM_BLOCKS = M / rows_per_block;

if (rows_per_block == 2) {
LLGemm1_kernel<2><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 4) {
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 8) {
LLGemm1_kernel<8><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 16) {
LLGemm1_kernel<16><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else {
NUM_BLOCKS = M / 4;
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
}

cudaError_t err = cudaGetLastError();
if (cudaSuccess != err)
Expand Down

0 comments on commit aa58624

Please sign in to comment.