diff --git a/.forge-snapshots/donate gas with 1 token.snap b/.forge-snapshots/donate gas with 1 token.snap index 3ad88ae0b..90e7d69c3 100644 --- a/.forge-snapshots/donate gas with 1 token.snap +++ b/.forge-snapshots/donate gas with 1 token.snap @@ -1 +1 @@ -139060 \ No newline at end of file +139280 \ No newline at end of file diff --git a/.forge-snapshots/donate gas with 2 tokens.snap b/.forge-snapshots/donate gas with 2 tokens.snap index 205667613..b9ab45c61 100644 --- a/.forge-snapshots/donate gas with 2 tokens.snap +++ b/.forge-snapshots/donate gas with 2 tokens.snap @@ -1 +1 @@ -186482 \ No newline at end of file +186702 \ No newline at end of file diff --git a/.forge-snapshots/gas overhead of no-op lock.snap b/.forge-snapshots/gas overhead of no-op lock.snap index fd0c301b5..1a78998d5 100644 --- a/.forge-snapshots/gas overhead of no-op lock.snap +++ b/.forge-snapshots/gas overhead of no-op lock.snap @@ -1 +1 @@ -15224 \ No newline at end of file +15246 \ No newline at end of file diff --git a/.forge-snapshots/initialize.snap b/.forge-snapshots/initialize.snap index 7351c1a9a..c63109919 100644 --- a/.forge-snapshots/initialize.snap +++ b/.forge-snapshots/initialize.snap @@ -1 +1 @@ -78563 \ No newline at end of file +75775 \ No newline at end of file diff --git a/.forge-snapshots/mint with empty hook.snap b/.forge-snapshots/mint with empty hook.snap index 6f5ed0e11..41a0ed1c7 100644 --- a/.forge-snapshots/mint with empty hook.snap +++ b/.forge-snapshots/mint with empty hook.snap @@ -1 +1 @@ -319063 \ No newline at end of file +318670 \ No newline at end of file diff --git a/.forge-snapshots/mint with native token.snap b/.forge-snapshots/mint with native token.snap index 7ed549245..163d757d5 100644 --- a/.forge-snapshots/mint with native token.snap +++ b/.forge-snapshots/mint with native token.snap @@ -1 +1 @@ -201844 \ No newline at end of file +201474 \ No newline at end of file diff --git a/.forge-snapshots/mint.snap b/.forge-snapshots/mint.snap index d86297139..20733ca89 100644 --- a/.forge-snapshots/mint.snap +++ b/.forge-snapshots/mint.snap @@ -1 +1 @@ -201764 \ No newline at end of file +201394 \ No newline at end of file diff --git a/.forge-snapshots/simple swap.snap b/.forge-snapshots/simple swap.snap index 563c04884..b894092bb 100644 --- a/.forge-snapshots/simple swap.snap +++ b/.forge-snapshots/simple swap.snap @@ -1 +1 @@ -205899 \ No newline at end of file +205733 \ No newline at end of file diff --git a/.forge-snapshots/swap against liquidity with native token.snap b/.forge-snapshots/swap against liquidity with native token.snap index 8e00a61bc..b9d2fc13c 100644 --- a/.forge-snapshots/swap against liquidity with native token.snap +++ b/.forge-snapshots/swap against liquidity with native token.snap @@ -1 +1 @@ -127924 \ No newline at end of file +127793 \ No newline at end of file diff --git a/.forge-snapshots/swap against liquidity.snap b/.forge-snapshots/swap against liquidity.snap index 4e05ab0b5..3780a96f4 100644 --- a/.forge-snapshots/swap against liquidity.snap +++ b/.forge-snapshots/swap against liquidity.snap @@ -1 +1 @@ -115415 \ No newline at end of file +115284 \ No newline at end of file diff --git a/.forge-snapshots/swap with dynamic fee.snap b/.forge-snapshots/swap with dynamic fee.snap index 654ec0f6e..cb5df9a3e 100644 --- a/.forge-snapshots/swap with dynamic fee.snap +++ b/.forge-snapshots/swap with dynamic fee.snap @@ -1 +1 @@ -192325 \ No newline at end of file +192160 \ No newline at end of file diff --git a/.forge-snapshots/swap with hooks.snap b/.forge-snapshots/swap with hooks.snap index b57fc3012..81bb0fc29 100644 --- a/.forge-snapshots/swap with hooks.snap +++ b/.forge-snapshots/swap with hooks.snap @@ -1 +1 @@ -115393 \ No newline at end of file +115262 \ No newline at end of file diff --git a/src/Fees.sol b/src/Fees.sol index 3eaed8f4b..5aee9ef2e 100644 --- a/src/Fees.sol +++ b/src/Fees.sol @@ -3,7 +3,6 @@ pragma solidity ^0.8.19; import {Currency, CurrencyLibrary} from "./types/Currency.sol"; import {IProtocolFeeController} from "./interfaces/IProtocolFeeController.sol"; -import {IHookFeeManager} from "./interfaces/IHookFeeManager.sol"; import {IFees} from "./interfaces/IFees.sol"; import {FeeLibrary} from "./libraries/FeeLibrary.sol"; import {Pool} from "./libraries/Pool.sol"; @@ -22,8 +21,6 @@ abstract contract Fees is IFees, Owned { mapping(Currency currency => uint256) public protocolFeesAccrued; - mapping(address hookAddress => mapping(Currency currency => uint256)) public hookFeesAccrued; - IProtocolFeeController public protocolFeeController; uint256 private immutable controllerGasLimit; @@ -35,15 +32,16 @@ abstract contract Fees is IFees, Owned { /// @notice Fetch the protocol fees for a given pool, returning false if the call fails or the returned fees are invalid. /// @dev to prevent an invalid protocol fee controller from blocking pools from being initialized /// the success of this function is NOT checked on initialize and if the call fails, the protocol fees are set to 0. - /// @dev the success of this function must be checked when called in setProtocolFees - function _fetchProtocolFees(PoolKey memory key) internal returns (bool success, uint24 protocolFees) { + /// @dev the success of this function must be checked when called in setProtocolFee + function _fetchProtocolFee(PoolKey memory key) internal returns (bool success, uint16 protocolFees) { if (address(protocolFeeController) != address(0)) { // note that EIP-150 mandates that calls requesting more than 63/64ths of remaining gas // will be allotted no more than this amount, so controllerGasLimit must be set with this // in mind. if (gasleft() < controllerGasLimit) revert ProtocolFeeCannotBeFetched(); + (bool _success, bytes memory _data) = address(protocolFeeController).call{gas: controllerGasLimit}( - abi.encodeWithSelector(IProtocolFeeController.protocolFeesForPool.selector, key) + abi.encodeWithSelector(IProtocolFeeController.protocolFeeForPool.selector, key) ); // Ensure that the return data fits within a word if (!_success || _data.length > 32) return (false, 0); @@ -52,44 +50,22 @@ abstract contract Fees is IFees, Owned { assembly { returnData := mload(add(_data, 0x20)) } - // Ensure return data does not overflow a uint24 and that the underlying fees are within bounds. - (success, protocolFees) = returnData == uint24(returnData) && _isValidProtocolFees(uint24(returnData)) - ? (true, uint24(returnData)) + // Ensure return data does not overflow a uint16 and that the underlying fees are within bounds. + (success, protocolFees) = returnData == uint16(returnData) && _isFeeWithinBounds(uint16(returnData)) + ? (true, uint16(returnData)) : (false, 0); } } - /// @notice There is no cap on the hook fee, but it is specified as a percentage taken on the amount after the protocol fee is applied, if there is a protocol fee. - function _fetchHookFees(PoolKey memory key) internal view returns (uint24 hookFees) { - if (address(key.hooks) != address(0)) { - try IHookFeeManager(address(key.hooks)).getHookFees(key) returns (uint24 hookFeesRaw) { - uint24 swapFeeMask = key.fee.hasHookSwapFee() ? 0xFFF000 : 0; - uint24 withdrawFeeMask = key.fee.hasHookWithdrawFee() ? 0xFFF : 0; - uint24 fullFeeMask = swapFeeMask | withdrawFeeMask; - hookFees = hookFeesRaw & fullFeeMask; - } catch {} - } - } - function _fetchDynamicSwapFee(PoolKey memory key) internal view returns (uint24 dynamicSwapFee) { dynamicSwapFee = IDynamicFeeManager(address(key.hooks)).getFee(msg.sender, key); if (dynamicSwapFee >= MAX_SWAP_FEE) revert FeeTooLarge(); } - function _isValidProtocolFees(uint24 protocolFees) internal pure returns (bool) { - if (protocolFees != 0) { - uint16 protocolSwapFee = uint16(protocolFees >> 12); - uint16 protocolWithdrawFee = uint16(protocolFees & 0xFFF); - return _isFeeWithinBounds(protocolSwapFee) && _isFeeWithinBounds(protocolWithdrawFee); - } - return true; - } - - /// @dev Only the lower 12 bits are used here to encode the fee denominator. function _isFeeWithinBounds(uint16 fee) internal pure returns (bool) { if (fee != 0) { - uint16 fee0 = fee % 64; - uint16 fee1 = fee >> 6; + uint16 fee0 = fee % 256; + uint16 fee1 = fee >> 8; // The fee is specified as a denominator so it cannot be LESS than the MIN_PROTOCOL_FEE_DENOMINATOR (unless it is 0). if ( (fee0 != 0 && fee0 < MIN_PROTOCOL_FEE_DENOMINATOR) || (fee1 != 0 && fee1 < MIN_PROTOCOL_FEE_DENOMINATOR) @@ -115,17 +91,4 @@ abstract contract Fees is IFees, Owned { protocolFeesAccrued[currency] -= amountCollected; currency.transfer(recipient, amountCollected); } - - function collectHookFees(address recipient, Currency currency, uint256 amount) - external - returns (uint256 amountCollected) - { - address hookAddress = msg.sender; - - amountCollected = (amount == 0) ? hookFeesAccrued[hookAddress][currency] : amount; - recipient = (recipient == address(0)) ? hookAddress : recipient; - - hookFeesAccrued[hookAddress][currency] -= amountCollected; - currency.transfer(recipient, amountCollected); - } } diff --git a/src/PoolManager.sol b/src/PoolManager.sol index a0a97bb55..b7f2170c5 100644 --- a/src/PoolManager.sol +++ b/src/PoolManager.sol @@ -13,7 +13,6 @@ import {NoDelegateCall} from "./NoDelegateCall.sol"; import {Owned} from "./Owned.sol"; import {IHooks} from "./interfaces/IHooks.sol"; import {IDynamicFeeManager} from "./interfaces/IDynamicFeeManager.sol"; -import {IHookFeeManager} from "./interfaces/IHookFeeManager.sol"; import {IPoolManager} from "./interfaces/IPoolManager.sol"; import {ILockCallback} from "./interfaces/callback/ILockCallback.sol"; import {Fees} from "./Fees.sol"; @@ -57,11 +56,11 @@ contract PoolManager is IPoolManager, Fees, NoDelegateCall, ERC6909Claims { external view override - returns (uint160 sqrtPriceX96, int24 tick, uint24 protocolFees, uint24 hookFees) + returns (uint160 sqrtPriceX96, int24 tick, uint16 protocolFee) { Pool.Slot0 memory slot0 = pools[id].slot0; - return (slot0.sqrtPriceX96, slot0.tick, slot0.protocolFees, slot0.hookFees); + return (slot0.sqrtPriceX96, slot0.tick, slot0.protocolFee); } /// @inheritdoc IPoolManager @@ -130,10 +129,10 @@ contract PoolManager is IPoolManager, Fees, NoDelegateCall, ERC6909Claims { } PoolId id = key.toId(); - (, uint24 protocolFees) = _fetchProtocolFees(key); + (, uint16 protocolFee) = _fetchProtocolFee(key); uint24 swapFee = key.fee.isDynamicFee() ? _fetchDynamicSwapFee(key) : key.fee.getStaticFee(); - tick = pools[id].initialize(sqrtPriceX96, protocolFees, _fetchHookFees(key), swapFee); + tick = pools[id].initialize(sqrtPriceX96, protocolFee, swapFee); if (key.hooks.shouldCallAfterInitialize()) { if ( @@ -217,8 +216,7 @@ contract PoolManager is IPoolManager, Fees, NoDelegateCall, ERC6909Claims { } } - Pool.FeeAmounts memory feeAmounts; - (delta, feeAmounts) = pools[id].modifyPosition( + delta = pools[id].modifyPosition( Pool.ModifyPositionParams({ owner: msg.sender, tickLower: params.tickLower, @@ -230,21 +228,6 @@ contract PoolManager is IPoolManager, Fees, NoDelegateCall, ERC6909Claims { _accountPoolBalanceDelta(key, delta); - unchecked { - if (feeAmounts.feeForProtocol0 > 0) { - protocolFeesAccrued[key.currency0] += feeAmounts.feeForProtocol0; - } - if (feeAmounts.feeForProtocol1 > 0) { - protocolFeesAccrued[key.currency1] += feeAmounts.feeForProtocol1; - } - if (feeAmounts.feeForHook0 > 0) { - hookFeesAccrued[address(key.hooks)][key.currency0] += feeAmounts.feeForHook0; - } - if (feeAmounts.feeForHook1 > 0) { - hookFeesAccrued[address(key.hooks)][key.currency1] += feeAmounts.feeForHook1; - } - } - if (key.hooks.shouldCallAfterModifyPosition()) { if ( key.hooks.afterModifyPosition(msg.sender, key, params, delta, hookData) @@ -286,10 +269,9 @@ contract PoolManager is IPoolManager, Fees, NoDelegateCall, ERC6909Claims { } uint256 feeForProtocol; - uint256 feeForHook; uint24 swapFee; Pool.SwapState memory state; - (delta, feeForProtocol, feeForHook, swapFee, state) = pools[id].swap( + (delta, feeForProtocol, swapFee, state) = pools[id].swap( Pool.SwapParams({ tickSpacing: key.tickSpacing, zeroForOne: params.zeroForOne, @@ -305,9 +287,6 @@ contract PoolManager is IPoolManager, Fees, NoDelegateCall, ERC6909Claims { if (feeForProtocol > 0) { protocolFeesAccrued[params.zeroForOne ? key.currency0 : key.currency1] += feeForProtocol; } - if (feeForHook > 0) { - hookFeesAccrued[address(key.hooks)][params.zeroForOne ? key.currency0 : key.currency1] += feeForHook; - } } if (key.hooks.shouldCallAfterSwap()) { @@ -391,19 +370,12 @@ contract PoolManager is IPoolManager, Fees, NoDelegateCall, ERC6909Claims { _burnFrom(from, currency.toId(), amount); } - function setProtocolFees(PoolKey memory key) external { - (bool success, uint24 newProtocolFees) = _fetchProtocolFees(key); + function setProtocolFee(PoolKey memory key) external { + (bool success, uint16 newProtocolFee) = _fetchProtocolFee(key); if (!success) revert ProtocolFeeControllerCallFailedOrInvalidResult(); PoolId id = key.toId(); - pools[id].setProtocolFees(newProtocolFees); - emit ProtocolFeeUpdated(id, newProtocolFees); - } - - function setHookFees(PoolKey memory key) external { - uint24 newHookFees = _fetchHookFees(key); - PoolId id = key.toId(); - pools[id].setHookFees(newHookFees); - emit HookFeeUpdated(id, newHookFees); + pools[id].setProtocolFee(newProtocolFee); + emit ProtocolFeeUpdated(id, newProtocolFee); } function updateDynamicSwapFee(PoolKey memory key) external { diff --git a/src/interfaces/IFees.sol b/src/interfaces/IFees.sol index 286240c74..68fd59942 100644 --- a/src/interfaces/IFees.sol +++ b/src/interfaces/IFees.sol @@ -20,7 +20,4 @@ interface IFees { /// @notice Given a currency address, returns the protocol fees accrued in that currency function protocolFeesAccrued(Currency) external view returns (uint256); - - /// @notice Given a hook and a currency address, returns the fees accrued - function hookFeesAccrued(address, Currency) external view returns (uint256); } diff --git a/src/interfaces/IHookFeeManager.sol b/src/interfaces/IHookFeeManager.sol deleted file mode 100644 index 8d491a4f6..000000000 --- a/src/interfaces/IHookFeeManager.sol +++ /dev/null @@ -1,13 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.20; - -import {PoolKey} from "../types/PoolKey.sol"; - -/// @notice The interface for setting a fee on swap or fee on withdraw to the hook -/// @dev This callback is only made if the Fee.HOOK_SWAP_FEE_FLAG or Fee.HOOK_WITHDRAW_FEE_FLAG in set in the pool's key.fee. -interface IHookFeeManager { - /// @notice Gets the fee a hook can take at swap/withdraw. Upper bits used for swap and lower bits for withdraw. - /// @param key The pool key - /// @return The hook fees for swapping (upper bits set) and withdrawing (lower bits set). - function getHookFees(PoolKey calldata key) external view returns (uint24); -} diff --git a/src/interfaces/IPoolManager.sol b/src/interfaces/IPoolManager.sol index 638976216..cbac02c4f 100644 --- a/src/interfaces/IPoolManager.sol +++ b/src/interfaces/IPoolManager.sol @@ -81,9 +81,7 @@ interface IPoolManager is IFees { uint24 fee ); - event ProtocolFeeUpdated(PoolId indexed id, uint24 protocolFees); - - event HookFeeUpdated(PoolId indexed id, uint24 hookFees); + event ProtocolFeeUpdated(PoolId indexed id, uint16 protocolFee); event DynamicSwapFeeUpdated(PoolId indexed id, uint24 dynamicSwapFee); @@ -94,10 +92,7 @@ interface IPoolManager is IFees { function MIN_TICK_SPACING() external view returns (int24); /// @notice Get the current value in slot0 of the given pool - function getSlot0(PoolId id) - external - view - returns (uint160 sqrtPriceX96, int24 tick, uint24 protocolFees, uint24 hookFees); + function getSlot0(PoolId id) external view returns (uint160 sqrtPriceX96, int24 tick, uint16 protocolFee); /// @notice Get the current value of liquidity of the given pool function getLiquidity(PoolId id) external view returns (uint128 liquidity); @@ -187,12 +182,9 @@ interface IPoolManager is IFees { /// @notice Called by the user to pay what is owed function settle(Currency token) external payable returns (uint256 paid); - /// @notice Sets the protocol's swap and withdrawal fees for the given pool - /// Protocol fees are always a portion of a fee that is owed. If that underlying fee is 0, no protocol fees will accrue even if it is set to > 0. - function setProtocolFees(PoolKey memory key) external; - - /// @notice Sets the hook's swap and withdrawal fees for the given pool - function setHookFees(PoolKey memory key) external; + /// @notice Sets the protocol's swap fee for the given pool + /// Protocol fees are always a portion of the LP swap fee that is owed. If that fee is 0, no protocol fees will accrue even if it is set to > 0. + function setProtocolFee(PoolKey memory key) external; /// @notice Updates the pools swap fees for the a pool that has enabled dynamic swap fees. function updateDynamicSwapFee(PoolKey memory key) external; diff --git a/src/interfaces/IProtocolFeeController.sol b/src/interfaces/IProtocolFeeController.sol index 18612a37f..fbe8799c8 100644 --- a/src/interfaces/IProtocolFeeController.sol +++ b/src/interfaces/IProtocolFeeController.sol @@ -7,5 +7,5 @@ interface IProtocolFeeController { /// @notice Returns the protocol fees for a pool given the conditions of this contract /// @param poolKey The pool key to identify the pool. The controller may want to use attributes on the pool /// to determine the protocol fee, hence the entire key is needed. - function protocolFeesForPool(PoolKey memory poolKey) external view returns (uint24); + function protocolFeeForPool(PoolKey memory poolKey) external view returns (uint16); } diff --git a/src/libraries/FeeLibrary.sol b/src/libraries/FeeLibrary.sol index 3b5484dd2..98718cceb 100644 --- a/src/libraries/FeeLibrary.sol +++ b/src/libraries/FeeLibrary.sol @@ -4,21 +4,11 @@ pragma solidity ^0.8.20; library FeeLibrary { uint24 public constant STATIC_FEE_MASK = 0x0FFFFF; uint24 public constant DYNAMIC_FEE_FLAG = 0x800000; // 1000 - uint24 public constant HOOK_SWAP_FEE_FLAG = 0x400000; // 0100 - uint24 public constant HOOK_WITHDRAW_FEE_FLAG = 0x200000; // 0010 function isDynamicFee(uint24 self) internal pure returns (bool) { return self & DYNAMIC_FEE_FLAG != 0; } - function hasHookSwapFee(uint24 self) internal pure returns (bool) { - return self & HOOK_SWAP_FEE_FLAG != 0; - } - - function hasHookWithdrawFee(uint24 self) internal pure returns (bool) { - return self & HOOK_WITHDRAW_FEE_FLAG != 0; - } - function isStaticFeeTooLarge(uint24 self) internal pure returns (bool) { return self & STATIC_FEE_MASK >= 1000000; } diff --git a/src/libraries/Hooks.sol b/src/libraries/Hooks.sol index 3577a4934..d2bdc8b3a 100644 --- a/src/libraries/Hooks.sol +++ b/src/libraries/Hooks.sol @@ -74,13 +74,11 @@ library Hooks { ) { return false; } - // If there is no hook contract set, then fee cannot be dynamic and there cannot be a hook fee on swap or withdrawal. + // If there is no hook contract set, then fee cannot be dynamic + // If a hook contract is set, it must have at least 1 flag set, or have a dynamic fee return address(hook) == address(0) - ? !fee.isDynamicFee() && !fee.hasHookSwapFee() && !fee.hasHookWithdrawFee() - : ( - uint160(address(hook)) >= ACCESS_LOCK_FLAG || fee.isDynamicFee() || fee.hasHookSwapFee() - || fee.hasHookWithdrawFee() - ); + ? !fee.isDynamicFee() + : (uint160(address(hook)) >= ACCESS_LOCK_FLAG || fee.isDynamicFee()); } function shouldCallBeforeInitialize(IHooks self) internal pure returns (bool) { diff --git a/src/libraries/Pool.sol b/src/libraries/Pool.sol index b93265c62..89b437373 100644 --- a/src/libraries/Pool.sol +++ b/src/libraries/Pool.sol @@ -59,22 +59,16 @@ library Pool { /// @notice Thrown by donate if there is currently 0 liquidity, since the fees will not go to any liquidity providers error NoLiquidityToReceiveFees(); - /// Each uint24 variable packs both the swap fees and the withdraw fees represented as integer denominators (1/x). The upper 12 bits are the swap fees, and the lower 12 bits - /// are the withdraw fees. For swap fees, the upper 6 bits are the fee for trading 1 for 0, and the lower 6 are for 0 for 1 and are taken as a percentage of the lp swap fee. - /// For withdraw fees the upper 6 bits are the fee on amount1, and the lower 6 are for amount0 and are taken as a percentage of the principle amount of the underlying position. - /// bits 24 22 20 18 16 14 12 10 8 6 4 2 0 - /// | swapFees | withdrawFees | - /// ┌────────┬────────┬────────┬────────┐ - /// protocolFees: | 1->0 | 0->1 | fee1 | fee0 | - /// hookFees: | 1->0 | 0->1 | fee1 | fee0 | - /// └────────┴────────┴────────┴────────┘ struct Slot0 { // the current price uint160 sqrtPriceX96; // the current tick int24 tick; - uint24 protocolFees; - uint24 hookFees; + // protocol swap fee represented as integer denominator (1/x), taken as a % of the LP swap fee + // upper 8 bits are for 1->0, and the lower 8 are for 0->1 + // the minimum permitted denominator is 4 - meaning the maximum protocol fee is 25% + // granularity is increments of 0.38% (100/type(uint8).max) + uint16 protocolFee; // used for the swap fee, either static at initialize or dynamic via hook uint24 swapFee; } @@ -109,7 +103,7 @@ library Pool { if (tickUpper > TickMath.MAX_TICK) revert TickUpperOutOfBounds(tickUpper); } - function initialize(State storage self, uint160 sqrtPriceX96, uint24 protocolFees, uint24 hookFees, uint24 swapFee) + function initialize(State storage self, uint160 sqrtPriceX96, uint16 protocolFee, uint24 swapFee) internal returns (int24 tick) { @@ -117,33 +111,13 @@ library Pool { tick = TickMath.getTickAtSqrtRatio(sqrtPriceX96); - self.slot0 = Slot0({ - sqrtPriceX96: sqrtPriceX96, - tick: tick, - protocolFees: protocolFees, - hookFees: hookFees, - swapFee: swapFee - }); - } - - function getSwapFee(uint24 feesStorage) internal pure returns (uint16) { - return uint16(feesStorage >> 12); - } - - function getWithdrawFee(uint24 feesStorage) internal pure returns (uint16) { - return uint16(feesStorage & 0xFFF); - } - - function setProtocolFees(State storage self, uint24 protocolFees) internal { - if (self.isNotInitialized()) revert PoolNotInitialized(); - - self.slot0.protocolFees = protocolFees; + self.slot0 = Slot0({sqrtPriceX96: sqrtPriceX96, tick: tick, protocolFee: protocolFee, swapFee: swapFee}); } - function setHookFees(State storage self, uint24 hookFees) internal { + function setProtocolFee(State storage self, uint16 protocolFee) internal { if (self.isNotInitialized()) revert PoolNotInitialized(); - self.slot0.hookFees = hookFees; + self.slot0.protocolFee = protocolFee; } /// @notice Only dynamic fee pools may update the swap fee. @@ -173,20 +147,13 @@ library Pool { uint256 feeGrowthInside1X128; } - struct FeeAmounts { - uint256 feeForProtocol0; - uint256 feeForProtocol1; - uint256 feeForHook0; - uint256 feeForHook1; - } - /// @notice Effect changes to a position in a pool /// @dev PoolManager checks that the pool is initialized before calling /// @param params the position details and the change to the position's liquidity to effect /// @return result the deltas of the token balances of the pool function modifyPosition(State storage self, ModifyPositionParams memory params) internal - returns (BalanceDelta result, FeeAmounts memory fees) + returns (BalanceDelta result) { checkTicks(params.tickLower, params.tickUpper); @@ -280,70 +247,15 @@ library Pool { } } - if (params.liquidityDelta < 0 && getWithdrawFee(self.slot0.hookFees) > 0) { - // Only take fees if the hook withdraw fee is set and the liquidity is being removed. - fees = _calculateExternalFees(self, result); - - // Amounts are balances owed to the pool. When negative, they represent the balance a user can take. - // Since protocol and hook fees are extracted on the balance a user can take - // they are owed (added) back to the pool where they are kept to be collected by the fee recipients. - result = result - + toBalanceDelta( - fees.feeForHook0.toInt128() + fees.feeForProtocol0.toInt128(), - fees.feeForHook1.toInt128() + fees.feeForProtocol1.toInt128() - ); - } - // Fees earned from LPing are removed from the pool balance. result = result - toBalanceDelta(feesOwed0.toInt128(), feesOwed1.toInt128()); } - function _calculateExternalFees(State storage self, BalanceDelta result) - internal - view - returns (FeeAmounts memory fees) - { - int128 amount0 = result.amount0(); - int128 amount1 = result.amount1(); - - Slot0 memory slot0Cache = self.slot0; - uint24 hookFees = slot0Cache.hookFees; - uint24 protocolFees = slot0Cache.protocolFees; - - uint16 hookFee0 = getWithdrawFee(hookFees) % 64; - uint16 hookFee1 = getWithdrawFee(hookFees) >> 6; - - uint16 protocolFee0 = getWithdrawFee(protocolFees) % 64; - uint16 protocolFee1 = getWithdrawFee(protocolFees) >> 6; - - if (amount0 < 0 && hookFee0 > 0) { - fees.feeForHook0 = uint128(-amount0) / hookFee0; - } - if (amount1 < 0 && hookFee1 > 0) { - fees.feeForHook1 = uint128(-amount1) / hookFee1; - } - - // A protocol fee is only applied if the hook fee is applied. - if (protocolFee0 > 0 && fees.feeForHook0 > 0) { - fees.feeForProtocol0 = fees.feeForHook0 / protocolFee0; - fees.feeForHook0 -= fees.feeForProtocol0; - } - - if (protocolFee1 > 0 && fees.feeForHook1 > 0) { - fees.feeForProtocol1 = fees.feeForHook1 / protocolFee1; - fees.feeForHook1 -= fees.feeForProtocol1; - } - - return fees; - } - struct SwapCache { // liquidity at the beginning of the swap uint128 liquidityStart; // the protocol fee for the input token - uint16 protocolFee; - // the hook fee for the input token - uint16 hookFee; + uint8 protocolFee; } // the top level state of the swap, the results of which are recorded in storage at the end @@ -390,13 +302,7 @@ library Pool { /// @dev PoolManager checks that the pool is initialized before calling function swap(State storage self, SwapParams memory params) internal - returns ( - BalanceDelta result, - uint256 feeForProtocol, - uint256 feeForHook, - uint24 swapFee, - SwapState memory state - ) + returns (BalanceDelta result, uint256 feeForProtocol, uint24 swapFee, SwapState memory state) { if (params.amountSpecified == 0) revert SwapAmountCannotBeZero(); @@ -420,10 +326,7 @@ library Pool { SwapCache memory cache = SwapCache({ liquidityStart: self.liquidity, - protocolFee: params.zeroForOne - ? (getSwapFee(slot0Start.protocolFees) % 64) - : (getSwapFee(slot0Start.protocolFees) >> 6), - hookFee: params.zeroForOne ? (getSwapFee(slot0Start.hookFees) % 64) : (getSwapFee(slot0Start.hookFees) >> 6) + protocolFee: params.zeroForOne ? uint8(slot0Start.protocolFee % 256) : uint8(slot0Start.protocolFee >> 8) }); bool exactInput = params.amountSpecified > 0; @@ -492,15 +395,6 @@ library Pool { } } - if (cache.hookFee > 0) { - // step.feeAmount has already been updated to account for the protocol fee - uint256 delta = step.feeAmount / cache.hookFee; - unchecked { - step.feeAmount -= delta; - feeForHook += delta; - } - } - // update global fee tracker if (state.liquidity > 0) { unchecked { diff --git a/src/test/AccessLockHook.sol b/src/test/AccessLockHook.sol index 97d7d73a9..ff2d07ebe 100644 --- a/src/test/AccessLockHook.sol +++ b/src/test/AccessLockHook.sol @@ -13,6 +13,7 @@ import {ILockCallback} from "../interfaces/callback/ILockCallback.sol"; import {MockERC20} from "solmate/test/utils/mocks/MockERC20.sol"; import {Constants} from "../../test/utils/Constants.sol"; import {PoolIdLibrary} from "../types/PoolId.sol"; +import {BalanceDelta} from "../types/BalanceDelta.sol"; contract AccessLockHook is Test, BaseTestHooks { using PoolIdLibrary for PoolKey; @@ -221,3 +222,59 @@ contract AccessLockHook3 is Test, ILockCallback, BaseTestHooks { return data; } } + +contract AccessLockFeeHook is Test, BaseTestHooks { + IPoolManager manager; + + uint256 constant WITHDRAWAL_FEE_BIPS = 40; // 40/10000 = 0.4% + uint256 constant SWAP_FEE_BIPS = 55; // 55/10000 = 0.55% + uint256 constant TOTAL_BIPS = 10000; + + constructor(IPoolManager _manager) { + manager = _manager; + } + + function afterModifyPosition( + address, /* sender **/ + PoolKey calldata key, + IPoolManager.ModifyPositionParams calldata, /* params **/ + BalanceDelta delta, + bytes calldata /* hookData **/ + ) external override returns (bytes4) { + int128 amount0 = delta.amount0(); + int128 amount1 = delta.amount1(); + + // positive delta => user owes money => liquidity addition + if (amount0 >= 0 && amount1 >= 0) return IHooks.afterModifyPosition.selector; + + // negative delta => user is owed money => liquidity withdrawal + uint256 amount0Fee = uint128(-amount0) * WITHDRAWAL_FEE_BIPS / TOTAL_BIPS; + uint256 amount1Fee = uint128(-amount1) * WITHDRAWAL_FEE_BIPS / TOTAL_BIPS; + + manager.take(key.currency0, address(this), amount0Fee); + manager.take(key.currency1, address(this), amount1Fee); + + return IHooks.afterModifyPosition.selector; + } + + function afterSwap( + address, /* sender **/ + PoolKey calldata key, + IPoolManager.SwapParams calldata params, + BalanceDelta delta, + bytes calldata /* hookData **/ + ) external override returns (bytes4) { + int128 amount0 = delta.amount0(); + int128 amount1 = delta.amount1(); + + // fee on output token - output delta will be negative + (Currency feeCurrency, uint256 outputAmount) = + (params.zeroForOne) ? (key.currency1, uint128(-amount1)) : (key.currency0, uint128(-amount0)); + + uint256 feeAmount = outputAmount * SWAP_FEE_BIPS / TOTAL_BIPS; + + manager.take(feeCurrency, address(this), feeAmount); + + return IHooks.afterSwap.selector; + } +} diff --git a/src/test/MockHooks.sol b/src/test/MockHooks.sol index e6a9ffa23..f1865e7a1 100644 --- a/src/test/MockHooks.sol +++ b/src/test/MockHooks.sol @@ -6,10 +6,9 @@ import {IHooks} from "../interfaces/IHooks.sol"; import {IPoolManager} from "../interfaces/IPoolManager.sol"; import {PoolKey} from "../types/PoolKey.sol"; import {BalanceDelta} from "../types/BalanceDelta.sol"; -import {IHookFeeManager} from "../interfaces/IHookFeeManager.sol"; import {PoolId, PoolIdLibrary} from "../types/PoolId.sol"; -contract MockHooks is IHooks, IHookFeeManager { +contract MockHooks is IHooks { using PoolIdLibrary for PoolKey; using Hooks for IHooks; @@ -26,8 +25,6 @@ contract MockHooks is IHooks, IHookFeeManager { mapping(PoolId => uint16) public swapFees; - mapping(PoolId => uint16) public withdrawFees; - function beforeInitialize(address, PoolKey calldata, uint160, bytes calldata hookData) external override @@ -113,10 +110,6 @@ contract MockHooks is IHooks, IHookFeeManager { return returnValues[selector] == bytes4(0) ? selector : returnValues[selector]; } - function getHookFees(PoolKey calldata key) external view override returns (uint24) { - return (uint24(swapFees[key.toId()]) << 12 | withdrawFees[key.toId()]); - } - function setReturnValue(bytes4 key, bytes4 value) external { returnValues[key] = value; } @@ -124,8 +117,4 @@ contract MockHooks is IHooks, IHookFeeManager { function setSwapFee(PoolKey calldata key, uint16 value) external { swapFees[key.toId()] = value; } - - function setWithdrawFee(PoolKey calldata key, uint16 value) external { - withdrawFees[key.toId()] = value; - } } diff --git a/src/test/PoolModifyPositionTest.sol b/src/test/PoolModifyPositionTest.sol index 5dc8be435..c3b7a2365 100644 --- a/src/test/PoolModifyPositionTest.sol +++ b/src/test/PoolModifyPositionTest.sol @@ -53,7 +53,7 @@ contract PoolModifyPositionTest is Test, PoolTestBase { (,,, int256 delta1) = _fetchBalances(data.key.currency1, data.sender); // These assertions only apply in non lock-accessing pools. - if (!data.key.hooks.hasPermissionToAccessLock() && !data.key.fee.hasHookWithdrawFee()) { + if (!data.key.hooks.hasPermissionToAccessLock()) { if (data.params.liquidityDelta > 0) { assert(delta0 > 0 || delta1 > 0 || data.key.hooks.hasPermissionToNoOp()); assert(!(delta0 < 0 || delta1 < 0)); diff --git a/src/test/ProtocolFeeControllerTest.sol b/src/test/ProtocolFeeControllerTest.sol index ebb82079b..1b46080d8 100644 --- a/src/test/ProtocolFeeControllerTest.sol +++ b/src/test/ProtocolFeeControllerTest.sol @@ -10,32 +10,27 @@ contract ProtocolFeeControllerTest is IProtocolFeeController { using PoolIdLibrary for PoolKey; mapping(PoolId => uint16) public swapFeeForPool; - mapping(PoolId => uint16) public withdrawFeeForPool; - function protocolFeesForPool(PoolKey memory key) external view returns (uint24) { - return (uint24(swapFeeForPool[key.toId()]) << 12 | withdrawFeeForPool[key.toId()]); + function protocolFeeForPool(PoolKey memory key) external view returns (uint16) { + return swapFeeForPool[key.toId()]; } // for tests to set pool protocol fees function setSwapFeeForPool(PoolId id, uint16 fee) external { swapFeeForPool[id] = fee; } - - function setWithdrawFeeForPool(PoolId id, uint16 fee) external { - withdrawFeeForPool[id] = fee; - } } /// @notice Reverts on call contract RevertingProtocolFeeControllerTest is IProtocolFeeController { - function protocolFeesForPool(PoolKey memory /* key */ ) external pure returns (uint24) { + function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint16) { revert(); } } /// @notice Returns an out of bounds protocol fee contract OutOfBoundsProtocolFeeControllerTest is IProtocolFeeController { - function protocolFeesForPool(PoolKey memory /* key */ ) external pure returns (uint24) { + function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint16) { // set both swap and withdraw fees to 1, which is less than MIN_PROTOCOL_FEE_DENOMINATOR return 0x001001; } @@ -43,7 +38,7 @@ contract OutOfBoundsProtocolFeeControllerTest is IProtocolFeeController { /// @notice Return a value that overflows a uint24 contract OverflowProtocolFeeControllerTest is IProtocolFeeController { - function protocolFeesForPool(PoolKey memory /* key */ ) external pure returns (uint24) { + function protocolFeeForPool(PoolKey memory /* key */ ) external pure returns (uint16) { assembly { let ptr := mload(0x40) mstore(ptr, 0xFFFFAAA001) @@ -54,7 +49,7 @@ contract OverflowProtocolFeeControllerTest is IProtocolFeeController { /// @notice Returns data that is larger than a word contract InvalidReturnSizeProtocolFeeControllerTest is IProtocolFeeController { - function protocolFeesForPool(PoolKey memory /* key */ ) external view returns (uint24) { + function protocolFeeForPool(PoolKey memory /* key */ ) external view returns (uint16) { address a = address(this); assembly { let ptr := mload(0x40) diff --git a/test/AccessLock.t.sol b/test/AccessLock.t.sol index 0fc6b8ea5..d7cb7f1c5 100644 --- a/test/AccessLock.t.sol +++ b/test/AccessLock.t.sol @@ -2,7 +2,7 @@ pragma solidity ^0.8.20; import {Test} from "forge-std/Test.sol"; -import {AccessLockHook, AccessLockHook2, AccessLockHook3} from "../src/test/AccessLockHook.sol"; +import {AccessLockHook, AccessLockHook2, AccessLockHook3, AccessLockFeeHook} from "../src/test/AccessLockHook.sol"; import {IPoolManager} from "../src/interfaces/IPoolManager.sol"; import {PoolModifyPositionTest} from "../src/test/PoolModifyPositionTest.sol"; import {PoolSwapTest} from "../src/test/PoolSwapTest.sol"; @@ -14,7 +14,7 @@ import {Currency, CurrencyLibrary} from "../src/types/Currency.sol"; import {MockERC20} from "solmate/test/utils/mocks/MockERC20.sol"; import {Hooks} from "../src/libraries/Hooks.sol"; import {IHooks} from "../src/interfaces/IHooks.sol"; -import {BalanceDelta} from "../src/types/BalanceDelta.sol"; +import {BalanceDelta, BalanceDeltaLibrary} from "../src/types/BalanceDelta.sol"; import {Pool} from "../src/libraries/Pool.sol"; import {TickMath} from "../src/libraries/TickMath.sol"; import {PoolIdLibrary} from "../src/types/PoolId.sol"; @@ -23,12 +23,17 @@ contract AccessLockTest is Test, Deployers { using Pool for Pool.State; using CurrencyLibrary for Currency; using PoolIdLibrary for PoolKey; + using BalanceDeltaLibrary for BalanceDelta; AccessLockHook accessLockHook; AccessLockHook noAccessLockHook; AccessLockHook2 accessLockHook2; AccessLockHook3 accessLockHook3; - AccessLockHook accessLockHook4; + AccessLockHook accessLockNoOpHook; + AccessLockFeeHook accessLockFeeHook; + + // global for stack too deep errors + BalanceDelta delta; uint128 amount = 1e18; @@ -47,9 +52,8 @@ contract AccessLockTest is Test, Deployers { deployCodeTo("AccessLockHook.sol:AccessLockHook", abi.encode(manager), accessLockAddress); accessLockHook = AccessLockHook(accessLockAddress); - (key,) = initPool( - currency0, currency1, IHooks(address(accessLockHook)), Constants.FEE_MEDIUM, SQRT_RATIO_1_1, ZERO_BYTES - ); + (key,) = + initPool(currency0, currency1, IHooks(accessLockAddress), Constants.FEE_MEDIUM, SQRT_RATIO_1_1, ZERO_BYTES); // Create AccessLockHook2. address accessLockAddress2 = address(uint160(Hooks.ACCESS_LOCK_FLAG | Hooks.BEFORE_MODIFY_POSITION_FLAG)); @@ -70,14 +74,20 @@ contract AccessLockTest is Test, Deployers { noAccessLockHook = AccessLockHook(noAccessLockHookAddress); // Create AccessLockHook with NoOp. - address accessLockHook4Address = address( + address accessLockNoOpHookAddress = address( uint160( Hooks.NO_OP_FLAG | Hooks.ACCESS_LOCK_FLAG | Hooks.BEFORE_INITIALIZE_FLAG | Hooks.BEFORE_SWAP_FLAG | Hooks.BEFORE_MODIFY_POSITION_FLAG | Hooks.BEFORE_DONATE_FLAG ) ); - deployCodeTo("AccessLockHook.sol:AccessLockHook", abi.encode(manager), accessLockHook4Address); - accessLockHook4 = AccessLockHook(accessLockHook4Address); + deployCodeTo("AccessLockHook.sol:AccessLockHook", abi.encode(manager), accessLockNoOpHookAddress); + accessLockNoOpHook = AccessLockHook(accessLockNoOpHookAddress); + + // Create AccessLockFeeHook + address accessLockFeeHookAddress = + address(uint160(Hooks.ACCESS_LOCK_FLAG | Hooks.AFTER_SWAP_FLAG | Hooks.AFTER_MODIFY_POSITION_FLAG)); + deployCodeTo("AccessLockHook.sol:AccessLockFeeHook", abi.encode(manager), accessLockFeeHookAddress); + accessLockFeeHook = AccessLockFeeHook(accessLockFeeHookAddress); } function test_onlyByLocker_revertsForNoAccessLockPool() public { @@ -124,7 +134,7 @@ contract AccessLockTest is Test, Deployers { uint256 balanceOfBefore1 = MockERC20(Currency.unwrap(currency1)).balanceOf(address(this)); uint256 balanceOfBefore0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); - BalanceDelta delta = modifyPositionRouter.modifyPosition( + delta = modifyPositionRouter.modifyPosition( key, IPoolManager.ModifyPositionParams(0, 60, 1e18), abi.encode(amount, AccessLockHook.LockAction.Mint) ); @@ -150,7 +160,7 @@ contract AccessLockTest is Test, Deployers { uint256 balanceOfBefore0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); // Hook only takes currency 1 rn. - BalanceDelta delta = modifyPositionRouter.modifyPosition( + delta = modifyPositionRouter.modifyPosition( key, IPoolManager.ModifyPositionParams(-60, 60, 1e18), abi.encode(amount, AccessLockHook.LockAction.Take) ); uint256 balanceOfAfter0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); @@ -235,7 +245,7 @@ contract AccessLockTest is Test, Deployers { ZERO_BYTES ); - BalanceDelta delta = swapRouter.swap( + delta = swapRouter.swap( key, IPoolManager.SwapParams(true, 10000, TickMath.MIN_SQRT_RATIO + 1), PoolSwapTest.TestSettings({withdrawTokens: false, settleUsingTransfer: true, currencyAlreadySent: false}), @@ -310,7 +320,7 @@ contract AccessLockTest is Test, Deployers { uint256 balanceOfBefore0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); // Small amount to swap (like NoOp). This way we can expect balances to just be from the hook applied delta. - BalanceDelta delta = swapRouter.swap( + delta = swapRouter.swap( key, IPoolManager.SwapParams(true, 1, TickMath.MIN_SQRT_RATIO + 1), PoolSwapTest.TestSettings({withdrawTokens: true, settleUsingTransfer: true, currencyAlreadySent: false}), @@ -340,7 +350,7 @@ contract AccessLockTest is Test, Deployers { // Hook only takes currency 1 rn. // Use small amount to NoOp. - BalanceDelta delta = swapRouter.swap( + delta = swapRouter.swap( key, IPoolManager.SwapParams(true, 1, TickMath.MIN_SQRT_RATIO + 1), PoolSwapTest.TestSettings({withdrawTokens: true, settleUsingTransfer: true, currencyAlreadySent: false}), @@ -453,7 +463,7 @@ contract AccessLockTest is Test, Deployers { uint256 balanceOfBefore1 = MockERC20(Currency.unwrap(currency1)).balanceOf(address(this)); uint256 balanceOfBefore0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); - BalanceDelta delta = donateRouter.donate(key, 1e18, 1e18, abi.encode(amount, AccessLockHook.LockAction.Mint)); + delta = donateRouter.donate(key, 1e18, 1e18, abi.encode(amount, AccessLockHook.LockAction.Mint)); uint256 balanceOfAfter0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); uint256 balanceOfAfter1 = MockERC20(Currency.unwrap(currency1)).balanceOf(address(this)); @@ -477,7 +487,7 @@ contract AccessLockTest is Test, Deployers { uint256 balanceOfBefore0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); // Hook only takes currency 1 rn. - BalanceDelta delta = donateRouter.donate(key, 1e18, 1e18, abi.encode(amount, AccessLockHook.LockAction.Take)); + delta = donateRouter.donate(key, 1e18, 1e18, abi.encode(amount, AccessLockHook.LockAction.Take)); // Take applies a positive delta in currency1. // Donate applies a positive delta in currency0 and currency1. uint256 balanceOfAfter0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); @@ -566,7 +576,7 @@ contract AccessLockTest is Test, Deployers { currency1: currency1, fee: Constants.FEE_MEDIUM, tickSpacing: 60, - hooks: IHooks(address(accessLockHook4)) + hooks: IHooks(address(accessLockNoOpHook)) }); initializeRouter.initialize(key1, SQRT_RATIO_1_1, abi.encode(amount, AccessLockHook.LockAction.Mint)); @@ -580,7 +590,7 @@ contract AccessLockTest is Test, Deployers { currency1: currency1, fee: Constants.FEE_MEDIUM, tickSpacing: 60, - hooks: IHooks(address(accessLockHook4)) + hooks: IHooks(address(accessLockNoOpHook)) }); // Add liquidity to a different pool there is something to take. @@ -592,7 +602,7 @@ contract AccessLockTest is Test, Deployers { initializeRouter.initialize(key1, SQRT_RATIO_1_1, abi.encode(amount, AccessLockHook.LockAction.Take)); - assertEq(MockERC20(Currency.unwrap(currency1)).balanceOf(address(accessLockHook4)), amount); + assertEq(MockERC20(Currency.unwrap(currency1)).balanceOf(address(accessLockNoOpHook)), amount); } function test_beforeInitialize_swap_revertsOnPoolNotInitialized() public { @@ -601,7 +611,7 @@ contract AccessLockTest is Test, Deployers { currency1: currency1, fee: Constants.FEE_MEDIUM, tickSpacing: 60, - hooks: IHooks(address(accessLockHook4)) + hooks: IHooks(address(accessLockNoOpHook)) }); vm.expectRevert(IPoolManager.PoolNotInitialized.selector); @@ -614,7 +624,7 @@ contract AccessLockTest is Test, Deployers { currency1: currency1, fee: Constants.FEE_MEDIUM, tickSpacing: 60, - hooks: IHooks(address(accessLockHook4)) + hooks: IHooks(address(accessLockNoOpHook)) }); vm.expectRevert(IPoolManager.PoolNotInitialized.selector); @@ -627,13 +637,101 @@ contract AccessLockTest is Test, Deployers { currency1: currency1, fee: Constants.FEE_MEDIUM, tickSpacing: 60, - hooks: IHooks(address(accessLockHook4)) + hooks: IHooks(address(accessLockNoOpHook)) }); vm.expectRevert(IPoolManager.PoolNotInitialized.selector); initializeRouter.initialize(key1, SQRT_RATIO_1_1, abi.encode(amount, AccessLockHook.LockAction.Donate)); } + /** + * + * HOOK FEE TESTS + * + */ + + function test_hookFees_takesFeeOnWithdrawal() public { + (key,) = initPool( + currency0, currency1, IHooks(address(accessLockFeeHook)), Constants.FEE_MEDIUM, SQRT_RATIO_1_1, ZERO_BYTES + ); + + (uint256 userBalanceBefore0, uint256 poolBalanceBefore0, uint256 reservesBefore0) = _fetchBalances(currency0); + (uint256 userBalanceBefore1, uint256 poolBalanceBefore1, uint256 reservesBefore1) = _fetchBalances(currency1); + + // add liquidity + delta = modifyPositionRouter.modifyPosition(key, LIQ_PARAMS, ZERO_BYTES); + + (uint256 userBalanceAfter0, uint256 poolBalanceAfter0, uint256 reservesAfter0) = _fetchBalances(currency0); + (uint256 userBalanceAfter1, uint256 poolBalanceAfter1, uint256 reservesAfter1) = _fetchBalances(currency1); + + assert(delta.amount0() > 0 && delta.amount1() > 0); + assertEq(userBalanceBefore0 - uint128(delta.amount0()), userBalanceAfter0, "addLiq user balance currency0"); + assertEq(userBalanceBefore1 - uint128(delta.amount1()), userBalanceAfter1, "addLiq user balance currency1"); + assertEq(poolBalanceBefore0 + uint128(delta.amount0()), poolBalanceAfter0, "addLiq pool balance currency0"); + assertEq(poolBalanceBefore1 + uint128(delta.amount1()), poolBalanceAfter1, "addLiq pool balance currency1"); + assertEq(reservesBefore0 + uint128(delta.amount0()), reservesAfter0, "addLiq reserves currency0"); + assertEq(reservesBefore1 + uint128(delta.amount1()), reservesAfter1, "addLiq reserves currency1"); + + (userBalanceBefore0, poolBalanceBefore0, reservesBefore0) = + (userBalanceAfter0, poolBalanceAfter0, reservesAfter0); + (userBalanceBefore1, poolBalanceBefore1, reservesBefore1) = + (userBalanceAfter1, poolBalanceAfter1, reservesAfter1); + + // remove liquidity, a 40 bip fee should be taken + LIQ_PARAMS.liquidityDelta *= -1; + delta = modifyPositionRouter.modifyPosition(key, LIQ_PARAMS, ZERO_BYTES); + + (userBalanceAfter0, poolBalanceAfter0, reservesAfter0) = _fetchBalances(currency0); + (userBalanceAfter1, poolBalanceAfter1, reservesAfter1) = _fetchBalances(currency1); + + assert(delta.amount0() < 0 && delta.amount1() < 0); + + uint256 totalWithdraw0 = uint128(-delta.amount0()) - (uint128(-delta.amount0()) * 40 / 10000); + uint256 totalWithdraw1 = uint128(-delta.amount1()) - (uint128(-delta.amount1()) * 40 / 10000); + + assertEq(userBalanceBefore0 + totalWithdraw0, userBalanceAfter0, "removeLiq user balance currency0"); + assertEq(userBalanceBefore1 + totalWithdraw1, userBalanceAfter1, "removeLiq user balance currency1"); + assertEq(poolBalanceBefore0 - uint128(-delta.amount0()), poolBalanceAfter0, "removeLiq pool balance currency0"); + assertEq(poolBalanceBefore1 - uint128(-delta.amount1()), poolBalanceAfter1, "removeLiq pool balance currency1"); + assertEq(reservesBefore0 - uint128(-delta.amount0()), reservesAfter0, "removeLiq reserves currency0"); + assertEq(reservesBefore1 - uint128(-delta.amount1()), reservesAfter1, "removeLiq reserves currency1"); + } + + function test_hookFees_takesFeeOnInputOfSwap() public { + (key,) = initPool( + currency0, currency1, IHooks(address(accessLockFeeHook)), Constants.FEE_MEDIUM, SQRT_RATIO_1_1, ZERO_BYTES + ); + + // add liquidity + delta = modifyPositionRouter.modifyPosition(key, LIQ_PARAMS, ZERO_BYTES); + + // now swap, with a hook fee of 55 bips + (uint256 userBalanceBefore0, uint256 poolBalanceBefore0, uint256 reservesBefore0) = _fetchBalances(currency0); + (uint256 userBalanceBefore1, uint256 poolBalanceBefore1, uint256 reservesBefore1) = _fetchBalances(currency1); + + delta = swapRouter.swap( + key, + IPoolManager.SwapParams({zeroForOne: true, amountSpecified: 100000, sqrtPriceLimitX96: SQRT_RATIO_1_2}), + PoolSwapTest.TestSettings({withdrawTokens: true, settleUsingTransfer: true, currencyAlreadySent: false}), + ZERO_BYTES + ); + + assert(delta.amount0() > 0 && delta.amount1() < 0); + + uint256 amountIn0 = uint128(delta.amount0()); + uint256 userAmountOut1 = uint128(-delta.amount1()) - (uint128(-delta.amount1()) * 55 / 10000); + + (uint256 userBalanceAfter0, uint256 poolBalanceAfter0, uint256 reservesAfter0) = _fetchBalances(currency0); + (uint256 userBalanceAfter1, uint256 poolBalanceAfter1, uint256 reservesAfter1) = _fetchBalances(currency1); + + assertEq(userBalanceBefore0 - amountIn0, userBalanceAfter0, "swap user balance currency0"); + assertEq(userBalanceBefore1 + userAmountOut1, userBalanceAfter1, "swap user balance currency1"); + assertEq(poolBalanceBefore0 + amountIn0, poolBalanceAfter0, "swap pool balance currency0"); + assertEq(poolBalanceBefore1 - uint128(-delta.amount1()), poolBalanceAfter1, "swap pool balance currency1"); + assertEq(reservesBefore0 + amountIn0, reservesAfter0, "swap reserves currency0"); + assertEq(reservesBefore1 - uint128(-delta.amount1()), reservesAfter1, "swap reserves currency1"); + } + /** * * EDGE CASE TESTS @@ -645,7 +743,7 @@ contract AccessLockTest is Test, Deployers { uint256 balanceOfBefore1 = MockERC20(Currency.unwrap(currency1)).balanceOf(address(this)); uint256 balanceOfBefore0 = MockERC20(Currency.unwrap(currency0)).balanceOf(address(this)); - BalanceDelta delta = modifyPositionRouter.modifyPosition( + delta = modifyPositionRouter.modifyPosition( key, IPoolManager.ModifyPositionParams(0, 60, 1e18), abi.encode(amount, AccessLockHook.LockAction.Mint) ); @@ -721,7 +819,7 @@ contract AccessLockTest is Test, Deployers { function test_getCurrentHook_isClearedAfterNoOpOnAllHooks() public { (PoolKey memory noOpKey,) = - initPool(currency0, currency1, IHooks(accessLockHook4), Constants.FEE_MEDIUM, SQRT_RATIO_1_1, ZERO_BYTES); + initPool(currency0, currency1, IHooks(accessLockNoOpHook), Constants.FEE_MEDIUM, SQRT_RATIO_1_1, ZERO_BYTES); // Assertions for current hook address in AccessLockHook and respective routers. // beforeModifyPosition noOp @@ -742,4 +840,14 @@ contract AccessLockTest is Test, Deployers { abi.encode(0, AccessLockHook.LockAction.NoOp) ); } + + function _fetchBalances(Currency currency) + internal + view + returns (uint256 userBalance, uint256 poolBalance, uint256 reserves) + { + userBalance = currency.balanceOf(address(this)); + poolBalance = currency.balanceOf(address(manager)); + reserves = manager.reservesOf(currency); + } } diff --git a/test/Fees.t.sol b/test/Fees.t.sol deleted file mode 100644 index c73b46d93..000000000 --- a/test/Fees.t.sol +++ /dev/null @@ -1,588 +0,0 @@ -// SPDX-License-Identifier: UNLICENSED -pragma solidity ^0.8.20; - -import {Test} from "forge-std/Test.sol"; -import {Vm} from "forge-std/Vm.sol"; -import {IHooks} from "../src/interfaces/IHooks.sol"; -import {Hooks} from "../src/libraries/Hooks.sol"; -import {FeeLibrary} from "../src/libraries/FeeLibrary.sol"; -import {IPoolManager} from "../src/interfaces/IPoolManager.sol"; -import {IFees} from "../src/interfaces/IFees.sol"; -import {PoolManager} from "../src/PoolManager.sol"; -import {TickMath} from "../src/libraries/TickMath.sol"; -import {Pool} from "../src/libraries/Pool.sol"; -import {PoolIdLibrary} from "../src/types/PoolId.sol"; -import {Deployers} from "./utils/Deployers.sol"; -import {PoolModifyPositionTest} from "../src/test/PoolModifyPositionTest.sol"; -import {Currency, CurrencyLibrary} from "../src/types/Currency.sol"; -import {MockERC20} from "solmate/test/utils/mocks/MockERC20.sol"; -import {MockHooks} from "../src/test/MockHooks.sol"; -import {PoolSwapTest} from "../src/test/PoolSwapTest.sol"; -import {GasSnapshot} from "forge-gas-snapshot/GasSnapshot.sol"; -import {ProtocolFeeControllerTest} from "../src/test/ProtocolFeeControllerTest.sol"; -import {IProtocolFeeController} from "../src/interfaces/IProtocolFeeController.sol"; -import {Fees} from "../src/Fees.sol"; -import {BalanceDelta} from "../src/types/BalanceDelta.sol"; -import {PoolKey} from "../src/types/PoolKey.sol"; - -contract FeesTest is Test, Deployers, GasSnapshot { - using Hooks for IHooks; - using Pool for Pool.State; - using PoolIdLibrary for PoolKey; - using CurrencyLibrary for Currency; - - MockHooks hook; - - // key0 hook enabled fee on swap - PoolKey key0; - // key1 hook enabled fee on withdraw - PoolKey key1; - // key2 hook enabled fee on swap and withdraw - PoolKey key2; - // key3 no hook - PoolKey key3; - - bool _zeroForOne = true; - bool _oneForZero = false; - - function setUp() public { - deployFreshManagerAndRouters(); - (currency0, currency1) = deployMintAndApprove2Currencies(); - - address hookAddr = address(99); // can't be a zero address, but does not have to have any other hook flags specified - MockHooks impl = new MockHooks(); - vm.etch(hookAddr, address(impl).code); - hook = MockHooks(hookAddr); - - key0 = PoolKey({ - currency0: currency0, - currency1: currency1, - fee: FeeLibrary.HOOK_SWAP_FEE_FLAG | uint24(3000), - hooks: hook, - tickSpacing: 60 - }); - - key1 = PoolKey({ - currency0: currency0, - currency1: currency1, - fee: FeeLibrary.HOOK_WITHDRAW_FEE_FLAG | uint24(3000), - hooks: hook, - tickSpacing: 60 - }); - - key2 = PoolKey({ - currency0: currency0, - currency1: currency1, - fee: FeeLibrary.HOOK_WITHDRAW_FEE_FLAG | FeeLibrary.HOOK_SWAP_FEE_FLAG | uint24(3000), - hooks: hook, - tickSpacing: 60 - }); - - key3 = PoolKey({ - currency0: currency0, - currency1: currency1, - fee: uint24(3000), - hooks: IHooks(address(0)), - tickSpacing: 60 - }); - - initializeRouter.initialize(key0, SQRT_RATIO_1_1, ZERO_BYTES); - initializeRouter.initialize(key1, SQRT_RATIO_1_1, ZERO_BYTES); - initializeRouter.initialize(key2, SQRT_RATIO_1_1, ZERO_BYTES); - initializeRouter.initialize(key3, SQRT_RATIO_1_1, ZERO_BYTES); - } - - function testInitializeFailsNoHook() public { - PoolKey memory key4 = PoolKey({ - currency0: currency0, - currency1: currency1, - fee: FeeLibrary.HOOK_WITHDRAW_FEE_FLAG | FeeLibrary.HOOK_SWAP_FEE_FLAG | uint24(3000), - hooks: IHooks(address(0)), - tickSpacing: 60 - }); - - vm.expectRevert(abi.encodeWithSelector(Hooks.HookAddressNotValid.selector, address(0))); - initializeRouter.initialize(key4, SQRT_RATIO_1_1, ZERO_BYTES); - - key4 = PoolKey({ - currency0: currency0, - currency1: currency1, - fee: FeeLibrary.DYNAMIC_FEE_FLAG, - hooks: IHooks(address(0)), - tickSpacing: 60 - }); - - vm.expectRevert(abi.encodeWithSelector(Hooks.HookAddressNotValid.selector, address(0))); - initializeRouter.initialize(key4, SQRT_RATIO_1_1, ZERO_BYTES); - } - - function testInitializeHookSwapFee(uint16 fee) public { - fee = uint16(bound(fee, 0, (2 ** 12) - 1)); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.hookFees), 0); - - hook.setSwapFee(key0, fee); - manager.setHookFees(key0); - - (slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.hookFees), fee); - assertEq(getWithdrawFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.protocolFees), 0); - assertEq(getWithdrawFee(slot0.protocolFees), 0); - } - - function testInitializeHookWithdrawFee(uint16 fee) public { - fee = uint16(bound(fee, 0, (2 ** 12) - 1)); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key1.toId()); - assertEq(getWithdrawFee(slot0.hookFees), 0); - - hook.setWithdrawFee(key1, fee); - manager.setHookFees(key1); - - (slot0,,,) = manager.pools(key1.toId()); - assertEq(getWithdrawFee(slot0.hookFees), fee); - assertEq(getSwapFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.protocolFees), 0); - assertEq(getWithdrawFee(slot0.protocolFees), 0); - } - - function testInitializeBothHookFee(uint16 swapFee, uint16 withdrawFee) public { - swapFee = uint16(bound(swapFee, 0, (2 ** 12) - 1)); - withdrawFee = uint16(bound(withdrawFee, 0, (2 ** 12) - 1)); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key2.toId()); - assertEq(getSwapFee(slot0.hookFees), 0); - assertEq(getWithdrawFee(slot0.hookFees), 0); - - hook.setSwapFee(key2, swapFee); - hook.setWithdrawFee(key2, withdrawFee); - manager.setHookFees(key2); - - (slot0,,,) = manager.pools(key2.toId()); - assertEq(getSwapFee(slot0.hookFees), swapFee); - assertEq(getWithdrawFee(slot0.hookFees), withdrawFee); - } - - function testInitializeHookProtocolSwapFee(uint16 hookSwapFee, uint16 protocolSwapFee) public { - hookSwapFee = uint16(bound(hookSwapFee, 0, (2 ** 12) - 1)); - protocolSwapFee = uint16(bound(protocolSwapFee, 0, (2 ** 12) - 1)); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.protocolFees), 0); - - feeController.setSwapFeeForPool(key0.toId(), protocolSwapFee); - - uint16 protocolSwapFee1 = protocolSwapFee >> 6; - uint16 protocolSwapFee0 = protocolSwapFee % 64; - - if ((protocolSwapFee0 != 0 && protocolSwapFee0 < 4) || (protocolSwapFee1 != 0 && protocolSwapFee1 < 4)) { - protocolSwapFee = 0; - vm.expectRevert(IFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); - } - manager.setProtocolFees(key0); - - hook.setSwapFee(key0, hookSwapFee); - manager.setHookFees(key0); - - (slot0,,,) = manager.pools(key0.toId()); - - assertEq(getWithdrawFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.hookFees), hookSwapFee); - assertEq(getSwapFee(slot0.protocolFees), protocolSwapFee); - assertEq(getWithdrawFee(slot0.protocolFees), 0); - } - - function testInitializeAllFees( - uint16 hookSwapFee, - uint16 hookWithdrawFee, - uint16 protocolSwapFee, - uint16 protocolWithdrawFee - ) public { - hookSwapFee = uint16(bound(hookSwapFee, 0, (2 ** 12) - 1)); - hookWithdrawFee = uint16(bound(hookWithdrawFee, 0, (2 ** 12) - 1)); - protocolSwapFee = uint16(bound(protocolSwapFee, 0, (2 ** 12) - 1)); - protocolWithdrawFee = uint16(bound(protocolWithdrawFee, 0, (2 ** 12) - 1)); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key2.toId()); - assertEq(getSwapFee(slot0.hookFees), 0); - assertEq(getWithdrawFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.protocolFees), 0); - assertEq(getWithdrawFee(slot0.protocolFees), 0); - - feeController.setSwapFeeForPool(key2.toId(), protocolSwapFee); - feeController.setWithdrawFeeForPool(key2.toId(), protocolWithdrawFee); - - uint16 protocolSwapFee1 = protocolSwapFee >> 6; - uint16 protocolSwapFee0 = protocolSwapFee % 64; - uint16 protocolWithdrawFee1 = protocolWithdrawFee >> 6; - uint16 protocolWithdrawFee0 = protocolWithdrawFee % 64; - - if ( - (protocolSwapFee1 != 0 && protocolSwapFee1 < 4) || (protocolSwapFee0 != 0 && protocolSwapFee0 < 4) - || (protocolWithdrawFee1 != 0 && protocolWithdrawFee1 < 4) - || (protocolWithdrawFee0 != 0 && protocolWithdrawFee0 < 4) - ) { - protocolSwapFee = 0; - protocolWithdrawFee = 0; - vm.expectRevert(IFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); - } - manager.setProtocolFees(key2); - - hook.setSwapFee(key2, hookSwapFee); - hook.setWithdrawFee(key2, hookWithdrawFee); - manager.setHookFees(key2); - - (slot0,,,) = manager.pools(key2.toId()); - - assertEq(getWithdrawFee(slot0.hookFees), hookWithdrawFee); - assertEq(getSwapFee(slot0.hookFees), hookSwapFee); - assertEq(getSwapFee(slot0.protocolFees), protocolSwapFee); - assertEq(getWithdrawFee(slot0.protocolFees), protocolWithdrawFee); - } - - function testProtocolFeeOnWithdrawalRemainsZeroIfNoHookWithdrawalFeeSet( - uint16 hookSwapFee, - uint8 protocolWithdrawFee0, - uint8 protocolWithdrawFee1 - ) public { - hookSwapFee = uint16(bound(hookSwapFee, 0, (2 ** 12) - 1)); - - protocolWithdrawFee0 = uint8(bound(protocolWithdrawFee0, 4, (2 ** 6) - 1)); - protocolWithdrawFee1 = uint8(bound(protocolWithdrawFee1, 4, (2 ** 6) - 1)); - - uint16 protocolWithdrawFee = (uint16(protocolWithdrawFee0) << 6) | uint16(protocolWithdrawFee1); - - // On a pool whose hook has not set a withdraw fee, the protocol should not accrue any value even if it has set a withdraw fee. - hook.setSwapFee(key0, hookSwapFee); - manager.setHookFees(key0); - - // set fee on the fee controller - feeController.setWithdrawFeeForPool(key0.toId(), protocolWithdrawFee); - manager.setProtocolFees(key0); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); - assertEq(getWithdrawFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.hookFees), hookSwapFee); - assertEq(getSwapFee(slot0.protocolFees), 0); - assertEq(getWithdrawFee(slot0.protocolFees), protocolWithdrawFee); - - IPoolManager.ModifyPositionParams memory params = IPoolManager.ModifyPositionParams(-60, 60, 10e18); - modifyPositionRouter.modifyPosition(key0, params, ZERO_BYTES); - - IPoolManager.ModifyPositionParams memory params2 = IPoolManager.ModifyPositionParams(-60, 60, -10e18); - modifyPositionRouter.modifyPosition(key0, params2, ZERO_BYTES); - - // Fees dont accrue when key.fee does not specify a withdrawal param even if the protocol fee is set. - assertEq(manager.protocolFeesAccrued(currency0), 0); - assertEq(manager.protocolFeesAccrued(currency1), 0); - assertEq(manager.hookFeesAccrued(address(key0.hooks), currency0), 0); - assertEq(manager.hookFeesAccrued(address(key0.hooks), currency1), 0); - } - - // function testFeeOutOfBoundsReverts(uint16 newFee) external { - // newFee = uint16(bound(newFee, 2 ** 12, type(uint16).max)); - - // hook.setSwapFee(key0, newFee); - // vm.expectRevert(abi.encodeWithSelector(IFees.FeeDenominatorOutOfBounds.selector, newFee)); - // manager.setHookFees(key0); - - // hook.setWithdrawFee(key0, newFee); - // vm.expectRevert(abi.encodeWithSelector(IFees.FeeDenominatorOutOfBounds.selector, newFee)); - // manager.setHookFees(key0); - - // manager.setProtocolFeeController(IProtocolFeeController(feeController)); - - // feeController.setSwapFeeForPool(key0.toId(), newFee); - // vm.expectRevert(abi.encodeWithSelector(IFees.FeeDenominatorOutOfBounds.selector, newFee)); - // manager.setProtocolFees(key0); - - // feeController.setWithdrawFeeForPool(key0.toId(), newFee); - // vm.expectRevert(abi.encodeWithSelector(IFees.FeeDenominatorOutOfBounds.selector, newFee)); - // manager.setProtocolFees(key0); - // } - - function testHookWithdrawFeeProtocolWithdrawFee( - uint16 hookWithdrawFee, - uint8 protocolWithdrawFee0, - uint8 protocolWithdrawFee1 - ) public { - hookWithdrawFee = uint16(bound(hookWithdrawFee, 0, (2 ** 12) - 1)); - protocolWithdrawFee0 = uint8(bound(protocolWithdrawFee0, 4, (2 ** 6) - 1)); - protocolWithdrawFee1 = uint8(bound(protocolWithdrawFee1, 4, (2 ** 6) - 1)); - - uint16 protocolWithdrawFee = (uint16(protocolWithdrawFee0) << 6) | uint16(protocolWithdrawFee1); - - hook.setWithdrawFee(key1, hookWithdrawFee); - manager.setHookFees(key1); - - feeController.setWithdrawFeeForPool(key1.toId(), protocolWithdrawFee); - manager.setProtocolFees(key1); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key1.toId()); - - assertEq(getWithdrawFee(slot0.hookFees), hookWithdrawFee); - assertEq(getSwapFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.protocolFees), 0); - assertEq(getWithdrawFee(slot0.protocolFees), protocolWithdrawFee); - - int256 liquidityDelta = 10000; - // The underlying amount for a liquidity delta of 10000 is 29. - uint256 underlyingAmount0 = 29; - uint256 underlyingAmount1 = 29; - - IPoolManager.ModifyPositionParams memory params = IPoolManager.ModifyPositionParams(-60, 60, liquidityDelta); - BalanceDelta delta = modifyPositionRouter.modifyPosition(key1, params, ZERO_BYTES); - - // Fees dont accrue for positive liquidity delta. - assertEq(manager.protocolFeesAccrued(currency0), 0); - assertEq(manager.protocolFeesAccrued(currency1), 0); - assertEq(manager.hookFeesAccrued(address(key1.hooks), currency0), 0); - assertEq(manager.hookFeesAccrued(address(key1.hooks), currency1), 0); - - IPoolManager.ModifyPositionParams memory params2 = IPoolManager.ModifyPositionParams(-60, 60, -liquidityDelta); - delta = modifyPositionRouter.modifyPosition(key1, params2, ZERO_BYTES); - - uint16 hookFee0 = (hookWithdrawFee % 64); - uint16 hookFee1 = (hookWithdrawFee >> 6); - uint16 protocolFee0 = (protocolWithdrawFee % 64); - uint16 protocolFee1 = (protocolWithdrawFee >> 6); - - // Fees should accrue to both the protocol and hook. - uint256 initialHookAmount0 = hookFee0 == 0 ? 0 : underlyingAmount0 / hookFee0; - uint256 initialHookAmount1 = hookFee1 == 0 ? 0 : underlyingAmount1 / hookFee1; - - uint256 expectedProtocolAmount0 = protocolFee0 == 0 ? 0 : initialHookAmount0 / protocolFee0; - uint256 expectedProtocolAmount1 = protocolFee1 == 0 ? 0 : initialHookAmount1 / protocolFee1; - // Adjust the hook fee amounts after the protocol fee is taken. - uint256 expectedHookFee0 = initialHookAmount0 - expectedProtocolAmount0; - uint256 expectedHookFee1 = initialHookAmount1 - expectedProtocolAmount1; - - assertEq(manager.protocolFeesAccrued(currency0), expectedProtocolAmount0); - assertEq(manager.protocolFeesAccrued(currency1), expectedProtocolAmount1); - assertEq(manager.hookFeesAccrued(address(key1.hooks), currency0), expectedHookFee0); - assertEq(manager.hookFeesAccrued(address(key1.hooks), currency1), expectedHookFee1); - } - - function testNoHookProtocolFee( - uint8 protocolSwapFee0, - uint8 protocolSwapFee1, - uint8 protocolWithdrawFee0, - uint8 protocolWithdrawFee1 - ) public { - protocolWithdrawFee0 = uint8(bound(protocolWithdrawFee0, 4, (2 ** 6) - 1)); - protocolWithdrawFee1 = uint8(bound(protocolWithdrawFee1, 4, (2 ** 6) - 1)); - protocolSwapFee0 = uint8(bound(protocolSwapFee0, 4, (2 ** 6) - 1)); - protocolSwapFee1 = uint8(bound(protocolSwapFee1, 4, (2 ** 6) - 1)); - - uint16 protocolWithdrawFee = (uint16(protocolWithdrawFee0) << 6) | uint16(protocolWithdrawFee1); - uint16 protocolSwapFee = (uint16(protocolSwapFee1) << 6) | uint16(protocolSwapFee0); - - feeController.setSwapFeeForPool(key3.toId(), protocolSwapFee); - feeController.setWithdrawFeeForPool(key3.toId(), protocolWithdrawFee); - manager.setProtocolFees(key3); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key3.toId()); - assertEq(getWithdrawFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.hookFees), 0); - assertEq(getSwapFee(slot0.protocolFees), protocolSwapFee); - assertEq(getWithdrawFee(slot0.protocolFees), protocolWithdrawFee); - - int256 liquidityDelta = 10000; - IPoolManager.ModifyPositionParams memory params = IPoolManager.ModifyPositionParams(-60, 60, liquidityDelta); - modifyPositionRouter.modifyPosition(key3, params, ZERO_BYTES); - - // Fees dont accrue for positive liquidity delta. - assertEq(manager.protocolFeesAccrued(currency0), 0); - assertEq(manager.protocolFeesAccrued(currency1), 0); - assertEq(manager.hookFeesAccrued(address(key3.hooks), currency0), 0); - assertEq(manager.hookFeesAccrued(address(key3.hooks), currency1), 0); - - IPoolManager.ModifyPositionParams memory params2 = IPoolManager.ModifyPositionParams(-60, 60, -liquidityDelta); - modifyPositionRouter.modifyPosition(key3, params2, ZERO_BYTES); - - // No fees should accrue bc there is no hook so the protocol cant take withdraw fees. - assertEq(manager.protocolFeesAccrued(currency0), 0); - assertEq(manager.protocolFeesAccrued(currency1), 0); - - // add larger liquidity - params = IPoolManager.ModifyPositionParams(-60, 60, 10e18); - modifyPositionRouter.modifyPosition(key3, params, ZERO_BYTES); - - MockERC20(Currency.unwrap(currency1)).approve(address(swapRouter), type(uint256).max); - swapRouter.swap( - key3, - IPoolManager.SwapParams(false, 10000, TickMath.MAX_SQRT_RATIO - 1), - PoolSwapTest.TestSettings(true, true, false), - ZERO_BYTES - ); - // key3 pool is 30 bps => 10000 * 0.003 (.3%) = 30 - uint256 expectedSwapFeeAccrued = 30; - - uint256 expectedProtocolAmount1 = protocolSwapFee1 == 0 ? 0 : expectedSwapFeeAccrued / protocolSwapFee1; - assertEq(manager.protocolFeesAccrued(currency0), 0); - assertEq(manager.protocolFeesAccrued(currency1), expectedProtocolAmount1); - } - - function testProtocolSwapFeeAndHookSwapFeeSameDirection() public { - uint16 protocolFee = _computeFee(_oneForZero, 10); // 10% on 1 to 0 swaps - feeController.setSwapFeeForPool(key0.toId(), protocolFee); - manager.setProtocolFees(key0); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.protocolFees), protocolFee); - assertEq(getWithdrawFee(slot0.protocolFees), 0); - - uint16 hookFee = _computeFee(_oneForZero, 5); // 20% on 1 to 0 swaps - hook.setSwapFee(key0, hookFee); - manager.setHookFees(key0); - (slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.hookFees), hookFee); - assertEq(getWithdrawFee(slot0.hookFees), 0); - - IPoolManager.ModifyPositionParams memory params = IPoolManager.ModifyPositionParams(-120, 120, 10e18); - modifyPositionRouter.modifyPosition(key0, params, ZERO_BYTES); - // 1 for 0 swap - MockERC20(Currency.unwrap(currency1)).approve(address(swapRouter), type(uint256).max); - swapRouter.swap( - key0, - IPoolManager.SwapParams(false, 10000, TickMath.MAX_SQRT_RATIO - 1), - PoolSwapTest.TestSettings(true, true, false), - ZERO_BYTES - ); - - assertEq(manager.protocolFeesAccrued(currency1), 3); // 10% of 30 is 3 - assertEq(manager.hookFeesAccrued(address(key0.hooks), currency1), 5); // 27 * .2 is 5.4 so 5 rounding down - } - - function testInitializeWithSwapProtocolFeeAndHookFeeDifferentDirections() public { - uint16 protocolFee = _computeFee(_oneForZero, 10); // 10% fee on 1 to 0 swaps - feeController.setSwapFeeForPool(key0.toId(), protocolFee); - manager.setProtocolFees(key0); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.protocolFees), protocolFee); - assertEq(getWithdrawFee(slot0.protocolFees), 0); - - uint16 hookFee = _computeFee(_zeroForOne, 5); // 20% on 0 to 1 swaps - - hook.setSwapFee(key0, hookFee); - manager.setHookFees(key0); - (slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.hookFees), hookFee); - assertEq(getWithdrawFee(slot0.hookFees), 0); - - IPoolManager.ModifyPositionParams memory params = IPoolManager.ModifyPositionParams(-120, 120, 10e18); - modifyPositionRouter.modifyPosition(key0, params, ZERO_BYTES); - // 1 for 0 swap - MockERC20(Currency.unwrap(currency1)).approve(address(swapRouter), type(uint256).max); - swapRouter.swap( - key0, - IPoolManager.SwapParams(false, 10000, TickMath.MAX_SQRT_RATIO - 1), - PoolSwapTest.TestSettings(true, true, false), - ZERO_BYTES - ); - - assertEq(manager.protocolFeesAccrued(currency1), 3); // 10% of 30 is 3 - assertEq(manager.hookFeesAccrued(address(key0.hooks), currency1), 0); // hook fee only taken on 0 to 1 swaps - } - - function testSwapWithProtocolFeeAllAndHookFeeAllButOnlySwapFlag() public { - // Protocol should not be able to withdraw since the hook withdraw fee is not set - uint16 protocolFee = _computeFee(_oneForZero, 4) | _computeFee(_zeroForOne, 4); // max fees on both amounts - feeController.setWithdrawFeeForPool(key0.toId(), protocolFee); // - manager.setProtocolFees(key0); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.protocolFees), 0); - assertEq(getWithdrawFee(slot0.protocolFees), protocolFee); // successfully sets the fee, but is never applied - - uint16 hookSwapFee = _computeFee(_oneForZero, 4); // 25% on 1 to 0 swaps - uint16 hookWithdrawFee = _computeFee(_oneForZero, 4) | _computeFee(_zeroForOne, 4); // max fees on both amounts - hook.setSwapFee(key0, hookSwapFee); - hook.setWithdrawFee(key0, hookWithdrawFee); - manager.setHookFees(key0); - (slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.hookFees), hookSwapFee); - assertEq(getWithdrawFee(slot0.hookFees), 0); // Even though the contract sets a withdraw fee it will not be applied bc the pool key.fee did not assert a withdraw flag. - - IPoolManager.ModifyPositionParams memory params = IPoolManager.ModifyPositionParams(-120, 120, 10e18); - modifyPositionRouter.modifyPosition(key0, params, ZERO_BYTES); - // 1 for 0 swap - MockERC20(Currency.unwrap(currency1)).approve(address(swapRouter), type(uint256).max); - swapRouter.swap( - key0, - IPoolManager.SwapParams(false, 10000, TickMath.MAX_SQRT_RATIO - 1), - PoolSwapTest.TestSettings(true, true, false), - ZERO_BYTES - ); - - assertEq(manager.protocolFeesAccrued(currency1), 0); // No protocol fee was accrued on swap - assertEq(manager.protocolFeesAccrued(currency0), 0); // No protocol fee was accrued on swap - assertEq(manager.hookFeesAccrued(address(key0.hooks), currency1), 7); // 25% on 1 to 0, 25% of 30 is 7.5 so 7 - - modifyPositionRouter.modifyPosition(key0, IPoolManager.ModifyPositionParams(-120, 120, -10e18), ZERO_BYTES); - - assertEq(manager.protocolFeesAccrued(currency1), 0); // No protocol fee was accrued on withdraw - assertEq(manager.protocolFeesAccrued(currency0), 0); // No protocol fee was accrued on withdraw - assertEq(manager.hookFeesAccrued(address(key0.hooks), currency1), 7); // Same amount of fees for hook. - assertEq(manager.hookFeesAccrued(address(key0.hooks), currency0), 0); // Same amount of fees for hook. - } - - function testCollectFees() public { - uint16 protocolFee = _computeFee(_oneForZero, 10); // 10% on 1 to 0 swaps - feeController.setSwapFeeForPool(key0.toId(), protocolFee); - manager.setProtocolFees(key0); - - (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.protocolFees), protocolFee); - - uint16 hookFee = _computeFee(_oneForZero, 5); // 20% on 1 to 0 swaps - hook.setSwapFee(key0, hookFee); - manager.setHookFees(key0); - - (slot0,,,) = manager.pools(key0.toId()); - assertEq(getSwapFee(slot0.hookFees), hookFee); - - IPoolManager.ModifyPositionParams memory params = IPoolManager.ModifyPositionParams(-120, 120, 10e18); - modifyPositionRouter.modifyPosition(key0, params, ZERO_BYTES); - // 1 for 0 swap - MockERC20(Currency.unwrap(currency1)).approve(address(swapRouter), type(uint256).max); - swapRouter.swap( - key0, - IPoolManager.SwapParams(false, 10000, TickMath.MAX_SQRT_RATIO - 1), - PoolSwapTest.TestSettings(true, true, false), - ZERO_BYTES - ); - - uint256 expectedProtocolFees = 3; // 10% of 30 is 3 - vm.prank(address(feeController)); - manager.collectProtocolFees(address(feeController), currency1, 0); - assertEq(currency1.balanceOf(address(feeController)), expectedProtocolFees); - - uint256 expectedHookFees = 5; // 20% of 27 (30-3) is 5.4, round down is 5 - vm.prank(address(hook)); - // Addr(0) recipient will be the hook. - manager.collectHookFees(address(hook), currency1, 0); - assertEq(currency1.balanceOf(address(hook)), expectedHookFees); - } - - // If zeroForOne is true, then value is set on the lower bits. If zeroForOne is false, then value is set on the higher bits. - function _computeFee(bool zeroForOne, uint16 value) internal pure returns (uint16 fee) { - if (zeroForOne) { - fee = value % 64; - } else { - fee = value << 6; - } - } - - function getSwapFee(uint24 feesStorage) internal pure returns (uint16) { - return uint16(feesStorage >> 12); - } - - function getWithdrawFee(uint24 feesStorage) internal pure returns (uint16) { - return uint16(feesStorage & 0xFFF); - } -} diff --git a/test/Hooks.t.sol b/test/Hooks.t.sol index db0707640..c768c5187 100644 --- a/test/Hooks.t.sol +++ b/test/Hooks.t.sol @@ -46,7 +46,7 @@ contract HooksTest is Test, Deployers, GasSnapshot { function test_initialize_succeedsWithHook() public { initializeRouter.initialize(uninitializedKey, SQRT_RATIO_1_1, new bytes(123)); - (uint160 sqrtPriceX96,,,) = manager.getSlot0(uninitializedKey.toId()); + (uint160 sqrtPriceX96,,) = manager.getSlot0(uninitializedKey.toId()); assertEq(sqrtPriceX96, SQRT_RATIO_1_1); assertEq(mockHooks.beforeInitializeData(), new bytes(123)); assertEq(mockHooks.afterInitializeData(), new bytes(123)); diff --git a/test/Pool.t.sol b/test/Pool.t.sol index f105a2bee..668e62cb4 100644 --- a/test/Pool.t.sol +++ b/test/Pool.t.sol @@ -17,27 +17,14 @@ contract PoolTest is Test { Pool.State state; - function testPoolInitialize(uint160 sqrtPriceX96, uint16 protocolFee, uint16 hookFee, uint24 dynamicFee) public { - protocolFee = uint16(bound(protocolFee, 0, (2 ** 12) - 1)); - hookFee = uint16(bound(hookFee, 0, (2 ** 12) - 1)); - + function testPoolInitialize(uint160 sqrtPriceX96, uint16 protocolFee, uint24 dynamicFee) public { if (sqrtPriceX96 < TickMath.MIN_SQRT_RATIO || sqrtPriceX96 >= TickMath.MAX_SQRT_RATIO) { vm.expectRevert(TickMath.InvalidSqrtRatio.selector); - state.initialize( - sqrtPriceX96, - _formatSwapAndWithdrawFee(protocolFee, protocolFee), - _formatSwapAndWithdrawFee(hookFee, hookFee), - dynamicFee - ); + state.initialize(sqrtPriceX96, protocolFee, dynamicFee); } else { - state.initialize( - sqrtPriceX96, - _formatSwapAndWithdrawFee(protocolFee, protocolFee), - _formatSwapAndWithdrawFee(hookFee, hookFee), - dynamicFee - ); + state.initialize(sqrtPriceX96, protocolFee, dynamicFee); assertEq(state.slot0.sqrtPriceX96, sqrtPriceX96); - assertEq(state.slot0.protocolFees >> 12, protocolFee); + assertEq(state.slot0.protocolFee, protocolFee); assertEq(state.slot0.tick, TickMath.getTickAtSqrtRatio(sqrtPriceX96)); assertLt(state.slot0.tick, TickMath.MAX_TICK); assertGt(state.slot0.tick, TickMath.MIN_TICK - 1); @@ -48,7 +35,7 @@ contract PoolTest is Test { // Assumptions tested in PoolManager.t.sol params.tickSpacing = int24(bound(params.tickSpacing, TickMath.MIN_TICK_SPACING, TickMath.MAX_TICK_SPACING)); - testPoolInitialize(sqrtPriceX96, 0, 0, 0); + testPoolInitialize(sqrtPriceX96, 0, 0); if (params.tickLower >= params.tickUpper) { vm.expectRevert(abi.encodeWithSelector(Pool.TicksMisordered.selector, params.tickLower, params.tickUpper)); @@ -94,7 +81,7 @@ contract PoolTest is Test { params.tickSpacing = int24(bound(params.tickSpacing, TickMath.MIN_TICK_SPACING, TickMath.MAX_TICK_SPACING)); swapFee = uint24(bound(swapFee, 0, 999999)); - testPoolInitialize(sqrtPriceX96, 0, 0, 0); + testPoolInitialize(sqrtPriceX96, 0, 0); Pool.Slot0 memory slot0 = state.slot0; if (params.amountSpecified == 0) { @@ -129,8 +116,4 @@ contract PoolTest is Test { assertGe(state.slot0.sqrtPriceX96, params.sqrtPriceLimitX96); } } - - function _formatSwapAndWithdrawFee(uint16 swapFee, uint16 withdrawFee) internal pure returns (uint24) { - return (uint24(swapFee) << 12) | withdrawFee; - } } diff --git a/test/PoolManager.t.sol b/test/PoolManager.t.sol index ec283958e..b7cec77df 100644 --- a/test/PoolManager.t.sol +++ b/test/PoolManager.t.sol @@ -628,6 +628,45 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { snapEnd(); } + function test_swap_accruesProtocolFees(uint8 protocolFee1, uint8 protocolFee0) public { + protocolFee0 = uint8(bound(protocolFee0, 4, type(uint8).max)); + protocolFee1 = uint8(bound(protocolFee1, 4, type(uint8).max)); + + uint16 protocolFee = (uint16(protocolFee1) << 8) | (uint16(protocolFee0) & uint16(0xFF)); + + feeController.setSwapFeeForPool(key.toId(), protocolFee); + manager.setProtocolFee(key); + + (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); + assertEq(slot0.protocolFee, protocolFee); + + // Add liquidity - Fees dont accrue for positive liquidity delta. + IPoolManager.ModifyPositionParams memory params = LIQ_PARAMS; + modifyPositionRouter.modifyPosition(key, params, ZERO_BYTES); + + assertEq(manager.protocolFeesAccrued(currency0), 0); + assertEq(manager.protocolFeesAccrued(currency1), 0); + + // Remove liquidity - Fees dont accrue for negative liquidity delta. + params.liquidityDelta = -LIQ_PARAMS.liquidityDelta; + modifyPositionRouter.modifyPosition(key, params, ZERO_BYTES); + + assertEq(manager.protocolFeesAccrued(currency0), 0); + assertEq(manager.protocolFeesAccrued(currency1), 0); + + // Now re-add the liquidity to test swap + params.liquidityDelta = LIQ_PARAMS.liquidityDelta; + modifyPositionRouter.modifyPosition(key, params, ZERO_BYTES); + + IPoolManager.SwapParams memory swapParams = IPoolManager.SwapParams(false, 10000, TickMath.MAX_SQRT_RATIO - 1); + swapRouter.swap(key, swapParams, PoolSwapTest.TestSettings(true, true, false), ZERO_BYTES); + + uint256 expectedTotalSwapFee = uint256(swapParams.amountSpecified) * key.fee / 1e6; + uint256 expectedProtocolFee = expectedTotalSwapFee / protocolFee1; + assertEq(manager.protocolFeesAccrued(currency0), 0); + assertEq(manager.protocolFeesAccrued(currency1), expectedProtocolFee); + } + function test_donate_failsIfNotInitialized() public { vm.expectRevert(abi.encodeWithSelector(Pool.PoolNotInitialized.selector)); donateRouter.donate(uninitializedKey, 100, 100, ZERO_BYTES); @@ -754,59 +793,67 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { takeRouter.take{value: 1}(nativeKey, 1, 1); // assertions inside takeRouter because it takes then settles } - function test_setProtocolFee_updatesProtocolFeeForInitializedPool() public { - uint24 protocolFee = 4; - + function test_setProtocolFee_updatesProtocolFeeForInitializedPool(uint16 protocolFee) public { (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFees, 0); - feeController.setSwapFeeForPool(key.toId(), uint16(protocolFee)); - - vm.expectEmit(false, false, false, true); - emit ProtocolFeeUpdated(key.toId(), protocolFee << 12); - manager.setProtocolFees(key); + assertEq(slot0.protocolFee, 0); + feeController.setSwapFeeForPool(key.toId(), protocolFee); + + uint8 fee0 = uint8(protocolFee >> 8); + uint8 fee1 = uint8(protocolFee % 256); + if ((0 < fee0 && fee0 < 4) || (0 < fee1 && fee1 < 4)) { + vm.expectRevert(IFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); + manager.setProtocolFee(key); + } else { + vm.expectEmit(false, false, false, true); + emit ProtocolFeeUpdated(key.toId(), protocolFee); + manager.setProtocolFee(key); + + (slot0,,,) = manager.pools(key.toId()); + assertEq(slot0.protocolFee, protocolFee); + } } function test_setProtocolFee_failsWithInvalidProtocolFeeControllers() public { (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFees, 0); + assertEq(slot0.protocolFee, 0); manager.setProtocolFeeController(revertingFeeController); vm.expectRevert(IFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); - manager.setProtocolFees(key); + manager.setProtocolFee(key); manager.setProtocolFeeController(outOfBoundsFeeController); vm.expectRevert(IFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); - manager.setProtocolFees(key); + manager.setProtocolFee(key); manager.setProtocolFeeController(overflowFeeController); vm.expectRevert(IFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); - manager.setProtocolFees(key); + manager.setProtocolFee(key); manager.setProtocolFeeController(invalidReturnSizeFeeController); vm.expectRevert(IFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); - manager.setProtocolFees(key); + manager.setProtocolFee(key); } function test_collectProtocolFees_initializesWithProtocolFeeIfCalled() public { - uint24 protocolFee = 260; // 0001 00 00 0100 + uint16 protocolFee = 1028; // 00000100 00000100 // sets the upper 12 bits feeController.setSwapFeeForPool(uninitializedKey.toId(), uint16(protocolFee)); initializeRouter.initialize(uninitializedKey, SQRT_RATIO_1_1, ZERO_BYTES); (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFees, protocolFee << 12); + assertEq(slot0.protocolFee, protocolFee); } function test_collectProtocolFees_ERC20_allowsOwnerToAccumulateFees_gas() public { - uint24 protocolFee = 260; // 0001 00 00 0100 + uint16 protocolFee = 1028; // 00000100 00000100 uint256 expectedFees = 7; feeController.setSwapFeeForPool(key.toId(), uint16(protocolFee)); - manager.setProtocolFees(key); + manager.setProtocolFee(key); (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFees, protocolFee << 12); + assertEq(slot0.protocolFee, protocolFee); swapRouter.swap( key, @@ -1041,14 +1088,14 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { } function test_collectProtocolFees_ERC20_returnsAllFeesIf0IsProvidedAsParameter() public { - uint24 protocolFee = 260; // 0001 00 00 0100 + uint16 protocolFee = 1028; // 00000100 00000100 uint256 expectedFees = 7; feeController.setSwapFeeForPool(key.toId(), uint16(protocolFee)); - manager.setProtocolFees(key); + manager.setProtocolFee(key); (Pool.Slot0 memory slot0,,,) = manager.pools(key.toId()); - assertEq(slot0.protocolFees, protocolFee << 12); + assertEq(slot0.protocolFee, protocolFee); swapRouter.swap( key, @@ -1066,16 +1113,16 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { } function test_collectProtocolFees_nativeToken_allowsOwnerToAccumulateFees_gas() public { - uint24 protocolFee = 260; // 0001 00 00 0100 + uint16 protocolFee = 1028; // 00000100 00000100 uint256 expectedFees = 7; Currency nativeCurrency = CurrencyLibrary.NATIVE; // set protocol fee before initializing the pool as it is fetched on initialization feeController.setSwapFeeForPool(nativeKey.toId(), uint16(protocolFee)); - manager.setProtocolFees(nativeKey); + manager.setProtocolFee(nativeKey); (Pool.Slot0 memory slot0,,,) = manager.pools(nativeKey.toId()); - assertEq(slot0.protocolFees, protocolFee << 12); + assertEq(slot0.protocolFee, protocolFee); swapRouter.swap{value: 10000}( nativeKey, @@ -1095,15 +1142,15 @@ contract PoolManagerTest is Test, Deployers, GasSnapshot { } function test_collectProtocolFees_nativeToken_returnsAllFeesIf0IsProvidedAsParameter() public { - uint24 protocolFee = 260; // 0001 00 00 0100 + uint16 protocolFee = 1028; // 00000100 00000100 uint256 expectedFees = 7; Currency nativeCurrency = CurrencyLibrary.NATIVE; feeController.setSwapFeeForPool(nativeKey.toId(), uint16(protocolFee)); - manager.setProtocolFees(nativeKey); + manager.setProtocolFee(nativeKey); (Pool.Slot0 memory slot0,,,) = manager.pools(nativeKey.toId()); - assertEq(slot0.protocolFees, protocolFee << 12); + assertEq(slot0.protocolFee, protocolFee); swapRouter.swap{value: 10000}( nativeKey, diff --git a/test/PoolManagerInitialize.t.sol b/test/PoolManagerInitialize.t.sol index be8c4e27e..313e439d7 100644 --- a/test/PoolManagerInitialize.t.sol +++ b/test/PoolManagerInitialize.t.sol @@ -82,7 +82,7 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { (Pool.Slot0 memory slot0,,,) = manager.pools(key0.toId()); assertEq(slot0.sqrtPriceX96, sqrtPriceX96); - assertEq(slot0.protocolFees, 0); + assertEq(slot0.protocolFee, 0); } } @@ -104,7 +104,7 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); assertEq(slot0.sqrtPriceX96, sqrtPriceX96); - assertEq(slot0.protocolFees >> 12, 0); + assertEq(slot0.protocolFee, 0); assertEq(slot0.tick, TickMath.getTickAtSqrtRatio(sqrtPriceX96)); } @@ -198,19 +198,22 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { initializeRouter.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); } - function test_initialize_fetchFeeWhenController(uint160 sqrtPriceX96) public { - // Assumptions tested in Pool.t.sol - sqrtPriceX96 = uint160(bound(sqrtPriceX96, TickMath.MIN_SQRT_RATIO, TickMath.MAX_SQRT_RATIO - 1)); - + function test_initialize_fetchFeeWhenController(uint16 protocolFee) public { manager.setProtocolFeeController(feeController); - uint16 poolProtocolFee = 4; - feeController.setSwapFeeForPool(uninitializedKey.toId(), poolProtocolFee); + feeController.setSwapFeeForPool(uninitializedKey.toId(), protocolFee); - initializeRouter.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); + uint8 fee0 = uint8(protocolFee >> 8); + uint8 fee1 = uint8(protocolFee % 256); + + initializeRouter.initialize(uninitializedKey, SQRT_RATIO_1_1, ZERO_BYTES); (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.sqrtPriceX96, sqrtPriceX96); - assertEq(slot0.protocolFees >> 12, poolProtocolFee); + assertEq(slot0.sqrtPriceX96, SQRT_RATIO_1_1); + if ((0 < fee0 && fee0 < 4) || (0 < fee1 && fee1 < 4)) { + assertEq(slot0.protocolFee, 0); + } else { + assertEq(slot0.protocolFee, protocolFee); + } } function test_initialize_revertsWhenPoolAlreadyInitialized(uint160 sqrtPriceX96) public { @@ -317,11 +320,10 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { initializeRouter.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); // protocol fees should default to 0 (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFees >> 12, 0); - assertEq(slot0.protocolFees & 0xFFF, 0); - // call to setProtocolFees should also revert + assertEq(slot0.protocolFee, 0); + // call to setProtocolFee should also revert vm.expectRevert(IFees.ProtocolFeeControllerCallFailedOrInvalidResult.selector); - manager.setProtocolFees(uninitializedKey); + manager.setProtocolFee(uninitializedKey); } function test_initialize_succeedsWithRevertingFeeController(uint160 sqrtPriceX96) public { @@ -342,8 +344,7 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { initializeRouter.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); // protocol fees should default to 0 (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFees >> 12, 0); - assertEq(slot0.protocolFees & 0xFFF, 0); + assertEq(slot0.protocolFee, 0); } function test_initialize_succeedsWithOverflowFeeController(uint160 sqrtPriceX96) public { @@ -364,8 +365,7 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { initializeRouter.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); // protocol fees should default to 0 (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFees >> 12, 0); - assertEq(slot0.protocolFees & 0xFFF, 0); + assertEq(slot0.protocolFee, 0); } function test_initialize_succeedsWithWrongReturnSizeFeeController(uint160 sqrtPriceX96) public { @@ -386,41 +386,7 @@ contract PoolManagerInitializeTest is Test, Deployers, GasSnapshot { initializeRouter.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); // protocol fees should default to 0 (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.protocolFees >> 12, 0); - assertEq(slot0.protocolFees & 0xFFF, 0); - } - - function test_initialize_succeedsAndSetsHookFeeIfControllerReverts(uint160 sqrtPriceX96) public { - // Assumptions tested in Pool.t.sol - sqrtPriceX96 = uint160(bound(sqrtPriceX96, TickMath.MIN_SQRT_RATIO, TickMath.MAX_SQRT_RATIO - 1)); - - address hookAddr = address(99); // can't be a zero address, but does not have to have any other hook flags specified - MockHooks impl = new MockHooks(); - vm.etch(hookAddr, address(impl).code); - MockHooks hook = MockHooks(hookAddr); - - uninitializedKey = PoolKey({ - currency0: currency0, - currency1: currency1, - fee: FeeLibrary.HOOK_SWAP_FEE_FLAG | uint24(3000), - hooks: hook, - tickSpacing: 60 - }); - - manager.setProtocolFeeController(revertingFeeController); - // expect initialize to succeed even though the controller reverts - initializeRouter.initialize(uninitializedKey, sqrtPriceX96, ZERO_BYTES); - (Pool.Slot0 memory slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(slot0.sqrtPriceX96, sqrtPriceX96); - // protocol fees should default to 0 - assertEq(slot0.protocolFees >> 12, 0); - // hook fees can still be set - assertEq(uint16(slot0.hookFees >> 12), 0); - hook.setSwapFee(uninitializedKey, 3000); - manager.setHookFees(uninitializedKey); - - (slot0,,,) = manager.pools(uninitializedKey.toId()); - assertEq(uint16(slot0.hookFees >> 12), 3000); + assertEq(slot0.protocolFee, 0); } function test_initialize_gas() public {