Skip to content

Commit

Permalink
added delta overflow checks (Uniswap#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
aadams authored and hyunchel committed Feb 21, 2024
1 parent 1d00c91 commit 8c9c83f
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/Pool.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import {PoolManager} from "../src/PoolManager.sol";
import {Position} from "../src/libraries/Position.sol";
import {TickMath} from "../src/libraries/TickMath.sol";
import {TickBitmap} from "../src/libraries/TickBitmap.sol";
import {LiquidityAmounts} from "./utils/LiquidityAmounts.sol";
import {Constants} from "./utils/Constants.sol";
import {SafeCast} from "../src/libraries/SafeCast.sol";

contract PoolTest is Test {
using Pool for Pool.State;
Expand Down Expand Up @@ -67,6 +70,19 @@ contract PoolTest is Test {
vm.expectRevert(
abi.encodeWithSelector(TickBitmap.TickMisaligned.selector, params.tickUpper, params.tickSpacing)
);
} else {
// We need the assumptions above to calculate this
uint256 maxInt128InTypeU256 = uint256(uint128(Constants.MAX_UINT128));
(uint256 amount0, uint256 amount1) = LiquidityAmounts.getAmountsForLiquidity(
sqrtPriceX96,
TickMath.getSqrtRatioAtTick(params.tickLower),
TickMath.getSqrtRatioAtTick(params.tickUpper),
uint128(params.liquidityDelta)
);

if ((amount0 > maxInt128InTypeU256) || (amount1 > maxInt128InTypeU256)) {
vm.expectRevert(abi.encodeWithSelector(SafeCast.SafeCastOverflow.selector));
}
}

params.owner = address(this);
Expand Down
134 changes: 134 additions & 0 deletions test/utils/LiquidityAmounts.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.20;

import "../../src/libraries/FullMath.sol";
import "../../src/libraries/FixedPoint96.sol";

/// @title Liquidity amount functions
/// @notice Provides functions for computing liquidity amounts from token amounts and prices
library LiquidityAmounts {
/// @notice Downcasts uint256 to uint128
/// @param x The uint258 to be downcasted
/// @return y The passed value, downcasted to uint128
function toUint128(uint256 x) private pure returns (uint128 y) {
require((y = uint128(x)) == x);
}

/// @notice Computes the amount of liquidity received for a given amount of token0 and price range
/// @dev Calculates amount0 * (sqrt(upper) * sqrt(lower)) / (sqrt(upper) - sqrt(lower))
/// @param sqrtRatioAX96 A sqrt price representing the first tick boundary
/// @param sqrtRatioBX96 A sqrt price representing the second tick boundary
/// @param amount0 The amount0 being sent in
/// @return liquidity The amount of returned liquidity
function getLiquidityForAmount0(uint160 sqrtRatioAX96, uint160 sqrtRatioBX96, uint256 amount0)
internal
pure
returns (uint128 liquidity)
{
if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96);
uint256 intermediate = FullMath.mulDiv(sqrtRatioAX96, sqrtRatioBX96, FixedPoint96.Q96);
return toUint128(FullMath.mulDiv(amount0, intermediate, sqrtRatioBX96 - sqrtRatioAX96));
}

/// @notice Computes the amount of liquidity received for a given amount of token1 and price range
/// @dev Calculates amount1 / (sqrt(upper) - sqrt(lower)).
/// @param sqrtRatioAX96 A sqrt price representing the first tick boundary
/// @param sqrtRatioBX96 A sqrt price representing the second tick boundary
/// @param amount1 The amount1 being sent in
/// @return liquidity The amount of returned liquidity
function getLiquidityForAmount1(uint160 sqrtRatioAX96, uint160 sqrtRatioBX96, uint256 amount1)
internal
pure
returns (uint128 liquidity)
{
if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96);
return toUint128(FullMath.mulDiv(amount1, FixedPoint96.Q96, sqrtRatioBX96 - sqrtRatioAX96));
}

/// @notice Computes the maximum amount of liquidity received for a given amount of token0, token1, the current
/// pool prices and the prices at the tick boundaries
/// @param sqrtRatioX96 A sqrt price representing the current pool prices
/// @param sqrtRatioAX96 A sqrt price representing the first tick boundary
/// @param sqrtRatioBX96 A sqrt price representing the second tick boundary
/// @param amount0 The amount of token0 being sent in
/// @param amount1 The amount of token1 being sent in
/// @return liquidity The maximum amount of liquidity received
function getLiquidityForAmounts(
uint160 sqrtRatioX96,
uint160 sqrtRatioAX96,
uint160 sqrtRatioBX96,
uint256 amount0,
uint256 amount1
) internal pure returns (uint128 liquidity) {
if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96);

if (sqrtRatioX96 <= sqrtRatioAX96) {
liquidity = getLiquidityForAmount0(sqrtRatioAX96, sqrtRatioBX96, amount0);
} else if (sqrtRatioX96 < sqrtRatioBX96) {
uint128 liquidity0 = getLiquidityForAmount0(sqrtRatioX96, sqrtRatioBX96, amount0);
uint128 liquidity1 = getLiquidityForAmount1(sqrtRatioAX96, sqrtRatioX96, amount1);

liquidity = liquidity0 < liquidity1 ? liquidity0 : liquidity1;
} else {
liquidity = getLiquidityForAmount1(sqrtRatioAX96, sqrtRatioBX96, amount1);
}
}

