Skip to content

Commit

Permalink
Websocket: Use atomics instead of mutex
Browse files Browse the repository at this point in the history
This was spurred by looking at the setState call in trafficMonitor and
the effect on blocking and efficiency.
With the new atomic types in Go 1.19, and the small types in use here,
atomics should be safe for our usage. bools should be truly atomic,
and uint32 is atomic when the accepted value range is less than one byte/uint8 since
that can be written atomicly by concurrent processors.
Maybe that's not even a factor any more, however we don't even have to worry enough to check.
  • Loading branch information
gbjk committed Feb 15, 2024
1 parent b65f102 commit 1c7441c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 114 deletions.
93 changes: 27 additions & 66 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
if s.ExchangeConfig.Features == nil {
return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil)
}
w.enabled = s.ExchangeConfig.Features.Enabled.Websocket
w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket)

if s.Connector == nil {
return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset)
Expand Down Expand Up @@ -384,9 +384,7 @@ func (w *Websocket) connectionMonitor() error {
if w.checkAndSetMonitorRunning() {
return errAlreadyRunning
}
w.fieldMutex.RLock()
delay := w.connectionMonitorDelay
w.fieldMutex.RUnlock()

go func() {
timer := time.NewTimer(delay)
Expand Down Expand Up @@ -614,104 +612,73 @@ func (w *Websocket) trafficMonitor() {
}()
}

// IsInitialised returns whether the websocket has been Setup() already
func (w *Websocket) IsInitialised() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.state != uninitialised
func (w *Websocket) setState(s uint32) {
w.state.Store(s)
}

func (w *Websocket) setState(s state) {
w.fieldMutex.Lock()
w.state = s
w.fieldMutex.Unlock()
// IsInitialised returns whether the websocket has been Setup() already
func (w *Websocket) IsInitialised() bool {
return w.state.Load() != uninitialised
}

// IsConnected returns whether the websocket is connected
func (w *Websocket) IsConnected() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.state == connected
return w.state.Load() == connected
}

// IsConnecting returns whether the websocket is connecting
func (w *Websocket) IsConnecting() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.state == connecting
return w.state.Load() == connecting
}

func (w *Websocket) setEnabled(b bool) {
w.fieldMutex.Lock()
w.enabled = b
w.fieldMutex.Unlock()
w.enabled.Store(b)
}

// IsEnabled returns whether the websocket is enabled
func (w *Websocket) IsEnabled() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.enabled
return w.enabled.Load()
}

func (w *Websocket) setTrafficMonitorRunning(b bool) {
w.fieldMutex.Lock()
w.trafficMonitorRunning = b
w.fieldMutex.Unlock()
w.trafficMonitorRunning.Store(b)
}

// IsTrafficMonitorRunning returns status of the traffic monitor
func (w *Websocket) IsTrafficMonitorRunning() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.trafficMonitorRunning
return w.trafficMonitorRunning.Load()
}

func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) {
w.fieldMutex.Lock()
defer w.fieldMutex.Unlock()
if w.connectionMonitorRunning {
return true
}
w.connectionMonitorRunning = true
return false
return !w.connectionMonitorRunning.CompareAndSwap(false, true)
}

func (w *Websocket) setConnectionMonitorRunning(b bool) {
w.fieldMutex.Lock()
w.connectionMonitorRunning = b
w.fieldMutex.Unlock()
w.connectionMonitorRunning.Store(b)
}

// IsConnectionMonitorRunning returns status of connection monitor
func (w *Websocket) IsConnectionMonitorRunning() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.connectionMonitorRunning
return w.connectionMonitorRunning.Load()
}

func (w *Websocket) setDataMonitorRunning(b bool) {
w.fieldMutex.Lock()
w.dataMonitorRunning = b
w.fieldMutex.Unlock()
w.dataMonitorRunning.Store(b)
}

// IsDataMonitorRunning returns status of data monitor
func (w *Websocket) IsDataMonitorRunning() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.dataMonitorRunning
return w.dataMonitorRunning.Load()
}

