From b44feeb8d4a7f137f54dd252a8b001a53ef47d14 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Wed, 3 Apr 2024 19:49:53 -0400 Subject: [PATCH] Remove `linked.Hashmap` locking (#2911) --- .../inbound_msg_byte_throttler_test.go | 3 + snow/networking/router/chain_router_test.go | 17 +++++ utils/linked/hashmap.go | 64 ++----------------- vms/avm/txs/mempool/mempool.go | 12 +++- vms/platformvm/txs/mempool/mempool.go | 9 +++ 5 files changed, 46 insertions(+), 59 deletions(-) diff --git a/network/throttling/inbound_msg_byte_throttler_test.go b/network/throttling/inbound_msg_byte_throttler_test.go index 52ffcf83c67..4fc931e3f37 100644 --- a/network/throttling/inbound_msg_byte_throttler_test.go +++ b/network/throttling/inbound_msg_byte_throttler_test.go @@ -422,13 +422,16 @@ func TestMsgThrottlerNextMsg(t *testing.T) { // Release 1 byte throttler.release(&msgMetadata{msgSize: 1}, vdr1ID) + // Byte should have gone toward next validator message + throttler.lock.Lock() require.Equal(2, throttler.waitingToAcquire.Len()) require.Contains(throttler.nodeToWaitingMsgID, vdr1ID) firstMsgID := throttler.nodeToWaitingMsgID[vdr1ID] firstMsg, exists := throttler.waitingToAcquire.Get(firstMsgID) require.True(exists) require.Equal(maxBytes-2, firstMsg.bytesNeeded) + throttler.lock.Unlock() select { case <-doneVdr: diff --git a/snow/networking/router/chain_router_test.go b/snow/networking/router/chain_router_test.go index 43ccfa09dbf..18a224703ed 100644 --- a/snow/networking/router/chain_router_test.go +++ b/snow/networking/router/chain_router_test.go @@ -811,7 +811,9 @@ func TestRouterHonorsRequestedEngine(t *testing.T) { chainRouter.HandleInbound(context.Background(), msg) } + chainRouter.lock.Lock() require.Zero(chainRouter.timedRequests.Len()) + chainRouter.lock.Unlock() } func TestRouterClearTimeouts(t *testing.T) { @@ -897,7 +899,10 @@ func TestRouterClearTimeouts(t *testing.T) { ) chainRouter.HandleInbound(context.Background(), tt.responseMsg) + + chainRouter.lock.Lock() require.Zero(chainRouter.timedRequests.Len()) + chainRouter.lock.Unlock() }) } } @@ -1383,7 +1388,9 @@ func TestAppRequest(t *testing.T) { if tt.inboundMsg == nil || tt.inboundMsg.Op() == message.AppErrorOp { engine.AppRequestFailedF = func(_ context.Context, nodeID ids.NodeID, requestID uint32, appErr *common.AppError) error { defer wg.Done() + chainRouter.lock.Lock() require.Zero(chainRouter.timedRequests.Len()) + chainRouter.lock.Unlock() require.Equal(ids.EmptyNodeID, nodeID) require.Equal(wantRequestID, requestID) @@ -1395,7 +1402,9 @@ func TestAppRequest(t *testing.T) { } else if tt.inboundMsg.Op() == message.AppResponseOp { engine.AppResponseF = func(_ context.Context, nodeID ids.NodeID, requestID uint32, msg []byte) error { defer wg.Done() + chainRouter.lock.Lock() require.Zero(chainRouter.timedRequests.Len()) + chainRouter.lock.Unlock() require.Equal(ids.EmptyNodeID, nodeID) require.Equal(wantRequestID, requestID) @@ -1407,7 +1416,9 @@ func TestAppRequest(t *testing.T) { ctx := context.Background() chainRouter.RegisterRequest(ctx, ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, tt.responseOp, tt.timeoutMsg, engineType) + chainRouter.lock.Lock() require.Equal(1, chainRouter.timedRequests.Len()) + chainRouter.lock.Unlock() if tt.inboundMsg != nil { chainRouter.HandleInbound(ctx, tt.inboundMsg) @@ -1465,7 +1476,9 @@ func TestCrossChainAppRequest(t *testing.T) { if tt.inboundMsg == nil || tt.inboundMsg.Op() == message.CrossChainAppErrorOp { engine.CrossChainAppRequestFailedF = func(_ context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error { defer wg.Done() + chainRouter.lock.Lock() require.Zero(chainRouter.timedRequests.Len()) + chainRouter.lock.Unlock() require.Equal(ids.Empty, chainID) require.Equal(wantRequestID, requestID) @@ -1477,7 +1490,9 @@ func TestCrossChainAppRequest(t *testing.T) { } else if tt.inboundMsg.Op() == message.CrossChainAppResponseOp { engine.CrossChainAppResponseF = func(_ context.Context, chainID ids.ID, requestID uint32, msg []byte) error { defer wg.Done() + chainRouter.lock.Lock() require.Zero(chainRouter.timedRequests.Len()) + chainRouter.lock.Unlock() require.Equal(ids.Empty, chainID) require.Equal(wantRequestID, requestID) @@ -1489,7 +1504,9 @@ func TestCrossChainAppRequest(t *testing.T) { ctx := context.Background() chainRouter.RegisterRequest(ctx, ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, tt.responseOp, tt.timeoutMsg, engineType) + chainRouter.lock.Lock() require.Equal(1, chainRouter.timedRequests.Len()) + chainRouter.lock.Unlock() if tt.inboundMsg != nil { chainRouter.HandleInbound(ctx, tt.inboundMsg) diff --git a/utils/linked/hashmap.go b/utils/linked/hashmap.go index da68e97f663..b17b7b60972 100644 --- a/utils/linked/hashmap.go +++ b/utils/linked/hashmap.go @@ -3,11 +3,7 @@ package linked -import ( - "sync" - - "github.com/ava-labs/avalanchego/utils" -) +import "github.com/ava-labs/avalanchego/utils" type keyValue[K, V any] struct { key K @@ -18,7 +14,6 @@ type keyValue[K, V any] struct { // // Entries are tracked by insertion order. type Hashmap[K comparable, V any] struct { - lock sync.RWMutex entryMap map[K]*ListElement[keyValue[K, V]] entryList *List[keyValue[K, V]] freeList []*ListElement[keyValue[K, V]] @@ -31,49 +26,7 @@ func NewHashmap[K comparable, V any]() *Hashmap[K, V] { } } -func (lh *Hashmap[K, V]) Put(key K, val V) { - lh.lock.Lock() - defer lh.lock.Unlock() - - lh.put(key, val) -} - -func (lh *Hashmap[K, V]) Get(key K) (V, bool) { - lh.lock.RLock() - defer lh.lock.RUnlock() - - return lh.get(key) -} - -func (lh *Hashmap[K, V]) Delete(key K) bool { - lh.lock.Lock() - defer lh.lock.Unlock() - - return lh.delete(key) -} - -func (lh *Hashmap[K, V]) Len() int { - lh.lock.RLock() - defer lh.lock.RUnlock() - - return lh.len() -} - -func (lh *Hashmap[K, V]) Oldest() (K, V, bool) { - lh.lock.RLock() - defer lh.lock.RUnlock() - - return lh.oldest() -} - -func (lh *Hashmap[K, V]) Newest() (K, V, bool) { - lh.lock.RLock() - defer lh.lock.RUnlock() - - return lh.newest() -} - -func (lh *Hashmap[K, V]) put(key K, value V) { +func (lh *Hashmap[K, V]) Put(key K, value V) { if e, ok := lh.entryMap[key]; ok { lh.entryList.MoveToBack(e) e.Value = keyValue[K, V]{ @@ -100,14 +53,14 @@ func (lh *Hashmap[K, V]) put(key K, value V) { lh.entryList.PushBack(e) } -func (lh *Hashmap[K, V]) get(key K) (V, bool) { +func (lh *Hashmap[K, V]) Get(key K) (V, bool) { if e, ok := lh.entryMap[key]; ok { return e.Value.value, true } return utils.Zero[V](), false } -func (lh *Hashmap[K, V]) delete(key K) bool { +func (lh *Hashmap[K, V]) Delete(key K) bool { e, ok := lh.entryMap[key] if ok { lh.entryList.Remove(e) @@ -118,18 +71,18 @@ func (lh *Hashmap[K, V]) delete(key K) bool { return ok } -func (lh *Hashmap[K, V]) len() int { +func (lh *Hashmap[K, V]) Len() int { return len(lh.entryMap) } -func (lh *Hashmap[K, V]) oldest() (K, V, bool) { +func (lh *Hashmap[K, V]) Oldest() (K, V, bool) { if e := lh.entryList.Front(); e != nil { return e.Value.key, e.Value.value, true } return utils.Zero[K](), utils.Zero[V](), false } -func (lh *Hashmap[K, V]) newest() (K, V, bool) { +func (lh *Hashmap[K, V]) Newest() (K, V, bool) { if e := lh.entryList.Back(); e != nil { return e.Value.key, e.Value.value, true } @@ -160,9 +113,6 @@ func (it *Iterator[K, V]) Next() bool { return false } - it.lh.lock.RLock() - defer it.lh.lock.RUnlock() - // If the iterator was not yet initialized, do it now. if !it.initialized { it.initialized = true diff --git a/vms/avm/txs/mempool/mempool.go b/vms/avm/txs/mempool/mempool.go index 51086df34ae..c761ae09795 100644 --- a/vms/avm/txs/mempool/mempool.go +++ b/vms/avm/txs/mempool/mempool.go @@ -160,8 +160,10 @@ func (m *mempool) Add(tx *txs.Tx) error { } func (m *mempool) Get(txID ids.ID) (*txs.Tx, bool) { - tx, ok := m.unissuedTxs.Get(txID) - return tx, ok + m.lock.RLock() + defer m.lock.RUnlock() + + return m.unissuedTxs.Get(txID) } func (m *mempool) Remove(txs ...*txs.Tx) { @@ -190,6 +192,9 @@ func (m *mempool) Remove(txs ...*txs.Tx) { } func (m *mempool) Peek() (*txs.Tx, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + _, tx, exists := m.unissuedTxs.Oldest() return tx, exists } @@ -207,6 +212,9 @@ func (m *mempool) Iterate(f func(*txs.Tx) bool) { } func (m *mempool) RequestBuildBlock() { + m.lock.RLock() + defer m.lock.RUnlock() + if m.unissuedTxs.Len() == 0 { return } diff --git a/vms/platformvm/txs/mempool/mempool.go b/vms/platformvm/txs/mempool/mempool.go index d432896a76b..b45213719b6 100644 --- a/vms/platformvm/txs/mempool/mempool.go +++ b/vms/platformvm/txs/mempool/mempool.go @@ -174,6 +174,9 @@ func (m *mempool) Add(tx *txs.Tx) error { } func (m *mempool) Get(txID ids.ID) (*txs.Tx, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + return m.unissuedTxs.Get(txID) } @@ -203,6 +206,9 @@ func (m *mempool) Remove(txs ...*txs.Tx) { } func (m *mempool) Peek() (*txs.Tx, bool) { + m.lock.RLock() + defer m.lock.RUnlock() + _, tx, exists := m.unissuedTxs.Oldest() return tx, exists } @@ -240,6 +246,9 @@ func (m *mempool) GetDropReason(txID ids.ID) error { } func (m *mempool) RequestBuildBlock(emptyBlockPermitted bool) { + m.lock.RLock() + defer m.lock.RUnlock() + if !emptyBlockPermitted && m.unissuedTxs.Len() == 0 { return }