Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
mod (#6698)
Browse files Browse the repository at this point in the history
* mod

* fix include

* fix cpplint

* fix R

* fix cpp header

* return NAN instead of 0 and remove warning for mod

* fix cpp header

* fix build for half

* fix half

* fix half

* 0 grad for mod for now

* add override for half2

* fix half2_t

* updated backward

* fix registration

* mod, working backward compatible with numpy

* update mshadow

* fix lint

* fix half2 neg

* fix scala
  • Loading branch information
szha authored and piiswrong committed Jun 20, 2017
1 parent 63392cf commit 63ac793
Show file tree
Hide file tree
Showing 23 changed files with 583 additions and 42 deletions.
12 changes: 12 additions & 0 deletions R-package/R/symbol.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,16 @@ init.symbol.methods <- function() {
setMethod("/", signature(e1 = "Rcpp_MXSymbol", e2 = "numeric"), function(e1, e2) {
mx.varg.symbol.internal.DivScalar(list(e1, scalar = e2))
})
setMethod("%%", signature(e1 = "Rcpp_MXSymbol", e2 = "Rcpp_MXSymbol"), function(e1, e2) {
mx.varg.symbol.internal.Mod(list(e1, e2))
})
setMethod("%%", signature(e1 = "Rcpp_MXSymbol", e2 = "numeric"), function(e1, e2) {
mx.varg.symbol.internal.ModScalar(list(e1, scalar = e2))
})
setMethod("%/%", signature(e1 = "Rcpp_MXSymbol", e2 = "Rcpp_MXSymbol"), function(e1, e2) {
mx.varg.symbol.internal.Mod(list(e1, e2))
})
setMethod("%/%", signature(e1 = "Rcpp_MXSymbol", e2 = "numeric"), function(e1, e2) {
mx.varg.symbol.internal.ModScalar(list(e1, scalar = e2))
})
}
13 changes: 13 additions & 0 deletions R-package/src/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ NDArray::RObjectType DispatchOps(SEXP op, SEXP lhs, SEXP rhs) {
static OpHandle div = NDArrayFunction::FindHandle("_div");
static OpHandle div_scalar = NDArrayFunction::FindHandle("_div_scalar");
static OpHandle rdiv_scalar = NDArrayFunction::FindHandle("_rdiv_scalar");
static OpHandle mod = NDArrayFunction::FindHandle("_mod");
static OpHandle mod_scalar = NDArrayFunction::FindHandle("_mod_scalar");
static OpHandle rmod_scalar = NDArrayFunction::FindHandle("_rmod_scalar");
// parse the arguments
std::string values[2];
NDArrayHandle handles[2];
Expand Down Expand Up @@ -591,6 +594,16 @@ NDArray::RObjectType DispatchOps(SEXP op, SEXP lhs, SEXP rhs) {
}
break;
}
case '%': {
if (lhs_nd && rhs_nd) {
out = BinaryOp(mod, handles);
} else if (lhs_nd && !rhs_nd) {
out = BinaryScalarOp(mod_scalar, handles[0], values[1]);
} else {
out = BinaryScalarOp(rmod_scalar, handles[1], values[0]);
}
break;
}
default: {
RLOG_FATAL << "Operator " << sop << "not supported for MXNDArray";
}
Expand Down
16 changes: 16 additions & 0 deletions cpp-package/include/mxnet-cpp/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,12 @@ class NDArray {
NDArray operator-(mx_float scalar);
NDArray operator*(mx_float scalar);
NDArray operator/(mx_float scalar);
NDArray operator%(mx_float scalar);
NDArray operator+(const NDArray &);
NDArray operator-(const NDArray &);
NDArray operator*(const NDArray &);
NDArray operator/(const NDArray &);
NDArray operator%(const NDArray &);
/*!
* \brief set all the elements in ndarray to be scalar
* \param scalar the scalar to set
Expand Down Expand Up @@ -184,6 +186,13 @@ class NDArray {
*/
NDArray &operator/=(mx_float scalar);
/*!
* \brief elementwise modulo from current ndarray
* this mutate the current NDArray
* \param scalar the data to subtract
* \return reference of self
*/
NDArray &operator%=(mx_float scalar);
/*!
* \brief elementwise add to current space
* this mutate the current NDArray
* \param src the data to add
Expand Down Expand Up @@ -211,6 +220,13 @@ class NDArray {
* \return reference of self
*/
NDArray &operator/=(const NDArray &src);
/*!
* \brief elementwise modulo from current ndarray
* this mutate the current NDArray
* \param src the data to subtract
* \return reference of self
*/
NDArray &operator%=(const NDArray &src);
NDArray ArgmaxChannel();
/*!
* \brief Do a synchronize copy from a continugous CPU memory region.
Expand Down
18 changes: 18 additions & 0 deletions cpp-package/include/mxnet-cpp/ndarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ inline NDArray NDArray::operator/(mx_float scalar) {
Operator("_div_scalar")(*this, scalar).Invoke(ret);
return ret;
}
inline NDArray NDArray::operator%(mx_float scalar) {
NDArray ret;
Operator("_mod_scalar")(*this, scalar).Invoke(ret);
return ret;
}
inline NDArray NDArray::operator+(const NDArray &rhs) {
NDArray ret;
Operator("_plus")(*this, rhs).Invoke(ret);
Expand All @@ -113,6 +118,11 @@ inline NDArray NDArray::operator/(const NDArray &rhs) {
Operator("_div")(*this, rhs).Invoke(ret);
return ret;
}
inline NDArray NDArray::operator%(const NDArray &rhs) {
NDArray ret;
Operator("_mod")(*this, rhs).Invoke(ret);
return ret;
}
inline NDArray &NDArray::operator=(mx_float scalar) {
Operator("_set_value")(scalar).Invoke(*this);
return *this;
Expand All @@ -133,6 +143,10 @@ inline NDArray &NDArray::operator/=(mx_float scalar) {
Operator("_div_scalar")(*this, scalar).Invoke(*this);
return *this;
}
inline NDArray &NDArray::operator%=(mx_float scalar) {
Operator("_mod_scalar")(*this, scalar).Invoke(*this);
return *this;
}
inline NDArray &NDArray::operator+=(const NDArray &rhs) {
Operator("_plus")(*this, rhs).Invoke(*this);
return *this;
Expand All @@ -149,6 +163,10 @@ inline NDArray &NDArray::operator/=(const NDArray &rhs) {
Operator("_div")(*this, rhs).Invoke(*this);
return *this;
}
inline NDArray &NDArray::operator%=(const NDArray &rhs) {
Operator("_mod")(*this, rhs).Invoke(*this);
return *this;
}

inline NDArray NDArray::ArgmaxChannel() {
NDArray ret;
Expand Down
14 changes: 14 additions & 0 deletions cpp-package/include/mxnet-cpp/op_suppl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ inline Symbol _Div(Symbol lhs, Symbol rhs) {
return Operator("_Div")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Mod(Symbol lhs, Symbol rhs) {
return Operator("_Mod")(lhs, rhs)
.CreateSymbol();
}
inline Symbol _Power(Symbol lhs, Symbol rhs) {
return Operator("_Power")(lhs, rhs)
.CreateSymbol();
Expand Down Expand Up @@ -77,6 +81,16 @@ inline Symbol _RDivScalar(mx_float scalar, Symbol rhs) {
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _ModScalar(Symbol lhs, mx_float scalar) {
return Operator("_ModScalar")(lhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _RModScalar(mx_float scalar, Symbol rhs) {
return Operator("_RModScalar")(rhs)
.SetParam("scalar", scalar)
.CreateSymbol();
}
inline Symbol _PowerScalar(Symbol lhs, mx_float scalar) {
return Operator("_PowerScalar")(lhs)
.SetParam("scalar", scalar)
Expand Down
3 changes: 3 additions & 0 deletions cpp-package/include/mxnet-cpp/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ class Symbol {
Symbol operator-(const Symbol &rhs) const;
Symbol operator*(const Symbol &rhs) const;
Symbol operator/(const Symbol &rhs) const;
Symbol operator%(const Symbol &rhs) const;

Symbol operator+(mx_float scalar) const;
Symbol operator-(mx_float scalar) const;
Symbol operator*(mx_float scalar) const;
Symbol operator/(mx_float scalar) const;
Symbol operator%(mx_float scalar) const;
Symbol Copy() const;
/*!
* \brief construct a variable Symbol
Expand Down Expand Up @@ -252,6 +254,7 @@ Symbol operator+(mx_float lhs, const Symbol &rhs);
Symbol operator-(mx_float lhs, const Symbol &rhs);
Symbol operator*(mx_float lhs, const Symbol &rhs);
Symbol operator/(mx_float lhs, const Symbol &rhs);
Symbol operator%(mx_float lhs, const Symbol &rhs);
} // namespace cpp
} // namespace mxnet
#endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_H_
7 changes: 7 additions & 0 deletions cpp-package/include/mxnet-cpp/symbol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ inline Symbol Symbol::operator+(const Symbol &rhs) const { return _Plus(*this, r
inline Symbol Symbol::operator-(const Symbol &rhs) const { return _Minus(*this, rhs); }
inline Symbol Symbol::operator*(const Symbol &rhs) const { return _Mul(*this, rhs); }
inline Symbol Symbol::operator/(const Symbol &rhs) const { return _Div(*this, rhs); }
inline Symbol Symbol::operator%(const Symbol &rhs) const { return _Mod(*this, rhs); }
inline Symbol Symbol::operator+(mx_float scalar) const {
return _PlusScalar(*this, scalar);
}
Expand All @@ -50,6 +51,9 @@ inline Symbol Symbol::operator*(mx_float scalar) const {
inline Symbol Symbol::operator/(mx_float scalar) const {
return _DivScalar(*this, scalar);
}
inline Symbol Symbol::operator%(mx_float scalar) const {
return _ModScalar(*this, scalar);
}
inline Symbol Symbol::operator[](int index) {
SymbolHandle out;
MXSymbolGetOutput(GetHandle(), index, &out);
Expand Down Expand Up @@ -337,6 +341,9 @@ inline Symbol operator*(mx_float lhs, const Symbol &rhs) { return rhs * lhs; }
inline Symbol operator/(mx_float lhs, const Symbol &rhs) {
return mxnet::cpp::_RDivScalar(lhs, rhs);
}
inline Symbol operator%(mx_float lhs, const Symbol &rhs) {
return mxnet::cpp::_RModScalar(lhs, rhs);
}
} // namespace cpp
} // namespace mxnet

Expand Down
4 changes: 4 additions & 0 deletions docs/api/python/ndarray.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ In the rest of this document, we first overview the methods provided by the
NDArray.__mul__
NDArray.__div__
NDArray.__rdiv__
NDArray.__mod__
NDArray.__rmod__
NDArray.__pow__
```

Expand All @@ -133,6 +135,7 @@ In the rest of this document, we first overview the methods provided by the
NDArray.__isub__
NDArray.__imul__
NDArray.__idiv__
NDArray.__imod__
```

### Comparison operators
Expand Down Expand Up @@ -259,6 +262,7 @@ In the rest of this document, we first overview the methods provided by the
negative
multiply
divide
modulo
dot
batch_dot
add_n
Expand Down
3 changes: 3 additions & 0 deletions docs/api/python/symbol.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ Composite multiple symbols into a new one by an operator.
Symbol.__mul__
Symbol.__div__
Symbol.__rdiv__
Symbol.__mod__
Symbol.__rmod__
Symbol.__pow__
```

Expand Down Expand Up @@ -249,6 +251,7 @@ Composite multiple symbols into a new one by an operator.
broadcast_sub
broadcast_mul
broadcast_div
broadcast_mod
negative
dot
batch_dot
Expand Down
2 changes: 1 addition & 1 deletion mshadow
75 changes: 75 additions & 0 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,25 @@ def __rtruediv__(self, other):
def __itruediv__(self, other):
return self.__idiv__(other)

def __mod__(self, other):
"""x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """
return modulo(self, other)

def __rmod__(self, other):
"""x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """
return modulo(other, self)

def __imod__(self, other):
"""x.__rmod__(y) <=> x%=y """
if not self.writable:
raise ValueError('trying to take modulo from a readonly NDArray')
if isinstance(other, NDArray):
return broadcast_mod(self, other, out=self)
elif isinstance(other, numeric_types):
return _internal._mod_scalar(self, float(other), out=self)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __pow__(self, other):
"""x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """
return power(self, other)
Expand Down Expand Up @@ -1516,6 +1535,62 @@ def divide(lhs, rhs):
_internal._rdiv_scalar)
# pylint: enable= no-member, protected-access

def modulo(lhs, rhs):
"""Returns element-wise modulo of the input arrays with broadcasting.
Equivalent to ``lhs % rhs`` and ``mx.nd.broadcast_mod(lhs, rhs)``.
.. note::
If the corresponding dimensions of two arrays have the same size or one of them has size 1,
then the arrays are broadcastable to a common shape.
Parameters
----------
lhs : scalar or array
First array in modulo.
rhs : scalar or array
Second array in modulo.
The arrays to be taken modulo. If ``lhs.shape != rhs.shape``, they must be
broadcastable to a common shape.
Returns
-------
NDArray
The element-wise modulo of the input arrays.
Examples
--------
>>> x = mx.nd.ones((2,3))*6
>>> y = mx.nd.ones((2,1))*4
>>> x.asnumpy()
array([[ 6., 6., 6.],
[ 6., 6., 6.]], dtype=float32)
>>> y.asnumpy()
array([[ 4.],
[ 4.]], dtype=float32)
>>> x%5
<NDArray 2x3 @cpu(0)>
>>> (x%5).asnumpy()
array([[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> (x%y).asnumpy()
array([[ 2., 2., 2.],
[ 2., 2., 2.]], dtype=float32)
>>> mx.nd.modulo(x,y).asnumpy()
array([[ 2., 2., 2.],
[ 2., 2., 2.]], dtype=float32)
"""
# pylint: disable= no-member, protected-access
return _ufunc_helper(
lhs,
rhs,
broadcast_mod,
operator.mod,
_internal._mod_scalar,
_internal._rmod_scalar)
# pylint: enable= no-member, protected-access

def power(base, exp):
"""Returns result of first array elements raised to powers from second array, element-wise
with broadcasting.
Expand Down
30 changes: 30 additions & 0 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,36 @@ def __rdiv__(self, other):
else:
raise TypeError('type %s not supported' % str(type(other)))

def __mod__(self, other):
"""x.__mod__(y) <=> x%y
Scalar input is supported.
Broadcasting is not supported. Use `broadcast_mod` instead. """
if isinstance(other, Symbol):
return _internal._Mod(self, other)
if isinstance(other, Number):
return _internal._ModScalar(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __rmod__(self, other):
"""x.__rmod__(y) <=> y%x
Only `NDArray` is supported for now.
Example usage:
----------
>>> x = mx.nd.ones((2,3))*3
>>> y = mx.nd.ones((2,3))
>>> x.__rmod__(y).asnumpy()
array([[ 1., 1., 1.,
[ 1., 1., 1., dtype=float32)
"""
if isinstance(other, Number):
return _internal._RModScalar(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))

def __truediv__(self, other):
return self.__div__(other)

Expand Down
Loading

0 comments on commit 63ac793

Please sign in to comment.