forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Distributions.cpp
215 lines (196 loc) · 7.52 KB
/
Distributions.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#include "ATen/ATen.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"
#include "ATen/CPUGenerator.h"
#include "ATen/CheckGenerator.h"
#include "ATen/Generator.h"
#include "TH/THRandom.h"
namespace {
/*
* This section is a counterpart to Distributions.cu
*
*/
// The function `sample_poisson`
// is adapted from Numpy's distributions.c implementation.
// It is MIT licensed, so here is the copyright:
/* Copyright 2005 Robert Kern ([email protected])
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
THGenerator* get_generator(at::Generator* gen) {
auto default_gen = &at::globalContext().defaultGenerator(at::Backend::CPU);
auto gen_ = at::check_generator<at::CPUGenerator>(gen, default_gen);
return gen_->generator;
}
int64_t sample_poisson(double lambda, THGenerator* generator) {
if (lambda >= 10) {
// transformed rejection method, (Hoermann, 1993)
int64_t k;
double U, V, a, b, invalpha, vr, us;
double slam = std::sqrt(lambda);
double loglam = std::log(lambda);
b = 0.931 + 2.53 * slam;
a = -0.059 + 0.02483 * b;
invalpha = 1.1239 + 1.1328 / (b - 3.4);
vr = 0.9277 - 3.6224 / (b - 2);
while (1) {
U = THRandom_standard_uniform(generator) - 0.5;
V = THRandom_standard_uniform(generator);
us = 0.5 - std::fabs(U);
k = (int64_t)std::floor((2 * a / us + b) * U + lambda + 0.43);
if ((us >= 0.07) && (V <= vr)) {
return k;
}
if ((k < 0) || ((us < 0.013) && (V > us))) {
continue;
}
if ((std::log(V) + std::log(invalpha) - std::log(a / (us * us) + b)) <=
(-lambda + k * loglam - std::lgamma((double)k + 1))) {
return k;
}
}
} else if (lambda == 0) {
return 0;
} else {
int64_t X;
double prod, U, enlam;
enlam = std::exp(-lambda);
X = 0;
prod = 1.0;
while (1) {
U = THRandom_standard_uniform(generator);
prod *= U;
if (prod > enlam) {
X += 1;
} else {
return X;
}
}
}
}
// TODO Replace this with more accurate digamma().
template <typename scalar_t>
scalar_t digamma_one(scalar_t x) {
const double eps = x * 1e-3;
return (std::lgamma(x + eps) - std::lgamma(x - eps)) / (eps + eps);
}
// Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
// for random number x drawn from a standard Gamma distribution Gamma(alpha).
template <typename scalar_t>
scalar_t standard_gamma_grad_one(scalar_t alpha, scalar_t x) {
// Use a Taylor series expansion for small x.
if (x < 0.8f) {
scalar_t numer = 1;
scalar_t denom = alpha;
auto series1 = numer / denom;
auto series2 = numer / (denom * denom);
for (int i = 1; i <= 5; ++i) {
numer *= -x / i;
denom += 1;
series1 += numer / denom;
series2 += numer / (denom * denom);
}
const auto pow_x_alpha = std::pow(x, alpha);
const auto gamma_pdf = std::pow(x, alpha - 1) * std::exp(-x);
const auto gamma_cdf = pow_x_alpha * series1;
const auto gamma_cdf_alpha = (std::log(x) - digamma_one(alpha)) * gamma_cdf
- pow_x_alpha * series2;
const auto result = -gamma_cdf_alpha / gamma_pdf;
return std::isnan(result) ? 0 : result;
}
// Use a Rice saddle point expansion for large alpha.
if (alpha > 8.0f) {
if (0.9f * alpha <= x && x <= 1.1f * alpha) {
const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha);
const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x)
- 65 * x * x / alpha + alpha * (107 + 3600 * x);
const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha);
return numer_1 * numer_2 / denom;
}
const auto denom = std::sqrt(8 * alpha);
const auto term2 = denom / (alpha - x);
const auto term3 = std::pow(x - alpha - alpha * std::log(x / alpha), -1.5f);
const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3;
const auto term1 = std::log(x / alpha) * term23
- std::sqrt(2 / alpha) * (alpha + x) / ((alpha - x) * (alpha - x));
const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha));
const auto numer = x * term1;
return -stirling * numer / denom;
}
// Use a bivariate rational approximation to the reparameterized gradient.
const auto u = std::log(x / alpha);
const auto v = std::log(alpha);
static const scalar_t coef_uv[3][8] = {
{0.16009398, -0.094634809, 0.025146376, -0.0030648343,
1, 0.32668115, 0.10406089, 0.0014179084},
{0.53487893, 0.1298071, 0.065735949, -0.0015649758,
0.16639465, 0.020070113, -0.0035938915, -0.00058392623},
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777,
0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07},
};
scalar_t coef_v[8];
for (int i = 0; i < 8; ++ i) {
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
}
const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
return std::exp(p / q);
}
} // namespace
namespace at {
namespace native {
Tensor& bernoulli_(Tensor& self, const Tensor& p, Generator* generator) {
self.copy_(at::bernoulli(std::get<0>(expand_inplace(self, p)), generator));
return self;
}
Tensor& bernoulli_(Tensor& self, double p, Generator* generator) {
Tensor probs = self.type().toScalarType(kDouble).tensor({}).fill_(p);
return native::bernoulli_(self, probs, generator);
}
Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
Tensor ret = self.type().tensor(self.sizes());
AT_DISPATCH_FLOATING_TYPES(self.type(), "_standard_gamma_grad", [&] {
CPU_tensor_apply3<scalar_t, scalar_t, scalar_t>(ret, self, output,
[](scalar_t& ret_val, const scalar_t& self_val, const scalar_t &output_val) {
ret_val = standard_gamma_grad_one(self_val, output_val);
}
);
});
return ret;
}
Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) {
runtime_error("_standard_gamma_grad is not implemented for CUDA types");
}
Tensor _s_poisson_cpu(const Tensor& lambda, Generator *gen) {
Tensor ret = at::zeros(lambda.type(), lambda.sizes());
auto lambda_ = lambda.toType(ScalarType::Double);
AT_DISPATCH_FLOATING_TYPES(ret.type(), "poisson", [&] {
THGenerator* generator = get_generator(gen);
CPU_tensor_apply2<scalar_t, double>(ret, lambda_,
[generator](scalar_t& ret_val, const double& lambda){
ret_val = sample_poisson(lambda, generator);
}
);
});
return ret;
}
}} // namespace at::native