// CanUseAuthenticatedWebsocketForWrapper Handles a common check to
// verify whether a wrapper can use an authenticated websocket endpoint
func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool {
if w.IsConnected() && w.CanUseAuthenticatedEndpoints() {
return true
} else if w.IsConnected() && !w.CanUseAuthenticatedEndpoints() {
log.Infof(log.WebsocketMgr,
WebsocketNotAuthenticatedUsingRest,
w.exchangeName)
if w.IsConnected() {
if w.CanUseAuthenticatedEndpoints() {
return true
}
log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName)
}
return false
}
Expand Down Expand Up @@ -990,20 +957,14 @@ func (w *Websocket) GetSubscriptions() []subscription.Subscription {
return subs
}

// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in
// a thread safe manner
func (w *Websocket) SetCanUseAuthenticatedEndpoints(val bool) {
w.fieldMutex.Lock()
defer w.fieldMutex.Unlock()
w.canUseAuthenticatedEndpoints = val
// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner
func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) {
w.canUseAuthenticatedEndpoints.Store(b)
}

// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in
// a thread safe manner
// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner
func (w *Websocket) CanUseAuthenticatedEndpoints() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.canUseAuthenticatedEndpoints
return w.canUseAuthenticatedEndpoints.Load()
}

// IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error
Expand Down
77 changes: 39 additions & 38 deletions exchanges/stream/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ func TestConnectionMonitorNoConnection(t *testing.T) {
ws.ShutdownC = make(chan struct{}, 1)
ws.exchangeName = "hello"
ws.Wg = &sync.WaitGroup{}
ws.enabled = true
ws.setEnabled(true)
err := ws.connectionMonitor()
require.NoError(t, err, "connectionMonitor must not error")
assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true")
Expand Down Expand Up @@ -792,9 +792,10 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) {
assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false")

ws.setState(connected)
require.True(t, ws.IsConnected(), "IsConnected must return true")
assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false")

ws.canUseAuthenticatedEndpoints = true
ws.SetCanUseAuthenticatedEndpoints(true)
assert.True(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return true")
}

Expand Down Expand Up @@ -951,9 +952,7 @@ func TestFlushChannels(t *testing.T) {
err = dodgyWs.FlushChannels()
assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly")

web := Websocket{
enabled: true,
state: connected,
w := Websocket{
connector: connect,
ShutdownC: make(chan struct{}),
Subscriber: newgen.SUBME,
Expand All @@ -966,6 +965,8 @@ func TestFlushChannels(t *testing.T) {
// in FlushChannels() so the traffic monitor doesn't time out and turn
// this to an unconnected state
}
w.setEnabled(true)
w.setState(connected)

problemFunc := func() ([]subscription.Subscription, error) {
return nil, errDastardlyReason
Expand All @@ -978,40 +979,40 @@ func TestFlushChannels(t *testing.T) {
// Disable pair and flush system
newgen.EnabledPairs = []currency.Pair{
currency.NewPair(currency.BTC, currency.AUD)}
web.GenerateSubs = func() ([]subscription.Subscription, error) {
w.GenerateSubs = func() ([]subscription.Subscription, error) {
return []subscription.Subscription{{Channel: "test"}}, nil
}
err = web.FlushChannels()
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")

web.features.FullPayloadSubscribe = true
web.GenerateSubs = problemFunc
err = web.FlushChannels() // error on full subscribeToChannels
w.features.FullPayloadSubscribe = true
w.GenerateSubs = problemFunc
err = w.FlushChannels() // error on full subscribeToChannels
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly")

web.GenerateSubs = noSub
err = web.FlushChannels() // No subs to unsub
w.GenerateSubs = noSub
err = w.FlushChannels() // No subs to unsub
assert.NoError(t, err, "FlushChannels should not error")

web.GenerateSubs = newgen.generateSubs
subs, err := web.GenerateSubs()
w.GenerateSubs = newgen.generateSubs
subs, err := w.GenerateSubs()
require.NoError(t, err, "GenerateSubs must not error")

web.AddSuccessfulSubscriptions(subs...)
err = web.FlushChannels()
w.AddSuccessfulSubscriptions(subs...)
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
web.features.FullPayloadSubscribe = false
web.features.Subscribe = true
w.features.FullPayloadSubscribe = false
w.features.Subscribe = true

web.GenerateSubs = problemFunc
err = web.FlushChannels()
w.GenerateSubs = problemFunc
err = w.FlushChannels()
assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly")

web.GenerateSubs = newgen.generateSubs
err = web.FlushChannels()
w.GenerateSubs = newgen.generateSubs
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
web.subscriptionMutex.Lock()
web.subscriptions = subscriptionMap{
w.subscriptionMutex.Lock()
w.subscriptions = subscriptionMap{
41: {
Key: 41,
Channel: "match channel",
Expand All @@ -1023,34 +1024,34 @@ func TestFlushChannels(t *testing.T) {
Pair: currency.NewPair(currency.THETA, currency.USDT),
},
}
web.subscriptionMutex.Unlock()
w.subscriptionMutex.Unlock()

err = web.FlushChannels()
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")

err = web.FlushChannels()
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")

web.setState(connected)
web.features.Unsubscribe = true
err = web.FlushChannels()
w.setState(connected)
w.features.Unsubscribe = true
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
}

func TestDisable(t *testing.T) {
t.Parallel()
web := Websocket{
enabled: true,
state: connected,
w := Websocket{
ShutdownC: make(chan struct{}),
}
require.NoError(t, web.Disable(), "Disable must not error")
assert.ErrorIs(t, web.Disable(), ErrAlreadyDisabled, "Disable should error correctly")
w.setEnabled(true)
w.setState(connected)
require.NoError(t, w.Disable(), "Disable must not error")
assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly")
}

func TestEnable(t *testing.T) {
t.Parallel()
web := Websocket{
w := Websocket{
connector: connect,
Wg: new(sync.WaitGroup),
ShutdownC: make(chan struct{}),
Expand All @@ -1060,8 +1061,8 @@ func TestEnable(t *testing.T) {
Subscriber: func(cs []subscription.Subscription) error { return nil },
}

require.NoError(t, web.Enable(), "Enable must not error")
assert.ErrorIs(t, web.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly")
require.NoError(t, w.Enable(), "Enable must not error")
assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly")
}

func TestSetupNewConnection(t *testing.T) {
Expand Down
18 changes: 8 additions & 10 deletions exchanges/stream/websocket_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package stream

import (
"sync"
"sync/atomic"
"time"

"github.com/gorilla/websocket"
Expand All @@ -23,10 +24,8 @@ const (

type subscriptionMap map[any]*subscription.Subscription

type state int

const (
uninitialised state = iota
uninitialised uint32 = iota
disconnected
connecting
connected
Expand All @@ -35,13 +34,13 @@ const (
// Websocket defines a return type for websocket connections via the interface
// wrapper for routine processing
type Websocket struct {
canUseAuthenticatedEndpoints bool
enabled bool
state state
canUseAuthenticatedEndpoints atomic.Bool
enabled atomic.Bool
state atomic.Uint32
verbose bool
connectionMonitorRunning bool
trafficMonitorRunning bool
dataMonitorRunning bool
connectionMonitorRunning atomic.Bool
trafficMonitorRunning atomic.Bool
dataMonitorRunning atomic.Bool
trafficTimeout time.Duration
connectionMonitorDelay time.Duration
proxyAddr string
Expand All @@ -51,7 +50,6 @@ type Websocket struct {
runningURLAuth string
exchangeName string
m sync.Mutex
fieldMutex sync.RWMutex
connector func() error

subscriptionMutex sync.RWMutex
Expand Down

0 comments on commit 1c7441c

Please sign in to comment.