forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
batch_moments_op.cu
152 lines (138 loc) · 4.02 KB
/
batch_moments_op.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include "caffe2/operators/batch_moments_op.h"
#include <cub/block/block_reduce.cuh>
#include "caffe2/core/context_gpu.h"
namespace caffe2 {
namespace {
template <typename T>
using BlockReduce = cub::BlockReduce<T, CAFFE_CUDA_NUM_THREADS>;
template <typename T, StorageOrder kOrder>
__global__ void BatchMomentsCUDAKernel(
const int N,
const int C,
const int HxW,
const T* X,
T* mu,
T* var) {
const int outer_size = C;
const int inner_size = N * HxW;
__shared__ typename BlockReduce<T>::TempStorage m_storage;
__shared__ typename BlockReduce<T>::TempStorage v_storage;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
T m_sum = 0;
T v_sum = 0;
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = kOrder == StorageOrder::NCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
#if __CUDA_ARCH__ >= 350
m_sum += __ldg(X + index);
v_sum += __ldg(X + index) * __ldg(X + index);
#else
m_sum += X[index];
v_sum += X[index] * X[index];
#endif
}
m_sum = BlockReduce<T>(m_storage).Reduce(m_sum, cub::Sum());
v_sum = BlockReduce<T>(v_storage).Reduce(v_sum, cub::Sum());
if (threadIdx.x == 0) {
mu[i] = m_sum / static_cast<T>(N * HxW);
var[i] = v_sum / static_cast<T>(N * HxW);
}
__syncthreads();
}
}
template <typename T, StorageOrder kOrder>
__global__ void BatchMomentsGradientCUDAKernel(
const int N,
const int C,
const int HxW,
const T* dmu,
const T* dvar,
const T* X,
T* dX) {
const int size = N * C * HxW;
const T scale = T(1) / static_cast<T>(N * HxW);
CUDA_1D_KERNEL_LOOP(i, size) {
const int i_mu = kOrder == StorageOrder::NCHW ? i / (HxW) % C : i % C;
#if __CUDA_ARCH__ >= 350
dX[i] =
(__ldg(dmu + i_mu) + __ldg(dvar + i_mu) * T(2) * __ldg(X + i)) * scale;
#else
dX[i] = (dmu[i_mu] + dvar[i_mu] * T(2) * X[i]) * scale;
#endif
}
}
} // namespace
template <>
bool BatchMomentsOp<float, CUDAContext>::ComputeBatchMomentsNCHW(
const int N,
const int C,
const int HxW,
const float* X,
float* mu,
float* var) {
const int outer_size = N * HxW;
BatchMomentsCUDAKernel<float, StorageOrder::NCHW>
<<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(N, C, HxW, X, mu, var);
return true;
}
template <>
bool BatchMomentsOp<float, CUDAContext>::ComputeBatchMomentsNHWC(
const int N,
const int C,
const int HxW,
const float* X,
float* mu,
float* var) {
const int outer_size = N * HxW;
BatchMomentsCUDAKernel<float, StorageOrder::NHWC>
<<<std::min(outer_size, CAFFE_MAXIMUM_NUM_BLOCKS),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(N, C, HxW, X, mu, var);
return true;
}
template <>
bool BatchMomentsGradientOp<float, CUDAContext>::
ComputeBatchMomentsGradientNCHW(
const int N,
const int C,
const int HxW,
const float* dmu,
const float* dvar,
const float* X,
float* dX) {
const int size = N * C * HxW;
BatchMomentsGradientCUDAKernel<float, StorageOrder::NCHW>
<<<CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(N, C, HxW, dmu, dvar, X, dX);
return true;
}
template <>
bool BatchMomentsGradientOp<float, CUDAContext>::
ComputeBatchMomentsGradientNHWC(
const int N,
const int C,
const int HxW,
const float* dmu,
const float* dvar,
const float* X,
float* dX) {
const int size = N * C * HxW;
BatchMomentsGradientCUDAKernel<float, StorageOrder::NHWC>
<<<CAFFE_GET_BLOCKS(size),
CAFFE_CUDA_NUM_THREADS,
0,
context_.cuda_stream()>>>(N, C, HxW, dmu, dvar, X, dX);
return true;
}
REGISTER_CUDA_OPERATOR(BatchMoments, BatchMomentsOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(
BatchMomentsGradient,
BatchMomentsGradientOp<float, CUDAContext>);
} // namespace caffe2