diff --git a/EigenRand/Dists/Basic.h b/EigenRand/Dists/Basic.h index cc4a0da..75e296d 100644 --- a/EigenRand/Dists/Basic.h +++ b/EigenRand/Dists/Basic.h @@ -32,6 +32,11 @@ namespace Eigen class GenBase { public: + /** + * @brief Return a reference to the derived type. + */ + DerivedGen &derived() { return static_cast(*this); } + /** * @brief generate random values from its distribution * @@ -48,7 +53,7 @@ namespace Eigen generate(Index rows, Index cols, Urng&& urng) { return { - rows, cols, { std::forward(urng), static_cast(*this) } + rows, cols, { std::forward(urng), derived() } }; } @@ -67,7 +72,7 @@ namespace Eigen generateLike(const Derived& o, Urng&& urng) { return { - o.rows(), o.cols(), { std::forward(urng), static_cast(*this) } + o.rows(), o.cols(), { std::forward(urng), derived() } }; } }; @@ -76,6 +81,11 @@ namespace Eigen class UnaryGenBase { public: + /** + * @brief Return a reference to the derived type. + */ + DerivedGen &derived() { return static_cast(*this); } + /** * @brief generate random values from its distribution * @@ -93,7 +103,7 @@ namespace Eigen > generate(Urng&& urng, const ArrayBase& a) { return { - a, { std::forward(urng), static_cast(*this) } + a, { std::forward(urng), derived() } }; } }; @@ -102,6 +112,11 @@ namespace Eigen class BinaryGenBase { public: + /** + * @brief Return a reference to the derived type. + */ + DerivedGen &derived() { return static_cast(*this); } + /** * @brief generate random values from its distribution * @@ -119,7 +134,7 @@ namespace Eigen > generate(Urng&& urng, const ArrayBase& a, const ArrayBase& b) { return { - a, b, { std::forward(urng), static_cast(*this) } + a, b, { std::forward(urng), derived() } }; } @@ -131,7 +146,7 @@ namespace Eigen { return { a, { a.rows(), a.cols(), internal::scalar_constant_op{ b } }, - { std::forward(urng), static_cast(*this) } + { std::forward(urng), derived() } }; } @@ -143,7 +158,7 @@ namespace Eigen { return { { b.rows(), b.cols(), internal::scalar_constant_op{ a } }, b, - { std::forward(urng), static_cast(*this) } + { std::forward(urng), derived() } }; } }; @@ -159,10 +174,15 @@ namespace Eigen class MvVecGenBase { public: + /** + * @brief Return a reference to the derived type. + */ + DerivedGen &derived() { return static_cast(*this); } + /** * @brief returns the dimensions of vectors to be generated */ - Index dims() const { return static_cast(*this).dims(); } + Index dims() const { return derived().dims(); } /** * @brief generates multiple samples at once @@ -176,7 +196,7 @@ namespace Eigen template inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng, Index samples) { - return static_cast(*this).generatr(std::forward(urng), samples); + return derived().generate(std::forward(urng), samples); } /** @@ -189,7 +209,7 @@ namespace Eigen template inline Matrix<_Scalar, Dim, 1> generate(Urng&& urng) { - return static_cast(*this).generatr(std::forward(urng)); + return derived().generate(std::forward(urng)); } }; @@ -204,10 +224,15 @@ namespace Eigen class MvMatGenBase { public: + /** + * @brief Return a reference to the derived type. + */ + DerivedGen &derived() { return static_cast(*this); } + /** * @brief returns the dimensions of matrices to be generated */ - Index dims() const { return static_cast(*this).dims(); } + Index dims() const { return derived().dims(); } /** * @brief generates multiple samples at once @@ -221,7 +246,7 @@ namespace Eigen template inline Matrix<_Scalar, Dim, -1> generate(Urng&& urng, Index samples) { - return static_cast(*this).generate(std::forward(urng), samples); + return derived().generate(std::forward(urng), samples); } /** @@ -234,7 +259,7 @@ namespace Eigen template inline Matrix<_Scalar, Dim, Dim> generate(Urng&& urng) { - return static_cast(*this).generate(std::forward(urng)); + return derived().generate(std::forward(urng)); } }; diff --git a/EigenRand/MvDists/MvNormal.h b/EigenRand/MvDists/MvNormal.h index e3555bf..b13c4a7 100644 --- a/EigenRand/MvDists/MvNormal.h +++ b/EigenRand/MvDists/MvNormal.h @@ -115,6 +115,20 @@ namespace Eigen } }; + namespace detail { + template + constexpr bool either_is_dynamic() { + return (MatrixBase::RowsAtCompileTime == Eigen::Dynamic) || + (MatrixBase::RowsAtCompileTime == Eigen::Dynamic); + } + + template + constexpr bool normal_check_dims() { + return (either_is_dynamic() || MatrixBase::RowsAtCompileTime == MatrixBase::RowsAtCompileTime) && + MatrixBase::RowsAtCompileTime == MatrixBase::ColsAtCompileTime; + } + } + /** * @brief helper function constructing Eigen::Rand::MvNormal * @@ -132,8 +146,7 @@ namespace Eigen "Derived::Scalar must be the same with `mean` and `cov`'s Scalar." ); static_assert( - MatrixBase::RowsAtCompileTime == MatrixBase::RowsAtCompileTime && - MatrixBase::RowsAtCompileTime == MatrixBase::ColsAtCompileTime, + detail::normal_check_dims(), "assert: mean.RowsAtCompileTime == cov.RowsAtCompileTime && cov.RowsAtCompileTime == cov.ColsAtCompileTime" ); return { mean, cov }; @@ -156,8 +169,7 @@ namespace Eigen "Derived::Scalar must be the same with `mean` and `lt`'s Scalar." ); static_assert( - MatrixBase::RowsAtCompileTime == MatrixBase::RowsAtCompileTime && - MatrixBase::RowsAtCompileTime == MatrixBase::ColsAtCompileTime, + detail::normal_check_dims(), "assert: mean.RowsAtCompileTime == lt.RowsAtCompileTime && lt.RowsAtCompileTime == lt.ColsAtCompileTime" ); return { mean, lt, lower_triangular };