Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix random op signature
Browse files Browse the repository at this point in the history
Fix weight shape inference using uniform/normal

Delete absolute_import
  • Loading branch information
reminisce committed Oct 7, 2019
1 parent 295fc14 commit 4940ec0
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 47 deletions.
8 changes: 4 additions & 4 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def __init__(self, scale=0.07):

def _init_weight(self, _, arr):
uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform
uniform_fn(-self.scale, self.scale, out=arr)
uniform_fn(-self.scale, self.scale, arr.shape, out=arr)

@register
class Normal(Initializer):
Expand Down Expand Up @@ -539,7 +539,7 @@ def __init__(self, sigma=0.01):

def _init_weight(self, _, arr):
normal_fn = _mx_np.random.normal if is_np_array() else random.normal
normal_fn(0, self.sigma, out=arr)
normal_fn(0, self.sigma, arr.shape, out=arr)

@register
class Orthogonal(Initializer):
Expand Down Expand Up @@ -639,10 +639,10 @@ def _init_weight(self, name, arr):
scale = np.sqrt(self.magnitude / factor)
if self.rnd_type == "uniform":
uniform_fn = _mx_np.random.uniform if is_np_array() else random.uniform
uniform_fn(-scale, scale, out=arr)
uniform_fn(-scale, scale, arr.shape, out=arr)
elif self.rnd_type == "gaussian":
normal_fn = _mx_np.random.normal if is_np_array() else random.normal
normal_fn(0, scale, out=arr)
normal_fn(0, scale, arr.shape, out=arr)
else:
raise ValueError("Unknown random type")

Expand Down
18 changes: 7 additions & 11 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
__all__ = ['randint', 'uniform', 'normal', "choice"]


def randint(low, high=None, size=None, dtype=None, **kwargs):
def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
"""Return random integers from `low` (inclusive) to `high` (exclusive).
Return random integers from the "discrete uniform" distribution of
Expand Down Expand Up @@ -75,14 +75,12 @@ def randint(low, high=None, size=None, dtype=None, **kwargs):
array([[4, 0, 2, 1],
[3, 2, 2, 0]])
"""
ctx = kwargs.pop('ctx', None)
out = kwargs.pop('out', None)
if dtype is None:
dtype = 'int'
if ctx is None:
ctx = current_context()
if size is None:
size = 1
size = ()
if high is None:
high = low
low = 0
Expand Down Expand Up @@ -114,6 +112,8 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
Expand All @@ -126,8 +126,6 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
dtype = 'float32'
if ctx is None:
ctx = current_context()
if out is not None:
size = out.shape
if size == ():
size = None
if input_type == (True, True):
Expand All @@ -144,7 +142,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
ctx=ctx, dtype=dtype, out=out)


def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
Expand All @@ -165,6 +163,8 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
Data type of output samples. Default is 'float32'
ctx : Context, optional
Device context of output. Default is current context.
out : ``ndarray``, optional
Store output to an existing ``ndarray``.
Returns
-------
Expand All @@ -173,14 +173,10 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
"""
from ...numpy import ndarray as np_ndarray
input_type = (isinstance(loc, np_ndarray), isinstance(scale, np_ndarray))
ctx = kwargs.pop('ctx', None)
out = kwargs.pop('out', None)
if dtype is None:
dtype = 'float32'
if ctx is None:
ctx = current_context()
if out is not None:
size = out.shape
if size == ():
size = None
if input_type == (True, True):
Expand Down
8 changes: 4 additions & 4 deletions python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
__all__ = ["randint", "uniform", "normal", "choice"]


def randint(low, high=None, size=None, dtype=None, **kwargs):
def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
"""Return random integers from `low` (inclusive) to `high` (exclusive).
Return random integers from the "discrete uniform" distribution of
Expand Down Expand Up @@ -73,7 +73,7 @@ def randint(low, high=None, size=None, dtype=None, **kwargs):
array([[4, 0, 2, 1],
[3, 2, 2, 0]])
"""
return _mx_nd_np.random.randint(low, high, size, dtype, **kwargs)
return _mx_nd_np.random.randint(low, high, size, dtype, ctx, out)


def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
Expand Down Expand Up @@ -110,7 +110,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
return _mx_nd_np.random.uniform(low, high, size=size, ctx=ctx, dtype=dtype, out=out)


def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
Expand Down Expand Up @@ -139,7 +139,7 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
out : ndarray
Drawn samples from the parameterized normal distribution.
"""
return _mx_nd_np.random.normal(loc, scale, size, dtype, **kwargs)
return _mx_nd_np.random.normal(loc, scale, size, dtype, ctx, out)


