Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FEATURE] Add binomial sampling and fix multinomial sampling (#20734)
Browse files Browse the repository at this point in the history
* implement binomial sampling

* add correct multinomial implementation

* small fix in binomial symbol api doc

* change npx_categorical to npx_multinomial

* small sanity fix

* fix python unit tests

* rename previous multinomial implementation to categorical

* small fix
  • Loading branch information
IrishWhiskey authored Feb 5, 2022
1 parent 1cb4d1d commit e9becb9
Show file tree
Hide file tree
Showing 14 changed files with 868 additions and 68 deletions.
4 changes: 4 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,8 @@
'_random_exponential_like',
'_random_gamma',
'_random_gamma_like',
'_random_binomial',
'_random_binomial_like',
'_random_generalized_negative_binomial',
'_random_generalized_negative_binomial_like',
'_random_negative_binomial',
Expand All @@ -353,7 +355,9 @@
'_rnn_param_concat',
'_sample_exponential',
'_sample_gamma',
'_sample_binomial',
'_sample_generalized_negative_binomial',
'_sample_categorical',
'_sample_multinomial',
'_sample_negative_binomial',
'_sample_normal',
Expand Down
134 changes: 121 additions & 13 deletions python/mxnet/ndarray/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from .ndarray import NDArray


__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma',
'multinomial', 'negative_binomial', 'generalized_negative_binomial',
__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'binomial',
'categorical', 'multinomial', 'negative_binomial', 'generalized_negative_binomial',
'shuffle', 'randint']


Expand Down Expand Up @@ -383,6 +383,59 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, ctx=None, out=None, **kwarg
[alpha, beta], shape, dtype, ctx, out, kwargs)


def binomial(n=1, p=0.5, shape=_Null, dtype=_Null, ctx=None, out=None, **kwargs):
"""Draw random samples from a binomial distribution.
Samples are distributed according to a binomial distribution parametrized
by *n* (number of trials) and *p* (success probability).
Parameters
----------
n : float or NDArray, optional
Number of experiments, > 0.
p : float or NDArray, optional
Success probability in each experiment, >= 0 and <= 1.
shape : int or tuple of ints, optional
The number of samples to draw. If shape is, e.g., `(m, n)` and `n` and
`p` are scalars, output shape will be `(m, n)`. If `n` and `p`
are NDArrays with shape, e.g., `(x, y)`, then output will have shape
`(x, y, m, n)`, where `m*n` samples are drawn for each `[n, p)` pair.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context. Overridden by
`n.context` when `n` is an NDArray.
out : NDArray, optional
Store output to an existing NDArray.
Returns
-------
NDArray
If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are scalars, output
shape will be `(m, n)`. If `n` and `p` are NDArrays with shape, e.g.,
`(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are
drawn for each `[n, p)` pair.
Examples
--------
>>> mx.nd.random.binomial(10, 0.1)
[ 1.]
<NDArray 1 @cpu(0)>
>>> mx.nd.random.binomial(10, 0.6, shape=(2,))
[ 4. 6.]
<NDArray 2 @cpu(0)>
>>> n = mx.nd.array([10,2,3])
>>> p = mx.nd.array([0.2,0.3,0.4])
>>> mx.nd.random.binomial(n, p, shape=2)
[[ 1. 4.]
[ 0. 2.]
[ 1. 1.]]
<NDArray 3x2 @cpu(0)>
"""
return _random_helper(_internal._random_binomial, _internal._sample_binomial,
[n, p], shape, dtype, ctx, out, kwargs)


