forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TemporalReflectionPadding.cu
70 lines (58 loc) · 2.06 KB
/
TemporalReflectionPadding.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include "THCUNN.h"
#include "common.h"
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
#include "THCDeviceUtils.cuh"
#include "THCReduceApplyUtils.cuh"
#include <THC/THCApply.cuh>
#include "THCTensor.hpp"
#include "THCStorage.hpp"
#include "TH/THHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCAtomics.cuh"
template<typename Dtype>
__global__ void TemporalReflectionPadding_updateOutput(
THCDeviceTensor<Dtype, 3> input,
THCDeviceTensor<Dtype, 3> output,
int padL, int padR) {
int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
if (outputPointId >= output.getSize(2)) {
return;
}
int outputPointX = outputPointId % output.getSize(2);
int iStartX = max(0, -padL);
int oStartX = max(0, padL);
int inputPointX = abs(outputPointX - padL)
- abs(outputPointX - (input.getSize(2) + padL - 1))
- outputPointX
+ 2 * padL + input.getSize(2) - 1
- oStartX + iStartX;
Dtype valueToCopy = input[batch][plane][inputPointX];
output[batch][plane][outputPointX] = valueToCopy;
}
template <typename Dtype>
__global__ void TemporalReflectionPadding_updateGradInput(
THCDeviceTensor<Dtype, 3> gradInput,
THCDeviceTensor<Dtype, 3> gradOutput,
int padL, int padR) {
int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
if (outputPointId >= gradOutput.getSize(2)) {
return;
}
int outputPointX = outputPointId % gradOutput.getSize(2);
int iStartX = max(0, -padL);
int oStartX = max(0, padL);
int inputPointX = abs(outputPointX - padL)
- abs(outputPointX - (gradInput.getSize(2) + padL - 1))
- outputPointX
+ 2 * padL + gradInput.getSize(2) - 1
- oStartX + iStartX;
Dtype valueToCopy = gradOutput[batch][plane][outputPointX];
atomicAdd(&gradInput[batch][plane][inputPointX], valueToCopy);
}
#include "generic/TemporalReflectionPadding.cu"
#include "THCGenerateFloatTypes.h"