forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
BatchLinearAlgebraLib.cu
133 lines (105 loc) · 5.22 KB
/
BatchLinearAlgebraLib.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/CUDASolver.h>
#include <ATen/cuda/CUDABlas.h>
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/cuda/MiscUtils.h>
#include <ATen/native/cuda/BatchLinearAlgebraLib.h>
#ifdef USE_CUSOLVER
namespace at {
namespace native {
inline static Tensor column_major_identity_matrix_like(const Tensor& self) {
auto size = self.sizes();
auto size_slice = IntArrayRef(size.data(), size.size()-1);
return at::ones(size_slice, self.options()).diag_embed().transpose(-2, -1);
}
template <typename scalar_t>
inline static void _apply_single_inverse_helper(scalar_t* self_ptr, scalar_t* self_inv_ptr, int* ipiv_ptr, int* info_ptr, int n) {
// self_inv_ptr should already be an identity matrix
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
at::cuda::solver::getrf<scalar_t>(handle, n, n, self_ptr, n, ipiv_ptr, info_ptr);
at::cuda::solver::getrs<scalar_t>(handle, n, n, self_ptr, n, ipiv_ptr, self_inv_ptr, n, info_ptr + 1);
}
template <typename scalar_t>
static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& infos) {
const int batch_size = cuda_int_cast(batchCount(self), "batchCount");
const int n = cuda_int_cast(self.size(-2), "self.size(-2)");
auto self_data = self.data_ptr<scalar_t>();
auto self_mat_stride = matrixStride(self);
auto self_inv_data = self_inv.data_ptr<scalar_t>();
auto self_inv_mat_stride = matrixStride(self_inv);
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
if (use_loop_launch(batch_size, n)) {
int* p_infos = infos.data_ptr<int>();
auto main_stream = at::cuda::getCurrentCUDAStream();
at::cuda::CUDAEvent main_event;
main_event.record(main_stream);
for (int64_t i = 0; i < batch_size; i++) {
auto stream = at::cuda::getStreamFromPool();
at::cuda::CUDAStreamGuard guard(stream);
main_event.block(stream);
auto dataPtr = allocator.allocate(sizeof(int) * n);
int* pivot = reinterpret_cast<int*>(dataPtr.get());
_apply_single_inverse_helper<scalar_t>(
&self_data[i * self_mat_stride], &self_inv_data[i * self_inv_mat_stride], pivot, p_infos + i * 2, n);
at::cuda::CUDAEvent finished;
finished.record(stream);
finished.block(main_stream);
}
} else {
// cublas batched kernels require input be "device array of device pointers"
Tensor self_array = at::arange(
reinterpret_cast<long>(self_data),
reinterpret_cast<long>(&self_data[(batch_size-1) * self_mat_stride]) + 1,
static_cast<long>(self_mat_stride * sizeof(scalar_t)), self.options().dtype(at::kLong));
Tensor self_inv_array = at::arange(
reinterpret_cast<long>(self_inv_data),
reinterpret_cast<long>(&self_inv_data[(batch_size-1) * self_inv_mat_stride]) + 1,
static_cast<long>(self_inv_mat_stride * sizeof(scalar_t)), self.options().dtype(at::kLong));
auto dataPtr = allocator.allocate(sizeof(int)*batch_size*n);
int* ipiv_array = reinterpret_cast<int*>(dataPtr.get());
Tensor _info1 = at::zeros({batch_size}, self.options().dtype(at::kInt));
Tensor _info2 = at::zeros({batch_size}, self.options().dtype(at::kInt));
at::cuda::blas::getrfBatched<scalar_t>(n, reinterpret_cast<scalar_t**>(self_array.data_ptr()), n,
ipiv_array, _info1.data_ptr<int>(), batch_size);
at::cuda::blas::getriBatched<scalar_t>(n, reinterpret_cast<scalar_t**>(self_array.data_ptr()), n,
ipiv_array, _info2.data_ptr<int>(), batch_size, reinterpret_cast<scalar_t**>(self_inv_array.data_ptr()));
infos = at::stack({_info1, _info2}, 1);
}
}
template <typename scalar_t>
static void apply_single_inverse_lib(const Tensor& self, Tensor& self_inv, Tensor& info) {
int n = cuda_int_cast(self.size(-2), "self.size(-2)");
Tensor ipiv = at::empty({n}, self.options().dtype(at::kInt));
_apply_single_inverse_helper<scalar_t>(
self.data_ptr<scalar_t>(), self_inv.data_ptr<scalar_t>(), ipiv.data_ptr<int>(), info.data_ptr<int>(), n);
}
Tensor _inverse_helper_cuda_lib(const Tensor& self) {
Tensor self_working_copy = cloneBatchedColumnMajor(self);
Tensor self_inv_working_copy = column_major_identity_matrix_like(self_working_copy);
const int batch_size = cuda_int_cast(batchCount(self), "batchCount");
if (self.dim() > 2 && batch_size > 1) {
Tensor infos = at::zeros({batchCount(self) * 2}, self.options().dtype(kInt));
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{
apply_batched_inverse_lib<scalar_t>(
self_working_copy, self_inv_working_copy, infos);
});
batchCheckErrors(infos, "inverse_cuda", false, 2);
} else {
Tensor info = at::zeros({2}, self.options().dtype(at::kInt));
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "inverse_cuda", [&]{
apply_single_inverse_lib<scalar_t>(self_working_copy, self_inv_working_copy, info);
});
batchCheckErrors(info, "inverse_cuda", false, 2);
}
return self_inv_working_copy;
}
}} // namespace at::native
#endif // USE_CUSOLVER