/// @notice Computes the amount of token0 for a given amount of liquidity and a price range
/// @param sqrtRatioAX96 A sqrt price representing the first tick boundary
/// @param sqrtRatioBX96 A sqrt price representing the second tick boundary
/// @param liquidity The liquidity being valued
/// @return amount0 The amount of token0
function getAmount0ForLiquidity(uint160 sqrtRatioAX96, uint160 sqrtRatioBX96, uint128 liquidity)
internal
pure
returns (uint256 amount0)
{
if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96);

return FullMath.mulDiv(
uint256(liquidity) << FixedPoint96.RESOLUTION, sqrtRatioBX96 - sqrtRatioAX96, sqrtRatioBX96
) / sqrtRatioAX96;
}

/// @notice Computes the amount of token1 for a given amount of liquidity and a price range
/// @param sqrtRatioAX96 A sqrt price representing the first tick boundary
/// @param sqrtRatioBX96 A sqrt price representing the second tick boundary
/// @param liquidity The liquidity being valued
/// @return amount1 The amount of token1
function getAmount1ForLiquidity(uint160 sqrtRatioAX96, uint160 sqrtRatioBX96, uint128 liquidity)
internal
pure
returns (uint256 amount1)
{
if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96);

return FullMath.mulDiv(liquidity, sqrtRatioBX96 - sqrtRatioAX96, FixedPoint96.Q96);
}

/// @notice Computes the token0 and token1 value for a given amount of liquidity, the current
/// pool prices and the prices at the tick boundaries
/// @param sqrtRatioX96 A sqrt price representing the current pool prices
/// @param sqrtRatioAX96 A sqrt price representing the first tick boundary
/// @param sqrtRatioBX96 A sqrt price representing the second tick boundary
/// @param liquidity The liquidity being valued
/// @return amount0 The amount of token0
/// @return amount1 The amount of token1
function getAmountsForLiquidity(
uint160 sqrtRatioX96,
uint160 sqrtRatioAX96,
uint160 sqrtRatioBX96,
uint128 liquidity
) internal pure returns (uint256 amount0, uint256 amount1) {
if (sqrtRatioAX96 > sqrtRatioBX96) (sqrtRatioAX96, sqrtRatioBX96) = (sqrtRatioBX96, sqrtRatioAX96);

if (sqrtRatioX96 <= sqrtRatioAX96) {
amount0 = getAmount0ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity);
} else if (sqrtRatioX96 < sqrtRatioBX96) {
amount0 = getAmount0ForLiquidity(sqrtRatioX96, sqrtRatioBX96, liquidity);
amount1 = getAmount1ForLiquidity(sqrtRatioAX96, sqrtRatioX96, liquidity);
} else {
amount1 = getAmount1ForLiquidity(sqrtRatioAX96, sqrtRatioBX96, liquidity);
}
}
}

0 comments on commit 8c9c83f

Please sign in to comment.