forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
THCTensorIndex.cu
293 lines (259 loc) · 11.3 KB
/
THCTensorIndex.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#include <THC/THC.h>
#include <THC/THCTensorMath.h>
#include <THC/THCGeneral.h>
#include <THC/THCBlas.h>
#include <THC/THCTensorCopy.h>
#include <TH/THHalf.h>
#include <THC/THCApply.cuh>
#include <THC/THCReduce.cuh>
#include <THC/THCDeviceUtils.cuh>
#include <THC/THCNumerics.cuh>
#include <THC/THCAtomics.cuh>
#include <THC/THCThrustAllocator.cuh>
#include <THC/THCTensorSort.cuh>
#include <THC/THCTensor.hpp>
#include <THC/THCStorage.hpp>
#include <thrust/device_ptr.h>
#include <thrust/sort.h>
#include <algorithm> // for std::min
#include <c10/macros/Macros.h>
#include <ATen/WrapDimUtils.h>
// We prefer this kernel to avoid reloading index points if the number
// of indices is a small number.
// This kernel in fact works for all choices of problem size, but if
// the number of indices chosen is large, then the
// indexCopyLargeIndex kernel is a better choice to increase
// parallelism.
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim>
__global__ void indexCopySmallIndex(TensorInfo<T, IndexType> dst,
TensorInfo<T, IndexType> src,
TensorInfo<int64_t, IndexType> indices,
int dstCopyDim,
int srcCopyDim,
IndexType innerSize,
int64_t dstCopyDimSize) {
// In order to avoid reloading the index that we are copying, load
// it once to handle all of the points that are being selected, so
// it can be reused as much as possible. This kernel is chosen when
// this is a good choice (small number of chosen indices), since
// re-accessing indices in addition to src elements can be slow.
for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
// Lua indices begin at 1
IndexType dstIndex =
indices.data[IndexToOffset<int64_t, IndexType, IdxDim>::get(srcIndex, indices)];
CUDA_KERNEL_ASSERT(dstIndex < dstCopyDimSize);
// We stride over the output ignoring the indexed dimension
// (innerSize), whose offset calculation is handled differently
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < innerSize;
linearIndex += gridDim.x * blockDim.x) {
IndexType dstOffset =
IndexToOffset<T, IndexType, DstDim>::get(linearIndex, dst);
dstOffset += dstIndex * dst.strides[dstCopyDim];
IndexType srcOffset =
IndexToOffset<T, IndexType, SrcDim>::get(linearIndex, src);
srcOffset += srcIndex * src.strides[srcCopyDim];
dst.data[dstOffset] = src.data[srcOffset];
}
}
}
// We prefer this kernel to balance parallelism across index points,
// if there are a large number of indices.
// This kernel in fact works for all choices of problem size, but if
// the number of indices chosen is small, then the
// indexCopySmallIndex kernel is a better choice to reduce memory
// accesses.
template <typename T, typename IndexType, int DstDim, int SrcDim, int IdxDim,
bool IndexIsMajor>
__global__ void indexCopyLargeIndex(TensorInfo<T, IndexType> dst,
TensorInfo<T, IndexType> src,
TensorInfo<int64_t, IndexType> indices,
int dstCopyDim,
int srcCopyDim,
IndexType totalSize,
IndexType innerSize,
int64_t dstCopyDimSize) {
// We stride over the output including the indexed dimension
// (totalSize), and calculate the destination index point based on that
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalSize;
linearIndex += gridDim.x * blockDim.x) {
IndexType srcIndex, elementInSlice;
if (IndexIsMajor) {
srcIndex = linearIndex / innerSize;
elementInSlice = linearIndex % innerSize;
}
else {
elementInSlice = linearIndex / innerSize;
srcIndex = linearIndex % innerSize;
}
// Lua indices begin at 1
IndexType dstIndex =
indices.data[IndexToOffset<int64_t, IndexType, IdxDim>::get(srcIndex, indices)];
CUDA_KERNEL_ASSERT(dstIndex < dstCopyDimSize);
IndexType dstOffset =
IndexToOffset<T, IndexType, DstDim>::get(elementInSlice, dst);
dstOffset += dstIndex * dst.strides[dstCopyDim];
IndexType srcOffset =
IndexToOffset<T, IndexType, SrcDim>::get(elementInSlice, src);
srcOffset += srcIndex * src.strides[srcCopyDim];
dst.data[dstOffset] = src.data[srcOffset];
}
}
// We prefer this kernel to avoid reloading index points if the number
// of indices is a small number.
// This kernel in fact works for all choices of problem size, but if
// the number of indices chosen is large, then the
// indexFillLargeIndex kernel is a better choice to increase
// parallelism.
template <typename T, typename IndexType, int DstDim, int IdxDim>
__global__ void indexFillSmallIndex(TensorInfo<T, IndexType> dst,
TensorInfo<int64_t, IndexType> indices,
int dstFillDim,
IndexType innerSize,
int64_t dstFillDimSize,
T val) {
// In order to avoid reloading the index that we are copying, load
// it once to handle all of the points that are being selected, so
// it can be reused as much as possible. This kernel is chosen when
// this is a good choice (small number of chosen indices), since
// re-accessing indices in addition to src elements can be slow.
for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) {
// Lua indices begin at 1
IndexType dstIndex_ =
indices.data[IndexToOffset<int64_t, IndexType, IdxDim>::get(dstIndex, indices)];
CUDA_KERNEL_ASSERT(dstIndex_ < dstFillDimSize);
// We stride over the output ignoring the indexed dimension
// (innerSize), whose offset calculation is handled differently
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < innerSize;
linearIndex += gridDim.x * blockDim.x) {
IndexType dstOffset =
IndexToOffset<T, IndexType, DstDim>::get(linearIndex, dst);
dstOffset += dstIndex_ * dst.strides[dstFillDim];
dst.data[dstOffset] = val;
}
}
}
// We prefer this kernel to balance parallelism across index points,
// if there are a large number of indices.
// This kernel in fact works for all choices of problem size, but if
// the number of indices chosen is small, then the
// indexFillSmallIndex kernel is a better choice to reduce memory
// accesses.
template <typename T, typename IndexType, int DstDim, int IdxDim,
bool IndexIsMajor>
__global__ void indexFillLargeIndex(TensorInfo<T, IndexType> dst,
TensorInfo<int64_t, IndexType> indices,
int dstFillDim,
IndexType totalSize,
IndexType innerSize,
int64_t dstFillDimSize,
T val) {
// We stride over the output including the indexed dimension
// (totalSize), and calculate the destination index point based on that
for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
linearIndex < totalSize;
linearIndex += gridDim.x * blockDim.x) {
IndexType dstIndex, elementInSlice;
if (IndexIsMajor) {
dstIndex = linearIndex / innerSize;
elementInSlice = linearIndex % innerSize;
}
else {
elementInSlice = linearIndex / innerSize;
dstIndex = linearIndex % innerSize;
}
// Lua indices begin at 1
IndexType dstIndex_ =
indices.data[IndexToOffset<int64_t, IndexType, IdxDim>::get(dstIndex, indices)];
CUDA_KERNEL_ASSERT(dstIndex_ < dstFillDimSize);
IndexType dstOffset =
IndexToOffset<T, IndexType, DstDim>::get(elementInSlice, dst);
dstOffset += dstIndex_ * dst.strides[dstFillDim];
dst.data[dstOffset] = val;
}
}
template <int Dims, typename T, typename IndexType>
__device__ __forceinline__ IndexType indexToOffset(
const TensorInfo<T, IndexType>& info,
int64_t index,
IndexType size)
{
IndexType linearIndex = static_cast<IndexType>(index);
CUDA_KERNEL_ASSERT(linearIndex < size && linearIndex >= -size);
if (linearIndex < 0) {
linearIndex += size;
}
return IndexToOffset<T, IndexType, Dims>::get(linearIndex, info);
}
struct WrapIndexOp {
WrapIndexOp(int64_t size) : size(size) {}
__device__ __forceinline__ void operator()(int64_t* out, int64_t* in) {
auto idx = *in;
CUDA_KERNEL_ASSERT(idx < size && idx >= -size);
*out = idx < 0 ? idx + size : idx;
}
int64_t size;
};
template <typename T, typename IndexType, int Dims>
struct TensorPutOp {
TensorPutOp(TensorInfo<T, IndexType> info, IndexType numel, int64_t*, int64_t*)
: info(info), numel(numel) {}
__device__ __forceinline__ void operator()(T* value, int64_t* index) {
auto offset = indexToOffset<Dims>(info, *index, numel);
info.data[offset] = *value;
}
const TensorInfo<T, IndexType> info;
IndexType numel;
};
template <typename T, typename IndexType, int Dims>
struct TensorPutAccumulateOp {
TensorPutAccumulateOp(TensorInfo<T, IndexType> info, IndexType numel, int64_t* start, int64_t* end)
: info(info), numel(numel), start(start), end(end) {}
__device__ __forceinline__ void operator()(T* value, int64_t* index) {
if (index == start || *index != *(index - 1)) {
int64_t linear_index = *index;
auto offset = indexToOffset<Dims>(info, linear_index, numel);
do {
info.data[offset] = THCNumerics<T>::add(info.data[offset], *value);
index++;
value++;
} while (index != end && *index == linear_index);
}
}
const TensorInfo<T, IndexType> info;
IndexType numel;
int64_t* start;
int64_t* end;
};
template<typename IndexType, typename T, template<class, class, int> class Op, typename TensorType>
void dispatchTakePutImpl(THCState *state, TensorType *a, TensorType *b, THCudaLongTensor *index) {
// These are only valid if index is contiguous
auto start = THCudaLongTensor_data(state, index);
auto end = start + THCudaLongTensor_numel(state, index);
auto aInfo = getTensorInfo<T, TensorType, IndexType>(state, a);
aInfo.collapseDims();
auto numel = THCTensor_nElement(state, a);
if (aInfo.isContiguous()) {
auto op = Op<T, IndexType, -2>(aInfo, numel, start, end);
THC_pointwiseApply2<T, int64_t>(state, b, index, op);
} else {
auto op = Op<T, IndexType, -1>(aInfo, numel, start, end);
THC_pointwiseApply2<T, int64_t>(state, b, index, op);
}
}
template<typename T, template<class, class, int> class Op, typename TensorType>
void dispatchTakePut(THCState *state, TensorType *a, TensorType *b, THCudaLongTensor *index) {
if (THCTensor_canUse32BitIndexMath(state, a, INT_MAX)) {
dispatchTakePutImpl<int32_t, T, Op>(state, a, b, index);
} else {
dispatchTakePutImpl<int64_t, T, Op>(state, a, b, index);
}
}
#include <THC/generic/THCTensorIndex.cu>
#include <THC/THCGenerateAllTypes.h>
#include <THC/generic/THCTensorIndex.cu>
#include <THC/THCGenerateBoolType.h>
#include <THC/generic/THCTensorIndex.cu>
#include <THC/THCGenerateBFloat16Type.h>