From 91eaf452bcec083de2df63bff62b4e9df815ef7c Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 15 Feb 2024 19:39:03 +0700 Subject: [PATCH] Websocket: Use atomics instead of mutex This was spurred by looking at the setState call in trafficMonitor and the effect on blocking and efficiency. With the new atomic types in Go 1.19, and the small types in use here, atomics should be safe for our usage. bools should be truly atomic, and uint32 is atomic when the accepted value range is less than one byte/uint8 since that can be written atomicly by concurrent processors. Maybe that's not even a factor any more, however we don't even have to worry enough to check. --- exchanges/stream/websocket.go | 66 +++++++------------------ exchanges/stream/websocket_test.go | 74 ++++++++++++++--------------- exchanges/stream/websocket_types.go | 15 +++--- 3 files changed, 61 insertions(+), 94 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 8c9c1151912..080d78f5ba8 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -128,7 +128,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { if s.ExchangeConfig.Features == nil { return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil) } - w.enabled = s.ExchangeConfig.Features.Enabled.Websocket + w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) if s.Connector == nil { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset) @@ -384,9 +384,7 @@ func (w *Websocket) connectionMonitor() error { if w.checkAndSetMonitorRunning() { return errAlreadyRunning } - w.fieldMutex.RLock() delay := w.connectionMonitorDelay - w.fieldMutex.RUnlock() go func() { timer := time.NewTimer(delay) @@ -614,93 +612,63 @@ func (w *Websocket) trafficMonitor() { }() } -// IsInitialised returns whether the websocket has been Setup() already -func (w *Websocket) IsInitialised() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state != uninitialised +func (w *Websocket) setState(s uint32) { + w.state.Store(s) } -func (w *Websocket) setState(s state) { - w.fieldMutex.Lock() - w.state = s - w.fieldMutex.Unlock() +// IsInitialised returns whether the websocket has been Setup() already +func (w *Websocket) IsInitialised() bool { + return w.state.Load() != uninitialised } // IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state == connected + return w.state.Load() != connected } // IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state == connecting + return w.state.Load() == connecting } func (w *Websocket) setEnabled(b bool) { - w.fieldMutex.Lock() - w.enabled = b - w.fieldMutex.Unlock() + w.enabled.Store(b) } // IsEnabled returns whether the websocket is enabled func (w *Websocket) IsEnabled() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.enabled + return w.enabled.Load() } func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.trafficMonitorRunning = b - w.fieldMutex.Unlock() + w.trafficMonitorRunning.Store(b) } // IsTrafficMonitorRunning returns status of the traffic monitor func (w *Websocket) IsTrafficMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.trafficMonitorRunning + return w.trafficMonitorRunning.Load() } func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - if w.connectionMonitorRunning { - return true - } - w.connectionMonitorRunning = true - return false + return !w.connectionMonitorRunning.CompareAndSwap(false, true) } func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.connectionMonitorRunning = b - w.fieldMutex.Unlock() + w.connectionMonitorRunning.Store(b) } // IsConnectionMonitorRunning returns status of connection monitor func (w *Websocket) IsConnectionMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connectionMonitorRunning + return w.connectionMonitorRunning.Load() } func (w *Websocket) setDataMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.dataMonitorRunning = b - w.fieldMutex.Unlock() + w.dataMonitorRunning.Store(b) } // IsDataMonitorRunning returns status of data monitor func (w *Websocket) IsDataMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.dataMonitorRunning + return w.dataMonitorRunning.Load() } // CanUseAuthenticatedWebsocketForWrapper Handles a common check to diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index d65313f732d..dab17d08b54 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -478,7 +478,7 @@ func TestConnectionMonitorNoConnection(t *testing.T) { ws.ShutdownC = make(chan struct{}, 1) ws.exchangeName = "hello" ws.Wg = &sync.WaitGroup{} - ws.enabled = true + ws.setEnabled(true) err := ws.connectionMonitor() require.NoError(t, err, "connectionMonitor must not error") assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") @@ -951,9 +951,7 @@ func TestFlushChannels(t *testing.T) { err = dodgyWs.FlushChannels() assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") - web := Websocket{ - enabled: true, - state: connected, + w := Websocket{ connector: connect, ShutdownC: make(chan struct{}), Subscriber: newgen.SUBME, @@ -966,6 +964,8 @@ func TestFlushChannels(t *testing.T) { // in FlushChannels() so the traffic monitor doesn't time out and turn // this to an unconnected state } + w.setEnabled(true) + w.setState(connected) problemFunc := func() ([]subscription.Subscription, error) { return nil, errDastardlyReason @@ -978,40 +978,40 @@ func TestFlushChannels(t *testing.T) { // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} - web.GenerateSubs = func() ([]subscription.Subscription, error) { + w.GenerateSubs = func() ([]subscription.Subscription, error) { return []subscription.Subscription{{Channel: "test"}}, nil } - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.features.FullPayloadSubscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() // error on full subscribeToChannels + w.features.FullPayloadSubscribe = true + w.GenerateSubs = problemFunc + err = w.FlushChannels() // error on full subscribeToChannels assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - web.GenerateSubs = noSub - err = web.FlushChannels() // No subs to unsub + w.GenerateSubs = noSub + err = w.FlushChannels() // No subs to unsub assert.NoError(t, err, "FlushChannels should not error") - web.GenerateSubs = newgen.generateSubs - subs, err := web.GenerateSubs() + w.GenerateSubs = newgen.generateSubs + subs, err := w.GenerateSubs() require.NoError(t, err, "GenerateSubs must not error") - web.AddSuccessfulSubscriptions(subs...) - err = web.FlushChannels() + w.AddSuccessfulSubscriptions(subs...) + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.features.FullPayloadSubscribe = false - web.features.Subscribe = true + w.features.FullPayloadSubscribe = false + w.features.Subscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() + w.GenerateSubs = problemFunc + err = w.FlushChannels() assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - web.GenerateSubs = newgen.generateSubs - err = web.FlushChannels() + w.GenerateSubs = newgen.generateSubs + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.subscriptionMutex.Lock() - web.subscriptions = subscriptionMap{ + w.subscriptionMutex.Lock() + w.subscriptions = subscriptionMap{ 41: { Key: 41, Channel: "match channel", @@ -1023,34 +1023,34 @@ func TestFlushChannels(t *testing.T) { Pair: currency.NewPair(currency.THETA, currency.USDT), }, } - web.subscriptionMutex.Unlock() + w.subscriptionMutex.Unlock() - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.setState(connected) - web.features.Unsubscribe = true - err = web.FlushChannels() + w.setState(connected) + w.features.Unsubscribe = true + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") } func TestDisable(t *testing.T) { t.Parallel() - web := Websocket{ - enabled: true, - state: connected, + w := Websocket{ ShutdownC: make(chan struct{}), } - require.NoError(t, web.Disable(), "Disable must not error") - assert.ErrorIs(t, web.Disable(), ErrAlreadyDisabled, "Disable should error correctly") + w.setEnabled(true) + w.setState(connected) + require.NoError(t, w.Disable(), "Disable must not error") + assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } func TestEnable(t *testing.T) { t.Parallel() - web := Websocket{ + w := Websocket{ connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), @@ -1060,8 +1060,8 @@ func TestEnable(t *testing.T) { Subscriber: func(cs []subscription.Subscription) error { return nil }, } - require.NoError(t, web.Enable(), "Enable must not error") - assert.ErrorIs(t, web.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") + require.NoError(t, w.Enable(), "Enable must not error") + assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 31e1db58b09..02946ea8022 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -2,6 +2,7 @@ package stream import ( "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -23,10 +24,8 @@ const ( type subscriptionMap map[any]*subscription.Subscription -type state int - const ( - uninitialised state = iota + uninitialised uint32 = iota disconnected connecting connected @@ -36,12 +35,12 @@ const ( // wrapper for routine processing type Websocket struct { canUseAuthenticatedEndpoints bool - enabled bool - state state + enabled atomic.Bool + state atomic.Uint32 verbose bool - connectionMonitorRunning bool - trafficMonitorRunning bool - dataMonitorRunning bool + connectionMonitorRunning atomic.Bool + trafficMonitorRunning atomic.Bool + dataMonitorRunning atomic.Bool trafficTimeout time.Duration connectionMonitorDelay time.Duration proxyAddr string