forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MiscUtils.h
124 lines (108 loc) · 3.89 KB
/
MiscUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <THC/THC.h> // for USE_MAGMA
#ifdef USE_MAGMA
#include <magma.h>
#include <magma_types.h>
#endif
namespace at {
namespace native {
#ifdef USE_MAGMA
// RAII for a MAGMA Queue
struct MAGMAQueue {
// Default constructor without a device will cause
// destroying a queue which has not been initialized.
MAGMAQueue() = delete;
// Constructor
explicit MAGMAQueue(int64_t device_id) {
auto& context = at::globalContext();
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
#if CUDA_VERSION >= 11000
// Magma operations is numerically sensitive, so TF32 should be off
// regardless of the global flag.
TORCH_CUDABLAS_CHECK(cublasGetMathMode(handle, &original_math_mode));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
#endif
magma_queue_create_from_cuda(
device_id,
at::cuda::getCurrentCUDAStream(),
handle,
at::cuda::getCurrentCUDASparseHandle(),
&magma_queue_);
}
// Getter
magma_queue_t get_queue() const { return magma_queue_; }
// Destructor
~MAGMAQueue() {
#if CUDA_VERSION >= 11000
// We've manually set the math mode to CUBLAS_DEFAULT_MATH, now we
// should restore the original math mode back
cublasHandle_t handle = magma_queue_get_cublas_handle(magma_queue_);
cublasSetMathMode(handle, original_math_mode);
#endif
magma_queue_destroy(magma_queue_);
}
private:
magma_queue_t magma_queue_;
#if CUDA_VERSION >= 11000
cublasMath_t original_math_mode;
#endif
};
static inline magma_int_t magma_int_cast(int64_t value, const char* varname) {
auto result = static_cast<magma_int_t>(value);
if (static_cast<int64_t>(result) != value) {
AT_ERROR("magma: The value of ", varname, "(", (long long)value,
") is too large to fit into a magma_int_t (", sizeof(magma_int_t), " bytes)");
}
return result;
}
// MAGMA functions that don't take a magma_queue_t aren't stream safe
// Work around this by synchronizing with the default stream
struct MagmaStreamSyncGuard {
MagmaStreamSyncGuard() {
auto stream = at::cuda::getCurrentCUDAStream();
if (stream != at::cuda::getDefaultCUDAStream()) {
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
}
}
~MagmaStreamSyncGuard() noexcept(false) {
auto default_stream = at::cuda::getDefaultCUDAStream();
if (at::cuda::getCurrentCUDAStream() != default_stream) {
AT_CUDA_CHECK(cudaStreamSynchronize(default_stream));
}
}
};
#endif
static inline int cuda_int_cast(int64_t value, const char* varname) {
auto result = static_cast<int>(value);
TORCH_CHECK(static_cast<int64_t>(result) == value,
"cuda_int_cast: The value of ", varname, "(", (long long)value,
") is too large to fit into a int (", sizeof(int), " bytes)");
return result;
}
// Creates an array of size elements of type T, backed by pinned memory
// wrapped in a Storage
template<class T>
static inline Storage pin_memory(int64_t size) {
auto* allocator = cuda::getPinnedMemoryAllocator();
int64_t adjusted_size = size * sizeof(T);
return Storage(
Storage::use_byte_size_t(),
adjusted_size,
allocator,
/*resizable=*/false);
}
// heuristic:
// cublas_x_batched doesn't work very well for small batchsize
// cublas_x_batched is intended to be used for matrices of small sizes where the launch overhead is a significant factor.
// with use_loop_launch = True, we will loop through all batches, and launch single matrix cusolver/cublas kernels
// (This heuristic was originally tested in getrf + getrs(getri), which may not work well on other kernels. )
inline static bool use_loop_launch(int batch_size, int matrix_size) {
return (batch_size <= 8) || \
(/* batch_size > 8 && */ matrix_size >= 512);
}
} // namespace native
} // namespace at