Skip to content

Commit

Permalink
Websocket: Add suffix to state consts
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Feb 22, 2024
1 parent 2ecad08 commit 4b0d2f1
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 32 deletions.
18 changes: 9 additions & 9 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions)
}
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
w.setState(disconnected)
w.setState(disconnectedState)

return nil
}
Expand Down Expand Up @@ -279,14 +279,14 @@ func (w *Websocket) Connect() error {

w.dataMonitor()
w.trafficMonitor()
w.setState(connecting)
w.setState(connectingState)

err := w.connector()
if err != nil {
w.setState(disconnected)
w.setState(disconnectedState)
return fmt.Errorf("%v Error connecting %w", w.exchangeName, err)
}
w.setState(connected)
w.setState(connectedState)

if !w.IsConnectionMonitorRunning() {
err = w.connectionMonitor()
Expand Down Expand Up @@ -406,7 +406,7 @@ func (w *Websocket) connectionMonitor() error {
case err := <-w.ReadMessageErrors:
if IsDisconnectionError(err) {
log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err)
w.setState(disconnected)
w.setState(disconnectedState)
}

w.DataHandler <- err
Expand Down Expand Up @@ -466,7 +466,7 @@ func (w *Websocket) Shutdown() error {
// flush any subscriptions from last connection if needed
w.subscriptions.Clear()

w.setState(disconnected)
w.setState(disconnectedState)

close(w.ShutdownC)
w.Wg.Wait()
Expand Down Expand Up @@ -597,17 +597,17 @@ func (w *Websocket) setState(s uint32) {

// IsInitialised returns whether the websocket has been Setup() already
func (w *Websocket) IsInitialised() bool {
return w.state.Load() != uninitialised
return w.state.Load() != uninitialisedState
}

// IsConnected returns whether the websocket is connected
func (w *Websocket) IsConnected() bool {
return w.state.Load() == connected
return w.state.Load() == connectedState
}

// IsConnecting returns whether the websocket is connecting
func (w *Websocket) IsConnecting() bool {
return w.state.Load() == connecting
return w.state.Load() == connectingState
}

func (w *Websocket) setEnabled(b bool) {
Expand Down
38 changes: 19 additions & 19 deletions exchanges/stream/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) {
signal := struct{}{}
patience := 10 * time.Millisecond
ws.trafficTimeout = 200 * time.Millisecond
ws.state.Store(connected)
ws.state.Store(connectedState)

thenish := time.Now()
ws.trafficMonitor()

assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
require.Equal(t, connected, ws.state.Load(), "websocket must be connected")
require.Equal(t, connectedState, ws.state.Load(), "websocket must be connected")

for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass
select {
Expand Down Expand Up @@ -228,7 +228,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) {
}

require.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected")
assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected")
assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down")
}, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts")
}
Expand All @@ -240,16 +240,16 @@ func TestTrafficMonitorConnecting(t *testing.T) {
err := ws.Setup(defaultSetup)
require.NoError(t, err, "Setup must not error")

ws.state.Store(connecting)
ws.state.Store(connectingState)
ws.trafficTimeout = 50 * time.Millisecond
ws.trafficMonitor()
require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting")
require.Equal(t, connectingState, ws.state.Load(), "websocket must be connecting")
<-time.After(4 * ws.trafficTimeout)
require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks")
ws.state.Store(connected)
require.Equal(t, connectingState, ws.state.Load(), "websocket must still be connecting after several checks")
ws.state.Store(connectedState)
require.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected")
assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected")
assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down")
}, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes")
}
Expand All @@ -261,7 +261,7 @@ func TestTrafficMonitorShutdown(t *testing.T) {
err := ws.Setup(defaultSetup)
require.NoError(t, err, "Setup must not error")

ws.state.Store(connected)
ws.state.Store(connectedState)
ws.trafficTimeout = time.Minute
ws.trafficMonitor()
assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running")
Expand Down Expand Up @@ -307,16 +307,16 @@ func TestConnectionMessageErrors(t *testing.T) {
assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly")

wsWrong.setEnabled(true)
wsWrong.setState(connecting)
wsWrong.setState(connectingState)
err = wsWrong.Connect()
assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly")

wsWrong.setState(disconnected)
wsWrong.setState(disconnectedState)
err = wsWrong.Connect()
assert.ErrorIs(t, err, common.ErrNilPointer, "Connect should get a nil pointer error, presumably on subs")

wsWrong.subscriptions = subscription.NewStore()
wsWrong.setState(disconnected)
wsWrong.setState(disconnectedState)
wsWrong.connector = func() error { return errDastardlyReason }
err = wsWrong.Connect()
assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly")
Expand Down Expand Up @@ -382,7 +382,7 @@ func TestWebsocket(t *testing.T) {
err = ws.SetProxyAddress("https://192.168.0.1:1337")
assert.NoError(t, err, "SetProxyAddress should not error when not yet connected")

ws.setState(connected)
ws.setState(connectedState)

err = ws.SetProxyAddress("https://192.168.0.1:1336")
assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there")
Expand All @@ -405,14 +405,14 @@ func TestWebsocket(t *testing.T) {
assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly")
assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly")

ws.setState(connected)
ws.setState(connectedState)
ws.AuthConn = &dodgyConnection{}
err = ws.Shutdown()
assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn")
assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn")

ws.AuthConn = &WebsocketConnection{}
ws.setState(disconnected)
ws.setState(disconnectedState)

err = ws.Connect()
assert.NoError(t, err, "Connect should not error")
Expand Down Expand Up @@ -859,7 +859,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) {
ws := &Websocket{}
assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false")

ws.setState(connected)
ws.setState(connectedState)
require.True(t, ws.IsConnected(), "IsConnected must return true")
assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false")

Expand Down Expand Up @@ -1021,7 +1021,7 @@ func TestFlushChannels(t *testing.T) {
w.trafficTimeout = time.Second * 30

w.setEnabled(true)
w.setState(connected)
w.setState(connectedState)

problemFunc := func() (subscription.List, error) {
return nil, errDastardlyReason
Expand Down Expand Up @@ -1077,7 +1077,7 @@ func TestFlushChannels(t *testing.T) {
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")

w.setState(connected)
w.setState(connectedState)
w.features.Unsubscribe = true
err = w.FlushChannels()
assert.NoError(t, err, "FlushChannels should not error")
Expand All @@ -1087,7 +1087,7 @@ func TestDisable(t *testing.T) {
t.Parallel()
w := NewWebsocket()
w.setEnabled(true)
w.setState(connected)
w.setState(connectedState)
require.NoError(t, w.Disable(), "Disable must not error")
assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly")
}
Expand Down
8 changes: 4 additions & 4 deletions exchanges/stream/websocket_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ const (
)

const (
uninitialised uint32 = iota
disconnected
connecting
connected
uninitialisedState uint32 = iota
disconnectedState
connectingState
connectedState
)

// Websocket defines a return type for websocket connections via the interface
Expand Down

0 comments on commit 4b0d2f1

Please sign in to comment.