Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Efficient MXNet sampling in the multinomial distribution (#15311)
Browse files Browse the repository at this point in the history
* Effective multinomial

* Meaningful uniform data pointer as input

* Remove beginning Zeros from CDFs

* Double precision for accumulated var
  • Loading branch information
zixuanweeei authored and wkcn committed Jun 22, 2019
1 parent b4ce4e7 commit e6fad30
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions src/operator/random/sample_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,29 @@ inline bool SampleMultinomialOpType(const nnvm::NodeAttrs& attrs,
struct SampleMultinomialKernel {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, index_t K, index_t M,
DType* dist, float* uniform, IType* out,
DType* prob) {
DType* dist, float* uniform, float* cum_table,
IType* out, DType* prob) {
double acc = 0.0;
// CDF table
for (index_t c = 0; c < K; ++c) {
acc += dist[i*K + c];
cum_table[i*K + c] = static_cast<float>(acc);
}
for (index_t j = 0; j < M; ++j) {
index_t left = 0, right = K;
index_t middle = left + (right - left) / 2;
DType loc = static_cast<DType>(uniform[i*M + j]);
DType acc = 0;
bool found = false;
for (index_t k = 0; k < K; ++k) {
acc += dist[i*K + k];
if (acc > loc) {
found = true;
out[i*M + j] = static_cast<IType>(k);
if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + k]);
break;
while (right - left > 0) {
middle = left + (right - left) / 2;
DType cum_prob = cum_table[i*K + middle];
if (cum_prob < loc) {
left = middle + 1;
} else {
right = middle;
}
}
if (!found) {
out[i*M + j] = static_cast<IType>(K-1);
if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + K - 1]);
}
out[i*M + j] = static_cast<IType>(left);
if (prob != nullptr) prob[i*M + j] = logf(dist[i*K + left]);
}
}
};
Expand All @@ -163,12 +167,14 @@ void SampleMultinomialForward(const nnvm::NodeAttrs& attrs,
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
Tensor<xpu, 1, float> uniform =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M), s);
Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(N*M + N*K), s);
Tensor<xpu, 1, float> uniform(workspace.dptr_, Shape1(N*M));
prnd->SampleUniform(&uniform, 0, 1);
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, IType, {
Kernel<SampleMultinomialKernel, xpu>::Launch(
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<IType>(),
s, N, K, M, inputs[0].dptr<DType>(), uniform.dptr_, workspace.dptr_ + N*M,
outputs[0].dptr<IType>(),
param.get_prob ? outputs[1].dptr<DType>() : nullptr);
});
});
Expand Down

0 comments on commit e6fad30

Please sign in to comment.