def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, ctx=None,
out=None, **kwargs):
"""Draw random samples from a negative binomial distribution.
Expand Down Expand Up @@ -496,9 +549,8 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, ctx=N
_internal._sample_generalized_negative_binomial,
[mu, alpha], shape, dtype, ctx, out, kwargs)


def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
"""Concurrent sampling from multiple multinomial distributions.
def categorical(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kwargs):
"""Concurrent sampling from multiple categorical distributions.
.. note:: The input distribution must be normalized, i.e. `data` must sum to
1 along its last dimension.
Expand All @@ -507,8 +559,8 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
----------
data : NDArray
An *n* dimensional array whose last dimension has length `k`, where
`k` is the number of possible outcomes of each multinomial distribution.
For example, data with shape `(m, n, k)` specifies `m*n` multinomial
`k` is the number of possible outcomes of each categorical distribution.
For example, data with shape `(m, n, k)` specifies `m*n` categorical
distributions each with `k` possible outcomes.
shape : int or tuple of ints, optional
The number of samples to draw from each distribution. If shape is empty
Expand All @@ -530,7 +582,7 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
For input `data` with `n` dimensions and shape `(d1, d2, ..., dn-1, k)`, and input
`shape` with shape `(s1, s2, ..., sx)`, returns an NDArray with shape
`(d1, d2, ... dn-1, s1, s2, ..., sx)`. The `s1, s2, ... sx` dimensions of the
returned NDArray consist of 0-indexed values sampled from each respective multinomial
returned NDArray consist of 0-indexed values sampled from each respective categorical
distribution provided in the `k` dimension of `data`.
For the case `n`=1, and `x`=1 (one shape dimension), returned NDArray has shape `(s1,)`.
Expand All @@ -542,24 +594,80 @@ def multinomial(data, shape=_Null, get_prob=False, out=None, dtype='int32', **kw
Examples
--------
>>> probs = mx.nd.array([0, 0.1, 0.2, 0.3, 0.4])
>>> mx.nd.random.multinomial(probs)
>>> mx.nd.random.categorical(probs)
[3]
<NDArray 1 @cpu(0)>
>>> probs = mx.nd.array([[0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0]])
>>> mx.nd.random.multinomial(probs)
>>> mx.nd.random.categorical(probs)
[3 1]
<NDArray 2 @cpu(0)>
>>> mx.nd.random.multinomial(probs, shape=2)
>>> mx.nd.random.categorical(probs, shape=2)
[[4 4]
[1 2]]
<NDArray 2x2 @cpu(0)>
>>> mx.nd.random.multinomial(probs, get_prob=True)
>>> mx.nd.random.categorical(probs, get_prob=True)
[3 2]
<NDArray 2 @cpu(0)>
[-1.20397282 -1.60943794]
<NDArray 2 @cpu(0)>
"""
return _internal._sample_multinomial(data, shape, get_prob, out=out, dtype=dtype, **kwargs)
return _internal._sample_categorical(data, shape, get_prob, out=out, dtype=dtype, **kwargs)


def multinomial(n=[1], p=[[1.0]], shape=_Null, dtype='float32', ctx=None, out=None, **kwargs):
"""Concurrent sampling from multiple multinomial distributions.
.. note:: The input distribution must be normalized, i.e. `p` must sum to
1 along its last dimension.
Parameters
----------
n : NDArray
An *n* dimensional array containing the number of trials of each
multinomial distribution.
p : NDArray
An *n+1* dimensional array containing the probabilities of each multinomial
distribution. Its last dimension has length `k`, where `k` is the number
of possible outcomes of each multinomial distribution.
For example, p with shape `(m, n, k)` specifies `m*n` multinomial
distributions each with `k` possible outcomes.
shape : int or tuple of ints, optional
The number of samples to draw from each distribution. If shape is empty
one sample will be drawn from each distribution.
out : NDArray, optional
Store output to an existing NDArray.
ctx : Context, optional
Device context of output. Default is current context. Overridden by
`n.context` when `n` is an NDArray.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
Returns
-------
NDArray
If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are a scalar and an array of length k
respectively, output shape will be `(m, n, k)`. If `n` and `p` are NDArrays with shape, e.g.,
`(x, y)` and `(x, y, k)`, then output will have shape `(x, y, m, n, k)`, where `m*n`
samples are drawn for each `[n, p)` pair.
Examples
--------
>>> mx.nd.random.multinomial(mx.nd.array([10]), mx.nd.array([[0.1, 0.9]]))
[[ 1. 9.]]
<NDArray 1x2 @cpu(0)>
>>> mx.nd.random.multinomial(mx.nd.array([10]), mx.nd.array([[0.6, 0.4]]), shape=(2,))
[[[ 5. 5.]
[ 6. 4.]]]
<NDArray 1x2x2 @cpu(0)>
>>> n = mx.nd.array([10, 2, 3])
>>> p = mx.nd.array([[0.2, 0.8], [0.3, 0.7], [0.4, 0.6]])
>>> mx.nd.random.binomial(n, p)
[[ 2. 8.]
[ 1. 1.]
[ 1. 2.]]
<NDArray 3x2 @cpu(0)>
"""
return _internal._sample_multinomial(n, p, shape=shape, out=out, ctx=ctx, dtype=dtype, **kwargs)


