diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f5a43b9a..0c823044a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,12 +35,14 @@ Ref: https://keepachangelog.com/en/1.0.0/ # Changelog -## v1.4.x - 2021.10.xx +## [Unreleased] + +## [v1.4.1](https://github.com/tendermint/liquidity/releases) - 2021.10.25 * [\#455](https://github.com/tendermint/liquidity/pull/455) (sdk) Bump SDK version to [v0.44.2](https://github.com/cosmos/cosmos-sdk/releases/tag/v0.44.2) * [\#446](https://github.com/tendermint/liquidity/pull/446) Fix: Pool Coin Decimal Truncation During Deposit +* [\#448](https://github.com/tendermint/liquidity/pull/448) Fix: add overflow checking and test codes for cover edge cases -## [Unreleased] ## [v1.4.0](https://github.com/tendermint/liquidity/releases/tag/v1.4.0) - 2021.09.07 diff --git a/x/liquidity/keeper/batch.go b/x/liquidity/keeper/batch.go index 948382523..4e261ac7d 100644 --- a/x/liquidity/keeper/batch.go +++ b/x/liquidity/keeper/batch.go @@ -236,7 +236,14 @@ func (k Keeper) WithdrawWithinBatch(ctx sdk.Context, msg *types.MsgWithdrawWithi // In order to deal with the batch at the same time, the coins of msgs are deposited in escrow. func (k Keeper) SwapWithinBatch(ctx sdk.Context, msg *types.MsgSwapWithinBatch, orderExpirySpanHeight int64) (*types.SwapMsgState, error) { - if err := k.ValidateMsgSwapWithinBatch(ctx, *msg); err != nil { + pool, found := k.GetPool(ctx, msg.PoolId) + if !found { + return nil, types.ErrPoolNotExists + } + if k.IsDepletedPool(ctx, pool) { + return nil, types.ErrDepletedPool + } + if err := k.ValidateMsgSwapWithinBatch(ctx, *msg, pool); err != nil { return nil, err } poolBatch, found := k.GetPoolBatch(ctx, msg.PoolId) diff --git a/x/liquidity/keeper/liquidity_pool.go b/x/liquidity/keeper/liquidity_pool.go index 8391079ed..961fe4c14 100644 --- a/x/liquidity/keeper/liquidity_pool.go +++ b/x/liquidity/keeper/liquidity_pool.go @@ -245,12 +245,18 @@ func (k Keeper) ExecuteDeposit(ctx sdk.Context, msg types.DepositMsgState, batch depositCoinA := depositCoins[0] depositCoinB := depositCoins[1] - poolCoinTotalSupply := k.GetPoolCoinTotalSupply(ctx, pool) + poolCoinTotalSupply := k.GetPoolCoinTotalSupply(ctx, pool).ToDec() + if err := types.CheckOverflowWithDec(poolCoinTotalSupply, depositCoinA.Amount.ToDec()); err != nil { + return err + } + if err := types.CheckOverflowWithDec(poolCoinTotalSupply, depositCoinB.Amount.ToDec()); err != nil { + return err + } poolCoinMintAmt := sdk.MinDec( - poolCoinTotalSupply.ToDec().MulTruncate(depositCoinA.Amount.ToDec()).QuoTruncate(lastReserveCoinA.Amount.ToDec()), - poolCoinTotalSupply.ToDec().MulTruncate(depositCoinB.Amount.ToDec()).QuoTruncate(lastReserveCoinB.Amount.ToDec()), + poolCoinTotalSupply.MulTruncate(depositCoinA.Amount.ToDec()).QuoTruncate(lastReserveCoinA.Amount.ToDec()), + poolCoinTotalSupply.MulTruncate(depositCoinB.Amount.ToDec()).QuoTruncate(lastReserveCoinB.Amount.ToDec()), ) - mintRate := poolCoinMintAmt.TruncateDec().QuoTruncate(poolCoinTotalSupply.ToDec()) + mintRate := poolCoinMintAmt.TruncateDec().QuoTruncate(poolCoinTotalSupply) acceptedCoins := sdk.NewCoins( sdk.NewCoin(depositCoins[0].Denom, lastReserveCoinA.Amount.ToDec().Mul(mintRate).TruncateInt()), sdk.NewCoin(depositCoins[1].Denom, lastReserveCoinB.Amount.ToDec().Mul(mintRate).TruncateInt()), @@ -301,7 +307,7 @@ func (k Keeper) ExecuteDeposit(ctx sdk.Context, msg types.DepositMsgState, batch afterReserveCoinA := afterReserveCoins[0].Amount afterReserveCoinB := afterReserveCoins[1].Amount - MintingPoolCoinsInvariant(poolCoinTotalSupply, mintPoolCoin.Amount, depositCoinA.Amount, depositCoinB.Amount, + MintingPoolCoinsInvariant(poolCoinTotalSupply.TruncateInt(), mintPoolCoin.Amount, depositCoinA.Amount, depositCoinB.Amount, lastReserveCoinA.Amount, lastReserveCoinB.Amount, refundedCoinA.Amount, refundedCoinB.Amount) DepositInvariant(lastReserveCoinA.Amount, lastReserveCoinB.Amount, depositCoinA.Amount, depositCoinB.Amount, afterReserveCoinA, afterReserveCoinB, refundedCoinA.Amount, refundedCoinB.Amount) @@ -378,6 +384,12 @@ func (k Keeper) ExecuteWithdrawal(ctx sdk.Context, msg types.WithdrawMsgState, b } else { // Calculate withdraw amount of respective reserve coin considering fees and pool coin's totally supply for _, reserveCoin := range reserveCoins { + if err := types.CheckOverflow(reserveCoin.Amount, msg.Msg.PoolCoin.Amount); err != nil { + return err + } + if err := types.CheckOverflow(reserveCoin.Amount.Mul(msg.Msg.PoolCoin.Amount).ToDec().TruncateInt(), poolCoinTotalSupply); err != nil { + return err + } // WithdrawAmount = ReserveAmount * PoolCoinAmount * WithdrawFeeProportion / TotalSupply withdrawAmtWithFee := reserveCoin.Amount.Mul(msg.Msg.PoolCoin.Amount).ToDec().TruncateInt().Quo(poolCoinTotalSupply) withdrawAmt := reserveCoin.Amount.Mul(msg.Msg.PoolCoin.Amount).ToDec().MulTruncate(withdrawProportion).TruncateInt().Quo(poolCoinTotalSupply) @@ -770,21 +782,12 @@ func (k Keeper) ValidateMsgWithdrawWithinBatch(ctx sdk.Context, msg types.MsgWit } // ValidateMsgSwapWithinBatch validates MsgSwapWithinBatch. -func (k Keeper) ValidateMsgSwapWithinBatch(ctx sdk.Context, msg types.MsgSwapWithinBatch) error { - pool, found := k.GetPool(ctx, msg.PoolId) - if !found { - return types.ErrPoolNotExists - } - +func (k Keeper) ValidateMsgSwapWithinBatch(ctx sdk.Context, msg types.MsgSwapWithinBatch, pool types.Pool) error { denomA, denomB := types.AlphabeticalDenomPair(msg.OfferCoin.Denom, msg.DemandCoinDenom) if denomA != pool.ReserveCoinDenoms[0] || denomB != pool.ReserveCoinDenoms[1] { return types.ErrNotMatchedReserveCoin } - if k.IsDepletedPool(ctx, pool) { - return types.ErrDepletedPool - } - params := k.GetParams(ctx) // can not exceed max order ratio of reserve coins that can be ordered at a order @@ -800,6 +803,10 @@ func (k Keeper) ValidateMsgSwapWithinBatch(ctx sdk.Context, msg types.MsgSwapWit return types.ErrBadOfferCoinFee } + if err := types.CheckOverflowWithDec(msg.OfferCoin.Amount.ToDec(), msg.OrderPrice); err != nil { + return err + } + if !msg.OfferCoinFee.Equal(types.GetOfferCoinFee(msg.OfferCoin, params.SwapFeeRate)) { return types.ErrBadOfferCoinFee } diff --git a/x/liquidity/keeper/liquidity_pool_test.go b/x/liquidity/keeper/liquidity_pool_test.go index 84b34c529..bce4803d1 100644 --- a/x/liquidity/keeper/liquidity_pool_test.go +++ b/x/liquidity/keeper/liquidity_pool_test.go @@ -11,6 +11,7 @@ import ( "github.com/tendermint/liquidity/app" "github.com/tendermint/liquidity/x/liquidity" + "github.com/tendermint/liquidity/x/liquidity/keeper" "github.com/tendermint/liquidity/x/liquidity/types" ) @@ -74,6 +75,28 @@ func TestCreatePool(t *testing.T) { require.ErrorIs(t, err, types.ErrPoolAlreadyExists) } +func TestCreatePoolInsufficientAmount(t *testing.T) { + simapp, ctx := createTestInput() + params := simapp.LiquidityKeeper.GetParams(ctx) + + depositCoins := sdk.NewCoins(sdk.NewInt64Coin(DenomX, 1000), sdk.NewInt64Coin(DenomY, 1000)) + creator := app.AddRandomTestAddr(simapp, ctx, depositCoins.Add(params.PoolCreationFee...)) + + // Depositing coins that are less than params.MinInitDepositAmount. + _, err := simapp.LiquidityKeeper.CreatePool(ctx, types.NewMsgCreatePool(creator, types.DefaultPoolTypeID, depositCoins)) + require.ErrorIs(t, err, types.ErrLessThanMinInitDeposit) + + fakeDepositCoins := depositCoins.Add( + sdk.NewCoin(DenomX, params.MinInitDepositAmount), + sdk.NewCoin(DenomY, params.MinInitDepositAmount), + ) + // Depositing coins that are greater than the depositor has. + _, err = simapp.LiquidityKeeper.CreatePool( + ctx, types.NewMsgCreatePool(creator, types.DefaultPoolTypeID, fakeDepositCoins), + ) + require.ErrorIs(t, err, types.ErrInsufficientBalance) +} + func TestPoolCreationFee(t *testing.T) { simapp, ctx := createTestInput() simapp.LiquidityKeeper.SetParams(ctx, types.DefaultParams()) @@ -483,6 +506,73 @@ func TestExecuteWithdrawal(t *testing.T) { require.Equal(t, deposit.AmountOf(pool.ReserveCoinDenoms[1]), withdrawerDenomBBalance.Amount) } +func TestSmallWithdrawalCase(t *testing.T) { + simapp, ctx := createTestInput() + params := types.DefaultParams() + params.InitPoolCoinMintAmount = sdk.NewInt(1_000000_000000) + simapp.LiquidityKeeper.SetParams(ctx, params) + + poolTypeID := types.DefaultPoolTypeID + addrs := app.AddTestAddrs(simapp, ctx, 3, params.PoolCreationFee) + + denomA := "uETH" + denomB := "uUSD" + denomA, denomB = types.AlphabeticalDenomPair(denomA, denomB) + + deposit := sdk.NewCoins(sdk.NewCoin(denomA, sdk.NewInt(1250001*1000000)), sdk.NewCoin(denomB, sdk.NewInt(9*1000000))) + app.SaveAccount(simapp, ctx, addrs[0], deposit) + + depositA := simapp.BankKeeper.GetBalance(ctx, addrs[0], denomA) + depositB := simapp.BankKeeper.GetBalance(ctx, addrs[0], denomB) + depositBalance := sdk.NewCoins(depositA, depositB) + + require.Equal(t, deposit, depositBalance) + + createMsg := types.NewMsgCreatePool(addrs[0], poolTypeID, depositBalance) + + _, err := simapp.LiquidityKeeper.CreatePool(ctx, createMsg) + require.NoError(t, err) + + pools := simapp.LiquidityKeeper.GetAllPools(ctx) + pool := pools[0] + + // Case for normal withdrawing + poolCoinBefore := simapp.LiquidityKeeper.GetPoolCoinTotalSupply(ctx, pool) + withdrawerPoolCoinBefore := simapp.BankKeeper.GetBalance(ctx, addrs[0], pool.PoolCoinDenom) + + withdrawerDenomABalanceBefore := simapp.BankKeeper.GetBalance(ctx, addrs[0], pool.ReserveCoinDenoms[0]) + withdrawerDenomBBalanceBefore := simapp.BankKeeper.GetBalance(ctx, addrs[0], pool.ReserveCoinDenoms[1]) + + require.Equal(t, poolCoinBefore, withdrawerPoolCoinBefore.Amount) + withdrawMsg := types.NewMsgWithdrawWithinBatch(addrs[0], pool.Id, sdk.NewCoin(pool.PoolCoinDenom, sdk.NewInt(1))) + + _, err = simapp.LiquidityKeeper.WithdrawWithinBatch(ctx, withdrawMsg) + require.NoError(t, err) + + poolBatch, found := simapp.LiquidityKeeper.GetPoolBatch(ctx, withdrawMsg.PoolId) + require.True(t, found) + msgs := simapp.LiquidityKeeper.GetAllPoolBatchWithdrawMsgStates(ctx, poolBatch) + require.Equal(t, 1, len(msgs)) + + liquidity.EndBlocker(ctx, simapp.LiquidityKeeper) + liquidity.BeginBlocker(ctx, simapp.LiquidityKeeper) + + poolCoinAfter := simapp.LiquidityKeeper.GetPoolCoinTotalSupply(ctx, pool) + withdrawerPoolCoinAfter := simapp.BankKeeper.GetBalance(ctx, addrs[0], pool.PoolCoinDenom) + + require.Equal(t, poolCoinAfter, poolCoinBefore) + require.Equal(t, withdrawerPoolCoinAfter.Amount, withdrawerPoolCoinBefore.Amount) + withdrawerDenomABalance := simapp.BankKeeper.GetBalance(ctx, addrs[0], pool.ReserveCoinDenoms[0]) + withdrawerDenomBBalance := simapp.BankKeeper.GetBalance(ctx, addrs[0], pool.ReserveCoinDenoms[1]) + + reservePoolBalanceA := simapp.BankKeeper.GetBalance(ctx, pool.GetReserveAccount(), pool.ReserveCoinDenoms[0]) + reservePoolBalanceB := simapp.BankKeeper.GetBalance(ctx, pool.GetReserveAccount(), pool.ReserveCoinDenoms[1]) + require.Equal(t, deposit.AmountOf(pool.ReserveCoinDenoms[0]), reservePoolBalanceA.Amount) + require.Equal(t, deposit.AmountOf(pool.ReserveCoinDenoms[1]), reservePoolBalanceB.Amount) + require.Equal(t, withdrawerDenomABalanceBefore, withdrawerDenomABalance) + require.Equal(t, withdrawerDenomBBalanceBefore, withdrawerDenomBBalance) +} + func TestReinitializePool(t *testing.T) { simapp, ctx := createTestInput() simapp.LiquidityKeeper.SetParams(ctx, types.DefaultParams()) @@ -1016,3 +1106,218 @@ func TestDepositWithCoinsSent(t *testing.T) { require.True(sdk.IntEq(t, sdk.NewInt(0), balances.AmountOf(DenomY))) require.True(sdk.IntEq(t, sdk.NewInt(1000000), balances.AmountOf(pool.PoolCoinDenom))) } + +func TestCreatePoolEqualDenom(t *testing.T) { + simapp, ctx := createTestInput() + params := types.DefaultParams() + simapp.LiquidityKeeper.SetParams(ctx, params) + addrs := app.AddTestAddrs(simapp, ctx, 1, params.PoolCreationFee) + + msg := types.NewMsgCreatePool(addrs[0], types.DefaultPoolTypeID, + sdk.Coins{ + sdk.NewCoin(DenomA, sdk.NewInt(1000000)), + sdk.NewCoin(DenomA, sdk.NewInt(1000000))}) + _, err := simapp.LiquidityKeeper.CreatePool(ctx, msg) + require.ErrorIs(t, err, types.ErrEqualDenom) +} + +func TestOverflowAndZeroCases(t *testing.T) { + simapp, ctx := createTestInput() + params := types.DefaultParams() + simapp.LiquidityKeeper.SetParams(ctx, params) + keeper.BatchLogicInvariantCheckFlag = false + + poolTypeID := types.DefaultPoolTypeID + addrs := app.AddTestAddrs(simapp, ctx, 3, params.PoolCreationFee) + + denomA := "uETH" + denomB := "uUSD" + denomA, denomB = types.AlphabeticalDenomPair(denomA, denomB) + + // Check overflow case on deposit + deposit := sdk.NewCoins( + sdk.NewCoin(denomA, sdk.NewInt(1_000_000)), + sdk.NewCoin(denomB, sdk.NewInt(2_000_000_000_000*1_000_000).MulRaw(1_000_000))) + hugeCoins := sdk.NewCoins( + sdk.NewCoin(denomA, sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000)), + sdk.NewCoin(denomB, sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000))) + hugeCoins2 := sdk.NewCoins( + sdk.NewCoin(denomA, sdk.NewInt(1_000_000_000_000_000_000)), + sdk.NewCoin(denomB, sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000))) + app.SaveAccount(simapp, ctx, addrs[0], deposit.Add(hugeCoins.Add(hugeCoins2...)...)) + + msg := types.NewMsgCreatePool(addrs[0], poolTypeID, deposit) + _, err := simapp.LiquidityKeeper.CreatePool(ctx, msg) + require.NoError(t, err) + pools := simapp.LiquidityKeeper.GetAllPools(ctx) + poolCoin := simapp.LiquidityKeeper.GetPoolCoinTotalSupply(ctx, pools[0]) + + depositorBalance := simapp.BankKeeper.GetAllBalances(ctx, addrs[0]) + depositMsg := types.NewMsgDepositWithinBatch(addrs[0], pools[0].Id, hugeCoins) + depositMsg2 := types.NewMsgDepositWithinBatch(addrs[0], pools[0].Id, hugeCoins2) + _, err = simapp.LiquidityKeeper.DepositWithinBatch(ctx, depositMsg) + _, err = simapp.LiquidityKeeper.DepositWithinBatch(ctx, depositMsg2) + require.NoError(t, err) + + poolBatch, found := simapp.LiquidityKeeper.GetPoolBatch(ctx, depositMsg.PoolId) + require.True(t, found) + msgs := simapp.LiquidityKeeper.GetAllPoolBatchDepositMsgs(ctx, poolBatch) + require.Equal(t, 2, len(msgs)) + err = simapp.LiquidityKeeper.ExecuteDeposit(ctx, msgs[0], poolBatch) + require.ErrorIs(t, err, types.ErrOverflowAmount) + err = simapp.LiquidityKeeper.RefundDeposit(ctx, msgs[0], poolBatch) + require.NoError(t, err) + err = simapp.LiquidityKeeper.ExecuteDeposit(ctx, msgs[1], poolBatch) + require.ErrorIs(t, err, types.ErrOverflowAmount) + err = simapp.LiquidityKeeper.RefundDeposit(ctx, msgs[1], poolBatch) + require.NoError(t, err) + + poolCoinAfter := simapp.LiquidityKeeper.GetPoolCoinTotalSupply(ctx, pools[0]) + depositorPoolCoinBalance := simapp.BankKeeper.GetBalance(ctx, addrs[0], pools[0].PoolCoinDenom) + require.Equal(t, poolCoin, poolCoinAfter) + require.Equal(t, poolCoinAfter, depositorPoolCoinBalance.Amount) + require.Equal(t, depositorBalance.AmountOf(pools[0].PoolCoinDenom), depositorPoolCoinBalance.Amount) + + hugeCoins3 := sdk.NewCoins( + sdk.NewCoin(denomA, sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000)), + sdk.NewCoin(denomB, sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000))) + depositMsg = types.NewMsgDepositWithinBatch(addrs[0], pools[0].Id, hugeCoins3) + _, err = simapp.LiquidityKeeper.DepositWithinBatch(ctx, depositMsg) + require.NoError(t, err) + msgs = simapp.LiquidityKeeper.GetAllPoolBatchDepositMsgs(ctx, poolBatch) + require.Equal(t, 3, len(msgs)) + err = simapp.LiquidityKeeper.ExecuteDeposit(ctx, msgs[2], poolBatch) + require.NoError(t, err) + + // Check overflow case on withdraw + depositorPoolCoinBalance = simapp.BankKeeper.GetBalance(ctx, addrs[0], pools[0].PoolCoinDenom) + _, err = simapp.LiquidityKeeper.WithdrawWithinBatch(ctx, types.NewMsgWithdrawWithinBatch(addrs[0], pools[0].Id, depositorPoolCoinBalance.SubAmount(sdk.NewInt(1)))) + require.NoError(t, err) + + poolBatch, found = simapp.LiquidityKeeper.GetPoolBatch(ctx, depositMsg.PoolId) + require.True(t, found) + withdrawMsgs := simapp.LiquidityKeeper.GetAllPoolBatchWithdrawMsgStates(ctx, poolBatch) + require.Equal(t, 1, len(withdrawMsgs)) + err = simapp.LiquidityKeeper.ExecuteWithdrawal(ctx, withdrawMsgs[0], poolBatch) + require.ErrorIs(t, err, types.ErrOverflowAmount) + err = simapp.LiquidityKeeper.RefundWithdrawal(ctx, withdrawMsgs[0], poolBatch) + require.NoError(t, err) + + // Check overflow, division by zero case on swap + swapUserBalanceBefore := simapp.BankKeeper.GetAllBalances(ctx, addrs[0]) + offerCoinA := sdk.NewCoin(denomA, sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000_000)) + orderPriceA := sdk.MustNewDecFromStr("110000000000000000000000000000000000000000000000000000000000.000000000000000001") + offerCoinB := sdk.NewCoin(denomB, sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000)) + orderPriceB := sdk.MustNewDecFromStr("0.000000000000000001") + liquidity.BeginBlocker(ctx, simapp.LiquidityKeeper) + _, err = simapp.LiquidityKeeper.SwapWithinBatch( + ctx, + types.NewMsgSwapWithinBatch(addrs[0], pools[0].Id, types.DefaultSwapTypeID, offerCoinA, denomB, orderPriceA, params.SwapFeeRate), + 0) + require.ErrorIs(t, err, types.ErrOverflowAmount) + _, err = simapp.LiquidityKeeper.SwapWithinBatch( + ctx, + types.NewMsgSwapWithinBatch(addrs[0], pools[0].Id, types.DefaultSwapTypeID, offerCoinB, denomA, orderPriceB, params.SwapFeeRate), + 0) + require.NoError(t, err) + liquidity.EndBlocker(ctx, simapp.LiquidityKeeper) + liquidity.BeginBlocker(ctx, simapp.LiquidityKeeper) + swapUserBalanceAfter := simapp.BankKeeper.GetAllBalances(ctx, addrs[0]) + require.Equal(t, swapUserBalanceBefore, swapUserBalanceAfter) + depositMsgs := simapp.LiquidityKeeper.GetAllPoolBatchDepositMsgs(ctx, poolBatch) + require.Equal(t, 0, len(depositMsgs)) + withdrawMsgs = simapp.LiquidityKeeper.GetAllPoolBatchWithdrawMsgStates(ctx, poolBatch) + require.Equal(t, 0, len(withdrawMsgs)) + swapMsgs := simapp.LiquidityKeeper.GetAllPoolBatchSwapMsgStates(ctx, poolBatch) + require.Equal(t, 0, len(swapMsgs)) +} + +func TestExecuteBigDeposit(t *testing.T) { + simapp, ctx := createTestInput() + simapp.LiquidityKeeper.SetParams(ctx, types.DefaultParams()) + params := simapp.LiquidityKeeper.GetParams(ctx) + keeper.BatchLogicInvariantCheckFlag = false + + poolTypeID := types.DefaultPoolTypeID + addrs := app.AddTestAddrs(simapp, ctx, 3, params.PoolCreationFee) + + denomA := "uETH" + denomB := "uUSD" + denomA, denomB = types.AlphabeticalDenomPair(denomA, denomB) + + // 2^63-1 + hugeInt := int64(9223372036854775807) + initDeposit := sdk.NewCoins(sdk.NewCoin(denomA, sdk.NewInt(hugeInt)), sdk.NewCoin(denomB, sdk.NewInt(hugeInt))) + app.SaveAccount(simapp, ctx, addrs[0], initDeposit) + app.SaveAccount(simapp, ctx, addrs[1], initDeposit) + app.SaveAccount(simapp, ctx, addrs[2], initDeposit) + + createBalance := sdk.NewCoins(sdk.NewCoin(denomA, sdk.NewInt(1*1000000)), sdk.NewCoin(denomB, sdk.NewInt(1*1000000))) + + createMsg := types.NewMsgCreatePool(addrs[0], poolTypeID, createBalance) + + _, err := simapp.LiquidityKeeper.CreatePool(ctx, createMsg) + require.NoError(t, err) + + pools := simapp.LiquidityKeeper.GetAllPools(ctx) + pool := pools[0] + + poolCoinInit := simapp.LiquidityKeeper.GetPoolCoinTotalSupply(ctx, pool) + require.Equal(t, poolCoinInit, sdk.NewInt(1*1000000)) + + depositMsg := types.NewMsgDepositWithinBatch(addrs[1], pool.Id, initDeposit) + _, err = simapp.LiquidityKeeper.DepositWithinBatch(ctx, depositMsg) + require.NoError(t, err) + + poolBatch, found := simapp.LiquidityKeeper.GetPoolBatch(ctx, depositMsg.PoolId) + require.True(t, found) + msgs := simapp.LiquidityKeeper.GetAllPoolBatchDepositMsgs(ctx, poolBatch) + require.Equal(t, 1, len(msgs)) + + err = simapp.LiquidityKeeper.ExecuteDeposit(ctx, msgs[0], poolBatch) + require.NoError(t, err) + + poolCoin := simapp.LiquidityKeeper.GetPoolCoinTotalSupply(ctx, pool) + require.Equal(t, poolCoin.Sub(poolCoinInit), simapp.BankKeeper.GetBalance(ctx, addrs[1], pool.PoolCoinDenom).Amount) + + simapp.LiquidityKeeper.DeleteAllReadyPoolBatchDepositMsgStates(ctx, poolBatch) + + depositMsg = types.NewMsgDepositWithinBatch(addrs[2], pool.Id, initDeposit) + _, err = simapp.LiquidityKeeper.DepositWithinBatch(ctx, depositMsg) + require.NoError(t, err) + + poolBatch, found = simapp.LiquidityKeeper.GetPoolBatch(ctx, depositMsg.PoolId) + require.True(t, found) + msgs = simapp.LiquidityKeeper.GetAllPoolBatchDepositMsgs(ctx, poolBatch) + require.Equal(t, 1, len(msgs)) + + err = simapp.LiquidityKeeper.ExecuteDeposit(ctx, msgs[0], poolBatch) + require.NoError(t, err) + + poolCoinAfter := simapp.LiquidityKeeper.GetPoolCoinTotalSupply(ctx, pool) + require.Equal(t, poolCoinAfter.Sub(poolCoin), simapp.BankKeeper.GetBalance(ctx, addrs[2], pool.PoolCoinDenom).Amount) + require.Equal(t, simapp.BankKeeper.GetBalance(ctx, addrs[1], pool.PoolCoinDenom).Amount, simapp.BankKeeper.GetBalance(ctx, addrs[2], pool.PoolCoinDenom).Amount) + + require.True(t, simapp.BankKeeper.GetBalance(ctx, addrs[1], denomA).IsZero()) + require.True(t, simapp.BankKeeper.GetBalance(ctx, addrs[1], denomB).IsZero()) + + // Error due to decimal operation exceeding precision + require.Equal(t, sdk.NewInt(8), simapp.BankKeeper.GetBalance(ctx, addrs[2], denomA).Amount) + require.Equal(t, sdk.NewInt(8), simapp.BankKeeper.GetBalance(ctx, addrs[2], denomB).Amount) + + poolCoinAmt := simapp.BankKeeper.GetBalance(ctx, addrs[1], pool.PoolCoinDenom) + state, err := simapp.LiquidityKeeper.WithdrawWithinBatch(ctx, types.NewMsgWithdrawWithinBatch(addrs[1], pool.Id, poolCoinAmt)) + require.NoError(t, err) + + err = simapp.LiquidityKeeper.ExecuteWithdrawal(ctx, state, poolBatch) + require.NoError(t, err) + + balanceAfter := simapp.BankKeeper.GetAllBalances(ctx, addrs[1]) + liquidity.EndBlocker(ctx, simapp.LiquidityKeeper) + liquidity.BeginBlocker(ctx, simapp.LiquidityKeeper) + + // Error due to decimal operation exceeding precision + require.Equal(t, sdk.ZeroInt(), balanceAfter.AmountOf(pool.PoolCoinDenom)) + require.Equal(t, sdk.NewInt(-4), balanceAfter.AmountOf(denomA).SubRaw(hugeInt)) + require.Equal(t, sdk.NewInt(-4), balanceAfter.AmountOf(denomB).SubRaw(hugeInt)) +} diff --git a/x/liquidity/keeper/swap.go b/x/liquidity/keeper/swap.go index 0cc558669..220b1492b 100644 --- a/x/liquidity/keeper/swap.go +++ b/x/liquidity/keeper/swap.go @@ -22,6 +22,10 @@ func (k Keeper) SwapExecution(ctx sdk.Context, poolBatch types.PoolBatch) (uint6 return 0, types.ErrPoolNotExists } + if k.IsDepletedPool(ctx, pool) { + return 0, types.ErrDepletedPool + } + currentHeight := ctx.BlockHeight() // set executed states of all messages to true executedMsgCount := uint64(0) @@ -32,7 +36,7 @@ func (k Keeper) SwapExecution(ctx sdk.Context, poolBatch types.PoolBatch) (uint6 if currentHeight > sms.OrderExpiryHeight { sms.ToBeDeleted = true } - if err := k.ValidateMsgSwapWithinBatch(ctx, *sms.Msg); err != nil { + if err := k.ValidateMsgSwapWithinBatch(ctx, *sms.Msg, pool); err != nil { sms.ToBeDeleted = true } if !sms.ToBeDeleted { @@ -79,7 +83,7 @@ func (k Keeper) SwapExecution(ctx sdk.Context, poolBatch types.PoolBatch) (uint6 // check orderbook validity and compute batchResult(direction, swapPrice, ..) result, found := orderBook.Match(X, Y) - if !found { + if !found || X.Quo(Y).IsZero() { err := k.RefundSwaps(ctx, pool, swapMsgStates) return executedMsgCount, err } diff --git a/x/liquidity/types/errors.go b/x/liquidity/types/errors.go index a70fb0764..82f73dc4d 100644 --- a/x/liquidity/types/errors.go +++ b/x/liquidity/types/errors.go @@ -46,4 +46,5 @@ var ( ErrExceededReserveCoinLimit = sdkerrors.Register(ModuleName, 38, "can not exceed reserve coin limit amount") ErrDepletedPool = sdkerrors.Register(ModuleName, 39, "the pool is depleted of reserve coin, reinitializing is required by deposit") ErrCircuitBreakerEnabled = sdkerrors.Register(ModuleName, 40, "circuit breaker is triggered") + ErrOverflowAmount = sdkerrors.Register(ModuleName, 41, "invalid amount that can cause overflow") ) diff --git a/x/liquidity/types/swap.go b/x/liquidity/types/swap.go index 07d6c2b4c..ccd46320b 100644 --- a/x/liquidity/types/swap.go +++ b/x/liquidity/types/swap.go @@ -156,6 +156,9 @@ func (orderBook OrderBook) Match(x, y sdk.Dec) (BatchResult, bool) { // Check orderbook validity naively func (orderBook OrderBook) Validate(currentPrice sdk.Dec) bool { + if !currentPrice.IsPositive() { + return false + } maxBuyOrderPrice := sdk.ZeroDec() minSellOrderPrice := sdk.NewDec(1000000000000) for _, order := range orderBook { @@ -324,7 +327,6 @@ func (orderBook OrderBook) PriceDirection(currentPrice sdk.Dec) PriceDirection { sellAmtUnderCurrentPrice = sellAmtUnderCurrentPrice.Add(order.SellOfferAmt.ToDec()) } } - if buyAmtOverCurrentPrice.GT(currentPrice.Mul(sellAmtUnderCurrentPrice.Add(sellAmtAtCurrentPrice))) { return Increasing } else if currentPrice.Mul(sellAmtUnderCurrentPrice).GT(buyAmtOverCurrentPrice.Add(buyAmtAtCurrentPrice)) { diff --git a/x/liquidity/types/swap_test.go b/x/liquidity/types/swap_test.go index b74c546b6..dbc3f16b1 100644 --- a/x/liquidity/types/swap_test.go +++ b/x/liquidity/types/swap_test.go @@ -550,44 +550,57 @@ func TestMakeOrderMapEdgeCase(t *testing.T) { } func TestOrderbookValidate(t *testing.T) { - currentPrice := sdk.MustNewDecFromStr("1.0") for _, testCase := range []struct { - buyPrice string - sellPrice string - valid bool + currentPrice string + buyPrice string + sellPrice string + valid bool }{ { - buyPrice: "0.99", - sellPrice: "1.01", - valid: true, + currentPrice: "1.0", + buyPrice: "0.99", + sellPrice: "1.01", + valid: true, }, { // maxBuyOrderPrice > minSellOrderPrice - buyPrice: "1.01", - sellPrice: "0.99", - valid: false, + currentPrice: "1.0", + buyPrice: "1.01", + sellPrice: "0.99", + valid: false, }, { - buyPrice: "1.1", - sellPrice: "1.2", - valid: true, + currentPrice: "1.0", + buyPrice: "1.1", + sellPrice: "1.2", + valid: true, }, { // maxBuyOrderPrice/currentPrice > 1.10 - buyPrice: "1.11", - sellPrice: "1.2", - valid: false, + currentPrice: "1.0", + buyPrice: "1.11", + sellPrice: "1.2", + valid: false, }, { - buyPrice: "0.8", - sellPrice: "0.9", - valid: true, + currentPrice: "1.0", + buyPrice: "0.8", + sellPrice: "0.9", + valid: true, }, { // minSellOrderPrice/currentPrice < 0.90 - buyPrice: "0.8", - sellPrice: "0.89", - valid: false, + currentPrice: "1.0", + buyPrice: "0.8", + sellPrice: "0.89", + valid: false, + }, + { + // not positive price + currentPrice: "0.0", + buyPrice: "0.00000000001", + sellPrice: "0.000000000011", + valid: false, }, } { buyPrice := sdk.MustNewDecFromStr(testCase.buyPrice) @@ -605,7 +618,7 @@ func TestOrderbookValidate(t *testing.T) { }, } orderBook := orderMap.SortOrderBook() - require.Equal(t, testCase.valid, orderBook.Validate(currentPrice)) + require.Equal(t, testCase.valid, orderBook.Validate(sdk.MustNewDecFromStr(testCase.currentPrice))) } } diff --git a/x/liquidity/types/utils.go b/x/liquidity/types/utils.go index 510d20b5e..d8e1fde0c 100644 --- a/x/liquidity/types/utils.go +++ b/x/liquidity/types/utils.go @@ -98,3 +98,27 @@ func MustParseCoinsNormalized(coinStr string) sdk.Coins { } return coins } + +func CheckOverflow(a, b sdk.Int) (err error) { + defer func() { + if r := recover(); r != nil { + err = ErrOverflowAmount + } + }() + a.Mul(b) + a.Quo(b) + b.Quo(a) + return nil +} + +func CheckOverflowWithDec(a, b sdk.Dec) (err error) { + defer func() { + if r := recover(); r != nil { + err = ErrOverflowAmount + } + }() + a.Mul(b) + a.Quo(b) + b.Quo(a) + return nil +} diff --git a/x/liquidity/types/utils_test.go b/x/liquidity/types/utils_test.go index a2e8bf21f..639aab338 100644 --- a/x/liquidity/types/utils_test.go +++ b/x/liquidity/types/utils_test.go @@ -280,3 +280,61 @@ func TestGetOfferCoinFee(t *testing.T) { }) } } + +func TestCheckOverflow(t *testing.T) { + testCases := []struct { + name string + a sdk.Int + b sdk.Int + expectErr error + }{ + { + name: "valid case", + a: sdk.NewInt(10000), + b: sdk.NewInt(100), + expectErr: nil, + }, + { + name: "overflow case", + a: sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000), + b: sdk.NewInt(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000).MulRaw(1_000_000_000_000_000_000), + expectErr: types.ErrOverflowAmount, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := types.CheckOverflow(tc.a, tc.b) + require.ErrorIs(t, err, tc.expectErr) + }) + } +} + +func TestCheckOverflowWithDec(t *testing.T) { + testCases := []struct { + name string + a sdk.Dec + b sdk.Dec + expectErr error + }{ + { + name: "valid case", + a: sdk.MustNewDecFromStr("1.0"), + b: sdk.MustNewDecFromStr("0.0000001"), + expectErr: nil, + }, + { + name: "overflow case", + a: sdk.MustNewDecFromStr("100000000000000000000000000000000000000000000000000000000000.0").MulInt64(10), + b: sdk.MustNewDecFromStr("0.000000000000000001"), + expectErr: types.ErrOverflowAmount, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := types.CheckOverflowWithDec(tc.a, tc.b) + require.ErrorIs(t, err, tc.expectErr) + }) + } +}