Skip to content

Commit

Permalink
Websockets: Use RW mutexes, rename, etc
Browse files Browse the repository at this point in the history
* This switches all RO uses of the mutex to use a RLock method.
* The mutex used for discrete field access has had scope drift from
  name 'connectionMutex' so rename to more appropriate fieldsMutex
* The mutex used for Set/CanUseAuthEndpoints moves from the
  subscriptions endpoint to the fieldsMutex
  • Loading branch information
gbjk committed Sep 24, 2023
1 parent 4d815d9 commit 7bb2043
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 49 deletions.
92 changes: 46 additions & 46 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,9 @@ func (w *Websocket) connectionMonitor() error {
if w.checkAndSetMonitorRunning() {
return errAlreadyRunning
}
w.connectionMutex.RLock()
w.fieldMutex.RLock()
delay := w.connectionMonitorDelay
w.connectionMutex.RUnlock()
w.fieldMutex.RUnlock()

go func() {
timer := time.NewTimer(delay)
Expand Down Expand Up @@ -615,73 +615,73 @@ func (w *Websocket) trafficMonitor() {
}

func (w *Websocket) setConnectedStatus(b bool) {
w.connectionMutex.Lock()
w.fieldMutex.Lock()
w.connected = b
w.connectionMutex.Unlock()
w.fieldMutex.Unlock()
}

// IsConnected returns status of connection
func (w *Websocket) IsConnected() bool {
w.connectionMutex.RLock()
defer w.connectionMutex.RUnlock()
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.connected
}

func (w *Websocket) setConnectingStatus(b bool) {
w.connectionMutex.Lock()
w.fieldMutex.Lock()
w.connecting = b
w.connectionMutex.Unlock()
w.fieldMutex.Unlock()
}

// IsConnecting returns status of connecting
func (w *Websocket) IsConnecting() bool {
w.connectionMutex.RLock()
defer w.connectionMutex.RUnlock()
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.connecting
}

func (w *Websocket) setEnabled(b bool) {
w.connectionMutex.Lock()
w.fieldMutex.Lock()
w.enabled = b
w.connectionMutex.Unlock()
w.fieldMutex.Unlock()
}

// IsEnabled returns status of enabled
func (w *Websocket) IsEnabled() bool {
w.connectionMutex.RLock()
defer w.connectionMutex.RUnlock()
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.enabled
}

func (w *Websocket) setInit(b bool) {
w.connectionMutex.Lock()
w.fieldMutex.Lock()
w.Init = b
w.connectionMutex.Unlock()
w.fieldMutex.Unlock()
}

// IsInit returns status of init
func (w *Websocket) IsInit() bool {
w.connectionMutex.RLock()
defer w.connectionMutex.RUnlock()
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.Init
}

func (w *Websocket) setTrafficMonitorRunning(b bool) {
w.connectionMutex.Lock()
w.fieldMutex.Lock()
w.trafficMonitorRunning = b
w.connectionMutex.Unlock()
w.fieldMutex.Unlock()
}

// IsTrafficMonitorRunning returns status of the traffic monitor
func (w *Websocket) IsTrafficMonitorRunning() bool {
w.connectionMutex.RLock()
defer w.connectionMutex.RUnlock()
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.trafficMonitorRunning
}

func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) {
w.connectionMutex.Lock()
defer w.connectionMutex.Unlock()
w.fieldMutex.Lock()
defer w.fieldMutex.Unlock()
if w.connectionMonitorRunning {
return true
}
Expand All @@ -690,28 +690,28 @@ func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) {
}

func (w *Websocket) setConnectionMonitorRunning(b bool) {
w.connectionMutex.Lock()
w.fieldMutex.Lock()
w.connectionMonitorRunning = b
w.connectionMutex.Unlock()
w.fieldMutex.Unlock()
}

// IsConnectionMonitorRunning returns status of connection monitor
func (w *Websocket) IsConnectionMonitorRunning() bool {
w.connectionMutex.RLock()
defer w.connectionMutex.RUnlock()
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.connectionMonitorRunning
}

func (w *Websocket) setDataMonitorRunning(b bool) {
w.connectionMutex.Lock()
w.fieldMutex.Lock()
w.dataMonitorRunning = b
w.connectionMutex.Unlock()
w.fieldMutex.Unlock()
}

// IsDataMonitorRunning returns status of data monitor
func (w *Websocket) IsDataMonitorRunning() bool {
w.connectionMutex.RLock()
defer w.connectionMutex.RUnlock()
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.dataMonitorRunning
}

Expand Down Expand Up @@ -848,8 +848,8 @@ func (w *Websocket) GetName() string {
// GetChannelDifference finds the difference between the subscribed channels
// and the new subscription list when pairs are disabled or enabled.
func (w *Websocket) GetChannelDifference(genSubs []ChannelSubscription) (sub, unsub []ChannelSubscription) {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()

oldsubs:
for _, x := range w.subscriptions {
Expand Down Expand Up @@ -878,7 +878,7 @@ func (w *Websocket) UnsubscribeChannels(channels []ChannelSubscription) error {
if len(channels) == 0 {
return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoChannelsInArgs)
}
w.subscriptionMutex.Lock()
w.subscriptionMutex.RLock()

channels:
for x := range channels {
Expand All @@ -887,13 +887,13 @@ channels:
continue channels
}
}
w.subscriptionMutex.Unlock()
w.subscriptionMutex.RUnlock()
return fmt.Errorf("%s websocket: %w: %+v",
w.exchangeName,
errSubscriptionNotFound,
channels[x])
}
w.subscriptionMutex.Unlock()
w.subscriptionMutex.RUnlock()
return w.Unsubscriber(channels)
}

Expand All @@ -911,19 +911,19 @@ func (w *Websocket) SubscribeToChannels(channels []ChannelSubscription) error {
if len(channels) == 0 {
return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoChannelsInArgs)
}
w.subscriptionMutex.Lock()
w.subscriptionMutex.RLock()
for x := range channels {
for _, y := range w.subscriptions {
if channels[x].Equal(&y) { //nolint:gosec // for alias var is not closured or stored
w.subscriptionMutex.Unlock()
w.subscriptionMutex.RUnlock()
return fmt.Errorf("%s websocket: %v %w",
w.exchangeName,
channels[x],
errAlreadySubscribed)
}
}
}
w.subscriptionMutex.Unlock()
w.subscriptionMutex.RUnlock()
if err := w.Subscriber(channels); err != nil {
return fmt.Errorf("%v %w: %v", w.exchangeName, ErrSubscriptionFailure, err)
}
Expand Down Expand Up @@ -976,8 +976,8 @@ func (w *ChannelSubscription) Equal(s *ChannelSubscription) bool {
// GetSubscriptions returns a copied list of subscriptions
// and is a private member that cannot be manipulated
func (w *Websocket) GetSubscriptions() []ChannelSubscription {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
subs := make([]ChannelSubscription, 0, len(w.subscriptions))
for _, c := range w.subscriptions {
subs = append(subs, c)
Expand All @@ -988,16 +988,16 @@ func (w *Websocket) GetSubscriptions() []ChannelSubscription {
// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in
// a thread safe manner
func (w *Websocket) SetCanUseAuthenticatedEndpoints(val bool) {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
w.fieldMutex.Lock()
defer w.fieldMutex.Unlock()
w.canUseAuthenticatedEndpoints = val
}

// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in
// a thread safe manner
func (w *Websocket) CanUseAuthenticatedEndpoints() bool {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.canUseAuthenticatedEndpoints
}

Expand Down
6 changes: 3 additions & 3 deletions exchanges/stream/websocket_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ type Websocket struct {
runningURL string
runningURLAuth string
exchangeName string
m sync.Mutex
connectionMutex sync.RWMutex
m sync.RWMutex
fieldMutex sync.RWMutex
connector func() error

subscriptionMutex sync.Mutex
subscriptionMutex sync.RWMutex
subscriptions subscriptionMap
Subscribe chan []ChannelSubscription
Unsubscribe chan []ChannelSubscription
Expand Down

0 comments on commit 7bb2043

Please sign in to comment.