forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SpatialSubSampling.cu
265 lines (224 loc) · 7.27 KB
/
SpatialSubSampling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
#include <THCUNN/THCUNN.h>
#include <THC/THCTensor.hpp>
#include <TH/THHalf.h>
#include <THC/THCNumerics.cuh>
#include <THC/THCAtomics.cuh>
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
/*
* Description:
* this function subsamples an input 3D tensor along dimensions 1 and 2
* 3D input, 3D output, 1D weight, 1D bias
*/
template <typename Dtype, typename Acctype>
__global__ void subsample(Dtype *input, Dtype *output, Dtype *weight, Dtype *bias,
int input_n, int input_h, int input_w,
int kH, int kW, int dH, int dW)
{
// iterators
int xx, yy;
// output size
int output_w = (input_w - kW) / dW + 1;
int output_h = (input_h - kH) / dH + 1;
// compute offsets based on thread/block ID
int o = blockIdx.x;
int i = o;
int k = blockIdx.x % input_n;
int xx_start = threadIdx.x;
int xx_end = output_w;
int xx_step = blockDim.x;
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
int yy_end = output_h;
int yy_step = blockDim.y*gridDim.y;
// select input/output plane
output = output + o*output_w*output_h;
input = input + i*input_w*input_h;
// Get the good mask for (k,i) (k out, i in)
Dtype the_weight = weight[k];
// Initialize to the bias
Dtype the_bias = bias[k];
// For all output pixels...
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
// Compute the mean of the input image...
Dtype *ptr_input = input + yy*dH*input_w + xx*dW;
Dtype *ptr_output = output + yy*output_w + xx;
Acctype sum = 0;
int kx, ky;
for(ky = 0; ky < kH; ky++) {
for(kx = 0; kx < kW; kx++)
sum += ptr_input[kx];
ptr_input += input_w; // next input line
}
// Update output
*ptr_output = ScalarConvert<Acctype, Dtype>::to(the_weight*sum + the_bias);
}
}
}
/*
* Description:
* this function computes the gradWeight from input and gradOutput
*/
template <typename Dtype, typename Acctype>
__global__ void subgradweight(Dtype *input, Dtype *gradOutput, Dtype *gradWeight, Dtype *gradBias,
int input_n, int input_h, int input_w,
int kH, int kW, int dH, int dW,
float scale)
{
// iterators
int xx, yy;
// output size
int output_w = (input_w - kW) / dW + 1;
int output_h = (input_h - kH) / dH + 1;
// compute offsets based on thread/block ID
int o = blockIdx.x;
int i = o;
int k = blockIdx.x % input_n;
int xx_start = threadIdx.x;
int xx_end = output_w;
int xx_step = blockDim.x;
int yy_start = threadIdx.y;
int yy_end = output_h;
int yy_step = blockDim.y;
// select input/output plane
gradOutput = gradOutput + o*output_w*output_h;
input = input + i*input_w*input_h;
// thread ID
int tid = blockDim.x*threadIdx.y + threadIdx.x;
// create array to hold partial sums
__shared__ Acctype sums[CUDA_MAX_THREADS];
sums[tid] = 0;
// compute partial sums
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
Dtype *ptr_input = input + yy*dH*input_w + xx*dW;
Dtype *ptr_gradOutput = gradOutput + yy*output_w + xx;
Dtype z = *ptr_gradOutput;
int64_t kx, ky;
for(ky = 0; ky < kH; ky++) {
for(kx = 0; kx < kW; kx++) {
sums[tid] += z * ptr_input[kx];
}
ptr_input += input_w;
}
}
}
__syncthreads();
// reduce: accumulate all partial sums to produce final gradWeight
if ((threadIdx.x == 0) && (threadIdx.y == 0)) {
Acctype scaledSums = Acctype(0);
for(int i = 0; i < blockDim.x*blockDim.y; i++) {
scaledSums += scale*sums[i];
}
gradWeight[k] += ScalarConvert<Acctype, Dtype>::to(scaledSums);
}
__syncthreads();
// compute gradBias
sums[tid] = 0;
for (int i=tid; i<output_w*output_h; i+=(blockDim.x*blockDim.y)) {
sums[tid] += gradOutput[i];
}
__syncthreads();
// reduce gradBias
if ((threadIdx.x == 0) && (threadIdx.y == 0)) {
Acctype scaledSums = Acctype(0);
for (int i=0; i<(blockDim.x*blockDim.y); i++) {
scaledSums += scale*sums[i];
}
gradBias[k] += ScalarConvert<Acctype, Dtype>::to(scaledSums);
}
}
/*
* Description:
* this function computes the gradInput from weight and gradOutput
*/
template <typename Dtype>
__global__ void subgradinput(Dtype *gradInput, Dtype *gradOutput, Dtype *weight,
int input_n, int input_h, int input_w,
int kH, int kW, int dH, int dW)
{
// iterators
int xx, yy;
// output size
int output_w = (input_w - kW) / dW + 1;
int output_h = (input_h - kH) / dH + 1;
// compute offsets based on thread/block ID
int o = blockIdx.x;
int i = o;
int k = blockIdx.x % input_n;
int xx_start = threadIdx.x;
int xx_end = output_w;
int xx_step = blockDim.x;
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
int yy_end = output_h;
int yy_step = blockDim.y*gridDim.y;
// select input/output plane
gradOutput = gradOutput + o*output_w*output_h;
gradInput = gradInput + i*input_w*input_h;
// get weight
Dtype the_weight = weight[k];
// compute gradInput
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
Dtype *ptr_gradInput = gradInput + yy*dH*input_w + xx*dW;
Dtype *ptr_gradOutput = gradOutput + yy*output_w + xx;
Dtype z = *ptr_gradOutput * the_weight;
int kx, ky;
for(ky = 0; ky < kH; ky++) {
for(kx = 0; kx < kW; kx++) {
// FIXME: should this be done at accreal precision?
ptr_gradInput[kx] += z;
}
ptr_gradInput += input_w;
}
}
}
}
/*
* Description:
* this function computes the gradInput from weight and gradOutput
*/
template <typename Dtype>
__global__ void subgradinputAtomic(Dtype *gradInput, Dtype *gradOutput, Dtype *weight,
int input_n, int input_h, int input_w,
int kH, int kW, int dH, int dW)
{
// iterators
int xx, yy;
// output size
int output_w = (input_w - kW) / dW + 1;
int output_h = (input_h - kH) / dH + 1;
// compute offsets based on thread/block ID
int o = blockIdx.x;
int i = o;
int k = blockIdx.x % input_n;
int xx_start = threadIdx.x;
int xx_end = output_w;
int xx_step = blockDim.x;
int yy_start = blockDim.y*blockIdx.y + threadIdx.y;
int yy_end = output_h;
int yy_step = blockDim.y*gridDim.y;
// select input/output plane
gradOutput = gradOutput + o*output_w*output_h;
gradInput = gradInput + i*input_w*input_h;
// get weight
Dtype the_weight = weight[k];
// compute gradInput
for(yy = yy_start; yy < yy_end; yy+=yy_step) {
for(xx = xx_start; xx < xx_end; xx+=xx_step) {
Dtype *ptr_gradInput = gradInput + yy*dH*input_w + xx*dW;
Dtype *ptr_gradOutput = gradOutput + yy*output_w + xx;
Dtype z = *ptr_gradOutput * the_weight;
int kx, ky;
for(ky = 0; ky < kH; ky++) {
for(kx = 0; kx < kW; kx++) {
// FIXME: should this be done at accreal precision?
atomicAdd(&(ptr_gradInput[kx]), z);
}
ptr_gradInput += input_w;
}
}
}
}
#include <THCUNN/generic/SpatialSubSampling.cu>
#include <THC/THCGenerateFloatTypes.h>
#undef CUDA_MAX_THREADS