diff --git a/contracts/libraries/Simulate.sol b/contracts/libraries/Simulate.sol new file mode 100644 index 000000000..c94674d66 --- /dev/null +++ b/contracts/libraries/Simulate.sol @@ -0,0 +1,196 @@ +// SPDX-License-Identifier: GPL-2.0-or-later +pragma solidity ^0.8.0; + +import {SwapMath} from './SwapMath.sol'; +import {SafeCast} from './SafeCast.sol'; +import {TickMath} from './TickMath.sol'; +import {TickBitmap} from './TickBitmap.sol'; +import {BitMath} from './BitMath.sol'; + +import {UniswapV3Pool} from '../UniswapV3Pool.sol'; + +import {IUniswapV3Pool} from '../interfaces/IUniswapV3Pool.sol'; + +/// @title Library for simulating swaps in a view function. +library Simulate { + using TickBitmapExtended for function(int16) external view returns (uint256); + using SafeCast for uint256; + + struct Cache { + // price at the beginning of the swap + uint160 sqrtPriceX96Start; + // tick at the beginning of the swap + int24 tickStart; + // liquidity at the beginning of the swap + uint128 liquidityStart; + // the lp fee of the pool + uint24 fee; + // the tick spacing of the pool + int24 tickSpacing; + } + + struct State { + // the amount remaining to be swapped in/out of the input/output asset + int256 amountSpecifiedRemaining; + // the amount already swapped out/in of the output/input asset + int256 amountCalculated; + // current sqrt(price) + uint160 sqrtPriceX96; + // the tick associated with the current price + int24 tick; + // the current liquidity in range + uint128 liquidity; + } + + function simulateSwap( + IUniswapV3Pool pool, + bool zeroForOne, + int256 amountSpecified, + uint160 sqrtPriceLimitX96 + ) internal view returns (int256 amount0, int256 amount1) { + require(amountSpecified != 0, 'AS'); + + (uint160 sqrtPriceX96, int24 tick, , , , , ) = pool.slot0(); + + require( + zeroForOne + ? sqrtPriceLimitX96 < sqrtPriceX96 && sqrtPriceLimitX96 > TickMath.MIN_SQRT_RATIO + : sqrtPriceLimitX96 > sqrtPriceX96 && sqrtPriceLimitX96 < TickMath.MAX_SQRT_RATIO, + 'SPL' + ); + + Cache memory cache = Cache({ + sqrtPriceX96Start: sqrtPriceX96, + tickStart: tick, + liquidityStart: pool.liquidity(), + fee: pool.fee(), + tickSpacing: pool.tickSpacing() + }); + + bool exactInput = amountSpecified > 0; + + State memory state = State({ + amountSpecifiedRemaining: amountSpecified, + amountCalculated: 0, + sqrtPriceX96: cache.sqrtPriceX96Start, + tick: cache.tickStart, + liquidity: cache.liquidityStart + }); + + while (state.amountSpecifiedRemaining != 0 && state.sqrtPriceX96 != sqrtPriceLimitX96) { + UniswapV3Pool.StepComputations memory step; + + step.sqrtPriceStartX96 = state.sqrtPriceX96; + + (step.tickNext, step.initialized) = pool.tickBitmap.nextInitializedTickWithinOneWord( + state.tick, + cache.tickSpacing, + zeroForOne + ); + + if (step.tickNext < TickMath.MIN_TICK) { + step.tickNext = TickMath.MIN_TICK; + } else if (step.tickNext > TickMath.MAX_TICK) { + step.tickNext = TickMath.MAX_TICK; + } + + step.sqrtPriceNextX96 = TickMath.getSqrtRatioAtTick(step.tickNext); + + (state.sqrtPriceX96, step.amountIn, step.amountOut, step.feeAmount) = SwapMath.computeSwapStep( + state.sqrtPriceX96, + (zeroForOne ? step.sqrtPriceNextX96 < sqrtPriceLimitX96 : step.sqrtPriceNextX96 > sqrtPriceLimitX96) + ? sqrtPriceLimitX96 + : step.sqrtPriceNextX96, + state.liquidity, + state.amountSpecifiedRemaining, + cache.fee + ); + + if (exactInput) { + unchecked { + state.amountSpecifiedRemaining -= (step.amountIn + step.feeAmount).toInt256(); + } + state.amountCalculated -= step.amountOut.toInt256(); + } else { + unchecked { + state.amountSpecifiedRemaining += step.amountOut.toInt256(); + } + state.amountCalculated += (step.amountIn + step.feeAmount).toInt256(); + } + + if (state.sqrtPriceX96 == step.sqrtPriceNextX96) { + if (step.initialized) { + (, int128 liquidityNet, , , , , , ) = pool.ticks(step.tickNext); + unchecked { + if (zeroForOne) liquidityNet = -liquidityNet; + } + + state.liquidity = liquidityNet < 0 + ? state.liquidity - uint128(-liquidityNet) + : state.liquidity + uint128(liquidityNet); + } + + unchecked { + state.tick = zeroForOne ? step.tickNext - 1 : step.tickNext; + } + } else if (state.sqrtPriceX96 != step.sqrtPriceStartX96) { + // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved + state.tick = TickMath.getTickAtSqrtRatio(state.sqrtPriceX96); + } + } + + (amount0, amount1) = zeroForOne == exactInput + ? (amountSpecified - state.amountSpecifiedRemaining, state.amountCalculated) + : (state.amountCalculated, amountSpecified - state.amountSpecifiedRemaining); + } +} + +/// @title Wrapper for TickBitmap that uses a function pointer instead of a bitmap +library TickBitmapExtended { + /// @notice Returns the next initialized tick contained in the same word (or adjacent word) as the tick that is either + /// to the left (less than or equal to) or right (greater than) of the given tick + /// @param self The function which fetches ticks + /// @param tick The starting tick + /// @param tickSpacing The spacing between usable ticks + /// @param lte Whether to search for the next initialized tick to the left (less than or equal to the starting tick) + /// @return next The next initialized or uninitialized tick up to 256 ticks away from the current tick + /// @return initialized Whether the next tick is initialized, as the function only searches within up to 256 ticks + function nextInitializedTickWithinOneWord( + function(int16) external view returns (uint256) self, + int24 tick, + int24 tickSpacing, + bool lte + ) internal view returns (int24 next, bool initialized) { + unchecked { + int24 compressed = tick / tickSpacing; + if (tick < 0 && tick % tickSpacing != 0) compressed--; // round towards negative infinity + + if (lte) { + (int16 wordPos, uint8 bitPos) = TickBitmap.position(compressed); + // all the 1s at or to the right of the current bitPos + uint256 mask = (1 << bitPos) - 1 + (1 << bitPos); + uint256 masked = self(wordPos) & mask; + + // if there are no initialized ticks to the right of or at the current tick, return rightmost in the word + initialized = masked != 0; + // overflow/underflow is possible, but prevented externally by limiting both tickSpacing and tick + next = initialized + ? (compressed - int24(uint24(bitPos - BitMath.mostSignificantBit(masked)))) * tickSpacing + : (compressed - int24(uint24(bitPos))) * tickSpacing; + } else { + // start from the word of the next tick, since the current tick state doesn't matter + (int16 wordPos, uint8 bitPos) = TickBitmap.position(compressed + 1); + // all the 1s at or to the left of the bitPos + uint256 mask = ~((1 << bitPos) - 1); + uint256 masked = self(wordPos) & mask; + + // if there are no initialized ticks to the left of the current tick, return leftmost in the word + initialized = masked != 0; + // overflow/underflow is possible, but prevented externally by limiting both tickSpacing and tick + next = initialized + ? (compressed + 1 + int24(uint24(BitMath.leastSignificantBit(masked) - bitPos))) * tickSpacing + : (compressed + 1 + int24(uint24(type(uint8).max - bitPos))) * tickSpacing; + } + } + } +} diff --git a/contracts/libraries/TickBitmap.sol b/contracts/libraries/TickBitmap.sol index 1cc5bb9ce..2239135ce 100644 --- a/contracts/libraries/TickBitmap.sol +++ b/contracts/libraries/TickBitmap.sol @@ -11,7 +11,7 @@ library TickBitmap { /// @param tick The tick for which to compute the position /// @return wordPos The key in the mapping containing the word in which the bit is stored /// @return bitPos The bit position in the word where the flag is stored - function position(int24 tick) private pure returns (int16 wordPos, uint8 bitPos) { + function position(int24 tick) internal pure returns (int16 wordPos, uint8 bitPos) { unchecked { wordPos = int16(tick >> 8); bitPos = uint8(int8(tick % 256));