forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Activation.cpp
384 lines (330 loc) · 12.5 KB
/
Activation.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
#include <ATen/native/Activation.h>
#include <ATen/ATen.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/Parallel.h>
namespace at { namespace native {
static const double SELU_ALPHA = 1.6732632423543772848170429916717;
static const double SELU_SCALE = 1.0507009873554804934193349852946;
DEFINE_DISPATCH(threshold_stub);
DEFINE_DISPATCH(hardshrink_cpu_stub);
DEFINE_DISPATCH(hardshrink_backward_cpu_stub);
Tensor relu(const Tensor & self) {
return at::threshold(self, 0, 0);
}
Tensor & relu_(Tensor & self) {
return at::threshold_(self, 0, 0);
}
Tensor selu(const Tensor & self) {
return at::elu(self, SELU_ALPHA, SELU_SCALE);
}
Tensor & selu_(Tensor & self) {
return at::elu_(self, SELU_ALPHA, SELU_SCALE);
}
Tensor celu(const Tensor & self, Scalar alpha) {
double inv_alpha = 1. / alpha.to<double>();
return at::elu(self, alpha, Scalar(1.0), Scalar(inv_alpha));
}
Tensor & celu_(Tensor & self, Scalar alpha) {
double inv_alpha = 1. / alpha.to<double>();
return at::elu_(self, alpha, Scalar(1.0), Scalar(inv_alpha));
}
Tensor rrelu(const Tensor & self, Scalar lower, Scalar upper, bool training, Generator* generator) {
return at::rrelu_with_noise(self, at::empty_like(self), lower, upper, training, generator);
}
Tensor & rrelu_(Tensor & self, Scalar lower, Scalar upper, bool training, Generator* generator) {
return at::rrelu_with_noise_(self, at::empty_like(self), lower, upper, training, generator);
}
// computes `result = self <= threshold ? value : other`
// other is `self` in threshold() and `grad` in threshold_backward()
static Tensor threshold_out(
optional<Tensor> opt_result,
const Tensor& self,
Scalar threshold,
Scalar value,
const Tensor& other) {
Tensor result = opt_result.value_or(Tensor());
auto iter = TensorIterator::binary_op(result, self, other);
threshold_stub(iter.device_type(), iter, threshold, value);
return iter.output();
}
Tensor threshold(const Tensor& self, Scalar threshold, Scalar value) {
return threshold_out(nullopt, self, threshold, value, self);
}
Tensor& threshold_(Tensor& self, Scalar threshold, Scalar value) {
threshold_out(make_optional(self), self, threshold, value, self);
return self;
}
Tensor& threshold_out(Tensor& result, const Tensor& self, Scalar threshold, Scalar value) {
threshold_out(make_optional(result), self, threshold, value, self);
return result;
}
Tensor threshold_backward(const Tensor& grad, const Tensor& self, Scalar threshold) {
return threshold_out(nullopt, self, threshold, 0, grad);
}
// -----------------------------------
// prelu forward
// -----------------------------------
template <typename scalar_t>
void inline prelu_cpu_kernel_share_weights(
Tensor& result,
const Tensor& input,
const Tensor& weight) {
int64_t input_numel = input.numel();
auto result_data = result.data<scalar_t>();
auto input_data = input.data<scalar_t>();
auto weight_val = weight.data<scalar_t>()[0];
at::parallel_for(0, input_numel, 1000, [&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
scalar_t input_data_val = input_data[i];
// to allow for compiler optimization, here splitting into two lines:
scalar_t r = (input_data_val > 0) ? scalar_t(1) : weight_val;
result_data[i] = r * input_data_val;
}
});
}
template <typename scalar_t>
void inline prelu_cpu_kernel_multi_weights(
Tensor& result,
const Tensor& input,
const Tensor& weight,
int64_t input_dim0_size,
int64_t channel_size,
int64_t input_stride0,
int64_t input_stride1) {
scalar_t* result_data = result.data<scalar_t>();
scalar_t* input_data = input.data<scalar_t>();
scalar_t* weight_data = weight.data<scalar_t>();
auto loop = [&](int64_t start, int64_t end) {
for (auto i = start; i < end; ++i) {
int64_t offset = i * channel_size * input_stride1;
scalar_t* n_input_data = input_data + offset;
scalar_t* n_result_data = result_data + offset;
for (auto j = 0; j < channel_size; ++j) {
for (auto k = 0; k < input_stride1; ++k) {
// to allow for compiler optimization, here splitting into two lines:
scalar_t w = (n_input_data[k] > 0) ? scalar_t(1) : weight_data[j];
n_result_data[k] = w * n_input_data[k];
}
n_input_data += input_stride1;
n_result_data += input_stride1;
}
}
};
if (input.numel() > 1000) {
at::parallel_for(0, input_dim0_size, 0, loop);
} else {
loop(0, input_dim0_size);
}
}
Tensor prelu_cpu(const Tensor& self, const Tensor& weight_) {
auto input = self.contiguous();
auto weight = weight_.contiguous();
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int64_t weight_num = weight.numel();
Tensor result = at::empty_like(input);
auto strides = input.strides();
// case1: shared weight for all channels
if (weight_num == 1) {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_cpu", [&] {
prelu_cpu_kernel_share_weights<scalar_t>(result, input, weight);
});
}
else { // case2: multiple weights, one for each channel
int64_t input_ndim = input.dim();
TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor.");
int64_t channel_size = 1; // channel_size default to 1
int64_t input_dim0_size = 1, input_stride0 = 1, input_stride1 = 1;
if (input_ndim > 1) {
channel_size = input.size(1); // channel is the 2nd dim of input
input_dim0_size = input.size(0);
input_stride0 = strides[0];
input_stride1 = strides[1];
}
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_cpu", [&] {
prelu_cpu_kernel_multi_weights<scalar_t>(
result,
input,
weight,
input_dim0_size,
channel_size,
input_stride0,
input_stride1);
});
}
return result;
}
// -----------------------------------
// prelu backward
// -----------------------------------
template <typename scalar_t>
void inline prelu_cpu_backward_kernel_share_weights(
const Tensor& input,
const Tensor& weight,
const Tensor& grad_out,
Tensor& input_grad,
Tensor& weight_grad) {
int64_t input_numel = input.numel();
auto input_data = input.data<scalar_t>();
auto weight_val = weight.data<scalar_t>()[0];
auto grad_out_data = grad_out.data<scalar_t>();
auto input_grad_data = input_grad.data<scalar_t>();
auto weight_grad_data = weight_grad.data<scalar_t>();
scalar_t sum = at::parallel_reduce(0, input_numel, 1000, scalar_t(0),
[&](int64_t start, int64_t end, scalar_t ident) -> scalar_t {
scalar_t partial_sum = ident;
for (auto i = start; i < end; i++) {
scalar_t input_data_val = input_data[i];
scalar_t grad_out_data_val = grad_out_data[i];
// to allow for compiler optimization, here splitting into two lines:
scalar_t w = (input_data_val > 0) ? scalar_t(1) : weight_val;
input_grad_data[i] = w * grad_out_data_val;
// to allow for compiler optimization, here splitting into two lines:
scalar_t mask = (input_data_val > 0) ? scalar_t(0) : scalar_t(1);
partial_sum += mask * input_data_val * grad_out_data_val;
}
return partial_sum;
}, std::plus<scalar_t>());
weight_grad_data[0] = sum;
}
template <typename scalar_t>
void inline prelu_cpu_backward_kernel_multi_weights(
const Tensor& input,
const Tensor& weight,
const Tensor& grad_out,
Tensor& input_grad,
Tensor& weight_grad_collector,
int64_t input_dim0_size,
int64_t channel_size,
int64_t input_stride0,
int64_t input_stride1) {
auto input_data = input.data<scalar_t>();
auto weight_data = weight.data<scalar_t>();
auto grad_out_data = grad_out.data<scalar_t>();
auto input_grad_data = input_grad.data<scalar_t>();
auto weight_grad_collector_data = weight_grad_collector.data<scalar_t>();
auto loop = [&](int64_t start, int64_t end) {
for (auto i = start; i < end; i++) {
for (auto j = 0; j < channel_size; j++) {
for (auto k = 0; k < input_stride1; k++) {
int64_t pos = i * input_stride0 + j * input_stride1 + k;
scalar_t weight_data_val = weight_data[j];
scalar_t input_data_val = input_data[pos];
scalar_t grad_out_data_val = grad_out_data[pos];
// to allow for compiler optimization, here splitting into two lines:
scalar_t w = (input_data_val > 0) ? scalar_t(1) : weight_data_val;
input_grad_data[pos] = w * grad_out_data_val;
// to allow for compiler optimization, here splitting into two lines:
scalar_t mask = (input_data_val > 0) ? scalar_t(0) : scalar_t(1);
weight_grad_collector_data[pos] = mask * input_data_val * grad_out_data_val;
}
}
}
};
if (input.numel() > 1000) {
at::parallel_for(0, input_dim0_size, 0, loop);
} else {
loop(0, input_dim0_size);
}
}
std::tuple<Tensor, Tensor> prelu_backward_cpu(const Tensor& grad_out_, const Tensor& self, const Tensor& weight_) {
auto input = self.contiguous();
auto grad_out = grad_out_.contiguous();
auto weight = weight_.contiguous();
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(grad_out.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
int64_t weight_num = weight.numel();
auto strides = input.strides();
auto dims = input.dim();
Tensor input_grad = at::empty_like(input);
Tensor weight_grad = at::empty_like(weight);
Tensor weight_grad_collector = at::empty_like(input);
// case1: shared parameter for all channels
if (weight_num == 1) {
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_backward_cpu", [&] {
prelu_cpu_backward_kernel_share_weights<scalar_t>(input, weight, grad_out, input_grad, weight_grad);
});
}
else { // case2: multiple parameters, one for each channel
int64_t input_ndim = input.dim();
TORCH_CHECK(input_ndim > 0, "Not allow zero-dim input tensor.");
int64_t channel_size = 1; // channel_size default to 1
int64_t input_dim0_size = 1, input_stride0 = 1, input_stride1 = 1;
if (input_ndim > 1) {
channel_size = input.size(1); // channel is the 2nd dim of input
input_dim0_size = input.size(0);
input_stride0 = strides[0];
input_stride1 = strides[1];
}
TORCH_CHECK(channel_size == weight_num,
"Mismatch of parameter numbers and input channel size. Found parameter numbers = ", weight_num,
" and channel size = ", channel_size, ".");
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "prelu_backward_cpu", [&] {
prelu_cpu_backward_kernel_multi_weights<scalar_t>(
input,
weight,
grad_out,
input_grad,
weight_grad_collector,
input_dim0_size,
channel_size,
input_stride0,
input_stride1);
});
// update weight_grad
std::vector<int64_t> reduce_dims;
reduce_dims.push_back(0);
if (dims > 2) {
for(int64_t i = 2; i < dims; i++) reduce_dims.push_back(i);
}
weight_grad = weight_grad_collector.sum(reduce_dims);
}
return std::tuple<Tensor, Tensor>{input_grad, weight_grad};
}
// -----------------------------------
// hardshrink
// -----------------------------------
Tensor hardshrink_cpu(const Tensor & self, Scalar lambd) {
auto out_tensor = at::empty_like(self);
auto iter = TensorIterator::unary_op(out_tensor, self);
hardshrink_cpu_stub(kCPU, iter, lambd);
return out_tensor;
}
Tensor hardshrink_backward_cpu(const Tensor & grad, const Tensor & self, Scalar lambd) {
auto out_tensor = at::empty_like(self);
auto iter = TensorIterator::binary_op(out_tensor, grad, self);
hardshrink_backward_cpu_stub(kCPU, iter, lambd);
return out_tensor;
}
Tensor gelu_cpu(const Tensor& self) {
const auto X = self.contiguous();
Tensor Y = at::native::empty_like(X);
GeluKernel(kCPU, X, &Y);
return Y;
}
Tensor gelu_cuda(const Tensor& self) {
Tensor Y = at::native::empty_like(self);
GeluKernel(kCUDA, self, &Y);
return Y;
}
Tensor gelu_backward_cpu(const Tensor& grad, const Tensor& self) {
const auto X = self.contiguous();
Tensor dX = at::native::empty_like(X);
GeluBackwardKernel(kCPU, grad.contiguous(), X, &dX);
return dX;
}
Tensor gelu_backward_cuda(const Tensor& grad, const Tensor& self) {
Tensor dX = at::native::empty_like(self);
GeluBackwardKernel(kCUDA, grad, self, &dX);
return dX;
}
DEFINE_DISPATCH(GeluKernel);
DEFINE_DISPATCH(GeluBackwardKernel);
}} // namespace at::native