-
Notifications
You must be signed in to change notification settings - Fork 3
/
fast_gemv.cu
273 lines (247 loc) · 10.8 KB
/
fast_gemv.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <driver_functions.h>
#include <stdio.h>
#include "fast_gemv.cuh"
#include "utility.cuh"
///////////////////////////// NORMAL //////////////////////////////
// thread_per_block = blockDim.x
// blockDim.y <= SHARED_MEM_MAX_ROWS
__global__ void gemv_fp16(half* mat, half* vec, half* res, unsigned int n,
unsigned int num_per_thread) {
float sum = 0;
// each thread load num_per_thread elements from global
unsigned int tid = threadIdx.x;
unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int start_idx = threadIdx.x;
float4* mat4 = reinterpret_cast<float4*>(mat);
float4* vec4 = reinterpret_cast<float4*>(vec);
#pragma unroll
for (int iter = 0; iter < num_per_thread >> 3; iter++) {
unsigned int j = start_idx + iter * blockDim.x;
if (j < n >> 3) {
float4 vec_val = vec4[j];
float4 mat_val = mat4[row * (n >> 3) + j];
const half2* vec_h1 = (half2*)&vec_val.x;
const half2* vec_h2 = (half2*)&vec_val.y;
const half2* vec_h3 = (half2*)&vec_val.z;
const half2* vec_h4 = (half2*)&vec_val.w;
const half2* mat_h1 = (half2*)&mat_val.x;
const half2* mat_h2 = (half2*)&mat_val.y;
const half2* mat_h3 = (half2*)&mat_val.z;
const half2* mat_h4 = (half2*)&mat_val.w;
sum += static_cast<float>(vec_h1->x) * static_cast<float>(mat_h1->x);
sum += static_cast<float>(vec_h1->y) * static_cast<float>(mat_h1->y);
sum += static_cast<float>(vec_h2->x) * static_cast<float>(mat_h2->x);
sum += static_cast<float>(vec_h2->y) * static_cast<float>(mat_h2->y);
sum += static_cast<float>(vec_h3->x) * static_cast<float>(mat_h3->x);
sum += static_cast<float>(vec_h3->y) * static_cast<float>(mat_h3->y);
sum += static_cast<float>(vec_h4->x) * static_cast<float>(mat_h4->x);
sum += static_cast<float>(vec_h4->y) * static_cast<float>(mat_h4->y);
}
}
sum = warpReduceSum(sum, blockDim.x);
if (blockDim.x <= WARP_SIZE) {
if (tid == 0) {
res[row] = __float2half(sum);
}
return;
}
// Shared mem for partial sums (one per warp in the block)
static __shared__ float warpLevelSums[SHARED_MEM_MAX_ROWS][WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (laneId == 0) warpLevelSums[threadIdx.y][warpId] = sum;
__syncthreads();
// read from shared memory only if that warp existed
sum = (threadIdx.x < blockDim.x / WARP_SIZE)
? warpLevelSums[threadIdx.y][laneId]
: 0.0;
// Final reduce using first warp
if (warpId == 0) sum = warpReduceSum(sum, blockDim.x / WARP_SIZE);
if (tid == 0) {
res[row] = __float2half(sum);
}
}
///////////////////////////// QUANTIZED-INT8 //////////////////////////////
__global__ void gemv_quantized_int8(int8_t* mat, half* vec, half* res,
unsigned int n, half scale, half zero_point,
unsigned int num_per_thread) {
float sum = 0;
// each thread load num_per_thread elements from global
unsigned int tid = threadIdx.x;
unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int start_idx = threadIdx.x;
half4* mat4 = reinterpret_cast<half4*>(mat);
float4* vec4 = reinterpret_cast<float4*>(vec);
float zero_point_f = static_cast<float>(zero_point);
float scale_f = static_cast<float>(scale);
#pragma unroll
for (int iter = 0; iter < num_per_thread >> 3; iter++) {
unsigned int j = start_idx + iter * blockDim.x;
if (j < n >> 3) {
float4 vec_val = vec4[j];
half4 mat_val = mat4[row * (n >> 3) + j];
const half2* vec_h1 = (half2*)&vec_val.x;
const half2* vec_h2 = (half2*)&vec_val.y;
const half2* vec_h3 = (half2*)&vec_val.z;
const half2* vec_h4 = (half2*)&vec_val.w;
const int8_2* mat_h1 = (int8_2*)&mat_val.x;
const int8_2* mat_h2 = (int8_2*)&mat_val.y;
const int8_2* mat_h3 = (int8_2*)&mat_val.z;
const int8_2* mat_h4 = (int8_2*)&mat_val.w;
sum += static_cast<float>(vec_h1->x) *
(static_cast<float>(mat_h1->x) - zero_point_f);
sum += static_cast<float>(vec_h1->y) *
(static_cast<float>(mat_h1->y) - zero_point_f);
sum += static_cast<float>(vec_h2->x) *
(static_cast<float>(mat_h2->x) - zero_point_f);
sum += static_cast<float>(vec_h2->y) *
(static_cast<float>(mat_h2->y) - zero_point_f);
sum += static_cast<float>(vec_h3->x) *
(static_cast<float>(mat_h3->x) - zero_point_f);
sum += static_cast<float>(vec_h3->y) *
(static_cast<float>(mat_h3->y) - zero_point_f);
sum += static_cast<float>(vec_h4->x) *
(static_cast<float>(mat_h4->x) - zero_point_f);
sum += static_cast<float>(vec_h4->y) *
(static_cast<float>(mat_h4->y) - zero_point_f);
}
}
sum *= scale_f;
sum = warpReduceSum(sum, blockDim.x);
if (blockDim.x <= WARP_SIZE) {
if (tid == 0) {
res[row] = __float2half(sum);
}
return;
}
// Shared mem for partial sums (one per warp in the block)
static __shared__ float warpLevelSums[SHARED_MEM_MAX_ROWS][WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (laneId == 0) warpLevelSums[threadIdx.y][warpId] = sum;
__syncthreads();
// read from shared memory only if that warp existed
sum = (threadIdx.x < blockDim.x / WARP_SIZE)
? warpLevelSums[threadIdx.y][laneId]
: 0.0;
// Final reduce using first warp
if (warpId == 0) sum = warpReduceSum(sum, blockDim.x / WARP_SIZE);
if (tid == 0) {
res[row] = __float2half(sum);
}
}
///////////////////////////// QUANTIZED-INT4 //////////////////////////////
// based on previous experiments, num_per_thread can >= 16
__global__ void gemv_quantized_int4(uint4_2* mat, half* vec, half* res,
unsigned int n, half scale, half zero_point,
unsigned int num_per_thread) {
float sum = 0;
// each thread load num_per_thread elements from global
unsigned int tid = threadIdx.x;
unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int start_idx = threadIdx.x;
uint4_2_4* mat4 = reinterpret_cast<uint4_2_4*>(mat);
float4* vec4 = reinterpret_cast<float4*>(vec);
float zero_point_f = static_cast<float>(zero_point);
float scale_f = static_cast<float>(scale);
#pragma unroll
for (int iter = 0; iter < num_per_thread >> 4; iter++) {
unsigned int j = 2 * (start_idx + iter * blockDim.x);
if (j < n >> 3) {
float4 vec_val_1 = vec4[j]; // 8 half
float4 vec_val_2 = vec4[j + 1];
const half2* vec_h1 = (half2*)&vec_val_1.x;
const half2* vec_h2 = (half2*)&vec_val_1.y;
const half2* vec_h3 = (half2*)&vec_val_1.z;
const half2* vec_h4 = (half2*)&vec_val_1.w;
const half2* vec_h5 = (half2*)&vec_val_2.x;
const half2* vec_h6 = (half2*)&vec_val_2.y;
const half2* vec_h7 = (half2*)&vec_val_2.z;
const half2* vec_h8 = (half2*)&vec_val_2.w;
uint4_2_4 mat_val_1 = mat4[row * (n >> 3) + j];
uint4_2_4 mat_val_2 = mat4[row * (n >> 3) + j + 1];
const uint4_2* mat_h1 = (uint4_2*)&mat_val_1.x;
const uint4_2* mat_h2 = (uint4_2*)&mat_val_1.y;
const uint4_2* mat_h3 = (uint4_2*)&mat_val_1.z;
const uint4_2* mat_h4 = (uint4_2*)&mat_val_1.w;
const uint4_2* mat_h5 = (uint4_2*)&mat_val_2.x;
const uint4_2* mat_h6 = (uint4_2*)&mat_val_2.y;
const uint4_2* mat_h7 = (uint4_2*)&mat_val_2.z;
const uint4_2* mat_h8 = (uint4_2*)&mat_val_2.w;
sum += static_cast<float>(vec_h1->x) *
(static_cast<float>(mat_h1->getX()) - zero_point_f);
sum += static_cast<float>(vec_h1->y) *
(static_cast<float>(mat_h1->getY()) - zero_point_f);
sum += static_cast<float>(vec_h2->x) *
(static_cast<float>(mat_h2->getX()) - zero_point_f);
sum += static_cast<float>(vec_h2->y) *
(static_cast<float>(mat_h2->getY()) - zero_point_f);
sum += static_cast<float>(vec_h3->x) *
(static_cast<float>(mat_h3->getX()) - zero_point_f);
sum += static_cast<float>(vec_h3->y) *
(static_cast<float>(mat_h3->getY()) - zero_point_f);
sum += static_cast<float>(vec_h4->x) *
(static_cast<float>(mat_h4->getX()) - zero_point_f);
sum += static_cast<float>(vec_h4->y) *
(static_cast<float>(mat_h4->getY()) - zero_point_f);
sum += static_cast<float>(vec_h5->x) *
(static_cast<float>(mat_h5->getX()) - zero_point_f);
sum += static_cast<float>(vec_h5->y) *
(static_cast<float>(mat_h5->getY()) - zero_point_f);
sum += static_cast<float>(vec_h6->x) *
(static_cast<float>(mat_h6->getX()) - zero_point_f);
sum += static_cast<float>(vec_h6->y) *
(static_cast<float>(mat_h6->getY()) - zero_point_f);
sum += static_cast<float>(vec_h7->x) *
(static_cast<float>(mat_h7->getX()) - zero_point_f);
sum += static_cast<float>(vec_h7->y) *
(static_cast<float>(mat_h7->getY()) - zero_point_f);
sum += static_cast<float>(vec_h8->x) *
(static_cast<float>(mat_h8->getX()) - zero_point_f);
sum += static_cast<float>(vec_h8->y) *
(static_cast<float>(mat_h8->getY()) - zero_point_f);
}
}
sum *= scale_f;
sum = warpReduceSum(sum, blockDim.x);
if (blockDim.x <= WARP_SIZE) {
if (tid == 0) {
res[row] = __float2half(sum);
}
return;
}
// Shared mem for partial sums (one per warp in the block)
static __shared__ float warpLevelSums[SHARED_MEM_MAX_ROWS][WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (laneId == 0) warpLevelSums[threadIdx.y][warpId] = sum;
__syncthreads();
// read from shared memory only if that warp existed
sum = (threadIdx.x < blockDim.x / WARP_SIZE)
? warpLevelSums[threadIdx.y][laneId]
: 0.0;
// Final reduce using first warp
if (warpId == 0) sum = warpReduceSum(sum, blockDim.x / WARP_SIZE);
if (tid == 0) {
res[row] = __float2half(sum);
}
}
///////////////////////////// REDUCE SUM //////////////////////////////
__device__ __forceinline__ float warpReduceSum(float sum,
unsigned int threadNum) {
if (threadNum >= 32)
sum += __shfl_down_sync(0xffffffff, sum, 16); // 0-16, 1-17, 2-18, etc.
if (threadNum >= 16)
sum += __shfl_down_sync(0xffffffff, sum, 8); // 0-8, 1-9, 2-10, etc.
if (threadNum >= 8)
sum += __shfl_down_sync(0xffffffff, sum, 4); // 0-4, 1-5, 2-6, etc.
if (threadNum >= 4)
sum += __shfl_down_sync(0xffffffff, sum, 2); // 0-2, 1-3, 4-6, 5-7, etc.
if (threadNum >= 2)
sum += __shfl_down_sync(0xffffffff, sum, 1); // 0-1, 2-3, 4-5, etc.
return sum;
}