forked from OpenZeppelin/openzeppelin-contracts
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from cairoeth/single-lookup-log2
Use single lookup and binary search for log2
- Loading branch information
Showing
1 changed file
with
39 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -537,100 +537,47 @@ library Math { | |
* @dev Return the log in base 2 of a positive value rounded towards zero. | ||
* Returns 0 if given 0. | ||
*/ | ||
function log2(uint256 x) internal pure returns (uint256) { | ||
// Efficient branchless algorithm for compute floor(log2(x)), the algorithm works as follow: | ||
// | ||
// 1. First round down `x` to the closest power of 2 using an modified version of the Seander's `Round up | ||
// to the next power of 2` algorithm, the version used here is modified to round down instead of up. | ||
// ref: https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 | ||
// | ||
// 2. Compute `n = x mod 255`, the `n` is used to split the equation in two parts that can be more efficiently | ||
// computed using the `Logarithm's Product Rule`: | ||
// - lemma: log2(x) == log2(x / n) + log2(n) | ||
// this `n` guarantees that `x / n` is a power of two multiple of 256, we use this fact to build a lookup table | ||
// that calculates `log2(x / n)` and `log2(n)` very efficiently. | ||
// | ||
// 3. Compute `log2n = log2(n)`, given `x % 255` is a power of two, there are only 8 possible values for `n`, so | ||
// `log2(n)` can be easily computed using a lookup table, here we use the opcode `BYTE` to lookup a 32-byte word. | ||
// | ||
// 4. Compute `log2x_n = log2(x / n)`, we use the fact that `x / n` is a power of two multiple of 256, so there's exactly | ||
// 32 possible distinct values for `log2x_n`, respectively: 0, 8, 16, .. 240, 248. This allow an efficient lookup | ||
// table to be created using a single 32-byte word, we extract `log2(x/n)` from the table by compute: | ||
// log2x_n = log2(x / n) = (x / n * table) >> 248 | ||
// | ||
// 5. The final result is simply the sum of `log2n` and `log2x_n` calculated previously: | ||
// log2(x) = log2(x / n) + log2(n) = log2x_n + log2n | ||
// | ||
// @author Lohann Ferreira <[email protected]> | ||
unchecked { | ||
// Round `x` down to the closest power of 2 using an modified version of Seander's algorithm. | ||
// Reference: https://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2 | ||
x |= x >> 1; | ||
x |= x >> 2; | ||
x |= x >> 4; | ||
x |= x >> 8; | ||
x |= x >> 16; | ||
x |= x >> 32; | ||
x |= x >> 64; | ||
x |= x >> 128; | ||
// note `x = 0` results in 1 here, when the closest power of two is actually zero given | ||
// `2**-infinity == 0` then we should do `(x >> 1) + toUint(x > 0)` instead, but this is not | ||
// necessary given `floor(log2(0)) == floor(log2(1))` anyway. | ||
x = (x >> 1) + 1; | ||
|
||
// 1. Compute `n = x mod 255`, given `x` is power of two the resulting `n` can only be one of the | ||
// following values: 1, 2, 4, 8, 16, 32, 64 or 128. | ||
uint256 n = x % 255; | ||
|
||
// 2. Compute `log2n = log2(n)` using the OPCODE BYTE to lookup a 32-byte word, which we use as table. | ||
// Notice the last index in our table is 31, but `n` can be greater than 31, so first we map `n` | ||
// into an unique index between 0~31 by computing `n % 11`, as demonstrated below. | ||
|
||
// log2(n) Lookup Table | ||
// | n = x % 255 | index = n % 11 | table[index] = log2(n) | | ||
// |-------------|----------------|---------------------------| | ||
// | 1 | 1 % 11 == 1 | table[ 1] = log2( 1) = 0 | | ||
// | 2 | 2 % 11 == 2 | table[ 2] = log2( 2) = 1 | | ||
// | 4 | 4 % 11 == 4 | table[ 4] = log2( 4) = 2 | | ||
// | 8 | 8 % 11 == 8 | table[ 8] = log2( 8) = 3 | | ||
// | 16 | 16 % 11 == 5 | table[ 5] = log2( 16) = 4 | | ||
// | 32 | 32 % 11 == 10 | table[10] = log2( 32) = 5 | | ||
// | 64 | 64 % 11 == 9 | table[ 9] = log2( 64) = 6 | | ||
// | 128 | 128 % 11 == 7 | table[ 7] = log2(128) = 7 | | ||
// Below we perform the table lookup, the table stores the result of `log2(n)` | ||
// in the corresponding byte at `index`. | ||
uint256 log2n; | ||
assembly ("memory-safe") { | ||
log2n := byte(mod(n, 11), 0x0000010002040007030605000000000000000000000000000000000000000000) | ||
} | ||
|
||
// 4. Compute `log2x_n = log2(x / n)`, notice `x / n` is a power of two multiple of 256, we use this | ||
// fact to build a very efficient lookup table using a single 32-byte word, described below: | ||
|
||
// log2(x/n) Lookup Table | ||
// | x | x / n | log2x_n = log2(x/n) | index = log2(x/n) / 8 | | ||
// |-----------------|-------|---------------------|-----------------------| | ||
// | 1 ≤ x < 2⁸ | 1 | log2(1 ) == 0 | 0 / 8 == 0 | | ||
// | 2⁸ ≤ x < 2¹⁶ | 2⁸ | log2(2⁸ ) == 8 | 8 / 8 == 1 | | ||
// | 2¹⁶ ≤ x < 2²⁴ | 2¹⁶ | log2(2¹⁶ ) == 16 | 16 / 8 == 2 | | ||
// | 2²⁴ ≤ x < 2³² | 2²⁴ | log2(2²⁴ ) == 24 | 24 / 8 == 3 | | ||
// | ... | ... | ... | ... | | ||
// | 2²³² ≤ x < 2²⁴⁰ | 2²³² | log2(2²³²) == 232 | 232 / 8 == 29 | | ||
// | 2²⁴⁰ ≤ x < 2²⁴⁸ | 2²⁴⁰ | log2(2²⁴⁰) == 240 | 240 / 8 == 30 | | ||
// | 2²⁴⁸ ≤ x < 2²⁵⁶ | 2²⁴⁸ | log2(2²⁴⁸) == 248 | 248 / 8 == 31 | | ||
function log2(uint256 value) internal pure returns (uint256 result) { | ||
assembly { | ||
// If value has upper 128 bits set, log2 result is at least 128 | ||
result := shl(7, lt(0xffffffffffffffffffffffffffffffff, value)) | ||
// If upper 64 bits of 128-bit half set, add 64 to result | ||
result := or(result, shl(6, lt(0xffffffffffffffff, shr(result, value)))) | ||
// If upper 32 bits of 64-bit half set, add 32 to result | ||
result := or(result, shl(5, lt(0xffffffff, shr(result, value)))) | ||
// If upper 16 bits of 32-bit half set, add 16 to result | ||
result := or(result, shl(4, lt(0xffff, shr(result, value)))) | ||
// If upper 8 bits of 16-bit half set, add 8 to result | ||
result := or(result, shl(3, lt(0xff, shr(result, value)))) | ||
// If upper 4 bits of 8-bit half set, add 4 to result | ||
result := or(result, shl(2, lt(0xf, shr(result, value)))) | ||
// shr(result, value) shifts value right by the current result, isolating the last significant bits. | ||
// byte(...) uses the shifted value as an index into this lookup table: | ||
// | ||
// It is important to note that `value * (x / n)` is equal to `value << (8 * index)`. We use this fact | ||
// to build a single 32-byte word where `table[index] = log2(x/n)`. Multiply the table by `x/n` moves the | ||
// result to the most significant byte, which we can then extract `log2(x/n) by shifting right. | ||
// | x (4 bits) | index | table[index] = MSB position | | ||
// |------------|---------|-----------------------------| | ||
// | 0000 | 0 | table[0] = 0 | | ||
// | 0001 | 1 | table[1] = 0 | | ||
// | 0010 | 2 | table[2] = 1 | | ||
// | 0011 | 3 | table[3] = 1 | | ||
// | 0100 | 4 | table[4] = 2 | | ||
// | 0101 | 5 | table[5] = 2 | | ||
// | 0110 | 6 | table[6] = 2 | | ||
// | 0111 | 7 | table[7] = 2 | | ||
// | 1000 | 8 | table[8] = 3 | | ||
// | 1001 | 9 | table[9] = 3 | | ||
// | 1010 | 10 | table[10] = 3 | | ||
// | 1011 | 11 | table[11] = 3 | | ||
// | 1100 | 12 | table[12] = 3 | | ||
// | 1101 | 13 | table[13] = 3 | | ||
// | 1110 | 14 | table[14] = 3 | | ||
// | 1111 | 15 | table[15] = 3 | | ||
// | ||
// log2x_n = log2(x / n) = (x / n * table) >> 248 | ||
uint256 log2x_n = x >> log2n; | ||
log2x_n *= 0x0008101820283038404850586068707880889098a0a8b0b8c0c8d0d8e0e8f0f8; | ||
log2x_n >>= 248; | ||
|
||
// 5. Sum both values to get the final result: | ||
// log2(x) = log2(x / n) + log2(n) = log2x_n + log2n | ||
return log2x_n + log2n; | ||
// The lookup table is represented as a 32-byte value with the MSB positions for 0-15 in the last 16 bytes. | ||
result := or( | ||
result, | ||
byte(shr(result, value), 0x0000010102020202030303030303030300000000000000000000000000000000) | ||
) | ||
} | ||
} | ||
|
||
|