Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem: mempool iteration is not thread safe #699

Merged
merged 8 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
### Bug Fixes

* (x/bank) [#20028](https://github.com/cosmos/cosmos-sdk/pull/20028) Align query with multi denoms for send-enabled.
* (baseapp) [#699](https://github.com/crypto-org-chain/cosmos-sdk/pull/699) Fix data race in mempool iteration.

## [Unreleased-Upstream]

Expand Down
32 changes: 19 additions & 13 deletions baseapp/abci_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,16 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
return &abci.ResponsePrepareProposal{Txs: h.txSelector.SelectedTxs(ctx)}, nil
}

iterator := h.mempool.Select(ctx, req.Txs)
selectedTxsSignersSeqs := make(map[string]uint64)
var selectedTxsNums int
for iterator != nil {
memTx := iterator.Tx()
signerData, err := h.signerExtAdapter.GetSigners(memTx.Tx)
var (
err error
selectedTxsNums int
)
h.mempool.SelectBy(ctx, req.Txs, func(memTx mempool.Tx) bool {
var signerData []mempool.SignerData
signerData, err = h.signerExtAdapter.GetSigners(memTx.Tx)
if err != nil {
return nil, err
return false
}

// If the signers aren't in selectedTxsSignersSeqs then we haven't seen them before
Expand All @@ -315,24 +317,24 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
txSignersSeqs[signer.Signer.String()] = signer.Sequence
}
if !shouldAdd {
iterator = iterator.Next()
continue
return true
}

// NOTE: Since transaction verification was already executed in CheckTx,
// which calls mempool.Insert, in theory everything in the pool should be
// valid. But some mempool implementations may insert invalid txs, so we
// check again.
txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx.Tx)
var txBz []byte
txBz, err = h.txVerifier.PrepareProposalVerifyTx(memTx.Tx)
if err != nil {
err := h.mempool.Remove(memTx.Tx)
err = h.mempool.Remove(memTx.Tx)
if err != nil && !errors.Is(err, mempool.ErrTxNotFound) {
return nil, err
return false
}
} else {
stop := h.txSelector.SelectTxForProposal(ctx, uint64(req.MaxTxBytes), maxBlockGas, memTx.Tx, txBz, memTx.GasWanted)
if stop {
break
return false
}

txsLen := len(h.txSelector.SelectedTxs(ctx))
Expand All @@ -353,7 +355,11 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
selectedTxsNums = txsLen
}

iterator = iterator.Next()
return true
})

if err != nil {
return nil, err
}

return &abci.ResponsePrepareProposal{Txs: h.txSelector.SelectedTxs(ctx)}, nil
Expand Down
3 changes: 3 additions & 0 deletions types/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ type Mempool interface {
// closed by the caller.
Select(context.Context, [][]byte) Iterator

// SelectBy use callback to iterate over the mempool.
SelectBy(context.Context, [][]byte, func(Tx) bool)

// CountTx returns the number of transactions currently in the mempool.
CountTx() int

Expand Down
1 change: 1 addition & 0 deletions types/mempool/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ type NoOpMempool struct{}
func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil }
func (NoOpMempool) InsertWithGasWanted(context.Context, sdk.Tx, uint64) error { return nil }
func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil }
func (NoOpMempool) SelectBy(context.Context, [][]byte, func(Tx) bool) {}
func (NoOpMempool) CountTx() int { return 0 }
func (NoOpMempool) Remove(sdk.Tx) error { return nil }
21 changes: 21 additions & 0 deletions types/mempool/priority_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,27 @@ func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterato
return iterator.iteratePriority()
}

func (mp *PriorityNonceMempool[C]) SelectBy(_ context.Context, _ [][]byte, callback func(Tx) bool) {
mp.mtx.Lock()
defer mp.mtx.Unlock()

if mp.priorityIndex.Len() == 0 {
return
}

mp.reorderPriorityTies()

iterator := &PriorityNonceIterator[C]{
mempool: mp,
senderCursors: make(map[string]*skiplist.Element),
}

iter := iterator.iteratePriority()
for iter != nil && callback(iter.Tx()) {
iter = iter.Next()
}
}

type reorderKey[C comparable] struct {
deleteKey txMeta[C]
insertKey txMeta[C]
Expand Down
85 changes: 85 additions & 0 deletions types/mempool/priority_nonce_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package mempool_test

import (
"context"
"fmt"
"math"
"math/rand"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -396,6 +398,89 @@ func (s *MempoolTestSuite) TestIterator() {
}
}

func (s *MempoolTestSuite) TestIteratorConcurrency() {
t := s.T()
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

tests := []struct {
txs []txSpec
fail bool
}{
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: 8, n: 2, a: sb},
},
},
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: math.MinInt64, n: 2, a: sb},
},
},
}

for i, tt := range tests {
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
pool := mempool.DefaultPriorityMempool()

// create test txs and insert into mempool
for i, ts := range tt.txs {
tx := testTx{id: i, priority: int64(ts.p), nonce: uint64(ts.n), address: ts.a}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}

// iterate through txs
stdCtx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

id := len(tt.txs)
for {
select {
case <-stdCtx.Done():
return
default:
id++
tx := testTx{id: id, priority: int64(rand.Intn(100)), nonce: uint64(id), address: sa}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}
}
}()

var i int
pool.SelectBy(ctx, nil, func(memTx mempool.Tx) bool {
tx := memTx.Tx.(testTx)
if tx.id < len(tt.txs) {
require.Equal(t, tt.txs[tx.id].p, int(tx.priority))
require.Equal(t, tt.txs[tx.id].n, int(tx.nonce))
require.Equal(t, tt.txs[tx.id].a, tx.address)
i++
}
return i < len(tt.txs)
})
require.Equal(t, i, len(tt.txs))
cancel()
wg.Wait()
})
}
}

func (s *MempoolTestSuite) TestPriorityTies() {
ctx := sdk.NewContext(nil, cmtproto.Header{}, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 3)
Expand Down
33 changes: 33 additions & 0 deletions types/mempool/sender_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,39 @@
return iter.Next()
}

func (snm *SenderNonceMempool) SelectBy(_ context.Context, _ [][]byte, callback func(Tx) bool) {
snm.mtx.Lock()
defer snm.mtx.Unlock()
var senders []string

senderCursors := make(map[string]*skiplist.Element)
orderedSenders := skiplist.New(skiplist.String)

// #nosec
for s := range snm.senders {
orderedSenders.Set(s, s)
}
Fixed Show fixed Hide fixed

s := orderedSenders.Front()
for s != nil {
sender := s.Value.(string)
senders = append(senders, sender)
senderCursors[sender] = snm.senders[sender].Front()
s = s.Next()
}

iterator := &senderNonceMempoolIterator{
senders: senders,
rnd: snm.rnd,
senderCursors: senderCursors,
}

iter := iterator.Next()
for iter != nil && callback(iter.Tx()) {
iter = iter.Next()
}
}

// CountTx returns the total count of txs in the mempool.
func (snm *SenderNonceMempool) CountTx() int {
snm.mtx.Lock()
Expand Down
Loading