Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Gamma reparameterization gradient #18852

Merged
merged 8 commits into from
Aug 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/mxnet/gluon/probability/distributions/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def broadcast_to(self, batch_shape):
return new_instance

def sample(self, size=None):
return self.F.np.random.gamma(self.shape, self.scale, size)
return self.F.np.random.gamma(self.shape, 1, size) * self.scale

def sample_n(self, size=None):
return self.F.np.random.gamma(self.shape, self.scale, sample_n_shape_converter(size))
return self.F.np.random.gamma(self.shape, 1, sample_n_shape_converter(size)) * self.scale

@property
def mean(self):
Expand Down
66 changes: 66 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,72 @@ struct smooth_l1_gradient : public mxnet_op::tunable {
}
}; // struct smooth_l1_derivative

/* Implicti reparameterization gradient for standard x ~ Gamma(\alpha, 1)
* according to dx/da = -cdf(x;alpha) / pdf(x;alpha)
*/
struct gamma_implicit_grad {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType x) {
if (x < 0.8f) {
DType numer = 1;
DType denom = a;
DType series1 = numer / denom;
DType series2 = numer / (denom * denom);
for (int i = 1; i <= 5; i++) {
numer *= -x / static_cast<DType>(i);
denom += 1;
series1 += numer / denom;
series2 += numer / (denom * denom);
}
DType pow_x_alpha = math::pow(x, a);
DType gamma_pdf = math::pow(x, a - 1) * math::exp(-x);
DType gamma_cdf = pow_x_alpha * series1;
DType gamma_cdf_alpha =
(math::log(x) - DType(special_functions::cephes::psi<float>(a))) *
gamma_cdf -
pow_x_alpha * series2;
DType result = -gamma_cdf_alpha / gamma_pdf;
return IsNan(result) ? static_cast<DType>( 0.f ) : static_cast<DType>(result);
}
if (a > 8.0f) {
if (0.9f * a <= x && x <= 1.1f * a) {
DType numer_1 = 1 + 24 * a * (1 + 12 * a);
DType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) -
65 * x * x / a + a * (107 + 3600 * x);
DType denom = 1244160 * (a * a) * (a * a);
return static_cast<DType>(numer_1 * numer_2 / denom);
}
DType denom = math::sqrt(8 * a);
DType term2 = denom / (a - x);
DType term3 =
math::pow(x - a - a * math::log(x / a), static_cast<DType>(-1.5));
DType term23 = (x < a) ? term2 - term3 : term2 + term3;
DType term1 = math::log(x / a) * term23 -
math::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
DType stirling = 1 + 1 / (12 * a) * (1 + 1 / (24 * a));
DType numer = x * term1;
return static_cast<DType>(-stirling * numer / denom);
}
DType u = math::log(x / a);
DType v = math::log(a);
DType coef_uv[3][8] = {
{0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115,
0.10406089, 0.0014179084},
{0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465,
0.020070113, -0.0035938915, -0.00058392623},
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642,
-0.0021309326, 0.00085092367, -1.5247877e-07},
};
DType coef_v[8];
for (int i = 0; i < 8; i++) {
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
}
DType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
DType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
return static_cast<DType>(math::exp(p / q));
}
}; // gamma_implicit_grad

/*! \brief product reducer */
struct product {
/*! \brief do reduction into dst */
Expand Down
31 changes: 30 additions & 1 deletion src/operator/numpy/random/np_gamma_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,39 @@ NNVM_REGISTER_OP(_npi_gamma)
ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyGammaForward<cpu, double>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_gamma_sample"})
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_argument("input2", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyGammaParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_gamma_sample)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<NumpyGammaParam>)
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyGammaParam& param = nnvm::get<NumpyGammaParam>(attrs.parsed);
int num_inputs = 4;
if (param.shape.has_value()) num_inputs -= 1;
if (param.scale.has_value()) num_inputs -= 1;
return num_inputs;
}
)
.set_num_outputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyGammaParam& param = nnvm::get<NumpyGammaParam>(attrs.parsed);
int num_outputs = 2;
if (param.shape.has_value()) num_outputs -= 1;
if (param.scale.has_value()) num_outputs -= 1;
return num_outputs;
}
)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyGammaGrad<cpu>)
.add_arguments(NumpyGammaParam::__FIELDS__());


} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/random/np_gamma_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ namespace op {
NNVM_REGISTER_OP(_npi_gamma)
.set_attr<FCompute>("FCompute<gpu>", NumpyGammaForward<gpu, double>);

NNVM_REGISTER_OP(_backward_gamma_sample)
.set_attr<FCompute>("FCompute<gpu>", NumpyGammaGrad<gpu>);

} // namespace op
} // namespace mxnet
82 changes: 82 additions & 0 deletions src/operator/numpy/random/np_gamma_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ struct CheckSuccessKernel {
}
};

