Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Gamma reparameterization gradient (#18852)
Browse files Browse the repository at this point in the history
* gamma grad wip

* gamma grad wip

* test tbd

* fix grad

* change scale to the frontend

* fix bugs

* change distributions.gamma

* fix test and operator tune
  • Loading branch information
xidulu authored Aug 12, 2020
1 parent f2a8b97 commit 83d2af5
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/gluon/probability/distributions/gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def broadcast_to(self, batch_shape):
return new_instance

def sample(self, size=None):
return self.F.np.random.gamma(self.shape, self.scale, size)
return self.F.np.random.gamma(self.shape, 1, size) * self.scale

def sample_n(self, size=None):
return self.F.np.random.gamma(self.shape, self.scale, sample_n_shape_converter(size))
return self.F.np.random.gamma(self.shape, 1, sample_n_shape_converter(size)) * self.scale

@property
def mean(self):
Expand Down
66 changes: 66 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,72 @@ struct smooth_l1_gradient : public mxnet_op::tunable {
}
}; // struct smooth_l1_derivative

/* Implicti reparameterization gradient for standard x ~ Gamma(\alpha, 1)
* according to dx/da = -cdf(x;alpha) / pdf(x;alpha)
*/
struct gamma_implicit_grad {
template <typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType x) {
if (x < 0.8f) {
DType numer = 1;
DType denom = a;
DType series1 = numer / denom;
DType series2 = numer / (denom * denom);
for (int i = 1; i <= 5; i++) {
numer *= -x / static_cast<DType>(i);
denom += 1;
series1 += numer / denom;
series2 += numer / (denom * denom);
}
DType pow_x_alpha = math::pow(x, a);
DType gamma_pdf = math::pow(x, a - 1) * math::exp(-x);
DType gamma_cdf = pow_x_alpha * series1;
DType gamma_cdf_alpha =
(math::log(x) - DType(special_functions::cephes::psi<float>(a))) *
gamma_cdf -
pow_x_alpha * series2;
DType result = -gamma_cdf_alpha / gamma_pdf;
return IsNan(result) ? static_cast<DType>( 0.f ) : static_cast<DType>(result);
}
if (a > 8.0f) {
if (0.9f * a <= x && x <= 1.1f * a) {
DType numer_1 = 1 + 24 * a * (1 + 12 * a);
DType numer_2 = 1440 * (a * a) + 6 * x * (53 - 120 * x) -
65 * x * x / a + a * (107 + 3600 * x);
DType denom = 1244160 * (a * a) * (a * a);
return static_cast<DType>(numer_1 * numer_2 / denom);
}
DType denom = math::sqrt(8 * a);
DType term2 = denom / (a - x);
DType term3 =
math::pow(x - a - a * math::log(x / a), static_cast<DType>(-1.5));
DType term23 = (x < a) ? term2 - term3 : term2 + term3;
DType term1 = math::log(x / a) * term23 -
math::sqrt(2 / a) * (a + x) / ((a - x) * (a - x));
DType stirling = 1 + 1 / (12 * a) * (1 + 1 / (24 * a));
DType numer = x * term1;
return static_cast<DType>(-stirling * numer / denom);
}
DType u = math::log(x / a);
DType v = math::log(a);
DType coef_uv[3][8] = {
{0.16009398, -0.094634809, 0.025146376, -0.0030648343, 1, 0.32668115,
0.10406089, 0.0014179084},
{0.53487893, 0.1298071, 0.065735949, -0.0015649758, 0.16639465,
0.020070113, -0.0035938915, -0.00058392623},
{0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, 0.017050642,
-0.0021309326, 0.00085092367, -1.5247877e-07},
};
DType coef_v[8];
for (int i = 0; i < 8; i++) {
coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]);
}
DType p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3]));
DType q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7]));
return static_cast<DType>(math::exp(p / q));
}
}; // gamma_implicit_grad

/*! \brief product reducer */
struct product {
/*! \brief do reduction into dst */
Expand Down
31 changes: 30 additions & 1 deletion src/operator/numpy/random/np_gamma_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,39 @@ NNVM_REGISTER_OP(_npi_gamma)
ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyGammaForward<cpu, double>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_gamma_sample"})
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_argument("input2", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyGammaParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_gamma_sample)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<NumpyGammaParam>)
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyGammaParam& param = nnvm::get<NumpyGammaParam>(attrs.parsed);
int num_inputs = 4;
if (param.shape.has_value()) num_inputs -= 1;
if (param.scale.has_value()) num_inputs -= 1;
return num_inputs;
}
)
.set_num_outputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyGammaParam& param = nnvm::get<NumpyGammaParam>(attrs.parsed);
int num_outputs = 2;
if (param.shape.has_value()) num_outputs -= 1;
if (param.scale.has_value()) num_outputs -= 1;
return num_outputs;
}
)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyGammaGrad<cpu>)
.add_arguments(NumpyGammaParam::__FIELDS__());


} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/random/np_gamma_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,8 @@ namespace op {
NNVM_REGISTER_OP(_npi_gamma)
.set_attr<FCompute>("FCompute<gpu>", NumpyGammaForward<gpu, double>);

NNVM_REGISTER_OP(_backward_gamma_sample)
.set_attr<FCompute>("FCompute<gpu>", NumpyGammaGrad<gpu>);

} // namespace op
} // namespace mxnet
82 changes: 82 additions & 0 deletions src/operator/numpy/random/np_gamma_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ struct CheckSuccessKernel {
}
};

