Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor and reduce float types for some functions, also add bitwise_xor
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 15, 2019
1 parent 4a27b5c commit 5e1ee4f
Show file tree
Hide file tree
Showing 14 changed files with 613 additions and 321 deletions.
40 changes: 39 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize',
'nan_to_num']
Expand Down Expand Up @@ -4291,6 +4291,44 @@ def hypot(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.
Parameters
----------
x1, x2 : ndarray or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.bitwise_xor(13, 17)
28
>>> np.bitwise_xor(31, 5)
26
>>> np.bitwise_xor(np.array([31,3], dtype='int32'), 5)
array([26, 6])
>>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32'))
array([26, 5])
>>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, False])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def ldexp(x1, x2, out=None, **kwargs):
Expand Down
42 changes: 40 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num']

Expand Down Expand Up @@ -6198,6 +6198,44 @@ def hypot(x1, x2, out=None, **kwargs):
return _mx_nd_np.hypot(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.
Parameters
----------
x1, x2 : ndarray or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.bitwise_xor(13, 17)
28
>>> np.bitwise_xor(31, 5)
26
>>> np.bitwise_xor(np.array([31,3], dtype=np.int32), 5)
array([26, 6])
>>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32'))
array([26, 5])
>>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, False])
"""
return _mx_nd_np.bitwise_xor(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def ldexp(x1, x2, out=None, **kwargs):
Expand Down
35 changes: 29 additions & 6 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num']
Expand Down Expand Up @@ -4058,17 +4058,16 @@ def hypot(x1, x2, out=None, **kwargs):
Parameters
----------
x1, x2 : array_like
x1, x2 : _Symbol or scalar
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
out : _Symbol or None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned. A tuple (possible only as a
keyword argument) must have length equal to the number of outputs.
a freshly-allocated array is returned.
Returns
-------
z : ndarray
z : _Symbol or scalar
The hypotenuse of the triangle(s).
This is a scalar if both `x1` and `x2` are scalars.
Expand All @@ -4080,6 +4079,30 @@ def hypot(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.
Parameters
----------
x1, x2 : _Symbol or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : _Symbol or None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
out : _Symbol or scalar
Result.
"""
return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out)


@set_module('mxnet.symbol.numpy')
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs,
CHECK(in_attrs->at(0) == mshadow::kInt64 ||
in_attrs->at(0) == mshadow::kInt32 ||
in_attrs->at(0) == mshadow::kInt8 ||
in_attrs->at(0) == mshadow::kUint8) << "Only supports integer types.";
in_attrs->at(0) == mshadow::kUint8 ||
in_attrs->at(0) == mshadow::kBool) << "Only supports integer types.";
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name;
}
Expand Down
Loading

0 comments on commit 5e1ee4f

Please sign in to comment.