Skip to content

Commit

Permalink
Fix RedeemOptimizerFIFO bug with partial withdraws
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasia committed Oct 6, 2024
1 parent 3092975 commit f7030f5
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 74 deletions.
1 change: 1 addition & 0 deletions packages/contracts/.gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Compiler files
cache/
out/
test-reports/

# Ignores development broadcast logs
!/broadcast
Expand Down
17 changes: 4 additions & 13 deletions packages/contracts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,10 @@ forge test
## Advanced Testing

### Code Coverage Summary
```bash
forge coverage
```

### Code Coverage Report
Pre-requisite: install [genhtml](https://manpages.ubuntu.com/manpages/focal/man1/genhtml.1.html)

```bash
forge coverage --report lcov

# to ignore errors, add flags "--ignore-errors inconsistent" or "--ignore-errors corrupt"
genhtml lcov.info -o out/coverage --ignore-errors inconsistent
```
Pre-requisite: install [genhtml](https://manpages.ubuntu.com/manpages/focal/man1/genhtml.1.html)
```bash
yarn coverage
```

## Reset Submodules

Expand Down
1 change: 1 addition & 0 deletions packages/contracts/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"dev": "yarn rm-dbdata && anvil --config-out localhost.json & make deploy-local",
"build": "forge build && yarn gen-types",
"test": "forge test",
"coverage": "forge coverage --report lcov && genhtml lcov.info -o out/test-reports/coverage --ignore-errors inconsistent",
"format": "forge fmt && prettier './script/**/*.js' --write",
"lint": "forge fmt && eslint --fix --ignore-path .gitignore && yarn solhint './*(test|src)/**/*.sol'",
"db-check": "tsc && node ./script/utils/checkDb.js",
Expand Down
58 changes: 37 additions & 21 deletions packages/contracts/src/token/ERC1155/RedeemOptimizerFIFO.sol
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ pragma solidity ^0.8.20;

import { IMultiTokenVault } from "@credbull/token/ERC1155/IMultiTokenVault.sol";
import { IRedeemOptimizer } from "@credbull/token/ERC1155/IRedeemOptimizer.sol";
import { Math } from "@openzeppelin/contracts/utils/math/Math.sol";

/**
* @title RedeemOptimizerFIFO
* @dev Optimizes the redemption of shares using a FIFO strategy.
*/
contract RedeemOptimizerFIFO is IRedeemOptimizer {
using Math for uint256;

error RedeemOptimizer__InvalidDepositPeriodRange(uint256 fromPeriod, uint256 toPeriod);
error RedeemOptimizer__FutureToDepositPeriod(uint256 toPeriod, uint256 currentPeriod);
error RedeemOptimizer__OptimizerFailed(uint256 amountFound, uint256 amountToFind);
Expand Down Expand Up @@ -59,50 +62,63 @@ contract RedeemOptimizerFIFO is IRedeemOptimizer {
}

/// @notice Returns deposit periods and corresponding amounts (shares or assets) within the specified range.
function _findAmount(IMultiTokenVault vault, OptimizerParams memory params)
function _findAmount(IMultiTokenVault vault, OptimizerParams memory optimizerParams)
internal
view
returns (uint256[] memory depositPeriods, uint256[] memory sharesAtPeriods)
{
if (params.fromDepositPeriod > params.toDepositPeriod) {
revert RedeemOptimizer__InvalidDepositPeriodRange(params.fromDepositPeriod, params.toDepositPeriod);
if (optimizerParams.fromDepositPeriod > optimizerParams.toDepositPeriod) {
revert RedeemOptimizer__InvalidDepositPeriodRange(
optimizerParams.fromDepositPeriod, optimizerParams.toDepositPeriod
);
}

if (params.toDepositPeriod > vault.currentPeriodsElapsed()) {
revert RedeemOptimizer__FutureToDepositPeriod(params.toDepositPeriod, vault.currentPeriodsElapsed());
if (optimizerParams.toDepositPeriod > vault.currentPeriodsElapsed()) {
revert RedeemOptimizer__FutureToDepositPeriod(
optimizerParams.toDepositPeriod, vault.currentPeriodsElapsed()
);
}

// Create local caching arrays that can contain the maximum number of results.
uint256[] memory cacheDepositPeriods = new uint256[]((params.toDepositPeriod - params.fromDepositPeriod) + 1);
uint256[] memory cacheSharesAtPeriods = new uint256[]((params.toDepositPeriod - params.fromDepositPeriod) + 1);
uint256[] memory cacheDepositPeriods =
new uint256[]((optimizerParams.toDepositPeriod - optimizerParams.fromDepositPeriod) + 1);
uint256[] memory cacheSharesAtPeriods =
new uint256[]((optimizerParams.toDepositPeriod - optimizerParams.fromDepositPeriod) + 1);

uint256 arrayIndex = 0;
uint256 amountFound = 0;

// Iterate over the from/to period range, inclusive of from and to.
for (uint256 depositPeriod = params.fromDepositPeriod; depositPeriod <= params.toDepositPeriod; ++depositPeriod)
{
uint256 sharesAtPeriod = vault.sharesAtPeriod(params.owner, depositPeriod);

uint256 amountAtPeriod = params.amountType == AmountType.Shares
for (
uint256 depositPeriod = optimizerParams.fromDepositPeriod;
depositPeriod <= optimizerParams.toDepositPeriod;
++depositPeriod
) {
uint256 sharesAtPeriod = vault.sharesAtPeriod(optimizerParams.owner, depositPeriod);

uint256 amountAtPeriod = optimizerParams.amountType == AmountType.Shares
? sharesAtPeriod
: vault.convertToAssetsForDepositPeriod(sharesAtPeriod, depositPeriod, params.redeemPeriod);
: vault.convertToAssetsForDepositPeriod(sharesAtPeriod, depositPeriod, optimizerParams.redeemPeriod);

// If there is an Amount, store the value.
if (amountAtPeriod > 0) {
cacheDepositPeriods[arrayIndex] = depositPeriod;

// check if we will go "over" the Amount To Find.
if (amountFound + amountAtPeriod > params.amountToFind) {
uint256 amountToInclude = params.amountToFind - amountFound; // we only need the amount that brings us to amountToFind
if (amountFound + amountAtPeriod > optimizerParams.amountToFind) {
uint256 amountToInclude = optimizerParams.amountToFind - amountFound; // we only need the amount that brings us to amountToFind

// the assets here include principal + interest. to utilize convertToShares() we want principal amount only
// the following ratio holds: partialShares/totalShares = partialReturns/totalReturns
// so, partialShares = ((partialReturns * totalShares) / totalReturns)
uint256 sharesToInclude = sharesAtPeriod.mulDiv(amountToInclude, amountAtPeriod); // also works for shares, when shares = amounts, sharesAtPeriod / sharesAtPeriod = 1

// only include equivalent amount of shares for the amountToInclude assets
cacheSharesAtPeriods[arrayIndex] = params.amountType == AmountType.Shares
? amountToInclude
: vault.convertToAssetsForDepositPeriod(amountToInclude, depositPeriod, params.redeemPeriod);
cacheSharesAtPeriods[arrayIndex] =
optimizerParams.amountType == AmountType.Shares ? amountToInclude : sharesToInclude;

// optimization succeeded - return here to be explicit we exit the function at this point
return _trimToSize(arrayIndex, cacheDepositPeriods, cacheSharesAtPeriods);
return _trimToSize(arrayIndex + 1, cacheDepositPeriods, cacheSharesAtPeriods);
} else {
cacheSharesAtPeriods[arrayIndex] = sharesAtPeriod;
}
Expand All @@ -112,8 +128,8 @@ contract RedeemOptimizerFIFO is IRedeemOptimizer {
}
}

if (amountFound < params.amountToFind) {
revert RedeemOptimizer__OptimizerFailed(amountFound, params.amountToFind);
if (amountFound < optimizerParams.amountToFind) {
revert RedeemOptimizer__OptimizerFailed(amountFound, optimizerParams.amountToFind);
}

return _trimToSize(arrayIndex, cacheDepositPeriods, cacheSharesAtPeriods);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ contract MultiTokenVaultTest is IMultiTokenVaultTestBase {

IMultiTokenVault vault = _createMultiTokenVault(_asset, assetToSharesRatio, 10);

uint256[] memory shares = _testDepositOnly(_alice, vault, testParamsArray.getAll());
uint256[] memory depositPeriods = testParamsArray.getAllDepositPeriods();
uint256[] memory shares = _testDepositOnly(_alice, vault, testParamsArray.all());
uint256[] memory depositPeriods = testParamsArray.depositPeriods();

// ------------------------ batch convert to assets ------------------------
uint256[] memory assets = vault.convertToAssetsForDepositPeriods(shares, depositPeriods, redeemPeriod);
Expand Down Expand Up @@ -244,8 +244,8 @@ contract MultiTokenVaultTest is IMultiTokenVaultTestBase {
IMTVTestParamArray testParamsArray,
uint256 assetToSharesRatio
) internal view returns (uint256[] memory balances_) {
address[] memory accounts = testParamsArray.accountArray(account, testParamsArray.length());
uint256[] memory balances = vault.balanceOfBatch(accounts, testParamsArray.getAllDepositPeriods());
address[] memory accounts = testParamsArray.createAccountArray(account, testParamsArray.length());
uint256[] memory balances = vault.balanceOfBatch(accounts, testParamsArray.depositPeriods());
assertEq(3, balances.length, "balances size incorrect");

assertEq(testParamsArray.get(0).principal / assetToSharesRatio, balances[0], "balance mismatch period 0");
Expand Down
116 changes: 87 additions & 29 deletions packages/contracts/test/src/token/ERC1155/RedeemOptimizerTest.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,121 @@ import { RedeemOptimizerFIFO } from "@credbull/token/ERC1155/RedeemOptimizerFIFO
import { IMultiTokenVault } from "@credbull/token/ERC1155/IMultiTokenVault.sol";

import { MultiTokenVaultTest } from "@test/src/token/ERC1155/MultiTokenVaultTest.t.sol";
import { IMTVTestParamArray } from "@test/test/token/ERC1155/IMTVTestParamArray.t.sol";

contract RedeemOptimizerTest is MultiTokenVaultTest {
address private _owner = makeAddr("owner");
address private _alice = makeAddr("alice");

IMTVTestParamArray private testParamsArr;

function setUp() public override {
super.setUp();

testParamsArr = new IMTVTestParamArray();
testParamsArr.addTestParam(_testParams1);
testParamsArr.addTestParam(_testParams2);
testParamsArr.addTestParam(_testParams3);
}

// Scenario: Calculating returns for a standard investment
function test__RedeemOptimizerTest__RedeemAllShares() public {
uint256 assetToSharesRatio = 2;

// setup
IMultiTokenVault multiTokenVault = _createMultiTokenVault(_asset, assetToSharesRatio, 10);
IRedeemOptimizer redeemOptimizer = new RedeemOptimizerFIFO(multiTokenVault.currentPeriodsElapsed());

(, uint256[] memory depositShares) = _testDeposits(_alice, multiTokenVault); // make a few deposits
uint256[] memory depositShares = _testDepositOnly(_alice, multiTokenVault, testParamsArr.all());
uint256 totalDepositShares = depositShares[0] + depositShares[1] + depositShares[2];

// warp vault ahead redeemPeriod
// warp vault ahead to redeemPeriod
uint256 redeemPeriod = _testParams3.redeemPeriod;
_warpToPeriod(multiTokenVault, redeemPeriod);

// check full redeem
(uint256[] memory redeemDepositPeriods, uint256[] memory sharesAtPeriods) =
redeemOptimizer.optimizeRedeemShares(multiTokenVault, _alice, totalDepositShares, redeemPeriod);

assertEq(3, redeemDepositPeriods.length, "depositPeriods wrong length - full redeem");
assertEq(3, sharesAtPeriods.length, "sharesAtPeriods wrong length - full redeem");
assertEq(testParamsArr.depositPeriods(), redeemDepositPeriods, "optimizeRedeem - depositPeriods not correct");
assertEq(depositShares, sharesAtPeriods, "optimizeRedeem - shares not correct");
}

function test__RedeemOptimizerTest__WithdrawAllShares() public {
uint256 assetToSharesRatio = 2;
uint256 redeemPeriod = _testParams3.redeemPeriod;

// setup
IMultiTokenVault multiTokenVault = _createMultiTokenVault(_asset, assetToSharesRatio, 10);
IRedeemOptimizer redeemOptimizer = new RedeemOptimizerFIFO(multiTokenVault.currentPeriodsElapsed());

uint256[] memory depositShares = _testDepositOnly(_alice, multiTokenVault, testParamsArr.all());
uint256[] memory depositAssets = multiTokenVault.convertToAssetsForDepositPeriods(
depositShares, testParamsArr.depositPeriods(), redeemPeriod
);
assertEq(depositShares.length, depositAssets.length, "mismatch in convertToAssets");
uint256 totalAssets = depositAssets[0] + depositAssets[1] + depositAssets[2];

// warp vault ahead to redeemPeriod
_warpToPeriod(multiTokenVault, redeemPeriod);

// check full withdraw
(uint256[] memory withdrawDepositPeriods, uint256[] memory sharesAtPeriods) =
redeemOptimizer.optimizeWithdrawAssets(multiTokenVault, _alice, totalAssets, redeemPeriod);

assertEq(testParamsArr.depositPeriods(), withdrawDepositPeriods, "optimizeRedeem - depositPeriods not correct");
assertEq(depositShares, sharesAtPeriods, "optimizeRedeem - shares not correct");
}

function test__RedeemOptimizerTest__PartialWithdraw() public {
uint256 assetToSharesRatio = 2;
uint256 redeemPeriod = _testParams3.redeemPeriod;

assertEq(_testParams1.depositPeriod, redeemDepositPeriods[0], "optimizeRedeem - wrong depositPeriod 0");
assertEq(depositShares[0], sharesAtPeriods[0], "optimizeRedeem - wrong shares 0");
// ---------------------- setup ----------------------
IMultiTokenVault multiTokenVault = _createMultiTokenVault(_asset, assetToSharesRatio, 10);
IRedeemOptimizer redeemOptimizer = new RedeemOptimizerFIFO(multiTokenVault.currentPeriodsElapsed());

assertEq(_testParams2.depositPeriod, redeemDepositPeriods[1], "optimizeRedeem - wrong depositPeriod 1");
assertEq(depositShares[1], sharesAtPeriods[1], "optimizeRedeem - wrong shares 1");
uint256[] memory depositShares = _testDepositOnly(_alice, multiTokenVault, testParamsArr.all());
uint256[] memory depositAssets = multiTokenVault.convertToAssetsForDepositPeriods(
depositShares, testParamsArr.depositPeriods(), redeemPeriod
);
assertEq(depositShares.length, depositAssets.length, "mismatch in convertToAssets");
uint256 totalAssets = depositAssets[0] + depositAssets[1] + depositAssets[2];

assertEq(
(_testParams1.principal + _testParams2.principal + _testParams3.principal) / 2,
depositShares[0] + depositShares[1] + depositShares[2],
"shares are wrong"
);

assertEq(_testParams3.depositPeriod, redeemDepositPeriods[2], "optimizeRedeem - wrong depositPeriod 2");
assertEq(depositShares[2], sharesAtPeriods[2], "optimizeRedeem - wrong shares 2");
uint256 oneAssetWithReturns = multiTokenVault.convertToAssetsForDepositPeriod(
1 * _scale / assetToSharesRatio, _testParams3.depositPeriod, redeemPeriod
);
uint256 assetsToWithdraw = totalAssets - oneAssetWithReturns;

// ---------------------- redeem ----------------------
_warpToPeriod(multiTokenVault, redeemPeriod); // warp vault ahead to redeemPeriod

(uint256[] memory actualDepositPeriods, uint256[] memory actualSharesAtPeriods) =
redeemOptimizer.optimizeWithdrawAssets(multiTokenVault, _alice, assetsToWithdraw, redeemPeriod);

// first two periods should be fully withdrawn - third period should be partial
assertEq(depositShares[0], actualSharesAtPeriods[0], "optimizeRedeem - wrong shares 0");
assertEq(depositShares[1], actualSharesAtPeriods[1], "optimizeRedeem - wrong shares 1");
// TODO - add check in for partial, should be full amount - equivalent of one asset

// convert shares to asset equivalent for further validation
uint256[] memory actualAssetsAtPeriods =
multiTokenVault.convertToAssetsForDepositPeriods(actualSharesAtPeriods, actualDepositPeriods, redeemPeriod);
assertEq(
testParamsArr.depositPeriods().length,
actualAssetsAtPeriods.length,
"convertToAssetsForDepositPeriods (partial) - length incorrect"
);
assertEq(
assetsToWithdraw,
actualAssetsAtPeriods[0] + actualAssetsAtPeriods[1] + actualAssetsAtPeriods[2],
"convertToAssetsForDepositPeriods (partial) - total incorrect"
);
}

function test__RedeemOptimizerTest__InsufficientSharesShouldRevert() public {
Expand Down Expand Up @@ -117,24 +193,6 @@ contract RedeemOptimizerTest is MultiTokenVaultTest {
})
);
}

function _testDeposits(address receiver, IMultiTokenVault vault)
internal
returns (uint256[] memory depositPeriods_, uint256[] memory shares_)
{
uint256[] memory depositPeriods = new uint256[](3);
uint256[] memory shares = new uint256[](3);

depositPeriods[0] = _testParams1.depositPeriod;
depositPeriods[1] = _testParams2.depositPeriod;
depositPeriods[2] = _testParams3.depositPeriod;

shares[0] = _testDepositOnly(receiver, vault, _testParams1);
shares[1] = _testDepositOnly(receiver, vault, _testParams2);
shares[2] = _testDepositOnly(receiver, vault, _testParams3);

return (depositPeriods, shares);
}
}

contract RedeemOptimizerFIFOMock is RedeemOptimizerFIFO {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,21 @@ contract IMTVTestParamArray {
return _allTestParams[index];
}

function getAll() public view returns (IMultiTokenVaultTestBase.TestParam[] memory testParamArr) {
function all() public view returns (IMultiTokenVaultTestBase.TestParam[] memory testParamArr) {
return _allTestParams;
}

function getAllDepositPeriods() public view returns (uint256[] memory depositPeriods_) {
uint256[] memory depositPeriods = new uint256[](_allTestParams.length);
function depositPeriods() public view returns (uint256[] memory depositPeriods_) {
uint256[] memory _depositPeriods = new uint256[](_allTestParams.length);

for (uint256 i = 0; i < _allTestParams.length; i++) {
depositPeriods[i] = _allTestParams[i].depositPeriod;
_depositPeriods[i] = _allTestParams[i].depositPeriod;
}

return depositPeriods;
return _depositPeriods;
}

function accountArray(address account, uint256 size) public pure returns (address[] memory accounts_) {
function createAccountArray(address account, uint256 size) public pure returns (address[] memory accounts_) {
address[] memory accounts = new address[](size);

for (uint256 i = 0; i < size; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ abstract contract IMultiTokenVaultTestBase is Test {
function testVaultAtOffsets(address account, IMultiTokenVault vault, TestParam memory testParam) internal {
IMTVTestParamArray testParamsArr = new IMTVTestParamArray();
testParamsArr.initUsingOffsets(testParam);
testVaultAtAllPeriods(account, vault, testParamsArr.getAll());
testVaultAtAllPeriods(account, vault, testParamsArr.all());
}

/// @dev test Vault at specified redeemPeriod and other "interesting" redeem periods
Expand Down

0 comments on commit f7030f5

Please sign in to comment.