-
Notifications
You must be signed in to change notification settings - Fork 36
/
dense.cu
171 lines (150 loc) · 6.43 KB
/
dense.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#include <stdio.h>
#include <stdlib.h>
// CUDA runtime
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "dense_help_func.hpp"
// cal offset from row col and ld , in row-major matrix, ld is the width of the matrix
#define OFFSET(row, col, ld) ((row) * (ld) + (col))
// transfer float4
#define FETCH_FLOAT4(pointer) (reinterpret_cast<float4*>(&(pointer))[0])
template <
const int BLOCK_SIZE_M, // width of block of C that each thread block calculate
const int BLOCK_SIZE_K, // height of block of A that each thread block load into shared memory
const int BLOCK_SIZE_N, // height of block of C that each thread block calculate
const int THREAD_SIZE_Y, // height of block of C that each thread calculate
const int THREAD_SIZE_X, // width of block of C that each thread calculate
const bool ENABLE_DOUBLE_BUFFER // whether enable double buffering or not
>
__global__ void MatrixMulCUDA6(
float * __restrict__ A,
float * __restrict__ B,
float * __restrict__ C,
const int M,
const int K,
const int N,
float alpha,
float beta
) {
// size of thread block
const int bszx = BLOCK_SIZE_N / THREAD_SIZE_X;
const int bszy = BLOCK_SIZE_M / THREAD_SIZE_Y;
const int THREAD_NUM_PER_BLOCK = bszy * bszx;
// thread id
const int tid = threadIdx.y * bszx + threadIdx.x;
// shared memory
__shared__ float As[BLOCK_SIZE_M][BLOCK_SIZE_K]; // avoid bank conflict
__shared__ float Bs[BLOCK_SIZE_K][BLOCK_SIZE_N];
// registers for C
float accum[THREAD_SIZE_Y][THREAD_SIZE_X] = {0};
// row number and col number that needs to be loaded blockIdx.y this thread
const int A_TILE_ROW = tid / BLOCK_SIZE_K;
const int B_TILE_ROW = tid / BLOCK_SIZE_N;
const int A_TILE_COL = tid % BLOCK_SIZE_K;
const int B_TILE_COL = tid % BLOCK_SIZE_N;
// row stride that thread uses to load multiple rows of a tile
const int A_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K;
const int B_TILE_ROW_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_N;
const int A_S = BLOCK_SIZE_M / THREAD_SIZE_Y;
const int B_S = BLOCK_SIZE_N / THREAD_SIZE_X;
// can not unroll since K can not be determined at this point
for (int tile_idx = 0 ; tile_idx < K ; tile_idx += BLOCK_SIZE_K) {
// load A from global memory to shared memory
// #pragma unroll
// for ( int i = A_TILE_ROW ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
// const int row = BLOCK_SIZE_M * blockIdx.y + i;
// const int col = A_TILE_COL + tile_idx;
// if (blockIdx.x == gridDim.x -1 || blockIdx.y == gridDim.y - 1) {
// As[i][A_TILE_COL] = row < M && col < K ? A[OFFSET(
// row, // row
// col, // col
// K )] : 0;
// } else {
// As[i][A_TILE_COL] = A[OFFSET(
// row, // row
// col, // col
// K )];
// }
// }
// // load B from global memory to shared memory
// #pragma unroll
// for ( int i = B_TILE_ROW ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
// const int row = tile_idx + i;
// const int col = B_TILE_COL + BLOCK_SIZE_N * blockIdx.x;
// if (blockIdx.x == gridDim.x -1 || blockIdx.y == gridDim.y - 1) {
// Bs[i][B_TILE_COL] = row < K && col < N ? B[OFFSET(
// row, // row
// col, // col
// N )] : 0;
// } else {
// Bs[i][B_TILE_COL] = B[OFFSET(
// row, // row
// col, // col
// N )];
// }
// }
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_M ; i += A_TILE_ROW_STRIDE) {
const int row = BLOCK_SIZE_M * blockIdx.y + i + A_TILE_ROW ;
const int col = A_TILE_COL + tile_idx;
if (blockIdx.x == gridDim.x -1 || blockIdx.y == gridDim.y - 1) {
As[i + A_TILE_ROW ][A_TILE_COL] = row < M && col < K ? A[OFFSET(
row, // row
col, // col
K )] : 0;
} else {
As[i + A_TILE_ROW ][A_TILE_COL] = A[OFFSET(
row, // row
col, // col
K )];
}
}
// load B from global memory to shared memory
#pragma unroll
for ( int i = 0 ; i < BLOCK_SIZE_K; i += B_TILE_ROW_STRIDE) {
const int row = tile_idx + i + B_TILE_ROW;
const int col = B_TILE_COL + BLOCK_SIZE_N * blockIdx.x;
if (blockIdx.x == gridDim.x -1 || blockIdx.y == gridDim.y - 1) {
Bs[i + B_TILE_ROW][B_TILE_COL] = row < K && col < N ? B[OFFSET(
row, // row
col, // col
N )] : 0;
} else {
Bs[i + B_TILE_ROW][B_TILE_COL] = B[OFFSET(
row, // row
col, // col
N )];
}
}
__syncthreads();
// compute c
#pragma unroll
for (int k = 0; k < BLOCK_SIZE_K; ++ k) {
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
// accum[thread_y][thread_x] += frag_a[thread_y] * frag_b[thread_x];
accum[thread_y][thread_x] += As[thread_y * A_S + threadIdx.y][k] * Bs[k][thread_x * B_S + threadIdx.x];
}
}
}
__syncthreads();
}
// store back to C
#pragma unroll
for (int thread_y = 0; thread_y < THREAD_SIZE_Y; ++thread_y) {
#pragma unroll
for (int thread_x = 0; thread_x < THREAD_SIZE_X; ++thread_x) {
const int row = BLOCK_SIZE_M * blockIdx.y + thread_y * A_S + threadIdx.y;
const int col = BLOCK_SIZE_N * blockIdx.x + thread_x * B_S + threadIdx.x;
if (blockIdx.x == gridDim.x -1 || blockIdx.y == gridDim.y - 1) {
if (row < M && col < N) {
C[OFFSET(row, col, N)] = C[OFFSET(row, col, N)] * beta + accum[thread_y][thread_x] * alpha;
}
} else {
C[OFFSET(row, col, N)] = C[OFFSET(row, col, N)] * beta + accum[thread_y][thread_x] * alpha;
}
}
}
}