diff --git a/python/mxnet/ndarray/numpy_extension/random.py b/python/mxnet/ndarray/numpy_extension/random.py index b0472a4ab122..5520a1f1b4d0 100644 --- a/python/mxnet/ndarray/numpy_extension/random.py +++ b/python/mxnet/ndarray/numpy_extension/random.py @@ -21,7 +21,7 @@ from ..numpy import _internal as _npi -__all__ = ['bernoulli'] +__all__ = ['bernoulli', 'normal_n', 'uniform_n'] def bernoulli(prob, logit, size, dtype, ctx, out): @@ -102,3 +102,166 @@ def bernoulli(prob, logit, size, dtype, ctx, out): else: return _npi.bernoulli(prob=None, logit=logit, is_logit=True, size=size, ctx=ctx, dtype=dtype, out=out) + + +def uniform_n(low=0.0, high=1.0, batch_shape=None, dtype=None, ctx=None): + r"""Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, ndarray, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float, ndarray, optional + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + batch_shape : int or tuple of ints, optional + Batch shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k * broadcast(low, high).size`` samples are drawn. + If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. Otherwise, + ``np.broadcast(low, high).size`` samples are drawn. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized uniform distribution. + + See Also + -------- + randint : Discrete uniform distribution, yielding integers. + rand : Convenience function that accepts dimensions as input, e.g., + ``rand(2,2)`` would generate a 2-by-2 array of floats, + uniformly distributed over ``[0, 1)``. + + Notes + ----- + The probability density function of the uniform distribution is + + .. math:: p(x) = \frac{1}{b - a} + + anywhere within the interval ``[a, b)``, and zero elsewhere. + + When ``high`` == ``low``, values of ``low`` will be returned. + If ``high`` < ``low``, the results are officially undefined + and may eventually raise an error, i.e. do not rely on this + function to behave when passed arguments satisfying that + inequality condition. + """ + from ...numpy import ndarray as np_ndarray + input_type = (isinstance(low, np_ndarray), isinstance(high, np_ndarray)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if batch_shape == (): + batch_shape = None + if input_type == (True, True): + return _npi.uniform_n(low, high, low=None, high=None, size=batch_shape, + ctx=ctx, dtype=dtype) + elif input_type == (False, True): + return _npi.uniform_n(high, low=low, high=None, size=batch_shape, + ctx=ctx, dtype=dtype) + elif input_type == (True, False): + return _npi.uniform_n(low, low=None, high=high, size=batch_shape, + ctx=ctx, dtype=dtype) + else: + return _npi.uniform_n(low=low, high=high, size=batch_shape, + ctx=ctx, dtype=dtype) + + +def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None): + r"""Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + batch_shape : int or tuple of ints, optional + Batch shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k * broadcast(low, high).size`` samples are drawn. + If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. Otherwise, + ``np.broadcast(loc, scale).size`` samples are drawn. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output, default is current context. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized normal distribution. + + Notes + ----- + The probability density for the Gaussian distribution is + + .. math:: p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }} + e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} }, + + where :math:`\mu` is the mean and :math:`\sigma` the standard + deviation. The square of the standard deviation, :math:`\sigma^2`, + is called the variance. + + The function has its peak at the mean, and its "spread" increases with + the standard deviation (the function reaches 0.607 times its maximum at + :math:`x + \sigma` and :math:`x - \sigma` [2]_). This implies that + `numpy.random.normal` is more likely to return samples lying close to + the mean, rather than those far away. + + References + ---------- + .. [1] Wikipedia, "Normal distribution", + https://en.wikipedia.org/wiki/Normal_distribution + .. [2] P. R. Peebles Jr., "Central Limit Theorem" in "Probability, + Random Variables and Random Signal Principles", 4th ed., 2001, + pp. 51, 51, 125. + + Examples + -------- + >>> mu, sigma = 0, 0.1 # mean and standard deviation + >>> s = np.random.normal(mu, sigma, 1000) + + Verify the mean and the variance: + + >>> np.abs(mu - np.mean(s)) < 0.01 + array(True) + """ + from ...numpy import ndarray as np_ndarray + input_type = (isinstance(loc, np_ndarray), isinstance(scale, np_ndarray)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if batch_shape == (): + batch_shape = None + if input_type == (True, True): + return _npi.normal_n(loc, scale, loc=None, scale=None, size=batch_shape, + ctx=ctx, dtype=dtype) + elif input_type == (False, True): + return _npi.normal_n(scale, loc=loc, scale=None, size=batch_shape, + ctx=ctx, dtype=dtype) + elif input_type == (True, False): + return _npi.normal_n(loc, loc=None, scale=scale, size=batch_shape, + ctx=ctx, dtype=dtype) + else: + return _npi.normal_n(loc=loc, scale=scale, size=batch_shape, + ctx=ctx, dtype=dtype) diff --git a/python/mxnet/numpy_extension/random.py b/python/mxnet/numpy_extension/random.py index 316760b6e2d6..3735abe14d51 100644 --- a/python/mxnet/numpy_extension/random.py +++ b/python/mxnet/numpy_extension/random.py @@ -22,7 +22,7 @@ from ..ndarray import numpy_extension as _mx_nd_npx -__all__ = ['seed', 'bernoulli'] +__all__ = ['seed', 'bernoulli', 'normal_n', 'uniform_n'] def seed(seed, ctx='all'): # pylint: disable=redefined-outer-name @@ -126,3 +126,128 @@ def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): [1., 0., 1., 0.]]) """ return _mx_nd_npx.random.bernoulli(prob, logit, size, dtype, ctx, out) + + +def uniform_n(low=0.0, high=1.0, batch_shape=None, dtype=None, ctx=None): + r"""Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, ndarray, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float, ndarray, optional + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + batch_shape : int or tuple of ints, optional + Batch shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k * broadcast(low, high).size`` samples are drawn. + If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. Otherwise, + ``np.broadcast(low, high).size`` samples are drawn. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized uniform distribution. + + See Also + -------- + randint : Discrete uniform distribution, yielding integers. + rand : Convenience function that accepts dimensions as input, e.g., + ``rand(2,2)`` would generate a 2-by-2 array of floats, + uniformly distributed over ``[0, 1)``. + + Notes + ----- + The probability density function of the uniform distribution is + + .. math:: p(x) = \frac{1}{b - a} + + anywhere within the interval ``[a, b)``, and zero elsewhere. + + When ``high`` == ``low``, values of ``low`` will be returned. + If ``high`` < ``low``, the results are officially undefined + and may eventually raise an error, i.e. do not rely on this + function to behave when passed arguments satisfying that + inequality condition. + """ + return _mx_nd_npx.random.uniform_n(low, high, batch_shape=batch_shape, ctx=ctx, dtype=dtype) + + +def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None): + r"""Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + batch_shape : int or tuple of ints, optional + Batch shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k * broadcast(low, high).size`` samples are drawn. + If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. Otherwise, + ``np.broadcast(loc, scale).size`` samples are drawn. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output, default is current context. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized normal distribution. + + Notes + ----- + The probability density for the Gaussian distribution is + + .. math:: p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }} + e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} }, + + where :math:`\mu` is the mean and :math:`\sigma` the standard + deviation. The square of the standard deviation, :math:`\sigma^2`, + is called the variance. + + The function has its peak at the mean, and its "spread" increases with + the standard deviation (the function reaches 0.607 times its maximum at + :math:`x + \sigma` and :math:`x - \sigma` [2]_). This implies that + `numpy.random.normal` is more likely to return samples lying close to + the mean, rather than those far away. + + References + ---------- + .. [1] Wikipedia, "Normal distribution", + https://en.wikipedia.org/wiki/Normal_distribution + .. [2] P. R. Peebles Jr., "Central Limit Theorem" in "Probability, + Random Variables and Random Signal Principles", 4th ed., 2001, + pp. 51, 51, 125. + + Examples + -------- + >>> mu, sigma = 0, 0.1 # mean and standard deviation + >>> s = np.random.normal(mu, sigma, 1000) + + Verify the mean and the variance: + + >>> np.abs(mu - np.mean(s)) < 0.01 + array(True) + """ + return _mx_nd_npx.random.normal_n(loc, scale, batch_shape, dtype, ctx) diff --git a/python/mxnet/symbol/numpy_extension/random.py b/python/mxnet/symbol/numpy_extension/random.py index a557a75d56f7..8d2dc4e03fee 100644 --- a/python/mxnet/symbol/numpy_extension/random.py +++ b/python/mxnet/symbol/numpy_extension/random.py @@ -21,7 +21,7 @@ from ...context import current_context from ..numpy import _internal as _npi -__all__ = ['bernoulli'] +__all__ = ['bernoulli', 'normal_n', 'uniform_n'] def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): @@ -102,3 +102,166 @@ def bernoulli(prob=None, logit=None, size=None, dtype=None, ctx=None, out=None): else: return _npi.bernoulli(prob=None, logit=logit, is_logit=True, size=size, ctx=ctx, dtype=dtype, out=out) + + +def uniform_n(low=0.0, high=1.0, batch_shape=None, dtype=None, ctx=None): + r"""Draw samples from a uniform distribution. + + Samples are uniformly distributed over the half-open interval + ``[low, high)`` (includes low, but excludes high). In other words, + any value within the given interval is equally likely to be drawn + by `uniform`. + + Parameters + ---------- + low : float, ndarray, optional + Lower boundary of the output interval. All values generated will be + greater than or equal to low. The default value is 0. + high : float, ndarray, optional + Upper boundary of the output interval. All values generated will be + less than high. The default value is 1.0. + shape : int or tuple of ints, optional + Batch shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k * broadcast(low, high).size`` samples are drawn. + If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. Otherwise, + ``np.broadcast(low, high).size`` samples are drawn. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output. Default is current context. + + Returns + ------- + out : ndarray + Drawn samples from the parameterized uniform distribution. + + See Also + -------- + randint : Discrete uniform distribution, yielding integers. + rand : Convenience function that accepts dimensions as input, e.g., + ``rand(2,2)`` would generate a 2-by-2 array of floats, + uniformly distributed over ``[0, 1)``. + + Notes + ----- + The probability density function of the uniform distribution is + + .. math:: p(x) = \frac{1}{b - a} + + anywhere within the interval ``[a, b)``, and zero elsewhere. + + When ``high`` == ``low``, values of ``low`` will be returned. + If ``high`` < ``low``, the results are officially undefined + and may eventually raise an error, i.e. do not rely on this + function to behave when passed arguments satisfying that + inequality condition. + """ + from ..numpy import _Symbol as np_symbol + input_type = (isinstance(low, np_symbol), isinstance(high, np_symbol)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if batch_shape == (): + batch_shape = None + if input_type == (True, True): + return _npi.uniform_n(low, high, low=None, high=None, size=batch_shape, + ctx=ctx, dtype=dtype) + elif input_type == (False, True): + return _npi.uniform_n(high, low=low, high=None, size=batch_shape, + ctx=ctx, dtype=dtype) + elif input_type == (True, False): + return _npi.uniform_n(low, low=None, high=high, size=batch_shape, + ctx=ctx, dtype=dtype) + else: + return _npi.uniform_n(low=low, high=high, size=batch_shape, + ctx=ctx, dtype=dtype) + + +def normal_n(loc=0.0, scale=1.0, batch_shape=None, dtype=None, ctx=None): + r"""Draw random samples from a normal (Gaussian) distribution. + + Samples are distributed according to a normal distribution parametrized + by *loc* (mean) and *scale* (standard deviation). + + + Parameters + ---------- + loc : float, optional + Mean (centre) of the distribution. + scale : float, optional + Standard deviation (spread or "width") of the distribution. + shape : int or tuple of ints, optional + Batch shape. If the given shape is, e.g., ``(m, n, k)``, then + ``m * n * k * broadcast(low, high).size`` samples are drawn. + If size is ``None`` (default), + a scalar tensor containing a single value is returned if + ``low`` and ``high`` are both scalars. Otherwise, + ``np.broadcast(loc, scale).size`` samples are drawn. + dtype : {'float16', 'float32', 'float64'}, optional + Data type of output samples. Default is 'float32' + ctx : Context, optional + Device context of output, default is current context. + + Returns + ------- + out : _Symbol + Drawn samples from the parameterized normal distribution. + + Notes + ----- + The probability density for the Gaussian distribution is + + .. math:: p(x) = \frac{1}{\sqrt{ 2 \pi \sigma^2 }} + e^{ - \frac{ (x - \mu)^2 } {2 \sigma^2} }, + + where :math:`\mu` is the mean and :math:`\sigma` the standard + deviation. The square of the standard deviation, :math:`\sigma^2`, + is called the variance. + + The function has its peak at the mean, and its "spread" increases with + the standard deviation (the function reaches 0.607 times its maximum at + :math:`x + \sigma` and :math:`x - \sigma` [2]_). This implies that + `numpy.random.normal` is more likely to return samples lying close to + the mean, rather than those far away. + + References + ---------- + .. [1] Wikipedia, "Normal distribution", + https://en.wikipedia.org/wiki/Normal_distribution + .. [2] P. R. Peebles Jr., "Central Limit Theorem" in "Probability, + Random Variables and Random Signal Principles", 4th ed., 2001, + pp. 51, 51, 125. + + Examples + -------- + >>> mu, sigma = 0, 0.1 # mean and standard deviation + >>> s = np.random.normal(mu, sigma, 1000) + + Verify the mean and the variance: + + >>> np.abs(mu - np.mean(s)) < 0.01 + array(True) + """ + from ..numpy import _Symbol as np_symbol + input_type = (isinstance(loc, np_symbol), isinstance(scale, np_symbol)) + if dtype is None: + dtype = 'float32' + if ctx is None: + ctx = current_context() + if batch_shape == (): + batch_shape = None + if input_type == (True, True): + return _npi.normal_n(loc, scale, loc=None, scale=None, size=batch_shape, + ctx=ctx, dtype=dtype) + elif input_type == (False, True): + return _npi.normal_n(scale, loc=loc, scale=None, size=batch_shape, + ctx=ctx, dtype=dtype) + elif input_type == (True, False): + return _npi.normal_n(loc, loc=None, scale=scale, size=batch_shape, + ctx=ctx, dtype=dtype) + else: + return _npi.normal_n(loc=loc, scale=scale, size=batch_shape, + ctx=ctx, dtype=dtype) diff --git a/src/operator/numpy/random/dist_common.h b/src/operator/numpy/random/dist_common.h index 6394375883aa..e8358294eaf0 100644 --- a/src/operator/numpy/random/dist_common.h +++ b/src/operator/numpy/random/dist_common.h @@ -205,6 +205,50 @@ inline bool UnaryDistOpShape(const nnvm::NodeAttrs &attrs, return shape_is_known(out_attrs->at(0)); } + +// Infer Shape function for sample_n Op. +// i.e. output_shape = (shape,) + broadcast(param1.shape, param2.shape) +template <typename DistParam> +inline bool TwoparamsDistOpConcatShape(const nnvm::NodeAttrs &attrs, + std::vector<TShape> *in_attrs, + std::vector<TShape> *out_attrs) { + const DistParam ¶m = nnvm::get<DistParam>(attrs.parsed); + // broadcast(param1.shape, param2.shape). + mxnet::TShape param_broadcast_shape; + if (in_attrs->size() == 2U) { + // Both params from ndarray. + mxnet::TShape ¶m1 = (*in_attrs)[0]; + mxnet::TShape ¶m2 = (*in_attrs)[1]; + mxnet::TShape out(std::max(param1.ndim(), param2.ndim()), -1); + InferBroadcastShape(param1, param2, &out); + param_broadcast_shape = out; + } else if (in_attrs->size() == 1U) { + // One param from ndarray. + param_broadcast_shape = in_attrs->at(0); + } else if (in_attrs->size() == 0) { + // Two scalar case. + param_broadcast_shape = TShape(0, -1); + } + if (param.size.has_value()) { + // Size declared. + std::vector<dim_t> oshape_vec; + const mxnet::Tuple<int> &size = param.size.value(); + for (int i = 0; i < size.ndim(); ++i) { + oshape_vec.emplace_back(size[i]); + } + for (int i = 0; i < param_broadcast_shape.ndim(); ++i) { + oshape_vec.emplace_back(param_broadcast_shape[i]); + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_broadcast_shape); + } + if (out_attrs->size() == 2U) { + SHAPE_ASSIGN_CHECK(*out_attrs, 1, out_attrs->at(0)); + } + return true; +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_normal_op.cc b/src/operator/numpy/random/np_normal_op.cc index 992e93c8af16..e2aa1397af67 100644 --- a/src/operator/numpy/random/np_normal_op.cc +++ b/src/operator/numpy/random/np_normal_op.cc @@ -69,6 +69,46 @@ NNVM_REGISTER_OP(_npi_normal) .add_argument("input2", "NDArray-or-Symbol", "Source input") .add_arguments(NumpyNormalParam::__FIELDS__()); +NNVM_REGISTER_OP(_npi_normal_n) +.describe("Ndarray behavior normal") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyNormalParam& param = nnvm::get<NumpyNormalParam>(attrs.parsed); + int num_inputs = 2; + if (param.loc.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs(2) +.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { + return 1; +}) +.set_attr<nnvm::FListInputNames>("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyNormalParam& param = nnvm::get<NumpyNormalParam>(attrs.parsed); + int num_inputs = 2; + if (param.loc.has_value()) num_inputs -= 1; + if (param.scale.has_value()) num_inputs -= 1; + if (num_inputs == 0) return std::vector<std::string>(); + if (num_inputs == 1) return std::vector<std::string>{"input1"}; + return std::vector<std::string>{"input1", "input2"}; + }) +.set_attr_parser(ParamParser<NumpyNormalParam>) +.set_attr<mxnet::FInferShape>("FInferShape", TwoparamsDistOpConcatShape<NumpyNormalParam>) +.set_attr<nnvm::FInferType>("FInferType", NumpyNormalOpType) +.set_attr<FResourceRequest>("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector<ResourceRequest>{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr<FCompute>("FCompute<cpu>", NumpyNormalForward<cpu>) +.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_normal"}) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyNormalParam::__FIELDS__()); + NNVM_REGISTER_OP(_backward_broadcast_normal) .set_attr<nnvm::TIsBackward>("TIsBackward", true) .set_attr_parser(ParamParser<NumpyNormalParam>) diff --git a/src/operator/numpy/random/np_normal_op.cu b/src/operator/numpy/random/np_normal_op.cu index 0eab089abbc9..d45bc2321bd7 100644 --- a/src/operator/numpy/random/np_normal_op.cu +++ b/src/operator/numpy/random/np_normal_op.cu @@ -29,10 +29,13 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_normal) - .set_attr<FCompute>("FCompute<gpu>", NumpyNormalForward<gpu>); +.set_attr<FCompute>("FCompute<gpu>", NumpyNormalForward<gpu>); NNVM_REGISTER_OP(_backward_broadcast_normal) .set_attr<FCompute>("FCompute<gpu>", NormalReparamBackward<gpu>); +NNVM_REGISTER_OP(_npi_normal_n) +.set_attr<FCompute>("FCompute<gpu>", NumpyNormalForward<gpu>); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_uniform_op.cc b/src/operator/numpy/random/np_uniform_op.cc index 7307b7744d5e..1f4be230996c 100644 --- a/src/operator/numpy/random/np_uniform_op.cc +++ b/src/operator/numpy/random/np_uniform_op.cc @@ -65,5 +65,41 @@ NNVM_REGISTER_OP(_npi_uniform) .add_argument("input2", "NDArray-or-Symbol", "Source input") .add_arguments(NumpyUniformParam::__FIELDS__()); +NNVM_REGISTER_OP(_npi_uniform_n) +.describe("numpy behavior uniform") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyUniformParam& param = nnvm::get<NumpyUniformParam>(attrs.parsed); + int num_inputs = 2; + if (param.low.has_value()) num_inputs -= 1; + if (param.high.has_value()) num_inputs -= 1; + return num_inputs; + } +) +.set_num_outputs(1) +.set_attr<nnvm::FListInputNames>("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyUniformParam& param = nnvm::get<NumpyUniformParam>(attrs.parsed); + int num_inputs = 2; + if (param.low.has_value()) num_inputs -= 1; + if (param.high.has_value()) num_inputs -= 1; + if (num_inputs == 0) return std::vector<std::string>(); + if (num_inputs == 1) return std::vector<std::string>{"input1"}; + return std::vector<std::string>{"input1", "input2"}; + }) +.set_attr_parser(ParamParser<NumpyUniformParam>) +.set_attr<mxnet::FInferShape>("FInferShape", TwoparamsDistOpConcatShape<NumpyUniformParam>) +.set_attr<nnvm::FInferType>("FInferType", NumpyUniformOpType) +.set_attr<FResourceRequest>("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector<ResourceRequest>{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr<FCompute>("FCompute<cpu>", NumpyUniformForward<cpu>) +.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) +.add_argument("input1", "NDArray-or-Symbol", "Source input") +.add_argument("input2", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyUniformParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/random/np_uniform_op.cu b/src/operator/numpy/random/np_uniform_op.cu index d997bc57d3be..be21eae55647 100644 --- a/src/operator/numpy/random/np_uniform_op.cu +++ b/src/operator/numpy/random/np_uniform_op.cu @@ -31,5 +31,8 @@ namespace op { NNVM_REGISTER_OP(_npi_uniform) .set_attr<FCompute>("FCompute<gpu>", NumpyUniformForward<gpu>); +NNVM_REGISTER_OP(_npi_uniform_n) +.set_attr<FCompute>("FCompute<gpu>", NumpyUniformForward<gpu>); + } // namespace op } // namespace mxnet diff --git a/tests/nightly/test_np_random.py b/tests/nightly/test_np_random.py new file mode 100644 index 000000000000..345fb86b1222 --- /dev/null +++ b/tests/nightly/test_np_random.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +from __future__ import absolute_import +from __future__ import division +import itertools +import os +import sys +from os import path +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../python/common/')) +sys.path.append(os.path.join(curr_path, '../python/unittest/')) +sys.path.insert(0, os.path.join(curr_path, '../../../python')) +import unittest +import numpy as _np +import mxnet as mx +from mxnet import np, npx, autograd +from mxnet.gluon import HybridBlock +from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, use_np +from common import with_seed +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, assert_exception, is_op_runnable, collapse_sum_like +from mxnet.ndarray.ndarray import py_slice +from mxnet.base import integer_types +import scipy.stats as ss + + +@retry(5) +@with_seed() +@use_np +def test_np_uniform(): + types = [None, "float32", "float64"] + ctx = mx.context.current_context() + samples = 1000000 + # Generation test + trials = 8 + num_buckets = 5 + for dtype in types: + for low, high in [(-100.0, -98.0), (99.0, 101.0)]: + scale = high - low + buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), num_buckets) + buckets = np.array(buckets, dtype=dtype).tolist() + probs = [(buckets[i][1] - buckets[i][0])/scale for i in range(num_buckets)] + generator_mx_np = lambda x: mx.np.random.uniform(low, high, size=x, ctx=ctx, dtype=dtype).asnumpy() + verify_generator(generator=generator_mx_np, buckets=buckets, probs=probs, nsamples=samples, nrepeat=trials) + + +@retry(5) +@with_seed() +@use_np +def test_np_normal(): + types = [None, "float32", "float64"] + ctx = mx.context.current_context() + samples = 1000000 + # Generation test + trials = 8 + num_buckets = 5 + for dtype in types: + for loc, scale in [(0.0, 1.0), (1.0, 5.0)]: + buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.norm.pdf(x, loc=low, scale=scale), num_buckets) + buckets = np.array(buckets, dtype=dtype).tolist() + probs = [(buckets[i][1] - buckets[i][0])/scale for i in range(num_buckets)] + generator_mx_np = lambda x: np.random.normal(loc, scale, size=x, ctx=ctx, dtype=dtype).asnumpy() + verify_generator(generator=generator_mx_np, buckets=buckets, probs=probs, nsamples=samples, nrepeat=trials) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 8e46f03e79bc..2c94ec06f4c7 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -974,40 +974,6 @@ def test_np_save_load_ndarrays(): assert _np.array_equal(v.asnumpy(), arr_dict[k].asnumpy()) -@retry(5) -@with_seed() -@use_np -def test_np_uniform(): - types = [None, "float32", "float64"] - ctx = mx.context.current_context() - samples = 1000000 - # Generation test - trials = 8 - num_buckets = 5 - for dtype in types: - for low, high in [(-100.0, -98.0), (99.0, 101.0)]: - scale = high - low - buckets, probs = gen_buckets_probs_with_ppf(lambda x: ss.uniform.ppf(x, loc=low, scale=scale), num_buckets) - buckets = np.array(buckets, dtype=dtype).tolist() - probs = [(buckets[i][1] - buckets[i][0])/scale for i in range(num_buckets)] - generator_mx_np = lambda x: mx.np.random.uniform(low, high, size=x, ctx=ctx, dtype=dtype).asnumpy() - verify_generator(generator=generator_mx_np, buckets=buckets, probs=probs, nsamples=samples, nrepeat=trials) - - # Broadcasting test - params = [ - (1.0, mx.np.ones((4,4)) + 2.0), - (mx.np.zeros((4,4)) + 1, 2.0), - (mx.np.zeros((1,4)), mx.np.ones((4,4)) + mx.np.array([1, 2, 3, 4])), - (mx.np.array([1, 2, 3, 4]), mx.np.ones((2,4,4)) * 5) - ] - for dtype in types: - for low, high in params: - expect_mean = (low + high) / 2 - expanded_size = (samples,) + expect_mean.shape - uniform_samples = mx.np.random.uniform(low, high, size=expanded_size, dtype=dtype) - mx.test_utils.assert_almost_equal(uniform_samples.asnumpy().mean(0), expect_mean.asnumpy(), rtol=0.20, atol=1e-1) - - @retry(5) @with_seed() @use_np diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index dc99fc6fc251..8adf9d78720b 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2670,6 +2670,45 @@ def hybrid_forward(self, F, loc, scale): assert_almost_equal(loc.grad.asnumpy().sum(), _np.ones(out_shape).sum(), rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_npx_sample_n(): + def shape_formatter(s): + if s is None: + return () + if isinstance(s, tuple): + return s + # scalar case + return (s,) + + class TestSampleN(HybridBlock): + def __init__(self, shape, op_name): + super(TestSampleN, self).__init__() + self._shape = shape + self._op_name = op_name + + def hybrid_forward(self, F, param1, param2): + op = getattr(F.npx.random, self._op_name, None) + assert op is not None + # return param1 + param2 + op(batch_shape=self._shape) + return op(param1, param2, batch_shape=self._shape) + + batch_shapes = [(10,), (2, 3), 6, (), None] + event_shapes = [(), (2,), (2,2)] + dtypes = ['float16', 'float32', 'float64'] + op_names = ['uniform_n', 'normal_n'] + + for bshape, eshape, dtype, op in itertools.product(batch_shapes, event_shapes, dtypes, op_names): + for hybridize in [True, False]: + net = TestSampleN(bshape, op) + if hybridize: + net.hybridize() + expected_shape = (shape_formatter(bshape) + + shape_formatter(eshape)) + out = net(np.ones(shape=eshape), np.ones(shape=eshape)) + assert out.shape == expected_shape + + @with_seed() @use_np def test_np_random():