From 26572802743101f160f2d07556edfc162896115e Mon Sep 17 00:00:00 2001 From: t11s Date: Sun, 11 Sep 2022 00:33:18 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Optimized=20sqrt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gas-snapshot | 20 ++++---- src/utils/FixedPointMathLib.sol | 81 ++++++++++++++++++--------------- 2 files changed, 54 insertions(+), 47 deletions(-) diff --git a/.gas-snapshot b/.gas-snapshot index b867b503..df75175a 100644 --- a/.gas-snapshot +++ b/.gas-snapshot @@ -240,17 +240,17 @@ ERC721Test:testTransferFromApproveAll() (gas: 92898) ERC721Test:testTransferFromApproveAll(uint256,address) (runs: 256, μ: 93182, ~: 93182) ERC721Test:testTransferFromSelf() (gas: 64776) ERC721Test:testTransferFromSelf(uint256,address) (runs: 256, μ: 65061, ~: 65061) -FixedPointMathLibTest:testDifferentiallyFuzzSqrt(uint256) (runs: 256, μ: 14009, ~: 4421) +FixedPointMathLibTest:testDifferentiallyFuzzSqrt(uint256) (runs: 256, μ: 13827, ~: 4181) FixedPointMathLibTest:testDivWadDown() (gas: 841) -FixedPointMathLibTest:testDivWadDown(uint256,uint256) (runs: 256, μ: 717, ~: 820) +FixedPointMathLibTest:testDivWadDown(uint256,uint256) (runs: 256, μ: 718, ~: 820) FixedPointMathLibTest:testDivWadDownEdgeCases() (gas: 446) FixedPointMathLibTest:testDivWadUp() (gas: 1003) -FixedPointMathLibTest:testDivWadUp(uint256,uint256) (runs: 256, μ: 807, ~: 972) +FixedPointMathLibTest:testDivWadUp(uint256,uint256) (runs: 256, μ: 809, ~: 972) FixedPointMathLibTest:testDivWadUpEdgeCases() (gas: 462) -FixedPointMathLibTest:testFailDivWadDownOverflow(uint256,uint256) (runs: 256, μ: 444, ~: 419) +FixedPointMathLibTest:testFailDivWadDownOverflow(uint256,uint256) (runs: 256, μ: 443, ~: 419) FixedPointMathLibTest:testFailDivWadDownZeroDenominator() (gas: 342) FixedPointMathLibTest:testFailDivWadDownZeroDenominator(uint256) (runs: 256, μ: 397, ~: 397) -FixedPointMathLibTest:testFailDivWadUpOverflow(uint256,uint256) (runs: 256, μ: 399, ~: 374) +FixedPointMathLibTest:testFailDivWadUpOverflow(uint256,uint256) (runs: 256, μ: 398, ~: 374) FixedPointMathLibTest:testFailDivWadUpZeroDenominator() (gas: 342) FixedPointMathLibTest:testFailDivWadUpZeroDenominator(uint256) (runs: 256, μ: 396, ~: 396) FixedPointMathLibTest:testFailMulDivDownOverflow(uint256,uint256,uint256) (runs: 256, μ: 437, ~: 414) @@ -274,11 +274,11 @@ FixedPointMathLibTest:testMulWadUp() (gas: 981) FixedPointMathLibTest:testMulWadUp(uint256,uint256) (runs: 256, μ: 835, ~: 1073) FixedPointMathLibTest:testMulWadUpEdgeCases() (gas: 959) FixedPointMathLibTest:testRPow() (gas: 2164) -FixedPointMathLibTest:testSqrt() (gas: 3168) -FixedPointMathLibTest:testSqrt(uint256) (runs: 256, μ: 1102, ~: 1109) -FixedPointMathLibTest:testSqrtBack(uint256) (runs: 256, μ: 18628, ~: 340) -FixedPointMathLibTest:testSqrtBackHashed(uint256) (runs: 256, μ: 71987, ~: 72172) -FixedPointMathLibTest:testSqrtBackHashedSingle() (gas: 71489) +FixedPointMathLibTest:testSqrt() (gas: 2580) +FixedPointMathLibTest:testSqrt(uint256) (runs: 256, μ: 997, ~: 1013) +FixedPointMathLibTest:testSqrtBack(uint256) (runs: 256, μ: 15210, ~: 340) +FixedPointMathLibTest:testSqrtBackHashed(uint256) (runs: 256, μ: 59040, ~: 59500) +FixedPointMathLibTest:testSqrtBackHashedSingle() (gas: 58937) LibStringTest:testDifferentiallyFuzzToString(uint256,bytes) (runs: 256, μ: 20749, ~: 8925) LibStringTest:testToString() (gas: 10047) LibStringTest:testToStringDirty() (gas: 8123) diff --git a/src/utils/FixedPointMathLib.sol b/src/utils/FixedPointMathLib.sol index 6920c438..95d80de4 100644 --- a/src/utils/FixedPointMathLib.sol +++ b/src/utils/FixedPointMathLib.sol @@ -165,43 +165,51 @@ library FixedPointMathLib { function sqrt(uint256 x) internal pure returns (uint256 z) { assembly { - // Start off with z at 1. - z := 1 + let y := x // We start y at x, which will help us make our initial estimate. - // Used below to help find a nearby power of 2. - let y := x + z := 181 // The "correct" value is 1, but this saves a multiplication later. - // Find the lowest power of 2 that is at least sqrt(x). - if iszero(lt(y, 0x100000000000000000000000000000000)) { - y := shr(128, y) // Like dividing by 2 ** 128. - z := shl(64, z) // Like multiplying by 2 ** 64. - } - if iszero(lt(y, 0x10000000000000000)) { - y := shr(64, y) // Like dividing by 2 ** 64. - z := shl(32, z) // Like multiplying by 2 ** 32. - } - if iszero(lt(y, 0x100000000)) { - y := shr(32, y) // Like dividing by 2 ** 32. - z := shl(16, z) // Like multiplying by 2 ** 16. - } - if iszero(lt(y, 0x10000)) { - y := shr(16, y) // Like dividing by 2 ** 16. - z := shl(8, z) // Like multiplying by 2 ** 8. + // This segment is to get a reasonable initial estimate for the Babylonian method. With a bad + // start, the correct # of bits increases ~linearly each iteration instead of ~quadratically. + + // We check y >= 2^(k + 8) but shift right by k bits + // each branch to ensure that if x >= 256, then y >= 256. + if iszero(lt(y, 0x10000000000000000000000000000000000)) { + y := shr(128, y) + z := shl(64, z) } - if iszero(lt(y, 0x100)) { - y := shr(8, y) // Like dividing by 2 ** 8. - z := shl(4, z) // Like multiplying by 2 ** 4. + if iszero(lt(y, 0x1000000000000000000)) { + y := shr(64, y) + z := shl(32, z) } - if iszero(lt(y, 0x10)) { - y := shr(4, y) // Like dividing by 2 ** 4. - z := shl(2, z) // Like multiplying by 2 ** 2. + if iszero(lt(y, 0x10000000000)) { + y := shr(32, y) + z := shl(16, z) } - if iszero(lt(y, 0x8)) { - // Equivalent to 2 ** z. - z := shl(1, z) + if iszero(lt(y, 0x1000000)) { + y := shr(16, y) + z := shl(8, z) } - // Shifting right by 1 is like dividing by 2. + // Goal was to get z*z*y within a small factor of x. More iterations could + // get y in a tighter range. Currently, we will have y in [256, 256*2^16). + // We ensured y >= 256 so that the relative difference between y and y+1 is small. + // That's not possible if x < 256 but we can just verify those cases exhaustively. + + // Now, z*z*y <= x < z*z*(y+1), and y <= 2^(16+8), and either y >= 256, or x < 256. + // Correctness can be checked exhaustively for x < 256, so we assume y >= 256. + // Then z*sqrt(y) is within sqrt(257)/sqrt(256) of sqrt(x), or about 20bps. + + // For s in the range [1/256, 256], the estimate f(s) = (181/1024) * (s+1) is in the range + // (1/2.84 * sqrt(s), 2.84 * sqrt(s)), with largest error when s = 1 and when s = 256 or 1/256. + + // Since y is in [256, 256*2^16), let a = y/65536, so that a is in [1/256, 256). Then we can estimate + // sqrt(y) using sqrt(65536) * 181/1024 * (a + 1) = 181/4 * (y + 65536)/65536 = 181 * (y + 65536)/2^18. + + // There is no overflow risk here since y < 2^136 after the first branch above. + z := shr(18, mul(z, add(y, 65536))) // A mul() is saved from starting z at 181. + + // Given the worst case multiplicative error of 2.84 above, 7 iterations should be enough. z := shr(1, add(z, div(x, z))) z := shr(1, add(z, div(x, z))) z := shr(1, add(z, div(x, z))) @@ -210,13 +218,12 @@ library FixedPointMathLib { z := shr(1, add(z, div(x, z))) z := shr(1, add(z, div(x, z))) - // Compute a rounded down version of z. - let zRoundDown := div(x, z) - - // If zRoundDown is smaller, use it. - if lt(zRoundDown, z) { - z := zRoundDown - } + // If x+1 is a perfect square, the Babylonian method cycles between + // floor(sqrt(x)) and ceil(sqrt(x)). This statement ensures we return floor. + // See: https://en.wikipedia.org/wiki/Integer_square_root#Using_only_integer_division + // Since the ceil is rare, we save gas on the assignment and repeat division in the rare case. + // If you don't care whether the floor or ceil square root is returned, you can remove this statement. + z := sub(z, lt(div(x, z), z)) } }