-
Notifications
You must be signed in to change notification settings - Fork 173
/
rms_norm.cu
819 lines (752 loc) · 34 KB
/
rms_norm.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <vector>
#include <algorithm>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <torch/types.h>
#include <torch/extension.h>
#define WARP_SIZE 32
#define INT4(value) (reinterpret_cast<int4*>(&(value))[0])
#define FLOAT4(value) (reinterpret_cast<float4*>(&(value))[0])
#define HALF2(value) (reinterpret_cast<half2*>(&(value))[0])
#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
#define LDST128BITS(value) (reinterpret_cast<float4*>(&(value))[0])
// -------------------------------------- FP32 --------------------------------------
// Warp Reduce Sum
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_sum_f32(float val) {
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
// Block reduce sum/max/min device helper for Layer/RMS Norm/Softmax etc.
// grid 1D block 1D, grid(N/256), block(256)
template<const int NUM_THREADS=256>
__device__ __forceinline__ float block_reduce_sum_f32(float val) {
// always <= 32 warps per block (limited by 1024 threads per block)
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
static __shared__ float shared[NUM_WARPS];
val = warp_reduce_sum_f32<WARP_SIZE>(val);
if (lane == 0) shared[warp] = val;
__syncthreads();
val = (lane < NUM_WARPS) ? shared[lane] : 0.0f;
val = warp_reduce_sum_f32<NUM_WARPS>(val);
return val;
}
// RMS Norm: x: NxK(K=256<1024), y': NxK, y'=x/rms(x) each row
// 1/rms(x) = rsqrtf( sum(x^2)/K ) each row
// grid(N*K/K), block(K<1024) N=batch_size*seq_len, K=hidden_size
// y=y'*g (g: scale)
template<const int NUM_THREADS=256>
__global__ void rms_norm_f32_kernel(float* x, float* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = bid * blockDim.x + threadIdx.x;
const float epsilon = 1e-5f;
__shared__ float s_variance; // shared within block
float value = (idx < N * K) ? x[idx] : 0.0f; // load once only
float variance = value * value;
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
if (idx < N * K) y[idx] = (value * s_variance) * g;
}
// RMS Norm Vec4: x: NxK(K=256<1024), y': NxK, y'=x/rms(x) each row
// 1/rms(x) = rsqrtf( sum(x^2)/K ) each row
// grid(N*K/K), block(K/4<1024) N=batch_size*seq_len, K=hidden_size
// y=y'*g (g: scale)
template<const int NUM_THREADS=256/4>
__global__ void rms_norm_f32x4_kernel(float* x, float* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = (bid * blockDim.x + threadIdx.x) * 4;
const float epsilon = 1e-5f;
__shared__ float s_variance; // shared within block
float4 reg_x = FLOAT4(x[idx]);
float variance = (idx < N * K) ? (reg_x.x * reg_x.x + reg_x.y * reg_x.y
+ reg_x.z * reg_x.z + reg_x.w * reg_x.w) : 0.0f;
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
if (tid == 0) s_variance = rsqrtf(variance / (float) K + epsilon);
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
float4 reg_y;
reg_y.x = reg_x.x * s_variance * g;
reg_y.y = reg_x.y * s_variance * g;
reg_y.z = reg_x.z * s_variance * g;
reg_y.w = reg_x.w * s_variance * g;
if (idx < N * K) FLOAT4(y[idx]) = reg_y;
}
// -------------------------------------- FP16 --------------------------------------
// Warp Reduce Sum: Half
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ half warp_reduce_sum_f16_f16(half val) {
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
// val = __hadd(val, __shfl_xor_sync(0xffffffff, val, mask));
val += __shfl_xor_sync(0xffffffff, val, mask);
}
return val;
}
template<const int kWarpSize = WARP_SIZE>
__device__ __forceinline__ float warp_reduce_sum_f16_f32(half val) {
float val_f32 = __half2float(val);
#pragma unroll
for (int mask = kWarpSize >> 1; mask >= 1; mask >>= 1) {
val_f32 += __shfl_xor_sync(0xffffffff, val_f32, mask);
}
return val_f32;
}
template<const int NUM_THREADS=256>
__device__ half block_reduce_sum_f16_f16(half val) {
// always <= 32 warps per block (limited by 1024 threads per block)
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
static __shared__ half shared[NUM_WARPS];
// reduce using half dtype within warps
val = warp_reduce_sum_f16_f16<WARP_SIZE>(val);
if (lane == 0) shared[warp] = val;
__syncthreads();
val = (lane < NUM_WARPS) ? shared[lane] : __float2half(0.0f);
val = warp_reduce_sum_f16_f16<NUM_WARPS>(val);
return val; // half
}
template<const int NUM_THREADS=256>
__device__ float block_reduce_sum_f16_f32(half val) {
// always <= 32 warps per block (limited by 1024 threads per block)
constexpr int NUM_WARPS = (NUM_THREADS + WARP_SIZE - 1) / WARP_SIZE;
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
static __shared__ float shared[NUM_WARPS];
// reduce using float dtype within warps
float val_f32 = warp_reduce_sum_f16_f32<WARP_SIZE>(val);
if (lane == 0) shared[warp] = val_f32;
__syncthreads();
val_f32 = (lane < NUM_WARPS) ? shared[lane] : 0.0f;
val_f32 = warp_reduce_sum_f32<NUM_WARPS>(val_f32);
return val_f32; // float
}
template<const int NUM_THREADS=256>
__global__ void rms_norm_f16_f16_kernel(half* x, half* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = bid * blockDim.x + threadIdx.x;
const half epsilon = __float2half(1e-5f);
const half g_ = __float2half(g);
const half K_ = __int2half_rn(K);
__shared__ half s_variance; // shared within block
half value = (idx < N * K) ? x[idx] : __float2half(0.0f); // load once only
half variance = value * value;
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
if (tid == 0) s_variance = hrsqrt(variance / (K_ + epsilon));
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
if (idx < N * K) y[idx] = (value * s_variance) * g_;
}
template<const int NUM_THREADS=256>
__global__ void rms_norm_f16x2_f16_kernel(half* x, half* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = (bid * blockDim.x + threadIdx.x) * 2;
const half epsilon = __float2half(1e-5f);
const half g_ = __float2half(g);
const half K_ = __int2half_rn(K);
__shared__ half s_variance; // shared within block
half2 reg_x = HALF2(x[idx]);
half variance = (idx < N * K) ? (reg_x.x * reg_x.x
+ reg_x.y * reg_x.y): __float2half(0.0f);
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
if (tid == 0) s_variance = hrsqrt(variance / (K_ + epsilon));
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
half2 reg_y;
reg_y.x = reg_x.x * s_variance * g_;
reg_y.y = reg_x.y * s_variance * g_;
if (idx < N * K) HALF2(y[idx]) = reg_y;
}
#define HALF2_VARIANCE(reg, i) \
(((idx + (i)) < N * K) ? ((reg).x * (reg).x + (reg).y * (reg).y) : __float2half(0.0f))
#define FLOAT2_VARIANCE(reg, i) \
(((idx + (i)) < N * K) ? ((reg).x * (reg).x + (reg).y * (reg).y) : 0.0f)
#define HALF2_RMS_NORM(reg_y, reg_x, g) \
(reg_y).x = (reg_x).x * s_variance * (g); (reg_y).y = (reg_x).y * s_variance * (g);
#define FLOAT2_RMS_NORM(reg_y, reg_x, g) \
(reg_y).x = (reg_x).x * s_variance * (g); (reg_y).y = (reg_x).y * s_variance * (g);
template<const int NUM_THREADS=256>
__global__ void rms_norm_f16x8_f16_kernel(half* x, half* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = (bid * blockDim.x + threadIdx.x) * 8;
const half epsilon = __float2half(1e-5f);
const half g_ = __float2half(g);
const half K_ = __int2half_rn(K);
__shared__ half s_variance; // shared within block
// manual unroll and improve L2 cache hit rate.
// Only L2 cache: load 32 bytes in 1 memory issue (default)
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
half2 reg_x_0 = HALF2(x[idx + 0]);
half2 reg_x_1 = HALF2(x[idx + 2]);
half2 reg_x_2 = HALF2(x[idx + 4]);
half2 reg_x_3 = HALF2(x[idx + 6]);
half variance = HALF2_VARIANCE(reg_x_0, 0);
variance += HALF2_VARIANCE(reg_x_1, 2);
variance += HALF2_VARIANCE(reg_x_2, 4);
variance += HALF2_VARIANCE(reg_x_3, 6);
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
if (tid == 0) s_variance = hrsqrt(variance / (K_ + epsilon));
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
// manual unroll
half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3;
HALF2_RMS_NORM(reg_y_0, reg_x_0, g_);
HALF2_RMS_NORM(reg_y_1, reg_x_1, g_);
HALF2_RMS_NORM(reg_y_2, reg_x_2, g_);
HALF2_RMS_NORM(reg_y_3, reg_x_3, g_);
if ((idx + 0) < N * K) { HALF2(y[idx + 0]) = reg_y_0; }
if ((idx + 2) < N * K) { HALF2(y[idx + 2]) = reg_y_1; }
if ((idx + 4) < N * K) { HALF2(y[idx + 4]) = reg_y_2; }
if ((idx + 6) < N * K) { HALF2(y[idx + 6]) = reg_y_3; }
}
template<const int NUM_THREADS=256>
__global__ void rms_norm_f16x8_f32_kernel(half* x, half* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = (bid * blockDim.x + threadIdx.x) * 8;
const float epsilon = 1e-5f;
__shared__ float s_variance; // shared within block
// manual unroll and improve L2 cache hit rate.
// Only L2 cache: load 32 bytes in 1 memory issue (default)
// Enable L1 cache: load 128 bytes in 1 memory issue (-Xptxas -dlcm=ca)
// why try fp16x8 within 1 threads? ref: https://zhuanlan.zhihu.com/p/641639133
// 0. first, tid_0 load 32 bytes in 1 memory issue and cache data into L2 cache.
// 1. then, tid_1,...,tid_3 hit L2 cache and load data from L2 cache directly.
float2 reg_x_0 = __half22float2(HALF2(x[idx + 0]));
float2 reg_x_1 = __half22float2(HALF2(x[idx + 2]));
float2 reg_x_2 = __half22float2(HALF2(x[idx + 4]));
float2 reg_x_3 = __half22float2(HALF2(x[idx + 6]));
float variance = FLOAT2_VARIANCE(reg_x_0, 0);
variance += FLOAT2_VARIANCE(reg_x_1, 2);
variance += FLOAT2_VARIANCE(reg_x_2, 4);
variance += FLOAT2_VARIANCE(reg_x_3, 6);
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
if (tid == 0) s_variance = rsqrtf(variance / ((float) K + epsilon));
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
// manual unroll
float2 reg_y_0, reg_y_1, reg_y_2, reg_y_3;
FLOAT2_RMS_NORM(reg_y_0, reg_x_0, g);
FLOAT2_RMS_NORM(reg_y_1, reg_x_1, g);
FLOAT2_RMS_NORM(reg_y_2, reg_x_2, g);
FLOAT2_RMS_NORM(reg_y_3, reg_x_3, g);
if ((idx + 0) < N * K) { HALF2(y[idx + 0]) = __float22half2_rn(reg_y_0); }
if ((idx + 2) < N * K) { HALF2(y[idx + 2]) = __float22half2_rn(reg_y_1); }
if ((idx + 4) < N * K) { HALF2(y[idx + 4]) = __float22half2_rn(reg_y_2); }
if ((idx + 6) < N * K) { HALF2(y[idx + 6]) = __float22half2_rn(reg_y_3); }
}
template<const int NUM_THREADS=256>
__global__ void rms_norm_f16_f32_kernel(half* x, half* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = bid * blockDim.x + threadIdx.x;
const float epsilon = 1e-5f;
__shared__ float s_variance; // shared within block
float value = (idx < N * K) ? __half2float(x[idx]) : 0.0f; // load once only
float variance = value * value;
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
if (tid == 0) s_variance = rsqrtf(variance / ((float) K + epsilon));
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
if (idx < N * K) {
y[idx] = __float2half((value * s_variance) * g);
}
}
template<const int NUM_THREADS=256>
__global__ void rms_norm_f16x8_pack_f16_kernel(half* x, half* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = (bid * blockDim.x + threadIdx.x) * 8;
const half epsilon = __float2half(1e-5f);
const half g_ = __float2half(g);
const half K_ = __int2half_rn(K);
const half z_ = __float2half(0.0f);
__shared__ half s_variance; // shared within block
// temporary register(memory), .local space in ptx, addressable
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
// reinterpret as float4 and load 128 bits in 1 memory issue.
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits
half variance = z_;
#pragma unroll
for (int i = 0; i < 8; ++i) {
variance += ((idx + i) < N * K ? pack_x[i] * pack_x[i] : z_);
}
variance = block_reduce_sum_f16_f16<NUM_THREADS>(variance);
if (tid == 0) s_variance = hrsqrt(variance / (K_ + epsilon));
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
#pragma unroll
for (int i = 0; i < 8; ++i) {
pack_y[i] = pack_x[i] * s_variance * g_;
}
// reinterpret as float4 and store 128 bits in 1 memory issue.
if ((idx + 7) < N * K) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
// TODO: support non 8-multiple K here
}
template<const int NUM_THREADS=256>
__global__ void rms_norm_f16x8_pack_f32_kernel(half* x, half* y, float g, int N, int K) {
int tid = threadIdx.x; // 0..K-1
int bid = blockIdx.x; // 0..N-1
int idx = (bid * blockDim.x + threadIdx.x) * 8;
const float epsilon = 1e-5f;
__shared__ float s_variance; // shared within block
// temporary register(memory), .local space in ptx, addressable
half pack_x[8], pack_y[8]; // 8x16 bits=128 bits.
// reinterpret as float4 and load 128 bits in 1 memory issue.
LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); // load 128 bits
float variance = 0.0f;
#pragma unroll
for (int i = 0; i < 8; ++i) {
float v = __half2float(pack_x[i]);
variance += ((idx + i) < N * K ? v * v : 0.0f);
}
variance = block_reduce_sum_f32<NUM_THREADS>(variance);
if (tid == 0) s_variance = rsqrtf(variance / ((float) K + epsilon));
// wait for s_variance in shared memory to be ready for all threads
__syncthreads();
#pragma unroll
for (int i = 0; i < 8; i += 2) {
float2 v2 = __half22float2(HALF2(pack_x[i]));
float2 y2 = {v2.x * s_variance * g, v2.y * s_variance * g};
HALF2(pack_y[i]) = __float22half2_rn(y2);
}
// reinterpret as float4 and store 128 bits in 1 memory issue.
if ((idx + 7) < N * K) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); }
// TODO: support non 8-multiple K here
}
// --------------------- PyTorch bindings for custom kernel -----------------------
#define STRINGFY(str) #str
#define TORCH_BINDING_COMMON_EXTENSION(func) \
m.def(STRINGFY(func), &func, STRINGFY(func));
#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
if(((T).options().dtype() != (th_type))) { \
std::cout << "Tensor Info:" << (T).options() << std::endl; \
throw std::runtime_error("values must be "#th_type); \
}
#define CHECK_TORCH_TENSOR_SHAPE(T1, T2) \
assert((T1).dim() == (T2).dim()); \
for (int i = 0; i < (T1).dim(); ++i) { \
if ((T2).size(i) != (T1).size(i)) { \
throw std::runtime_error("Tensor size mismatch!"); \
} \
}
#define LANUCH_RMS_NORM_F32_KERNEL(K) \
rms_norm_f32_kernel<(K)><<<grid, block>>>( \
reinterpret_cast<float*>(x.data_ptr()), \
reinterpret_cast<float*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F32_KERNEL(N, K) \
dim3 block((K)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F32_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F32_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F32_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F32_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F32_KERNEL(1024) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/256/512/1024");\
break; \
}
#define LANUCH_RMS_NORM_F32x4_KERNEL(K) \
rms_norm_f32x4_kernel<(K)/4><<<grid, block>>>( \
reinterpret_cast<float*>(x.data_ptr()), \
reinterpret_cast<float*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F32x4_KERNEL(N, K) \
dim3 block((K)/4); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F32x4_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F32x4_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F32x4_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F32x4_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F32x4_KERNEL(1024) \
break; \
case 2048: \
LANUCH_RMS_NORM_F32x4_KERNEL(2048) \
break; \
case 4096: \
LANUCH_RMS_NORM_F32x4_KERNEL(4096) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/.../512/1024*4"); \
break; \
}
void rms_norm_f32(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F32_KERNEL(N, K)
}
void rms_norm_f32x4(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kFloat32)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kFloat32)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F32x4_KERNEL(N, K)
}
// fp16
#define LANUCH_RMS_NORM_F16F16_KERNEL(K) \
rms_norm_f16_f16_kernel<(K)><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F16F16_KERNEL(N, K) \
dim3 block((K)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F16F16_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F16F16_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F16F16_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F16F16_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F16F16_KERNEL(1024) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/256/512/1024"); \
break; \
}
#define LANUCH_RMS_NORM_F16F32_KERNEL(K) \
rms_norm_f16_f32_kernel<(K)><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F16F32_KERNEL(N, K) \
dim3 block((K)); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F16F32_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F16F32_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F16F32_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F16F32_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F16F32_KERNEL(1024) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/256/512/1024"); \
break; \
}
#define LANUCH_RMS_NORM_F16x2F16_KERNEL(K) \
rms_norm_f16x2_f16_kernel<(K)/2><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F16x2F16_KERNEL(N, K) \
dim3 block((K)/2); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F16x2F16_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F16x2F16_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F16x2F16_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F16x2F16_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F16x2F16_KERNEL(1024) \
break; \
case 2048: \
LANUCH_RMS_NORM_F16x2F16_KERNEL(2048) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/.../1024*2"); \
break; \
}
#define LANUCH_RMS_NORM_F16x8F16_KERNEL(K) \
rms_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F16x8F16_KERNEL(N, K) \
dim3 block((K)/8); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F16x8F16_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F16x8F16_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F16x8F16_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F16x8F16_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F16x8F16_KERNEL(1024) \
break; \
case 2048: \
LANUCH_RMS_NORM_F16x8F16_KERNEL(2048) \
break; \
case 4096: \
LANUCH_RMS_NORM_F16x8F16_KERNEL(4096) \
break; \
case 8192: \
LANUCH_RMS_NORM_F16x8F16_KERNEL(8192) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/.../1024*8"); \
break; \
}
#define LANUCH_RMS_NORM_F16x8F32_KERNEL(K) \
rms_norm_f16x8_f16_kernel<(K)/8><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F16x8F32_KERNEL(N, K) \
dim3 block((K)/8); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F16x8F32_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F16x8F32_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F16x8F32_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F16x8F32_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F16x8F32_KERNEL(1024) \
break; \
case 2048: \
LANUCH_RMS_NORM_F16x8F32_KERNEL(2048) \
break; \
case 4096: \
LANUCH_RMS_NORM_F16x8F32_KERNEL(4096) \
break; \
case 8192: \
LANUCH_RMS_NORM_F16x8F32_KERNEL(8192) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/.../1024*8"); \
break; \
}
#define LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(K) \
rms_norm_f16x8_pack_f16_kernel<(K)/8><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F16x8_PACK_F16_KERNEL(N, K) \
dim3 block((K)/8); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(1024) \
break; \
case 2048: \
LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(2048) \
break; \
case 4096: \
LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(4096) \
break; \
case 8192: \
LANUCH_RMS_NORM_F16x8_PACK_F16_KERNEL(8192) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/.../1024*8"); \
break; \
}
#define LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(K) \
rms_norm_f16x8_pack_f32_kernel<(K)/8><<<grid, block>>>( \
reinterpret_cast<half*>(x.data_ptr()), \
reinterpret_cast<half*>(y.data_ptr()), \
g, N, (K));
#define DISPATCH_RMS_NORM_F16x8_PACK_F32_KERNEL(N, K) \
dim3 block((K)/8); \
dim3 grid((N)); \
switch ((K)) \
{ \
case 64: \
LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(64) \
break; \
case 128: \
LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(128) \
break; \
case 256: \
LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(256) \
break; \
case 512: \
LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(512) \
break; \
case 1024: \
LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(1024) \
break; \
case 2048: \
LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(2048) \
break; \
case 4096: \
LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(4096) \
break; \
case 8192: \
LANUCH_RMS_NORM_F16x8_PACK_F32_KERNEL(8192) \
break; \
default: \
throw std::runtime_error( \
"only support K: 64/128/.../1024*8"); \
break; \
}
void rms_norm_f16_f16(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F16F16_KERNEL(N, K)
}
void rms_norm_f16x2_f16(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F16x2F16_KERNEL(N, K)
}
void rms_norm_f16x8_f16(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F16x8F16_KERNEL(N, K)
}
void rms_norm_f16x8_f32(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F16x8F32_KERNEL(N, K)
}
void rms_norm_f16_f32(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F16F32_KERNEL(N, K)
}
// pack
void rms_norm_f16x8_pack_f16(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F16x8_PACK_F16_KERNEL(N, K)
}
void rms_norm_f16x8_pack_f32(torch::Tensor x, torch::Tensor y, float g) {
CHECK_TORCH_TENSOR_DTYPE(x, torch::kHalf)
CHECK_TORCH_TENSOR_DTYPE(y, torch::kHalf)
CHECK_TORCH_TENSOR_SHAPE(x, y)
const int N = x.size(0);
const int K = x.size(1);
DISPATCH_RMS_NORM_F16x8_PACK_F32_KERNEL(N, K)
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f32)
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f32x4)
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f16_f16)
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f16x2_f16)
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f16x8_f16)
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f16x8_pack_f16)
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f16x8_f32)
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f16x8_pack_f32)
TORCH_BINDING_COMMON_EXTENSION(rms_norm_f16_f32)
}