Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Implements ldexp operator #15845

Merged
merged 1 commit into from
Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique']
'unique', 'ldexp']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -3246,7 +3246,7 @@ def hypot(x1, x2, out=None):
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support float16, float32 and float64.
- Only support float16, float32 and float64.

Examples
--------
Expand All @@ -3263,3 +3263,41 @@ def hypot(x1, x2, out=None):
[ 5., 5., 5.]])
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
def ldexp(x1, x2, out=None):
"""
Returns x1 * 2**x2, element-wise.
The mantissas `x1` and twos exponents `x2` are used to construct
floating point numbers ``x1 * 2**x2``.

Parameters
----------
x1 : ndarray or scalar
Array of multipliers.
x2 : ndarray or scalar, int
Array of twos exponents.
out : ndarray, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.

Returns
-------
y : ndarray or scalar
The result of ``x1 * 2**x2``.
This is a scalar if both `x1` and `x2` are scalars.

Notes
-----
Complex dtypes are not supported, they will raise a TypeError.
Different from numpy, we allow x2 to be float besides int.
`ldexp` is useful as the inverse of `frexp`, if used by itself it is
more clear to simply use the expression ``x1 * 2**x2``.

Examples
--------
>>> np.ldexp(5, np.arange(4))
array([ 5., 10., 20., 40.])
"""
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
40 changes: 39 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique']
'rad2deg', 'deg2rad', 'unique', 'ldexp']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -4792,3 +4792,41 @@ def hypot(x1, x2, out=None):
[ 5., 5., 5.]])
"""
return _mx_nd_np.hypot(x1, x2, out=out)


@set_module('mxnet.numpy')
def ldexp(x1, x2, out=None):
"""
Returns x1 * 2**x2, element-wise.
The mantissas `x1` and twos exponents `x2` are used to construct
floating point numbers ``x1 * 2**x2``.

Parameters
----------
x1 : ndarray or scalar
Array of multipliers.
x2 : ndarray or scalar, int
Array of twos exponents.
out : ndarray, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.

Returns
-------
y : ndarray or scalar
The result of ``x1 * 2**x2``.
This is a scalar if both `x1` and `x2` are scalars.

Notes
-----
Complex dtypes are not supported, they will raise a TypeError.
Different from numpy, we allow x2 to be float besides int.
`ldexp` is useful as the inverse of `frexp`, if used by itself it is
more clear to simply use the expression ``x1 * 2**x2``.

