forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Repeat.cu
30 lines (25 loc) · 1.03 KB
/
Repeat.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/Repeat.h>
__global__ static void compute_cuda_kernel(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
int64_t stride = blockDim.x * gridDim.x;
for (int64_t i = idx; i < size; i += stride) {
int64_t end = cumsum_ptr[i];
int64_t repeat = repeat_ptr[i];
int64_t start = end - repeat;
for(int64_t j = start; j < end; j++) {
result_ptr[j] = i;
}
}
}
static void compute_cuda(int64_t *repeat_ptr, int64_t *cumsum_ptr, int64_t *result_ptr, int64_t size) {
int64_t block = 512;
int64_t grid = std::min<int64_t>((size + block - 1) / block, 2048L);
compute_cuda_kernel<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(repeat_ptr, cumsum_ptr, result_ptr, size);
}
namespace at { namespace native {
Tensor repeat_interleave_cuda(const Tensor &repeat) {
return repeat_interleave_common<compute_cuda>(repeat);
}
}}