def shuffle(data, **kwargs):
Expand Down
82 changes: 74 additions & 8 deletions python/mxnet/symbol/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from .symbol import Symbol


__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'multinomial',
'negative_binomial', 'generalized_negative_binomial', 'shuffle', 'randint']
__all__ = ['uniform', 'normal', 'randn', 'poisson', 'exponential', 'gamma', 'categorical', 'multinomial',
'binomial', 'negative_binomial', 'generalized_negative_binomial', 'shuffle', 'randint']


def _random_helper(random, sampler, params, shape, dtype, kwargs):
Expand Down Expand Up @@ -240,6 +240,38 @@ def gamma(alpha=1, beta=1, shape=_Null, dtype=_Null, **kwargs):
[alpha, beta], shape, dtype, kwargs)


def binomial(n=1, p=0.5, shape=_Null, dtype=_Null, **kwargs):
"""Draw random samples from a binomial distribution.
Samples are distributed according to a binomial distribution parametrized
by *n* (number of trials) and *p* (success probability).
Parameters
----------
n : float or Symbol, optional
Number of experiments, > 0.
p : float or Symbol, optional
Success probability in each experiment, >= 0 and <= 1.
shape : int or tuple of ints, optional
The number of samples to draw. If shape is, e.g., `(m, n)` and `n` and
`p` are scalars, output shape will be `(m, n)`. If `n` and `p`
are NDArrays with shape, e.g., `(x, y)`, then output will have shape
`(x, y, m, n)`, where `m*n` samples are drawn for each `[n, p)` pair.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
Returns
-------
Symbol
If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are scalars, output
shape will be `(m, n)`. If `n` and `p` are NDArrays with shape, e.g.,
`(x, y)`, then output will have shape `(x, y, m, n)`, where `m*n` samples are
drawn for each `[n, p)` pair.
"""
return _random_helper(_internal._random_binomial, _internal._sample_binomial,
[n, p], shape, dtype, kwargs)


def negative_binomial(k=1, p=1, shape=_Null, dtype=_Null, **kwargs):
"""Draw random samples from a negative binomial distribution.
Expand Down Expand Up @@ -311,8 +343,8 @@ def generalized_negative_binomial(mu=1, alpha=1, shape=_Null, dtype=_Null, **kwa
[mu, alpha], shape, dtype, kwargs)


def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
"""Concurrent sampling from multiple multinomial distributions.
def categorical(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
"""Concurrent sampling from multiple categorical distributions.
.. note:: The input distribution must be normalized, i.e. `data` must sum to
1 along its last dimension.
Expand All @@ -321,8 +353,8 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
----------
data : Symbol
An *n* dimensional array whose last dimension has length `k`, where
`k` is the number of possible outcomes of each multinomial distribution.
For example, data with shape `(m, n, k)` specifies `m*n` multinomial
`k` is the number of possible outcomes of each categorical distribution.
For example, data with shape `(m, n, k)` specifies `m*n` categorical
distributions each with `k` possible outcomes.
shape : int or tuple of ints, optional
The number of samples to draw from each distribution. If shape is empty
Expand All @@ -343,7 +375,7 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
`shape` with shape `(s1, s2, ..., sx)`, returns a Symbol that resovles to shape
`(d1, d2, ... dn-1, s1, s2, ..., sx)`. The `s1, s2, ... sx` dimensions of the
returned Symbol's resolved value will consist of 0-indexed values sampled from each
respective multinomial distribution provided in the `k` dimension of `data`.
respective categorical distribution provided in the `k` dimension of `data`.
For the case `n`=1, and `x`=1 (one shape dimension), returned Symbol will resolve to
shape `(s1,)`.
Expand All @@ -352,7 +384,41 @@ def multinomial(data, shape=_Null, get_prob=True, dtype='int32', **kwargs):
outputs: `[ndarray_output, log_likelihood_output]`, where `log_likelihood_output` will resolve
to the same shape as the sampled outputs in ndarray_output.
"""
return _internal._sample_multinomial(data, shape, get_prob, dtype=dtype, **kwargs)
return _internal._sample_categorical(data, shape, get_prob, dtype=dtype, **kwargs)


def multinomial(n=[1], p=[[1.0]], shape=_Null, dtype='float32', **kwargs):
"""Concurrent sampling from multiple multinomial distributions.
.. note:: The input distribution must be normalized, i.e. `p` must sum to
1 along its last dimension.
Parameters
----------
n : Symbol
An *n* dimensional array containing the number of trials of each
multinomial distribution.
p : Symbol
An *n+1* dimensional array containing the probabilities of each multinomial
distribution. Its last dimension has length `k`, where `k` is the number
of possible outcomes of each multinomial distribution.
For example, p with shape `(m, n, k)` specifies `m*n` multinomial
distributions each with `k` possible outcomes.
shape : int or tuple of ints, optional
The number of samples to draw from each distribution. If shape is empty
one sample will be drawn from each distribution.
dtype : {'float16', 'float32', 'float64'}, optional
Data type of output samples. Default is 'float32'
Returns
-------
Symbol
If input `shape` has shape, e.g., `(m, n)` and `n` and `p` are a scalar and an array of length k
respectively, output shape will be `(m, n, k)`. If `n` and `p` are NDArrays with shape, e.g.,
`(x, y)` and `(x, y, k)`, then output will have shape `(x, y, m, n, k)`, where `m*n`
samples are drawn for each `[n, p)` pair.
"""
return _internal._sample_multinomial(n, p, shape, dtype=dtype, **kwargs)


def shuffle(data, **kwargs):
Expand Down
38 changes: 38 additions & 0 deletions src/operator/random/multisample_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,37 @@ Examples::
)code");
}

