From 8401b96f8b243d4de9b1a5787060019edce9facb Mon Sep 17 00:00:00 2001 From: Anthony Fieroni Date: Tue, 18 Jan 2022 14:34:38 +0200 Subject: [PATCH] Return uint32_t multiplication in base uint Signed-off-by: Anthony Fieroni --- src/arith_uint256.cpp | 33 +++++++++++++++++++++++++ src/arith_uint256.h | 42 +++++++++++++++++++++++++------- src/test/arith_uint256_tests.cpp | 10 +++++--- 3 files changed, 73 insertions(+), 12 deletions(-) diff --git a/src/arith_uint256.cpp b/src/arith_uint256.cpp index b092977a4d..b67e49dcfc 100644 --- a/src/arith_uint256.cpp +++ b/src/arith_uint256.cpp @@ -54,6 +54,39 @@ base_uint& base_uint::operator>>=(unsigned int shift) return *this; } +template +base_uint& base_uint::operator*=(uint32_t b32) +{ + uint64_t carry = 0; + for (int i = 0; i < WIDTH; i++) { + uint64_t n = carry + (uint64_t)b32 * pn[i]; + pn[i] = n & 0xffffffff; + carry = n >> 32; + } + return *this; +} + +template +base_uint& base_uint::operator*=(int32_t b32) +{ + (*this) *= uint32_t(b32); + return *this; +} + +template +base_uint& base_uint::operator*=(int64_t b64) +{ + (*this) *= base_uint(b64); + return *this; +} + +template +base_uint& base_uint::operator*=(uint64_t b64) +{ + (*this) *= base_uint(b64); + return *this; +} + template base_uint& base_uint::operator*=(const base_uint& b) { diff --git a/src/arith_uint256.h b/src/arith_uint256.h index bdcb0e43a9..e0086f9306 100644 --- a/src/arith_uint256.h +++ b/src/arith_uint256.h @@ -42,15 +42,7 @@ class base_uint { static_assert(BITS/32 > 0 && BITS%32 == 0, "Template parameter BITS must be a positive multiple of 32."); - for (int i = 0; i < WIDTH; i++) - pn[i] = b.pn[i]; - } - - base_uint& operator=(const base_uint& b) - { - for (int i = 0; i < WIDTH; i++) - pn[i] = b.pn[i]; - return *this; + (*this) = b; } base_uint(uint64_t b) @@ -63,6 +55,12 @@ class base_uint pn[i] = 0; } + template + base_uint(const base_uint& b) + { + (*this) = b; + } + explicit base_uint(const std::string& str); const base_uint operator~() const @@ -84,6 +82,24 @@ class base_uint double getdouble() const; + template + base_uint& operator=(const base_uint& b) + { + auto bits = std::min(BITS, BITS1); + for (int i = 0; i < bits; i++) + pn[i] = b.pn[i]; + for (int i = bits; i < BITS; i++) + pn[i] = 0; + return *this; + } + + base_uint& operator=(const base_uint& b) + { + for (int i = 0; i < WIDTH; i++) + pn[i] = b.pn[i]; + return *this; + } + base_uint& operator=(uint64_t b) { pn[0] = (unsigned int)b; @@ -165,6 +181,10 @@ class base_uint return *this; } + base_uint& operator*=(int32_t b32); + base_uint& operator*=(uint32_t b32); + base_uint& operator*=(int64_t b64); + base_uint& operator*=(uint64_t b64); base_uint& operator*=(const base_uint& b); base_uint& operator/=(const base_uint& b); @@ -214,6 +234,10 @@ class base_uint friend inline const base_uint operator^(const base_uint& a, const base_uint& b) { return base_uint(a) ^= b; } friend inline const base_uint operator>>(const base_uint& a, int shift) { return base_uint(a) >>= shift; } friend inline const base_uint operator<<(const base_uint& a, int shift) { return base_uint(a) <<= shift; } + friend inline const base_uint operator*(const base_uint& a, int32_t b) { return base_uint(a) *= b; } + friend inline const base_uint operator*(const base_uint& a, uint32_t b) { return base_uint(a) *= b; } + friend inline const base_uint operator*(const base_uint& a, int64_t b) { return base_uint(a) *= b; } + friend inline const base_uint operator*(const base_uint& a, uint64_t b) { return base_uint(a) *= b; } friend inline bool operator==(const base_uint& a, const base_uint& b) { return memcmp(a.pn, b.pn, sizeof(a.pn)) == 0; } friend inline bool operator!=(const base_uint& a, const base_uint& b) { return memcmp(a.pn, b.pn, sizeof(a.pn)) != 0; } friend inline bool operator>(const base_uint& a, const base_uint& b) { return a.CompareTo(b) > 0; } diff --git a/src/test/arith_uint256_tests.cpp b/src/test/arith_uint256_tests.cpp index 7133572c9d..2320b558dc 100644 --- a/src/test/arith_uint256_tests.cpp +++ b/src/test/arith_uint256_tests.cpp @@ -336,10 +336,14 @@ BOOST_AUTO_TEST_CASE( multiply ) BOOST_CHECK(MaxL * MaxL == OneL); - BOOST_CHECK((R1L * 0) == 0); - BOOST_CHECK((R1L * 1) == R1L); - BOOST_CHECK((R1L * 3).ToString() == "7759b1c0ed14047f961ad09b20ff83687876a0181a367b813634046f91def7d4"); + BOOST_CHECK((R1L * 0u) == 0u); + BOOST_CHECK((R1L * 1u) == R1L); + BOOST_CHECK((R1L * 3u).ToString() == "7759b1c0ed14047f961ad09b20ff83687876a0181a367b813634046f91def7d4"); BOOST_CHECK((R2L * 0x87654321UL).ToString() == "23f7816e30c4ae2017257b7a0fa64d60402f5234d46e746b61c960d09a26d070"); + + BOOST_CHECK((R1L * 10000000) == (R1L * arith_uint256(10000000))); + BOOST_CHECK((R1L * COIN) == (R1L * arith_uint256(COIN))); + BOOST_CHECK((R1L * (COIN * COIN)) == (R1L * arith_uint256(COIN * COIN))); } BOOST_AUTO_TEST_CASE( divide )