diff --git a/src/balancer/BalancerController.sol b/src/balancer/BalancerController.sol index 4ed940c..8fb2ceb 100644 --- a/src/balancer/BalancerController.sol +++ b/src/balancer/BalancerController.sol @@ -177,21 +177,35 @@ contract BalancerController is IController { returns (bool, address[] memory, address[] memory) { ( - , + IVault.SwapKind kind, IVault.BatchSwapStep[] memory swaps, IAsset[] memory assets, , , ) = abi.decode(data, ( - uint8, IVault.BatchSwapStep[], IAsset[], IVault.FundManagement, uint256[], uint256 + IVault.SwapKind, + IVault.BatchSwapStep[], + IAsset[], + IVault.FundManagement, + uint256[], + uint256 ) ); - if (!isMultiHopSwap(swaps)) - return (false, new address[](0), new address[](0)); - - uint tokenInIndex = swaps[swaps.length - 1].assetOutIndex; - uint tokenOutIndex = swaps[0].assetInIndex; + uint tokenInIndex; + uint tokenOutIndex; + + if (kind == IVault.SwapKind.GIVEN_IN) { + if (!isMultiHopSwapGivenIn(swaps)) + return (false, new address[](0), new address[](0)); + tokenInIndex = swaps[swaps.length - 1].assetOutIndex; + tokenOutIndex = swaps[0].assetInIndex; + } else { + if (!isMultiHopSwapGivenOut(swaps)) + return (false, new address[](0), new address[](0)); + tokenOutIndex = swaps[swaps.length - 1].assetInIndex; + tokenInIndex = swaps[0].assetOutIndex; + } address[] memory tokensIn; address[] memory tokensOut; @@ -228,14 +242,33 @@ contract BalancerController is IController { ); } - function isMultiHopSwap(IVault.BatchSwapStep[] memory swaps) + function isMultiHopSwapGivenIn(IVault.BatchSwapStep[] memory swaps) + internal + pure + returns (bool) + { + uint steps = swaps.length; + for (uint i; i < steps - 1; i++) { + if ( + swaps[i].assetOutIndex != swaps[i+1].assetInIndex || + swaps[i+1].amount > 0 + ) + return false; + } + return true; + } + + function isMultiHopSwapGivenOut(IVault.BatchSwapStep[] memory swaps) internal pure returns (bool) { uint steps = swaps.length; for (uint i; i < steps - 1; i++) { - if (swaps[i].assetOutIndex != swaps[i+1].assetInIndex) + if ( + swaps[i].assetInIndex != swaps[i+1].assetOutIndex || + swaps[i+1].amount > 0 + ) return false; } return true; diff --git a/src/balancer/IVault.sol b/src/balancer/IVault.sol index 631c630..270d468 100644 --- a/src/balancer/IVault.sol +++ b/src/balancer/IVault.sol @@ -59,7 +59,7 @@ interface IVault { } function batchSwap( - uint8 kind, + SwapKind kind, BatchSwapStep[] memory swaps, IAsset[] memory assets, FundManagement memory funds, @@ -74,4 +74,6 @@ interface IVault { uint256 amount; bytes userData; } + + enum SwapKind { GIVEN_IN, GIVEN_OUT } } \ No newline at end of file diff --git a/src/tests/Balancer.t.sol b/src/tests/Balancer.t.sol index eddf6ad..b921b14 100644 --- a/src/tests/Balancer.t.sol +++ b/src/tests/Balancer.t.sol @@ -148,4 +148,161 @@ contract TestBalancer is TestBase { assertEq(tokensOut[0], token); assertEq(tokensIn.length, 1); } + + function testCanSwap() public { + // Setup + controllerFacade.toggleTokenAllowance(0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF); + + IVault.SingleSwap memory swap = IVault.SingleSwap( + "0", + 0, + IAsset(0xCfCA23cA9CA720B6E98E3Eb9B6aa0fFC4a5C08B9), + IAsset(0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF), + 0, + "0" + ); + + IVault.FundManagement memory funds = IVault.FundManagement( + address(0), + false, + payable(address(0)), + false + ); + + bytes memory data = abi.encodeWithSelector(0x52bbbe29, + swap, + funds, + 0, + 0 + ); + + // Test + (bool canCall, address[] memory tokensIn, address[] memory tokensOut) + = controllerFacade.canCall(vault, true, data); + + // Assert + assertTrue(canCall); + assertEq(tokensOut[0], 0xCfCA23cA9CA720B6E98E3Eb9B6aa0fFC4a5C08B9); + assertEq(tokensIn[0], 0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF); + assertEq(tokensIn.length, 1); + assertEq(tokensOut.length, 1); + } + + function testCanBatchSwapGivenIn() public { + // Setup + controllerFacade.toggleTokenAllowance(0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF); + + int256[] memory limits = new int256[](3); + + IAsset[] memory assets = new IAsset[](3); + assets[0] = IAsset(0xBA12222222228d8Ba445958a75a0704d566BF2C8); + assets[1] = IAsset(0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2); + assets[2] = IAsset(0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF); + + IVault.BatchSwapStep memory swap1 = IVault.BatchSwapStep( + "0", + 0, + 1, + 10, + "0" + ); + + IVault.BatchSwapStep memory swap2 = IVault.BatchSwapStep( + "0", + 1, + 2, + 0, + "0" + ); + + IVault.FundManagement memory funds = IVault.FundManagement( + address(0), + false, + payable(address(0)), + false + ); + + IVault.BatchSwapStep[] memory swaps = new IVault.BatchSwapStep[](2); + swaps[0] = swap1; + swaps[1] = swap2; + + bytes memory data = abi.encodeWithSelector(0x945bcec9, + IVault.SwapKind.GIVEN_IN, + swaps, + assets, + funds, + limits, + 0 + ); + + // Test + (bool canCall, address[] memory tokensIn, address[] memory tokensOut) + = controllerFacade.canCall(vault, true, data); + + // Assert + assertTrue(canCall); + assertEq(tokensOut[0], 0xBA12222222228d8Ba445958a75a0704d566BF2C8); + assertEq(tokensIn[0], 0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF); + assertEq(tokensIn.length, 1); + assertEq(tokensOut.length, 1); + } + + function testCanBatchSwapGivenOut() public { + // Setup + controllerFacade.toggleTokenAllowance(0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF); + + int256[] memory limits = new int256[](3); + + IAsset[] memory assets = new IAsset[](3); + assets[0] = IAsset(0xBA12222222228d8Ba445958a75a0704d566BF2C8); + assets[1] = IAsset(0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2); + assets[2] = IAsset(0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF); + + IVault.BatchSwapStep memory swap1 = IVault.BatchSwapStep( + "0", + 1, + 2, + 10, + "0" + ); + + IVault.BatchSwapStep memory swap2 = IVault.BatchSwapStep( + "0", + 0, + 1, + 0, + "0" + ); + + IVault.FundManagement memory funds = IVault.FundManagement( + address(0), + false, + payable(address(0)), + false + ); + + IVault.BatchSwapStep[] memory swaps = new IVault.BatchSwapStep[](2); + swaps[0] = swap1; + swaps[1] = swap2; + + bytes memory data = abi.encodeWithSelector(0x945bcec9, + IVault.SwapKind.GIVEN_OUT, + swaps, + assets, + funds, + limits, + 0 + ); + + // Test + (bool canCall, address[] memory tokensIn, address[] memory tokensOut) + = controllerFacade.canCall(vault, true, data); + + // Assert + assertTrue(canCall); + assertEq(tokensOut[0], 0xBA12222222228d8Ba445958a75a0704d566BF2C8); + assertEq(tokensIn[0], 0xC0c293ce456fF0ED870ADd98a0828Dd4d2903DBF); + assertEq(tokensIn.length, 1); + assertEq(tokensOut.length, 1); + } } \ No newline at end of file