Skip to content

Commit

Permalink
Websocket: Simplify Connecting/Connected state
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Jan 30, 2024
1 parent e8395a5 commit 47f0df3
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 70 deletions.
37 changes: 17 additions & 20 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ var (
errChannelAlreadySubscribed = errors.New("channel already subscribed")
errInvalidChannelState = errors.New("invalid Channel state")
errSameProxyAddress = errors.New("cannot set proxy address to the same address")
errNoConnectFunc = errors.New("connect func not set")
errAlreadyConnected = errors.New("already connected")
)

var globalReporter Reporter
Expand All @@ -75,8 +77,8 @@ func SetupGlobalReporter(r Reporter) {
globalReporter = r
}

// New initialises the websocket struct
func New() *Websocket {
// NewWebsocket initialises the websocket struct
func NewWebsocket() *Websocket {
return &Websocket{
DataHandler: make(chan interface{}, defaultJobBuffer),
ToRoutine: make(chan interface{}, defaultJobBuffer),
Expand Down Expand Up @@ -188,6 +190,8 @@ func (w *Websocket) Setup(s *WebsocketSetup) error {
return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions)
}
w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection
w.setState(disconnected)

return nil
}

Expand Down Expand Up @@ -253,7 +257,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error {
// function
func (w *Websocket) Connect() error {
if w.connector == nil {
return errors.New("websocket connect function not set, cannot continue")
return errNoConnectFunc
}
w.m.Lock()
defer w.m.Unlock()
Expand All @@ -270,15 +274,14 @@ func (w *Websocket) Connect() error {

w.dataMonitor()
w.trafficMonitor()
w.setConnectingStatus(true)
w.setState(connecting)

err := w.connector()
if err != nil {
w.setConnectingStatus(false)
w.setState(disconnected)
return fmt.Errorf("%v Error connecting %w", w.exchangeName, err)
}
w.setConnectedStatus(true)
w.setConnectingStatus(false)
w.setState(connected)

if !w.IsConnectionMonitorRunning() {
err = w.connectionMonitor()
Expand Down Expand Up @@ -306,6 +309,7 @@ func (w *Websocket) Connect() error {
}

// Disable disables the exchange websocket protocol
// Note that connectionMonitor will be responsible for shutting down the websocket after disabling
func (w *Websocket) Disable() error {
if !w.IsEnabled() {
return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName)
Expand Down Expand Up @@ -405,7 +409,7 @@ func (w *Websocket) connectionMonitor() error {
case err := <-w.ReadMessageErrors:
if IsDisconnectionError(err) {
log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err)
w.setConnectedStatus(false)
w.setState(disconnected)
}

w.DataHandler <- err
Expand Down Expand Up @@ -470,8 +474,7 @@ func (w *Websocket) Shutdown() error {
close(w.ShutdownC)
w.Wg.Wait()
w.ShutdownC = make(chan struct{})
w.setConnectedStatus(false)
w.setConnectingStatus(false)
w.setState(disconnected)
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName)
}
Expand Down Expand Up @@ -565,7 +568,7 @@ func (w *Websocket) trafficMonitor() {
default:
}
}
w.setConnectedStatus(true)
w.setState(connected)
trafficTimer.Reset(w.trafficTimeout)
case <-trafficTimer.C: // Falls through when timer runs out
if w.verbose {
Expand Down Expand Up @@ -606,9 +609,9 @@ func (w *Websocket) trafficMonitor() {
}()
}

func (w *Websocket) setConnectedStatus(b bool) {
func (w *Websocket) setState(s state) {
w.fieldMutex.Lock()
w.state = connected
w.state = s
w.fieldMutex.Unlock()
}

Expand All @@ -619,17 +622,11 @@ func (w *Websocket) IsConnected() bool {
return w.state == connected
}

func (w *Websocket) setConnectingStatus(b bool) {
w.fieldMutex.Lock()
w.state = connecting
w.fieldMutex.Unlock()
}

// IsConnecting returns status of connecting
func (w *Websocket) IsConnecting() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.connecting
return w.state == connecting
}

func (w *Websocket) setEnabled(b bool) {
Expand Down
86 changes: 39 additions & 47 deletions exchanges/stream/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestSetup(t *testing.T) {
t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised)
}

w.Init = true
w.setState(disconnected)
err = w.Setup(websocketSetup)
if !errors.Is(err, errExchangeConfigIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil)
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestSetup(t *testing.T) {

func TestTrafficMonitorTimeout(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()
if err := ws.Setup(defaultSetup); err != nil {
t.Fatal(err)
}
Expand All @@ -232,7 +232,7 @@ func TestTrafficMonitorTimeout(t *testing.T) {
t.Fatal("traffic monitor should be running")
}
// prevent shutdown routine
ws.setConnectedStatus(false)
ws.setState(disconnected)
// await timeout closure
ws.Wg.Wait()
if ws.IsTrafficMonitorRunning() {
Expand Down Expand Up @@ -284,21 +284,21 @@ func TestConnectionMessageErrors(t *testing.T) {
}

wsWrong.setEnabled(true)
wsWrong.setConnectingStatus(true)
wsWrong.setState(connecting)
wsWrong.Wg = &sync.WaitGroup{}
err = wsWrong.Connect()
if err == nil {
t.Fatal("error cannot be nil")
}

wsWrong.setConnectedStatus(false)
wsWrong.setState(disconnected)
wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") }
err = wsWrong.Connect()
if err == nil {
t.Fatal("error cannot be nil")
}

ws := *New()
ws := NewWebsocket()
err = ws.Setup(defaultSetup)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -345,31 +345,35 @@ outer:

func TestWebsocket(t *testing.T) {
t.Parallel()
wsInit := Websocket{}
err := wsInit.Setup(&WebsocketSetup{
ExchangeConfig: &config.Exchange{
Features: &config.FeaturesConfig{
Enabled: config.FeaturesEnabledConfig{Websocket: true},
},
Name: "test",
},
})
assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "SetProxyAddress should error correctly")

ws := *New()
err = ws.SetProxyAddress("garbagio")
ws := NewWebsocket()

err := ws.SetProxyAddress("garbagio")
assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly")

ws.Conn = &WebsocketConnection{}
ws.AuthConn = &WebsocketConnection{}
ws.setEnabled(true)

err = ws.Setup(defaultSetup) // Sets to enabled again
require.NoError(t, err, "Setup may not error")

err = ws.Setup(defaultSetup)
assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly if called twice")

assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly")
assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup")

ws.setEnabled(false)
assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)")

ws.setEnabled(true)
assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)")

err = ws.SetProxyAddress("https://192.168.0.1:1337")
assert.NoError(t, err, "SetProxyAddress should not error when not yet connected")

ws.setConnectedStatus(true)
ws.ShutdownC = make(chan struct{})
ws.Wg = &sync.WaitGroup{}
ws.setState(connected)

err = ws.SetProxyAddress("https://192.168.0.1:1336")
assert.ErrorIs(t, err, errNoConnectFunc, "SetProxyAddress should call Connect and error from there") // This test asserts we actually set the proxy address, etc
Expand All @@ -386,39 +390,27 @@ func TestWebsocket(t *testing.T) {
err = ws.SetProxyAddress("http://localhost:1337")
assert.NoError(t, err, "SetProxyAddress should not error")

err = ws.Setup(defaultSetup) // Sets to enabled again
require.NoError(t, err, "Setup may not error")

assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly")
assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup")

ws.setEnabled(false)
assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)")

ws.setEnabled(true)
assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)")

assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly")
assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly")
assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly")

err = ws.Shutdown()
assert.ErrorIs(t, err, ErrNotConnected, "Shutdown should error when not Connected")

ws.setConnectedStatus(true)
ws.setState(connected)
ws.Conn = &dodgyConnection{}
err = ws.Shutdown()
assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy conn")

ws.Conn = &WebsocketConnection{}

ws.setConnectedStatus(true)
ws.setState(connected)
ws.AuthConn = &dodgyConnection{}
err = ws.Shutdown()
assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn")

ws.AuthConn = &WebsocketConnection{}
ws.setConnectedStatus(false)
ws.setState(disconnected)

err = ws.Connect()
assert.NoError(t, err, "Connect should not error")
Expand Down Expand Up @@ -456,7 +448,7 @@ func TestWebsocket(t *testing.T) {
// TestSubscribe logic test
func TestSubscribeUnsubscribe(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")

fnSub := func(subs []subscription.Subscription) error {
Expand Down Expand Up @@ -501,7 +493,7 @@ func TestSubscribeUnsubscribe(t *testing.T) {
// TestResubscribe tests Resubscribing to existing subscriptions
func TestResubscribe(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()

wackedOutSetup := *defaultSetup
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
Expand Down Expand Up @@ -532,7 +524,7 @@ func TestResubscribe(t *testing.T) {
// TestSubscriptionState tests Subscription state changes
func TestSubscriptionState(t *testing.T) {
t.Parallel()
ws := New()
ws := NewWebsocket()

c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState}
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error")
Expand All @@ -558,7 +550,7 @@ func TestSubscriptionState(t *testing.T) {
// TestRemoveSubscriptions tests removing a subscription
func TestRemoveSubscriptions(t *testing.T) {
t.Parallel()
ws := New()
ws := NewWebsocket()

c := &subscription.Subscription{Key: 42, Channel: "Unite!"}
assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
Expand All @@ -571,7 +563,7 @@ func TestRemoveSubscriptions(t *testing.T) {
// TestConnectionMonitorNoConnection logic test
func TestConnectionMonitorNoConnection(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()
ws.connectionMonitorDelay = 500
ws.DataHandler = make(chan interface{}, 1)
ws.ShutdownC = make(chan struct{}, 1)
Expand Down Expand Up @@ -626,7 +618,7 @@ func TestGetSubscriptions(t *testing.T) {
// TestSetCanUseAuthenticatedEndpoints logic test
func TestSetCanUseAuthenticatedEndpoints(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()
result := ws.CanUseAuthenticatedEndpoints()
if result {
t.Error("expected `canUseAuthenticatedEndpoints` to be false")
Expand Down Expand Up @@ -925,7 +917,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) {
if resp {
t.Error("Expected false, `connected` is false")
}
ws.setConnectedStatus(true)
ws.setState(connected)
resp = ws.CanUseAuthenticatedWebsocketForWrapper()
if resp {
t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false")
Expand Down Expand Up @@ -1109,7 +1101,7 @@ func TestFlushChannels(t *testing.T) {

web := Websocket{
enabled: true,
connected: true,
state: connected,
connector: connect,
ShutdownC: make(chan struct{}),
Subscriber: newgen.SUBME,
Expand Down Expand Up @@ -1204,7 +1196,7 @@ func TestFlushChannels(t *testing.T) {
t.Fatal(err)
}

web.setConnectedStatus(true)
web.setState(connected)
web.features.Unsubscribe = true
err = web.FlushChannels()
if err != nil {
Expand All @@ -1216,7 +1208,7 @@ func TestDisable(t *testing.T) {
t.Parallel()
web := Websocket{
enabled: true,
connected: true,
state: connected,
ShutdownC: make(chan struct{}),
}
err := web.Disable()
Expand Down Expand Up @@ -1284,7 +1276,7 @@ func TestSetupNewConnection(t *testing.T) {
connector: connect,
Wg: new(sync.WaitGroup),
ShutdownC: make(chan struct{}),
Init: true,
state: disconnected,
TrafficAlert: make(chan struct{}),
ReadMessageErrors: make(chan error),
DataHandler: make(chan interface{}),
Expand Down
6 changes: 3 additions & 3 deletions exchanges/stream/websocket_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ const (

type subscriptionMap map[any]*subscription.Subscription

type State int
type state int

const (
uninitialised State = iota
uninitialised state = iota
disconnected
connecting
connected
Expand All @@ -39,7 +39,7 @@ const (
type Websocket struct {
canUseAuthenticatedEndpoints bool
enabled bool
state State
state state
connecting bool
verbose bool
connectionMonitorRunning bool
Expand Down

0 comments on commit 47f0df3

Please sign in to comment.