diff --git a/R-package/R/symbol.R b/R-package/R/symbol.R index d2fd67bc45c0..091a6468c1a1 100644 --- a/R-package/R/symbol.R +++ b/R-package/R/symbol.R @@ -158,4 +158,16 @@ init.symbol.methods <- function() { setMethod("/", signature(e1 = "Rcpp_MXSymbol", e2 = "numeric"), function(e1, e2) { mx.varg.symbol.internal.DivScalar(list(e1, scalar = e2)) }) + setMethod("%%", signature(e1 = "Rcpp_MXSymbol", e2 = "Rcpp_MXSymbol"), function(e1, e2) { + mx.varg.symbol.internal.Mod(list(e1, e2)) + }) + setMethod("%%", signature(e1 = "Rcpp_MXSymbol", e2 = "numeric"), function(e1, e2) { + mx.varg.symbol.internal.ModScalar(list(e1, scalar = e2)) + }) + setMethod("%/%", signature(e1 = "Rcpp_MXSymbol", e2 = "Rcpp_MXSymbol"), function(e1, e2) { + mx.varg.symbol.internal.Mod(list(e1, e2)) + }) + setMethod("%/%", signature(e1 = "Rcpp_MXSymbol", e2 = "numeric"), function(e1, e2) { + mx.varg.symbol.internal.ModScalar(list(e1, scalar = e2)) + }) } diff --git a/R-package/src/ndarray.cc b/R-package/src/ndarray.cc index 335e5d47b486..c2bfe0c945a6 100644 --- a/R-package/src/ndarray.cc +++ b/R-package/src/ndarray.cc @@ -541,6 +541,9 @@ NDArray::RObjectType DispatchOps(SEXP op, SEXP lhs, SEXP rhs) { static OpHandle div = NDArrayFunction::FindHandle("_div"); static OpHandle div_scalar = NDArrayFunction::FindHandle("_div_scalar"); static OpHandle rdiv_scalar = NDArrayFunction::FindHandle("_rdiv_scalar"); + static OpHandle mod = NDArrayFunction::FindHandle("_mod"); + static OpHandle mod_scalar = NDArrayFunction::FindHandle("_mod_scalar"); + static OpHandle rmod_scalar = NDArrayFunction::FindHandle("_rmod_scalar"); // parse the arguments std::string values[2]; NDArrayHandle handles[2]; @@ -591,6 +594,16 @@ NDArray::RObjectType DispatchOps(SEXP op, SEXP lhs, SEXP rhs) { } break; } + case '%': { + if (lhs_nd && rhs_nd) { + out = BinaryOp(mod, handles); + } else if (lhs_nd && !rhs_nd) { + out = BinaryScalarOp(mod_scalar, handles[0], values[1]); + } else { + out = BinaryScalarOp(rmod_scalar, handles[1], values[0]); + } + break; + } default: { RLOG_FATAL << "Operator " << sop << "not supported for MXNDArray"; } diff --git a/cpp-package/include/mxnet-cpp/ndarray.h b/cpp-package/include/mxnet-cpp/ndarray.h index 58376a8ef6da..f908b4ff38eb 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.h +++ b/cpp-package/include/mxnet-cpp/ndarray.h @@ -145,10 +145,12 @@ class NDArray { NDArray operator-(mx_float scalar); NDArray operator*(mx_float scalar); NDArray operator/(mx_float scalar); + NDArray operator%(mx_float scalar); NDArray operator+(const NDArray &); NDArray operator-(const NDArray &); NDArray operator*(const NDArray &); NDArray operator/(const NDArray &); + NDArray operator%(const NDArray &); /*! * \brief set all the elements in ndarray to be scalar * \param scalar the scalar to set @@ -184,6 +186,13 @@ class NDArray { */ NDArray &operator/=(mx_float scalar); /*! + * \brief elementwise modulo from current ndarray + * this mutate the current NDArray + * \param scalar the data to subtract + * \return reference of self + */ + NDArray &operator%=(mx_float scalar); + /*! * \brief elementwise add to current space * this mutate the current NDArray * \param src the data to add @@ -211,6 +220,13 @@ class NDArray { * \return reference of self */ NDArray &operator/=(const NDArray &src); + /*! + * \brief elementwise modulo from current ndarray + * this mutate the current NDArray + * \param src the data to subtract + * \return reference of self + */ + NDArray &operator%=(const NDArray &src); NDArray ArgmaxChannel(); /*! * \brief Do a synchronize copy from a continugous CPU memory region. diff --git a/cpp-package/include/mxnet-cpp/ndarray.hpp b/cpp-package/include/mxnet-cpp/ndarray.hpp index 69d1082bf8fa..6157a6600cb4 100644 --- a/cpp-package/include/mxnet-cpp/ndarray.hpp +++ b/cpp-package/include/mxnet-cpp/ndarray.hpp @@ -93,6 +93,11 @@ inline NDArray NDArray::operator/(mx_float scalar) { Operator("_div_scalar")(*this, scalar).Invoke(ret); return ret; } +inline NDArray NDArray::operator%(mx_float scalar) { + NDArray ret; + Operator("_mod_scalar")(*this, scalar).Invoke(ret); + return ret; +} inline NDArray NDArray::operator+(const NDArray &rhs) { NDArray ret; Operator("_plus")(*this, rhs).Invoke(ret); @@ -113,6 +118,11 @@ inline NDArray NDArray::operator/(const NDArray &rhs) { Operator("_div")(*this, rhs).Invoke(ret); return ret; } +inline NDArray NDArray::operator%(const NDArray &rhs) { + NDArray ret; + Operator("_mod")(*this, rhs).Invoke(ret); + return ret; +} inline NDArray &NDArray::operator=(mx_float scalar) { Operator("_set_value")(scalar).Invoke(*this); return *this; @@ -133,6 +143,10 @@ inline NDArray &NDArray::operator/=(mx_float scalar) { Operator("_div_scalar")(*this, scalar).Invoke(*this); return *this; } +inline NDArray &NDArray::operator%=(mx_float scalar) { + Operator("_mod_scalar")(*this, scalar).Invoke(*this); + return *this; +} inline NDArray &NDArray::operator+=(const NDArray &rhs) { Operator("_plus")(*this, rhs).Invoke(*this); return *this; @@ -149,6 +163,10 @@ inline NDArray &NDArray::operator/=(const NDArray &rhs) { Operator("_div")(*this, rhs).Invoke(*this); return *this; } +inline NDArray &NDArray::operator%=(const NDArray &rhs) { + Operator("_mod")(*this, rhs).Invoke(*this); + return *this; +} inline NDArray NDArray::ArgmaxChannel() { NDArray ret; diff --git a/cpp-package/include/mxnet-cpp/op_suppl.h b/cpp-package/include/mxnet-cpp/op_suppl.h index 9381a1ecade9..c40449cc9f89 100644 --- a/cpp-package/include/mxnet-cpp/op_suppl.h +++ b/cpp-package/include/mxnet-cpp/op_suppl.h @@ -35,6 +35,10 @@ inline Symbol _Div(Symbol lhs, Symbol rhs) { return Operator("_Div")(lhs, rhs) .CreateSymbol(); } +inline Symbol _Mod(Symbol lhs, Symbol rhs) { + return Operator("_Mod")(lhs, rhs) + .CreateSymbol(); +} inline Symbol _Power(Symbol lhs, Symbol rhs) { return Operator("_Power")(lhs, rhs) .CreateSymbol(); @@ -77,6 +81,16 @@ inline Symbol _RDivScalar(mx_float scalar, Symbol rhs) { .SetParam("scalar", scalar) .CreateSymbol(); } +inline Symbol _ModScalar(Symbol lhs, mx_float scalar) { + return Operator("_ModScalar")(lhs) + .SetParam("scalar", scalar) + .CreateSymbol(); +} +inline Symbol _RModScalar(mx_float scalar, Symbol rhs) { + return Operator("_RModScalar")(rhs) + .SetParam("scalar", scalar) + .CreateSymbol(); +} inline Symbol _PowerScalar(Symbol lhs, mx_float scalar) { return Operator("_PowerScalar")(lhs) .SetParam("scalar", scalar) diff --git a/cpp-package/include/mxnet-cpp/symbol.h b/cpp-package/include/mxnet-cpp/symbol.h index 03a8409f8087..e853c2617ea4 100644 --- a/cpp-package/include/mxnet-cpp/symbol.h +++ b/cpp-package/include/mxnet-cpp/symbol.h @@ -72,11 +72,13 @@ class Symbol { Symbol operator-(const Symbol &rhs) const; Symbol operator*(const Symbol &rhs) const; Symbol operator/(const Symbol &rhs) const; + Symbol operator%(const Symbol &rhs) const; Symbol operator+(mx_float scalar) const; Symbol operator-(mx_float scalar) const; Symbol operator*(mx_float scalar) const; Symbol operator/(mx_float scalar) const; + Symbol operator%(mx_float scalar) const; Symbol Copy() const; /*! * \brief construct a variable Symbol @@ -252,6 +254,7 @@ Symbol operator+(mx_float lhs, const Symbol &rhs); Symbol operator-(mx_float lhs, const Symbol &rhs); Symbol operator*(mx_float lhs, const Symbol &rhs); Symbol operator/(mx_float lhs, const Symbol &rhs); +Symbol operator%(mx_float lhs, const Symbol &rhs); } // namespace cpp } // namespace mxnet #endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_SYMBOL_H_ diff --git a/cpp-package/include/mxnet-cpp/symbol.hpp b/cpp-package/include/mxnet-cpp/symbol.hpp index 40108325d594..26962ba5c99b 100644 --- a/cpp-package/include/mxnet-cpp/symbol.hpp +++ b/cpp-package/include/mxnet-cpp/symbol.hpp @@ -38,6 +38,7 @@ inline Symbol Symbol::operator+(const Symbol &rhs) const { return _Plus(*this, r inline Symbol Symbol::operator-(const Symbol &rhs) const { return _Minus(*this, rhs); } inline Symbol Symbol::operator*(const Symbol &rhs) const { return _Mul(*this, rhs); } inline Symbol Symbol::operator/(const Symbol &rhs) const { return _Div(*this, rhs); } +inline Symbol Symbol::operator%(const Symbol &rhs) const { return _Mod(*this, rhs); } inline Symbol Symbol::operator+(mx_float scalar) const { return _PlusScalar(*this, scalar); } @@ -50,6 +51,9 @@ inline Symbol Symbol::operator*(mx_float scalar) const { inline Symbol Symbol::operator/(mx_float scalar) const { return _DivScalar(*this, scalar); } +inline Symbol Symbol::operator%(mx_float scalar) const { + return _ModScalar(*this, scalar); +} inline Symbol Symbol::operator[](int index) { SymbolHandle out; MXSymbolGetOutput(GetHandle(), index, &out); @@ -337,6 +341,9 @@ inline Symbol operator*(mx_float lhs, const Symbol &rhs) { return rhs * lhs; } inline Symbol operator/(mx_float lhs, const Symbol &rhs) { return mxnet::cpp::_RDivScalar(lhs, rhs); } +inline Symbol operator%(mx_float lhs, const Symbol &rhs) { + return mxnet::cpp::_RModScalar(lhs, rhs); +} } // namespace cpp } // namespace mxnet diff --git a/docs/api/python/ndarray.md b/docs/api/python/ndarray.md index 2581c2c3354b..a782b910e656 100644 --- a/docs/api/python/ndarray.md +++ b/docs/api/python/ndarray.md @@ -120,6 +120,8 @@ In the rest of this document, we first overview the methods provided by the NDArray.__mul__ NDArray.__div__ NDArray.__rdiv__ + NDArray.__mod__ + NDArray.__rmod__ NDArray.__pow__ ``` @@ -133,6 +135,7 @@ In the rest of this document, we first overview the methods provided by the NDArray.__isub__ NDArray.__imul__ NDArray.__idiv__ + NDArray.__imod__ ``` ### Comparison operators @@ -259,6 +262,7 @@ In the rest of this document, we first overview the methods provided by the negative multiply divide + modulo dot batch_dot add_n diff --git a/docs/api/python/symbol.md b/docs/api/python/symbol.md index 14ed06b9db9b..f99bee2bd79b 100644 --- a/docs/api/python/symbol.md +++ b/docs/api/python/symbol.md @@ -86,6 +86,8 @@ Composite multiple symbols into a new one by an operator. Symbol.__mul__ Symbol.__div__ Symbol.__rdiv__ + Symbol.__mod__ + Symbol.__rmod__ Symbol.__pow__ ``` @@ -249,6 +251,7 @@ Composite multiple symbols into a new one by an operator. broadcast_sub broadcast_mul broadcast_div + broadcast_mod negative dot batch_dot diff --git a/mshadow b/mshadow index eda261eef135..8db65bd081c7 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit eda261eef135a51e7388e680b295996d18d4e4d1 +Subproject commit 8db65bd081c7e243028ace93ef0acc9efc4383ba diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 8900843f5937..9ec4d47bbb81 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -206,6 +206,25 @@ def __rtruediv__(self, other): def __itruediv__(self, other): return self.__idiv__(other) + def __mod__(self, other): + """x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """ + return modulo(self, other) + + def __rmod__(self, other): + """x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """ + return modulo(other, self) + + def __imod__(self, other): + """x.__rmod__(y) <=> x%=y """ + if not self.writable: + raise ValueError('trying to take modulo from a readonly NDArray') + if isinstance(other, NDArray): + return broadcast_mod(self, other, out=self) + elif isinstance(other, numeric_types): + return _internal._mod_scalar(self, float(other), out=self) + else: + raise TypeError('type %s not supported' % str(type(other))) + def __pow__(self, other): """x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """ return power(self, other) @@ -1516,6 +1535,62 @@ def divide(lhs, rhs): _internal._rdiv_scalar) # pylint: enable= no-member, protected-access +def modulo(lhs, rhs): + """Returns element-wise modulo of the input arrays with broadcasting. + + Equivalent to ``lhs % rhs`` and ``mx.nd.broadcast_mod(lhs, rhs)``. + + .. note:: + + If the corresponding dimensions of two arrays have the same size or one of them has size 1, + then the arrays are broadcastable to a common shape. + + Parameters + ---------- + lhs : scalar or array + First array in modulo. + rhs : scalar or array + Second array in modulo. + The arrays to be taken modulo. If ``lhs.shape != rhs.shape``, they must be + broadcastable to a common shape. + + Returns + ------- + NDArray + The element-wise modulo of the input arrays. + + Examples + -------- + >>> x = mx.nd.ones((2,3))*6 + >>> y = mx.nd.ones((2,1))*4 + >>> x.asnumpy() + array([[ 6., 6., 6.], + [ 6., 6., 6.]], dtype=float32) + >>> y.asnumpy() + array([[ 4.], + [ 4.]], dtype=float32) + >>> x%5 + + >>> (x%5).asnumpy() + array([[ 1., 1., 1.], + [ 1., 1., 1.]], dtype=float32) + >>> (x%y).asnumpy() + array([[ 2., 2., 2.], + [ 2., 2., 2.]], dtype=float32) + >>> mx.nd.modulo(x,y).asnumpy() + array([[ 2., 2., 2.], + [ 2., 2., 2.]], dtype=float32) + """ + # pylint: disable= no-member, protected-access + return _ufunc_helper( + lhs, + rhs, + broadcast_mod, + operator.mod, + _internal._mod_scalar, + _internal._rmod_scalar) + # pylint: enable= no-member, protected-access + def power(base, exp): """Returns result of first array elements raised to powers from second array, element-wise with broadcasting. diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 14203e59862d..bd0aca65f521 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -172,6 +172,36 @@ def __rdiv__(self, other): else: raise TypeError('type %s not supported' % str(type(other))) + def __mod__(self, other): + """x.__mod__(y) <=> x%y + + Scalar input is supported. + Broadcasting is not supported. Use `broadcast_mod` instead. """ + if isinstance(other, Symbol): + return _internal._Mod(self, other) + if isinstance(other, Number): + return _internal._ModScalar(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __rmod__(self, other): + """x.__rmod__(y) <=> y%x + + Only `NDArray` is supported for now. + + Example usage: + ---------- + >>> x = mx.nd.ones((2,3))*3 + >>> y = mx.nd.ones((2,3)) + >>> x.__rmod__(y).asnumpy() + array([[ 1., 1., 1., + [ 1., 1., 1., dtype=float32) + """ + if isinstance(other, Number): + return _internal._RModScalar(self, scalar=other) + else: + raise TypeError('type %s not supported' % str(type(other))) + def __truediv__(self, other): return self.__div__(other) diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala index 49eea3dc9d05..94ce0086f6e4 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/NDArray.scala @@ -880,6 +880,30 @@ class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle, NDArray.lesserEqual(this, other) } + def %(other: NDArray): NDArray = { + NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other)) + } + + def %(other: Float): NDArray = { + NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other)) + } + + def %=(other: NDArray): NDArray = { + if (!writable) { + throw new IllegalArgumentException("trying to take modulo from a readonly NDArray") + } + NDArray.genericNDArrayFunctionInvoke("_mod", Seq(this, other), Map("out" -> this)) + this + } + + def %=(other: Float): NDArray = { + if (!writable) { + throw new IllegalArgumentException("trying to take modulo from a readonly NDArray") + } + NDArray.genericNDArrayFunctionInvoke("_mod_scalar", Seq(this, other), Map("out" -> this)) + this + } + /** * Return a copied flat java array of current array (row-major). * @return A copy of array content. diff --git a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala index de60e472e76c..4e8d4c2bd9f9 100644 --- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/Symbol.scala @@ -85,6 +85,11 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) { def <=(other: Symbol): Symbol = Symbol.lesserEqual(this, other) def <=[@specialized(Int, Float, Double) V](other: V): Symbol = Symbol.lesserEqual(this, other) + def %(other: Symbol): Symbol = Symbol.createFromListedSymbols("_Mod")(Array(this, other)) + def %[@specialized(Int, Float, Double) V](other: V): Symbol = { + Symbol.createFromListedSymbols("_ModScalar")(Array(this), Map("scalar" -> other.toString)) + } + override def clone(): Symbol = { val clonedHandle = new SymbolHandleRef checkCall(_LIB.mxSymbolCopy(handle, clonedHandle)) @@ -1236,6 +1241,11 @@ class SymbolConversions[@specialized(Int, Float, Double) V](val value: V) { def <=(other: Symbol): Symbol = { other >= value } + + def %(other: Symbol): Symbol = { + Symbol.createFromListedSymbols("_RModScalar")( + Array(other), Map("scalar" -> value.toString)) + } } trait SymbolGenerator { diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h index 00dd3d0e959a..479f6f99f07a 100644 --- a/src/ndarray/ndarray_function.h +++ b/src/ndarray/ndarray_function.h @@ -41,6 +41,10 @@ struct Div : public BinaryBase { typedef mshadow::op::div mshadow_op; }; +struct Mod : public BinaryBase { + typedef op::mshadow_op::mod mshadow_op; +}; + struct ClipMin : public BinaryBase { struct mshadow_op { template diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 7e950c980e53..c63739ba5085 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -8,8 +8,13 @@ #define MXNET_OPERATOR_MSHADOW_OP_H_ #include +#include #include "special_functions-inl.h" +#ifdef __CUDACC__ +#include +#endif + namespace mxnet { namespace op { namespace mshadow_op { @@ -24,14 +29,14 @@ using std::isnan; struct identity { template MSHADOW_XINLINE static DType Map(DType a) { - return DType(a); + return a; } }; struct identity_grad { template MSHADOW_XINLINE static DType Map(DType a) { - return DType(DType(1.0f)); + return DType(1.0f); } }; @@ -434,15 +439,15 @@ struct abs { struct sign { template MSHADOW_XINLINE static DType Map(DType a) { - if (a < 0.0f) return DType(-DType(1.0f)); - if (a > 0.0f) return DType(DType(1.0f)); - return DType(DType(0.0f)); + if (a < 0.0f) return DType(-1.0f); + if (a > 0.0f) return DType(1.0f); + return DType(0.0f); } }; struct sign_grad { template MSHADOW_XINLINE static DType Map(DType a) { - return DType(DType(0.0f)); + return DType(0.0f); } }; /*! \brief used for generate element of power */ @@ -664,6 +669,172 @@ struct rdiv_grad { } }; +struct mod { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (b == DType(0)) { + return DType(0); + } else if (b < DType(0)) { + if (a < DType(0)) { + return DType(-::fmod(-a, -b)); + } else { + return DType(::fmod(a, -b) + (::fmod(a, -b) != DType(0) ? b : DType(0))); + } + } else { + if (a < DType(0)) { + return DType(-::fmod(-a, b) + (::fmod(-a, b) != DType(0) ? b : DType(0))); + } else { + return DType(::fmod(a, b)); + } + } + } +}; +#ifdef __CUDACC__ +template<> +MSHADOW_XINLINE mshadow::half::half2_t mod::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { + return a%b; +} +#endif + +struct mod_grad { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return DType(0); + } +}; +template<> +MSHADOW_XINLINE double mod_grad::Map(double a, double b) { + return 1.0f; +} +template<> +MSHADOW_XINLINE float mod_grad::Map(float a, float b) { + return 1.0f; +} +#ifdef __CUDACC__ +template<> +MSHADOW_XINLINE mshadow::half::half_t mod_grad::Map + (mshadow::half::half_t a, + mshadow::half::half_t b) { + return mshadow::half::half_t(1.0f); +} +template<> +MSHADOW_XINLINE mshadow::half::half2_t mod_grad::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { + mshadow::half::half2_t result = mshadow::half::half2_t(); +#if MSHADOW_CUDA_HALF2 + result.half2_ = ::__float2half2_rn(1.0f); +#else + result.half_t2[0] = mshadow::half::half_t(0.0f); + result.half_t2[1] = mshadow::half::half_t(1.0f); +#endif + return result; +} +#endif + +struct mod_rgrad { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return DType(0); + } +}; +template<> +MSHADOW_XINLINE double mod_rgrad::Map(double a, double b) { + return -::floor(a/b); +} +template<> +MSHADOW_XINLINE float mod_rgrad::Map(float a, float b) { + return -::floorf(a/b); +} +#ifdef __CUDACC__ +template<> +MSHADOW_XINLINE mshadow::half::half_t mod_rgrad::Map + (mshadow::half::half_t a, + mshadow::half::half_t b) { + return mshadow::half::half_t(-::floorf(static_cast(a/b))); +} +template<> +MSHADOW_XINLINE mshadow::half::half2_t mod_rgrad::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { +#if MSHADOW_CUDA_HALF2 + return mshadow::half::half2_t(__hneg2(::h2floor((a/b).half2_))); +#else + return mshadow::half::half2_t(mshadow::half::half_t(-::floorf( + static_cast(a.half_t2[0]/b.half_t2[0]))), + mshadow::half::half_t(-::floorf( + static_cast(a.half_t2[1]/b.half_t2[1])))); +#endif +} +#endif + +struct rmod { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (a == DType(0)) { + return DType(0); + } else if (a < DType(0)) { + if (b < DType(0)) { + return DType(-::fmod(-b, -a)); + } else { + return DType(::fmod(b, -a) + (::fmod(b, -a) != DType(0) ? a : DType(0))); + } + } else { + if (b < DType(0)) { + return DType(-::fmod(-b, a) + (::fmod(-b, a) != DType(0) ? a : DType(0))); + } else { + return DType(::fmod(b, a)); + } + } + } +}; +#ifdef __CUDACC__ +template<> +MSHADOW_XINLINE mshadow::half::half2_t rmod::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { + return b%a; +} +#endif + +struct rmod_grad { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + return DType(0); + } +}; +template<> +MSHADOW_XINLINE double rmod_grad::Map(double a, double b) { + return -::floor(b/a); +} +template<> +MSHADOW_XINLINE float rmod_grad::Map(float a, float b) { + return -::floorf(b/a); +} +#ifdef __CUDACC__ +template<> +MSHADOW_XINLINE mshadow::half::half_t rmod_grad::Map + (mshadow::half::half_t a, + mshadow::half::half_t b) { + return mshadow::half::half_t(-::floorf(static_cast(b/a))); +} +template<> +MSHADOW_XINLINE mshadow::half::half2_t rmod_grad::Map + (mshadow::half::half2_t a, + mshadow::half::half2_t b) { +#if MSHADOW_CUDA_HALF2 + return mshadow::half::half2_t(::__hneg2(::h2floor((b/a).half2_))); +#else + return mshadow::half::half2_t(mshadow::half::half_t(-::floorf( + static_cast(b.half_t2[0]/a.half_t2[0]))), + mshadow::half::half_t(-::floorf( + static_cast(b.half_t2[1]/a.half_t2[1])))); +#endif +} +#endif + struct clip { template MSHADOW_XINLINE static DType Map(DType x, DType bound) { diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc index 0d0a1d8b5df0..27a4b5f25c82 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc @@ -153,5 +153,38 @@ NNVM_REGISTER_OP(_backward_broadcast_div) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_mod) +.describe(R"code(Returns element-wise modulo of the input arrays with broadcasting. + +Example:: + + x = [[ 8., 8., 8.], + [ 8., 8., 8.]] + + y = [[ 2.], + [ 3.]] + + broadcast_mod(x, y) = [[ 0., 0., 0.], + [ 2., 2., 2.]] + +)code" ADD_FILELINE) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}); + +NNVM_REGISTER_OP(_backward_broadcast_mod) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu index f23d3d0cbad8..ef0e679d6166 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cu @@ -37,5 +37,12 @@ NNVM_REGISTER_OP(_backward_broadcast_div) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(broadcast_mod) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_broadcast_mod) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index be4c1d88e983..1f363a114375 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -78,5 +78,21 @@ NNVM_REGISTER_OP(_backward_div) .set_attr("FCompute", BinaryBackwardUseIn); +MXNET_OPERATOR_REGISTER_BINARY(_mod) +.add_alias("_Mod") +.set_attr("FCompute", BinaryCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_mod"}); + +NNVM_REGISTER_OP(_backward_mod) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FCompute", BinaryBackwardUseIn); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index ff432380d6d1..6355c4e5cf01 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -40,5 +40,12 @@ NNVM_REGISTER_OP(_backward_div) .set_attr("FCompute", BinaryBackwardUseInWithHalf2); +NNVM_REGISTER_OP(_mod) +.set_attr("FCompute", BinaryComputeWithHalf2); + +NNVM_REGISTER_OP(_backward_mod) +.set_attr("FCompute", BinaryBackwardUseInWithHalf2); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc index ddbba4d10f2c..bd0b5335e3ae 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cc @@ -44,5 +44,25 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_rdiv_scalar) .set_attr_parser([](NodeAttrs* attrs) {attrs->parsed = std::stod(attrs->dict["scalar"]);}) .set_attr("FCompute", BinaryScalarBackward); +MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_mod_scalar) +.set_attr("FCompute", BinaryScalarCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_mod_scalar"}) +.add_alias("_ModScalar"); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_mod_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs* attrs) {attrs->parsed = std::stod(attrs->dict["scalar"]);}) +.set_attr("FCompute", BinaryScalarBackward); + +MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_rmod_scalar) +.set_attr("FCompute", BinaryScalarCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_rmod_scalar"}) +.add_alias("_RModScalar"); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_rmod_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs* attrs) {attrs->parsed = std::stod(attrs->dict["scalar"]);}) +.set_attr("FCompute", BinaryScalarBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu index 356b34901670..ae19aa8a72f6 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu @@ -30,5 +30,17 @@ NNVM_REGISTER_OP(_rdiv_scalar) NNVM_REGISTER_OP(_backward_rdiv_scalar) .set_attr("FCompute", BinaryScalarBackward); +NNVM_REGISTER_OP(_mod_scalar) +.set_attr("FCompute", BinaryScalarCompute); + +NNVM_REGISTER_OP(_backward_mod_scalar) +.set_attr("FCompute", BinaryScalarBackward); + +NNVM_REGISTER_OP(_rmod_scalar) +.set_attr("FCompute", BinaryScalarCompute); + +NNVM_REGISTER_OP(_backward_rmod_scalar) +.set_attr("FCompute", BinaryScalarBackward); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index fece5414dbc3..55c1d2488d6e 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -994,23 +994,41 @@ def gen_broadcast_data(idx): def gen_broadcast_data_int(idx): d = gen_broadcast_data(idx); - return [np.round(d[0]*100), np.round(d[1]*100)] + return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)] def gen_binary_data(dummy): ndim = np.random.randint(1, 6) shape = np.random.randint(1, 6, size=(ndim,)) return [np.random.random(shape), np.random.random(shape)] -def check_binary_op_forward(symbol, baseline, gen_data): +def gen_binary_data_int(dummy): + d = gen_binary_data(dummy); + return [np.round(d[0]*100).astype(int), np.round(d[1]*100).astype(int)] + +def check_binary_op_forward(symbol, baseline, gen_data, rtol=1e-3, atol=1e-5): sample_num = 200 for i in range(sample_num): d = gen_data(i) x = baseline(d[0], d[1]) y = symbol.bind(default_context(), args={'a': mx.nd.array(d[0]), 'b' : mx.nd.array(d[1])}) y.forward(is_train=True) - assert_allclose(x, y.outputs[0].asnumpy(), rtol=1e-3, atol=1e-5) - -def check_binary_op_backward(symbol, baseline, gen_data): + y = y.outputs[0].asnumpy() + idx = np.abs(x-y) > atol+rtol*np.abs(x) + if idx.any(): + print('found precision problem') + d[0] = np.broadcast_to(d[0], x.shape) + d[1] = np.broadcast_to(d[1], x.shape) + print('a: {}'.format(d[0][idx])) + print('b: {}'.format(d[1][idx])) + import struct + print('a hex: {}'.format(struct.pack('d', d[0][idx]).encode('hex'))) + print('b hex: {}'.format(struct.pack('d', np.broadcast_to(d[1], x.shape)[idx]).encode('hex'))) + print('in baseline(a, b): {}'.format(x[idx])) + print('in symbol(a, b): {}'.format(y[idx])) + print('diff: {}'.format(np.abs(x-y)[idx] - atol-rtol*np.abs(x)[idx])) + assert_allclose(y, x, rtol=rtol, atol=atol) + +def check_binary_op_backward(symbol, baseline, gen_data, rtol=1e-3, atol=1e-5): sample_num = 200 for i in range(sample_num): d = gen_data(i) @@ -1033,8 +1051,8 @@ def reduce_op(shape, x): args_grad=[y_1, y_2]) y.forward(is_train=True) y.backward([mx.nd.array(out)]) - assert_allclose(x_1, y_1.asnumpy(), rtol=1e-3, atol=1e-5) - assert_allclose(x_2, y_2.asnumpy(), rtol=1e-3, atol=1e-5) + assert_allclose(y_1.asnumpy(), x_1, rtol=rtol, atol=atol) + assert_allclose(y_2.asnumpy(), x_2, rtol=rtol, atol=atol) def test_binary_op(): a = mx.sym.Variable('a') @@ -1060,6 +1078,16 @@ def test_bdiv(a, b): check_binary_op_forward(c, lambda a, b: a / b, gen_binary_data) check_binary_op_backward(c, lambda g_out, a, b: (g_out / b, - g_out * a / (b * b)), gen_binary_data) + def test_bmod(a, b): + c = a % b + check_binary_op_forward(c, lambda a, b: a % b, gen_binary_data) + check_binary_op_backward(c, lambda g_out, a, b: (g_out, - g_out * (a // b)), gen_binary_data) + + def test_bmod_int(a, b): + c = mx.sym.cast(a, dtype='int32') % mx.sym.cast(b, dtype='int32') + check_binary_op_forward(c, lambda a, b: a % b, gen_binary_data_int) + check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_binary_data_int) + def test_bpow(a, b): c = a ** b check_binary_op_forward(c, lambda a, b: a ** b, gen_binary_data) @@ -1075,6 +1103,8 @@ def test_bneq(a, b): test_bminus(a, b) test_bmul(a, b) test_bdiv(a, b) + test_bmod(a, b) + test_bmod_int(a, b) test_bpow(a, b) test_bneq(a, b) @@ -1102,6 +1132,16 @@ def test_bdiv(a, b): check_binary_op_forward(c, lambda a, b: a / b, gen_broadcast_data) check_binary_op_backward(c, lambda g_out, a, b: (g_out / b, - g_out * a / (b * b)), gen_broadcast_data) + def test_bmod(a, b): + c = mx.sym.broadcast_mod(a, b) + check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data, atol=1) + check_binary_op_backward(c, lambda g_out, a, b: (g_out, - g_out * (a // b)), gen_broadcast_data, atol=1) + + def test_bmod_int(a, b): + c = mx.sym.broadcast_mod(mx.sym.cast(a, dtype='int32'), mx.sym.cast(b, dtype='int32')) + check_binary_op_forward(c, lambda a, b: a % b, gen_broadcast_data_int) + check_binary_op_backward(c, lambda g_out, a, b: (np.zeros_like(a), np.zeros_like(b)), gen_broadcast_data_int) + def test_bpow(a, b): c = mx.sym.broadcast_power(a, b) check_binary_op_forward(c, lambda a, b: a ** b, gen_broadcast_data) @@ -1117,6 +1157,8 @@ def test_bequal(a, b): test_bminus(a, b) test_bmul(a, b) test_bdiv(a, b) + test_bmod(a, b) + test_bmod_int(a, b) test_bpow(a, b) test_bequal(a, b) @@ -3276,7 +3318,7 @@ def test_laop(): # Currently no support for GPU. Will be added soon # so keep these tests here in this file and activate - # gpu-testing when it is ready. + # gpu-testing when it is ready. dev = default_context() if dev.device_type == 'gpu': return @@ -3293,37 +3335,37 @@ def test_laop(): shape2 = (3, 2) shape3 = (3, 3) shape4 = (2, 2) - #Ensure that ithis tests don't get changed by other calls to random. + #Ensure that ithis tests don't get changed by other calls to random. np.random.seed(42) - data_in1 = np.random.uniform(1, 10, shape1) - data_in2 = np.random.uniform(1, 10, shape2) - data_in3 = np.random.uniform(1, 10, shape3) - data_in4 = np.random.uniform(1, 10, shape4) + data_in1 = np.random.uniform(1, 10, shape1) + data_in2 = np.random.uniform(1, 10, shape2) + data_in3 = np.random.uniform(1, 10, shape3) + data_in4 = np.random.uniform(1, 10, shape4) # Check all transpositions of gemm operator. - data_in1_t = np.transpose(data_in1) - data_in2_t = np.transpose(data_in2) + data_in1_t = np.transpose(data_in1) + data_in2_t = np.transpose(data_in2) res_gemm = 4*np.dot(data_in1,data_in2)+7*data_in4 - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7) + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7) check_symbolic_forward(test_gemm, [data_in1, data_in2, data_in4], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in2, data_in4], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1_t,data_in2_t)+7*data_in3 - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1, transpose_b = 1) + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1, transpose_b = 1) check_symbolic_forward(test_gemm, [data_in1, data_in2, data_in3], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in2, data_in3], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1_t,data_in1)+7*data_in3 - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1) + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_a = 1) check_symbolic_forward(test_gemm, [data_in1, data_in1, data_in3], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in1, data_in3], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1,data_in1_t)+7*data_in4 - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_b = 1) + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7, transpose_b = 1) check_symbolic_forward(test_gemm, [data_in1, data_in1, data_in4], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in1, data_in4], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - # Check batch of gemm. + # Check batch of gemm. a = np.tile(np.array(data_in1).flatten(),3) a = np.reshape(a,(3,1,2,3)) b = np.tile(np.array(data_in2).flatten(),3) @@ -3333,34 +3375,34 @@ def test_laop(): r = 4*np.dot(data_in1,data_in2)+7*data_in4 r = np.tile(r.flatten(),3) r = np.reshape(r,(3,1,2,2)) - test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7) + test_gemm = mx.sym.linalg_gemm(data1, data2, data3, alpha = 4, beta = 7) check_symbolic_forward(test_gemm, [a, b, c], [r]) if grad_check == 1: check_numeric_gradient(test_gemm, [a, b, c], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - # Check gemm2 operator same way as gemm. + # Check gemm2 operator same way as gemm. res_gemm = 4*np.dot(data_in1,data_in2) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4) check_symbolic_forward(test_gemm, [data_in1, data_in2], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in2], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1_t, data_in2_t) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_a = 1, transpose_b = 1) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_a = 1, transpose_b = 1) check_symbolic_forward(test_gemm, [data_in1, data_in2], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in2], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1_t,data_in1) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_a = 1) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_a = 1) check_symbolic_forward(test_gemm, [data_in1, data_in1], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in1], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) res_gemm = 4*np.dot(data_in1,data_in1_t) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_b = 1) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4, transpose_b = 1) check_symbolic_forward(test_gemm, [data_in1, data_in1], [res_gemm]) if grad_check == 1: check_numeric_gradient(test_gemm, [data_in1, data_in1], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - # Check batch of gemm2. + # Check batch of gemm2. a = np.tile(np.array(data_in1).flatten(),3) a = np.reshape(a,(3,1,2,3)) b = np.tile(np.array(data_in2).flatten(),3) @@ -3368,12 +3410,12 @@ def test_laop(): r = 4*np.dot(data_in1,data_in2) r = np.tile(r.flatten(),3) r = np.reshape(r,(3,1,2,2)) - test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4) + test_gemm = mx.sym.linalg_gemm2(data1, data2, alpha = 4) check_symbolic_forward(test_gemm, [a, b], [r]) if grad_check == 1: check_numeric_gradient(test_gemm, [a, b], numeric_eps=1e-3, rtol=1e-1, atol=1e-1) - # Now test all the other operators. + # Now test all the other operators. # Tests with trivial 1x1 matrices. shape = (4, 4, 1, 1 ) @@ -3404,7 +3446,7 @@ def test_laop(): if grad_check == 1: check_numeric_gradient(test_trmm, [trian_in,data_in], atol = 0.02, rtol = 2.0) # test sumlogdiag - res_sumlogdiag = np.reshape(np.log(data_in),(4,4)) + res_sumlogdiag = np.reshape(np.log(data_in),(4,4)) test_sumlogdiag = mx.sym.linalg_sumlogdiag(data1) check_symbolic_forward(test_sumlogdiag, [data_in], [res_sumlogdiag]) if grad_check == 1: @@ -3417,9 +3459,9 @@ def test_laop(): inv = [ 2.98333, 0.01667, 2.65, -0.83333, 0.01667, 0.05, 0.05, 0, 2.65, 0.05, 2.5, -0.75, -0.83333, 0, -0.75, 0.25 ] ident = [ 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1 ] - # Tests for numeric gradients for potrf/potri/trmm/trsm are suppressed by default - # as they are very volatile and may often report false negatives which - # have to be excluded by manual inspection. + # Tests for numeric gradients for potrf/potri/trmm/trsm are suppressed by default + # as they are very volatile and may often report false negatives which + # have to be excluded by manual inspection. grad_check = 0 # test potrf @@ -3430,7 +3472,7 @@ def test_laop(): check_symbolic_forward(test_potrf, [a], [r]) if grad_check == 1: check_numeric_gradient(test_potrf, [a], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - + #test potri a = np.tile(np.array(trian),3) a = np.reshape(a,(3,1,4,4)) @@ -3450,7 +3492,7 @@ def test_laop(): check_symbolic_forward(test_trsm, [a,b], [r]) if grad_check == 1: check_numeric_gradient(test_trsm, [a,b], numeric_eps=1e-3, rtol=1e-2, atol=1e-1) - + test_trsm2 = mx.sym.linalg_trsm(data1,data2,alpha = -2, rightside = 1, transpose = 1) r = -2*np.reshape(np.array(trian),(4,4)) r = np.reshape(np.tile(np.reshape(r,(16)),3),(3,1,4,4))