inline std::string binomial_desc() {
return std::string(R"code(Concurrent sampling from multiple
binomial distributions with parameters *n* (number of trials) and *p* (success probability).
The parameters of the distributions are provided as input arrays.
Let *[s]* be the shape of the input arrays, *n* be the dimension of *[s]*, *[t]*
be the shape specified as the parameter of the operator, and *m* be the dimension
of *[t]*. Then the output will be a *(n+m)*-dimensional array with shape *[s]x[t]*.
For any valid *n*-dimensional index *i* with respect to the input arrays, *output[i]*
will be an *m*-dimensional array that holds randomly drawn samples from the distribution
which is parameterized by the input values at index *i*. If the shape parameter of the
operator is not set, then one sample will be drawn per distribution and the output array
has the same shape as the input arrays.
Samples will always be returned as a floating point data type.
Examples::
n = [ 20, 49 ]
p = [ 0.4 , 0.77 ]
// Draw a single sample for each distribution
sample_binomial(n, p) = [ 5., 36.]
// Draw a vector containing two samples for each distribution
sample_binomial(n, p, shape=(2)) = [[ 5., 40.],
[ 11., 35.]]
)code");
}

inline std::string negative_binomial_desc() {
return std::string(R"code(Concurrent sampling from multiple
negative binomial distributions with parameters *k* (failure limit) and *p* (failure probability).
Expand Down Expand Up @@ -312,6 +343,13 @@ MXNET_OPERATOR_REGISTER_SAMPLING1(poisson,
"Lambda (rate) parameters of the distributions.",
poisson_desc)
.add_alias("_npx_tensor_poisson");
MXNET_OPERATOR_REGISTER_SAMPLING2(binomial,
BinomialSampler<cpu>,
"n",
"p",
"Number of experiments.",
"Success probabilities in each experiment.",
binomial_desc);
MXNET_OPERATOR_REGISTER_SAMPLING2(negative_binomial,
NegativeBinomialSampler<cpu>,
"k",
Expand Down
3 changes: 3 additions & 0 deletions src/operator/random/multisample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ NNVM_REGISTER_OP(_sample_exponential)
NNVM_REGISTER_OP(_sample_poisson)
.set_attr<FCompute>("FCompute<gpu>", MultiSampleOpForward<gpu, PoissonSampler<gpu>, 1>);

NNVM_REGISTER_OP(_sample_binomial)
.set_attr<FCompute>("FCompute<gpu>", MultiSampleOpForward<gpu, BinomialSampler<gpu>, 2>);

NNVM_REGISTER_OP(_sample_negative_binomial)
.set_attr<FCompute>("FCompute<gpu>",
MultiSampleOpForward<gpu, NegativeBinomialSampler<gpu>, 2>);
Expand Down
Loading

0 comments on commit e9becb9

Please sign in to comment.