From bf8d9916fc987954ab14c9b0bfb85a1a9089fa2b Mon Sep 17 00:00:00 2001 From: codchen Date: Mon, 22 May 2023 14:24:08 +0800 Subject: [PATCH] Not charge gas for contract dependency loading --- x/dex/ante.go | 11 +-- x/dex/cache/cache.go | 18 ++-- x/dex/cache/cache_test.go | 8 +- x/dex/keeper/contract.go | 4 + x/dex/keeper/contract_test.go | 33 +++++++ .../msgserver/msg_server_cancel_orders.go | 8 +- .../msgserver/msg_server_place_orders.go | 8 +- x/dex/module_test.go | 96 +++---------------- 8 files changed, 63 insertions(+), 123 deletions(-) diff --git a/x/dex/ante.go b/x/dex/ante.go index 844757e3eb..3742370946 100644 --- a/x/dex/ante.go +++ b/x/dex/ante.go @@ -126,20 +126,13 @@ func (d CheckDexGasDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, simulate bo } else { memState = utils.GetMemState(ctx.Context()) } - contractLoader := func(addr string) *types.ContractInfoV2 { - contract, err := d.dexKeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &contract - } for _, msg := range tx.GetMsgs() { switch m := msg.(type) { case *types.MsgPlaceOrders: - numDependencies := len(memState.GetContractToDependencies(m.ContractAddr, contractLoader)) + numDependencies := len(memState.GetContractToDependencies(ctx, m.ContractAddr, d.dexKeeper.GetContractWithoutGasCharge)) dexGasRequired += params.DefaultGasPerOrder * uint64(len(m.Orders)*numDependencies) case *types.MsgCancelOrders: - numDependencies := len(memState.GetContractToDependencies(m.ContractAddr, contractLoader)) + numDependencies := len(memState.GetContractToDependencies(ctx, m.ContractAddr, d.dexKeeper.GetContractWithoutGasCharge)) dexGasRequired += params.DefaultGasPerCancel * uint64(len(m.Cancellations)*numDependencies) } } diff --git a/x/dex/cache/cache.go b/x/dex/cache/cache.go index c1d70240ec..a7b8470c80 100644 --- a/x/dex/cache/cache.go +++ b/x/dex/cache/cache.go @@ -88,13 +88,13 @@ func (s *MemState) GetDepositInfo(ctx sdk.Context, contractAddr types.ContractAd ) } -func (s *MemState) GetContractToDependencies(contractAddress string, loader func(addr string) *types.ContractInfoV2) []string { +func (s *MemState) GetContractToDependencies(ctx sdk.Context, contractAddress string, loader func(sdk.Context, string) (types.ContractInfoV2, error)) []string { s.contractsToDepsMtx.Lock() defer s.contractsToDepsMtx.Unlock() if deps, ok := s.contractsToDependencies.Load(contractAddress); ok { return deps } - loadedDownstreams := GetAllDownstreamContracts(contractAddress, loader) + loadedDownstreams := GetAllDownstreamContracts(ctx, contractAddress, loader) s.contractsToDependencies.Store(contractAddress, loadedDownstreams) return loadedDownstreams } @@ -106,8 +106,8 @@ func (s *MemState) ClearContractToDependencies() { s.contractsToDependencies = datastructures.NewTypedSyncMap[string, []string]() } -func (s *MemState) SetDownstreamsToProcess(contractAddress string, loader func(addr string) *types.ContractInfoV2) { - s.contractsToProcess.AddAll(s.GetContractToDependencies(contractAddress, loader)) +func (s *MemState) SetDownstreamsToProcess(ctx sdk.Context, contractAddress string, loader func(sdk.Context, string) (types.ContractInfoV2, error)) { + s.contractsToProcess.AddAll(s.GetContractToDependencies(ctx, contractAddress, loader)) } func (s *MemState) GetContractToProcess() *datastructures.SyncSet[string] { @@ -245,15 +245,15 @@ func DeepDelete(kvStore sdk.KVStore, storePrefix []byte, matcher func([]byte) bo // BFS traversal over a acyclic graph // Includes the root contract itself. -func GetAllDownstreamContracts(contractAddress string, loader func(addr string) *types.ContractInfoV2) []string { +func GetAllDownstreamContracts(ctx sdk.Context, contractAddress string, loader func(sdk.Context, string) (types.ContractInfoV2, error)) []string { res := []string{contractAddress} seen := datastructures.NewSyncSet(res) downstreams := []*types.ContractInfoV2{} populater := func(target *types.ContractInfoV2) { for _, dep := range target.Dependencies { - if downstream := loader(dep.Dependency); downstream != nil && !seen.Contains(downstream.ContractAddr) { + if downstream, err := loader(ctx, dep.Dependency); err == nil && !seen.Contains(downstream.ContractAddr) { if !downstream.Suspended { - downstreams = append(downstreams, downstream) + downstreams = append(downstreams, &downstream) seen.Add(downstream.ContractAddr) } } else { @@ -264,8 +264,8 @@ func GetAllDownstreamContracts(contractAddress string, loader func(addr string) } } // init first layer downstreams - if contract := loader(contractAddress); contract != nil { - populater(contract) + if contract, err := loader(ctx, contractAddress); err == nil { + populater(&contract) } else { return res } diff --git a/x/dex/cache/cache_test.go b/x/dex/cache/cache_test.go index 953494cfa7..d975c68e5d 100644 --- a/x/dex/cache/cache_test.go +++ b/x/dex/cache/cache_test.go @@ -222,11 +222,5 @@ func TestGetAllDownstreamContracts(t *testing.T) { "sei1ery8l6jquynn9a4cz2pff6khg8c68f7urt33l5n9dng2cwzz4c4q4hncrd", "sei1wl59k23zngj34l7d42y9yltask7rjlnxgccawc7ltrknp6n52fpsj6ctln", "sei1stwdtk6ja0705v8qmtukcp4vd422p5vy4jr5wdc4qk44c57k955qcannhd", - }, dex.GetAllDownstreamContracts("sei1ery8l6jquynn9a4cz2pff6khg8c68f7urt33l5n9dng2cwzz4c4q4hncrd", func(addr string) *types.ContractInfoV2 { - c, err := keeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - })) + }, dex.GetAllDownstreamContracts(ctx, "sei1ery8l6jquynn9a4cz2pff6khg8c68f7urt33l5n9dng2cwzz4c4q4hncrd", keeper.GetContractWithoutGasCharge)) } diff --git a/x/dex/keeper/contract.go b/x/dex/keeper/contract.go index 729e0642ab..170fd976c8 100644 --- a/x/dex/keeper/contract.go +++ b/x/dex/keeper/contract.go @@ -50,6 +50,10 @@ func (k Keeper) GetContract(ctx sdk.Context, contractAddr string) (types.Contrac return res, nil } +func (k Keeper) GetContractWithoutGasCharge(ctx sdk.Context, contractAddr string) (types.ContractInfoV2, error) { + return k.GetContract(ctx.WithGasMeter(sdk.NewInfiniteGasMeter()), contractAddr) +} + func (k Keeper) GetContractGasLimit(ctx sdk.Context, contractAddr sdk.AccAddress) (uint64, error) { bech32ContractAddr := contractAddr.String() contract, err := k.GetContract(ctx, bech32ContractAddr) diff --git a/x/dex/keeper/contract_test.go b/x/dex/keeper/contract_test.go index 1efb6a1272..62ea8edee9 100644 --- a/x/dex/keeper/contract_test.go +++ b/x/dex/keeper/contract_test.go @@ -283,3 +283,36 @@ func TestClearDependenciesForContract(t *testing.T) { require.Nil(t, err) require.Equal(t, int64(0), downB.NumIncomingDependencies) } + +func TestGetContractWithoutGasCharge(t *testing.T) { + keeper, ctx := keepertest.DexKeeper(t) + _ = keeper.SetContract(ctx, &types.ContractInfoV2{ + Creator: keepertest.TestAccount, + ContractAddr: keepertest.TestContract, + CodeId: 1, + RentBalance: 1000000, + }) + // regular gas meter case + ctx = ctx.WithGasMeter(sdk.NewGasMeter(10000)) + contract, err := keeper.GetContractWithoutGasCharge(ctx, keepertest.TestContract) + require.Nil(t, err) + require.Equal(t, keepertest.TestContract, contract.ContractAddr) + require.Equal(t, uint64(0), ctx.GasMeter().GasConsumed()) + require.Equal(t, uint64(10000), ctx.GasMeter().Limit()) + + // regular gas meter out of gas case + ctx = ctx.WithGasMeter(sdk.NewGasMeter(1)) + contract, err = keeper.GetContractWithoutGasCharge(ctx, keepertest.TestContract) + require.Nil(t, err) + require.Equal(t, keepertest.TestContract, contract.ContractAddr) + require.Equal(t, uint64(0), ctx.GasMeter().GasConsumed()) + require.Equal(t, uint64(1), ctx.GasMeter().Limit()) + + // infinite gas meter case + ctx = ctx.WithGasMeter(sdk.NewInfiniteGasMeter()) + contract, err = keeper.GetContractWithoutGasCharge(ctx, keepertest.TestContract) + require.Nil(t, err) + require.Equal(t, keepertest.TestContract, contract.ContractAddr) + require.Equal(t, uint64(0), ctx.GasMeter().GasConsumed()) + require.Equal(t, uint64(0), ctx.GasMeter().Limit()) +} diff --git a/x/dex/keeper/msgserver/msg_server_cancel_orders.go b/x/dex/keeper/msgserver/msg_server_cancel_orders.go index 8620475dbf..817d8aca37 100644 --- a/x/dex/keeper/msgserver/msg_server_cancel_orders.go +++ b/x/dex/keeper/msgserver/msg_server_cancel_orders.go @@ -56,12 +56,6 @@ func (k msgServer) CancelOrders(goCtx context.Context, msg *types.MsgCancelOrder } } ctx.EventManager().EmitEvents(events) - utils.GetMemState(ctx.Context()).SetDownstreamsToProcess(msg.ContractAddr, func(addr string) *types.ContractInfoV2 { - contract, err := k.GetContract(ctx, addr) - if err != nil { - return nil - } - return &contract - }) + utils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, msg.ContractAddr, k.GetContractWithoutGasCharge) return &types.MsgCancelOrdersResponse{}, nil } diff --git a/x/dex/keeper/msgserver/msg_server_place_orders.go b/x/dex/keeper/msgserver/msg_server_place_orders.go index 971c4e33c4..e13fbef923 100644 --- a/x/dex/keeper/msgserver/msg_server_place_orders.go +++ b/x/dex/keeper/msgserver/msg_server_place_orders.go @@ -87,13 +87,7 @@ func (k msgServer) PlaceOrders(goCtx context.Context, msg *types.MsgPlaceOrders) k.SetNextOrderID(ctx, msg.ContractAddr, nextID) ctx.EventManager().EmitEvents(events) - utils.GetMemState(ctx.Context()).SetDownstreamsToProcess(msg.ContractAddr, func(addr string) *types.ContractInfoV2 { - contract, err := k.GetContract(ctx, addr) - if err != nil { - return nil - } - return &contract - }) + utils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, msg.ContractAddr, k.GetContractWithoutGasCharge) return &types.MsgPlaceOrdersResponse{ OrderIds: idsInResp, }, nil diff --git a/x/dex/module_test.go b/x/dex/module_test.go index aade1d60a2..7426c300a9 100644 --- a/x/dex/module_test.go +++ b/x/dex/module_test.go @@ -118,13 +118,7 @@ func TestEndBlockMarketOrder(t *testing.T) { Amount: sdk.MustNewDecFromStr("2000000"), }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(1) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) @@ -147,13 +141,7 @@ func TestEndBlockMarketOrder(t *testing.T) { Data: "{\"position_effect\":\"Open\",\"leverage\":\"1\"}", }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(2) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) @@ -185,13 +173,7 @@ func TestEndBlockMarketOrder(t *testing.T) { Data: "{\"position_effect\":\"Open\",\"leverage\":\"1\"}", }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(3) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) @@ -283,13 +265,7 @@ func TestEndBlockLimitOrder(t *testing.T) { Amount: sdk.MustNewDecFromStr("2000000"), }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(1) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) @@ -329,13 +305,7 @@ func TestEndBlockLimitOrder(t *testing.T) { Data: "{\"position_effect\":\"Open\",\"leverage\":\"1\"}", }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(2) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) @@ -368,13 +338,7 @@ func TestEndBlockLimitOrder(t *testing.T) { Data: "{\"position_effect\":\"Open\",\"leverage\":\"1\"}", }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(3) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) @@ -414,13 +378,7 @@ func TestEndBlockRollback(t *testing.T) { PositionDirection: types.PositionDirection_LONG, }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(keepertest.TestContract, func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, keepertest.TestContract, dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(1) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) // No state change should've been persisted @@ -456,13 +414,7 @@ func TestEndBlockPartialRollback(t *testing.T) { PositionDirection: types.PositionDirection_LONG, }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(keepertest.TestContract, func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, keepertest.TestContract, dexkeeper.GetContractWithoutGasCharge) // GOOD CONTRACT testAccount, _ := sdk.AccAddressFromBech32("sei1yezq49upxhunjjhudql2fnj5dgvcwjj87pn2wx") amounts := sdk.NewCoins(sdk.NewCoin("usei", sdk.NewInt(1000000)), sdk.NewCoin("uusdc", sdk.NewInt(1000000))) @@ -511,13 +463,7 @@ func TestEndBlockPartialRollback(t *testing.T) { Amount: sdk.MustNewDecFromStr("10000"), }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(1) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) @@ -683,13 +629,7 @@ func TestEndBlockRollbackWithRentCharge(t *testing.T) { Amount: sdk.MustNewDecFromStr("10000"), }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) // overwrite params for testing params := dexkeeper.GetParams(ctx) params.MinProcessableRent = 0 @@ -846,13 +786,7 @@ func TestOrderCountUpdate(t *testing.T) { Amount: sdk.MustNewDecFromStr("2000000"), }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(1) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) @@ -889,13 +823,7 @@ func TestOrderCountUpdate(t *testing.T) { PositionDirection: types.PositionDirection_LONG, }, ) - dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(contractAddr.String(), func(addr string) *types.ContractInfoV2 { - c, err := dexkeeper.GetContract(ctx, addr) - if err != nil { - return nil - } - return &c - }) + dexutils.GetMemState(ctx.Context()).SetDownstreamsToProcess(ctx, contractAddr.String(), dexkeeper.GetContractWithoutGasCharge) ctx = ctx.WithBlockHeight(2) testApp.EndBlocker(ctx, abci.RequestEndBlock{}) require.Equal(t, uint64(2), dexkeeper.GetOrderCountState(ctx, contractAddr.String(), pair.PriceDenom, pair.AssetDenom, types.PositionDirection_LONG, sdk.NewDec(1)))