Skip to content

Commit

Permalink
fix: handle negative order amount returned from RangedPool
Browse files Browse the repository at this point in the history
also add more tests
  • Loading branch information
hallazzang committed Jun 28, 2022
1 parent 29349c0 commit 24a9408
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 84 deletions.
2 changes: 1 addition & 1 deletion x/liquidity/amm/match_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func TestFindMatchableAmountAtSinglePrice(t *testing.T) {
}
}

func TestMatch_EdgeCase1(t *testing.T) {
func TestMatch_edgecase1(t *testing.T) {
orders := []amm.Order{
newOrder(amm.Sell, utils.ParseDec("0.100"), sdk.NewInt(10000)),
newOrder(amm.Sell, utils.ParseDec("0.099"), sdk.NewInt(9995)),
Expand Down
62 changes: 42 additions & 20 deletions x/liquidity/amm/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,14 @@ func (pool *BasicPool) LowestSellPrice() (price sdk.Dec, found bool) {

func (pool *BasicPool) BuyAmountOver(price sdk.Dec, _ bool) (amt sdk.Int) {
if price.GTE(pool.Price()) {
return sdk.ZeroInt()
return zeroInt
}
if price.LT(MinPoolPrice) {
price = MinPoolPrice
}
dx := pool.rx.ToDec().Sub(price.MulInt(pool.ry)).TruncateInt()
if dx.IsZero() {
return sdk.ZeroInt()
if !dx.IsPositive() {
return zeroInt
}
utils.SafeMath(func() {
amt = pool.rx.ToDec().QuoTruncate(price).Sub(pool.ry.ToDec()).TruncateInt()
Expand All @@ -130,19 +130,23 @@ func (pool *BasicPool) BuyAmountOver(price sdk.Dec, _ bool) (amt sdk.Int) {
return
}

func (pool *BasicPool) SellAmountUnder(price sdk.Dec, _ bool) sdk.Int {
func (pool *BasicPool) SellAmountUnder(price sdk.Dec, _ bool) (amt sdk.Int) {
if price.LTE(pool.Price()) {
return sdk.ZeroInt()
return zeroInt
}
if price.GT(MaxPoolPrice) {
price = MaxPoolPrice
}
return pool.ry.ToDec().Sub(pool.rx.ToDec().QuoRoundUp(price)).TruncateInt()
amt = pool.ry.ToDec().Sub(pool.rx.ToDec().QuoRoundUp(price)).TruncateInt()
if !amt.IsPositive() {
return zeroInt
}
return
}

func (pool *BasicPool) BuyAmountTo(price sdk.Dec) (amt sdk.Int) {
if price.GTE(pool.Price()) {
return sdk.ZeroInt()
return zeroInt
}
if price.LT(MinPoolPrice) {
price = MinPoolPrice
Expand All @@ -151,8 +155,8 @@ func (pool *BasicPool) BuyAmountTo(price sdk.Dec) (amt sdk.Int) {
sqrtRy := utils.DecApproxSqrt(pool.ry.ToDec())
sqrtPrice := utils.DecApproxSqrt(price)
dx := pool.rx.ToDec().Sub(sqrtPrice.Mul(sqrtRx.Mul(sqrtRy))) // dx = rx - sqrt(P * rx * ry)
if dx.IsZero() {
return sdk.ZeroInt()
if !dx.IsPositive() {
return zeroInt
}
utils.SafeMath(func() {
amt = dx.QuoTruncate(price).TruncateInt() // dy = dx / P
Expand All @@ -165,9 +169,9 @@ func (pool *BasicPool) BuyAmountTo(price sdk.Dec) (amt sdk.Int) {
return
}

func (pool *BasicPool) SellAmountTo(price sdk.Dec) sdk.Int {
func (pool *BasicPool) SellAmountTo(price sdk.Dec) (amt sdk.Int) {
if price.LTE(pool.Price()) {
return sdk.ZeroInt()
return zeroInt
}
if price.GT(MaxPoolPrice) {
price = MaxPoolPrice
Expand All @@ -176,7 +180,11 @@ func (pool *BasicPool) SellAmountTo(price sdk.Dec) sdk.Int {
sqrtRy := utils.DecApproxSqrt(pool.ry.ToDec())
sqrtPrice := utils.DecApproxSqrt(price)
// dy = ry - sqrt(rx * ry / P)
return pool.ry.ToDec().Sub(sqrtRx.Mul(sqrtRy).Quo(sqrtPrice)).TruncateInt()
amt = pool.ry.ToDec().Sub(sqrtRx.Mul(sqrtRy).Quo(sqrtPrice)).TruncateInt()
if !amt.IsPositive() {
return zeroInt
}
return
}

func (pool *BasicPool) Clone() Pool {
Expand Down Expand Up @@ -297,6 +305,14 @@ func (pool *RangedPool) Translation() (transX, transY sdk.Dec) {
return DeriveTranslation(pool.rx, pool.ry, pool.minPrice, pool.maxPrice)
}

func (pool *RangedPool) MinPrice() sdk.Dec {
return pool.minPrice
}

func (pool *RangedPool) MaxPrice() sdk.Dec {
return pool.maxPrice
}

// Price returns the pool price.
func (pool *RangedPool) Price() sdk.Dec {
if pool.rx.IsZero() && pool.ry.IsZero() {
Expand Down Expand Up @@ -328,15 +344,15 @@ func (pool *RangedPool) LowestSellPrice() (price sdk.Dec, found bool) {
// or equal to given price.
func (pool *RangedPool) BuyAmountOver(price sdk.Dec, _ bool) (amt sdk.Int) {
if price.GTE(pool.Price()) {
return sdk.ZeroInt()
return zeroInt
}
if price.LT(pool.minPrice) {
price = pool.minPrice
}
// dx = (rx + transX) - P * (ry + transY)
dx := pool.xComp.Sub(price.Mul(pool.yComp))
if dx.IsZero() {
return sdk.ZeroInt()
if !dx.IsPositive() {
return zeroInt
} else if dx.GT(pool.rx.ToDec()) {
dx = pool.rx.ToDec()
}
Expand All @@ -353,7 +369,7 @@ func (pool *RangedPool) BuyAmountOver(price sdk.Dec, _ bool) (amt sdk.Int) {

func (pool *RangedPool) SellAmountUnder(price sdk.Dec, _ bool) (amt sdk.Int) {
if price.LTE(pool.Price()) {
return sdk.ZeroInt()
return zeroInt
}
if price.GT(pool.maxPrice) {
price = pool.maxPrice
Expand All @@ -363,12 +379,15 @@ func (pool *RangedPool) SellAmountUnder(price sdk.Dec, _ bool) (amt sdk.Int) {
if amt.GT(pool.ry) {
amt = pool.ry
}
if !amt.IsPositive() {
return zeroInt
}
return
}

func (pool *RangedPool) BuyAmountTo(price sdk.Dec) (amt sdk.Int) {
if price.GTE(pool.Price()) {
return sdk.ZeroInt()
return zeroInt
}
if price.LT(pool.minPrice) {
price = pool.minPrice
Expand All @@ -378,8 +397,8 @@ func (pool *RangedPool) BuyAmountTo(price sdk.Dec) (amt sdk.Int) {
sqrtPrice := utils.DecApproxSqrt(price)
// dx = rx - (sqrt(P * (rx + transX) * (ry + transY)) - transX)
dx := pool.rx.ToDec().Sub(sqrtPrice.Mul(sqrtXComp.Mul(sqrtYComp)).Sub(pool.transX))
if dx.IsZero() {
return sdk.ZeroInt()
if !dx.IsPositive() {
return zeroInt
} else if dx.GT(pool.rx.ToDec()) {
dx = pool.rx.ToDec()
}
Expand All @@ -396,7 +415,7 @@ func (pool *RangedPool) BuyAmountTo(price sdk.Dec) (amt sdk.Int) {

func (pool *RangedPool) SellAmountTo(price sdk.Dec) (amt sdk.Int) {
if price.LTE(pool.Price()) {
return sdk.ZeroInt()
return zeroInt
}
if price.GT(pool.maxPrice) {
price = pool.maxPrice
Expand All @@ -409,6 +428,9 @@ func (pool *RangedPool) SellAmountTo(price sdk.Dec) (amt sdk.Int) {
if amt.GT(pool.ry) {
amt = pool.ry
}
if !amt.IsPositive() {
return zeroInt
}
return
}

Expand Down
8 changes: 8 additions & 0 deletions x/liquidity/amm/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,14 @@ func TestRangedPool_BuyAmountTo(t *testing.T) {
{pool, utils.ParseDec("1.0"), sdk.ZeroInt()},
{pool, utils.ParseDec("0.8"), sdk.NewInt(450560)},
{pool, utils.ParseDec("0.7"), sdk.NewInt(796682)},
{
amm.NewRangedPool(
sdk.NewInt(957322), sdk.NewInt(3351038710333311), sdk.Int{},
utils.ParseDec("0.9"), utils.ParseDec("1.1"),
),
utils.ParseDec("0.899580000000000000"),
sdk.NewInt(0),
},
} {
t.Run("", func(t *testing.T) {
amt := tc.pool.BuyAmountTo(tc.price)
Expand Down
21 changes: 9 additions & 12 deletions x/liquidity/keeper/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func (s *KeeperTestSuite) TestDefaultGenesis() {

func (s *KeeperTestSuite) TestImportExportGenesis() {
s.ctx = s.ctx.WithBlockHeight(1).WithBlockTime(utils.ParseTime("2022-01-01T00:00:00Z"))
k, ctx := s.keeper, s.ctx

pair := s.createPair(s.addr(0), "denom1", "denom2", true)
pool := s.createPool(s.addr(0), pair.Id, utils.ParseCoins("1000000denom1,1000000denom2"), true)
Expand All @@ -40,42 +39,40 @@ func (s *KeeperTestSuite) TestImportExportGenesis() {
withdrawReq := s.withdraw(s.addr(1), pool.Id, poolCoin)
order := s.sellLimitOrder(s.addr(3), pair.Id, utils.ParseDec("1.0"), newInt(1000), 0, true)

genState := k.ExportGenesis(ctx)
genState := s.keeper.ExportGenesis(s.ctx)

bz := s.app.AppCodec().MustMarshalJSON(genState)

s.SetupTest()
s.ctx = s.ctx.WithBlockHeight(1).WithBlockTime(utils.ParseTime("2022-01-01T00:00:00Z"))
k, ctx = s.keeper, s.ctx

var genState2 types.GenesisState
s.app.AppCodec().MustUnmarshalJSON(bz, &genState2)
k.InitGenesis(ctx, genState2)
genState3 := k.ExportGenesis(ctx)
s.keeper.InitGenesis(s.ctx, genState2)
genState3 := s.keeper.ExportGenesis(s.ctx)

s.Require().Equal(*genState, *genState3)

depositReq2, found := k.GetDepositRequest(ctx, depositReq.PoolId, depositReq.Id)
depositReq2, found := s.keeper.GetDepositRequest(s.ctx, depositReq.PoolId, depositReq.Id)
s.Require().True(found)
s.Require().Equal(depositReq, depositReq2)
withdrawReq2, found := k.GetWithdrawRequest(ctx, withdrawReq.PoolId, withdrawReq.Id)
withdrawReq2, found := s.keeper.GetWithdrawRequest(s.ctx, withdrawReq.PoolId, withdrawReq.Id)
s.Require().True(found)
s.Require().Equal(withdrawReq, withdrawReq2)
order2, found := k.GetOrder(ctx, order.PairId, order.Id)
order2, found := s.keeper.GetOrder(s.ctx, order.PairId, order.Id)
s.Require().True(found)
s.Require().Equal(order, order2)
}

func (s *KeeperTestSuite) TestImportExportGenesisEmpty() {
k, ctx := s.keeper, s.ctx
genState := k.ExportGenesis(ctx)
genState := s.keeper.ExportGenesis(s.ctx)

var genState2 types.GenesisState
bz := s.app.AppCodec().MustMarshalJSON(genState)
s.app.AppCodec().MustUnmarshalJSON(bz, &genState2)
k.InitGenesis(ctx, genState2)
s.keeper.InitGenesis(s.ctx, genState2)

genState3 := k.ExportGenesis(ctx)
genState3 := s.keeper.ExportGenesis(s.ctx)
s.Require().Equal(*genState, genState2)
s.Require().Equal(genState2, *genState3)
}
Expand Down
23 changes: 19 additions & 4 deletions x/liquidity/keeper/keeper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
"time"

"github.com/stretchr/testify/suite"
abci "github.com/tendermint/tendermint/abci/types"

sdk "github.com/cosmos/cosmos-sdk/types"
tmproto "github.com/tendermint/tendermint/proto/tendermint/types"

chain "github.com/cosmosquad-labs/squad/app"
"github.com/cosmosquad-labs/squad/x/liquidity"
utils "github.com/cosmosquad-labs/squad/types"
"github.com/cosmosquad-labs/squad/x/liquidity/amm"
"github.com/cosmosquad-labs/squad/x/liquidity/keeper"
"github.com/cosmosquad-labs/squad/x/liquidity/types"
Expand All @@ -33,7 +34,13 @@ func TestKeeperTestSuite(t *testing.T) {

func (s *KeeperTestSuite) SetupTest() {
s.app = chain.Setup(false)
s.ctx = s.app.BaseApp.NewContext(false, tmproto.Header{})
hdr := tmproto.Header{
Height: 1,
Time: utils.ParseTime("2022-01-01T00:00:00Z"),
}
s.app.BeginBlock(abci.RequestBeginBlock{Header: hdr})
s.ctx = s.app.BaseApp.NewContext(false, hdr)
s.app.BeginBlocker(s.ctx, abci.RequestBeginBlock{Header: hdr})
s.keeper = s.app.LiquidityKeeper
s.querier = keeper.Querier{Keeper: s.keeper}
s.msgServer = keeper.NewMsgServerImpl(s.keeper)
Expand All @@ -55,8 +62,16 @@ func (s *KeeperTestSuite) sendCoins(fromAddr, toAddr sdk.AccAddress, amt sdk.Coi
}

func (s *KeeperTestSuite) nextBlock() {
liquidity.EndBlocker(s.ctx, s.keeper)
liquidity.BeginBlocker(s.ctx, s.keeper)
s.T().Helper()
s.app.EndBlock(abci.RequestEndBlock{})
s.app.Commit()
hdr := tmproto.Header{
Height: s.app.LastBlockHeight() + 1,
Time: s.ctx.BlockTime().Add(5 * time.Second),
}
s.app.BeginBlock(abci.RequestBeginBlock{Header: hdr})
s.ctx = s.app.BaseApp.NewContext(false, hdr)
s.app.BeginBlocker(s.ctx, abci.RequestBeginBlock{Header: hdr})
}

// Below are useful helpers to write test code easily.
Expand Down
Loading

0 comments on commit 24a9408

Please sign in to comment.