-
Notifications
You must be signed in to change notification settings - Fork 2
/
saturated_cast.cu
270 lines (247 loc) · 10.4 KB
/
saturated_cast.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
#include "include/saturated_cast.h"
#include "utils.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/ScalarType.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <type_traits>
namespace driss_torch {
using namespace at;
namespace {
#define DISPATCH_KERNEL_SINGLE(T) \
saturated_cast_kernel_single<T><<<grid, block>>>( \
static_cast<T *>(input.data_ptr()), \
static_cast<__nv_fp8_storage_t *>(output.data_ptr()), n_rows, n_cols, \
out_dtype, static_cast<float *>(scale.data_ptr()))
#define DISPATCH_KERNEL_DOUBLE_COALESCED(T) \
saturated_cast_kernel_double_coalesced<coarse_factor, T><<<grid, block>>>( \
static_cast<T *>(input.data_ptr()), \
static_cast<__nv_fp8x2_storage_t *>(output.data_ptr()), n_rows, n_cols, \
out_dtype, static_cast<float *>(scale.data_ptr()))
#define DISPATCH_KERNEL_DOUBLE_COALESCED_FLAT(T) \
saturated_cast_kernel_double_coalesced_flat<coarse_factor, T> \
<<<grid, block>>>( \
static_cast<T *>(input.data_ptr()), \
static_cast<__nv_fp8x2_storage_t *>(output.data_ptr()), \
packed_numel, out_dtype, static_cast<float *>(scale.data_ptr()))
template <typename HPType>
__global__ void saturated_cast_kernel_single(
HPType *input, __nv_fp8_storage_t *output, int n_rows, int n_cols,
__nv_fp8_interpretation_t out_dtype, float *scaler) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
// Assume row major
const int global_index = row * n_cols + col;
if (row < n_rows && col < n_cols) {
if constexpr (std::is_same_v<HPType, nv_bfloat16>) {
const HPType scaled_input = __hmul(input[global_index], (*scaler));
output[global_index] = __nv_cvt_bfloat16raw_to_fp8(
scaled_input, __nv_saturation_t::__NV_SATFINITE, out_dtype);
} else {
const HPType scaled_input = input[global_index] * (*scaler);
output[global_index] = __nv_cvt_float_to_fp8(
scaled_input, __nv_saturation_t::__NV_SATFINITE, out_dtype);
}
}
}
template <int coarse_factor, typename PackedHPType>
__global__ void saturated_cast_kernel_double_coalesced_flat(
PackedHPType const *__restrict input,
__nv_fp8x2_storage_t *__restrict output, const int numels,
__nv_fp8_interpretation_t out_dtype, float const *scaler) {
const int idx = (blockIdx.x * blockDim.x + threadIdx.x) * coarse_factor;
const int stride = 1;
const PackedHPType scale_2 = {(*scaler), (*scaler)};
PackedHPType scaled_inputs[coarse_factor];
#pragma unroll
for (int i{0}; i < coarse_factor; ++i) {
const int temp_idx = idx + i;
if (temp_idx < numels) {
scaled_inputs[i] = input[temp_idx * stride];
}
}
#pragma unroll
for (int i{0}; i < coarse_factor; ++i) {
const int temp_idx = idx + i;
if (temp_idx < numels) {
if constexpr (std::is_same_v<PackedHPType, nv_bfloat162>) {
scaled_inputs[i] = __hmul2(scaled_inputs[i], scale_2);
} else {
// I can't find the right fmul2 fo this??
scaled_inputs[i] = {scaled_inputs[i].x * (*scaler),
scaled_inputs[i].y * (*scaler)};
}
}
}
#pragma unroll
for (int i{0}; i < coarse_factor; ++i) {
const int temp_idx = idx + i;
if (temp_idx < numels) {
__nv_fp8x2_storage_t out;
if constexpr (std::is_same_v<PackedHPType, nv_bfloat162>) {
out = __nv_cvt_bfloat16raw2_to_fp8x2(
scaled_inputs[i], __nv_saturation_t::__NV_SATFINITE, out_dtype);
} else {
out = __nv_cvt_float2_to_fp8x2(
scaled_inputs[i], __nv_saturation_t::__NV_SATFINITE, out_dtype);
}
output[temp_idx * stride] = out;
}
}
}
template <int coarse_factor, typename PackedHPType>
__global__ void saturated_cast_kernel_double_coalesced(
PackedHPType const *__restrict input,
__nv_fp8x2_storage_t *__restrict output, int n_rows, int n_cols,
__nv_fp8_interpretation_t out_dtype, float const *scaler) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = (blockIdx.x * blockDim.x + threadIdx.x) * coarse_factor;
const int row_stride = n_cols;
const int col_stride = 1;
const PackedHPType scale_2 = {(*scaler), (*scaler)};
PackedHPType scaled_inputs[coarse_factor];
#pragma unroll
for (int i{0}; i < coarse_factor; ++i) {
const int temp_col = col + i;
if (row < n_rows && temp_col < n_cols) {
scaled_inputs[i] = input[row * row_stride + temp_col * col_stride];
}
}
#pragma unroll
for (int i{0}; i < coarse_factor; ++i) {
const int temp_col = col + i;
if (row < n_rows && temp_col < n_cols) {
if constexpr (std::is_same_v<PackedHPType, nv_bfloat162>) {
scaled_inputs[i] = __hmul2(scaled_inputs[i], scale_2);
} else {
// I can't find the right fmul2 fo this??
scaled_inputs[i] = {scaled_inputs[i].x * (*scaler),
scaled_inputs[i].y * (*scaler)};
}
}
}
#pragma unroll
for (int i{0}; i < coarse_factor; ++i) {
const int temp_col = col + i;
if (row < n_rows && temp_col < n_cols) {
__nv_fp8x2_storage_t out;
if constexpr (std::is_same_v<PackedHPType, nv_bfloat162>) {
out = __nv_cvt_bfloat16raw2_to_fp8x2(
scaled_inputs[i], __nv_saturation_t::__NV_SATFINITE, out_dtype);
} else {
out = __nv_cvt_float2_to_fp8x2(
scaled_inputs[i], __nv_saturation_t::__NV_SATFINITE, out_dtype);
}
output[row * row_stride + temp_col * col_stride] = out;
}
}
}
__nv_fp8_interpretation_t dtype_map(const ScalarType dtype) {
switch (dtype) {
case at::kFloat8_e4m3fn:
return __nv_fp8_interpretation_t::__NV_E4M3;
case at::kFloat8_e5m2:
return __nv_fp8_interpretation_t::__NV_E5M2;
default:
TORCH_CHECK(false, "Invalid dtype");
}
}
enum KernelChoice { single, coalesced, coalesced_flat };
void dispatch_best_kernel(const Tensor &input, const Tensor &output,
__nv_fp8_interpretation_t out_dtype,
const Tensor &scale, bool transpose) {
const int n_rows = input.size(0);
const int n_cols = input.size(1);
const int block_size_x = 32;
const int block_size_y = 32;
const auto numel = input.numel();
int kernel_choice = KernelChoice::single;
if (numel % 2 == 0 && !transpose) {
kernel_choice = KernelChoice::coalesced_flat;
} else if (n_cols % 2 == 0) {
kernel_choice = KernelChoice::coalesced;
}
switch (kernel_choice) {
case KernelChoice::single: {
const dim3 block(block_size_x, block_size_y);
const dim3 grid(ceil_div(n_cols, block_size_x),
ceil_div(n_rows, block_size_y));
if (input.scalar_type() == at::kBFloat16) {
DISPATCH_KERNEL_SINGLE(nv_bfloat16);
} else if (input.scalar_type() == at::kFloat) {
DISPATCH_KERNEL_SINGLE(float);
}
break;
}
case KernelChoice::coalesced: {
// / We cast to a 16x2 type, so we need to divide the number of columns by 2
const auto packed_col_size = n_cols / 2;
// Found 4 to be the best factor for the coalesced kernel
const int coarse_factor = 4;
const dim3 block(block_size_x, block_size_y);
const dim3 grid(ceil_div(packed_col_size, block_size_x * coarse_factor),
ceil_div(n_rows, block_size_y));
if (input.scalar_type() == at::kBFloat16) {
DISPATCH_KERNEL_DOUBLE_COALESCED(nv_bfloat162);
} else if (input.scalar_type() == at::kFloat) {
DISPATCH_KERNEL_DOUBLE_COALESCED(float2);
}
break;
}
case KernelChoice::coalesced_flat: {
const int coarse_factor = 4;
const dim3 block(256);
const int packed_numel = numel / 2;
// We divide numel by 2 because we are casting to a 16x2 type
const dim3 grid(ceil_div(packed_numel, block.x * coarse_factor));
if (input.scalar_type() == at::kBFloat16) {
DISPATCH_KERNEL_DOUBLE_COALESCED_FLAT(nv_bfloat162);
} else if (input.scalar_type() == at::kFloat) {
DISPATCH_KERNEL_DOUBLE_COALESCED_FLAT(float2);
}
break;
}
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} // namespace
Tensor saturated_cast_meta(const Tensor &input, const Tensor &scale,
ScalarType dtype, bool transpose) {
TORCH_CHECK(dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2,
"Output tensor must be of type Float8_e4m3fn or Float8_e5m2")
TORCH_CHECK(input.scalar_type() == at::kBFloat16 ||
input.scalar_type() == at::kFloat,
"Input tensor must be of type BFloat16 or Float, but got ",
input.dtype());
TORCH_CHECK(scale.scalar_type() == at::kFloat,
"Scale tensor must be of type Float, but got ", scale.dtype())
auto output = torch::empty_like(input, input.options().dtype(dtype));
return output;
}
Tensor saturated_cast(const Tensor &input, const Tensor &scale,
ScalarType dtype, bool transpose) {
TORCH_CHECK(dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2,
"Output tensor must be of type Float8_e4m3fn or Float8_e5m2")
TORCH_CHECK(input.scalar_type() == at::kBFloat16 ||
input.scalar_type() == at::kFloat,
"Input tensor must be of type BFloat16 or Float, but got ",
input.dtype());
TORCH_CHECK(scale.scalar_type() == at::kFloat,
"Scale tensor must be of type Float, but got ", scale.dtype())
TORCH_CHECK(input.dim() == 2, "Input tensor must be 2D, but got ", input.dim());
TORCH_CHECK(scale.numel() == 1, "Scale tensor must be a scalar, but got ",
scale.numel());
// Input must either be transposed or contiguous
auto strides = input.strides();
bool is_contiguous = input.is_contiguous();
bool is_transposed = strides[0] == 1 && strides[1] == input.size(0);
bool check_allowed_strides = (is_contiguous || is_transposed) && input.storage_offset() == 0 ;
auto contig_input = check_allowed_strides ? input : input.contiguous();
auto output = torch::empty_like(contig_input, contig_input.options().dtype(dtype));
dispatch_best_kernel(contig_input, output, dtype_map(dtype), scale, transpose);
return output;
}
} // namespace driss_torch