diff --git a/consensus/mempool_test.go b/consensus/mempool_test.go index b4b338172..73beb8990 100644 --- a/consensus/mempool_test.go +++ b/consensus/mempool_test.go @@ -182,7 +182,17 @@ func TestMempoolRmBadTx(t *testing.T) { // check for the tx for { - txs := assertMempool(cs.txNotifier).ReapMaxBytesMaxGas(int64(len(txBytes)), -1) + txs := assertMempool(cs.txNotifier).ReapMaxTxs(1) + if len(txs) == 0 { + emptyMempoolCh <- struct{}{} + return + } + txs = assertMempool(cs.txNotifier).ReapMaxBytesMaxGasMaxTxs(int64(len(txBytes)), -1, 1) + if len(txs) == 0 { + emptyMempoolCh <- struct{}{} + return + } + txs = assertMempool(cs.txNotifier).ReapMaxBytesMaxGas(int64(len(txBytes)), -1) if len(txs) == 0 { emptyMempoolCh <- struct{}{} return diff --git a/mempool/clist_mempool_test.go b/mempool/clist_mempool_test.go index 1f0b7c5c1..f0716cfbe 100644 --- a/mempool/clist_mempool_test.go +++ b/mempool/clist_mempool_test.go @@ -114,26 +114,35 @@ func TestReapMaxBytesMaxGas(t *testing.T) { maxBytes int64 maxGas int64 expectedNumTxs int + maxTxs int64 }{ - {20, -1, -1, 20}, - {20, -1, 0, 0}, - {20, -1, 10, 10}, - {20, -1, 30, 20}, - {20, 0, -1, 0}, - {20, 0, 10, 0}, - {20, 10, 10, 0}, - {20, 24, 10, 1}, - {20, 240, 5, 5}, - {20, 240, -1, 10}, - {20, 240, 10, 10}, - {20, 240, 15, 10}, - {20, 20000, -1, 20}, - {20, 20000, 5, 5}, - {20, 20000, 30, 20}, + {20, -1, -1, 20, 0}, + {20, -1, 0, 0, 0}, + {20, -1, 10, 10, 0}, + {20, -1, 30, 20, 0}, + {20, 0, -1, 0, 0}, + {20, 0, 10, 0, 0}, + {20, 10, 10, 0, 0}, + {20, 24, 10, 1, 0}, + {20, 240, 5, 5, 0}, + {20, 240, -1, 10, 0}, + {20, 240, 10, 10, 0}, + {20, 240, 15, 10, 0}, + {20, 20000, -1, 20, 0}, + {20, 20000, 5, 5, 0}, + {20, 20000, 30, 20, 0}, + {20, 20000, 30, 20, 0}, + {20, 20000, 30, 10, 10}, + {20, 20000, 30, 20, 100}, } for tcIndex, tt := range tests { checkTxs(t, mempool, tt.numTxsToCreate, UnknownPeerID) - got := mempool.ReapMaxBytesMaxGas(tt.maxBytes, tt.maxGas) + var got types.Txs + if tt.maxTxs <= 0 { + got = mempool.ReapMaxBytesMaxGas(tt.maxBytes, tt.maxGas) + } else { + got = mempool.ReapMaxBytesMaxGasMaxTxs(tt.maxBytes, tt.maxGas, tt.maxTxs) + } assert.Equal(t, tt.expectedNumTxs, len(got), "Got %d txs, expected %d, tc #%d", len(got), tt.expectedNumTxs, tcIndex) mempool.Flush()