forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
VolumetricAdaptiveMaxPooling.cu
207 lines (180 loc) · 6.67 KB
/
VolumetricAdaptiveMaxPooling.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
#include "THCUNN.h"
#include "TH/THHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCAtomics.cuh"
#include "THCTensor.hpp"
#define CUDA_MAX_THREADS 1024 // this is safe, in reality 256 is our limit
#define START_IND(a,b,c) (int)floor((float)(a * c) / b)
#define END_IND(a,b,c) (int)ceil((float)((a + 1) * c) / b)
// #define START_IND(a,b,c) a * c / b
// #define END_IND(a,b,c) (a + 1) * c / b + ((a + 1) * c % b > 0)?1:0
// 5d tensor B x D x T x H x W
/*
* Description:
* this function adaptively maxpools an input 4D tensor along dimensions 2 and 3
* 4D input, 4D output, 4D argmax x and y
*/
template <typename T>
__global__ void cunn_VolumetricAdaptiveMaxPooling_updateOutput_kernel(
T *input, T *output, THCIndex_t *indices,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
int64_t istrideD,
int64_t istrideT, int64_t istrideH, int64_t istrideW,
int64_t offsetZ)
{
// iterators on output pixels
int ot, oh, ow;
// compute offsets based on thread/block ID
int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
int oendH = osizeH;
int ostepH = gridDim.y * blockDim.y;
int ostartW = threadIdx.x;
int oendW = osizeW;
int ostepW = blockDim.x;
// select output plane
int64_t o_plane = blockIdx.x + offsetZ;
ot = o_plane % osizeT; // output frame/time
int d = o_plane / osizeT; // slice/feature
// input frame/time ramge is fixed.
int istartT = START_IND(ot, osizeT, isizeT);
int iendT = END_IND(ot, osizeT, isizeT);
int kT = iendT - istartT;
// input offset by slice/feature and earliest relevant frame/time
T *input_dt = input + d*istrideD + istartT*istrideT;
// output offset by slice/feature and frame/time
T *output_dt = output + o_plane*osizeH*osizeW;
// indices offset by slice/feature and frame/time
THCIndex_t *indices_dt = indices + o_plane*osizeH*osizeW;
// For all output pixels...
for(oh = ostartH; oh < oendH; oh += ostepH) {
int istartH = START_IND(oh, osizeH, isizeH);
int iendH = END_IND(oh, osizeH, isizeH);
int kH = iendH - istartH;
for(ow = ostartW; ow < oendW; ow += ostepW) {
int istartW = START_IND(ow, osizeW, isizeW);
int iendW = END_IND(ow, osizeW, isizeW);
int kW = iendW - istartW;
// Compute the average pooling from corresponding input pixels
T *ptr_input = input_dt + istartH*istrideH + istartW*istrideW;
T *ptr_output = output_dt + oh*osizeW + ow;
THCIndex_t *ptr_ind = indices_dt + oh*osizeW + ow;
int64_t argmax = -1;
T max = THCNumerics<T>::min();
int it, ih, iw;
for(it = 0; it < kT; ++it) {
for(ih = 0; ih < kH; ++ih) {
for(iw = 0; iw < kW; ++iw) {
T val = ptr_input[ih*istrideH + iw*istrideW];
if ((val > max) || THCNumerics<T>::isnan(val)) {
max = val;
argmax = (it+istartT)*isizeH*isizeW + (ih+istartH)*isizeW + iw+istartW;
}
}
}
ptr_input += istrideT; // next input frame
}
// Update output and argmax
*ptr_output = max;
*ptr_ind = argmax + TH_INDEX_BASE;
}
}
}
/*
* Description:
* This function computes the gradInput from gradOutput.
*
* gridDim.y blocks work together on a single 2D output plane specified by
* (blockIdx.x + offsetZ).
*
* Assumes that input size can be perfectly divided by output size, i.e.
* each input pixel can only be argmax of one output pixel.
*/
template <typename T>
__global__ void cunn_VolumetricAdaptiveMaxPooling_updateGradInput_kernel(
T *gradInput, T *gradOutput, THCIndex_t *indices,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
int64_t offsetZ
)
{
// iterators on output pixels
int oh, ow;
// compute offsets based on thread/block ID
int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
int oendH = osizeH;
int ostepH = gridDim.y * blockDim.y;
int ostartW = threadIdx.x;
int oendW = osizeW;
int ostepW = blockDim.x;
// select output plane
int64_t o_plane = blockIdx.x + offsetZ;
int d = o_plane / osizeT; // output slice/feature
// gradInput offset by slice/feature
T *gradInput_d = gradInput + d*isizeT*isizeH*isizeW;
// gradOutput offset by slice/feature and frame/otme
T *gradOutput_dt = gradOutput + o_plane*osizeH*osizeW;
// indices offset by slice/feature and frame/otme
THCIndex_t *indices_dt = indices + o_plane*osizeH*osizeW;
// For all output pixels...
for(oh = ostartH; oh < oendH; oh += ostepH) {
for(ow = ostartW; ow < oendW; ow += ostepW) {
// Compute the gradients for the argmax input pixel
T *ptr_gradOutput = gradOutput_dt + oh*osizeW + ow;
THCIndex_t *ptr_ind = indices_dt + oh*osizeW + ow;
T grad_delta = *ptr_gradOutput;
int argmax = (*ptr_ind) - TH_INDEX_BASE;
gradInput_d[argmax] += grad_delta;
}
}
}
/*
* Description:
* This function computes the gradInput from gradOutput.
*
* gridDim.y blocks work together on a single 2D output plane specified by
* (blockIdx.x + offsetZ).
*
* Uses atomic add.
*/
template <typename T>
__global__ void cunn_atomic_VolumetricAdaptiveMaxPooling_updateGradInput_kernel(
T *gradInput, T *gradOutput, THCIndex_t *indices,
int isizeT, int isizeH, int isizeW,
int osizeT, int osizeH, int osizeW,
int64_t offsetZ
)
{
// iterators on output pixels
int oh, ow;
// compute offsets based on thread/block ID
int ostartH = blockIdx.y * blockDim.y + threadIdx.y;
int oendH = osizeH;
int ostepH = gridDim.y * blockDim.y;
int ostartW = threadIdx.x;
int oendW = osizeW;
int ostepW = blockDim.x;
// select output plane
int64_t o_plane = blockIdx.x + offsetZ;
int d = o_plane / osizeT; // output slice/feature
// gradInput offset by slice/feature
T *gradInput_d = gradInput + d*isizeT*isizeH*isizeW;
// gradOutput offset by slice/feature and frame/otme
T *gradOutput_dt = gradOutput + o_plane*osizeH*osizeW;
// indices offset by slice/feature and frame/otme
THCIndex_t *indices_dt = indices + o_plane*osizeH*osizeW;
// For all output pixels...
for(oh = ostartH; oh < oendH; oh += ostepH) {
for(ow = ostartW; ow < oendW; ow += ostepW) {
// Compute the gradients for the argmax input pixel
T *ptr_gradOutput = gradOutput_dt + oh*osizeW + ow;
THCIndex_t *ptr_ind = indices_dt + oh*osizeW + ow;
T grad_delta = *ptr_gradOutput;
int64_t argmax = (*ptr_ind) - TH_INDEX_BASE;
atomicAdd(&(gradInput_d[argmax]), grad_delta);
}
}
}
#include "generic/VolumetricAdaptiveMaxPooling.cu"
#include "THCGenerateFloatTypes.h"
#undef CUDA_MAX_THREADS