Skip to content

Commit

Permalink
Enable warp size 64
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonix committed Jun 14, 2024
1 parent c80bc8a commit 85b3f13
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
18 changes: 10 additions & 8 deletions llmc/amd_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,20 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
}
#else
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
#ifdef WAVEFRONTSIZE64
for (int mask = 32; mask > 0; mask >>= 1) { x += __shfl_xor(x, mask, 64); }
#else
for (int mask = 16; mask > 0; mask >>= 1) { x += __shfl_xor(x, mask, 32); }
#endif
return x;
}

static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
#ifdef WAVEFRONTSIZE64
for (int mask = 32; mask > 0; mask >>= 1) { x = fmaxf(x, __shfl_xor(x, mask, 64)); }
#else
for (int mask = 16; mask > 0; mask >>= 1) { x = fmaxf(x, __shfl_xor(x, mask, 32)); }
#endif
return x;
}
#endif
Expand Down
4 changes: 4 additions & 0 deletions llmc/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@ extern cudaDeviceProp deviceProp;

// WarpSize is not a compile time constant
// Defining here like this possibly allows the compiler to optimize better
#ifdef WAVEFRONTSIZE64
#define WARP_SIZE 64U
#else
#define WARP_SIZE 32U
#endif

// try to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance
// this needs to be defines rather than queried to be used for __launch_bounds__
Expand Down
2 changes: 1 addition & 1 deletion llmc/matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,

const int block_size = deviceProp.maxThreadsPerMultiProcessor == 1536 ? 768 : 1024;

dim3 block_dim = {4, 8, (unsigned)block_size/WARP_SIZE};
dim3 block_dim = {4, WARP_SIZE/4, (unsigned)block_size/WARP_SIZE};
const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16
const int grid_size_x = CEIL_DIV(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16
const int grid_size_y = max(1, deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / (block_size * grid_size_x)); // full GPU!
Expand Down

0 comments on commit 85b3f13

Please sign in to comment.