forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialReflectionPadding.cu
87 lines (73 loc) · 2.87 KB
/
SpatialReflectionPadding.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include "THCUNN.h"
#include "THCTensor.hpp"
#include "common.h"
#include "THCDeviceTensor.cuh"
#include "THCDeviceTensorUtils.cuh"
#include "THCDeviceUtils.cuh"
#include "THCReduceApplyUtils.cuh"
#include <THC/THCApply.cuh>
#include "TH/THHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCAtomics.cuh"
template<typename Dtype>
__global__ void SpatialReflectionPadding_updateOutput(
THCDeviceTensor<Dtype, 4> input,
THCDeviceTensor<Dtype, 4> output,
int padT, int padB, int padL, int padR) {
int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
if (outputPointId >= output.getSize(2) * output.getSize(3)) {
return;
}
int outputPointX = outputPointId % output.getSize(3);
int outputPointY = outputPointId / output.getSize(3);
int iStartX = max(0, -padL);
int iStartY = max(0, -padT);
int oStartX = max(0, padL);
int oStartY = max(0, padT);
int inputPointX = abs(outputPointX - padL)
- abs(outputPointX - (input.getSize(3) + padL - 1))
- outputPointX
+ 2 * padL + input.getSize(3) - 1
- oStartX + iStartX;
int inputPointY = abs(outputPointY - padT)
- abs(outputPointY - (input.getSize(2) + padT - 1))
- outputPointY
+ 2 * padT + input.getSize(2) - 1
- oStartY + iStartY;
Dtype valueToCopy = input[batch][plane][inputPointY][inputPointX];
output[batch][plane][outputPointY][outputPointX] = valueToCopy;
}
template <typename Dtype>
__global__ void SpatialReflectionPadding_updateGradInput(
THCDeviceTensor<Dtype, 4> gradInput,
THCDeviceTensor<Dtype, 4> gradOutput,
int padT, int padB, int padL, int padR) {
int outputPointId = threadIdx.x + blockIdx.x * blockDim.x;
int plane = blockIdx.y;
int batch = blockIdx.z;
if (outputPointId >= gradOutput.getSize(2) * gradOutput.getSize(3)) {
return;
}
int outputPointX = outputPointId % gradOutput.getSize(3);
int outputPointY = outputPointId / gradOutput.getSize(3);
int iStartX = max(0, -padL);
int iStartY = max(0, -padT);
int oStartX = max(0, padL);
int oStartY = max(0, padT);
int inputPointX = abs(outputPointX - padL)
- abs(outputPointX - (gradInput.getSize(3) + padL - 1))
- outputPointX
+ 2 * padL + gradInput.getSize(3) - 1
- oStartX + iStartX;
int inputPointY = abs(outputPointY - padT)
- abs(outputPointY - (gradInput.getSize(2) + padT - 1))
- outputPointY
+ 2 * padT + gradInput.getSize(2) - 1
- oStartY + iStartY;
Dtype valueToCopy = gradOutput[batch][plane][outputPointY][outputPointX];
atomicAdd(&gradInput[batch][plane][inputPointY][inputPointX], valueToCopy);
}
#include "generic/SpatialReflectionPadding.cu"
#include "THCGenerateFloatTypes.h"