template<typename DType>
struct StandarizeKernel {
MSHADOW_XINLINE static void Map(int i, DType* samples, float scale) {
samples[i] /= scale;
}
};

template <int ndim, typename IType, typename OType, typename FType>
struct gamma_kernel {
MSHADOW_XINLINE static void Map(index_t i, const Shape<ndim> &lstride,
Expand Down Expand Up @@ -394,6 +401,81 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}
}

template<typename xpu, int ndim, typename DType>
inline void GammaReparamBackwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_ishape,
const mxnet::TShape& new_oshape,
const float scale) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob igrad = outputs[0].reshape(new_ishape);
// inputs: [grad_from_samples, alpha_tensor, samples]
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob alpha = inputs[1].reshape(new_ishape);
TBlob samples = inputs[2].reshape(new_oshape);
size_t workspace_size =
ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
// Convert samples to standard gamma
Kernel<op_with_req<mshadow_op::div, kWriteTo>, xpu>::Launch(
s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), DType(scale));
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::gamma_implicit_grad>(
xidulu marked this conversation as resolved.
Show resolved Hide resolved
xidulu marked this conversation as resolved.
Show resolved Hide resolved
s, igrad, req[0], workspace, ograd, alpha, samples);
Kernel<op_with_req<mshadow_op::mul, kWriteTo>, xpu>::Launch(
s, igrad.Size(), igrad.dptr<DType>(), igrad.dptr<DType>(), DType(scale));
// Convert samples back, otherwise the output would be corrupted.
Kernel<op_with_req<mshadow_op::mul, kWriteTo>, xpu>::Launch(
s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), DType(scale));
}

// Allow gamma sampling to be differentiable,
// using implicit reparameterization gradient:
// -(d/d\alpha cdf(x;alpha)) / pdf(x;alpha)
template<typename xpu>
void NumpyGammaGrad(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
// [scalar, scalar] case
if (outputs.size() == 0U) {
return;
}
const NumpyGammaParam &param = nnvm::get<NumpyGammaParam>(attrs.parsed);
// [tensor tensor] case, not supported.
if (inputs.size() == 5U) {
LOG(FATAL) << "ValueError: two tensor case not supported";
}

// [tensor, scalar] case, only scalar scale is supported.
if (inputs.size() == 3U) {
if (param.shape.has_value()) {
LOG(FATAL) << "ValueError: tensor scale case not supported";
}
mxnet::TShape new_ishape, new_oshape;
int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_,
&new_ishape, &new_ishape, &new_oshape);
auto scale = param.scale.value();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
GammaReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_ishape, new_oshape, scale);
});
});
}
}

} // namespace op
} // namespace mxnet

Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4777,6 +4777,46 @@ def _test_gamma_exception(shape, scale):
assertRaises(ValueError, _test_gamma_exception, shape, scale)


@with_seed()
@use_np
@pytest.mark.parametrize("shape", [(1,), (2, 2), (4, 2, 2)])
@pytest.mark.parametrize("a", [2.0, 5.0, 10.0])
@pytest.mark.parametrize("b", [0.5, 1.0, 1.5])
def test_gamma_grad(shape, a, b):
class TestGammaGrad(HybridBlock):
def __init__(self, size, beta):
super(TestGammaGrad, self).__init__()
self._size = size
self._beta = beta

def hybrid_forward(self, F, a):
return F.np.random.gamma(a, self._beta, size=self._size)

for hybridize in [True, False]:
param = np.ones(shape) * a
param.attach_grad()
net = TestGammaGrad(shape, b)
if hybridize:
net.hybridize()
with mx.autograd.record():
samples = net(param)
samples.backward()
# Check shape
assert param.grad.shape == param.shape
# Check correctness
cdf = ss.gamma.cdf
log_pdf = ss.gamma.logpdf
eps = (0.01 * param / (1.0 + param ** 0.5)).asnumpy()
x = samples.asnumpy().astype('float64') / b
# d(cdf(x;alpha,beta))/d(alpha)
cdf_alpha = (cdf(x, param.asnumpy() + eps) -
cdf(x, param.asnumpy() - eps)) / (2 * eps)
# d(cdf(x;alpha,beta))/d(x)
log_cdf_x = log_pdf(x, param.asnumpy())
expected_grad = -b * cdf_alpha / _np.exp(log_cdf_x)
assert_almost_equal(expected_grad, param.grad.asnumpy(), rtol=1e-2, atol=1e-3)


@with_seed()
@use_np
@pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/18600')
Expand Down