Skip to content

Commit

Permalink
Remove linked.Hashmap locking (#2911)
Browse files Browse the repository at this point in the history
  • Loading branch information
StephenButtolph authored Apr 3, 2024
1 parent 0fea82e commit b44feeb
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 59 deletions.
3 changes: 3 additions & 0 deletions network/throttling/inbound_msg_byte_throttler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,13 +422,16 @@ func TestMsgThrottlerNextMsg(t *testing.T) {

// Release 1 byte
throttler.release(&msgMetadata{msgSize: 1}, vdr1ID)

// Byte should have gone toward next validator message
throttler.lock.Lock()
require.Equal(2, throttler.waitingToAcquire.Len())
require.Contains(throttler.nodeToWaitingMsgID, vdr1ID)
firstMsgID := throttler.nodeToWaitingMsgID[vdr1ID]
firstMsg, exists := throttler.waitingToAcquire.Get(firstMsgID)
require.True(exists)
require.Equal(maxBytes-2, firstMsg.bytesNeeded)
throttler.lock.Unlock()

select {
case <-doneVdr:
Expand Down
17 changes: 17 additions & 0 deletions snow/networking/router/chain_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,9 @@ func TestRouterHonorsRequestedEngine(t *testing.T) {
chainRouter.HandleInbound(context.Background(), msg)
}

chainRouter.lock.Lock()
require.Zero(chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()
}

func TestRouterClearTimeouts(t *testing.T) {
Expand Down Expand Up @@ -897,7 +899,10 @@ func TestRouterClearTimeouts(t *testing.T) {
)

chainRouter.HandleInbound(context.Background(), tt.responseMsg)

chainRouter.lock.Lock()
require.Zero(chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()
})
}
}
Expand Down Expand Up @@ -1383,7 +1388,9 @@ func TestAppRequest(t *testing.T) {
if tt.inboundMsg == nil || tt.inboundMsg.Op() == message.AppErrorOp {
engine.AppRequestFailedF = func(_ context.Context, nodeID ids.NodeID, requestID uint32, appErr *common.AppError) error {
defer wg.Done()
chainRouter.lock.Lock()
require.Zero(chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()

require.Equal(ids.EmptyNodeID, nodeID)
require.Equal(wantRequestID, requestID)
Expand All @@ -1395,7 +1402,9 @@ func TestAppRequest(t *testing.T) {
} else if tt.inboundMsg.Op() == message.AppResponseOp {
engine.AppResponseF = func(_ context.Context, nodeID ids.NodeID, requestID uint32, msg []byte) error {
defer wg.Done()
chainRouter.lock.Lock()
require.Zero(chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()

require.Equal(ids.EmptyNodeID, nodeID)
require.Equal(wantRequestID, requestID)
Expand All @@ -1407,7 +1416,9 @@ func TestAppRequest(t *testing.T) {

ctx := context.Background()
chainRouter.RegisterRequest(ctx, ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, tt.responseOp, tt.timeoutMsg, engineType)
chainRouter.lock.Lock()
require.Equal(1, chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()

if tt.inboundMsg != nil {
chainRouter.HandleInbound(ctx, tt.inboundMsg)
Expand Down Expand Up @@ -1465,7 +1476,9 @@ func TestCrossChainAppRequest(t *testing.T) {
if tt.inboundMsg == nil || tt.inboundMsg.Op() == message.CrossChainAppErrorOp {
engine.CrossChainAppRequestFailedF = func(_ context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error {
defer wg.Done()
chainRouter.lock.Lock()
require.Zero(chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()

require.Equal(ids.Empty, chainID)
require.Equal(wantRequestID, requestID)
Expand All @@ -1477,7 +1490,9 @@ func TestCrossChainAppRequest(t *testing.T) {
} else if tt.inboundMsg.Op() == message.CrossChainAppResponseOp {
engine.CrossChainAppResponseF = func(_ context.Context, chainID ids.ID, requestID uint32, msg []byte) error {
defer wg.Done()
chainRouter.lock.Lock()
require.Zero(chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()

require.Equal(ids.Empty, chainID)
require.Equal(wantRequestID, requestID)
Expand All @@ -1489,7 +1504,9 @@ func TestCrossChainAppRequest(t *testing.T) {

ctx := context.Background()
chainRouter.RegisterRequest(ctx, ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, tt.responseOp, tt.timeoutMsg, engineType)
chainRouter.lock.Lock()
require.Equal(1, chainRouter.timedRequests.Len())
chainRouter.lock.Unlock()

if tt.inboundMsg != nil {
chainRouter.HandleInbound(ctx, tt.inboundMsg)
Expand Down
64 changes: 7 additions & 57 deletions utils/linked/hashmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@

package linked

import (
"sync"

"github.com/ava-labs/avalanchego/utils"
)
import "github.com/ava-labs/avalanchego/utils"

type keyValue[K, V any] struct {
key K
Expand All @@ -18,7 +14,6 @@ type keyValue[K, V any] struct {
//
// Entries are tracked by insertion order.
type Hashmap[K comparable, V any] struct {
lock sync.RWMutex
entryMap map[K]*ListElement[keyValue[K, V]]
entryList *List[keyValue[K, V]]
freeList []*ListElement[keyValue[K, V]]
Expand All @@ -31,49 +26,7 @@ func NewHashmap[K comparable, V any]() *Hashmap[K, V] {
}
}

func (lh *Hashmap[K, V]) Put(key K, val V) {
lh.lock.Lock()
defer lh.lock.Unlock()

lh.put(key, val)
}

func (lh *Hashmap[K, V]) Get(key K) (V, bool) {
lh.lock.RLock()
defer lh.lock.RUnlock()

return lh.get(key)
}

func (lh *Hashmap[K, V]) Delete(key K) bool {
lh.lock.Lock()
defer lh.lock.Unlock()

return lh.delete(key)
}

func (lh *Hashmap[K, V]) Len() int {
lh.lock.RLock()
defer lh.lock.RUnlock()

return lh.len()
}

func (lh *Hashmap[K, V]) Oldest() (K, V, bool) {
lh.lock.RLock()
defer lh.lock.RUnlock()

return lh.oldest()
}

func (lh *Hashmap[K, V]) Newest() (K, V, bool) {
lh.lock.RLock()
defer lh.lock.RUnlock()

return lh.newest()
}

func (lh *Hashmap[K, V]) put(key K, value V) {
func (lh *Hashmap[K, V]) Put(key K, value V) {
if e, ok := lh.entryMap[key]; ok {
lh.entryList.MoveToBack(e)
e.Value = keyValue[K, V]{
Expand All @@ -100,14 +53,14 @@ func (lh *Hashmap[K, V]) put(key K, value V) {
lh.entryList.PushBack(e)
}

func (lh *Hashmap[K, V]) get(key K) (V, bool) {
func (lh *Hashmap[K, V]) Get(key K) (V, bool) {
if e, ok := lh.entryMap[key]; ok {
return e.Value.value, true
}
return utils.Zero[V](), false
}

func (lh *Hashmap[K, V]) delete(key K) bool {
func (lh *Hashmap[K, V]) Delete(key K) bool {
e, ok := lh.entryMap[key]
if ok {
lh.entryList.Remove(e)
Expand All @@ -118,18 +71,18 @@ func (lh *Hashmap[K, V]) delete(key K) bool {
return ok
}

func (lh *Hashmap[K, V]) len() int {
func (lh *Hashmap[K, V]) Len() int {
return len(lh.entryMap)
}

func (lh *Hashmap[K, V]) oldest() (K, V, bool) {
func (lh *Hashmap[K, V]) Oldest() (K, V, bool) {
if e := lh.entryList.Front(); e != nil {
return e.Value.key, e.Value.value, true
}
return utils.Zero[K](), utils.Zero[V](), false
}

func (lh *Hashmap[K, V]) newest() (K, V, bool) {
func (lh *Hashmap[K, V]) Newest() (K, V, bool) {
if e := lh.entryList.Back(); e != nil {
return e.Value.key, e.Value.value, true
}
Expand Down Expand Up @@ -160,9 +113,6 @@ func (it *Iterator[K, V]) Next() bool {
return false
}

it.lh.lock.RLock()
defer it.lh.lock.RUnlock()

// If the iterator was not yet initialized, do it now.
if !it.initialized {
it.initialized = true
Expand Down
12 changes: 10 additions & 2 deletions vms/avm/txs/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,10 @@ func (m *mempool) Add(tx *txs.Tx) error {
}

func (m *mempool) Get(txID ids.ID) (*txs.Tx, bool) {
tx, ok := m.unissuedTxs.Get(txID)
return tx, ok
m.lock.RLock()
defer m.lock.RUnlock()

return m.unissuedTxs.Get(txID)
}

func (m *mempool) Remove(txs ...*txs.Tx) {
Expand Down Expand Up @@ -190,6 +192,9 @@ func (m *mempool) Remove(txs ...*txs.Tx) {
}

func (m *mempool) Peek() (*txs.Tx, bool) {
m.lock.RLock()
defer m.lock.RUnlock()

_, tx, exists := m.unissuedTxs.Oldest()
return tx, exists
}
Expand All @@ -207,6 +212,9 @@ func (m *mempool) Iterate(f func(*txs.Tx) bool) {
}

func (m *mempool) RequestBuildBlock() {
m.lock.RLock()
defer m.lock.RUnlock()

if m.unissuedTxs.Len() == 0 {
return
}
Expand Down
9 changes: 9 additions & 0 deletions vms/platformvm/txs/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ func (m *mempool) Add(tx *txs.Tx) error {
}

func (m *mempool) Get(txID ids.ID) (*txs.Tx, bool) {
m.lock.RLock()
defer m.lock.RUnlock()

return m.unissuedTxs.Get(txID)
}

Expand Down Expand Up @@ -203,6 +206,9 @@ func (m *mempool) Remove(txs ...*txs.Tx) {
}

func (m *mempool) Peek() (*txs.Tx, bool) {
m.lock.RLock()
defer m.lock.RUnlock()

_, tx, exists := m.unissuedTxs.Oldest()
return tx, exists
}
Expand Down Expand Up @@ -240,6 +246,9 @@ func (m *mempool) GetDropReason(txID ids.ID) error {
}

func (m *mempool) RequestBuildBlock(emptyBlockPermitted bool) {
m.lock.RLock()
defer m.lock.RUnlock()

if !emptyBlockPermitted && m.unissuedTxs.Len() == 0 {
return
}
Expand Down

0 comments on commit b44feeb

Please sign in to comment.