def multinomial(n, pvals, size=None, **kwargs):
Expand Down
7 changes: 3 additions & 4 deletions python/mxnet/numpy_op_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

"""Make builtin ops' signatures compatible with NumPy."""

from __future__ import absolute_import # pylint: disable=reimported
from sys import version, version_info
import sys
import warnings
from . import _numpy_op_doc
from . import numpy as mx_np
Expand Down Expand Up @@ -55,11 +54,11 @@ def _get_builtin_op(op_name):


def _register_op_signatures():
if version_info.major < 3 or version_info.minor < 5:
if sys.version_info.major < 3 or sys.version_info.minor < 5:
warnings.warn('Some mxnet.numpy operator signatures may not be displayed consistently with '
'their counterparts in the official NumPy package due to too-low Python '
'version {}. Python >= 3.5 is required to make the signatures display correctly.'
.format(str(version)))
.format(str(sys.version)))
return

import inspect
Expand Down
13 changes: 3 additions & 10 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
__all__ = ['randint', 'uniform', 'normal']


def randint(low, high=None, size=None, dtype=None, **kwargs):
def randint(low, high=None, size=None, dtype=None, ctx=None, out=None):
"""Return random integers from `low` (inclusive) to `high` (exclusive).
Return random integers from the "discrete uniform" distribution of
Expand Down Expand Up @@ -74,14 +74,12 @@ def randint(low, high=None, size=None, dtype=None, **kwargs):
array([[4, 0, 2, 1],
[3, 2, 2, 0]])
"""
ctx = kwargs.pop('ctx', None)
out = kwargs.pop('out', None)
if dtype is None:
dtype = 'int'
if ctx is None:
ctx = current_context()
if size is None:
size = 1
size = ()
if high is None:
high = low
low = 0
Expand Down Expand Up @@ -143,7 +141,7 @@ def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None):
ctx=ctx, dtype=dtype, out=out)


