diff --git a/contracts/solidity/Utils/Math.sol b/contracts/solidity/Utils/Math.sol index b06d023..22c7e4b 100644 --- a/contracts/solidity/Utils/Math.sol +++ b/contracts/solidity/Utils/Math.sol @@ -10,7 +10,10 @@ library Math { * Constants */ // This is equal to 1 in our calculations - uint public constant ONE = 0x10000000000000000; + uint public constant ONE_SHIFT = 64; + uint public constant ONE = 0x10000000000000000; + uint public constant LN2 = 0xb17217f7d1cf79ac; + uint public constant LOG2_E = 0x171547652b82fe177; /* * Public functions @@ -23,76 +26,119 @@ library Math { constant returns (uint) { - /* This is equivalent to ln(2) */ - uint ln2 = 0xb17217f7d1cf79ac; - uint y = x * ONE / ln2; - uint shift = 2**(y / ONE); - uint z = y % ONE; + // Transform so that e^x = 2^x + x = x * ONE / LN2; + uint shift = x / ONE; + + // 2^x = 2^whole(x) * 2^frac(x) + // ^^^^^^^^^^ is a bit shift + // so Taylor expand on z = frac(x) + uint z = x % ONE; + + // 2^x = 1 + (ln 2) x + (ln 2)^2/2! x^2 + ... + // + // Can generate the z coefficients using mpmath and the following lines + // >>> from mpmath import mp + // >>> mp.dps = 100 + // >>> ONE = 0x10000000000000000 + // >>> print('\n'.join(hex(int(mp.log(2)**i / mp.factorial(i) * ONE)) for i in range(1, 7))) + // 0xb17217f7d1cf79ab + // 0x3d7f7bff058b1d50 + // 0xe35846b82505fc5 + // 0x276556df749cee5 + // 0x5761ff9e299cc4 + // 0xa184897c363c3 + uint zpow = z; uint result = ONE; - result += 0xb172182739bc0e46 * zpow / ONE; - zpow = zpow * z / ONE; - result += 0x3d7f78a624cfb9b5 * zpow / ONE; - zpow = zpow * z / ONE; - result += 0xe359bcfeb6e4531 * zpow / ONE; - zpow = zpow * z / ONE; - result += 0x27601df2fc048dc * zpow / ONE; - zpow = zpow * z / ONE; - result += 0x5808a728816ee8 * zpow / ONE; - zpow = zpow * z / ONE; - result += 0x95dedef350bc9 * zpow / ONE; - result += 0x16aee6e8ef; - return shift * result; - } - - /// @dev Returns natural logarithm value of given x - /// @param x x - /// @return Returns ln(x) - function ln(uint x) - public - constant - returns (uint) - { - uint log2e = 0x171547652b82fe177; - // binary search for floor(log2(x)) - uint ilog2 = floorLog2(x); - // lagrange interpolation for log2 - uint z = x / (2**ilog2); - uint zpow = ONE; - uint const = ONE * 10; - uint result = const; - result -= 0x443b9c5adb08cc45f * zpow / ONE; + result += 0xb17217f7d1cf79ab * zpow / ONE; zpow = zpow * z / ONE; - result += 0xf0a52590f17c71a3f * zpow / ONE; + result += 0x3d7f7bff058b1d50 * zpow / ONE; zpow = zpow * z / ONE; - result -= 0x2478f22e787502b023 * zpow / ONE; + result += 0xe35846b82505fc5 * zpow / ONE; zpow = zpow * z / ONE; - result += 0x48c6de1480526b8d4c * zpow / ONE; + result += 0x276556df749cee5 * zpow / ONE; zpow = zpow * z / ONE; - result -= 0x70c18cae824656408c * zpow / ONE; + result += 0x5761ff9e299cc4 * zpow / ONE; zpow = zpow * z / ONE; - result += 0x883c81ec0ce7abebb2 * zpow / ONE; + result += 0xa184897c363c3 * zpow / ONE; zpow = zpow * z / ONE; - result -= 0x81814da94fe52ca9f5 * zpow / ONE; + result += 0xffe5fe2c4586 * zpow / ONE; zpow = zpow * z / ONE; - result += 0x616361924625d1acf5 * zpow / ONE; + result += 0x162c0223a5c8 * zpow / ONE; zpow = zpow * z / ONE; - result -= 0x39f9a16fb9292a608d * zpow / ONE; + result += 0x1b5253d395e * zpow / ONE; zpow = zpow * z / ONE; - result += 0x1b3049a5740b21d65f * zpow / ONE; + result += 0x1e4cf5158b * zpow / ONE; zpow = zpow * z / ONE; - result -= 0x9ee1408bd5ad96f3e * zpow / ONE; + result += 0x1e8cac735 * zpow / ONE; zpow = zpow * z / ONE; - result += 0x2c465c91703b7a7f4 * zpow / ONE; + result += 0x1c3bd650 * zpow / ONE; zpow = zpow * z / ONE; - result -= 0x918d2d5f045a4d63 * zpow / ONE; + result += 0x1816193 * zpow / ONE; zpow = zpow * z / ONE; - result += 0x14ca095145f44f78 * zpow / ONE; + result += 0x131496 * zpow / ONE; zpow = zpow * z / ONE; - result -= 0x1d806fc412c1b99 * zpow / ONE; + result += 0xe1b7 * zpow / ONE; zpow = zpow * z / ONE; - result += 0x13950b4e1e89cc * zpow / ONE; - return (ilog2 * ONE + result - const) * ONE / log2e; + result += 0x9c7 * zpow / ONE; + return result << shift; + } + + /// @dev Returns natural logarithm value of given x + /// @param x x + /// @return Returns ln(x) + function ln(uint x) + public + constant + returns (int) + { + require(x > 0); + + // binary search for floor(log2(x)) + int ilog2 = floorLog2(x); + + int z; + if(ilog2 < 0) + z = int(x << uint(ilog2)); + else + z = int(x >> uint(ilog2)); + + int zpow = int(ONE); + int const = int(ONE) * 10; + int result = const; + result -= 0x443b9c5adb08cc45f * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result += 0xf0a52590f17c71a3f * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result -= 0x2478f22e787502b023 * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result += 0x48c6de1480526b8d4c * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result -= 0x70c18cae824656408c * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result += 0x883c81ec0ce7abebb2 * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result -= 0x81814da94fe52ca9f5 * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result += 0x616361924625d1acf5 * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result -= 0x39f9a16fb9292a608d * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result += 0x1b3049a5740b21d65f * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result -= 0x9ee1408bd5ad96f3e * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result += 0x2c465c91703b7a7f4 * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result -= 0x918d2d5f045a4d63 * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result += 0x14ca095145f44f78 * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result -= 0x1d806fc412c1b99 * zpow / int(ONE); + zpow = zpow * z / int(ONE); + result += 0x13950b4e1e89cc * zpow / int(ONE); + return (ilog2 * int(ONE) + result - const) * int(ONE) / int(LOG2_E); } /// @dev Returns base 2 logarithm value of given x @@ -101,14 +147,13 @@ library Math { function floorLog2(uint x) public constant - returns (uint lo) + returns (int lo) { - lo = 0; - uint y = x / ONE; - uint hi = 191; - uint mid = (hi + lo) / 2; + lo = -133; + int hi = 133; + int mid = (hi + lo) / 2; while ((lo + 1) != hi) { - if (y < 2**mid) + if (mid < 0 && x << uint(-mid) < ONE || mid >= 0 && x >> uint(mid) < ONE ) hi = mid; else lo = mid; diff --git a/contracts/tests/utils/test_math.py b/contracts/tests/utils/test_math.py index 59c838a..c9c285e 100644 --- a/contracts/tests/utils/test_math.py +++ b/contracts/tests/utils/test_math.py @@ -1,5 +1,19 @@ -from ..abstract_test import AbstractTestContract +from itertools import chain +from functools import partial + import math +import random + +from ethereum.tester import TransactionFailed + +from ..abstract_test import AbstractTestContract + +if hasattr(math, 'isclose'): + isclose = math.isclose +else: + # PEP 485 + def isclose(a, b, *, rel_tol=1e-9, abs_tol=0.0): + return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) class TestContract(AbstractTestContract): @@ -12,15 +26,32 @@ def __init__(self, *args, **kwargs): self.math = self.create_contract('Utils/Math.sol') def test(self): + ONE = 0x10000000000000000 + RELATIVE_TOLERANCE = 2.0**-1 + MAX_POWER = math.floor(math.log((2**256 - 1) / ONE) * ONE) + # LN - x = 2 - self.assertAlmostEqual(self.math.ln(x * 2 ** 64) / 2.0 ** 64, math.log(x), places=2) + self.assertRaises(TransactionFailed, partial(self.math.ln, 0)) + for x in chain( + (1, 2**254-1), + (random.randrange(1, ONE) for _ in range(10)), + (random.randrange(ONE, 2**256) for _ in range(10)), + ): + actual, expected = self.math.ln(x) / ONE, math.log(x / ONE) + assert isclose(actual, expected, rel_tol=RELATIVE_TOLERANCE) + # EXP - x = 10 - self.assertAlmostEqual(self.math.exp(x * 2 ** 64) / 2.0 ** 64, math.exp(x), places=2) + for x in chain( + (0, MAX_POWER), + (random.randrange(MAX_POWER) for _ in range(10)), + ): + actual, expected = self.math.exp(x) / ONE, math.exp(x / ONE) + assert isclose(actual, expected, rel_tol=RELATIVE_TOLERANCE) + # Safe to add self.assertFalse(self.math.safeToAdd(2**256 - 1, 1)) self.assertTrue(self.math.safeToAdd(1, 1)) + # Safe to subtract self.assertFalse(self.math.safeToSubtract(1, 2)) self.assertTrue(self.math.safeToSubtract(1, 1))