From f92b843786a63515fd960a42e24edb168dcd3f2c Mon Sep 17 00:00:00 2001 From: Steven Landers Date: Mon, 29 Jan 2024 02:29:17 -0500 Subject: [PATCH] [EVM] prevent duplicate txs from getting inserted (#196) * prevent duplicates in mempool * use timestamp in priority queue --- internal/consensus/mempool_test.go | 5 +- internal/mempool/mempool.go | 32 ++++++--- internal/mempool/priority_queue.go | 90 +++++++++++++++++++------ internal/mempool/priority_queue_test.go | 1 + internal/mempool/tx.go | 6 ++ 5 files changed, 102 insertions(+), 32 deletions(-) diff --git a/internal/consensus/mempool_test.go b/internal/consensus/mempool_test.go index 84badbeeb..6ef3849ad 100644 --- a/internal/consensus/mempool_test.go +++ b/internal/consensus/mempool_test.go @@ -139,7 +139,7 @@ func checkTxsRange(ctx context.Context, t *testing.T, cs *State, start, end int) var rCode uint32 err := assertMempool(t, cs.txNotifier).CheckTx(ctx, txBytes, func(r *abci.ResponseCheckTx) { rCode = r.Code }, mempool.TxInfo{}) require.NoError(t, err, "error after checkTx") - require.Equal(t, code.CodeTypeOK, rCode, "checkTx code is error, txBytes %X", txBytes) + require.Equal(t, code.CodeTypeOK, rCode, "checkTx code is error, txBytes %X, index=%d", txBytes, i) } } @@ -166,7 +166,7 @@ func TestMempoolTxConcurrentWithCommit(t *testing.T) { require.NoError(t, err) newBlockHeaderCh := subscribe(ctx, t, cs.eventBus, types.EventQueryNewBlockHeader) - const numTxs int64 = 100 + const numTxs int64 = 50 go checkTxsRange(ctx, t, cs, 0, int(numTxs)) startTestRound(ctx, cs, cs.roundState.Height(), cs.roundState.Round()) @@ -331,7 +331,6 @@ func txAsUint64(tx []byte) uint64 { func (app *CounterApplication) Commit(context.Context) (*abci.ResponseCommit, error) { app.mu.Lock() defer app.mu.Unlock() - app.mempoolTxCount = app.txCount return &abci.ResponseCommit{}, nil } diff --git a/internal/mempool/mempool.go b/internal/mempool/mempool.go index 9bed03a8a..612a0bc48 100644 --- a/internal/mempool/mempool.go +++ b/internal/mempool/mempool.go @@ -332,6 +332,11 @@ func (txmp *TxMempool) CheckTx( return nil } +func (txmp *TxMempool) isInMempool(tx types.Tx) bool { + existingTx := txmp.txStore.GetTxByHash(tx.Key()) + return existingTx != nil && !existingTx.removed +} + func (txmp *TxMempool) RemoveTxByKey(txKey types.TxKey) error { txmp.Lock() defer txmp.Unlock() @@ -635,15 +640,17 @@ func (txmp *TxMempool) addNewTransaction(wtx *WrappedTx, res *abci.ResponseCheck txmp.metrics.Size.Set(float64(txmp.Size())) txmp.metrics.PendingSize.Set(float64(txmp.PendingSize())) - txmp.insertTx(wtx) - txmp.logger.Debug( - "inserted good transaction", - "priority", wtx.priority, - "tx", fmt.Sprintf("%X", wtx.tx.Hash()), - "height", txmp.height, - "num_txs", txmp.Size(), - ) - txmp.notifyTxsAvailable() + if txmp.insertTx(wtx) { + txmp.logger.Debug( + "inserted good transaction", + "priority", wtx.priority, + "tx", fmt.Sprintf("%X", wtx.tx.Hash()), + "height", txmp.height, + "num_txs", txmp.Size(), + ) + txmp.notifyTxsAvailable() + } + return nil } @@ -809,7 +816,11 @@ func (txmp *TxMempool) canAddTx(wtx *WrappedTx) error { return nil } -func (txmp *TxMempool) insertTx(wtx *WrappedTx) { +func (txmp *TxMempool) insertTx(wtx *WrappedTx) bool { + if txmp.isInMempool(wtx.tx) { + return false + } + txmp.txStore.SetTx(wtx) txmp.priorityIndex.PushTx(wtx) txmp.heightIndex.Insert(wtx) @@ -822,6 +833,7 @@ func (txmp *TxMempool) insertTx(wtx *WrappedTx) { wtx.gossipEl = gossipEl atomic.AddInt64(&txmp.sizeBytes, int64(wtx.Size())) + return true } func (txmp *TxMempool) removeTx(wtx *WrappedTx, removeFromCache bool) { diff --git a/internal/mempool/priority_queue.go b/internal/mempool/priority_queue.go index 6dbbfe9b2..3de4e810f 100644 --- a/internal/mempool/priority_queue.go +++ b/internal/mempool/priority_queue.go @@ -30,7 +30,7 @@ func binarySearch(queue []*WrappedTx, tx *WrappedTx) int { low, high := 0, len(queue) for low < high { mid := low + (high-low)/2 - if queue[mid].evmNonce <= tx.evmNonce { + if queue[mid].IsBefore(tx) { low = mid + 1 } else { high = mid @@ -118,11 +118,6 @@ func (pq *TxPriorityQueue) removeQueuedEvmTxUnsafe(tx *WrappedTx) { pq.evmQueue[tx.evmAddress] = append(queue[:i], queue[i+1:]...) if len(pq.evmQueue[tx.evmAddress]) == 0 { delete(pq.evmQueue, tx.evmAddress) - } else { - // only if removing the first item, then push next onto queue - if i == 0 { - heap.Push(pq, pq.evmQueue[tx.evmAddress][0]) - } } break } @@ -132,7 +127,7 @@ func (pq *TxPriorityQueue) removeQueuedEvmTxUnsafe(tx *WrappedTx) { func (pq *TxPriorityQueue) findTxIndexUnsafe(tx *WrappedTx) (int, bool) { for i, t := range pq.txs { - if t == tx { + if t.tx.Key() == tx.tx.Key() { return i, true } } @@ -146,9 +141,13 @@ func (pq *TxPriorityQueue) RemoveTx(tx *WrappedTx) { if idx, ok := pq.findTxIndexUnsafe(tx); ok { heap.Remove(pq, idx) - } - - if tx.isEVM { + if tx.isEVM { + pq.removeQueuedEvmTxUnsafe(tx) + if len(pq.evmQueue[tx.evmAddress]) > 0 { + heap.Push(pq, pq.evmQueue[tx.evmAddress][0]) + } + } + } else if tx.isEVM { pq.removeQueuedEvmTxUnsafe(tx) } } @@ -159,6 +158,7 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) { return } + // if there aren't other waiting txs, init and return queue, exists := pq.evmQueue[tx.evmAddress] if !exists { pq.evmQueue[tx.evmAddress] = []*WrappedTx{tx} @@ -166,29 +166,45 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) { return } + // this item is on the heap at the moment first := queue[0] - if tx.evmNonce < first.evmNonce { + + // the queue's first item (and ONLY the first item) must be on the heap + // if this tx is before the first item, then we need to remove the first + // item from the heap + if tx.IsBefore(first) { if idx, ok := pq.findTxIndexUnsafe(first); ok { heap.Remove(pq, idx) } heap.Push(pq, tx) } - pq.evmQueue[tx.evmAddress] = insertToEVMQueue(queue, tx, binarySearch(queue, tx)) - } // These are available if we need to test the invariant checks // these can be used to troubleshoot invariant violations //func (pq *TxPriorityQueue) checkInvariants(msg string) { -// // uniqHashes := make(map[string]bool) -// for _, tx := range pq.txs { +// for idx, tx := range pq.txs { +// if tx == nil { +// pq.print() +// panic(fmt.Sprintf("DEBUG PRINT: found nil item on heap: idx=%d\n", idx)) +// } +// if tx.tx == nil { +// pq.print() +// panic(fmt.Sprintf("DEBUG PRINT: found nil tx.tx on heap: idx=%d\n", idx)) +// } // if _, ok := uniqHashes[fmt.Sprintf("%x", tx.tx.Key())]; ok { // pq.print() // panic(fmt.Sprintf("INVARIANT (%s): duplicate hash=%x in heap", msg, tx.tx.Key())) // } // uniqHashes[fmt.Sprintf("%x", tx.tx.Key())] = true +// +// //if _, ok := pq.keys[tx.tx.Key()]; !ok { +// // pq.print() +// // panic(fmt.Sprintf("INVARIANT (%s): tx in heap but not in keys hash=%x", msg, tx.tx.Key())) +// //} +// // if tx.isEVM { // if queue, ok := pq.evmQueue[tx.evmAddress]; ok { // if queue[0].tx.Key() != tx.tx.Key() { @@ -213,6 +229,10 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) { // panic(fmt.Sprintf("INVARIANT (%s): did not find tx[0] hash=%x nonce=%d in heap", msg, tx.tx.Key(), tx.evmNonce)) // } // } +// //if _, ok := pq.keys[tx.tx.Key()]; !ok { +// // pq.print() +// // panic(fmt.Sprintf("INVARIANT (%s): tx in heap but not in keys hash=%x", msg, tx.tx.Key())) +// //} // if _, ok := hashes[fmt.Sprintf("%x", tx.tx.Key())]; ok { // pq.print() // panic(fmt.Sprintf("INVARIANT (%s): duplicate hash=%x in queue nonce=%d", msg, tx.tx.Key(), tx.evmNonce)) @@ -224,13 +244,31 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) { // for debugging situations where invariant violations occur //func (pq *TxPriorityQueue) print() { +// fmt.Println("PRINT PRIORITY QUEUE ****************** ") // for _, tx := range pq.txs { -// fmt.Printf("DEBUG PRINT: heap: nonce=%d, hash=%x\n", tx.evmNonce, tx.tx.Key()) +// if tx == nil { +// fmt.Printf("DEBUG PRINT: heap (nil): nonce=?, hash=?\n") +// continue +// } +// if tx.tx == nil { +// fmt.Printf("DEBUG PRINT: heap (%s): nonce=%d, tx.tx is nil \n", tx.evmAddress, tx.evmNonce) +// continue +// } +// fmt.Printf("DEBUG PRINT: heap (%s): nonce=%d, hash=%x, time=%d\n", tx.evmAddress, tx.evmNonce, tx.tx.Key(), tx.timestamp.UnixNano()) // } // -// for _, queue := range pq.evmQueue { +// for addr, queue := range pq.evmQueue { // for idx, tx := range queue { -// fmt.Printf("DEBUG PRINT: evmQueue[%d]: nonce=%d, hash=%x\n", idx, tx.evmNonce, tx.tx.Key()) +// if tx == nil { +// fmt.Printf("DEBUG PRINT: found nil item on evmQueue(%s): idx=%d\n", addr, idx) +// continue +// } +// if tx.tx == nil { +// fmt.Printf("DEBUG PRINT: found nil tx.tx on evmQueue(%s): idx=%d\n", addr, idx) +// continue +// } +// +// fmt.Printf("DEBUG PRINT: evmQueue(%s)[%d]: nonce=%d, hash=%x, time=%d\n", tx.evmAddress, idx, tx.evmNonce, tx.tx.Key(), tx.timestamp.UnixNano()) // } // } //} @@ -239,6 +277,7 @@ func (pq *TxPriorityQueue) pushTxUnsafe(tx *WrappedTx) { func (pq *TxPriorityQueue) PushTx(tx *WrappedTx) { pq.mtx.Lock() defer pq.mtx.Unlock() + pq.pushTxUnsafe(tx) } @@ -246,19 +285,31 @@ func (pq *TxPriorityQueue) popTxUnsafe() *WrappedTx { if len(pq.txs) == 0 { return nil } + + // remove the first item from the heap x := heap.Pop(pq) if x == nil { return nil } - tx := x.(*WrappedTx) + // non-evm transactions do not have txs waiting on a nonce if !tx.isEVM { return tx } + // evm transactions can have txs waiting on this nonce + // if there are any, we should replace the heap with the next nonce + // for the address + + // remove the first item from the evmQueue pq.removeQueuedEvmTxUnsafe(tx) + // if there is a next item, now it can be added to the heap + if len(pq.evmQueue[tx.evmAddress]) > 0 { + heap.Push(pq, pq.evmQueue[tx.evmAddress][0]) + } + return tx } @@ -266,6 +317,7 @@ func (pq *TxPriorityQueue) popTxUnsafe() *WrappedTx { func (pq *TxPriorityQueue) PopTx() *WrappedTx { pq.mtx.Lock() defer pq.mtx.Unlock() + return pq.popTxUnsafe() } diff --git a/internal/mempool/priority_queue_test.go b/internal/mempool/priority_queue_test.go index c1e17d278..8cb6d2e1a 100644 --- a/internal/mempool/priority_queue_test.go +++ b/internal/mempool/priority_queue_test.go @@ -196,6 +196,7 @@ func TestTxPriorityQueue(t *testing.T) { pq.PushTx(&WrappedTx{ priority: 1000, timestamp: now, + tx: []byte(fmt.Sprintf("%d", time.Now().UnixNano())), }) require.Equal(t, 1001, pq.NumTxs()) diff --git a/internal/mempool/tx.go b/internal/mempool/tx.go index 025cfda73..13ddb0a12 100644 --- a/internal/mempool/tx.go +++ b/internal/mempool/tx.go @@ -74,6 +74,12 @@ type WrappedTx struct { isEVM bool } +// IsBefore returns true if the WrappedTx is before the given WrappedTx +// this applies to EVM transactions only +func (wtx *WrappedTx) IsBefore(tx *WrappedTx) bool { + return wtx.evmNonce < tx.evmNonce || (wtx.evmNonce == tx.evmNonce && wtx.timestamp.Before(tx.timestamp)) +} + func (wtx *WrappedTx) Size() int { return len(wtx.tx) }