def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
def normal(loc=0.0, scale=1.0, size=None, dtype=None, ctx=None, out=None):
"""Draw random samples from a normal (Gaussian) distribution.
Samples are distributed according to a normal distribution parametrized
Expand Down Expand Up @@ -172,15 +170,10 @@ def normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs):
"""
from ._symbol import _Symbol as np_symbol
input_type = (isinstance(loc, np_symbol), isinstance(scale, np_symbol))
ctx = kwargs.pop('ctx', None)
out = kwargs.pop('out', None)
if dtype is None:
dtype = 'float32'
if ctx is None:
ctx = current_context()
out = kwargs.pop('out', None)
if out is not None:
size = out.shape
if size == ():
size = None
if input_type == (True, True):
Expand Down
2 changes: 0 additions & 2 deletions src/operator/random/sample_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ DMLC_REGISTER_PARAMETER(SampleGenNegBinomialLikeParam);
MXNET_OPERATOR_REGISTER_SAMPLE(_random_uniform, SampleUniformParam)
.add_alias("uniform")
.add_alias("random_uniform")
.add_alias("_npi_random_uniform")
.describe(R"code(Draw random samples from a uniform distribution.
.. note:: The existing alias ``uniform`` is deprecated.
Expand All @@ -100,7 +99,6 @@ Example::
MXNET_OPERATOR_REGISTER_SAMPLE(_random_normal, SampleNormalParam)
.add_alias("normal")
.add_alias("random_normal")
.add_alias("_npi_random_normal")
.describe(R"code(Draw random samples from a normal (Gaussian) distribution.
.. note:: The existing alias ``normal`` is deprecated.
Expand Down
15 changes: 14 additions & 1 deletion tests/python/unittest/test_exc_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from common import setup_module, with_seed, teardown
from mxnet.gluon import nn
from mxnet.base import MXNetError
from mxnet.test_utils import assert_exception, default_context, set_default_context
from mxnet.test_utils import assert_exception, default_context, set_default_context, use_np
from nose.tools import assert_raises


@with_seed()
def test_exc_imperative():
def imperative(exec_numpy=True):
Expand Down Expand Up @@ -181,6 +182,7 @@ def run_training_iteration(data):
mx.nd.waitall()
mx.profiler.set_state("stop")


@with_seed()
def test_opencv_exception():
def check_resize():
Expand All @@ -199,6 +201,17 @@ def test_np_reshape_exception():
assert_raises(MXNetError, lambda: mx.np.reshape(a, (-1, 3)))


@with_seed()
@use_np
def test_np_random_incorrect_named_arguments():
random_ops = ['uniform', 'normal', 'randint']
for op_name in random_ops:
op = getattr(mx.np.random, op_name, None)
assert op is not None
assert_raises(TypeError, op, shape=())
assert_raises(TypeError, op, shape=None)


if __name__ == '__main__':
import nose
nose.runmodule()
15 changes: 9 additions & 6 deletions tests/python/unittest/test_numpy_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
from common import with_seed


@with_seed()
def test_create_np_param():
M, K, N = 10, 9, 20

def check_block_params(x, TestBlock, hybridize, expected_type):
def check_block_params(x, TestBlock, hybridize, expected_type, initializer):
net = TestBlock()
net.initialize()
net.initialize(initializer())
if hybridize:
net.hybridize()
net(x)
Expand Down Expand Up @@ -59,12 +60,14 @@ def hybrid_forward(self, F, x, w):
return F.np.dot(x, w)

x = mx.nd.random.uniform(shape=(M, K))
check_block_params(x, TestBlock1, False, mx.nd.NDArray)
check_block_params(x, TestBlock1, True, mx.nd.NDArray)
check_block_params(x.as_np_ndarray(), TestBlock2, False, np.ndarray)
check_block_params(x.as_np_ndarray(), TestBlock2, True, np.ndarray)
for initializer in [mx.initializer.Uniform, mx.initializer.Normal]:
check_block_params(x, TestBlock1, False, mx.nd.NDArray, initializer)
check_block_params(x, TestBlock1, True, mx.nd.NDArray, initializer)
check_block_params(x.as_np_ndarray(), TestBlock2, False, np.ndarray, initializer)
check_block_params(x.as_np_ndarray(), TestBlock2, True, np.ndarray, initializer)


@with_seed()
@use_np
def test_optimizer_with_np_ndarrays():
class LinearRegression(gluon.HybridBlock):
Expand Down
10 changes: 5 additions & 5 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,11 +695,9 @@ def _test_np_exception(func, shape, dim):
if hybridize:
test_gluon.hybridize()
if is_int(itype):
x = mx.nd.arange(120).reshape((2, 3, 4, 5))
x = mx.nd.array(x)
x = np.arange(120).reshape((2, 3, 4, 5))
else:
x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype)
x = x.as_np_ndarray()
x = np.random.uniform(-1.0, 1.0, size=shape, dtype=itype)
x.attach_grad()
if func == 'max':
expected_ret = _np.amax(x.asnumpy(), axis=axis, keepdims=keepdims)
Expand Down Expand Up @@ -1812,6 +1810,8 @@ def test_np_randint():
(5, None)
]
shapes = [
None,
(),
(3, 3),
(3, 4),
(0, 0),
Expand All @@ -1825,7 +1825,7 @@ def test_np_randint():
for shape in shapes:
for (low, high) in params:
data_mx = np.random.randint(low, high, size=shape)
assert data_mx.shape == shape
assert data_mx.shape == (shape if shape is not None else ())

# test generator
for dtype in ['int32', 'int64']:
Expand Down

0 comments on commit 4940ec0

Please sign in to comment.