forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LinearAlgebra.cu
497 lines (430 loc) · 15.9 KB
/
LinearAlgebra.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
#include <ATen/ATen.h>
#include <ATen/LegacyTHFunctionsCUDA.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/LinearAlgebra.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/cuda/Loops.cuh>
namespace at { namespace native {
Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) {
Tensor tensor_;
IntArrayRef tensor_strides = tensor.strides();
IntArrayRef tensor_sizes = tensor.sizes();
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
tensor_ = tensor;
transpose_tensor = false;
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
tensor_ = tensor;
transpose_tensor = true;
} else {
transpose_tensor = true;
tensor_ = tensor.clone(at::MemoryFormat::Contiguous);
}
return tensor_;
}
Tensor prepare_batch_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, int64_t& ld_tensor, bool transpose_result, int64_t m, int64_t n) {
IntArrayRef tensor_strides = tensor.strides();
Tensor tensor_;
int fast_dim = transpose_result ? 2 : 1;
int leading_dim = transpose_result ? 1 : 2;
if (tensor_strides[fast_dim] == 1 &&
(tensor_strides[leading_dim] >= std::max<int64_t>(1, m))) {
transpose_tensor = false;
tensor_ = tensor;
ld_tensor = tensor_strides[leading_dim];
} else if ((tensor_strides[leading_dim] == 1) &&
(tensor_strides[fast_dim] >= std::max<int64_t>(1, n))) {
transpose_tensor = true;
tensor_ = tensor;
ld_tensor = tensor_strides[fast_dim];
} else {
transpose_tensor = !transpose_result;
if (tensor.is_contiguous()) {
tensor_ = tensor;
} else {
tensor_ = tensor.clone(at::MemoryFormat::Contiguous);
}
ld_tensor = tensor_.stride(1);
}
return tensor_;
}
namespace {
Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D");
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {mat1, "mat1", 2}, {mat2, "mat2", 3}};
checkAllSameGPU("addmm", args);
Tensor self_;
if (&result != &self) {
std::tie(self_) = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm");
} else {
self_ = self;
}
IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes = self_.sizes();
TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0");
TORCH_CHECK(self__sizes[0] == mat1_sizes[0], "self_ dim 0 must match mat1 dim 0");
TORCH_CHECK(self__sizes[1] == mat2_sizes[1], "self_ dim 1 must match mat2 dim 1");
if (&result != &self) {
at::native::resize_as_(result, self_);
if (beta.toComplexDouble() != 0.0) {
at::native::copy_(result, self_);
}
}
TORCH_CHECK(result.dim() == 2 && self_.dim() == 2, "tensors must be 2-D");
IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
return result;
}
bool transpose_result;
Tensor result_ = prepare_matrix_for_cublas(result, transpose_result);
bool transpose_mat1;
bool transpose_mat2;
Tensor mat1_ = transpose_result ? mat2 : mat1;
Tensor mat2_ = transpose_result ? mat1 : mat2;
mat1_ = prepare_matrix_for_cublas(mat1_, transpose_mat1);
mat2_ = prepare_matrix_for_cublas(mat2_, transpose_mat2);
if (transpose_result) {
transpose_mat1 = !transpose_mat1;
transpose_mat2 = !transpose_mat2;
mat1_sizes = mat1_.sizes();
mat2_sizes = mat2_.sizes();
}
int64_t m = mat1_sizes[transpose_result ? 1 : 0];
int64_t k = mat1_sizes[transpose_result ? 0 : 1];
int64_t n = mat2_sizes[transpose_result ? 0 : 1];
int64_t mat1_ld = mat1_.stride((transpose_mat1 == transpose_result) ? 1 : 0);
int64_t mat2_ld = mat2_.stride((transpose_mat2 == transpose_result) ? 1 : 0);
int64_t result_ld = result_.stride(transpose_result ? 0 : 1);
at::ScalarType scalar_type = self_.scalar_type();
if (mat1.numel() == 0) {
// By definition, when beta==0, values in self should be ignored. nans and infs
// should not propagate
if (beta.toComplexDouble() == 0.) {
return result.zero_();
}
return at::native::mul_out(result, self, at::native::scalar_tensor(beta, at::device(at::kCPU).dtype(self.scalar_type())));
}
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] {
scalar_t alpha_val = alpha.to<scalar_t>();
scalar_t beta_val = beta.to<scalar_t>();
scalar_t* mat1_ptr = mat1_.data_ptr<scalar_t>();
scalar_t* mat2_ptr = mat2_.data_ptr<scalar_t>();
scalar_t* result_ptr = result_.data_ptr<scalar_t>();
at::cuda::blas::gemm<scalar_t>(
transpose_mat1 ? 't' : 'n',
transpose_mat2 ? 't' : 'n',
m, n, k,
alpha_val,
mat1_ptr, mat1_ld,
mat2_ptr, mat2_ld,
beta_val,
result_ptr, result_ld
);
});
if (result.data_ptr() != result_.data_ptr()) {
result.copy_(result_);
}
return result;
}
Tensor& baddbmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
TORCH_CHECK(self.dim() == 3, "self must be a 3D tensor");
TORCH_CHECK(batch1.dim() == 3, "batch1 must be a 3D tensor");
TORCH_CHECK(batch2.dim() == 3, "batch2 must be a 3D tensor");
TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {batch1, "batch1", 2}, {batch2, "batch2", 3}};
checkAllSameGPU("baddbmm", args);
IntArrayRef batch1_sizes = batch1.sizes();
IntArrayRef batch2_sizes = batch2.sizes();
IntArrayRef self_sizes = self.sizes();
TORCH_CHECK(self_sizes[0] == batch1_sizes[0], "self dim 0 must match batch1 dim 0");
TORCH_CHECK(self_sizes[0] == batch2_sizes[0], "self dim 0 must match batch2 dim 0");
TORCH_CHECK(self_sizes[1] == batch1_sizes[1], "self dim 1 must match batch1 dim 1");
TORCH_CHECK(self_sizes[2] == batch2_sizes[2], "self dim 2 must match batch2 dim 2");
TORCH_CHECK(batch1_sizes[2] == batch2_sizes[1], "batch1 dim 2 must match batch2 dim 1");
if (!result.is_same(self)) {
result.resize_as_(self);
if (beta.to<c10::complex<double>>() != 0.0) {
result.copy_(self);
}
}
// handle pathological cases that blas may not like
if (result.numel() == 0) {
return result;
} else if (batch1_sizes[2] == 0) {
if (beta.to<c10::complex<double>>() == 0.0) {
return result.zero_();
} else {
return result.mul_(beta);
}
}
bool transpose_result = false;
Tensor result_;
IntArrayRef result_strides = result.strides();
IntArrayRef result_sizes = result.sizes();
if ((result_strides[1] == 1) &&
((result_sizes[2] == 1) || (result_strides[2] >= std::max<int64_t>(1, result_sizes[1])))) {
result_ = result;
} else if ((result_strides[2] == 1) &&
(result_sizes[1] == 1 || (result_strides[1] >= std::max<int64_t>(1, result_sizes[2])))) {
transpose_result = true;
result_ = result;
} else {
result_ = result.transpose(1, 2).clone(at::MemoryFormat::Contiguous);
result_ = result_.transpose(1, 2);
}
int leading_dim = transpose_result ? 1 : 2;
Tensor batch1_ = transpose_result ? batch2 : batch1;
Tensor batch2_ = transpose_result ? batch1 : batch2;
int64_t m = result_sizes[transpose_result ? 2 : 1];
int64_t n = result_sizes[leading_dim];
int64_t k = batch1_.size(leading_dim);
int64_t lda, ldb, ldc;
bool transpose_batch1, transpose_batch2;
batch1_ = prepare_batch_matrix_for_cublas(batch1_, transpose_batch1, lda, transpose_result, m, k);
batch2_ = prepare_batch_matrix_for_cublas(batch2_, transpose_batch2, ldb, transpose_result, k, n);
ldc = result_.stride(leading_dim);
int64_t num_batches = result_.size(0);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] {
scalar_t alpha_val = alpha.to<scalar_t>();
scalar_t beta_val = beta.to<scalar_t>();
scalar_t* batch1_ptr = batch1_.data_ptr<scalar_t>();
scalar_t* batch2_ptr = batch2_.data_ptr<scalar_t>();
scalar_t* result_ptr = result_.data_ptr<scalar_t>();
at::cuda::blas::bgemm<scalar_t>(
transpose_batch1 ? 't' : 'n',
transpose_batch2 ? 't' : 'n',
m, n, k,
alpha_val,
batch1_ptr, lda, batch1_.stride(0),
batch2_ptr, ldb, batch2_.stride(0),
beta_val,
result_ptr, ldc, result_.stride(0),
num_batches
);
});
if (!result.is_same(result_)) {
result.copy_(result_);
}
return result;
}
} // anonymous namespace
Tensor& mm_out_cuda(Tensor& result, const Tensor& self, const Tensor& mat2) {
result.resize_({ self.size(0), mat2.size(1) });
return addmm_out_cuda_impl(result, result, self, mat2, 0, 1);
}
Tensor mm_cuda(const Tensor& self, const Tensor& mat2) {
Tensor result = at::empty({ self.size(0), mat2.size(1) }, self.options());
return addmm_out_cuda_impl(result, result, self, mat2, 0, 1);
}
Tensor& addmm_out_cuda(Tensor &out, const Tensor &self,
const Tensor &mat1, const Tensor &mat2,
Scalar beta, Scalar alpha) {
{
at::NoNamesGuard guard;
Tensor& result = addmm_out_cuda_impl(out, self, mat1, mat2, beta, alpha);
}
at::namedinference::propagate_names_for_addmm(out, mat1, mat2, self);
return out;
}
Tensor addmm_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2,
Scalar beta, Scalar alpha) {
Tensor out = at::empty({0}, self.options());
addmm_out_cuda(out, self, mat1, mat2, beta, alpha);
return out;
}
Tensor& addmm__cuda(Tensor& self, const Tensor& mat1, const Tensor& mat2,
Scalar beta, Scalar alpha) {
addmm_out_cuda(self, self, mat1, mat2, beta, alpha);
return self;
}
Tensor& baddbmm_out_cuda(Tensor &result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
Tensor self_;
if (&result != &self) {
std::tie(self_) = expand_size(self, {batch1.size(0), batch1.size(1), batch2.size(2)}, "baddbmm");
} else {
self_ = self;
}
{
at::NoNamesGuard guard;
baddbmm_out_cuda_impl(result, self_, batch1, batch2, beta, alpha);
}
namedinference::propagate_names_if_nonempty(
result,
namedinference::compute_baddbmm_outnames(result, batch1, batch2, self));
return result;
}
Tensor baddbmm_cuda(const Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
Tensor out = at::empty({0}, self.options());
return baddbmm_out_cuda(out, self, batch1, batch2, beta, alpha);
}
Tensor& baddbmm__cuda(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
return baddbmm_out_cuda(self, self, batch1, batch2, beta, alpha);
}
Tensor& bmm_out_cuda(Tensor &result, const Tensor& batch1, const Tensor& batch2) {
result.resize_({ batch1.size(0), batch1.size(1), batch2.size(2) });
Scalar beta(0.0);
Scalar alpha(1.0);
{
NoNamesGuard guard;
baddbmm_out_cuda_impl(result, result, batch1, batch2, beta, alpha);
}
namedinference::propagate_names_if_nonempty(
result,
namedinference::compute_bmm_outnames(result, batch1, batch2));
return result;
}
Tensor bmm_cuda(const Tensor& self, const Tensor& mat2) {
Tensor result = at::empty({0}, self.options());
return native::bmm_out_cuda(result, self, mat2);
}
namespace {
inline void dot_check(const Tensor& self, const Tensor& other) {
TORCH_CHECK(
self.dim() == 1 && other.dim() == 1,
"1D tensors expected, but got ",
self.dim(),
"D and ",
other.dim(),
"D tensors");
TORCH_CHECK(
self.scalar_type() == other.scalar_type(),
"dot : expected both vectors to have same dtype, but found ",
self.scalar_type(),
" and ",
other.scalar_type());
TORCH_CHECK(
self.numel() == other.numel(),
"inconsistent tensor size, expected tensor [",
self.numel(),
"] and src [",
other.numel(),
"] to have the same number of elements, but got ",
self.numel(),
" and ",
other.numel(),
" elements respectively");
TORCH_CHECK(
self.device() == other.device(),
"expected all tensors to be on the same device. Found: ",
self.device(),
", ",
other.device());
TORCH_CHECK(
(self.numel() <= INT_MAX) && (self.stride(0) <= INT_MAX) &&
(other.stride(0) <= INT_MAX),
"dot only supports n, incx, incy with the bound [val] <= %d",
INT_MAX);
}
} // anonymous namespace
Tensor dot_cuda(const Tensor& self, const Tensor& other) {
at::NoNamesGuard guard;
dot_check(self, other);
const int n = static_cast<int>(self.numel());
int incx = static_cast<int>(self.stride(0));
int incy = static_cast<int>(other.stride(0));
if (n == 1) {
incx = 1;
incy = 1;
}
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(ScalarType::Half, self.scalar_type(), "dot", [&] {
Tensor result = at::empty({}, self.options());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::PointerModeGuard pointerModeGuard(handle, CUBLAS_POINTER_MODE_DEVICE);
at::cuda::blas::dot<scalar_t>(
handle,
n,
self.data_ptr<scalar_t>(),
incx,
other.data_ptr<scalar_t>(),
incy,
result.data_ptr<scalar_t>());
return result;
});
}
Tensor vdot_cuda(const Tensor& self, const Tensor& other) {
if (!self.is_complex()) {
return dot_cuda(self, other);
}
at::NoNamesGuard guard;
dot_check(self, other);
const int n = static_cast<int>(self.numel());
int incx = static_cast<int>(self.stride(0));
int incy = static_cast<int>(other.stride(0));
if (n == 1) {
incx = 1;
incy = 1;
}
return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
Tensor result = at::empty({}, self.options());
auto handle = at::cuda::getCurrentCUDABlasHandle();
at::cuda::blas::PointerModeGuard pointerModeGuard(
handle, CUBLAS_POINTER_MODE_DEVICE);
at::cuda::blas::vdot<scalar_t>(
handle,
n,
self.data_ptr<scalar_t>(),
incx,
other.data_ptr<scalar_t>(),
incy,
result.data_ptr<scalar_t>());
return result;
});
}
namespace {
void addr_kernel_cuda(TensorIterator &iter, Scalar beta, Scalar alpha) {
if (iter.dtype() == ScalarType::Bool) {
using scalar_t = bool;
auto beta_val = beta.to<scalar_t>();
auto alpha_val = alpha.to<scalar_t>();
// when beta is false, values in self should be ignored,
// nans and infs in self should not propagate.
if (beta_val == false) {
gpu_kernel(
iter,
[=] GPU_LAMBDA (scalar_t self_val,
scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
return alpha_val && vec1_val && vec2_val;
}
);
} else {
gpu_kernel(
iter,
[=] GPU_LAMBDA (scalar_t self_val,
scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
return (beta_val && self_val) || (alpha_val && vec1_val && vec2_val);
}
);
}
return;
}
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf,
iter.dtype(), "addr_cuda", [&] {
auto beta_val = beta.to<scalar_t>();
auto alpha_val = alpha.to<scalar_t>();
scalar_t zero_val(0);
// when beta==0, values in self should be ignored,
// nans and infs in self should not propagate.
if (beta_val == zero_val) {
gpu_kernel(
iter,
[=] GPU_LAMBDA (scalar_t self_val,
scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
return alpha_val * vec1_val * vec2_val;
}
);
} else {
gpu_kernel(
iter,
[=] GPU_LAMBDA (scalar_t self_val,
scalar_t vec1_val, scalar_t vec2_val) -> scalar_t {
return beta_val * self_val + alpha_val * vec1_val * vec2_val;
}
);
}
});
}
} // anonymous namespace
REGISTER_DISPATCH(addr_stub, &addr_kernel_cuda);
}}