diff --git a/CHANGELOG.md b/CHANGELOG.md index c615c52f17df..9fbd99c24c31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -148,7 +148,7 @@ Ref: https://keepachangelog.com/en/1.0.0/ * [\#11274](https://github.com/cosmos/cosmos-sdk/pull/11274) `types/errors.New` now is an alias for `types/errors.Register` and should only be used in initialization code. * (authz)[\#11060](https://github.com/cosmos/cosmos-sdk/pull/11060) `authz.NewMsgGrant` `expiration` is now a pointer. When `nil` is used then no expiration will be set (grant won't expire). * (x/distribution)[\#11457](https://github.com/cosmos/cosmos-sdk/pull/11457) Add amount field to `distr.MsgWithdrawDelegatorRewardResponse` and `distr.MsgWithdrawValidatorCommissionResponse`. - +* (x/auth/middleware) [#11413](https://github.com/cosmos/cosmos-sdk/pull/11413) Refactor tx middleware to be extensible on tx fee logic. Merged `MempoolFeeMiddleware` and `TxPriorityMiddleware` functionalities into `DeductFeeMiddleware`, make the logic extensible using the `TxFeeChecker` option, the current fee logic is preserved by the default `checkTxFeeWithValidatorMinGasPrices` implementation. Change `RejectExtensionOptionsMiddleware` to `NewExtensionOptionsMiddleware` which is extensible with the `ExtensionOptionChecker` option. Unpack the tx extension options `Any`s to interface `TxExtensionOptionI`. ### Client Breaking Changes diff --git a/testutil/testdata/codec.go b/testutil/testdata/codec.go index d5c6e8abd2f2..5124f506d0d5 100644 --- a/testutil/testdata/codec.go +++ b/testutil/testdata/codec.go @@ -6,6 +6,7 @@ import ( "github.com/cosmos/cosmos-sdk/codec/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/msgservice" + tx "github.com/cosmos/cosmos-sdk/types/tx" ) func NewTestInterfaceRegistry() types.InterfaceRegistry { @@ -31,6 +32,10 @@ func RegisterInterfaces(registry types.InterfaceRegistry) { (*HasHasAnimalI)(nil), &HasHasAnimal{}, ) + registry.RegisterImplementations( + (*tx.TxExtensionOptionI)(nil), + &Cat{}, + ) msgservice.RegisterMsgServiceDesc(registry, &_Msg_serviceDesc) } diff --git a/types/tx/ext.go b/types/tx/ext.go new file mode 100644 index 000000000000..fb2e1ed448bc --- /dev/null +++ b/types/tx/ext.go @@ -0,0 +1,21 @@ +package tx + +import ( + "github.com/cosmos/cosmos-sdk/codec/types" +) + +// TxExtensionOptionI defines the interface for tx extension options +type TxExtensionOptionI interface{} + +// unpackTxExtensionOptionsI unpacks Any's to TxExtensionOptionI's. +func unpackTxExtensionOptionsI(unpacker types.AnyUnpacker, anys []*types.Any) error { + for _, any := range anys { + var opt TxExtensionOptionI + err := unpacker.UnpackAny(any, &opt) + if err != nil { + return err + } + } + + return nil +} diff --git a/types/tx/types.go b/types/tx/types.go index 006941b35f17..cbdb50a43ca4 100644 --- a/types/tx/types.go +++ b/types/tx/types.go @@ -173,7 +173,19 @@ func (t *Tx) UnpackInterfaces(unpacker codectypes.AnyUnpacker) error { // UnpackInterfaces implements the UnpackInterfaceMessages.UnpackInterfaces method func (m *TxBody) UnpackInterfaces(unpacker codectypes.AnyUnpacker) error { - return UnpackInterfaces(unpacker, m.Messages) + if err := UnpackInterfaces(unpacker, m.Messages); err != nil { + return err + } + + if err := unpackTxExtensionOptionsI(unpacker, m.ExtensionOptions); err != nil { + return err + } + + if err := unpackTxExtensionOptionsI(unpacker, m.NonCriticalExtensionOptions); err != nil { + return err + } + + return nil } // UnpackInterfaces implements the UnpackInterfaceMessages.UnpackInterfaces method @@ -200,4 +212,6 @@ func RegisterInterfaces(registry codectypes.InterfaceRegistry) { registry.RegisterInterface("cosmos.tx.v1beta1.Tx", (*sdk.Tx)(nil)) registry.RegisterImplementations((*sdk.Tx)(nil), &Tx{}) + + registry.RegisterInterface("cosmos.tx.v1beta1.TxExtensionOptionI", (*TxExtensionOptionI)(nil)) } diff --git a/x/auth/middleware/branch_store_test.go b/x/auth/middleware/branch_store_test.go index ea675492c1b5..80e116fb2a82 100644 --- a/x/auth/middleware/branch_store_test.go +++ b/x/auth/middleware/branch_store_test.go @@ -72,7 +72,7 @@ func (s *MWTestSuite) TestBranchStore() { middleware.NewTxDecoderMiddleware(s.clientCtx.TxConfig.TxDecoder()), middleware.GasTxMiddleware, middleware.RecoveryTxMiddleware, - middleware.DeductFeeMiddleware(s.app.AccountKeeper, s.app.BankKeeper, s.app.FeeGrantKeeper), + middleware.DeductFeeMiddleware(s.app.AccountKeeper, s.app.BankKeeper, s.app.FeeGrantKeeper, nil), middleware.IncrementSequenceMiddleware(s.app.AccountKeeper), middleware.WithBranchedStore, middleware.ConsumeBlockGasMiddleware, diff --git a/x/auth/middleware/ext.go b/x/auth/middleware/ext.go index 783c3bbc476b..5159e3c5f1f8 100644 --- a/x/auth/middleware/ext.go +++ b/x/auth/middleware/ext.go @@ -14,27 +14,44 @@ type HasExtensionOptionsTx interface { GetNonCriticalExtensionOptions() []*codectypes.Any } +// ExtensionOptionChecker is a function that returns true if the extension option is accepted. +type ExtensionOptionChecker func(*codectypes.Any) bool + +// rejectExtensionOption is the default extension check that reject all tx +// extensions. +func rejectExtensionOption(*codectypes.Any) bool { + return false +} + type rejectExtensionOptionsTxHandler struct { - next tx.Handler + next tx.Handler + checker ExtensionOptionChecker } -// RejectExtensionOptionsMiddleware creates a new rejectExtensionOptionsMiddleware. -// rejectExtensionOptionsMiddleware is a middleware that rejects all extension -// options which can optionally be included in protobuf transactions. Users that -// need extension options should create a custom middleware chain that handles -// needed extension options properly and rejects unknown ones. -func RejectExtensionOptionsMiddleware(txh tx.Handler) tx.Handler { - return rejectExtensionOptionsTxHandler{ - next: txh, +// NewExtensionOptionsMiddleware creates a new middleware that rejects all extension +// options which can optionally be included in protobuf transactions that don't pass the checker. +// Users that need extension options should pass a custom checker that returns true for the +// needed extension options. +func NewExtensionOptionsMiddleware(checker ExtensionOptionChecker) tx.Middleware { + if checker == nil { + checker = rejectExtensionOption + } + return func(txh tx.Handler) tx.Handler { + return rejectExtensionOptionsTxHandler{ + next: txh, + checker: checker, + } } } var _ tx.Handler = rejectExtensionOptionsTxHandler{} -func checkExtOpts(tx sdk.Tx) error { +func checkExtOpts(tx sdk.Tx, checker ExtensionOptionChecker) error { if hasExtOptsTx, ok := tx.(HasExtensionOptionsTx); ok { - if len(hasExtOptsTx.GetExtensionOptions()) != 0 { - return sdkerrors.ErrUnknownExtensionOptions + for _, opt := range hasExtOptsTx.GetExtensionOptions() { + if !checker(opt) { + return sdkerrors.ErrUnknownExtensionOptions + } } } @@ -43,7 +60,7 @@ func checkExtOpts(tx sdk.Tx) error { // CheckTx implements tx.Handler.CheckTx. func (txh rejectExtensionOptionsTxHandler) CheckTx(ctx context.Context, req tx.Request, checkReq tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error) { - if err := checkExtOpts(req.Tx); err != nil { + if err := checkExtOpts(req.Tx, txh.checker); err != nil { return tx.Response{}, tx.ResponseCheckTx{}, err } @@ -52,7 +69,7 @@ func (txh rejectExtensionOptionsTxHandler) CheckTx(ctx context.Context, req tx.R // DeliverTx implements tx.Handler.DeliverTx. func (txh rejectExtensionOptionsTxHandler) DeliverTx(ctx context.Context, req tx.Request) (tx.Response, error) { - if err := checkExtOpts(req.Tx); err != nil { + if err := checkExtOpts(req.Tx, txh.checker); err != nil { return tx.Response{}, err } @@ -61,7 +78,7 @@ func (txh rejectExtensionOptionsTxHandler) DeliverTx(ctx context.Context, req tx // SimulateTx implements tx.Handler.SimulateTx method. func (txh rejectExtensionOptionsTxHandler) SimulateTx(ctx context.Context, req tx.Request) (tx.Response, error) { - if err := checkExtOpts(req.Tx); err != nil { + if err := checkExtOpts(req.Tx, txh.checker); err != nil { return tx.Response{}, err } diff --git a/x/auth/middleware/ext_test.go b/x/auth/middleware/ext_test.go index ec57f722c893..27c294794cfb 100644 --- a/x/auth/middleware/ext_test.go +++ b/x/auth/middleware/ext_test.go @@ -2,6 +2,7 @@ package middleware_test import ( "github.com/cosmos/cosmos-sdk/codec/types" + codectypes "github.com/cosmos/cosmos-sdk/codec/types" "github.com/cosmos/cosmos-sdk/testutil/testdata" sdk "github.com/cosmos/cosmos-sdk/types" typestx "github.com/cosmos/cosmos-sdk/types/tx" @@ -9,28 +10,45 @@ import ( "github.com/cosmos/cosmos-sdk/x/auth/tx" ) -func (s *MWTestSuite) TestRejectExtensionOptionsMiddleware() { - ctx := s.SetupTest(true) // setup - txBuilder := s.clientCtx.TxConfig.NewTxBuilder() +func (s *MWTestSuite) TestExtensionOptionsMiddleware() { + testCases := []struct { + msg string + allow bool + }{ + {"allow extension", true}, + {"reject extension", false}, + } + for _, tc := range testCases { + s.Run(tc.msg, func() { + ctx := s.SetupTest(true) // setup + txBuilder := s.clientCtx.TxConfig.NewTxBuilder() - txHandler := middleware.ComposeMiddlewares(noopTxHandler, middleware.RejectExtensionOptionsMiddleware) + txHandler := middleware.ComposeMiddlewares(noopTxHandler, middleware.NewExtensionOptionsMiddleware(func(_ *codectypes.Any) bool { + return tc.allow + })) - // no extension options should not trigger an error - theTx := txBuilder.GetTx() - _, _, err := txHandler.CheckTx(sdk.WrapSDKContext(ctx), typestx.Request{Tx: theTx}, typestx.RequestCheckTx{}) - s.Require().NoError(err) + // no extension options should not trigger an error + theTx := txBuilder.GetTx() + _, _, err := txHandler.CheckTx(sdk.WrapSDKContext(ctx), typestx.Request{Tx: theTx}, typestx.RequestCheckTx{}) + s.Require().NoError(err) - extOptsTxBldr, ok := txBuilder.(tx.ExtensionOptionsTxBuilder) - if !ok { - // if we can't set extension options, this middleware doesn't apply and we're done - return - } + extOptsTxBldr, ok := txBuilder.(tx.ExtensionOptionsTxBuilder) + if !ok { + // if we can't set extension options, this middleware doesn't apply and we're done + return + } - // setting any extension option should cause an error - any, err := types.NewAnyWithValue(testdata.NewTestMsg()) - s.Require().NoError(err) - extOptsTxBldr.SetExtensionOptions(any) - theTx = txBuilder.GetTx() - _, _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), typestx.Request{Tx: theTx}, typestx.RequestCheckTx{}) - s.Require().EqualError(err, "unknown extension options") + // set an extension option and check + any, err := types.NewAnyWithValue(testdata.NewTestMsg()) + s.Require().NoError(err) + extOptsTxBldr.SetExtensionOptions(any) + theTx = txBuilder.GetTx() + _, _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), typestx.Request{Tx: theTx}, typestx.RequestCheckTx{}) + if tc.allow { + s.Require().NoError(err) + } else { + s.Require().EqualError(err, "unknown extension options") + } + }) + } } diff --git a/x/auth/middleware/fee.go b/x/auth/middleware/fee.go index 718afba479d1..2ae83c37269e 100644 --- a/x/auth/middleware/fee.go +++ b/x/auth/middleware/fee.go @@ -10,73 +10,9 @@ import ( "github.com/cosmos/cosmos-sdk/x/auth/types" ) -var _ tx.Handler = mempoolFeeTxHandler{} - -type mempoolFeeTxHandler struct { - next tx.Handler -} - -// MempoolFeeMiddleware will check if the transaction's fee is at least as large -// as the local validator's minimum gasFee (defined in validator config). -// If fee is too low, middleware returns error and tx is rejected from mempool. -// Note this only applies when ctx.CheckTx = true -// If fee is high enough or not CheckTx, then call next middleware -// CONTRACT: Tx must implement FeeTx to use MempoolFeeMiddleware -func MempoolFeeMiddleware(txh tx.Handler) tx.Handler { - return mempoolFeeTxHandler{ - next: txh, - } -} - -// CheckTx implements tx.Handler.CheckTx. It is responsible for determining if a -// transaction's fees meet the required minimum of the processing node. Note, a -// node can have zero fees set as the minimum. If non-zero minimum fees are set -// and the transaction does not meet the minimum, the transaction is rejected. -// -// Recall, a transaction's fee is determined by ceil(minGasPrice * gasLimit). -func (txh mempoolFeeTxHandler) CheckTx(ctx context.Context, req tx.Request, checkReq tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error) { - sdkCtx := sdk.UnwrapSDKContext(ctx) - - feeTx, ok := req.Tx.(sdk.FeeTx) - if !ok { - return tx.Response{}, tx.ResponseCheckTx{}, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx") - } - - feeCoins := feeTx.GetFee() - gas := feeTx.GetGas() - - // Ensure that the provided fees meet a minimum threshold for the validator, - // if this is a CheckTx. This is only for local mempool purposes, and thus - // is only ran on check tx. - minGasPrices := sdkCtx.MinGasPrices() - if !minGasPrices.IsZero() { - requiredFees := make(sdk.Coins, len(minGasPrices)) - - // Determine the required fees by multiplying each required minimum gas - // price by the gas limit, where fee = ceil(minGasPrice * gasLimit). - glDec := sdk.NewDec(int64(gas)) - for i, gp := range minGasPrices { - fee := gp.Amount.Mul(glDec) - requiredFees[i] = sdk.NewCoin(gp.Denom, fee.Ceil().RoundInt()) - } - - if !feeCoins.IsAnyGTE(requiredFees) { - return tx.Response{}, tx.ResponseCheckTx{}, sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "insufficient fees; got: %s required: %s", feeCoins, requiredFees) - } - } - - return txh.next.CheckTx(ctx, req, checkReq) -} - -// DeliverTx implements tx.Handler.DeliverTx. -func (txh mempoolFeeTxHandler) DeliverTx(ctx context.Context, req tx.Request) (tx.Response, error) { - return txh.next.DeliverTx(ctx, req) -} - -// SimulateTx implements tx.Handler.SimulateTx. -func (txh mempoolFeeTxHandler) SimulateTx(ctx context.Context, req tx.Request) (tx.Response, error) { - return txh.next.SimulateTx(ctx, req) -} +// TxFeeChecker check if the provided fee is enough and returns the effective fee and tx priority, +// the effective fee should be deducted later, and the priority should be returned in abci response. +type TxFeeChecker func(ctx sdk.Context, tx sdk.Tx) (sdk.Coins, int64, error) var _ tx.Handler = deductFeeTxHandler{} @@ -84,6 +20,7 @@ type deductFeeTxHandler struct { accountKeeper AccountKeeper bankKeeper types.BankKeeper feegrantKeeper FeegrantKeeper + txFeeChecker TxFeeChecker next tx.Handler } @@ -91,19 +28,22 @@ type deductFeeTxHandler struct { // If the first signer does not have the funds to pay for the fees, return with InsufficientFunds error // Call next middleware if fees successfully deducted // CONTRACT: Tx must implement FeeTx interface to use deductFeeTxHandler -func DeductFeeMiddleware(ak AccountKeeper, bk types.BankKeeper, fk FeegrantKeeper) tx.Middleware { +func DeductFeeMiddleware(ak AccountKeeper, bk types.BankKeeper, fk FeegrantKeeper, tfc TxFeeChecker) tx.Middleware { + if tfc == nil { + tfc = checkTxFeeWithValidatorMinGasPrices + } return func(txh tx.Handler) tx.Handler { return deductFeeTxHandler{ accountKeeper: ak, bankKeeper: bk, feegrantKeeper: fk, + txFeeChecker: tfc, next: txh, } } } -func (dfd deductFeeTxHandler) checkDeductFee(ctx context.Context, sdkTx sdk.Tx) error { - sdkCtx := sdk.UnwrapSDKContext(ctx) +func (dfd deductFeeTxHandler) checkDeductFee(ctx sdk.Context, sdkTx sdk.Tx, fee sdk.Coins) error { feeTx, ok := sdkTx.(sdk.FeeTx) if !ok { return sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx") @@ -113,10 +53,8 @@ func (dfd deductFeeTxHandler) checkDeductFee(ctx context.Context, sdkTx sdk.Tx) return fmt.Errorf("Fee collector module account (%s) has not been set", types.FeeCollectorName) } - fee := feeTx.GetFee() feePayer := feeTx.FeePayer() feeGranter := feeTx.FeeGranter() - deductFeesFrom := feePayer // if feegranter set deduct fee from feegranter account. @@ -125,7 +63,7 @@ func (dfd deductFeeTxHandler) checkDeductFee(ctx context.Context, sdkTx sdk.Tx) if dfd.feegrantKeeper == nil { return sdkerrors.ErrInvalidRequest.Wrap("fee grants are not enabled") } else if !feeGranter.Equals(feePayer) { - err := dfd.feegrantKeeper.UseGrantedFees(sdkCtx, feeGranter, feePayer, fee, sdkTx.GetMsgs()) + err := dfd.feegrantKeeper.UseGrantedFees(ctx, feeGranter, feePayer, fee, sdkTx.GetMsgs()) if err != nil { return sdkerrors.Wrapf(err, "%s does not not allow to pay fees for %s", feeGranter, feePayer) } @@ -134,39 +72,52 @@ func (dfd deductFeeTxHandler) checkDeductFee(ctx context.Context, sdkTx sdk.Tx) deductFeesFrom = feeGranter } - deductFeesFromAcc := dfd.accountKeeper.GetAccount(sdkCtx, deductFeesFrom) + deductFeesFromAcc := dfd.accountKeeper.GetAccount(ctx, deductFeesFrom) if deductFeesFromAcc == nil { return sdkerrors.ErrUnknownAddress.Wrapf("fee payer address: %s does not exist", deductFeesFrom) } // deduct the fees - if !feeTx.GetFee().IsZero() { - err := DeductFees(dfd.bankKeeper, sdkCtx, deductFeesFromAcc, feeTx.GetFee()) + if !fee.IsZero() { + err := DeductFees(dfd.bankKeeper, ctx, deductFeesFromAcc, fee) if err != nil { return err } } events := sdk.Events{sdk.NewEvent(sdk.EventTypeTx, - sdk.NewAttribute(sdk.AttributeKeyFee, feeTx.GetFee().String()), + sdk.NewAttribute(sdk.AttributeKeyFee, fee.String()), )} - sdkCtx.EventManager().EmitEvents(events) + ctx.EventManager().EmitEvents(events) return nil } // CheckTx implements tx.Handler.CheckTx. func (dfd deductFeeTxHandler) CheckTx(ctx context.Context, req tx.Request, checkReq tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error) { - if err := dfd.checkDeductFee(ctx, req.Tx); err != nil { + sdkCtx := sdk.UnwrapSDKContext(ctx) + fee, priority, err := dfd.txFeeChecker(sdkCtx, req.Tx) + if err != nil { + return tx.Response{}, tx.ResponseCheckTx{}, err + } + if err := dfd.checkDeductFee(sdkCtx, req.Tx, fee); err != nil { return tx.Response{}, tx.ResponseCheckTx{}, err } - return dfd.next.CheckTx(ctx, req, checkReq) + res, checkRes, err := dfd.next.CheckTx(ctx, req, checkReq) + checkRes.Priority = priority + + return res, checkRes, err } // DeliverTx implements tx.Handler.DeliverTx. func (dfd deductFeeTxHandler) DeliverTx(ctx context.Context, req tx.Request) (tx.Response, error) { - if err := dfd.checkDeductFee(ctx, req.Tx); err != nil { + sdkCtx := sdk.UnwrapSDKContext(ctx) + fee, _, err := dfd.txFeeChecker(sdkCtx, req.Tx) + if err != nil { + return tx.Response{}, err + } + if err := dfd.checkDeductFee(sdkCtx, req.Tx, fee); err != nil { return tx.Response{}, err } @@ -174,7 +125,12 @@ func (dfd deductFeeTxHandler) DeliverTx(ctx context.Context, req tx.Request) (tx } func (dfd deductFeeTxHandler) SimulateTx(ctx context.Context, req tx.Request) (tx.Response, error) { - if err := dfd.checkDeductFee(ctx, req.Tx); err != nil { + sdkCtx := sdk.UnwrapSDKContext(ctx) + fee, _, err := dfd.txFeeChecker(sdkCtx, req.Tx) + if err != nil { + return tx.Response{}, err + } + if err := dfd.checkDeductFee(sdkCtx, req.Tx, fee); err != nil { return tx.Response{}, err } diff --git a/x/auth/middleware/fee_test.go b/x/auth/middleware/fee_test.go index e4673f3f8dd5..7c1f54463fff 100644 --- a/x/auth/middleware/fee_test.go +++ b/x/auth/middleware/fee_test.go @@ -13,14 +13,21 @@ func (s *MWTestSuite) TestEnsureMempoolFees() { ctx := s.SetupTest(true) // setup txBuilder := s.clientCtx.TxConfig.NewTxBuilder() - txHandler := middleware.ComposeMiddlewares(noopTxHandler, middleware.MempoolFeeMiddleware) + txHandler := middleware.ComposeMiddlewares(noopTxHandler, middleware.DeductFeeMiddleware( + s.app.AccountKeeper, + s.app.BankKeeper, + s.app.FeeGrantKeeper, + nil, + )) // keys and addresses priv1, _, addr1 := testdata.KeyTestPubAddr() // msg and signatures msg := testdata.NewTestMsg(addr1) - feeAmount := testdata.NewTestFeeAmount() + atomCoin := sdk.NewCoin("atom", sdk.NewInt(150)) + apeCoin := sdk.NewInt64Coin("ape", 1500000) + feeAmount := sdk.NewCoins(apeCoin, atomCoin) gasLimit := testdata.NewTestGasLimit() s.Require().NoError(txBuilder.SetMsgs(msg)) txBuilder.SetFeeAmount(feeAmount) @@ -39,16 +46,23 @@ func (s *MWTestSuite) TestEnsureMempoolFees() { _, _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx.Request{Tx: testTx}, tx.RequestCheckTx{}) s.Require().NotNil(err, "Middleware should have errored on too low fee for local gasPrice") - // txHandler should not error since we do not check minGasPrice in DeliverTx + // txHandler should fail since we also check minGasPrice in DeliverTx _, err = txHandler.DeliverTx(sdk.WrapSDKContext(ctx), tx.Request{Tx: testTx}) - s.Require().Nil(err, "MempoolFeeMiddleware returned error in DeliverTx") + s.Require().Error(err, "MempoolFeeMiddleware don't error in DeliverTx") atomPrice = sdk.NewDecCoinFromDec("atom", sdk.NewDec(0).Quo(sdk.NewDec(100000))) lowGasPrice := []sdk.DecCoin{atomPrice} ctx = ctx.WithMinGasPrices(lowGasPrice) - _, _, err = txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx.Request{Tx: testTx}, tx.RequestCheckTx{}) + // Set account with sufficient funds + acc := s.app.AccountKeeper.NewAccountWithAddress(ctx, addr1) + s.app.AccountKeeper.SetAccount(ctx, acc) + err = testutil.FundAccount(s.app.BankKeeper, ctx, addr1, feeAmount) + s.Require().NoError(err) + + _, checkTxRes, err := txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx.Request{Tx: testTx}, tx.RequestCheckTx{}) s.Require().Nil(err, "Middleware should not have errored on fee higher than local gasPrice") + s.Require().Equal(atomCoin.Amount.Int64(), checkTxRes.Priority, "priority should be atom amount") } func (s *MWTestSuite) TestDeductFees() { @@ -60,6 +74,7 @@ func (s *MWTestSuite) TestDeductFees() { s.app.AccountKeeper, s.app.BankKeeper, s.app.FeeGrantKeeper, + nil, ), ) diff --git a/x/auth/middleware/feegrant_test.go b/x/auth/middleware/feegrant_test.go index f2c36ad5efdd..547f57a589d1 100644 --- a/x/auth/middleware/feegrant_test.go +++ b/x/auth/middleware/feegrant_test.go @@ -36,6 +36,7 @@ func (s *MWTestSuite) TestDeductFeesNoDelegation() { s.app.AccountKeeper, s.app.BankKeeper, s.app.FeeGrantKeeper, + nil, ), ) diff --git a/x/auth/middleware/middleware.go b/x/auth/middleware/middleware.go index 87b1216363fb..a3fb30b0d6ae 100644 --- a/x/auth/middleware/middleware.go +++ b/x/auth/middleware/middleware.go @@ -43,11 +43,13 @@ type TxHandlerOptions struct { LegacyRouter sdk.Router MsgServiceRouter *MsgServiceRouter - AccountKeeper AccountKeeper - BankKeeper types.BankKeeper - FeegrantKeeper FeegrantKeeper - SignModeHandler authsigning.SignModeHandler - SigGasConsumer func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error + AccountKeeper AccountKeeper + BankKeeper types.BankKeeper + FeegrantKeeper FeegrantKeeper + SignModeHandler authsigning.SignModeHandler + SigGasConsumer func(meter sdk.GasMeter, sig signing.SignatureV2, params types.Params) error + ExtensionOptionChecker ExtensionOptionChecker + TxFeeChecker TxFeeChecker } // NewDefaultTxHandler defines a TxHandler middleware stacks that should work @@ -74,6 +76,16 @@ func NewDefaultTxHandler(options TxHandlerOptions) (tx.Handler, error) { sigGasConsumer = DefaultSigVerificationGasConsumer } + var extensionOptionChecker = options.ExtensionOptionChecker + if extensionOptionChecker == nil { + extensionOptionChecker = rejectExtensionOption + } + + var txFeeChecker = options.TxFeeChecker + if txFeeChecker == nil { + txFeeChecker = checkTxFeeWithValidatorMinGasPrices + } + return ComposeMiddlewares( NewRunMsgsTxHandler(options.MsgServiceRouter, options.LegacyRouter), NewTxDecoderMiddleware(options.TxDecoder), @@ -89,10 +101,8 @@ func NewDefaultTxHandler(options TxHandlerOptions) (tx.Handler, error) { // Choose which events to index in Tendermint. Make sure no events are // emitted outside of this middleware. NewIndexEventsTxMiddleware(options.IndexEvents), - // Reject all extension options which can optionally be included in the - // tx. - RejectExtensionOptionsMiddleware, - MempoolFeeMiddleware, + // Reject all extension options other than the ones needed by the feemarket. + NewExtensionOptionsMiddleware(extensionOptionChecker), ValidateBasicMiddleware, TxTimeoutHeightMiddleware, ValidateMemoMiddleware(options.AccountKeeper), @@ -101,8 +111,7 @@ func NewDefaultTxHandler(options TxHandlerOptions) (tx.Handler, error) { // ComposeMiddlewares godoc for details. // `DeductFeeMiddleware` and `IncrementSequenceMiddleware` should be put outside of `WithBranchedStore` middleware, // so their storage writes are not discarded when tx fails. - DeductFeeMiddleware(options.AccountKeeper, options.BankKeeper, options.FeegrantKeeper), - TxPriorityMiddleware, + DeductFeeMiddleware(options.AccountKeeper, options.BankKeeper, options.FeegrantKeeper, txFeeChecker), SetPubKeyMiddleware(options.AccountKeeper), ValidateSigCountMiddleware(options.AccountKeeper), SigGasConsumeMiddleware(options.AccountKeeper, sigGasConsumer), diff --git a/x/auth/middleware/priority.go b/x/auth/middleware/priority.go deleted file mode 100644 index a464fddd1e7d..000000000000 --- a/x/auth/middleware/priority.go +++ /dev/null @@ -1,63 +0,0 @@ -package middleware - -import ( - "context" - - sdk "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - "github.com/cosmos/cosmos-sdk/types/tx" -) - -var _ tx.Handler = txPriorityHandler{} - -type txPriorityHandler struct { - next tx.Handler -} - -// TxPriorityMiddleware implements tx handling middleware that determines a -// transaction's priority via a naive mechanism -- the total sum of fees provided. -// It sets the Priority in ResponseCheckTx only. -func TxPriorityMiddleware(txh tx.Handler) tx.Handler { - return txPriorityHandler{next: txh} -} - -// CheckTx implements tx.Handler.CheckTx. We set the Priority of the transaction -// to be ordered in the Tendermint mempool based naively on the total sum of all -// fees included. Applications that need more sophisticated mempool ordering -// should look to implement their own fee handling middleware instead of using -// TxPriorityHandler. -func (h txPriorityHandler) CheckTx(ctx context.Context, req tx.Request, checkReq tx.RequestCheckTx) (tx.Response, tx.ResponseCheckTx, error) { - feeTx, ok := req.Tx.(sdk.FeeTx) - if !ok { - return tx.Response{}, tx.ResponseCheckTx{}, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx") - } - - feeCoins := feeTx.GetFee() - - res, checkRes, err := h.next.CheckTx(ctx, req, checkReq) - checkRes.Priority = GetTxPriority(feeCoins) - - return res, checkRes, err -} - -func (h txPriorityHandler) DeliverTx(ctx context.Context, req tx.Request) (tx.Response, error) { - return h.next.DeliverTx(ctx, req) -} - -func (h txPriorityHandler) SimulateTx(ctx context.Context, req tx.Request) (tx.Response, error) { - return h.next.SimulateTx(ctx, req) -} - -// GetTxPriority returns a naive tx priority based on the amount of the smallest denomination of the fee -// provided in a transaction. -func GetTxPriority(fee sdk.Coins) int64 { - var priority int64 - for _, c := range fee { - p := c.Amount.Int64() - if priority == 0 || p < priority { - priority = p - } - } - - return priority -} diff --git a/x/auth/middleware/priority_test.go b/x/auth/middleware/priority_test.go deleted file mode 100644 index 388418644040..000000000000 --- a/x/auth/middleware/priority_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package middleware_test - -import ( - cryptotypes "github.com/cosmos/cosmos-sdk/crypto/types" - "github.com/cosmos/cosmos-sdk/testutil/testdata" - sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/cosmos/cosmos-sdk/types/tx" - "github.com/cosmos/cosmos-sdk/x/auth/middleware" -) - -func (s *MWTestSuite) TestPriority() { - ctx := s.SetupTest(true) // setup - txBuilder := s.clientCtx.TxConfig.NewTxBuilder() - - txHandler := middleware.ComposeMiddlewares(noopTxHandler, middleware.TxPriorityMiddleware) - - // keys and addresses - priv1, _, addr1 := testdata.KeyTestPubAddr() - - // msg and signatures - msg := testdata.NewTestMsg(addr1) - atomCoin := sdk.NewCoin("atom", sdk.NewInt(150)) - apeCoin := sdk.NewInt64Coin("ape", 1500000) - feeAmount := sdk.NewCoins(apeCoin, atomCoin) - gasLimit := testdata.NewTestGasLimit() - s.Require().NoError(txBuilder.SetMsgs(msg)) - txBuilder.SetFeeAmount(feeAmount) - txBuilder.SetGasLimit(gasLimit) - - privs, accNums, accSeqs := []cryptotypes.PrivKey{priv1}, []uint64{0}, []uint64{0} - testTx, _, err := s.createTestTx(txBuilder, privs, accNums, accSeqs, ctx.ChainID()) - s.Require().NoError(err) - - // txHandler errors with insufficient fees - _, checkTxRes, err := txHandler.CheckTx(sdk.WrapSDKContext(ctx), tx.Request{Tx: testTx}, tx.RequestCheckTx{}) - s.Require().NoError(err, "Middleware should not have errored on too low fee for local gasPrice") - s.Require().Equal(atomCoin.Amount.Int64(), checkTxRes.Priority, "priority should be atom amount") -} diff --git a/x/auth/middleware/validator_tx_fee.go b/x/auth/middleware/validator_tx_fee.go new file mode 100644 index 000000000000..687608da55a9 --- /dev/null +++ b/x/auth/middleware/validator_tx_fee.go @@ -0,0 +1,59 @@ +package middleware + +import ( + "math" + + sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" +) + +// checkTxFeeWithValidatorMinGasPrices implements the default fee logic, where the minimum price per +// unit of gas is fixed and set by each validator, can the tx priority is computed from the gas price. +func checkTxFeeWithValidatorMinGasPrices(ctx sdk.Context, tx sdk.Tx) (sdk.Coins, int64, error) { + feeTx, ok := tx.(sdk.FeeTx) + if !ok { + return nil, 0, sdkerrors.Wrap(sdkerrors.ErrTxDecode, "Tx must be a FeeTx") + } + + feeCoins := feeTx.GetFee() + gas := feeTx.GetGas() + + // Ensure that the provided fees meet a minimum threshold for the validator, + // This is only for local mempool purposes, if this is a DeliverTx, the `MinGasPrices` should be zero. + minGasPrices := ctx.MinGasPrices() + if !minGasPrices.IsZero() { + requiredFees := make(sdk.Coins, len(minGasPrices)) + + // Determine the required fees by multiplying each required minimum gas + // price by the gas limit, where fee = ceil(minGasPrice * gasLimit). + glDec := sdk.NewDec(int64(gas)) + for i, gp := range minGasPrices { + fee := gp.Amount.Mul(glDec) + requiredFees[i] = sdk.NewCoin(gp.Denom, fee.Ceil().RoundInt()) + } + + if !feeCoins.IsAnyGTE(requiredFees) { + return nil, 0, sdkerrors.Wrapf(sdkerrors.ErrInsufficientFee, "insufficient fees; got: %s required: %s", feeCoins, requiredFees) + } + } + + priority := getTxPriority(feeCoins) + return feeCoins, priority, nil +} + +// getTxPriority returns a naive tx priority based on the amount of the smallest denomination of the fee +// provided in a transaction. +func getTxPriority(fee sdk.Coins) int64 { + var priority int64 + for _, c := range fee { + p := int64(math.MaxInt64) + if c.Amount.IsInt64() { + p = c.Amount.Int64() + } + if priority == 0 || p < priority { + priority = p + } + } + + return priority +}