Examples
--------
>>> np.ldexp(5, np.arange(4))
array([ 5., 10., 20., 40.])
"""
return _mx_nd_np.ldexp(x1, x2, out)
31 changes: 30 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique']
'unique', 'ldexp']


def _num_outputs(sym):
Expand Down Expand Up @@ -3394,4 +3394,33 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax
return _npi.unique(ar, return_index, return_inverse, return_counts, axis)


@set_module('mxnet.symbol.numpy')
def ldexp(x1, x2, out=None):
"""
ldexp(x1, x2, out=None)
Returns x1 * 2**x2, element-wise.
The mantissas `x1` and twos exponents `x2` are used to construct
floating point numbers ``x1 * 2**x2``.
Parameters
----------
x1 : _Symbol
Array of multipliers.
x2 : _Symbol
Array of twos exponents.
out : _Symbol or None
Dummy parameter to keep the consistency with the ndarray counterpart.
Returns
-------
y : _Symbol
The result of ``x1 * 2**x2``.
Notes
-----
Complex dtypes are not supported, they will raise a TypeError.
Different from numpy, we allow x2 to be float besides int.
`ldexp` is useful as the inverse of `frexp`, if used by itself it is
more clear to simply use the expression ``x1 * 2**x2``.
"""
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)


_set_np_symbol_class(_Symbol)
11 changes: 11 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,17 @@ MXNET_UNARY_MATH_OP(reciprocal_cube_root, 1.0f / math::cbrt(a));

MXNET_UNARY_MATH_OP(reciprocal_cube_root_grad, -1.0f / (3.0f * math::cbrt(a) * math::id(a)));

/*! \brief used for generate element of ldexp */
MXNET_BINARY_MATH_OP(ldexp, math::id(a) * math::pow(2.0f, b));

MXNET_BINARY_MATH_OP(ldexp_grad, math::pow(2.0f, b));

MXNET_BINARY_MATH_OP(ldexp_rgrad, math::id(a) * math::pow(2.0f, b) * math::log(2.0f));

MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a)); // swap a and b if a is scalar.

MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f));

/*! \brief used for generate element of round */
MXNET_SIMPLE_UNARY_MATH_OP(round);

Expand Down
37 changes: 37 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,5 +263,42 @@ NNVM_REGISTER_OP(_backward_npi_hypot)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::hypot_grad_left,
mshadow_op::hypot_grad_right>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::ldexp>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_ldexp_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::ldexp>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp_scalar"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rldexp_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rldexp>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rldexp_scalar"});

NNVM_REGISTER_OP(_backward_npi_ldexp)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::ldexp_grad,
mshadow_op::ldexp_rgrad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::ldexp_grad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::rldexp_grad>);

} // namespace op
} // namespace mxnet
19 changes: 19 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,24 @@ NNVM_REGISTER_OP(_npi_rarctan2_scalar)
NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rarctan2_grad>);

NNVM_REGISTER_OP(_npi_ldexp)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::ldexp>);

NNVM_REGISTER_OP(_npi_ldexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::ldexp>);

NNVM_REGISTER_OP(_npi_rldexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rldexp>);

NNVM_REGISTER_OP(_backward_npi_ldexp)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::ldexp_grad,
mshadow_op::ldexp_rgrad>);

NNVM_REGISTER_OP(_backward_npi_ldexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, mshadow_op::ldexp_grad>);

NNVM_REGISTER_OP(_backward_npi_rldexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, mshadow_op::rldexp_grad>);

} // namespace op
} // namespace mxnet
5 changes: 5 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,11 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NO
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ldexp); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad); // NOLINT()
/*!
* \brief Tuner objects, *not* automatically generated
*/
Expand Down
61 changes: 61 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,67 @@ def test_np_dot():
assert False


@with_seed()
@use_np
def test_np_ldexp():
class TestLdexp(HybridBlock):
def __init__(self):
super(TestLdexp, self).__init__()

def hybrid_forward(self, F, x1, x2):
return F.np.ldexp(x1, x2)

def _np_ldexp(x1, x2):
return x1 * _np.power(2.0, x2)

def dldx(x1, x2):
grad_a = _np.power(2.0, x2)
grad_b = _np_ldexp(x1, x2) * _np.log(2.0)
if len(x1) == 1:
grad_a = _np.sum(grad_a)
if len(x2) == 1:
grad_b = _np.sum(grad_b)
return [grad_a, grad_b]

shapes = [
((3, 1), (3, 1)),
((3, 1, 2), (3, 1, 2)),
((1, ),(1, )),
((1, ), (2, )),
((3, ), (1, )),
((3, 0), (3, 0)), # zero-size shape
((0, 1), (0, 1)), # zero-size shape
((2, 0, 2), (2, 0, 2)), # zero-size shape
]

for hybridize in [True, False]:
for shape1, shape2 in shapes:
for dtype in [_np.float16, _np.float32, _np.float64]:
test_ldexp = TestLdexp()
if hybridize:
test_ldexp.hybridize()
x1 = rand_ndarray(shape=shape1, dtype=dtype).as_np_ndarray()
x1.attach_grad()
x2 = rand_ndarray(shape=shape2, dtype=dtype).as_np_ndarray()
x2.attach_grad()

np_out = _np_ldexp(x1.asnumpy(), x2.asnumpy())
with mx.autograd.record():
mx_out = test_ldexp(x1, x2)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)

mx_out.backward()
np_backward = dldx(x1.asnumpy(), x2.asnumpy())
assert_almost_equal(x1.grad.asnumpy(), np_backward[0], atol=1e-1, rtol=1e-1)
assert_almost_equal(x2.grad.asnumpy(), np_backward[1], atol=1e-1, rtol=1e-1)

# Test imperative once again
mx_out = np.ldexp(x1, x2)
np_out = _np_ldexp(x1.asnumpy(), x2.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)


@with_seed()
@use_np
def test_np_sum():
Expand Down