diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 82d8f3f1847..d0029b208db 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -190,7 +190,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions) } w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection - w.setState(disconnected) + w.setState(disconnectedState) return nil } @@ -279,14 +279,14 @@ func (w *Websocket) Connect() error { w.dataMonitor() w.trafficMonitor() - w.setState(connecting) + w.setState(connectingState) err := w.connector() if err != nil { - w.setState(disconnected) + w.setState(disconnectedState) return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } - w.setState(connected) + w.setState(connectedState) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() @@ -406,7 +406,7 @@ func (w *Websocket) connectionMonitor() error { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) - w.setState(disconnected) + w.setState(disconnectedState) } w.DataHandler <- err @@ -466,7 +466,7 @@ func (w *Websocket) Shutdown() error { // flush any subscriptions from last connection if needed w.subscriptions.Clear() - w.setState(disconnected) + w.setState(disconnectedState) close(w.ShutdownC) w.Wg.Wait() @@ -597,17 +597,17 @@ func (w *Websocket) setState(s uint32) { // IsInitialised returns whether the websocket has been Setup() already func (w *Websocket) IsInitialised() bool { - return w.state.Load() != uninitialised + return w.state.Load() != uninitialisedState } // IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { - return w.state.Load() == connected + return w.state.Load() == connectedState } // IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { - return w.state.Load() == connecting + return w.state.Load() == connectingState } func (w *Websocket) setEnabled(b bool) { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index bdf92a6f662..e2f25056bb0 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -193,13 +193,13 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { signal := struct{}{} patience := 10 * time.Millisecond ws.trafficTimeout = 200 * time.Millisecond - ws.state.Store(connected) + ws.state.Store(connectedState) thenish := time.Now() ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, connected, ws.state.Load(), "websocket must be connected") + require.Equal(t, connectedState, ws.state.Load(), "websocket must be connected") for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass select { @@ -225,7 +225,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { } require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") } @@ -237,16 +237,16 @@ func TestTrafficMonitorConnecting(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - ws.state.Store(connecting) + ws.state.Store(connectingState) ws.trafficTimeout = 50 * time.Millisecond ws.trafficMonitor() require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting") + require.Equal(t, connectingState, ws.state.Load(), "websocket must be connecting") <-time.After(4 * ws.trafficTimeout) - require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks") - ws.state.Store(connected) + require.Equal(t, connectingState, ws.state.Load(), "websocket must still be connecting after several checks") + ws.state.Store(connectedState) require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") } @@ -258,7 +258,7 @@ func TestTrafficMonitorShutdown(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - ws.state.Store(connected) + ws.state.Store(connectedState) ws.trafficTimeout = time.Minute ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") @@ -304,16 +304,16 @@ func TestConnectionMessageErrors(t *testing.T) { assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) - wsWrong.setState(connecting) + wsWrong.setState(connectingState) err = wsWrong.Connect() assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") - wsWrong.setState(disconnected) + wsWrong.setState(disconnectedState) err = wsWrong.Connect() assert.ErrorIs(t, err, common.ErrNilPointer, "Connect should get a nil pointer error, presumably on subs") wsWrong.subscriptions = subscription.NewStore() - wsWrong.setState(disconnected) + wsWrong.setState(disconnectedState) wsWrong.connector = func() error { return errDastardlyReason } err = wsWrong.Connect() assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") @@ -379,7 +379,7 @@ func TestWebsocket(t *testing.T) { err = ws.SetProxyAddress("https://192.168.0.1:1337") assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") - ws.setState(connected) + ws.setState(connectedState) err = ws.SetProxyAddress("https://192.168.0.1:1336") assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") @@ -402,14 +402,14 @@ func TestWebsocket(t *testing.T) { assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") - ws.setState(connected) + ws.setState(connectedState) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} - ws.setState(disconnected) + ws.setState(disconnectedState) err = ws.Connect() assert.NoError(t, err, "Connect should not error") @@ -853,7 +853,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { ws := &Websocket{} assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") - ws.setState(connected) + ws.setState(connectedState) require.True(t, ws.IsConnected(), "IsConnected must return true") assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") @@ -987,7 +987,7 @@ func TestFlushChannels(t *testing.T) { w.trafficTimeout = time.Second * 30 w.setEnabled(true) - w.setState(connected) + w.setState(connectedState) problemFunc := func() (subscription.List, error) { return nil, errDastardlyReason @@ -1042,7 +1042,7 @@ func TestFlushChannels(t *testing.T) { err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - w.setState(connected) + w.setState(connectedState) w.features.Unsubscribe = true err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") @@ -1052,7 +1052,7 @@ func TestDisable(t *testing.T) { t.Parallel() w := NewWebsocket() w.setEnabled(true) - w.setState(connected) + w.setState(connectedState) require.NoError(t, w.Disable(), "Disable must not error") assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 353003d0a21..707fc7dcb05 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -23,10 +23,10 @@ const ( ) const ( - uninitialised uint32 = iota - disconnected - connecting - connected + uninitialisedState uint32 = iota + disconnectedState + connectingState + connectedState ) // Websocket defines a return type for websocket connections via the interface