diff --git a/CHANGELOG.md b/CHANGELOG.md index 803d47ddc8c..fa6001fa312 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * [#4336](https://github.com/osmosis-labs/osmosis/pull/4336) Move epochs module into its own go.mod * [#4658](https://github.com/osmosis-labs/osmosis/pull/4658) Deprecate x/gamm Pool query. The new one is located in x/poolmanager. * [#4682](https://github.com/osmosis-labs/osmosis/pull/4682) Deprecate x/gamm SpotPrice v2 query. The new one is located in x/poolmanager. +* [#4801](https://github.com/osmosis-labs/osmosis/pull/4801) remove GetTotalShares, GetTotalLiquidity and GetExitFee from PoolI. Define all on CFMMPoolI, define GetTotalLiquidity on PoolModuleI only. + ## v15.0.0 diff --git a/Makefile b/Makefile index d6d03ed5a28..c6eb55cc4e0 100644 --- a/Makefile +++ b/Makefile @@ -432,6 +432,13 @@ localnet-state-export-clean: localnet-clean localnet-cl-create-positions: go run tests/cl-go-client/main.go +############################################################################### +### Go Mock ### +############################################################################### + +go-mock-update-pool-module: + mockgen -source=x/poolmanager/types/routes.go -destination=tests/mocks/pool_module.go -package=mocks + .PHONY: all build-linux install format lint \ go-mod-cache draw-deps clean build build-contract-tests-hooks \ test test-all test-build test-cover test-unit test-race benchmark diff --git a/app/upgrades/v15/upgrade_test.go b/app/upgrades/v15/upgrade_test.go index 82aa69018fe..8f102c50a2e 100644 --- a/app/upgrades/v15/upgrade_test.go +++ b/app/upgrades/v15/upgrade_test.go @@ -130,10 +130,11 @@ func (suite *UpgradeTestSuite) TestMigrateBalancerToStablePools() { suite.Require().NoError(err) // shares before migration - balancerPool, err := gammKeeper.GetPool(suite.Ctx, poolID) + balancerPool, err := gammKeeper.GetCFMMPool(suite.Ctx, poolID) + balancerLiquidity, err := gammKeeper.GetTotalPoolLiquidity(suite.Ctx, balancerPool.GetId()) suite.Require().NoError(err) + balancerShares := balancerPool.GetTotalShares() - balancerLiquidity := balancerPool.GetTotalPoolLiquidity(ctx).String() // check balancer pool liquidity using the bank module balancerBalances := suite.App.BankKeeper.GetAllBalances(ctx, balancerPool.GetAddress()) @@ -141,13 +142,16 @@ func (suite *UpgradeTestSuite) TestMigrateBalancerToStablePools() { v15.MigrateBalancerPoolToSolidlyStable(ctx, gammKeeper, poolmanagerKeeper, suite.App.BankKeeper, poolID) // check that the pool is now a stable pool - stablepool, err := gammKeeper.GetPool(ctx, poolID) + stablepool, err := gammKeeper.GetCFMMPool(ctx, poolID) suite.Require().NoError(err) suite.Require().Equal(stablepool.GetType(), poolmanagertypes.Stableswap) + // check that the number of stableswap LP shares is the same as the number of balancer LP shares suite.Require().Equal(balancerShares.String(), stablepool.GetTotalShares().String()) // check that the pool liquidity is the same - suite.Require().Equal(balancerLiquidity, stablepool.GetTotalPoolLiquidity(ctx).String()) + stableLiquidity, err := gammKeeper.GetTotalPoolLiquidity(suite.Ctx, balancerPool.GetId()) + suite.Require().NoError(err) + suite.Require().Equal(balancerLiquidity.String(), stableLiquidity.String()) // check pool liquidity using the bank module stableBalances := suite.App.BankKeeper.GetAllBalances(ctx, stablepool.GetAddress()) suite.Require().Equal(balancerBalances, stableBalances) diff --git a/app/upgrades/v15/upgrades.go b/app/upgrades/v15/upgrades.go index 6f3a5a85f51..e5a0f82d04c 100644 --- a/app/upgrades/v15/upgrades.go +++ b/app/upgrades/v15/upgrades.go @@ -89,7 +89,12 @@ func migrateBalancerPoolsToSolidlyStable(ctx sdk.Context, gammKeeper *gammkeeper func migrateBalancerPoolToSolidlyStable(ctx sdk.Context, gammKeeper *gammkeeper.Keeper, poolmanagerKeeper *poolmanager.Keeper, bankKeeper bankkeeper.Keeper, poolId uint64) { // fetch the pool with the given poolId - balancerPool, err := gammKeeper.GetPool(ctx, poolId) + balancerPool, err := gammKeeper.GetCFMMPool(ctx, poolId) + if err != nil { + panic(err) + } + + balancerPoolLiquidity, err := gammKeeper.GetTotalPoolLiquidity(ctx, poolId) if err != nil { panic(err) } @@ -98,7 +103,7 @@ func migrateBalancerPoolToSolidlyStable(ctx sdk.Context, gammKeeper *gammkeeper. stableswapPool, err := stableswap.NewStableswapPool( poolId, stableswap.PoolParams{SwapFee: balancerPool.GetSwapFee(ctx), ExitFee: balancerPool.GetExitFee(ctx)}, - balancerPool.GetTotalPoolLiquidity(ctx), + balancerPoolLiquidity, []uint64{1, 1}, "osmo1k8c2m5cn322akk5wy8lpt87dd2f4yh9afcd7af", // Stride Foundation 2/3 multisig "", diff --git a/tests/mocks/pool_module.go b/tests/mocks/pool_module.go index a2f220d0cbb..871f12f2df6 100644 --- a/tests/mocks/pool_module.go +++ b/tests/mocks/pool_module.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: /root/osmosis/x/poolmanager/types/routes.go +// Source: x/poolmanager/types/routes.go // Package mocks is a generated GoMock package. package mocks @@ -8,10 +8,160 @@ import ( reflect "reflect" types "github.com/cosmos/cosmos-sdk/types" + types0 "github.com/cosmos/cosmos-sdk/x/auth/types" + types1 "github.com/cosmos/cosmos-sdk/x/bank/types" gomock "github.com/golang/mock/gomock" types2 "github.com/osmosis-labs/osmosis/v15/x/poolmanager/types" ) +// MockAccountI is a mock of AccountI interface. +type MockAccountI struct { + ctrl *gomock.Controller + recorder *MockAccountIMockRecorder +} + +// MockAccountIMockRecorder is the mock recorder for MockAccountI. +type MockAccountIMockRecorder struct { + mock *MockAccountI +} + +// NewMockAccountI creates a new mock instance. +func NewMockAccountI(ctrl *gomock.Controller) *MockAccountI { + mock := &MockAccountI{ctrl: ctrl} + mock.recorder = &MockAccountIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAccountI) EXPECT() *MockAccountIMockRecorder { + return m.recorder +} + +// GetAccount mocks base method. +func (m *MockAccountI) GetAccount(ctx types.Context, addr types.AccAddress) types0.AccountI { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAccount", ctx, addr) + ret0, _ := ret[0].(types0.AccountI) + return ret0 +} + +// GetAccount indicates an expected call of GetAccount. +func (mr *MockAccountIMockRecorder) GetAccount(ctx, addr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccount", reflect.TypeOf((*MockAccountI)(nil).GetAccount), ctx, addr) +} + +// NewAccount mocks base method. +func (m *MockAccountI) NewAccount(arg0 types.Context, arg1 types0.AccountI) types0.AccountI { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewAccount", arg0, arg1) + ret0, _ := ret[0].(types0.AccountI) + return ret0 +} + +// NewAccount indicates an expected call of NewAccount. +func (mr *MockAccountIMockRecorder) NewAccount(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewAccount", reflect.TypeOf((*MockAccountI)(nil).NewAccount), arg0, arg1) +} + +// SetAccount mocks base method. +func (m *MockAccountI) SetAccount(ctx types.Context, acc types0.AccountI) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccount", ctx, acc) +} + +// SetAccount indicates an expected call of SetAccount. +func (mr *MockAccountIMockRecorder) SetAccount(ctx, acc interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccount", reflect.TypeOf((*MockAccountI)(nil).SetAccount), ctx, acc) +} + +// MockBankI is a mock of BankI interface. +type MockBankI struct { + ctrl *gomock.Controller + recorder *MockBankIMockRecorder +} + +// MockBankIMockRecorder is the mock recorder for MockBankI. +type MockBankIMockRecorder struct { + mock *MockBankI +} + +// NewMockBankI creates a new mock instance. +func NewMockBankI(ctrl *gomock.Controller) *MockBankI { + mock := &MockBankI{ctrl: ctrl} + mock.recorder = &MockBankIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBankI) EXPECT() *MockBankIMockRecorder { + return m.recorder +} + +// SendCoins mocks base method. +func (m *MockBankI) SendCoins(ctx types.Context, fromAddr, toAddr types.AccAddress, amt types.Coins) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendCoins", ctx, fromAddr, toAddr, amt) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendCoins indicates an expected call of SendCoins. +func (mr *MockBankIMockRecorder) SendCoins(ctx, fromAddr, toAddr, amt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendCoins", reflect.TypeOf((*MockBankI)(nil).SendCoins), ctx, fromAddr, toAddr, amt) +} + +// SetDenomMetaData mocks base method. +func (m *MockBankI) SetDenomMetaData(ctx types.Context, denomMetaData types1.Metadata) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetDenomMetaData", ctx, denomMetaData) +} + +// SetDenomMetaData indicates an expected call of SetDenomMetaData. +func (mr *MockBankIMockRecorder) SetDenomMetaData(ctx, denomMetaData interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDenomMetaData", reflect.TypeOf((*MockBankI)(nil).SetDenomMetaData), ctx, denomMetaData) +} + +// MockCommunityPoolI is a mock of CommunityPoolI interface. +type MockCommunityPoolI struct { + ctrl *gomock.Controller + recorder *MockCommunityPoolIMockRecorder +} + +// MockCommunityPoolIMockRecorder is the mock recorder for MockCommunityPoolI. +type MockCommunityPoolIMockRecorder struct { + mock *MockCommunityPoolI +} + +// NewMockCommunityPoolI creates a new mock instance. +func NewMockCommunityPoolI(ctrl *gomock.Controller) *MockCommunityPoolI { + mock := &MockCommunityPoolI{ctrl: ctrl} + mock.recorder = &MockCommunityPoolIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCommunityPoolI) EXPECT() *MockCommunityPoolIMockRecorder { + return m.recorder +} + +// FundCommunityPool mocks base method. +func (m *MockCommunityPoolI) FundCommunityPool(ctx types.Context, amount types.Coins, sender types.AccAddress) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FundCommunityPool", ctx, amount, sender) + ret0, _ := ret[0].(error) + return ret0 +} + +// FundCommunityPool indicates an expected call of FundCommunityPool. +func (mr *MockCommunityPoolIMockRecorder) FundCommunityPool(ctx, amount, sender interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FundCommunityPool", reflect.TypeOf((*MockCommunityPoolI)(nil).FundCommunityPool), ctx, amount, sender) +} // MockPoolModuleI is a mock of PoolModuleI interface. type MockPoolModuleI struct { @@ -126,6 +276,21 @@ func (mr *MockPoolModuleIMockRecorder) GetPools(ctx interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPools", reflect.TypeOf((*MockPoolModuleI)(nil).GetPools), ctx) } +// GetTotalPoolLiquidity mocks base method. +func (m *MockPoolModuleI) GetTotalPoolLiquidity(ctx types.Context, poolId uint64) (types.Coins, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTotalPoolLiquidity", ctx, poolId) + ret0, _ := ret[0].(types.Coins) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTotalPoolLiquidity indicates an expected call of GetTotalPoolLiquidity. +func (mr *MockPoolModuleIMockRecorder) GetTotalPoolLiquidity(ctx, poolId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTotalPoolLiquidity", reflect.TypeOf((*MockPoolModuleI)(nil).GetTotalPoolLiquidity), ctx, poolId) +} + // InitializePool mocks base method. func (m *MockPoolModuleI) InitializePool(ctx types.Context, pool types2.PoolI, creatorAddress types.AccAddress) error { m.ctrl.T.Helper() @@ -168,4 +333,106 @@ func (m *MockPoolModuleI) SwapExactAmountOut(ctx types.Context, sender types.Acc func (mr *MockPoolModuleIMockRecorder) SwapExactAmountOut(ctx, sender, pool, tokenInDenom, tokenInMaxAmount, tokenOut, swapFee interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SwapExactAmountOut", reflect.TypeOf((*MockPoolModuleI)(nil).SwapExactAmountOut), ctx, sender, pool, tokenInDenom, tokenInMaxAmount, tokenOut, swapFee) -} \ No newline at end of file +} + +// MockPoolIncentivesKeeperI is a mock of PoolIncentivesKeeperI interface. +type MockPoolIncentivesKeeperI struct { + ctrl *gomock.Controller + recorder *MockPoolIncentivesKeeperIMockRecorder +} + +// MockPoolIncentivesKeeperIMockRecorder is the mock recorder for MockPoolIncentivesKeeperI. +type MockPoolIncentivesKeeperIMockRecorder struct { + mock *MockPoolIncentivesKeeperI +} + +// NewMockPoolIncentivesKeeperI creates a new mock instance. +func NewMockPoolIncentivesKeeperI(ctrl *gomock.Controller) *MockPoolIncentivesKeeperI { + mock := &MockPoolIncentivesKeeperI{ctrl: ctrl} + mock.recorder = &MockPoolIncentivesKeeperIMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPoolIncentivesKeeperI) EXPECT() *MockPoolIncentivesKeeperIMockRecorder { + return m.recorder +} + +// IsPoolIncentivized mocks base method. +func (m *MockPoolIncentivesKeeperI) IsPoolIncentivized(ctx types.Context, poolId uint64) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsPoolIncentivized", ctx, poolId) + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsPoolIncentivized indicates an expected call of IsPoolIncentivized. +func (mr *MockPoolIncentivesKeeperIMockRecorder) IsPoolIncentivized(ctx, poolId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsPoolIncentivized", reflect.TypeOf((*MockPoolIncentivesKeeperI)(nil).IsPoolIncentivized), ctx, poolId) +} + +// MockMultihopRoute is a mock of MultihopRoute interface. +type MockMultihopRoute struct { + ctrl *gomock.Controller + recorder *MockMultihopRouteMockRecorder +} + +// MockMultihopRouteMockRecorder is the mock recorder for MockMultihopRoute. +type MockMultihopRouteMockRecorder struct { + mock *MockMultihopRoute +} + +// NewMockMultihopRoute creates a new mock instance. +func NewMockMultihopRoute(ctrl *gomock.Controller) *MockMultihopRoute { + mock := &MockMultihopRoute{ctrl: ctrl} + mock.recorder = &MockMultihopRouteMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMultihopRoute) EXPECT() *MockMultihopRouteMockRecorder { + return m.recorder +} + +// IntermediateDenoms mocks base method. +func (m *MockMultihopRoute) IntermediateDenoms() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IntermediateDenoms") + ret0, _ := ret[0].([]string) + return ret0 +} + +// IntermediateDenoms indicates an expected call of IntermediateDenoms. +func (mr *MockMultihopRouteMockRecorder) IntermediateDenoms() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IntermediateDenoms", reflect.TypeOf((*MockMultihopRoute)(nil).IntermediateDenoms)) +} + +// Length mocks base method. +func (m *MockMultihopRoute) Length() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Length") + ret0, _ := ret[0].(int) + return ret0 +} + +// Length indicates an expected call of Length. +func (mr *MockMultihopRouteMockRecorder) Length() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Length", reflect.TypeOf((*MockMultihopRoute)(nil).Length)) +} + +// PoolIds mocks base method. +func (m *MockMultihopRoute) PoolIds() []uint64 { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PoolIds") + ret0, _ := ret[0].([]uint64) + return ret0 +} + +// PoolIds indicates an expected call of PoolIds. +func (mr *MockMultihopRouteMockRecorder) PoolIds() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PoolIds", reflect.TypeOf((*MockMultihopRoute)(nil).PoolIds)) +} diff --git a/x/concentrated-liquidity/lp_test.go b/x/concentrated-liquidity/lp_test.go index 7a827de5b46..94f4464c41a 100644 --- a/x/concentrated-liquidity/lp_test.go +++ b/x/concentrated-liquidity/lp_test.go @@ -945,6 +945,9 @@ func (s *KeeperTestSuite) TestInverseRelation_CreatePosition_WithdrawPosition() poolBefore, err := clKeeper.GetPool(s.Ctx, poolID) s.Require().NoError(err) + liquidityBefore, err := s.App.ConcentratedLiquidityKeeper.GetTotalPoolLiquidity(s.Ctx, poolID) + s.Require().NoError(err) + // Pre-set fee growth accumulator if !tc.preSetChargeFee.IsZero() { err = clKeeper.ChargeFee(s.Ctx, 1, tc.preSetChargeFee) @@ -990,10 +993,12 @@ func (s *KeeperTestSuite) TestInverseRelation_CreatePosition_WithdrawPosition() s.Require().Equal(sdk.Dec{}, positionLiquidity) // 4. Check that pool has come back to original state - poolAfter, err := clKeeper.GetPool(s.Ctx, poolID) + + liquidityAfter, err := s.App.ConcentratedLiquidityKeeper.GetTotalPoolLiquidity(s.Ctx, poolID) + s.Require().NoError(err) + s.Require().NoError(err) - s.Require().Equal(poolBefore.GetTotalShares(), poolAfter.GetTotalShares()) - s.Require().Equal(poolBefore.GetTotalPoolLiquidity(s.Ctx), poolAfter.GetTotalPoolLiquidity(s.Ctx)) + s.Require().Equal(liquidityBefore, liquidityAfter) }) } } diff --git a/x/concentrated-liquidity/model/pool.go b/x/concentrated-liquidity/model/pool.go index 47a378b96fe..fd1dc9a980d 100644 --- a/x/concentrated-liquidity/model/pool.go +++ b/x/concentrated-liquidity/model/pool.go @@ -95,11 +95,6 @@ func (p Pool) GetSwapFee(ctx sdk.Context) sdk.Dec { return p.SwapFee } -// GetExitFee returns the exit fee of the pool -func (p Pool) GetExitFee(ctx sdk.Context) sdk.Dec { - return sdk.ZeroDec() -} - // IsActive returns true if the pool is active func (p Pool) IsActive(ctx sdk.Context) bool { return true @@ -124,11 +119,6 @@ func (p Pool) SpotPrice(ctx sdk.Context, baseAssetDenom string, quoteAssetDenom return sdk.NewDec(1).Quo(p.CurrentSqrtPrice.Power(2)), nil } -// GetTotalShares returns the total shares of the pool -func (p Pool) GetTotalShares() sdk.Int { - return sdk.Int{} -} - // GetToken0 returns the token0 of the pool func (p Pool) GetToken0() string { return p.Token0 @@ -278,8 +268,3 @@ func (p *Pool) ApplySwap(newLiquidity sdk.Dec, newCurrentTick sdk.Int, newCurren return nil } - -// TODO: finish this function -func (p Pool) GetTotalPoolLiquidity(ctx sdk.Context) sdk.Coins { - return sdk.Coins{} -} diff --git a/x/concentrated-liquidity/pool.go b/x/concentrated-liquidity/pool.go index e19c216f256..12d649774ad 100644 --- a/x/concentrated-liquidity/pool.go +++ b/x/concentrated-liquidity/pool.go @@ -146,6 +146,22 @@ func (k Keeper) CalculateSpotPrice( return price, nil } +// GetTotalPoolLiquidity returns the coins in the pool owned by all LPs +func (k Keeper) GetTotalPoolLiquidity(ctx sdk.Context, poolId uint64) (sdk.Coins, error) { + pool, err := k.getPoolById(ctx, poolId) + if err != nil { + return nil, err + } + + poolBalance := k.bankKeeper.GetAllBalances(ctx, pool.GetAddress()) + + // This is to ensure that malicious actor cannot send dust to + // a pool address. + filteredPoolBalance := poolBalance.FilterDenoms([]string{pool.GetToken0(), pool.GetToken1()}) + + return filteredPoolBalance, nil +} + // convertConcentratedToPoolInterface takes a types.ConcentratedPoolExtension and attempts to convert it to a // poolmanagertypes.PoolI. If the conversion is successful, the converted value is returned. If the conversion fails, // an error is returned. diff --git a/x/concentrated-liquidity/pool_test.go b/x/concentrated-liquidity/pool_test.go index f4a8da1c1a3..af7561e6264 100644 --- a/x/concentrated-liquidity/pool_test.go +++ b/x/concentrated-liquidity/pool_test.go @@ -257,6 +257,82 @@ func (s *KeeperTestSuite) TestCalculateSpotPrice() { s.Require().True(spotPrice.IsNil()) } +func (s *KeeperTestSuite) TestGetTotalPoolLiquidity() { + var ( + defaultPoolCoinOne = sdk.NewCoin(USDC, sdk.OneInt()) + defaultPoolCoinTwo = sdk.NewCoin(ETH, sdk.NewInt(2)) + nonPoolCool = sdk.NewCoin("uosmo", sdk.NewInt(3)) + + defaultCoins = sdk.NewCoins(defaultPoolCoinOne, defaultPoolCoinTwo) + ) + + tests := []struct { + name string + poolId uint64 + poolLiquidity sdk.Coins + expectedResult sdk.Coins + expectedErr error + }{ + { + name: "valid with 2 coins", + poolId: defaultPoolId, + poolLiquidity: defaultCoins, + expectedResult: defaultCoins, + }, + { + name: "valid with 1 coin", + poolId: defaultPoolId, + poolLiquidity: sdk.NewCoins(defaultPoolCoinTwo), + expectedResult: sdk.NewCoins(defaultPoolCoinTwo), + }, + { + // can only happen if someone sends extra tokens to pool + // address. Should not occur in practice. + name: "valid with 3 coins", + poolId: defaultPoolId, + poolLiquidity: sdk.NewCoins(defaultPoolCoinTwo, defaultPoolCoinOne, nonPoolCool), + expectedResult: defaultCoins, + }, + { + // this can happen if someone sends random dust to pool address. + name: "only non-pool coin - does not show up in result", + poolId: defaultPoolId, + poolLiquidity: sdk.NewCoins(nonPoolCool), + expectedResult: sdk.Coins(nil), + }, + { + name: "invalid pool id", + poolId: defaultPoolId + 1, + expectedErr: types.PoolNotFoundError{PoolId: defaultPoolId + 1}, + }, + } + + for _, tc := range tests { + tc := tc + s.Run(tc.name, func() { + s.SetupTest() + + // Create default CL pool + pool := s.PrepareConcentratedPool() + + s.FundAcc(pool.GetAddress(), tc.poolLiquidity) + + // Get pool defined in test case + actual, err := s.App.ConcentratedLiquidityKeeper.GetTotalPoolLiquidity(s.Ctx, tc.poolId) + + if tc.expectedErr != nil { + s.Require().Error(err) + s.Require().ErrorIs(err, tc.expectedErr) + s.Require().Nil(actual) + return + } + + s.Require().NoError(err) + s.Require().Equal(tc.expectedResult, actual) + }) + } +} + func (s *KeeperTestSuite) TestValidateSwapFee() { tests := []struct { name string diff --git a/x/concentrated-liquidity/swaps_test.go b/x/concentrated-liquidity/swaps_test.go index d7709607b8e..35d1528d2d1 100644 --- a/x/concentrated-liquidity/swaps_test.go +++ b/x/concentrated-liquidity/swaps_test.go @@ -1527,7 +1527,6 @@ func (s *KeeperTestSuite) TestCalcAndSwapOutAmtGivenIn() { s.Require().Equal(poolBeforeCalc.GetCurrentSqrtPrice(), poolAfterCalc.GetCurrentSqrtPrice()) s.Require().Equal(poolBeforeCalc.GetCurrentTick(), poolAfterCalc.GetCurrentTick()) - s.Require().Equal(poolBeforeCalc.GetTotalShares(), poolAfterCalc.GetTotalShares()) s.Require().Equal(poolBeforeCalc.GetLiquidity(), poolAfterCalc.GetLiquidity()) s.Require().Equal(poolBeforeCalc.GetTickSpacing(), poolAfterCalc.GetTickSpacing()) } @@ -1761,7 +1760,6 @@ func (s *KeeperTestSuite) TestCalcAndSwapInAmtGivenOut() { s.Require().Equal(poolBeforeCalc.GetCurrentSqrtPrice(), poolAfterCalc.GetCurrentSqrtPrice()) s.Require().Equal(poolBeforeCalc.GetCurrentTick(), poolAfterCalc.GetCurrentTick()) - s.Require().Equal(poolBeforeCalc.GetTotalShares(), poolAfterCalc.GetTotalShares()) s.Require().Equal(poolBeforeCalc.GetLiquidity(), poolAfterCalc.GetLiquidity()) s.Require().Equal(poolBeforeCalc.GetTickSpacing(), poolAfterCalc.GetTickSpacing()) } @@ -2319,7 +2317,6 @@ func (s *KeeperTestSuite) TestCalcOutAmtGivenInWriteCtx() { s.Require().Equal(poolBeforeCalc.GetCurrentSqrtPrice(), poolAfterCalc.GetCurrentSqrtPrice()) s.Require().Equal(poolBeforeCalc.GetCurrentTick(), poolAfterCalc.GetCurrentTick()) - s.Require().Equal(poolBeforeCalc.GetTotalShares(), poolAfterCalc.GetTotalShares()) s.Require().Equal(poolBeforeCalc.GetLiquidity(), poolAfterCalc.GetLiquidity()) s.Require().Equal(poolBeforeCalc.GetTickSpacing(), poolAfterCalc.GetTickSpacing()) @@ -2405,7 +2402,6 @@ func (s *KeeperTestSuite) TestCalcInAmtGivenOutWriteCtx() { s.Require().Equal(poolBeforeCalc.GetCurrentSqrtPrice(), poolAfterCalc.GetCurrentSqrtPrice()) s.Require().Equal(poolBeforeCalc.GetCurrentTick(), poolAfterCalc.GetCurrentTick()) - s.Require().Equal(poolBeforeCalc.GetTotalShares(), poolAfterCalc.GetTotalShares()) s.Require().Equal(poolBeforeCalc.GetLiquidity(), poolAfterCalc.GetLiquidity()) s.Require().Equal(poolBeforeCalc.GetTickSpacing(), poolAfterCalc.GetTickSpacing()) @@ -2690,6 +2686,9 @@ func (s *KeeperTestSuite) inverseRelationshipInvariants(firstTokenIn, firstToken pool, ok := poolBefore.(cltypes.ConcentratedPoolExtension) s.Require().True(ok) + liquidityBefore, err := s.App.ConcentratedLiquidityKeeper.GetTotalPoolLiquidity(s.Ctx, pool.GetId()) + s.Require().NoError(err) + // The output of the first swap should be exactly the same as the input of the second swap. // The input of the first swap should be within a margin of error of the output of the second swap. if outGivenIn { @@ -2704,9 +2703,11 @@ func (s *KeeperTestSuite) inverseRelationshipInvariants(firstTokenIn, firstToken poolAfter, err := s.App.ConcentratedLiquidityKeeper.GetPool(s.Ctx, poolBefore.GetId()) s.Require().NoError(err) + liquidityAfter, err := s.App.ConcentratedLiquidityKeeper.GetTotalPoolLiquidity(s.Ctx, pool.GetId()) + s.Require().NoError(err) + // After both swaps, the pool should have the same total shares and total liquidity. - s.Require().Equal(poolBefore.GetTotalShares(), poolAfter.GetTotalShares()) - s.Require().Equal(poolBefore.GetTotalPoolLiquidity(s.Ctx), poolAfter.GetTotalPoolLiquidity(s.Ctx)) + s.Require().Equal(liquidityBefore, liquidityAfter) // Within a margin of error, the spot price should be the same before and after the swap oldSpotPrice, err := poolBefore.SpotPrice(s.Ctx, pool.GetToken0(), pool.GetToken1()) diff --git a/x/concentrated-liquidity/types/expected_keepers.go b/x/concentrated-liquidity/types/expected_keepers.go index 46cc5e095f6..737f47567b6 100644 --- a/x/concentrated-liquidity/types/expected_keepers.go +++ b/x/concentrated-liquidity/types/expected_keepers.go @@ -10,6 +10,7 @@ import ( // BankKeeper defines the banking contract that must be fulfilled when // creating a x/concentrated-liquidity keeper. type BankKeeper interface { + GetAllBalances(ctx sdk.Context, addr sdk.AccAddress) sdk.Coins GetDenomMetaData(ctx sdk.Context, denom string) (banktypes.Metadata, bool) SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error HasBalance(ctx sdk.Context, addr sdk.AccAddress, amt sdk.Coin) bool diff --git a/x/gamm/keeper/grpc_query.go b/x/gamm/keeper/grpc_query.go index 1201b954ce1..2b836a4710c 100644 --- a/x/gamm/keeper/grpc_query.go +++ b/x/gamm/keeper/grpc_query.go @@ -153,7 +153,7 @@ func (q Querier) CalcJoinPoolShares(ctx context.Context, req *types.QueryCalcJoi } sdkCtx := sdk.UnwrapSDKContext(ctx) - pool, err := q.Keeper.getPoolForSwap(sdkCtx, req.PoolId) + pool, err := q.Keeper.GetCFMMPool(sdkCtx, req.PoolId) if err != nil { return nil, err } diff --git a/x/gamm/keeper/migrate.go b/x/gamm/keeper/migrate.go index 18932f1b105..b93be110c81 100644 --- a/x/gamm/keeper/migrate.go +++ b/x/gamm/keeper/migrate.go @@ -132,7 +132,10 @@ func (k Keeper) validateRecords(ctx sdk.Context, records []types.BalancerToConce } // Ensure the balancer pools denoms are the same as the concentrated pool denoms - balancerPoolAssets := balancerPool.GetTotalPoolLiquidity(ctx) + balancerPoolAssets, err := k.GetTotalPoolLiquidity(ctx, balancerPool.GetId()) + if err != nil { + return err + } if len(balancerPoolAssets) != 2 { return fmt.Errorf("Balancer pool ID #%d does not contain exactly 2 tokens", record.BalancerPoolId) diff --git a/x/gamm/keeper/pool.go b/x/gamm/keeper/pool.go index 73a924a9960..2f7abf31517 100644 --- a/x/gamm/keeper/pool.go +++ b/x/gamm/keeper/pool.go @@ -68,8 +68,10 @@ func (k Keeper) GetPoolAndPoke(ctx sdk.Context, poolId uint64) (types.CFMMPoolI, return pool, nil } -// Get pool and check if the pool is active, i.e. allowed to be swapped against. -func (k Keeper) getPoolForSwap(ctx sdk.Context, poolId uint64) (types.CFMMPoolI, error) { +// GetCFMMPool gets CFMMPool and checks if the pool is active, i.e. allowed to be swapped against. +// The difference from GetPools is that this function returns an error if the pool is inactive. +// Additionally, it returns x/gamm specific CFMMPool type. +func (k Keeper) GetCFMMPool(ctx sdk.Context, poolId uint64) (types.CFMMPoolI, error) { pool, err := k.GetPoolAndPoke(ctx, poolId) if err != nil { return &balancer.Pool{}, err @@ -286,6 +288,15 @@ func (k Keeper) GetPoolType(ctx sdk.Context, poolId uint64) (poolmanagertypes.Po } } +// GetTotalPoolLiquidity returns the coins in the pool owned by all LPs +func (k Keeper) GetTotalPoolLiquidity(ctx sdk.Context, poolId uint64) (sdk.Coins, error) { + pool, err := k.GetCFMMPool(ctx, poolId) + if err != nil { + return nil, err + } + return pool.GetTotalPoolLiquidity(ctx), nil +} + // setStableSwapScalingFactors sets the stable swap scaling factors. // errors if the pool does not exist, the sender is not the scaling factor controller, or due to other // internal errors. @@ -308,7 +319,6 @@ func (k Keeper) setStableSwapScalingFactors(ctx sdk.Context, poolId uint64, scal // convertToCFMMPool converts PoolI to CFMMPoolI by casting the input. // Returns the pool of the CFMMPoolI or error if the given pool does not implement // CFMMPoolI. -// nolint: unused func convertToCFMMPool(pool poolmanagertypes.PoolI) (types.CFMMPoolI, error) { cfmmPool, ok := pool.(types.CFMMPoolI) if !ok { diff --git a/x/gamm/keeper/pool_service.go b/x/gamm/keeper/pool_service.go index e3139d301f9..ae1597af59e 100644 --- a/x/gamm/keeper/pool_service.go +++ b/x/gamm/keeper/pool_service.go @@ -63,8 +63,18 @@ func (k Keeper) CalculateSpotPrice( // - Records total liquidity increase // - Calls the AfterPoolCreated hook func (k Keeper) InitializePool(ctx sdk.Context, pool poolmanagertypes.PoolI, sender sdk.AccAddress) (err error) { + cfmmPool, err := convertToCFMMPool(pool) + if err != nil { + return err + } + + exitFee := cfmmPool.GetExitFee(ctx) + if !exitFee.Equal(sdk.ZeroDec()) { + return fmt.Errorf("can not create pool with non zero exit fee, got %d", exitFee) + } + // Mint the initial pool shares share token to the sender - err = k.MintPoolShareToAccount(ctx, pool, sender, pool.GetTotalShares()) + err = k.MintPoolShareToAccount(ctx, pool, sender, cfmmPool.GetTotalShares()) if err != nil { return err } @@ -97,7 +107,7 @@ func (k Keeper) InitializePool(ctx sdk.Context, pool poolmanagertypes.PoolI, sen } k.hooks.AfterPoolCreated(ctx, sender, pool.GetId()) - k.RecordTotalLiquidityIncrease(ctx, pool.GetTotalPoolLiquidity(ctx)) + k.RecordTotalLiquidityIncrease(ctx, cfmmPool.GetTotalPoolLiquidity(ctx)) return nil } @@ -216,7 +226,7 @@ func (k Keeper) JoinSwapExactAmountIn( } }() - pool, err := k.getPoolForSwap(ctx, poolId) + pool, err := k.GetCFMMPool(ctx, poolId) if err != nil { return sdk.Int{}, err } @@ -260,7 +270,7 @@ func (k Keeper) JoinSwapShareAmountOut( } }() - pool, err := k.getPoolForSwap(ctx, poolId) + pool, err := k.GetCFMMPool(ctx, poolId) if err != nil { return sdk.Int{}, err } @@ -376,7 +386,7 @@ func (k Keeper) ExitSwapExactAmountOut( tokenOut sdk.Coin, shareInMaxAmount sdk.Int, ) (shareInAmount sdk.Int, err error) { - pool, err := k.getPoolForSwap(ctx, poolId) + pool, err := k.GetCFMMPool(ctx, poolId) if err != nil { return sdk.Int{}, err } diff --git a/x/gamm/keeper/pool_service_test.go b/x/gamm/keeper/pool_service_test.go index b0514689311..9a0c96409ff 100644 --- a/x/gamm/keeper/pool_service_test.go +++ b/x/gamm/keeper/pool_service_test.go @@ -18,7 +18,6 @@ import ( ) var ( - defaultPoolParams = balancer.PoolParams{ SwapFee: defaultSwapFee, ExitFee: defaultZeroExitFee, @@ -251,10 +250,9 @@ func (suite *KeeperTestSuite) TestInitializePool() { testAccount := suite.TestAccs[0] tests := []struct { - name string - createPool func() poolmanagertypes.PoolI - expectPass bool - expectPanic bool + name string + createPool func() poolmanagertypes.PoolI + expectPass bool }{ { name: "initialize balancer pool with default assets", @@ -288,91 +286,107 @@ func (suite *KeeperTestSuite) TestInitializePool() { expectPass: true, }, { - name: "initialize a CL pool which cause panic", + name: "initialize a CL pool which cause error", createPool: func() poolmanagertypes.PoolI { return suite.PrepareConcentratedPool() }, - expectPanic: true, + expectPass: false, + }, + { + name: "initialize pool with non-zero exit fee", + createPool: func() poolmanagertypes.PoolI { + balancerPool, err := balancer.NewBalancerPool( + defaultPoolId, + balancer.PoolParams{ + SwapFee: defaultSwapFee, + ExitFee: sdk.NewDecWithPrec(5, 1), + }, + defaultPoolAssets, + "", + time.Now(), + ) + require.NoError(suite.T(), err) + return &balancerPool + }, + expectPass: false, }, } for _, test := range tests { test := test suite.Run(test.name, func() { - osmoassert.ConditionalPanic(suite.T(), test.expectPanic, func() { - suite.SetupTest() + suite.SetupTest() - gammKeeper := suite.App.GAMMKeeper - bankKeeper := suite.App.BankKeeper - poolIncentivesKeeper := suite.App.PoolIncentivesKeeper + gammKeeper := suite.App.GAMMKeeper + bankKeeper := suite.App.BankKeeper + poolIncentivesKeeper := suite.App.PoolIncentivesKeeper - // sender test account - sender := testAccount - senderBalBeforeNewPool := bankKeeper.GetAllBalances(suite.Ctx, sender) + // sender test account + sender := testAccount + senderBalBeforeNewPool := bankKeeper.GetAllBalances(suite.Ctx, sender) - // initializePool with a poolI - // initializePool shoould be called by pool manager in practice. - // We set pool route here to make sure hooks from InitializePool do not break - suite.App.PoolManagerKeeper.SetPoolRoute(suite.Ctx, defaultPoolId, poolmanagertypes.Balancer) - err := gammKeeper.InitializePool(suite.Ctx, test.createPool(), sender) + // initializePool with a poolI + // initializePool shoould be called by pool manager in practice. + // We set pool route here to make sure hooks from InitializePool do not break + suite.App.PoolManagerKeeper.SetPoolRoute(suite.Ctx, defaultPoolId, poolmanagertypes.Balancer) + err := gammKeeper.InitializePool(suite.Ctx, test.createPool(), sender) + + if test.expectPass { + suite.Require().NoError(err, "test: %v", test.name) + + // check to make sure new pool exists and has minted the correct number of pool shares + pool, err := gammKeeper.GetPoolAndPoke(suite.Ctx, defaultPoolId) + suite.Require().NoError(err, "test: %v", test.name) + suite.Require().Equal(types.InitPoolSharesSupply.String(), pool.GetTotalShares().String(), + fmt.Sprintf("share token should be minted as %s initially", types.InitPoolSharesSupply), + ) - if test.expectPass { + // check to make sure user user balance increase correct number of pool shares + suite.Require().Equal( + senderBalBeforeNewPool.Add(sdk.NewCoin(types.GetPoolShareDenom(pool.GetId()), types.InitPoolSharesSupply)), + bankKeeper.GetAllBalances(suite.Ctx, sender), + ) + + // get expected tokens in new pool and corresponding pool shares + expectedPoolTokens := sdk.NewCoins() + for _, asset := range pool.GetTotalPoolLiquidity(suite.Ctx) { + expectedPoolTokens = expectedPoolTokens.Add(asset) + } + expectedPoolShares := sdk.NewCoin(types.GetPoolShareDenom(pool.GetId()), types.InitPoolSharesSupply) + + // make sure expected pool tokens and expected pool shares matches the actual tokens and shares in the pool + suite.Require().Equal(expectedPoolTokens.String(), pool.GetTotalPoolLiquidity(suite.Ctx).String()) + suite.Require().Equal(expectedPoolShares.Amount.String(), pool.GetTotalShares().String()) + + // check pool metadata + poolShareBaseDenom := types.GetPoolShareDenom(pool.GetId()) + poolShareDisplayDenom := fmt.Sprintf("GAMM-%d", pool.GetId()) + metadata, found := bankKeeper.GetDenomMetaData(suite.Ctx, poolShareBaseDenom) + suite.Require().Equal(found, true, fmt.Sprintf("Pool share denom %s is not set", poolShareDisplayDenom)) + suite.Require().Equal(metadata.Base, poolShareBaseDenom, fmt.Sprintf("Pool share base denom %s is not correctly set", poolShareBaseDenom)) + suite.Require().Equal(metadata.Display, poolShareDisplayDenom, fmt.Sprintf("Pool share display denom %s is not correctly set", poolShareDisplayDenom)) + suite.Require().Equal(metadata.DenomUnits[0].Denom, poolShareBaseDenom) + suite.Require().Equal(metadata.DenomUnits[0].Exponent, uint32(0x0)) + suite.Require().Equal(metadata.DenomUnits[0].Aliases, []string{ + "attopoolshare", + }) + suite.Require().Equal(metadata.DenomUnits[1].Denom, poolShareDisplayDenom) + suite.Require().Equal(metadata.DenomUnits[1].Exponent, uint32(types.OneShareExponent)) + suite.Require().Equal(metadata.DenomUnits[1].Aliases, []string(nil)) + + // check AfterPoolCreated hook + for _, lockableDuration := range poolIncentivesKeeper.GetLockableDurations(suite.Ctx) { + gaugeId, err := poolIncentivesKeeper.GetPoolGaugeId(suite.Ctx, defaultPoolId, lockableDuration) suite.Require().NoError(err, "test: %v", test.name) - // check to make sure new pool exists and has minted the correct number of pool shares - pool, err := gammKeeper.GetPoolAndPoke(suite.Ctx, defaultPoolId) + poolIdFromPoolIncentives, err := poolIncentivesKeeper.GetPoolIdFromGaugeId(suite.Ctx, gaugeId, lockableDuration) suite.Require().NoError(err, "test: %v", test.name) - suite.Require().Equal(types.InitPoolSharesSupply.String(), pool.GetTotalShares().String(), - fmt.Sprintf("share token should be minted as %s initially", types.InitPoolSharesSupply), - ) - - // check to make sure user user balance increase correct number of pool shares - suite.Require().Equal( - senderBalBeforeNewPool.Add(sdk.NewCoin(types.GetPoolShareDenom(pool.GetId()), types.InitPoolSharesSupply)), - bankKeeper.GetAllBalances(suite.Ctx, sender), - ) - - // get expected tokens in new pool and corresponding pool shares - expectedPoolTokens := sdk.NewCoins() - for _, asset := range pool.GetTotalPoolLiquidity(suite.Ctx) { - expectedPoolTokens = expectedPoolTokens.Add(asset) - } - expectedPoolShares := sdk.NewCoin(types.GetPoolShareDenom(pool.GetId()), types.InitPoolSharesSupply) - - // make sure expected pool tokens and expected pool shares matches the actual tokens and shares in the pool - suite.Require().Equal(expectedPoolTokens.String(), pool.GetTotalPoolLiquidity(suite.Ctx).String()) - suite.Require().Equal(expectedPoolShares.Amount.String(), pool.GetTotalShares().String()) - - // check pool metadata - poolShareBaseDenom := types.GetPoolShareDenom(pool.GetId()) - poolShareDisplayDenom := fmt.Sprintf("GAMM-%d", pool.GetId()) - metadata, found := bankKeeper.GetDenomMetaData(suite.Ctx, poolShareBaseDenom) - suite.Require().Equal(found, true, fmt.Sprintf("Pool share denom %s is not set", poolShareDisplayDenom)) - suite.Require().Equal(metadata.Base, poolShareBaseDenom, fmt.Sprintf("Pool share base denom %s is not correctly set", poolShareBaseDenom)) - suite.Require().Equal(metadata.Display, poolShareDisplayDenom, fmt.Sprintf("Pool share display denom %s is not correctly set", poolShareDisplayDenom)) - suite.Require().Equal(metadata.DenomUnits[0].Denom, poolShareBaseDenom) - suite.Require().Equal(metadata.DenomUnits[0].Exponent, uint32(0x0)) - suite.Require().Equal(metadata.DenomUnits[0].Aliases, []string{ - "attopoolshare", - }) - suite.Require().Equal(metadata.DenomUnits[1].Denom, poolShareDisplayDenom) - suite.Require().Equal(metadata.DenomUnits[1].Exponent, uint32(types.OneShareExponent)) - suite.Require().Equal(metadata.DenomUnits[1].Aliases, []string(nil)) - - // check AfterPoolCreated hook - for _, lockableDuration := range poolIncentivesKeeper.GetLockableDurations(suite.Ctx) { - gaugeId, err := poolIncentivesKeeper.GetPoolGaugeId(suite.Ctx, defaultPoolId, lockableDuration) - suite.Require().NoError(err, "test: %v", test.name) - - poolIdFromPoolIncentives, err := poolIncentivesKeeper.GetPoolIdFromGaugeId(suite.Ctx, gaugeId, lockableDuration) - suite.Require().NoError(err, "test: %v", test.name) - suite.Require().Equal(poolIdFromPoolIncentives, defaultPoolId) - } - - } else { - suite.Require().Error(err, "test: %v", test.name) + suite.Require().Equal(poolIdFromPoolIncentives, defaultPoolId) } - }) + + } else { + suite.Require().Error(err, "test: %v", test.name) + } }) } } diff --git a/x/gamm/keeper/swap.go b/x/gamm/keeper/swap.go index e37d9f6a623..c9d0d94e9f9 100644 --- a/x/gamm/keeper/swap.go +++ b/x/gamm/keeper/swap.go @@ -97,7 +97,12 @@ func (k Keeper) SwapExactAmountOut( } }() - poolOutBal := pool.GetTotalPoolLiquidity(ctx).AmountOf(tokenOut.Denom) + liquidity, err := k.GetTotalPoolLiquidity(ctx, pool.GetId()) + if err != nil { + return sdk.Int{}, err + } + + poolOutBal := liquidity.AmountOf(tokenOut.Denom) if tokenOut.Amount.GTE(poolOutBal) { return sdk.Int{}, sdkerrors.Wrapf(types.ErrTooManyTokensOut, "can't get more tokens out than there are tokens in the pool") diff --git a/x/gamm/pool-models/balancer/msgs_test.go b/x/gamm/pool-models/balancer/msgs_test.go index 5deefb1dbfe..44b3268e4ca 100644 --- a/x/gamm/pool-models/balancer/msgs_test.go +++ b/x/gamm/pool-models/balancer/msgs_test.go @@ -285,8 +285,12 @@ func (suite *KeeperTestSuite) TestMsgCreateBalancerPool() { for _, asset := range tc.msg.PoolAssets { expectedPoolLiquidity = expectedPoolLiquidity.Add(asset.Token) } - suite.Require().Equal(expectedPoolLiquidity, pool.GetTotalPoolLiquidity(suite.Ctx)) - suite.Require().Equal(types.InitPoolSharesSupply, pool.GetTotalShares()) + + cfmmPool, ok := pool.(types.CFMMPoolI) + suite.Require().True(ok) + + suite.Require().Equal(expectedPoolLiquidity, cfmmPool.GetTotalPoolLiquidity(suite.Ctx)) + suite.Require().Equal(types.InitPoolSharesSupply, cfmmPool.GetTotalShares()) }) } } diff --git a/x/gamm/pool-models/stableswap/msgs_test.go b/x/gamm/pool-models/stableswap/msgs_test.go index 89702f94f03..ea20a88917f 100644 --- a/x/gamm/pool-models/stableswap/msgs_test.go +++ b/x/gamm/pool-models/stableswap/msgs_test.go @@ -348,8 +348,12 @@ func (suite *TestSuite) TestMsgCreateStableswapPool() { suite.Require().NoError(err) suite.Require().Equal(tc.poolId, pool.GetId()) - suite.Require().Equal(tc.msg.InitialPoolLiquidity, pool.GetTotalPoolLiquidity(suite.Ctx)) - suite.Require().Equal(types.InitPoolSharesSupply, pool.GetTotalShares()) + + cfmmPool, ok := pool.(types.CFMMPoolI) + suite.Require().True(ok) + + suite.Require().Equal(tc.msg.InitialPoolLiquidity, cfmmPool.GetTotalPoolLiquidity(suite.Ctx)) + suite.Require().Equal(types.InitPoolSharesSupply, cfmmPool.GetTotalShares()) }) } } diff --git a/x/gamm/types/pool.go b/x/gamm/types/pool.go index 6bac8ca41a2..890dccecc39 100644 --- a/x/gamm/types/pool.go +++ b/x/gamm/types/pool.go @@ -50,6 +50,13 @@ type CFMMPoolI interface { // CalcInAmtGivenOut returns how many coins SwapInAmtGivenOut would return on these arguments. // This does not mutate the pool, or state. CalcInAmtGivenOut(ctx sdk.Context, tokenOut sdk.Coins, tokenInDenom string, swapFee sdk.Dec) (tokenIn sdk.Coin, err error) + // GetTotalShares returns the total number of LP shares in the pool + GetTotalShares() sdk.Int + // GetTotalPoolLiquidity returns the coins in the pool owned by all LPs + GetTotalPoolLiquidity(ctx sdk.Context) sdk.Coins + // GetExitFee returns the pool's exit fee, based on the current state. + // Pools may choose to make their exit fees dependent upon state. + GetExitFee(ctx sdk.Context) sdk.Dec } // PoolAmountOutExtension is an extension of the PoolI diff --git a/x/poolmanager/create_pool.go b/x/poolmanager/create_pool.go index 7e956e397cc..87a812db6a4 100644 --- a/x/poolmanager/create_pool.go +++ b/x/poolmanager/create_pool.go @@ -61,11 +61,6 @@ func (k Keeper) CreatePool(ctx sdk.Context, msg types.CreatePoolMsg) (uint64, er return 0, err } - exitFee := pool.GetExitFee(ctx) - if !exitFee.Equal(sdk.ZeroDec()) { - return 0, fmt.Errorf("can not create pool with non zero exit fee, got %d", exitFee) - } - k.SetPoolRoute(ctx, poolId, msg.GetPoolType()) if err := k.validateCreatedPool(ctx, poolId, pool); err != nil { diff --git a/x/poolmanager/types/pool.go b/x/poolmanager/types/pool.go index dd8e3473294..cf113beb246 100644 --- a/x/poolmanager/types/pool.go +++ b/x/poolmanager/types/pool.go @@ -19,15 +19,9 @@ type PoolI interface { // (prior TWAPs, network downtime, other pool states, etc.) // hence Context is provided as an argument. GetSwapFee(ctx sdk.Context) sdk.Dec - // GetExitFee returns the pool's exit fee, based on the current state. - // Pools may choose to make their exit fees dependent upon state. - GetExitFee(ctx sdk.Context) sdk.Dec // Returns whether the pool has swaps enabled at the moment IsActive(ctx sdk.Context) bool - // GetTotalShares returns the total number of LP shares in the pool - GetTotalShares() sdk.Int - // GetTotalPoolLiquidity returns the coins in the pool owned by all LPs - GetTotalPoolLiquidity(ctx sdk.Context) sdk.Coins + // Returns the spot price of the 'base asset' in terms of the 'quote asset' in the pool, // errors if either baseAssetDenom, or quoteAssetDenom does not exist. // For example, if this was a UniV2 50-50 pool, with 2 ETH, and 8000 UST diff --git a/x/poolmanager/types/routes.go b/x/poolmanager/types/routes.go index 663f6385c7f..1664552cc9c 100644 --- a/x/poolmanager/types/routes.go +++ b/x/poolmanager/types/routes.go @@ -81,6 +81,9 @@ type PoolModuleI interface { tokenInDenom string, swapFee sdk.Dec, ) (tokenIn sdk.Coin, err error) + + // GetTotalPoolLiquidity returns the coins in the pool owned by all LPs + GetTotalPoolLiquidity(ctx sdk.Context, poolId uint64) (sdk.Coins, error) } type PoolIncentivesKeeperI interface {