Skip to content

Commit

Permalink
⚡️ Optimized sqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
transmissions11 committed Sep 11, 2022
1 parent 9266c69 commit 2657280
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 47 deletions.
20 changes: 10 additions & 10 deletions .gas-snapshot
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,17 @@ ERC721Test:testTransferFromApproveAll() (gas: 92898)
ERC721Test:testTransferFromApproveAll(uint256,address) (runs: 256, μ: 93182, ~: 93182)
ERC721Test:testTransferFromSelf() (gas: 64776)
ERC721Test:testTransferFromSelf(uint256,address) (runs: 256, μ: 65061, ~: 65061)
FixedPointMathLibTest:testDifferentiallyFuzzSqrt(uint256) (runs: 256, μ: 14009, ~: 4421)
FixedPointMathLibTest:testDifferentiallyFuzzSqrt(uint256) (runs: 256, μ: 13827, ~: 4181)
FixedPointMathLibTest:testDivWadDown() (gas: 841)
FixedPointMathLibTest:testDivWadDown(uint256,uint256) (runs: 256, μ: 717, ~: 820)
FixedPointMathLibTest:testDivWadDown(uint256,uint256) (runs: 256, μ: 718, ~: 820)
FixedPointMathLibTest:testDivWadDownEdgeCases() (gas: 446)
FixedPointMathLibTest:testDivWadUp() (gas: 1003)
FixedPointMathLibTest:testDivWadUp(uint256,uint256) (runs: 256, μ: 807, ~: 972)
FixedPointMathLibTest:testDivWadUp(uint256,uint256) (runs: 256, μ: 809, ~: 972)
FixedPointMathLibTest:testDivWadUpEdgeCases() (gas: 462)
FixedPointMathLibTest:testFailDivWadDownOverflow(uint256,uint256) (runs: 256, μ: 444, ~: 419)
FixedPointMathLibTest:testFailDivWadDownOverflow(uint256,uint256) (runs: 256, μ: 443, ~: 419)
FixedPointMathLibTest:testFailDivWadDownZeroDenominator() (gas: 342)
FixedPointMathLibTest:testFailDivWadDownZeroDenominator(uint256) (runs: 256, μ: 397, ~: 397)
FixedPointMathLibTest:testFailDivWadUpOverflow(uint256,uint256) (runs: 256, μ: 399, ~: 374)
FixedPointMathLibTest:testFailDivWadUpOverflow(uint256,uint256) (runs: 256, μ: 398, ~: 374)
FixedPointMathLibTest:testFailDivWadUpZeroDenominator() (gas: 342)
FixedPointMathLibTest:testFailDivWadUpZeroDenominator(uint256) (runs: 256, μ: 396, ~: 396)
FixedPointMathLibTest:testFailMulDivDownOverflow(uint256,uint256,uint256) (runs: 256, μ: 437, ~: 414)
Expand All @@ -274,11 +274,11 @@ FixedPointMathLibTest:testMulWadUp() (gas: 981)
FixedPointMathLibTest:testMulWadUp(uint256,uint256) (runs: 256, μ: 835, ~: 1073)
FixedPointMathLibTest:testMulWadUpEdgeCases() (gas: 959)
FixedPointMathLibTest:testRPow() (gas: 2164)
FixedPointMathLibTest:testSqrt() (gas: 3168)
FixedPointMathLibTest:testSqrt(uint256) (runs: 256, μ: 1102, ~: 1109)
FixedPointMathLibTest:testSqrtBack(uint256) (runs: 256, μ: 18628, ~: 340)
FixedPointMathLibTest:testSqrtBackHashed(uint256) (runs: 256, μ: 71987, ~: 72172)
FixedPointMathLibTest:testSqrtBackHashedSingle() (gas: 71489)
FixedPointMathLibTest:testSqrt() (gas: 2580)
FixedPointMathLibTest:testSqrt(uint256) (runs: 256, μ: 997, ~: 1013)
FixedPointMathLibTest:testSqrtBack(uint256) (runs: 256, μ: 15210, ~: 340)
FixedPointMathLibTest:testSqrtBackHashed(uint256) (runs: 256, μ: 59040, ~: 59500)
FixedPointMathLibTest:testSqrtBackHashedSingle() (gas: 58937)
LibStringTest:testDifferentiallyFuzzToString(uint256,bytes) (runs: 256, μ: 20749, ~: 8925)
LibStringTest:testToString() (gas: 10047)
LibStringTest:testToStringDirty() (gas: 8123)
Expand Down
81 changes: 44 additions & 37 deletions src/utils/FixedPointMathLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -165,43 +165,51 @@ library FixedPointMathLib {

function sqrt(uint256 x) internal pure returns (uint256 z) {
assembly {
// Start off with z at 1.
z := 1
let y := x // We start y at x, which will help us make our initial estimate.

// Used below to help find a nearby power of 2.
let y := x
z := 181 // The "correct" value is 1, but this saves a multiplication later.

// Find the lowest power of 2 that is at least sqrt(x).
if iszero(lt(y, 0x100000000000000000000000000000000)) {
y := shr(128, y) // Like dividing by 2 ** 128.
z := shl(64, z) // Like multiplying by 2 ** 64.
}
if iszero(lt(y, 0x10000000000000000)) {
y := shr(64, y) // Like dividing by 2 ** 64.
z := shl(32, z) // Like multiplying by 2 ** 32.
}
if iszero(lt(y, 0x100000000)) {
y := shr(32, y) // Like dividing by 2 ** 32.
z := shl(16, z) // Like multiplying by 2 ** 16.
}
if iszero(lt(y, 0x10000)) {
y := shr(16, y) // Like dividing by 2 ** 16.
z := shl(8, z) // Like multiplying by 2 ** 8.
// This segment is to get a reasonable initial estimate for the Babylonian method. With a bad
// start, the correct # of bits increases ~linearly each iteration instead of ~quadratically.

// We check y >= 2^(k + 8) but shift right by k bits
// each branch to ensure that if x >= 256, then y >= 256.
if iszero(lt(y, 0x10000000000000000000000000000000000)) {
y := shr(128, y)
z := shl(64, z)
}
if iszero(lt(y, 0x100)) {
y := shr(8, y) // Like dividing by 2 ** 8.
z := shl(4, z) // Like multiplying by 2 ** 4.
if iszero(lt(y, 0x1000000000000000000)) {
y := shr(64, y)
z := shl(32, z)
}
if iszero(lt(y, 0x10)) {
y := shr(4, y) // Like dividing by 2 ** 4.
z := shl(2, z) // Like multiplying by 2 ** 2.
if iszero(lt(y, 0x10000000000)) {
y := shr(32, y)
z := shl(16, z)
}
if iszero(lt(y, 0x8)) {
// Equivalent to 2 ** z.
z := shl(1, z)
if iszero(lt(y, 0x1000000)) {
y := shr(16, y)
z := shl(8, z)
}

// Shifting right by 1 is like dividing by 2.
// Goal was to get z*z*y within a small factor of x. More iterations could
// get y in a tighter range. Currently, we will have y in [256, 256*2^16).
// We ensured y >= 256 so that the relative difference between y and y+1 is small.
// That's not possible if x < 256 but we can just verify those cases exhaustively.

// Now, z*z*y <= x < z*z*(y+1), and y <= 2^(16+8), and either y >= 256, or x < 256.
// Correctness can be checked exhaustively for x < 256, so we assume y >= 256.
// Then z*sqrt(y) is within sqrt(257)/sqrt(256) of sqrt(x), or about 20bps.

// For s in the range [1/256, 256], the estimate f(s) = (181/1024) * (s+1) is in the range
// (1/2.84 * sqrt(s), 2.84 * sqrt(s)), with largest error when s = 1 and when s = 256 or 1/256.

// Since y is in [256, 256*2^16), let a = y/65536, so that a is in [1/256, 256). Then we can estimate
// sqrt(y) using sqrt(65536) * 181/1024 * (a + 1) = 181/4 * (y + 65536)/65536 = 181 * (y + 65536)/2^18.

// There is no overflow risk here since y < 2^136 after the first branch above.
z := shr(18, mul(z, add(y, 65536))) // A mul() is saved from starting z at 181.

// Given the worst case multiplicative error of 2.84 above, 7 iterations should be enough.
z := shr(1, add(z, div(x, z)))
z := shr(1, add(z, div(x, z)))
z := shr(1, add(z, div(x, z)))
Expand All @@ -210,13 +218,12 @@ library FixedPointMathLib {
z := shr(1, add(z, div(x, z)))
z := shr(1, add(z, div(x, z)))

// Compute a rounded down version of z.
let zRoundDown := div(x, z)

// If zRoundDown is smaller, use it.
if lt(zRoundDown, z) {
z := zRoundDown
}
// If x+1 is a perfect square, the Babylonian method cycles between
// floor(sqrt(x)) and ceil(sqrt(x)). This statement ensures we return floor.
// See: https://en.wikipedia.org/wiki/Integer_square_root#Using_only_integer_division
// Since the ceil is rare, we save gas on the assignment and repeat division in the rare case.
// If you don't care whether the floor or ceil square root is returned, you can remove this statement.
z := sub(z, lt(div(x, z), z))
}
}

Expand Down

0 comments on commit 2657280

Please sign in to comment.