diff --git a/crypto/keys/utils.go b/crypto/keys/utils.go deleted file mode 100644 index 2b81337d33c8..000000000000 --- a/crypto/keys/utils.go +++ /dev/null @@ -1,13 +0,0 @@ -package keys - -import ( - "math/big" - - "github.com/cosmos/cosmos-sdk/crypto/keys/internal/ecdsa" -) - -// Replicates https://github.com/cosmos/cosmos-sdk/blob/44fbb0df9cea049d588e76bf930177d777552cf3/crypto/ledger/ledger_secp256k1.go#L228 -// DO NOT USE. This is a temporary workaround that is cleaned-up in v0.47+ -func IsOverHalfOrder(sigS *big.Int) bool { - return !ecdsa.IsSNormalized(sigS) -} diff --git a/crypto/ledger/ledger_secp256k1.go b/crypto/ledger/ledger_secp256k1.go index 56ba46918ea0..29f50ad4e212 100644 --- a/crypto/ledger/ledger_secp256k1.go +++ b/crypto/ledger/ledger_secp256k1.go @@ -6,6 +6,7 @@ import ( "math/big" "os" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/cosmos/cosmos-sdk/crypto/hd" diff --git a/go.sum b/go.sum index 45009393e711..bc3cfe354753 100644 --- a/go.sum +++ b/go.sum @@ -152,9 +152,8 @@ github.com/boltdb/bolt v1.3.1/go.mod h1:clJnj/oiGkjum5o1McbSZDSLxVThjynRyGBgiAx2 github.com/btcsuite/btcd v0.0.0-20190315201642-aa6e0f35703c/go.mod h1:DrZx5ec/dmnfpw9KyYoQyYo7d0KEvTkk/5M/vbZjAr8= github.com/btcsuite/btcd v0.20.1-beta/go.mod h1:wVuoA8VJLEcwgqHBwHmzLRazpKxTv13Px/pDuV7OomQ= github.com/btcsuite/btcd v0.21.0-beta.0.20201114000516-e9c7a5ac6401/go.mod h1:Sv4JPQ3/M+teHz9Bo5jBpkNcP0x6r7rdihlNL/7tTAs= +github.com/btcsuite/btcd v0.22.1 h1:CnwP9LM/M9xuRrGSCGeMVs9iv09uMqwsVX7EeIpgV2c= github.com/btcsuite/btcd v0.22.1/go.mod h1:wqgTSL29+50LRkmOVknEdmt8ZojIzhuWvgu/iptuN7Y= -github.com/btcsuite/btcd v0.22.2 h1:vBZ+lGGd1XubpOWO67ITJpAEsICWhA0YzqkcpkgNBfo= -github.com/btcsuite/btcd v0.22.2/go.mod h1:wqgTSL29+50LRkmOVknEdmt8ZojIzhuWvgu/iptuN7Y= github.com/btcsuite/btcd/btcec/v2 v2.1.2/go.mod h1:ctjw4H1kknNJmRN4iP1R7bTQ+v3GJkZBd6mui8ZsAZE= github.com/btcsuite/btcd/btcec/v2 v2.3.2 h1:5n0X6hX0Zk+6omWcihdYvdAlGf2DfasC0GMf7DClJ3U= github.com/btcsuite/btcd/btcec/v2 v2.3.2/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= @@ -248,6 +247,7 @@ github.com/cosmos/gorocksdb v1.2.0/go.mod h1:aaKvKItm514hKfNJpUJXnnOWeBnk2GL4+Qw github.com/cosmos/iavl v0.19.5 h1:rGA3hOrgNxgRM5wYcSCxgQBap7fW82WZgY78V9po/iY= github.com/cosmos/iavl v0.19.5/go.mod h1:X9PKD3J0iFxdmgNLa7b2LYWdsGd90ToV5cAONApkEPw= github.com/cosmos/keyring v1.2.0 h1:8C1lBP9xhImmIabyXW4c3vFjjLiBdGCmfLUfeZlV1Yo= +github.com/cosmos/keyring v1.2.0/go.mod h1:fc+wB5KTk9wQ9sDx0kFXB3A0MaeGHM9AwRStKOQ5vOA= github.com/cosmos/ledger-cosmos-go v0.12.2 h1:/XYaBlE2BJxtvpkHiBm97gFGSGmYGKunKyF3nNqAXZA= github.com/cosmos/ledger-cosmos-go v0.12.2/go.mod h1:ZcqYgnfNJ6lAXe4HPtWgarNEY+B74i+2/8MhZw4ziiI= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= diff --git a/x/bank/app_test.go b/x/bank/app_test.go index d1e810277edd..11b4ed0c5730 100644 --- a/x/bank/app_test.go +++ b/x/bank/app_test.go @@ -59,31 +59,21 @@ var ( }, } multiSendMsg3 = &types.MsgMultiSend{ - Inputs: []types.Input{ - types.NewInput(addr1, coins), - types.NewInput(addr4, coins), - }, - Outputs: []types.Output{ - types.NewOutput(addr2, coins), - types.NewOutput(addr3, coins), - }, - } - multiSendMsg4 = &types.MsgMultiSend{ - Inputs: []types.Input{ - types.NewInput(addr2, coins), - }, + Inputs: []types.Input{types.NewInput(addr2, coins)}, Outputs: []types.Output{ types.NewOutput(addr1, coins), }, } - multiSendMsg5 = &types.MsgMultiSend{ - Inputs: []types.Input{ - types.NewInput(addr1, coins), - }, + multiSendMsg4 = &types.MsgMultiSend{ + Inputs: []types.Input{types.NewInput(addr1, coins)}, Outputs: []types.Output{ types.NewOutput(moduleAccAddr, coins), }, } + invalidMultiSendMsg = &types.MsgMultiSend{ + Inputs: []types.Input{types.NewInput(addr1, coins), types.NewInput(addr2, coins)}, + Outputs: []types.Output{}, + } ) func TestSendNotEnoughBalance(t *testing.T) { @@ -163,13 +153,22 @@ func TestMsgMultiSendWithAccounts(t *testing.T) { }, { desc: "wrong accSeq should not pass Simulate", - msgs: []sdk.Msg{multiSendMsg5}, + msgs: []sdk.Msg{multiSendMsg4}, accNums: []uint64{0}, accSeqs: []uint64{0}, // wrong account sequence expSimPass: false, expPass: false, privKeys: []cryptotypes.PrivKey{priv1}, }, + { + desc: "multiple inputs not allowed", + msgs: []sdk.Msg{invalidMultiSendMsg}, + accNums: []uint64{0}, + accSeqs: []uint64{0}, + expSimPass: false, + expPass: false, + privKeys: []cryptotypes.PrivKey{priv1}, + }, } for _, tc := range testCases { @@ -234,58 +233,6 @@ func TestMsgMultiSendMultipleOut(t *testing.T) { } } -func TestMsgMultiSendMultipleInOut(t *testing.T) { - acc1 := &authtypes.BaseAccount{ - Address: addr1.String(), - } - acc2 := &authtypes.BaseAccount{ - Address: addr2.String(), - } - acc4 := &authtypes.BaseAccount{ - Address: addr4.String(), - } - - genAccs := []authtypes.GenesisAccount{acc1, acc2, acc4} - app := simapp.SetupWithGenesisAccounts(t, genAccs) - ctx := app.BaseApp.NewContext(false, tmproto.Header{}) - - require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr1, sdk.NewCoins(sdk.NewInt64Coin("foocoin", 42)))) - - require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr2, sdk.NewCoins(sdk.NewInt64Coin("foocoin", 42)))) - - require.NoError(t, testutil.FundAccount(app.BankKeeper, ctx, addr4, sdk.NewCoins(sdk.NewInt64Coin("foocoin", 42)))) - - app.Commit() - - testCases := []appTestCase{ - { - msgs: []sdk.Msg{multiSendMsg3}, - accNums: []uint64{0, 2}, - accSeqs: []uint64{0, 0}, - expSimPass: true, - expPass: true, - privKeys: []cryptotypes.PrivKey{priv1, priv4}, - expectedBalances: []expectedBalance{ - {addr1, sdk.Coins{sdk.NewInt64Coin("foocoin", 32)}}, - {addr4, sdk.Coins{sdk.NewInt64Coin("foocoin", 32)}}, - {addr2, sdk.Coins{sdk.NewInt64Coin("foocoin", 52)}}, - {addr3, sdk.Coins{sdk.NewInt64Coin("foocoin", 10)}}, - }, - }, - } - - for _, tc := range testCases { - header := tmproto.Header{Height: app.LastBlockHeight() + 1} - txGen := simapp.MakeTestEncodingConfig().TxConfig - _, _, err := simapp.SignCheckDeliver(t, txGen, app.BaseApp, header, tc.msgs, "", tc.accNums, tc.accSeqs, tc.expSimPass, tc.expPass, tc.privKeys...) - require.NoError(t, err) - - for _, eb := range tc.expectedBalances { - simapp.CheckBalance(t, app, eb.addr, eb.coins) - } - } -} - func TestMsgMultiSendDependent(t *testing.T) { acc1 := authtypes.NewBaseAccountWithAddress(addr1) acc2 := authtypes.NewBaseAccountWithAddress(addr2) @@ -314,7 +261,7 @@ func TestMsgMultiSendDependent(t *testing.T) { }, }, { - msgs: []sdk.Msg{multiSendMsg4}, + msgs: []sdk.Msg{multiSendMsg3}, accNums: []uint64{1}, accSeqs: []uint64{0}, expSimPass: true, diff --git a/x/bank/keeper/export_test.go b/x/bank/keeper/export_test.go new file mode 100644 index 000000000000..6ed22da65456 --- /dev/null +++ b/x/bank/keeper/export_test.go @@ -0,0 +1,12 @@ +package keeper + +// This file exists in the keeper package to expose some private things +// for the purpose of testing in the keeper_test package. + +func (k BaseSendKeeper) SetSendRestriction(restriction SendRestrictionFn) { + k.sendRestriction.Fn = restriction +} + +func (k BaseSendKeeper) GetSendRestrictionFn() SendRestrictionFn { + return k.sendRestriction.Fn +} diff --git a/x/bank/keeper/keeper.go b/x/bank/keeper/keeper.go index 32591fb12fd9..bc3c7d2b39f3 100644 --- a/x/bank/keeper/keeper.go +++ b/x/bank/keeper/keeper.go @@ -61,8 +61,6 @@ type BaseKeeper struct { mintCoinsRestrictionFn MintingRestrictionFn } -type MintingRestrictionFn func(ctx sdk.Context, coins sdk.Coins) error - // GetPaginatedTotalSupply queries for the supply, ignoring 0 coins, with a given pagination func (k BaseKeeper) GetPaginatedTotalSupply(ctx sdk.Context, pagination *query.PageRequest) (sdk.Coins, *query.PageResponse, error) { store := ctx.KVStore(k.storeKey) @@ -112,7 +110,7 @@ func NewBaseKeeper( cdc: cdc, storeKey: storeKey, paramSpace: paramSpace, - mintCoinsRestrictionFn: func(ctx sdk.Context, coins sdk.Coins) error { return nil }, + mintCoinsRestrictionFn: NoOpMintingRestrictionFn, } } @@ -122,18 +120,7 @@ func NewBaseKeeper( // // bankKeeper.WithMintCoinsRestriction(restriction1).WithMintCoinsRestriction(restriction2) func (k BaseKeeper) WithMintCoinsRestriction(check MintingRestrictionFn) BaseKeeper { - oldRestrictionFn := k.mintCoinsRestrictionFn - k.mintCoinsRestrictionFn = func(ctx sdk.Context, coins sdk.Coins) error { - err := check(ctx, coins) - if err != nil { - return err - } - err = oldRestrictionFn(ctx, coins) - if err != nil { - return err - } - return nil - } + k.mintCoinsRestrictionFn = check.Then(k.mintCoinsRestrictionFn) return k } diff --git a/x/bank/keeper/keeper_test.go b/x/bank/keeper/keeper_test.go index 8e6e137ed8ae..18056f834094 100644 --- a/x/bank/keeper/keeper_test.go +++ b/x/bank/keeper/keeper_test.go @@ -354,14 +354,13 @@ func (suite *IntegrationTestSuite) TestInputOutputNewAccount() { suite.Require().Nil(app.AccountKeeper.GetAccount(ctx, addr2)) suite.Require().Empty(app.BankKeeper.GetAllBalances(ctx, addr2)) - inputs := []types.Input{ - {Address: addr1.String(), Coins: sdk.NewCoins(newFooCoin(30), newBarCoin(10))}, - } + input := types.Input{Address: addr1.String(), Coins: sdk.NewCoins(newFooCoin(30), newBarCoin(10))} + outputs := []types.Output{ {Address: addr2.String(), Coins: sdk.NewCoins(newFooCoin(30), newBarCoin(10))}, } - suite.Require().NoError(app.BankKeeper.InputOutputCoins(ctx, inputs, outputs)) + suite.Require().NoError(app.BankKeeper.InputOutputCoins(ctx, input, outputs)) expected := sdk.NewCoins(newFooCoin(30), newBarCoin(10)) acc2Balances := app.BankKeeper.GetAllBalances(ctx, addr2) @@ -385,30 +384,26 @@ func (suite *IntegrationTestSuite) TestInputOutputCoins() { acc3 := app.AccountKeeper.NewAccountWithAddress(ctx, addr3) app.AccountKeeper.SetAccount(ctx, acc3) - inputs := []types.Input{ - {Address: addr1.String(), Coins: sdk.NewCoins(newFooCoin(30), newBarCoin(10))}, - {Address: addr1.String(), Coins: sdk.NewCoins(newFooCoin(30), newBarCoin(10))}, - } + input := types.Input{Address: addr1.String(), Coins: sdk.NewCoins(newFooCoin(60), newBarCoin(20))} + outputs := []types.Output{ {Address: addr2.String(), Coins: sdk.NewCoins(newFooCoin(30), newBarCoin(10))}, {Address: addr3.String(), Coins: sdk.NewCoins(newFooCoin(30), newBarCoin(10))}, } - suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, inputs, []types.Output{})) - suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, inputs, outputs)) + suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, input, []types.Output{})) + suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, input, outputs)) suite.Require().NoError(testutil.FundAccount(app.BankKeeper, ctx, addr1, balances)) - insufficientInputs := []types.Input{ - {Address: addr1.String(), Coins: sdk.NewCoins(newFooCoin(300), newBarCoin(100))}, - {Address: addr1.String(), Coins: sdk.NewCoins(newFooCoin(300), newBarCoin(100))}, - } + insufficientInput := types.Input{Address: addr1.String(), Coins: sdk.NewCoins(newFooCoin(300), newBarCoin(100))} + insufficientOutputs := []types.Output{ {Address: addr2.String(), Coins: sdk.NewCoins(newFooCoin(300), newBarCoin(100))}, {Address: addr3.String(), Coins: sdk.NewCoins(newFooCoin(300), newBarCoin(100))}, } - suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, insufficientInputs, insufficientOutputs)) - suite.Require().NoError(app.BankKeeper.InputOutputCoins(ctx, inputs, outputs)) + suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, insufficientInput, insufficientOutputs)) + suite.Require().NoError(app.BankKeeper.InputOutputCoins(ctx, input, outputs)) acc1Balances := app.BankKeeper.GetAllBalances(ctx, addr1) expected := sdk.NewCoins(newFooCoin(30), newBarCoin(10)) @@ -602,58 +597,59 @@ func (suite *IntegrationTestSuite) TestMsgMultiSendEvents() { app.AccountKeeper.SetAccount(ctx, acc) app.AccountKeeper.SetAccount(ctx, acc2) + coins := sdk.NewCoins(sdk.NewInt64Coin(fooDenom, 50), sdk.NewInt64Coin(barDenom, 100)) newCoins := sdk.NewCoins(sdk.NewInt64Coin(fooDenom, 50)) newCoins2 := sdk.NewCoins(sdk.NewInt64Coin(barDenom, 100)) - inputs := []types.Input{ - {Address: addr.String(), Coins: newCoins}, - {Address: addr2.String(), Coins: newCoins2}, - } + input := types.Input{Address: addr.String(), Coins: coins} outputs := []types.Output{ {Address: addr3.String(), Coins: newCoins}, {Address: addr4.String(), Coins: newCoins2}, } - suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, inputs, outputs)) + suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, input, outputs)) events := ctx.EventManager().ABCIEvents() suite.Require().Equal(0, len(events)) // Set addr's coins but not addr2's coins - suite.Require().NoError(testutil.FundAccount(app.BankKeeper, ctx, addr, sdk.NewCoins(sdk.NewInt64Coin(fooDenom, 50)))) - suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, inputs, outputs)) + suite.Require().NoError(testutil.FundAccount(app.BankKeeper, ctx, addr, newCoins)) + suite.Require().Error(app.BankKeeper.InputOutputCoins(ctx, input, outputs)) events = ctx.EventManager().ABCIEvents() - suite.Require().Equal(8, len(events)) // 7 events because account funding causes extra minting + coin_spent + coin_recv events + suite.Require().Equal(6, len(events)) // 6 events because account funding causes extra minting + coin_spent + coin_recv events event1 := sdk.Event{ - Type: sdk.EventTypeMessage, + Type: types.EventTypeCoinReceived, Attributes: []abci.EventAttribute{}, } event1.Attributes = append( event1.Attributes, - abci.EventAttribute{Key: []byte(types.AttributeKeySender), Value: []byte(addr.String())}, + abci.EventAttribute{Key: []byte(types.AttributeKeyReceiver), Value: []byte(addr.String())}, + abci.EventAttribute{Key: []byte(sdk.AttributeKeyAmount), Value: []byte(newCoins.String())}, ) - suite.Require().Equal(abci.Event(event1), events[7]) + + suite.Require().Equal(abci.Event(event1), events[3]) // Set addr's coins and addr2's coins - suite.Require().NoError(testutil.FundAccount(app.BankKeeper, ctx, addr, sdk.NewCoins(sdk.NewInt64Coin(fooDenom, 50)))) + suite.Require().NoError(testutil.FundAccount(app.BankKeeper, ctx, addr, sdk.NewCoins(sdk.NewInt64Coin(fooDenom, 50), sdk.NewInt64Coin(barDenom, 100)))) newCoins = sdk.NewCoins(sdk.NewInt64Coin(fooDenom, 50)) suite.Require().NoError(testutil.FundAccount(app.BankKeeper, ctx, addr2, sdk.NewCoins(sdk.NewInt64Coin(barDenom, 100)))) newCoins2 = sdk.NewCoins(sdk.NewInt64Coin(barDenom, 100)) - suite.Require().NoError(app.BankKeeper.InputOutputCoins(ctx, inputs, outputs)) + suite.Require().NoError(app.BankKeeper.InputOutputCoins(ctx, input, outputs)) events = ctx.EventManager().ABCIEvents() - suite.Require().Equal(28, len(events)) // 25 due to account funding + coin_spent + coin_recv events + suite.Require().Equal(24, len(events)) // 24 due to account funding + coin_spent + coin_recv events event2 := sdk.Event{ - Type: sdk.EventTypeMessage, + Type: types.EventTypeCoinReceived, Attributes: []abci.EventAttribute{}, } event2.Attributes = append( event2.Attributes, - abci.EventAttribute{Key: []byte(types.AttributeKeySender), Value: []byte(addr2.String())}, + abci.EventAttribute{Key: []byte(types.AttributeKeyReceiver), Value: []byte(addr2.String())}, + abci.EventAttribute{Key: []byte(sdk.AttributeKeyAmount), Value: []byte(newCoins2.String())}, ) event3 := sdk.Event{ Type: types.EventTypeTransfer, @@ -678,11 +674,10 @@ func (suite *IntegrationTestSuite) TestMsgMultiSendEvents() { event4.Attributes, abci.EventAttribute{Key: []byte(sdk.AttributeKeyAmount), Value: []byte(newCoins2.String())}, ) - // events are shifted due to the funding account events - suite.Require().Equal(abci.Event(event1), events[21]) - suite.Require().Equal(abci.Event(event2), events[23]) - suite.Require().Equal(abci.Event(event3), events[25]) - suite.Require().Equal(abci.Event(event4), events[27]) + //events are shifted due to the funding account events + suite.Require().Equal(abci.Event(event2), events[15]) + suite.Require().Equal(abci.Event(event3), events[21]) + suite.Require().Equal(abci.Event(event4), events[23]) } func (suite *IntegrationTestSuite) TestSpendableCoins() { diff --git a/x/bank/keeper/msg_server.go b/x/bank/keeper/msg_server.go index 4e9237631d54..f2358fc15cad 100644 --- a/x/bank/keeper/msg_server.go +++ b/x/bank/keeper/msg_server.go @@ -88,7 +88,7 @@ func (k msgServer) MultiSend(goCtx context.Context, msg *types.MsgMultiSend) (*t } } - err := k.InputOutputCoins(ctx, msg.Inputs, msg.Outputs) + err := k.InputOutputCoins(ctx, msg.Inputs[0], msg.Outputs) if err != nil { return nil, err } diff --git a/x/bank/keeper/restrictions.go b/x/bank/keeper/restrictions.go new file mode 100644 index 000000000000..c47b69dff56d --- /dev/null +++ b/x/bank/keeper/restrictions.go @@ -0,0 +1,118 @@ +package keeper + +import sdk "github.com/cosmos/cosmos-sdk/types" + +// A MintingRestrictionFn can restrict minting of coins. +type MintingRestrictionFn func(ctx sdk.Context, coins sdk.Coins) error + +var _ MintingRestrictionFn = NoOpMintingRestrictionFn + +// NoOpMintingRestrictionFn is a no-op MintingRestrictionFn. +func NoOpMintingRestrictionFn(_ sdk.Context, _ sdk.Coins) error { + return nil +} + +// Then creates a composite restriction that runs this one then the provided second one. +func (r MintingRestrictionFn) Then(second MintingRestrictionFn) MintingRestrictionFn { + return ComposeMintingRestrictions(r, second) +} + +// ComposeMintingRestrictions combines multiple MintingRestrictionFn into one. +// nil entries are ignored. +// If all entries are nil, nil is returned. +// If exactly one entry is not nil, it is returned. +// Otherwise, a new MintingRestrictionFn is returned that runs the non-nil restrictions in the order they are given. +// The composition runs each minting restriction until an error is encountered and returns that error. +func ComposeMintingRestrictions(restrictions ...MintingRestrictionFn) MintingRestrictionFn { + toRun := make([]MintingRestrictionFn, 0, len(restrictions)) + for _, r := range restrictions { + if r != nil { + toRun = append(toRun, r) + } + } + switch len(toRun) { + case 0: + return nil + case 1: + return toRun[0] + } + return func(ctx sdk.Context, coins sdk.Coins) error { + for _, r := range toRun { + err := r(ctx, coins) + if err != nil { + return err + } + } + return nil + } +} + +// A SendRestrictionFn can restrict sends and/or provide a new receiver address. +type SendRestrictionFn func(ctx sdk.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) (newToAddr sdk.AccAddress, err error) + +var _ SendRestrictionFn = NoOpSendRestrictionFn + +// NoOpSendRestrictionFn is a no-op SendRestrictionFn. +func NoOpSendRestrictionFn(_ sdk.Context, _, toAddr sdk.AccAddress, _ sdk.Coins) (sdk.AccAddress, error) { + return toAddr, nil +} + +// Then creates a composite restriction that runs this one then the provided second one. +func (r SendRestrictionFn) Then(second SendRestrictionFn) SendRestrictionFn { + return ComposeSendRestrictions(r, second) +} + +// ComposeSendRestrictions combines multiple SendRestrictionFn into one. +// nil entries are ignored. +// If all entries are nil, nil is returned. +// If exactly one entry is not nil, it is returned. +// Otherwise, a new SendRestrictionFn is returned that runs the non-nil restrictions in the order they are given. +// The composition runs each send restriction until an error is encountered and returns that error, +// otherwise it returns the toAddr of the last send restriction. +func ComposeSendRestrictions(restrictions ...SendRestrictionFn) SendRestrictionFn { + toRun := make([]SendRestrictionFn, 0, len(restrictions)) + for _, r := range restrictions { + if r != nil { + toRun = append(toRun, r) + } + } + switch len(toRun) { + case 0: + return nil + case 1: + return toRun[0] + } + return func(ctx sdk.Context, fromAddr, toAddr sdk.AccAddress, amt sdk.Coins) (sdk.AccAddress, error) { + var err error + for _, r := range toRun { + toAddr, err = r(ctx, fromAddr, toAddr, amt) + if err != nil { + return toAddr, err + } + } + return toAddr, err + } +} + +// SendRestriction is a struct that houses a SendRestrictionFn. +// It exists so that the SendRestrictionFn can be updated in the SendKeeper without needing to have a pointer receiver. +type SendRestriction struct { + Fn SendRestrictionFn +} + +// NewSendRestriction creates a new SendRestriction with nil send restriction. +func NewSendRestriction() *SendRestriction { + return &SendRestriction{ + Fn: nil, + } +} + +// Append adds the provided restriction to this, to be run after the existing function. +func (r *SendRestriction) Append(restriction SendRestrictionFn) { + r.Fn = r.Fn.Then(restriction) +} + +// Prepend adds the provided restriction to this, to be run before the existing function. +func (r *SendRestriction) Prepend(restriction SendRestrictionFn) { + r.Fn = restriction.Then(r.Fn) +} diff --git a/x/bank/keeper/restrictoins_test.go b/x/bank/keeper/restrictoins_test.go new file mode 100644 index 000000000000..0b598002858c --- /dev/null +++ b/x/bank/keeper/restrictoins_test.go @@ -0,0 +1,918 @@ +package keeper_test + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/bank/keeper" +) + +// MintingRestrictionArgs are the args provided to a MintingRestrictionFn function. +type MintingRestrictionArgs struct { + Name string + Coins sdk.Coins +} + +// MintingRestrictionTestHelper is a struct with stuff helpful for testing the MintingRestrictionFn stuff. +type MintingRestrictionTestHelper struct { + Calls []*MintingRestrictionArgs +} + +func NewMintingRestrictionTestHelper() *MintingRestrictionTestHelper { + return &MintingRestrictionTestHelper{Calls: make([]*MintingRestrictionArgs, 0, 2)} +} + +// RecordCall makes note that the provided args were used as a funcion call. +func (s *MintingRestrictionTestHelper) RecordCall(name string, coins sdk.Coins) { + s.Calls = append(s.Calls, s.NewArgs(name, coins)) +} + +// NewCalls is just a shorter way to create a []*MintingRestrictionArgs. +func (s *MintingRestrictionTestHelper) NewCalls(args ...*MintingRestrictionArgs) []*MintingRestrictionArgs { + return args +} + +// NewArgs creates a new MintingRestrictionArgs. +func (s *MintingRestrictionTestHelper) NewArgs(name string, coins sdk.Coins) *MintingRestrictionArgs { + return &MintingRestrictionArgs{ + Name: name, + Coins: coins, + } +} + +// NamedRestriction creates a new MintingRestrictionFn function that records the arguments it's called with and returns nil. +func (s *MintingRestrictionTestHelper) NamedRestriction(name string) keeper.MintingRestrictionFn { + return func(_ sdk.Context, coins sdk.Coins) error { + s.RecordCall(name, coins) + return nil + } +} + +// ErrorRestriction creates a new MintingRestrictionFn function that returns an error. +func (s *MintingRestrictionTestHelper) ErrorRestriction(message string) keeper.MintingRestrictionFn { + return func(_ sdk.Context, coins sdk.Coins) error { + s.RecordCall(message, coins) + return errors.New(message) + } +} + +// MintingRestrictionTestParams are parameters to test regarding calling a MintingRestrictionFn. +type MintingRestrictionTestParams struct { + // ExpNil is whether to expect the provided MintingRestrictionFn to be nil. + // If it is true, the rest of these test params are ignored. + ExpNil bool + // Coins is the MintingRestrictionFn coins input. + Coins sdk.Coins + // ExpErr is the expected return error string. + ExpErr string + // ExpCalls is the args of all the MintingRestrictionFn calls that end up being made. + ExpCalls []*MintingRestrictionArgs +} + +// TestActual tests the provided MintingRestrictionFn using the provided test parameters. +func (s *MintingRestrictionTestHelper) TestActual(t *testing.T, tp *MintingRestrictionTestParams, actual keeper.MintingRestrictionFn) { + t.Helper() + if tp.ExpNil { + require.Nil(t, actual, "resulting MintingRestrictionFn") + } else { + require.NotNil(t, actual, "resulting MintingRestrictionFn") + s.Calls = s.Calls[:0] + err := actual(sdk.Context{}, tp.Coins) + if len(tp.ExpErr) != 0 { + assert.EqualError(t, err, tp.ExpErr, "composite MintingRestrictionFn output error") + } else { + assert.NoError(t, err, "composite MintingRestrictionFn output error") + } + assert.Equal(t, tp.ExpCalls, s.Calls, "args given to funcs in composite MintingRestrictionFn") + } +} + +func TestMintingRestriction_Then(t *testing.T) { + coins := sdk.NewCoins(sdk.NewInt64Coin("acoin", 2), sdk.NewInt64Coin("bcoin", 4)) + + h := NewMintingRestrictionTestHelper() + + tests := []struct { + name string + base keeper.MintingRestrictionFn + second keeper.MintingRestrictionFn + exp *MintingRestrictionTestParams + }{ + { + name: "nil nil", + base: nil, + second: nil, + exp: &MintingRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "nil noop", + base: nil, + second: h.NamedRestriction("noop"), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("noop", coins)), + }, + }, + { + name: "noop nil", + base: h.NamedRestriction("noop"), + second: nil, + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("noop", coins)), + }, + }, + { + name: "noop noop", + base: h.NamedRestriction("noop1"), + second: h.NamedRestriction("noop2"), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("noop1", coins), h.NewArgs("noop2", coins)), + }, + }, + { + name: "noop error", + base: h.NamedRestriction("noop"), + second: h.ErrorRestriction("this is a test error"), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "this is a test error", + ExpCalls: h.NewCalls(h.NewArgs("noop", coins), h.NewArgs("this is a test error", coins)), + }, + }, + { + name: "error noop", + base: h.ErrorRestriction("another test error"), + second: h.NamedRestriction("noop"), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "another test error", + ExpCalls: h.NewCalls(h.NewArgs("another test error", coins)), + }, + }, + { + name: "error error", + base: h.ErrorRestriction("first test error"), + second: h.ErrorRestriction("second test error"), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "first test error", + ExpCalls: h.NewCalls(h.NewArgs("first test error", coins)), + }, + }, + { + name: "double chain", + base: keeper.ComposeMintingRestrictions(h.NamedRestriction("r1"), h.NamedRestriction("r2")), + second: keeper.ComposeMintingRestrictions(h.NamedRestriction("r3"), h.NamedRestriction("r4")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls( + h.NewArgs("r1", coins), + h.NewArgs("r2", coins), + h.NewArgs("r3", coins), + h.NewArgs("r4", coins), + ), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var actual keeper.MintingRestrictionFn + testFunc := func() { + actual = tc.base.Then(tc.second) + } + require.NotPanics(t, testFunc, "MintingRestrictionFn.Then") + h.TestActual(t, tc.exp, actual) + }) + } +} + +func TestComposeMintingRestrictions(t *testing.T) { + rz := func(rs ...keeper.MintingRestrictionFn) []keeper.MintingRestrictionFn { + return rs + } + coins := sdk.NewCoins(sdk.NewInt64Coin("ccoin", 8), sdk.NewInt64Coin("dcoin", 16)) + + h := NewMintingRestrictionTestHelper() + + tests := []struct { + name string + input []keeper.MintingRestrictionFn + exp *MintingRestrictionTestParams + }{ + { + name: "nil list", + input: nil, + exp: &MintingRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "empty list", + input: rz(), + exp: &MintingRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "only nil entry", + input: rz(nil), + exp: &MintingRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "five nil entries", + input: rz(nil, nil, nil, nil, nil), + exp: &MintingRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "only noop entry", + input: rz(h.NamedRestriction("noop")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("noop", coins)), + }, + }, + { + name: "only error entry", + input: rz(h.ErrorRestriction("test error")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "test error", + ExpCalls: h.NewCalls(h.NewArgs("test error", coins)), + }, + }, + { + name: "noop nil nil", + input: rz(h.NamedRestriction("noop"), nil, nil), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("noop", coins)), + }, + }, + { + name: "nil noop nil", + input: rz(nil, h.NamedRestriction("noop"), nil), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("noop", coins)), + }, + }, + { + name: "nil nil noop", + input: rz(nil, nil, h.NamedRestriction("noop")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("noop", coins)), + }, + }, + { + name: "noop noop nil", + input: rz(h.NamedRestriction("r1"), h.NamedRestriction("r2"), nil), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("r1", coins), h.NewArgs("r2", coins)), + }, + }, + { + name: "noop nil noop", + input: rz(h.NamedRestriction("r1"), nil, h.NamedRestriction("r2")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("r1", coins), h.NewArgs("r2", coins)), + }, + }, + { + name: "nil noop noop", + input: rz(nil, h.NamedRestriction("r1"), h.NamedRestriction("r2")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("r1", coins), h.NewArgs("r2", coins)), + }, + }, + { + name: "noop noop noop", + input: rz(h.NamedRestriction("r1"), h.NamedRestriction("r2"), h.NamedRestriction("r3")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpCalls: h.NewCalls(h.NewArgs("r1", coins), h.NewArgs("r2", coins), h.NewArgs("r3", coins)), + }, + }, + { + name: "err noop noop", + input: rz(h.ErrorRestriction("first error"), h.NamedRestriction("r2"), h.NamedRestriction("r3")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "first error", + ExpCalls: h.NewCalls(h.NewArgs("first error", coins)), + }, + }, + { + name: "noop err noop", + input: rz(h.NamedRestriction("r1"), h.ErrorRestriction("second error"), h.NamedRestriction("r3")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "second error", + ExpCalls: h.NewCalls(h.NewArgs("r1", coins), h.NewArgs("second error", coins)), + }, + }, + { + name: "noop noop err", + input: rz(h.NamedRestriction("r1"), h.NamedRestriction("r2"), h.ErrorRestriction("third error")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "third error", + ExpCalls: h.NewCalls(h.NewArgs("r1", coins), h.NewArgs("r2", coins), h.NewArgs("third error", coins)), + }, + }, + { + name: "noop err err", + input: rz(h.NamedRestriction("r1"), h.ErrorRestriction("second error"), h.ErrorRestriction("third error")), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "second error", + ExpCalls: h.NewCalls(h.NewArgs("r1", coins), h.NewArgs("second error", coins)), + }, + }, + { + name: "big bang", + input: rz( + h.NamedRestriction("r1"), nil, h.NamedRestriction("r2"), nil, + h.NamedRestriction("r3"), h.NamedRestriction("r4"), h.NamedRestriction("r5"), + nil, h.NamedRestriction("r6"), h.NamedRestriction("r7"), nil, + h.NamedRestriction("r8"), nil, nil, h.ErrorRestriction("oops, an error"), + h.NamedRestriction("r9"), nil, h.NamedRestriction("ra"), // Not called. + ), + exp: &MintingRestrictionTestParams{ + Coins: coins, + ExpErr: "oops, an error", + ExpCalls: h.NewCalls( + h.NewArgs("r1", coins), + h.NewArgs("r2", coins), + h.NewArgs("r3", coins), + h.NewArgs("r4", coins), + h.NewArgs("r5", coins), + h.NewArgs("r6", coins), + h.NewArgs("r7", coins), + h.NewArgs("r8", coins), + h.NewArgs("oops, an error", coins), + ), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var actual keeper.MintingRestrictionFn + testFunc := func() { + actual = keeper.ComposeMintingRestrictions(tc.input...) + } + require.NotPanics(t, testFunc, "ComposeMintingRestrictions") + h.TestActual(t, tc.exp, actual) + }) + } +} + +func TestNoOpMintingRestrictionFn(t *testing.T) { + var err error + testFunc := func() { + err = keeper.NoOpMintingRestrictionFn(sdk.Context{}, sdk.Coins{}) + } + require.NotPanics(t, testFunc, "NoOpMintingRestrictionFn") + assert.NoError(t, err, "NoOpSendRestrictionFn error") +} + +// SendRestrictionArgs are the args provided to a SendRestrictionFn function. +type SendRestrictionArgs struct { + Name string + FromAddr sdk.AccAddress + ToAddr sdk.AccAddress + Coins sdk.Coins +} + +// SendRestrictionTestHelper is a struct with stuff helpful for testing the SendRestrictionFn stuff. +type SendRestrictionTestHelper struct { + Calls []*SendRestrictionArgs +} + +func NewSendRestrictionTestHelper() *SendRestrictionTestHelper { + return &SendRestrictionTestHelper{Calls: make([]*SendRestrictionArgs, 0, 2)} +} + +// RecordCall makes note that the provided args were used as a funcion call. +func (s *SendRestrictionTestHelper) RecordCall(name string, fromAddr, toAddr sdk.AccAddress, coins sdk.Coins) { + s.Calls = append(s.Calls, s.NewArgs(name, fromAddr, toAddr, coins)) +} + +// NewCalls is just a shorter way to create a []*SendRestrictionArgs. +func (s *SendRestrictionTestHelper) NewCalls(args ...*SendRestrictionArgs) []*SendRestrictionArgs { + return args +} + +// NewArgs creates a new SendRestrictionArgs. +func (s *SendRestrictionTestHelper) NewArgs(name string, fromAddr, toAddr sdk.AccAddress, coins sdk.Coins) *SendRestrictionArgs { + return &SendRestrictionArgs{ + Name: name, + FromAddr: fromAddr, + ToAddr: toAddr, + Coins: coins, + } +} + +// NamedRestriction creates a new SendRestrictionFn function that records the arguments it's called with and returns the provided toAddr. +func (s *SendRestrictionTestHelper) NamedRestriction(name string) keeper.SendRestrictionFn { + return func(_ sdk.Context, fromAddr, toAddr sdk.AccAddress, coins sdk.Coins) (sdk.AccAddress, error) { + s.RecordCall(name, fromAddr, toAddr, coins) + return toAddr, nil + } +} + +// NewToRestriction creates a new SendRestrictionFn function that returns a different toAddr than provided. +func (s *SendRestrictionTestHelper) NewToRestriction(name string, addr sdk.AccAddress) keeper.SendRestrictionFn { + return func(_ sdk.Context, fromAddr, toAddr sdk.AccAddress, coins sdk.Coins) (sdk.AccAddress, error) { + s.RecordCall(name, fromAddr, toAddr, coins) + return addr, nil + } +} + +// ErrorRestriction creates a new SendRestrictionFn function that returns a nil toAddr and an error. +func (s *SendRestrictionTestHelper) ErrorRestriction(message string) keeper.SendRestrictionFn { + return func(_ sdk.Context, fromAddr, toAddr sdk.AccAddress, coins sdk.Coins) (sdk.AccAddress, error) { + s.RecordCall(message, fromAddr, toAddr, coins) + return nil, errors.New(message) + } +} + +// SendRestrictionTestParams are parameters to test regarding calling a SendRestrictionFn. +type SendRestrictionTestParams struct { + // ExpNil is whether to expect the provided SendRestrictionFn to be nil. + // If it is true, the rest of these test params are ignored. + ExpNil bool + // FromAddr is the SendRestrictionFn fromAddr input. + FromAddr sdk.AccAddress + // ToAddr is the SendRestrictionFn toAddr input. + ToAddr sdk.AccAddress + // Coins is the SendRestrictionFn coins input. + Coins sdk.Coins + // ExpAddr is the expected return address. + ExpAddr sdk.AccAddress + // ExpErr is the expected return error string. + ExpErr string + // ExpCalls is the args of all the SendRestrictionFn calls that end up being made. + ExpCalls []*SendRestrictionArgs +} + +// TestActual tests the provided SendRestrictionFn using the provided test parameters. +func (s *SendRestrictionTestHelper) TestActual(t *testing.T, tp *SendRestrictionTestParams, actual keeper.SendRestrictionFn) { + t.Helper() + if tp.ExpNil { + require.Nil(t, actual, "resulting SendRestrictionFn") + } else { + require.NotNil(t, actual, "resulting SendRestrictionFn") + s.Calls = s.Calls[:0] + addr, err := actual(sdk.Context{}, tp.FromAddr, tp.ToAddr, tp.Coins) + if len(tp.ExpErr) != 0 { + assert.EqualError(t, err, tp.ExpErr, "composite SendRestrictionFn output error") + } else { + assert.NoError(t, err, "composite SendRestrictionFn output error") + } + assert.Equal(t, tp.ExpAddr, addr, "composite SendRestrictionFn output address") + assert.Equal(t, tp.ExpCalls, s.Calls, "args given to funcs in composite SendRestrictionFn") + } +} + +func TestSendRestriction_Then(t *testing.T) { + fromAddr := sdk.AccAddress("fromaddr____________") + addr0 := sdk.AccAddress("0addr_______________") + addr1 := sdk.AccAddress("1addr_______________") + addr2 := sdk.AccAddress("2addr_______________") + addr3 := sdk.AccAddress("3addr_______________") + addr4 := sdk.AccAddress("4addr_______________") + coins := sdk.NewCoins(sdk.NewInt64Coin("ecoin", 32), sdk.NewInt64Coin("fcoin", 64)) + + h := NewSendRestrictionTestHelper() + + tests := []struct { + name string + base keeper.SendRestrictionFn + second keeper.SendRestrictionFn + exp *SendRestrictionTestParams + }{ + { + name: "nil nil", + base: nil, + second: nil, + exp: &SendRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "nil noop", + base: nil, + second: h.NamedRestriction("noop"), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: addr1, + ExpCalls: h.NewCalls(h.NewArgs("noop", fromAddr, addr1, coins)), + }, + }, + { + name: "noop nil", + base: h.NamedRestriction("noop"), + second: nil, + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: addr1, + ExpCalls: h.NewCalls(h.NewArgs("noop", fromAddr, addr1, coins)), + }, + }, + { + name: "noop noop", + base: h.NamedRestriction("noop1"), + second: h.NamedRestriction("noop2"), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: addr1, + ExpCalls: h.NewCalls( + h.NewArgs("noop1", fromAddr, addr1, coins), + h.NewArgs("noop2", fromAddr, addr1, coins), + ), + }, + }, + { + name: "setter setter", + base: h.NewToRestriction("r1", addr2), + second: h.NewToRestriction("r2", addr3), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: addr3, + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr1, coins), + h.NewArgs("r2", fromAddr, addr2, coins), + ), + }, + }, + { + name: "setter error", + base: h.NewToRestriction("r1", addr2), + second: h.ErrorRestriction("this is a test error"), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: nil, + ExpErr: "this is a test error", + ExpCalls: h.NewCalls(h.NewArgs( + "r1", fromAddr, addr1, coins), + h.NewArgs("this is a test error", fromAddr, addr2, coins), + ), + }, + }, + { + name: "error setter", + base: h.ErrorRestriction("another test error"), + second: h.NewToRestriction("r2", addr3), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: nil, + ExpErr: "another test error", + ExpCalls: h.NewCalls(h.NewArgs("another test error", fromAddr, addr1, coins)), + }, + }, + { + name: "error error", + base: h.ErrorRestriction("first test error"), + second: h.ErrorRestriction("second test error"), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: nil, + ExpErr: "first test error", + ExpCalls: h.NewCalls(h.NewArgs("first test error", fromAddr, addr1, coins)), + }, + }, + { + name: "double chain", + base: keeper.ComposeSendRestrictions(h.NewToRestriction("r1", addr1), h.NewToRestriction("r2", addr2)), + second: keeper.ComposeSendRestrictions(h.NewToRestriction("r3", addr3), h.NewToRestriction("r4", addr4)), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr0, + Coins: coins, + ExpAddr: addr4, + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr0, coins), + h.NewArgs("r2", fromAddr, addr1, coins), + h.NewArgs("r3", fromAddr, addr2, coins), + h.NewArgs("r4", fromAddr, addr3, coins), + ), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var actual keeper.SendRestrictionFn + testFunc := func() { + actual = tc.base.Then(tc.second) + } + require.NotPanics(t, testFunc, "SendRestrictionFn.Then") + h.TestActual(t, tc.exp, actual) + }) + } +} + +func TestComposeSendRestrictions(t *testing.T) { + rz := func(rs ...keeper.SendRestrictionFn) []keeper.SendRestrictionFn { + return rs + } + fromAddr := sdk.AccAddress("fromaddr____________") + addr0 := sdk.AccAddress("0addr_______________") + addr1 := sdk.AccAddress("1addr_______________") + addr2 := sdk.AccAddress("2addr_______________") + addr3 := sdk.AccAddress("3addr_______________") + addr4 := sdk.AccAddress("4addr_______________") + coins := sdk.NewCoins(sdk.NewInt64Coin("gcoin", 128), sdk.NewInt64Coin("hcoin", 256)) + + h := NewSendRestrictionTestHelper() + + tests := []struct { + name string + input []keeper.SendRestrictionFn + exp *SendRestrictionTestParams + }{ + { + name: "nil list", + input: nil, + exp: &SendRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "empty list", + input: rz(), + exp: &SendRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "only nil entry", + input: rz(nil), + exp: &SendRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "five nil entries", + input: rz(nil, nil, nil, nil, nil), + exp: &SendRestrictionTestParams{ + ExpNil: true, + }, + }, + { + name: "only noop entry", + input: rz(h.NamedRestriction("noop")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr0, + Coins: coins, + ExpAddr: addr0, + ExpCalls: h.NewCalls(h.NewArgs("noop", fromAddr, addr0, coins)), + }, + }, + { + name: "only error entry", + input: rz(h.ErrorRestriction("test error")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr0, + Coins: coins, + ExpAddr: nil, + ExpErr: "test error", + ExpCalls: h.NewCalls(h.NewArgs("test error", fromAddr, addr0, coins)), + }, + }, + { + name: "noop nil nil", + input: rz(h.NamedRestriction("noop"), nil, nil), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr0, + Coins: coins, + ExpAddr: addr0, + ExpCalls: h.NewCalls(h.NewArgs("noop", fromAddr, addr0, coins)), + }, + }, + { + name: "nil noop nil", + input: rz(nil, h.NamedRestriction("noop"), nil), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: addr1, + ExpCalls: h.NewCalls(h.NewArgs("noop", fromAddr, addr1, coins)), + }, + }, + { + name: "nil nil noop", + input: rz(nil, nil, h.NamedRestriction("noop")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr2, + Coins: coins, + ExpAddr: addr2, + ExpCalls: h.NewCalls(h.NewArgs("noop", fromAddr, addr2, coins)), + }, + }, + { + name: "noop noop nil", + input: rz(h.NamedRestriction("r1"), h.NamedRestriction("r2"), nil), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr0, + Coins: coins, + ExpAddr: addr0, + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr0, coins), + h.NewArgs("r2", fromAddr, addr0, coins), + ), + }, + }, + { + name: "noop nil noop", + input: rz(h.NamedRestriction("r1"), nil, h.NamedRestriction("r2")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr1, + Coins: coins, + ExpAddr: addr1, + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr1, coins), + h.NewArgs("r2", fromAddr, addr1, coins), + ), + }, + }, + { + name: "nil noop noop", + input: rz(nil, h.NamedRestriction("r1"), h.NamedRestriction("r2")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr2, + Coins: coins, + ExpAddr: addr2, + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr2, coins), + h.NewArgs("r2", fromAddr, addr2, coins), + ), + }, + }, + { + name: "noop noop noop", + input: rz(h.NamedRestriction("r1"), h.NamedRestriction("r2"), h.NamedRestriction("r3")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr3, + Coins: coins, + ExpAddr: addr3, + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr3, coins), + h.NewArgs("r2", fromAddr, addr3, coins), + h.NewArgs("r3", fromAddr, addr3, coins), + ), + }, + }, + { + name: "err noop noop", + input: rz(h.ErrorRestriction("first error"), h.NamedRestriction("r2"), h.NamedRestriction("r3")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr4, + Coins: coins, + ExpAddr: nil, + ExpErr: "first error", + ExpCalls: h.NewCalls(h.NewArgs("first error", fromAddr, addr4, coins)), + }, + }, + { + name: "noop err noop", + input: rz(h.NamedRestriction("r1"), h.ErrorRestriction("second error"), h.NamedRestriction("r3")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr4, + Coins: coins, + ExpAddr: nil, + ExpErr: "second error", + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr4, coins), + h.NewArgs("second error", fromAddr, addr4, coins), + ), + }, + }, + { + name: "noop noop err", + input: rz(h.NamedRestriction("r1"), h.NamedRestriction("r2"), h.ErrorRestriction("third error")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr4, + Coins: coins, + ExpAddr: nil, + ExpErr: "third error", + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr4, coins), + h.NewArgs("r2", fromAddr, addr4, coins), + h.NewArgs("third error", fromAddr, addr4, coins), + ), + }, + }, + { + name: "new-to err err", + input: rz(h.NewToRestriction("r1", addr0), h.ErrorRestriction("second error"), h.ErrorRestriction("third error")), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr4, + Coins: coins, + ExpAddr: nil, + ExpErr: "second error", + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr4, coins), + h.NewArgs("second error", fromAddr, addr0, coins), + ), + }, + }, + { + name: "big bang", + input: rz( + h.NamedRestriction("r1"), nil, h.NewToRestriction("r2", addr1), // Called with orig toAddr. + nil, h.NamedRestriction("r3"), h.NewToRestriction("r4", addr2), // Called with addr1 toAddr. + h.NewToRestriction("r5", addr3), // Called with addr2 toAddr. + nil, h.NamedRestriction("r6"), h.NewToRestriction("r7", addr4), // Called with addr3 toAddr. + nil, h.NamedRestriction("r8"), nil, nil, h.ErrorRestriction("oops, an error"), // Called with addr4 toAddr. + h.NewToRestriction("r9", addr0), nil, h.NamedRestriction("ra"), // Not called. + ), + exp: &SendRestrictionTestParams{ + FromAddr: fromAddr, + ToAddr: addr0, + Coins: coins, + ExpAddr: nil, + ExpErr: "oops, an error", + ExpCalls: h.NewCalls( + h.NewArgs("r1", fromAddr, addr0, coins), + h.NewArgs("r2", fromAddr, addr0, coins), + h.NewArgs("r3", fromAddr, addr1, coins), + h.NewArgs("r4", fromAddr, addr1, coins), + h.NewArgs("r5", fromAddr, addr2, coins), + h.NewArgs("r6", fromAddr, addr3, coins), + h.NewArgs("r7", fromAddr, addr3, coins), + h.NewArgs("r8", fromAddr, addr4, coins), + h.NewArgs("oops, an error", fromAddr, addr4, coins), + ), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var actual keeper.SendRestrictionFn + testFunc := func() { + actual = keeper.ComposeSendRestrictions(tc.input...) + } + require.NotPanics(t, testFunc, "ComposeSendRestrictions") + h.TestActual(t, tc.exp, actual) + }) + } +} + +func TestNoOpSendRestrictionFn(t *testing.T) { + expAddr := sdk.AccAddress("__expectedaddr__") + var addr sdk.AccAddress + var err error + testFunc := func() { + addr, err = keeper.NoOpSendRestrictionFn(sdk.Context{}, sdk.AccAddress("first_addr"), expAddr, sdk.Coins{}) + } + require.NotPanics(t, testFunc, "NoOpSendRestrictionFn") + assert.NoError(t, err, "NoOpSendRestrictionFn error") + assert.Equal(t, expAddr, addr, "NoOpSendRestrictionFn addr") +} diff --git a/x/bank/keeper/send.go b/x/bank/keeper/send.go index 356be5ac1c93..c98aa8920f82 100644 --- a/x/bank/keeper/send.go +++ b/x/bank/keeper/send.go @@ -17,7 +17,10 @@ import ( type SendKeeper interface { ViewKeeper - InputOutputCoins(ctx sdk.Context, inputs []types.Input, outputs []types.Output) error + AppendSendRestriction(restriction SendRestrictionFn) + PrependSendRestriction(restriction SendRestrictionFn) + + InputOutputCoins(ctx sdk.Context, input types.Input, outputs []types.Output) error SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error GetParams(ctx sdk.Context) types.Params @@ -43,21 +46,34 @@ type BaseSendKeeper struct { // list of addresses that are restricted from receiving transactions blockedAddrs map[string]bool + + sendRestriction *SendRestriction } func NewBaseSendKeeper( cdc codec.BinaryCodec, storeKey storetypes.StoreKey, ak types.AccountKeeper, paramSpace paramtypes.Subspace, blockedAddrs map[string]bool, ) BaseSendKeeper { return BaseSendKeeper{ - BaseViewKeeper: NewBaseViewKeeper(cdc, storeKey, ak), - cdc: cdc, - ak: ak, - storeKey: storeKey, - paramSpace: paramSpace, - blockedAddrs: blockedAddrs, + BaseViewKeeper: NewBaseViewKeeper(cdc, storeKey, ak), + cdc: cdc, + ak: ak, + storeKey: storeKey, + paramSpace: paramSpace, + blockedAddrs: blockedAddrs, + sendRestriction: NewSendRestriction(), } } +// AppendSendRestriction adds the provided SendRestrictionFn to run after previously provided restrictions. +func (k BaseSendKeeper) AppendSendRestriction(restriction SendRestrictionFn) { + k.sendRestriction.Append(restriction) +} + +// PrependSendRestriction adds the provided SendRestrictionFn to run before previously provided restrictions. +func (k BaseSendKeeper) PrependSendRestriction(restriction SendRestrictionFn) { + k.sendRestriction.Prepend(restriction) +} + // GetParams returns the total set of bank parameters. func (k BaseSendKeeper) GetParams(ctx sdk.Context) (params types.Params) { k.paramSpace.GetParamSet(ctx, ¶ms) @@ -69,42 +85,47 @@ func (k BaseSendKeeper) SetParams(ctx sdk.Context, params types.Params) { k.paramSpace.SetParamSet(ctx, ¶ms) } -// InputOutputCoins performs multi-send functionality. It accepts a series of -// inputs that correspond to a series of outputs. It returns an error if the -// inputs and outputs don't lineup or if any single transfer of tokens fails. -func (k BaseSendKeeper) InputOutputCoins(ctx sdk.Context, inputs []types.Input, outputs []types.Output) error { +// InputOutputCoins performs multi-send functionality. It accepts an +// input that corresponds to a series of outputs. It returns an error if the +// input and outputs don't line up or if any single transfer of tokens fails. +func (k BaseSendKeeper) InputOutputCoins(ctx sdk.Context, input types.Input, outputs []types.Output) error { // Safety check ensuring that when sending coins the keeper must maintain the // Check supply invariant and validity of Coins. - if err := types.ValidateInputsOutputs(inputs, outputs); err != nil { + if err := types.ValidateInputOutputs(input, outputs); err != nil { return err } - for _, in := range inputs { - inAddress, err := sdk.AccAddressFromBech32(in.Address) - if err != nil { - return err - } - - err = k.subUnlockedCoins(ctx, inAddress, in.Coins) - if err != nil { - return err - } + inAddress, err := sdk.AccAddressFromBech32(input.Address) + if err != nil { + return err + } - ctx.EventManager().EmitEvent( - sdk.NewEvent( - sdk.EventTypeMessage, - sdk.NewAttribute(types.AttributeKeySender, in.Address), - ), - ) + err = k.subUnlockedCoins(ctx, inAddress, input.Coins) + if err != nil { + return err } + ctx.EventManager().EmitEvent( + sdk.NewEvent( + sdk.EventTypeMessage, + sdk.NewAttribute(types.AttributeKeySender, input.Address), + ), + ) + for _, out := range outputs { outAddress, err := sdk.AccAddressFromBech32(out.Address) if err != nil { return err } - err = k.addCoins(ctx, outAddress, out.Coins) - if err != nil { + + if k.sendRestriction.Fn != nil { + outAddress, err = k.sendRestriction.Fn(ctx, inAddress, outAddress, out.Coins) + if err != nil { + return err + } + } + + if err := k.addCoins(ctx, outAddress, out.Coins); err != nil { return err } @@ -133,6 +154,14 @@ func (k BaseSendKeeper) InputOutputCoins(ctx sdk.Context, inputs []types.Input, // SendCoins transfers amt coins from a sending account to a receiving account. // An error is returned upon failure. func (k BaseSendKeeper) SendCoins(ctx sdk.Context, fromAddr sdk.AccAddress, toAddr sdk.AccAddress, amt sdk.Coins) error { + if k.sendRestriction.Fn != nil { + var err error + toAddr, err = k.sendRestriction.Fn(ctx, fromAddr, toAddr, amt) + if err != nil { + return err + } + } + err := k.subUnlockedCoins(ctx, fromAddr, amt) if err != nil { return err @@ -287,6 +316,7 @@ func (k BaseSendKeeper) setBalance(ctx sdk.Context, addr sdk.AccAddress, balance if err != nil { return err } + accountStore.Set([]byte(balance.Denom), amount) // Store a reverse index from denomination to account address with a diff --git a/x/bank/simulation/operations.go b/x/bank/simulation/operations.go index d474160b70fa..39bd31d94d2c 100644 --- a/x/bank/simulation/operations.go +++ b/x/bank/simulation/operations.go @@ -176,7 +176,7 @@ func SimulateMsgMultiSend(ak types.AccountKeeper, bk keeper.Keeper) simtypes.Ope accs []simtypes.Account, chainID string, ) (simtypes.OperationMsg, []simtypes.FutureOperation, error) { // random number of inputs/outputs between [1, 3] - inputs := make([]types.Input, r.Intn(3)+1) + inputs := make([]types.Input, r.Intn(1)+1) outputs := make([]types.Output, r.Intn(3)+1) // collect signer privKeys diff --git a/x/bank/simulation/operations_test.go b/x/bank/simulation/operations_test.go index f50c2b944885..3331c9d004e7 100644 --- a/x/bank/simulation/operations_test.go +++ b/x/bank/simulation/operations_test.go @@ -112,12 +112,12 @@ func (suite *SimTestSuite) TestSimulateMsgMultiSend() { types.ModuleCdc.UnmarshalJSON(operationMsg.Msg, &msg) require.True(operationMsg.OK) - require.Len(msg.Inputs, 3) - require.Equal("cosmos1p8wcgrjr4pjju90xg6u9cgq55dxwq8j7u4x9a0", msg.Inputs[1].Address) - require.Equal("185121068stake", msg.Inputs[1].Coins.String()) + require.Len(msg.Inputs, 1) + require.Equal("cosmos1tnh2q55v8wyygtt9srz5safamzdengsnqeycj3", msg.Inputs[0].Address) + require.Equal("114949958stake", msg.Inputs[0].Coins.String()) require.Len(msg.Outputs, 2) require.Equal("cosmos1ghekyjucln7y67ntx7cf27m9dpuxxemn4c8g4r", msg.Outputs[1].Address) - require.Equal("260469617stake", msg.Outputs[1].Coins.String()) + require.Equal("107287087stake", msg.Outputs[1].Coins.String()) require.Equal(types.TypeMsgMultiSend, msg.Type()) require.Equal(types.ModuleName, msg.Route()) require.Len(futureOperations, 0) diff --git a/x/bank/types/errors.go b/x/bank/types/errors.go index 8446d957b678..13af3dc4c7aa 100644 --- a/x/bank/types/errors.go +++ b/x/bank/types/errors.go @@ -12,4 +12,6 @@ var ( ErrSendDisabled = sdkerrors.Register(ModuleName, 5, "send transactions are disabled") ErrDenomMetadataNotFound = sdkerrors.Register(ModuleName, 6, "client denom metadata not found") ErrInvalidKey = sdkerrors.Register(ModuleName, 7, "invalid key") + ErrDuplicateEntry = sdkerrors.Register(ModuleName, 8, "duplicate entry") + ErrMultipleSenders = sdkerrors.Register(ModuleName, 9, "multiple senders not allowed") ) diff --git a/x/bank/types/msgs.go b/x/bank/types/msgs.go index d9806b55835e..d7b7392e9231 100644 --- a/x/bank/types/msgs.go +++ b/x/bank/types/msgs.go @@ -73,17 +73,22 @@ func (msg MsgMultiSend) Type() string { return TypeMsgMultiSend } // ValidateBasic Implements Msg. func (msg MsgMultiSend) ValidateBasic() error { - // this just makes sure all the inputs and outputs are properly formatted, + // this just makes sure the input and all the outputs are properly formatted, // not that they actually have the money inside + if len(msg.Inputs) == 0 { return ErrNoInputs } + if len(msg.Inputs) != 1 { + return ErrMultipleSenders + } + if len(msg.Outputs) == 0 { return ErrNoOutputs } - return ValidateInputsOutputs(msg.Inputs, msg.Outputs) + return ValidateInputOutputs(msg.Inputs[0], msg.Outputs) } // GetSignBytes Implements Msg. @@ -156,18 +161,15 @@ func NewOutput(addr sdk.AccAddress, coins sdk.Coins) Output { } } -// ValidateInputsOutputs validates that each respective input and output is +// ValidateInputOutputs validates that each respective input and output is // valid and that the sum of inputs is equal to the sum of outputs. -func ValidateInputsOutputs(inputs []Input, outputs []Output) error { +func ValidateInputOutputs(input Input, outputs []Output) error { var totalIn, totalOut sdk.Coins - for _, in := range inputs { - if err := in.ValidateBasic(); err != nil { - return err - } - - totalIn = totalIn.Add(in.Coins...) + if err := input.ValidateBasic(); err != nil { + return err } + totalIn = input.Coins for _, out := range outputs { if err := out.ValidateBasic(); err != nil { diff --git a/x/bank/types/msgs_test.go b/x/bank/types/msgs_test.go index 523a57f28f38..3a004009a7a2 100644 --- a/x/bank/types/msgs_test.go +++ b/x/bank/types/msgs_test.go @@ -179,18 +179,20 @@ func TestMsgMultiSendValidation(t *testing.T) { var emptyAddr sdk.AccAddress cases := []struct { - valid bool - tx MsgMultiSend + valid bool + tx MsgMultiSend + expErrMsg string }{ - {false, MsgMultiSend{}}, // no input or output - {false, MsgMultiSend{Inputs: []Input{input1}}}, // just input - {false, MsgMultiSend{Outputs: []Output{output1}}}, // just output + {false, MsgMultiSend{}, "no inputs to send transaction"}, // no input or output + {false, MsgMultiSend{Inputs: []Input{input1}}, "no outputs to send transaction"}, // just input + {false, MsgMultiSend{Outputs: []Output{output1}}, "no inputs to send transaction"}, // just output { false, MsgMultiSend{ Inputs: []Input{NewInput(emptyAddr, atom123)}, // invalid input Outputs: []Output{output1}, }, + "invalid input address", }, { false, @@ -198,6 +200,7 @@ func TestMsgMultiSendValidation(t *testing.T) { Inputs: []Input{input1}, Outputs: []Output{{emptyAddr.String(), atom123}}, // invalid output }, + "invalid output address", }, { false, @@ -205,6 +208,7 @@ func TestMsgMultiSendValidation(t *testing.T) { Inputs: []Input{input1}, Outputs: []Output{output2}, // amounts dont match }, + "sum inputs != sum outputs", }, { true, @@ -212,13 +216,15 @@ func TestMsgMultiSendValidation(t *testing.T) { Inputs: []Input{input1}, Outputs: []Output{output1}, }, + "", }, { - true, + false, MsgMultiSend{ Inputs: []Input{input1, input2}, Outputs: []Output{outputMulti}, }, + "multiple senders not allowed", }, { true, @@ -226,6 +232,7 @@ func TestMsgMultiSendValidation(t *testing.T) { Inputs: []Input{NewInput(addr2, atom123.MulInt(types.NewInt(2)))}, Outputs: []Output{output1, output1}, }, + "", }, } @@ -233,8 +240,10 @@ func TestMsgMultiSendValidation(t *testing.T) { err := tc.tx.ValidateBasic() if tc.valid { require.Nil(t, err, "%d: %+v", i, err) + require.Nil(t, err) } else { require.NotNil(t, err, "%d", i) + require.Contains(t, err.Error(), tc.expErrMsg) } } }