template<typename DType>
struct StandarizeKernel {
MSHADOW_XINLINE static void Map(int i, DType* samples, float scale) {
samples[i] /= scale;
}
};

template <int ndim, typename IType, typename OType, typename FType>
struct gamma_kernel {
MSHADOW_XINLINE static void Map(index_t i, const Shape<ndim> &lstride,
Expand Down Expand Up @@ -394,6 +401,81 @@ void NumpyGammaForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}
}

template<typename xpu, int ndim, typename DType>
inline void GammaReparamBackwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_ishape,
const mxnet::TShape& new_oshape,
const float scale) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
using namespace mxnet_op;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob igrad = outputs[0].reshape(new_ishape);
// inputs: [grad_from_samples, alpha_tensor, samples]
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob alpha = inputs[1].reshape(new_ishape);
TBlob samples = inputs[2].reshape(new_oshape);
size_t workspace_size =
ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
// Convert samples to standard gamma
Kernel<op_with_req<mshadow_op::div, kWriteTo>, xpu>::Launch(
s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), DType(scale));
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::gamma_implicit_grad>(
s, igrad, req[0], workspace, ograd, alpha, samples);
Kernel<op_with_req<mshadow_op::mul, kWriteTo>, xpu>::Launch(
s, igrad.Size(), igrad.dptr<DType>(), igrad.dptr<DType>(), DType(scale));
// Convert samples back, otherwise the output would be corrupted.
Kernel<op_with_req<mshadow_op::mul, kWriteTo>, xpu>::Launch(
s, samples.Size(), samples.dptr<DType>(), samples.dptr<DType>(), DType(scale));
}

// Allow gamma sampling to be differentiable,
// using implicit reparameterization gradient:
// -(d/d\alpha cdf(x;alpha)) / pdf(x;alpha)
template<typename xpu>
void NumpyGammaGrad(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
// [scalar, scalar] case
if (outputs.size() == 0U) {
return;
}
const NumpyGammaParam &param = nnvm::get<NumpyGammaParam>(attrs.parsed);
// [tensor tensor] case, not supported.
if (inputs.size() == 5U) {
LOG(FATAL) << "ValueError: two tensor case not supported";
}

// [tensor, scalar] case, only scalar scale is supported.
if (inputs.size() == 3U) {
if (param.shape.has_value()) {
LOG(FATAL) << "ValueError: tensor scale case not supported";
}
mxnet::TShape new_ishape, new_oshape;
int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_,
&new_ishape, &new_ishape, &new_oshape);
auto scale = param.scale.value();
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
GammaReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, req, outputs, new_ishape, new_oshape, scale);
});
});
}
}

} // namespace op
} // namespace mxnet

Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4801,6 +4801,46 @@ def _test_gamma_exception(shape, scale):
assertRaises(ValueError, _test_gamma_exception, shape, scale)


@with_seed()
@use_np
@pytest.mark.parametrize("shape", [(1,), (2, 2), (4, 2, 2)])
@pytest.mark.parametrize("a", [2.0, 5.0, 10.0])
@pytest.mark.parametrize("b", [0.5, 1.0, 1.5])
def test_gamma_grad(shape, a, b):
class TestGammaGrad(HybridBlock):
def __init__(self, size, beta):
super(TestGammaGrad, self).__init__()
self._size = size
self._beta = beta

def hybrid_forward(self, F, a):
return F.np.random.gamma(a, self._beta, size=self._size)

for hybridize in [True, False]:
param = np.ones(shape) * a
param.attach_grad()
net = TestGammaGrad(shape, b)
if hybridize:
net.hybridize()
with mx.autograd.record():
samples = net(param)
samples.backward()
# Check shape
assert param.grad.shape == param.shape
# Check correctness
cdf = ss.gamma.cdf
log_pdf = ss.gamma.logpdf
eps = (0.01 * param / (1.0 + param ** 0.5)).asnumpy()
x = samples.asnumpy().astype('float64') / b
# d(cdf(x;alpha,beta))/d(alpha)
cdf_alpha = (cdf(x, param.asnumpy() + eps) -
cdf(x, param.asnumpy() - eps)) / (2 * eps)
# d(cdf(x;alpha,beta))/d(x)
log_cdf_x = log_pdf(x, param.asnumpy())
expected_grad = -b * cdf_alpha / _np.exp(log_cdf_x)
assert_almost_equal(expected_grad, param.grad.asnumpy(), rtol=1e-2, atol=1e-3)


@with_seed()
@use_np
@pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/18600')
Expand Down

0 comments on commit 83d2af5

Please sign in to comment.