Skip to content

Commit

Permalink
Websocket: Simplify Connecting/Connected state
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Jan 29, 2024
1 parent e8395a5 commit 7dbb6cc
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 52 deletions.
33 changes: 14 additions & 19 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ var (
errChannelAlreadySubscribed = errors.New("channel already subscribed")
errInvalidChannelState = errors.New("invalid Channel state")
errSameProxyAddress = errors.New("cannot set proxy address to the same address")
errNoConnectFunc = errors.New("connect func not set")
errAlreadyConnected = errors.New("already connected")
)

var globalReporter Reporter
Expand All @@ -75,8 +77,8 @@ func SetupGlobalReporter(r Reporter) {
globalReporter = r
}

// New initialises the websocket struct
func New() *Websocket {
// NewWebsocket initialises the websocket struct
func NewWebsocket() *Websocket {
return &Websocket{
DataHandler: make(chan interface{}, defaultJobBuffer),
ToRoutine: make(chan interface{}, defaultJobBuffer),
Expand Down Expand Up @@ -270,15 +272,14 @@ func (w *Websocket) Connect() error {

w.dataMonitor()
w.trafficMonitor()
w.setConnectingStatus(true)
w.setState(connecting)

err := w.connector()
if err != nil {
w.setConnectingStatus(false)
w.setState(disconnected)
return fmt.Errorf("%v Error connecting %w", w.exchangeName, err)
}
w.setConnectedStatus(true)
w.setConnectingStatus(false)
w.setState(connected)

if !w.IsConnectionMonitorRunning() {
err = w.connectionMonitor()
Expand Down Expand Up @@ -306,6 +307,7 @@ func (w *Websocket) Connect() error {
}

// Disable disables the exchange websocket protocol
// Note that connectionMonitor will be responsible for shutting down the websocket after disabling
func (w *Websocket) Disable() error {
if !w.IsEnabled() {
return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName)
Expand Down Expand Up @@ -405,7 +407,7 @@ func (w *Websocket) connectionMonitor() error {
case err := <-w.ReadMessageErrors:
if IsDisconnectionError(err) {
log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err)
w.setConnectedStatus(false)
w.setState(disconnected)
}

w.DataHandler <- err
Expand Down Expand Up @@ -470,8 +472,7 @@ func (w *Websocket) Shutdown() error {
close(w.ShutdownC)
w.Wg.Wait()
w.ShutdownC = make(chan struct{})
w.setConnectedStatus(false)
w.setConnectingStatus(false)
w.setState(disconnected)
if w.verbose {
log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName)
}
Expand Down Expand Up @@ -565,7 +566,7 @@ func (w *Websocket) trafficMonitor() {
default:
}
}
w.setConnectedStatus(true)
w.setState(connected)
trafficTimer.Reset(w.trafficTimeout)
case <-trafficTimer.C: // Falls through when timer runs out
if w.verbose {
Expand Down Expand Up @@ -606,9 +607,9 @@ func (w *Websocket) trafficMonitor() {
}()
}

func (w *Websocket) setConnectedStatus(b bool) {
func (w *Websocket) setState(s state) {
w.fieldMutex.Lock()
w.state = connected
w.state = s
w.fieldMutex.Unlock()
}

Expand All @@ -619,17 +620,11 @@ func (w *Websocket) IsConnected() bool {
return w.state == connected
}

func (w *Websocket) setConnectingStatus(b bool) {
w.fieldMutex.Lock()
w.state = connecting
w.fieldMutex.Unlock()
}

// IsConnecting returns status of connecting
func (w *Websocket) IsConnecting() bool {
w.fieldMutex.RLock()
defer w.fieldMutex.RUnlock()
return w.connecting
return w.state == connecting
}

func (w *Websocket) setEnabled(b bool) {
Expand Down
65 changes: 35 additions & 30 deletions exchanges/stream/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestSetup(t *testing.T) {
t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised)
}

w.Init = true
w.setState(disconnected)
err = w.Setup(websocketSetup)
if !errors.Is(err, errExchangeConfigIsNil) {
t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil)
Expand Down Expand Up @@ -214,7 +214,7 @@ func TestSetup(t *testing.T) {

func TestTrafficMonitorTimeout(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()
if err := ws.Setup(defaultSetup); err != nil {
t.Fatal(err)
}
Expand All @@ -232,7 +232,7 @@ func TestTrafficMonitorTimeout(t *testing.T) {
t.Fatal("traffic monitor should be running")
}
// prevent shutdown routine
ws.setConnectedStatus(false)
ws.setState(disconnected)
// await timeout closure
ws.Wg.Wait()
if ws.IsTrafficMonitorRunning() {
Expand Down Expand Up @@ -284,21 +284,21 @@ func TestConnectionMessageErrors(t *testing.T) {
}

wsWrong.setEnabled(true)
wsWrong.setConnectingStatus(true)
wsWrong.setState(connecting)
wsWrong.Wg = &sync.WaitGroup{}
err = wsWrong.Connect()
if err == nil {
t.Fatal("error cannot be nil")
}

wsWrong.setConnectedStatus(false)
wsWrong.setState(disconnected)
wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") }
err = wsWrong.Connect()
if err == nil {
t.Fatal("error cannot be nil")
}

ws := *New()
ws := NewWebsocket()
err = ws.Setup(defaultSetup)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -346,17 +346,22 @@ outer:
func TestWebsocket(t *testing.T) {
t.Parallel()
wsInit := Websocket{}
err := wsInit.Setup(&WebsocketSetup{
ExchangeConfig: &config.Exchange{
Features: &config.FeaturesConfig{
Enabled: config.FeaturesEnabledConfig{Websocket: true},
s :=
&WebsocketSetup{
ExchangeConfig: &config.Exchange{
Features: &config.FeaturesConfig{
Enabled: config.FeaturesEnabledConfig{Websocket: true},
},
Name: "test",
},
Name: "test",
},
})
assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "SetProxyAddress should error correctly")
}
err := wsInit.Setup(s)
assert.NoError(t, err, "Setup should not error first time around")

err = wsInit.Setup(s)
assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly")

ws := *New()
ws := NewWebsocket()
err = ws.SetProxyAddress("garbagio")
assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly")

Expand All @@ -367,7 +372,7 @@ func TestWebsocket(t *testing.T) {
err = ws.SetProxyAddress("https://192.168.0.1:1337")
assert.NoError(t, err, "SetProxyAddress should not error when not yet connected")

ws.setConnectedStatus(true)
ws.setState(connected)
ws.ShutdownC = make(chan struct{})
ws.Wg = &sync.WaitGroup{}

Expand Down Expand Up @@ -405,20 +410,20 @@ func TestWebsocket(t *testing.T) {
err = ws.Shutdown()
assert.ErrorIs(t, err, ErrNotConnected, "Shutdown should error when not Connected")

ws.setConnectedStatus(true)
ws.setState(connected)
ws.Conn = &dodgyConnection{}
err = ws.Shutdown()
assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy conn")

ws.Conn = &WebsocketConnection{}

ws.setConnectedStatus(true)
ws.setState(connected)
ws.AuthConn = &dodgyConnection{}
err = ws.Shutdown()
assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn")

ws.AuthConn = &WebsocketConnection{}
ws.setConnectedStatus(false)
ws.setState(disconnected)

err = ws.Connect()
assert.NoError(t, err, "Connect should not error")
Expand Down Expand Up @@ -456,7 +461,7 @@ func TestWebsocket(t *testing.T) {
// TestSubscribe logic test
func TestSubscribeUnsubscribe(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")

fnSub := func(subs []subscription.Subscription) error {
Expand Down Expand Up @@ -501,7 +506,7 @@ func TestSubscribeUnsubscribe(t *testing.T) {
// TestResubscribe tests Resubscribing to existing subscriptions
func TestResubscribe(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()

wackedOutSetup := *defaultSetup
wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1
Expand Down Expand Up @@ -532,7 +537,7 @@ func TestResubscribe(t *testing.T) {
// TestSubscriptionState tests Subscription state changes
func TestSubscriptionState(t *testing.T) {
t.Parallel()
ws := New()
ws := NewWebsocket()

c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState}
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error")
Expand All @@ -558,7 +563,7 @@ func TestSubscriptionState(t *testing.T) {
// TestRemoveSubscriptions tests removing a subscription
func TestRemoveSubscriptions(t *testing.T) {
t.Parallel()
ws := New()
ws := NewWebsocket()

c := &subscription.Subscription{Key: 42, Channel: "Unite!"}
assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
Expand All @@ -571,7 +576,7 @@ func TestRemoveSubscriptions(t *testing.T) {
// TestConnectionMonitorNoConnection logic test
func TestConnectionMonitorNoConnection(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()
ws.connectionMonitorDelay = 500
ws.DataHandler = make(chan interface{}, 1)
ws.ShutdownC = make(chan struct{}, 1)
Expand Down Expand Up @@ -626,7 +631,7 @@ func TestGetSubscriptions(t *testing.T) {
// TestSetCanUseAuthenticatedEndpoints logic test
func TestSetCanUseAuthenticatedEndpoints(t *testing.T) {
t.Parallel()
ws := *New()
ws := NewWebsocket()
result := ws.CanUseAuthenticatedEndpoints()
if result {
t.Error("expected `canUseAuthenticatedEndpoints` to be false")
Expand Down Expand Up @@ -925,7 +930,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) {
if resp {
t.Error("Expected false, `connected` is false")
}
ws.setConnectedStatus(true)
ws.setState(connected)
resp = ws.CanUseAuthenticatedWebsocketForWrapper()
if resp {
t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false")
Expand Down Expand Up @@ -1109,7 +1114,7 @@ func TestFlushChannels(t *testing.T) {

web := Websocket{
enabled: true,
connected: true,
state: connected,
connector: connect,
ShutdownC: make(chan struct{}),
Subscriber: newgen.SUBME,
Expand Down Expand Up @@ -1204,7 +1209,7 @@ func TestFlushChannels(t *testing.T) {
t.Fatal(err)
}

web.setConnectedStatus(true)
web.setState(connected)
web.features.Unsubscribe = true
err = web.FlushChannels()
if err != nil {
Expand All @@ -1216,7 +1221,7 @@ func TestDisable(t *testing.T) {
t.Parallel()
web := Websocket{
enabled: true,
connected: true,
state: connected,
ShutdownC: make(chan struct{}),
}
err := web.Disable()
Expand Down Expand Up @@ -1284,7 +1289,7 @@ func TestSetupNewConnection(t *testing.T) {
connector: connect,
Wg: new(sync.WaitGroup),
ShutdownC: make(chan struct{}),
Init: true,
state: disconnected,
TrafficAlert: make(chan struct{}),
ReadMessageErrors: make(chan error),
DataHandler: make(chan interface{}),
Expand Down
6 changes: 3 additions & 3 deletions exchanges/stream/websocket_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ const (

type subscriptionMap map[any]*subscription.Subscription

type State int
type state int

const (
uninitialised State = iota
uninitialised state = iota
disconnected
connecting
connected
Expand All @@ -39,7 +39,7 @@ const (
type Websocket struct {
canUseAuthenticatedEndpoints bool
enabled bool
state State
state state
connecting bool
verbose bool
connectionMonitorRunning bool
Expand Down

0 comments on commit 7dbb6cc

Please sign in to comment.