Skip to content

Commit

Permalink
fix(YAUDIT-COVE-18): reverse traversing in previewMints and `previe…
Browse files Browse the repository at this point in the history
…wWithdraws` (#298)

* fix(YAUDIT-COVE-18): change ordering of preview mint to be reversed

* fix: reverse traversing in previewWithdraw

---------

Co-authored-by: Sunil Srivatsa <[email protected]>
  • Loading branch information
penandlim and alphastorm authored Mar 28, 2024
1 parent 637e7ed commit 3a40ac3
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 22 deletions.
39 changes: 22 additions & 17 deletions src/Yearn4626RouterExt.sol
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ contract Yearn4626RouterExt is IYearn4626RouterExt, Yearn4626Router {
* @param path The array of addresses that represents the path from input to output.
* @param sharesOut The amount of shares to mint from the last vault.
* @return assetsIn The amount of assets required at each step. The length of the array is `path.length - 1`.
* @dev sharesOut is the expected result at the last vault, and the path = [tokenIn, vault0, vault1, ..., vaultN].
* First calculate the amount of assets in to get the desired sharesOut from the last vault, then using that amount
* as the next sharesOut to get the amount of assets in for the penultimate vault.
*/
function previewMints(
address[] calldata path,
Expand All @@ -272,21 +275,21 @@ contract Yearn4626RouterExt is IYearn4626RouterExt, Yearn4626Router {
if (path.length < 2) revert PreviewPathIsTooShort();
uint256 assetsInLength = path.length - 1;
assetsIn = new uint256[](assetsInLength);
for (uint256 i; i < assetsInLength;) {
address vault = path[i + 1];
for (uint256 i = assetsInLength; i > 0;) {
address vault = path[i];
if (!Address.isContract(vault)) {
revert PreviewNonVaultAddressInPath(vault);
}
address vaultAsset = address(0);
(bool success, bytes memory data) = vault.staticcall(abi.encodeCall(IERC4626.asset, ()));
if (success) {
vaultAsset = abi.decode(data, (address));
assetsIn[i] = IERC4626(vault).previewMint(sharesOut);
assetsIn[i - 1] = IERC4626(vault).previewMint(sharesOut);
} else {
(success, data) = vault.staticcall(abi.encodeCall(IYearnVaultV2.token, ()));
if (success) {
vaultAsset = abi.decode(data, (address));
assetsIn[i] = Math.mulDiv(
assetsIn[i - 1] = Math.mulDiv(
sharesOut,
IYearnVaultV2(vault).pricePerShare(),
10 ** IERC20Metadata(vault).decimals(),
Expand All @@ -297,26 +300,28 @@ contract Yearn4626RouterExt is IYearn4626RouterExt, Yearn4626Router {
}
}

if (vaultAsset != path[i]) {
if (vaultAsset != path[i - 1]) {
revert PreviewVaultMismatch();
}
sharesOut = assetsIn[i];
sharesOut = assetsIn[i - 1];

/// @dev Increment the loop counter within an unchecked block to avoid redundant gas cost associated with
/// overflow checking. This is safe because the loop's exit condition ensures that `i` will not exceed
/// `assetsInLength - 1`, preventing overflow.
/// @dev Decrement the loop counter within an unchecked block to avoid redundant gas cost associated with
/// underflow checking. This is safe because the loop's initialization and exit condition ensure that `i`
/// will not underflow.
unchecked {
++i;
--i;
}
}
}

/**
* @notice Calculate the amount of shares required to withdraw a given amount of assets from a series of withdraws
* from
* ERC4626 vaults or Yearn Vault V2.
* from ERC4626 vaults or Yearn Vault V2.
* @param path The array of addresses that represents the path from input to output.
* @param assetsOut The amount of assets to withdraw from the last vault.
* @dev assetsOut is the desired result of the output token, and the path = [vault0, vault1, ..., vaultN, tokenOut].
* First calculate the amount of shares in to get the desired assetsOut from the last vault, then using that amount
* as the next assetsOut to get the amount of shares in for the penultimate vault.
* @return sharesIn The amount of shares required at each step. The length of the array is `path.length - 1`.
*/
function previewWithdraws(
Expand All @@ -330,7 +335,7 @@ contract Yearn4626RouterExt is IYearn4626RouterExt, Yearn4626Router {
if (path.length < 2) revert PreviewPathIsTooShort();
uint256 sharesInLength = path.length - 1;
sharesIn = new uint256[](sharesInLength);
for (uint256 i; i < sharesInLength;) {
for (uint256 i = path.length - 2;;) {
address vault = path[i];
if (!Address.isContract(vault)) {
revert PreviewNonVaultAddressInPath(vault);
Expand Down Expand Up @@ -365,13 +370,13 @@ contract Yearn4626RouterExt is IYearn4626RouterExt, Yearn4626Router {
if (vaultAsset != path[i + 1]) {
revert PreviewVaultMismatch();
}
if (i == 0) return sharesIn;
assetsOut = sharesIn[i];

/// @dev Increment the loop counter without checking for overflow. This is safe because the for loop
/// naturally ensures that `i` will not overflow as it is bounded by `sharesInLength`, which is derived from
/// the length of the `path` array.
/// @dev Decrement the loop counter without checking for overflow. This is safe because the for loop
/// naturally ensures that `i` will not underflow as it is bounded by i == 0 check.
unchecked {
++i;
--i;
}
}
}
Expand Down
96 changes: 91 additions & 5 deletions test/forked/Router.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import { IPermit2 } from "permit2/interfaces/IPermit2.sol";
import { IERC20Permit } from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Permit.sol";
import { IStakeDaoGauge } from "src/interfaces/deps/stakeDAO/IStakeDaoGauge.sol";
import { ERC20 } from "solmate/tokens/ERC20.sol";
import { ERC4626Mock } from "@openzeppelin/contracts/mocks/ERC4626Mock.sol";
import { ERC20Mock } from "@openzeppelin/contracts/mocks/ERC20Mock.sol";

contract Router_ForkedTest is BaseTest {
Yearn4626RouterExt public router;
Expand Down Expand Up @@ -305,7 +307,48 @@ contract Router_ForkedTest is BaseTest {
uint256[] memory assetsIn = router.previewMints(path, shareOutAmount);
assertEq(assetsIn.length, 2);
assertEq(assetsIn[0], 1 ether);
assertEq(assetsIn[1], 1 ether);
assertEq(assetsIn[1], 949_289_266_142_683_599);
}

function test_previewMints_Multiple4626() public {
ERC20Mock baseAsset = new ERC20Mock();
ERC4626Mock mock1 = new ERC4626Mock(address(baseAsset));
ERC4626Mock mock2 = new ERC4626Mock(address(mock1));
ERC4626Mock mock3 = new ERC4626Mock(address(mock2));

baseAsset.mint(address(this), 10e18);
baseAsset.approve(address(mock1), 10e18);
mock1.approve(address(mock2), 10e18);
mock2.approve(address(mock3), 10e18);

mock1.deposit(2e18, address(this));
baseAsset.transfer(address(mock1), 1e18);

mock2.deposit(1e18, address(this));
mock1.transfer(address(mock2), 1e18);

mock3.deposit(0.5e18, address(this));
mock2.transfer(address(mock3), 0.5e18);

uint256 expectedAssetIn2 = mock3.previewMint(1e18);
uint256 expectedAssetIn1 = mock2.previewMint(expectedAssetIn2);
uint256 expectedBaseAssetIn = mock1.previewMint(expectedAssetIn1);

assertEq(expectedBaseAssetIn, 5_999_999_999_999_999_995);
assertEq(expectedAssetIn1, 3_999_999_999_999_999_997);
assertEq(expectedAssetIn2, 1_999_999_999_999_999_999);

address[] memory path = new address[](4);
path[0] = address(baseAsset);
path[1] = address(mock1);
path[2] = address(mock2);
path[3] = address(mock3);

uint256[] memory assetsIn = router.previewMints(path, 1e18);
assertEq(assetsIn.length, 3);
assertEq(assetsIn[0], expectedBaseAssetIn);
assertEq(assetsIn[1], expectedAssetIn1);
assertEq(assetsIn[2], expectedAssetIn2);
}

function test_previewMints_v2Vault() public {
Expand All @@ -314,9 +357,9 @@ contract Router_ForkedTest is BaseTest {
path[0] = MAINNET_USDC;
path[1] = MAINNET_YVUSDC_VAULT_V2;

uint256[] memory sharesOut = router.previewMints(path, assetInAmount);
assertEq(sharesOut.length, 1);
assertEq(sharesOut[0], 1_058_324);
uint256[] memory assetsIn = router.previewMints(path, assetInAmount);
assertEq(assetsIn.length, 1);
assertEq(assetsIn[0], 1_058_324);
}

function test_previewMints_revertWhen_PreviewPathIsTooShort() public {
Expand Down Expand Up @@ -368,10 +411,53 @@ contract Router_ForkedTest is BaseTest {

uint256[] memory sharesIn = router.previewWithdraws(path, assetOutAmount);
assertEq(sharesIn.length, 2);
assertEq(sharesIn[0], 999_999_999_999_999_999);
assertEq(sharesIn[0], 949_289_266_142_683_600);
assertEq(sharesIn[1], 949_289_266_142_683_600);
}

function test_previewWithdraws_Multiple4626() public {
// asset to vault flow:
// baseAsset -> mock1 -> mock2 -> mock3
ERC20Mock baseAsset = new ERC20Mock();
ERC4626Mock mock1 = new ERC4626Mock(address(baseAsset));
ERC4626Mock mock2 = new ERC4626Mock(address(mock1));
ERC4626Mock mock3 = new ERC4626Mock(address(mock2));

baseAsset.mint(address(this), 10e18);
baseAsset.approve(address(mock1), 10e18);
mock1.approve(address(mock2), 10e18);
mock2.approve(address(mock3), 10e18);

mock1.deposit(2e18, address(this));
baseAsset.transfer(address(mock1), 1e18);

mock2.deposit(1e18, address(this));
mock1.transfer(address(mock2), 1e18);

mock3.deposit(0.5e18, address(this));
mock2.transfer(address(mock3), 0.5e18);

uint256 expectedShareIn1 = mock1.previewWithdraw(1e18);
uint256 expectedShareIn2 = mock2.previewWithdraw(expectedShareIn1);
uint256 expectedShareIn3 = mock3.previewWithdraw(expectedShareIn2);

assertEq(expectedShareIn3, 166_666_666_666_666_668);
assertEq(expectedShareIn2, 333_333_333_333_333_334);
assertEq(expectedShareIn1, 666_666_666_666_666_667);

address[] memory path = new address[](4);
path[0] = address(mock3);
path[1] = address(mock2);
path[2] = address(mock1);
path[3] = address(baseAsset);

uint256[] memory sharesIn = router.previewWithdraws(path, 1e18);
assertEq(sharesIn.length, 3);
assertEq(sharesIn[0], expectedShareIn3);
assertEq(sharesIn[1], expectedShareIn2);
assertEq(sharesIn[2], expectedShareIn1);
}

function test_previewWithdraws_v2Vault() public {
uint256 assetInAmount = 1e6;
address[] memory path = new address[](2);
Expand Down

0 comments on commit 3a40ac3

Please sign in to comment.