diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 633201892c8..ccbca2a92bd 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -359,9 +359,9 @@ func (w *Websocket) connectionMonitor() error { if w.checkAndSetMonitorRunning() { return errAlreadyRunning } - w.connectionMutex.RLock() + w.fieldMutex.RLock() delay := w.connectionMonitorDelay - w.connectionMutex.RUnlock() + w.fieldMutex.RUnlock() go func() { timer := time.NewTimer(delay) @@ -615,73 +615,73 @@ func (w *Websocket) trafficMonitor() { } func (w *Websocket) setConnectedStatus(b bool) { - w.connectionMutex.Lock() + w.fieldMutex.Lock() w.connected = b - w.connectionMutex.Unlock() + w.fieldMutex.Unlock() } // IsConnected returns status of connection func (w *Websocket) IsConnected() bool { - w.connectionMutex.RLock() - defer w.connectionMutex.RUnlock() + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() return w.connected } func (w *Websocket) setConnectingStatus(b bool) { - w.connectionMutex.Lock() + w.fieldMutex.Lock() w.connecting = b - w.connectionMutex.Unlock() + w.fieldMutex.Unlock() } // IsConnecting returns status of connecting func (w *Websocket) IsConnecting() bool { - w.connectionMutex.RLock() - defer w.connectionMutex.RUnlock() + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() return w.connecting } func (w *Websocket) setEnabled(b bool) { - w.connectionMutex.Lock() + w.fieldMutex.Lock() w.enabled = b - w.connectionMutex.Unlock() + w.fieldMutex.Unlock() } // IsEnabled returns status of enabled func (w *Websocket) IsEnabled() bool { - w.connectionMutex.RLock() - defer w.connectionMutex.RUnlock() + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() return w.enabled } func (w *Websocket) setInit(b bool) { - w.connectionMutex.Lock() + w.fieldMutex.Lock() w.Init = b - w.connectionMutex.Unlock() + w.fieldMutex.Unlock() } // IsInit returns status of init func (w *Websocket) IsInit() bool { - w.connectionMutex.RLock() - defer w.connectionMutex.RUnlock() + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() return w.Init } func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.connectionMutex.Lock() + w.fieldMutex.Lock() w.trafficMonitorRunning = b - w.connectionMutex.Unlock() + w.fieldMutex.Unlock() } // IsTrafficMonitorRunning returns status of the traffic monitor func (w *Websocket) IsTrafficMonitorRunning() bool { - w.connectionMutex.RLock() - defer w.connectionMutex.RUnlock() + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() return w.trafficMonitorRunning } func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - w.connectionMutex.Lock() - defer w.connectionMutex.Unlock() + w.fieldMutex.Lock() + defer w.fieldMutex.Unlock() if w.connectionMonitorRunning { return true } @@ -690,28 +690,28 @@ func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { } func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.connectionMutex.Lock() + w.fieldMutex.Lock() w.connectionMonitorRunning = b - w.connectionMutex.Unlock() + w.fieldMutex.Unlock() } // IsConnectionMonitorRunning returns status of connection monitor func (w *Websocket) IsConnectionMonitorRunning() bool { - w.connectionMutex.RLock() - defer w.connectionMutex.RUnlock() + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() return w.connectionMonitorRunning } func (w *Websocket) setDataMonitorRunning(b bool) { - w.connectionMutex.Lock() + w.fieldMutex.Lock() w.dataMonitorRunning = b - w.connectionMutex.Unlock() + w.fieldMutex.Unlock() } // IsDataMonitorRunning returns status of data monitor func (w *Websocket) IsDataMonitorRunning() bool { - w.connectionMutex.RLock() - defer w.connectionMutex.RUnlock() + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() return w.dataMonitorRunning } @@ -848,8 +848,8 @@ func (w *Websocket) GetName() string { // GetChannelDifference finds the difference between the subscribed channels // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(genSubs []ChannelSubscription) (sub, unsub []ChannelSubscription) { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() + w.subscriptionMutex.RLock() + defer w.subscriptionMutex.RUnlock() oldsubs: for _, x := range w.subscriptions { @@ -878,7 +878,7 @@ func (w *Websocket) UnsubscribeChannels(channels []ChannelSubscription) error { if len(channels) == 0 { return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoChannelsInArgs) } - w.subscriptionMutex.Lock() + w.subscriptionMutex.RLock() channels: for x := range channels { @@ -887,13 +887,13 @@ channels: continue channels } } - w.subscriptionMutex.Unlock() + w.subscriptionMutex.RUnlock() return fmt.Errorf("%s websocket: %w: %+v", w.exchangeName, errSubscriptionNotFound, channels[x]) } - w.subscriptionMutex.Unlock() + w.subscriptionMutex.RUnlock() return w.Unsubscriber(channels) } @@ -911,11 +911,11 @@ func (w *Websocket) SubscribeToChannels(channels []ChannelSubscription) error { if len(channels) == 0 { return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoChannelsInArgs) } - w.subscriptionMutex.Lock() + w.subscriptionMutex.RLock() for x := range channels { for _, y := range w.subscriptions { if channels[x].Equal(&y) { //nolint:gosec // for alias var is not closured or stored - w.subscriptionMutex.Unlock() + w.subscriptionMutex.RUnlock() return fmt.Errorf("%s websocket: %v %w", w.exchangeName, channels[x], @@ -923,7 +923,7 @@ func (w *Websocket) SubscribeToChannels(channels []ChannelSubscription) error { } } } - w.subscriptionMutex.Unlock() + w.subscriptionMutex.RUnlock() if err := w.Subscriber(channels); err != nil { return fmt.Errorf("%v %w: %v", w.exchangeName, ErrSubscriptionFailure, err) } @@ -976,8 +976,8 @@ func (w *ChannelSubscription) Equal(s *ChannelSubscription) bool { // GetSubscriptions returns a copied list of subscriptions // and is a private member that cannot be manipulated func (w *Websocket) GetSubscriptions() []ChannelSubscription { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() + w.subscriptionMutex.RLock() + defer w.subscriptionMutex.RUnlock() subs := make([]ChannelSubscription, 0, len(w.subscriptions)) for _, c := range w.subscriptions { subs = append(subs, c) @@ -988,16 +988,16 @@ func (w *Websocket) GetSubscriptions() []ChannelSubscription { // SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in // a thread safe manner func (w *Websocket) SetCanUseAuthenticatedEndpoints(val bool) { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() + w.fieldMutex.Lock() + defer w.fieldMutex.Unlock() w.canUseAuthenticatedEndpoints = val } // CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in // a thread safe manner func (w *Websocket) CanUseAuthenticatedEndpoints() bool { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() return w.canUseAuthenticatedEndpoints } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index c9fae47dfb9..3ebf888293b 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -44,11 +44,11 @@ type Websocket struct { runningURL string runningURLAuth string exchangeName string - m sync.Mutex - connectionMutex sync.RWMutex + m sync.RWMutex + fieldMutex sync.RWMutex connector func() error - subscriptionMutex sync.Mutex + subscriptionMutex sync.RWMutex subscriptions subscriptionMap Subscribe chan []ChannelSubscription Unsubscribe chan []ChannelSubscription