From ffe3a00bbd79e7c47db72209657ea7fc5dbbacd3 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Fri, 26 Jan 2024 14:32:45 +0700 Subject: [PATCH 01/15] Websocket: Remove IsInit and simplify SetProxyAddress IsInit was basically the same as IsConnected. Any time Connect was called both would be set to true. Any time we had a disconnect they'd both be set to false Shutdown() incorrectly didn't setInit(false) SetProxyAddress simplified to only reconnect a connected Websocket. Any other state means it hasn't been Connected, or it's about to reconnect anyway. There's no handling for IsConnecting previously, either, so I've wrapped that behind the main mutex. --- exchanges/stream/websocket.go | 104 ++++++++++------------------------ 1 file changed, 30 insertions(+), 74 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index aefc3400f60..b49835bac80 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -286,7 +286,6 @@ func (w *Websocket) Connect() error { } w.setConnectedStatus(true) w.setConnectingStatus(false) - w.setInit(true) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() @@ -369,9 +368,7 @@ func (w *Websocket) dataMonitor() { case <-w.ShutdownC: return default: - log.Warnf(log.WebsocketMgr, - "%s exchange backlog in websocket processing detected", - w.exchangeName) + log.Warnf(log.WebsocketMgr, "%s exchange backlog in websocket processing detected", w.exchangeName) select { case w.ToRoutine <- d: case <-w.ShutdownC: @@ -396,26 +393,19 @@ func (w *Websocket) connectionMonitor() error { timer := time.NewTimer(delay) for { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: running connection monitor cycle\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) } if !w.IsEnabled() { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connectionMonitor - websocket disabled, shutting down\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) } if w.IsConnected() { - err := w.Shutdown() - if err != nil { + if err := w.Shutdown(); err != nil { log.Errorln(log.WebsocketMgr, err) } } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connection monitor exiting\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) } timer.Stop() w.setConnectionMonitorRunning(false) @@ -424,10 +414,7 @@ func (w *Websocket) connectionMonitor() error { select { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { - w.setInit(false) - log.Warnf(log.WebsocketMgr, - "%v websocket has been disconnected. Reason: %v", - w.exchangeName, err) + log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) w.setConnectedStatus(false) } @@ -459,21 +446,16 @@ func (w *Websocket) Shutdown() error { defer w.m.Unlock() if !w.IsConnected() { - return fmt.Errorf("%v websocket: cannot shutdown %w", - w.exchangeName, - ErrNotConnected) + return fmt.Errorf("%v websocket: cannot shutdown %w", w.exchangeName, ErrNotConnected) } // TODO: Interrupt connection and or close connection when it is re-established. if w.IsConnecting() { - return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", - w.exchangeName) + return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", w.exchangeName) } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: shutting down websocket\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", w.exchangeName) } defer w.Orderbook.FlushBuffer() @@ -501,9 +483,7 @@ func (w *Websocket) Shutdown() error { w.setConnectedStatus(false) w.setConnectingStatus(false) if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: completed websocket shutdown\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } return nil } @@ -582,9 +562,7 @@ func (w *Websocket) trafficMonitor() { select { case <-w.ShutdownC: if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown message received\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) } trafficTimer.Stop() w.setTrafficMonitorRunning(false) @@ -601,10 +579,7 @@ func (w *Websocket) trafficMonitor() { trafficTimer.Reset(w.trafficTimeout) case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { - log.Warnf(log.WebsocketMgr, - "%v websocket: has not received a traffic alert in %v. Reconnecting", - w.exchangeName, - w.trafficTimeout) + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) } trafficTimer.Stop() w.setTrafficMonitorRunning(false) @@ -612,9 +587,7 @@ func (w *Websocket) trafficMonitor() { if !w.IsConnecting() && w.IsConnected() { err := w.Shutdown() if err != nil { - log.Errorf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown err: %s", - w.exchangeName, err) + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) } } @@ -682,19 +655,6 @@ func (w *Websocket) IsEnabled() bool { return w.enabled } -func (w *Websocket) setInit(b bool) { - w.fieldMutex.Lock() - w.Init = b - w.fieldMutex.Unlock() -} - -// IsInit returns status of init -func (w *Websocket) IsInit() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.Init -} - func (w *Websocket) setTrafficMonitorRunning(b bool) { w.fieldMutex.Lock() w.trafficMonitorRunning = b @@ -820,28 +780,22 @@ func (w *Websocket) GetWebsocketURL() string { // SetProxyAddress sets websocket proxy address func (w *Websocket) SetProxyAddress(proxyAddr string) error { + w.m.Lock() + if proxyAddr != "" { - _, err := url.ParseRequestURI(proxyAddr) - if err != nil { - return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", - w.exchangeName, - err) + if _, err := url.ParseRequestURI(proxyAddr); err != nil { + w.m.Unlock() + return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", w.exchangeName, err) } if w.proxyAddr == proxyAddr { - return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", - w.exchangeName, - w.proxyAddr) + w.m.Unlock() + return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", w.exchangeName, w.proxyAddr) } - log.Debugf(log.ExchangeSys, - "%s websocket: setting websocket proxy: %s\n", - w.exchangeName, - proxyAddr) + log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr) } else { - log.Debugf(log.ExchangeSys, - "%s websocket: removing websocket proxy\n", - w.exchangeName) + log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) } if w.Conn != nil { @@ -852,15 +806,17 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { } w.proxyAddr = proxyAddr - if w.IsInit() && w.IsEnabled() { - if w.IsConnected() { - err := w.Shutdown() - if err != nil { - return err - } + + if w.IsConnected() { + w.m.Unlock() + if err := w.Shutdown(); err != nil { + return err } return w.Connect() } + + w.m.Unlock() + return nil } From 2e086aeead674b99acb05e53f13426279b7d5a02 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 27 Jan 2024 12:13:49 +0700 Subject: [PATCH 02/15] Websocket: Expand and Assertify tests --- exchanges/stream/websocket.go | 19 ++-- exchanges/stream/websocket_test.go | 141 ++++++++++------------------- 2 files changed, 55 insertions(+), 105 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index b49835bac80..bb23c9a8464 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -64,6 +64,7 @@ var ( errNoSubscriptionsSupplied = errors.New("no subscriptions supplied") errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") ) var globalReporter Reporter @@ -262,12 +263,10 @@ func (w *Websocket) Connect() error { return errors.New(WebsocketNotEnabled) } if w.IsConnecting() { - return fmt.Errorf("%v Websocket already attempting to connect", - w.exchangeName) + return fmt.Errorf("%v Websocket already attempting to connect", w.exchangeName) } if w.IsConnected() { - return fmt.Errorf("%v Websocket already connected", - w.exchangeName) + return fmt.Errorf("%v Websocket already connected", w.exchangeName) } w.subscriptionMutex.Lock() @@ -281,8 +280,7 @@ func (w *Websocket) Connect() error { err := w.connector() if err != nil { w.setConnectingStatus(false) - return fmt.Errorf("%v Error connecting %s", - w.exchangeName, err) + return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } w.setConnectedStatus(true) w.setConnectingStatus(false) @@ -290,10 +288,7 @@ func (w *Websocket) Connect() error { if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() if err != nil { - log.Errorf(log.WebsocketMgr, - "%s cannot start websocket connection monitor %v", - w.GetName(), - err) + log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) } } @@ -785,12 +780,12 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { if proxyAddr != "" { if _, err := url.ParseRequestURI(proxyAddr); err != nil { w.m.Unlock() - return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", w.exchangeName, err) + return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err) } if w.proxyAddr == proxyAddr { w.m.Unlock() - return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", w.exchangeName, w.proxyAddr) + return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr) } log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3ab49e0df10..650509b11d1 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -30,6 +31,10 @@ const ( proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) +var ( + errDastardlyReason = errors.New("cannot shutdown due to some dastardly reason") +) + var dialer websocket.Dialer type testStruct struct { @@ -92,12 +97,12 @@ type dodgyConnection struct { // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Shutdown() error { - return errors.New("cannot shutdown due to some dastardly reason") + return errDastardlyReason } // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Connect() error { - return errors.New("cannot connect due to some dastardly reason") + return errDastardlyReason } func TestSetup(t *testing.T) { @@ -349,152 +354,102 @@ func TestWebsocket(t *testing.T) { Name: "test", }, }) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } + assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "SetProxyAddress should error correctly") ws := *New() err = ws.SetProxyAddress("garbagio") - if err == nil { - t.Error("error cannot be nil") - } + assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") ws.Conn = &WebsocketConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) + err = ws.SetProxyAddress("https://192.168.0.1:1337") - if err == nil { - t.Error("error cannot be nil") - } + assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") + ws.setConnectedStatus(true) ws.ShutdownC = make(chan struct{}) ws.Wg = &sync.WaitGroup{} + err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } + assert.ErrorIs(t, err, errNoConnectFunc, "SetProxyAddress should call Connect and error from there") // This test asserts we actually set the proxy address, etc err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } + assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") ws.setEnabled(false) // removing proxy err = ws.SetProxyAddress("") - if err != nil { - t.Error(err) - } + assert.NoError(t, err, "SetProxyAddress should not error when removing proxy") + // reinstate proxy err = ws.SetProxyAddress("http://localhost:1337") - if err != nil { - t.Error(err) - } - // conflict proxy - err = ws.SetProxyAddress("http://localhost:1337") - if err == nil { - t.Error("error cannot be nil") - } - err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } - if ws.GetName() != "exchangeName" { - t.Error("WebsocketSetup") - } + assert.NoError(t, err, "SetProxyAddress should not error") - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } + err = ws.Setup(defaultSetup) // Sets to enabled again + require.NoError(t, err, "Setup may not error") + + assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly") + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") ws.setEnabled(false) - if ws.IsEnabled() { - t.Error("WebsocketSetup") - } + assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") + ws.setEnabled(true) - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") - if ws.GetProxyAddress() != "http://localhost:1337" { - t.Error("WebsocketSetup") - } + assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") + assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") + assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") - if ws.GetWebsocketURL() != "wss://testRunningURL" { - t.Error("WebsocketSetup") - } - if ws.trafficTimeout != time.Second*5 { - t.Error("WebsocketSetup") - } - // -- Not connected shutdown err = ws.Shutdown() - if err == nil { - t.Fatal("should not be connected to able to shut down") - } + assert.ErrorIs(t, err, ErrNotConnected, "Shutdown should error when not Connected") ws.setConnectedStatus(true) ws.Conn = &dodgyConnection{} err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy conn") ws.Conn = &WebsocketConnection{} ws.setConnectedStatus(true) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil ") - } + assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} ws.setConnectedStatus(false) - // -- Normal connect err = ws.Connect() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Connect should not error") ws.defaultURL = "ws://demos.kaazing.com/echo" ws.defaultURLAuth = "ws://demos.kaazing.com/echo" err = ws.SetWebsocketURL("", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("", true, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, false) - if err != nil { - t.Fatal(err) - } - // Attempt reconnect + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, true) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error on reconnect") + // -- initiate the reconnect which is usually handled by connection monitor err = ws.Connect() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "ReConnect called manually should not error") + err = ws.Connect() - if err == nil { - t.Fatal("should already be connected") - } - // -- Normal shutdown + assert.ErrorIs(t, err, errAlreadyConnected, "ReConnect should error when already connected") + err = ws.Shutdown() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Shutdown should not error") ws.Wg.Wait() } From 30f36ca15ad56c9506f0d4f747101029dc6572fe Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sun, 28 Jan 2024 08:24:15 +0700 Subject: [PATCH 03/15] Websocket: Simplify state transistions --- exchanges/stream/websocket.go | 9 ++++----- exchanges/stream/websocket_types.go | 13 ++++++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index bb23c9a8464..8f27bca4c50 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -78,7 +78,6 @@ func SetupGlobalReporter(r Reporter) { // New initialises the websocket struct func New() *Websocket { return &Websocket{ - Init: true, DataHandler: make(chan interface{}, defaultJobBuffer), ToRoutine: make(chan interface{}, defaultJobBuffer), TrafficAlert: make(chan struct{}), @@ -99,7 +98,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } - if !w.Init { + if w.state != uninitialised { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } @@ -613,7 +612,7 @@ func (w *Websocket) trafficMonitor() { func (w *Websocket) setConnectedStatus(b bool) { w.fieldMutex.Lock() - w.connected = b + w.state = connected w.fieldMutex.Unlock() } @@ -621,12 +620,12 @@ func (w *Websocket) setConnectedStatus(b bool) { func (w *Websocket) IsConnected() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() - return w.connected + return w.state == connected } func (w *Websocket) setConnectingStatus(b bool) { w.fieldMutex.Lock() - w.connecting = b + w.state = connecting w.fieldMutex.Unlock() } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 925c34b907c..9f3ec68a317 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -25,14 +25,21 @@ const ( type subscriptionMap map[any]*subscription.Subscription +type State int + +const ( + uninitialised State = iota + disconnected + connecting + connected +) + // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { canUseAuthenticatedEndpoints bool enabled bool - Init bool - connected bool - connecting bool + state state verbose bool connectionMonitorRunning bool trafficMonitorRunning bool From e1c2055a670b0277b4cc3bda93975c532b33be2f Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 29 Jan 2024 14:48:54 +0700 Subject: [PATCH 04/15] Websocket: Simplify Connecting/Connected state --- engine/websocketroutine_manager_test.go | 2 +- exchanges/binance/binance_wrapper.go | 2 +- exchanges/binanceus/binanceus_wrapper.go | 2 +- exchanges/bitfinex/bitfinex_wrapper.go | 2 +- exchanges/bithumb/bithumb_wrapper.go | 2 +- exchanges/bitmex/bitmex_wrapper.go | 2 +- exchanges/bitstamp/bitstamp_wrapper.go | 2 +- exchanges/btcmarkets/btcmarkets_wrapper.go | 2 +- exchanges/btse/btse_wrapper.go | 2 +- exchanges/bybit/bybit_wrapper.go | 2 +- exchanges/coinbasepro/coinbasepro_wrapper.go | 2 +- exchanges/coinut/coinut_wrapper.go | 2 +- exchanges/exchange_test.go | 6 +- exchanges/gateio/gateio_wrapper.go | 2 +- exchanges/gemini/gemini_wrapper.go | 2 +- exchanges/hitbtc/hitbtc_wrapper.go | 2 +- exchanges/huobi/huobi_wrapper.go | 2 +- exchanges/kraken/kraken_wrapper.go | 2 +- exchanges/kucoin/kucoin_wrapper.go | 2 +- exchanges/okcoin/okcoin_wrapper.go | 2 +- exchanges/okx/okx_wrapper.go | 2 +- exchanges/poloniex/poloniex_wrapper.go | 2 +- .../sharedtestvalues/sharedtestvalues.go | 1 - exchanges/stream/websocket.go | 52 +++++------ exchanges/stream/websocket_test.go | 86 +++++++++---------- exchanges/stream/websocket_types.go | 4 +- 26 files changed, 93 insertions(+), 98 deletions(-) diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index e082193bda4..c1c3541db31 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -293,7 +293,7 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { t.Fatal("unexpected data handlers registered") } - mock := stream.New() + mock := stream.NewWebsocket() mock.ToRoutine = make(chan interface{}) m.state = readyState err = m.websocketDataReceiver(mock) diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index bdf70e0d3ee..d54f89242ca 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -238,7 +238,7 @@ func (b *Binance) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/binanceus/binanceus_wrapper.go b/exchanges/binanceus/binanceus_wrapper.go index 3f078ede7a8..ed2e5d10f5e 100644 --- a/exchanges/binanceus/binanceus_wrapper.go +++ b/exchanges/binanceus/binanceus_wrapper.go @@ -162,7 +162,7 @@ func (bi *Binanceus) SetDefaults() { "%s setting default endpoints error %v", bi.Name, err) } - bi.Websocket = stream.New() + bi.Websocket = stream.NewWebsocket() bi.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit bi.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout bi.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 7e1ef64c766..2aba9987cf9 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -198,7 +198,7 @@ func (b *Bitfinex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 5dbc314674d..24ea8980b19 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -150,7 +150,7 @@ func (b *Bithumb) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 151f1ef28e9..a0810b1a302 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -175,7 +175,7 @@ func (b *Bitmex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index 888fae0b395..2fe7fa96d82 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -146,7 +146,7 @@ func (b *Bitstamp) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index 17ee7277c0a..8a924b08cf0 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -150,7 +150,7 @@ func (b *BTCMarkets) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index f7ec3c8f702..3cdbd56b2c5 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -176,7 +176,7 @@ func (b *BTSE) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 28d4f15041d..b6e7f2be224 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -216,7 +216,7 @@ func (by *Bybit) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - by.Websocket = stream.New() + by.Websocket = stream.NewWebsocket() by.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit by.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout by.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 47a20d9ed95..21a34e2ea22 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -145,7 +145,7 @@ func (c *CoinbasePro) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 503c2909461..db4af53badd 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -127,7 +127,7 @@ func (c *COINUT) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index d41f499d48d..ad0060495f3 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -198,7 +198,7 @@ func TestSetClientProxyAddress(t *testing.T) { Name: "rawr", Requester: requester} - newBase.Websocket = stream.New() + newBase.Websocket = stream.NewWebsocket() err = newBase.SetClientProxyAddress("") if err != nil { t.Error(err) @@ -1251,7 +1251,7 @@ func TestSetupDefaults(t *testing.T) { } // Test websocket support - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.Features.Supports.Websocket = true err = b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ @@ -1596,7 +1596,7 @@ func TestIsWebsocketEnabled(t *testing.T) { t.Error("exchange doesn't support websocket") } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() err := b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ Enabled: true, diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 3026efc0608..adea4d2fa6b 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -194,7 +194,7 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index fee75d6b1a2..d2d89eb5008 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -128,7 +128,7 @@ func (g *Gemini) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index 3b65fe649b5..7bcf7ada335 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -147,7 +147,7 @@ func (h *HitBTC) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 3d1d576b9d2..90d70491bd8 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -202,7 +202,7 @@ func (h *HUOBI) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index de631e0d904..5875917592c 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -209,7 +209,7 @@ func (k *Kraken) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - k.Websocket = stream.New() + k.Websocket = stream.NewWebsocket() k.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit k.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout k.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kucoin/kucoin_wrapper.go b/exchanges/kucoin/kucoin_wrapper.go index 8a98a77a02d..d767a2cf441 100644 --- a/exchanges/kucoin/kucoin_wrapper.go +++ b/exchanges/kucoin/kucoin_wrapper.go @@ -195,7 +195,7 @@ func (ku *Kucoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - ku.Websocket = stream.New() + ku.Websocket = stream.NewWebsocket() ku.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit ku.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout ku.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okcoin/okcoin_wrapper.go b/exchanges/okcoin/okcoin_wrapper.go index 8519cdc1b11..af0655ef3f1 100644 --- a/exchanges/okcoin/okcoin_wrapper.go +++ b/exchanges/okcoin/okcoin_wrapper.go @@ -150,7 +150,7 @@ func (o *Okcoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - o.Websocket = stream.New() + o.Websocket = stream.NewWebsocket() o.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit o.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout o.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index 0b5472b9e65..64e2e269877 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -190,7 +190,7 @@ func (ok *Okx) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - ok.Websocket = stream.New() + ok.Websocket = stream.NewWebsocket() ok.WebsocketResponseMaxLimit = okxWebsocketResponseMaxLimit ok.WebsocketResponseCheckTimeout = okxWebsocketResponseMaxLimit ok.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index 97eb9293d56..57f28fac98b 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -159,7 +159,7 @@ func (p *Poloniex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - p.Websocket = stream.New() + p.Websocket = stream.NewWebsocket() p.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit p.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout p.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index 60d51d82bfe..54acf9dcb5b 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -57,7 +57,6 @@ func GetWebsocketStructChannelOverride() chan struct{} { // NewTestWebsocket returns a test websocket object func NewTestWebsocket() *stream.Websocket { return &stream.Websocket{ - Init: true, DataHandler: make(chan interface{}, WebsocketChannelOverrideCapacity), ToRoutine: make(chan interface{}, 1000), TrafficAlert: make(chan struct{}), diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 8f27bca4c50..53b49ab93c3 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -65,6 +65,8 @@ var ( errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("connect func not set") + errAlreadyConnected = errors.New("already connected") ) var globalReporter Reporter @@ -75,8 +77,8 @@ func SetupGlobalReporter(r Reporter) { globalReporter = r } -// New initialises the websocket struct -func New() *Websocket { +// NewWebsocket initialises the websocket struct +func NewWebsocket() *Websocket { return &Websocket{ DataHandler: make(chan interface{}, defaultJobBuffer), ToRoutine: make(chan interface{}, defaultJobBuffer), @@ -98,7 +100,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } - if w.state != uninitialised { + if w.IsInitialised() { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } @@ -188,6 +190,8 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions) } w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection + w.setState(disconnected) + return nil } @@ -253,7 +257,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { // function func (w *Websocket) Connect() error { if w.connector == nil { - return errors.New("websocket connect function not set, cannot continue") + return errNoConnectFunc } w.m.Lock() defer w.m.Unlock() @@ -274,15 +278,14 @@ func (w *Websocket) Connect() error { w.dataMonitor() w.trafficMonitor() - w.setConnectingStatus(true) + w.setState(connecting) err := w.connector() if err != nil { - w.setConnectingStatus(false) + w.setState(disconnected) return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } - w.setConnectedStatus(true) - w.setConnectingStatus(false) + w.setState(connected) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() @@ -310,6 +313,7 @@ func (w *Websocket) Connect() error { } // Disable disables the exchange websocket protocol +// Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { if !w.IsEnabled() { return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName) @@ -409,7 +413,7 @@ func (w *Websocket) connectionMonitor() error { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) - w.setConnectedStatus(false) + w.setState(disconnected) } w.DataHandler <- err @@ -474,8 +478,7 @@ func (w *Websocket) Shutdown() error { close(w.ShutdownC) w.Wg.Wait() w.ShutdownC = make(chan struct{}) - w.setConnectedStatus(false) - w.setConnectingStatus(false) + w.setState(disconnected) if w.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } @@ -569,7 +572,7 @@ func (w *Websocket) trafficMonitor() { default: } } - w.setConnectedStatus(true) + w.setState(connected) trafficTimer.Reset(w.trafficTimeout) case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { @@ -610,30 +613,31 @@ func (w *Websocket) trafficMonitor() { }() } -func (w *Websocket) setConnectedStatus(b bool) { +// IsInitialised returns whether the websocket has been Setup() already +func (w *Websocket) IsInitialised() bool { + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() + return w.state != uninitialised +} + +func (w *Websocket) setState(s state) { w.fieldMutex.Lock() - w.state = connected + w.state = s w.fieldMutex.Unlock() } -// IsConnected returns status of connection +// IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() return w.state == connected } -func (w *Websocket) setConnectingStatus(b bool) { - w.fieldMutex.Lock() - w.state = connecting - w.fieldMutex.Unlock() -} - -// IsConnecting returns status of connecting +// IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() - return w.connecting + return w.state == connecting } func (w *Websocket) setEnabled(b bool) { @@ -642,7 +646,7 @@ func (w *Websocket) setEnabled(b bool) { w.fieldMutex.Unlock() } -// IsEnabled returns status of enabled +// IsEnabled returns whether the websocket is enabled func (w *Websocket) IsEnabled() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 650509b11d1..d3b142fa926 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -125,7 +125,7 @@ func TestSetup(t *testing.T) { t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) } - w.Init = true + w.setState(disconnected) err = w.Setup(websocketSetup) if !errors.Is(err, errExchangeConfigIsNil) { t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil) @@ -214,7 +214,7 @@ func TestSetup(t *testing.T) { func TestTrafficMonitorTimeout(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() if err := ws.Setup(defaultSetup); err != nil { t.Fatal(err) } @@ -232,7 +232,7 @@ func TestTrafficMonitorTimeout(t *testing.T) { t.Fatal("traffic monitor should be running") } // prevent shutdown routine - ws.setConnectedStatus(false) + ws.setState(disconnected) // await timeout closure ws.Wg.Wait() if ws.IsTrafficMonitorRunning() { @@ -284,21 +284,21 @@ func TestConnectionMessageErrors(t *testing.T) { } wsWrong.setEnabled(true) - wsWrong.setConnectingStatus(true) + wsWrong.setState(connecting) wsWrong.Wg = &sync.WaitGroup{} err = wsWrong.Connect() if err == nil { t.Fatal("error cannot be nil") } - wsWrong.setConnectedStatus(false) + wsWrong.setState(disconnected) wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") } err = wsWrong.Connect() if err == nil { t.Fatal("error cannot be nil") } - ws := *New() + ws := NewWebsocket() err = ws.Setup(defaultSetup) if err != nil { t.Fatal(err) @@ -345,31 +345,35 @@ outer: func TestWebsocket(t *testing.T) { t.Parallel() - wsInit := Websocket{} - err := wsInit.Setup(&WebsocketSetup{ - ExchangeConfig: &config.Exchange{ - Features: &config.FeaturesConfig{ - Enabled: config.FeaturesEnabledConfig{Websocket: true}, - }, - Name: "test", - }, - }) - assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "SetProxyAddress should error correctly") - ws := *New() - err = ws.SetProxyAddress("garbagio") + ws := NewWebsocket() + + err := ws.SetProxyAddress("garbagio") assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") ws.Conn = &WebsocketConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) + err = ws.Setup(defaultSetup) // Sets to enabled again + require.NoError(t, err, "Setup may not error") + + err = ws.Setup(defaultSetup) + assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly if called twice") + + assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly") + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") + + ws.setEnabled(false) + assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") + + ws.setEnabled(true) + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") + err = ws.SetProxyAddress("https://192.168.0.1:1337") assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") - ws.setConnectedStatus(true) - ws.ShutdownC = make(chan struct{}) - ws.Wg = &sync.WaitGroup{} + ws.setState(connected) err = ws.SetProxyAddress("https://192.168.0.1:1336") assert.ErrorIs(t, err, errNoConnectFunc, "SetProxyAddress should call Connect and error from there") // This test asserts we actually set the proxy address, etc @@ -386,18 +390,6 @@ func TestWebsocket(t *testing.T) { err = ws.SetProxyAddress("http://localhost:1337") assert.NoError(t, err, "SetProxyAddress should not error") - err = ws.Setup(defaultSetup) // Sets to enabled again - require.NoError(t, err, "Setup may not error") - - assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly") - assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") - - ws.setEnabled(false) - assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") - - ws.setEnabled(true) - assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") - assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") @@ -405,20 +397,20 @@ func TestWebsocket(t *testing.T) { err = ws.Shutdown() assert.ErrorIs(t, err, ErrNotConnected, "Shutdown should error when not Connected") - ws.setConnectedStatus(true) + ws.setState(connected) ws.Conn = &dodgyConnection{} err = ws.Shutdown() assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy conn") ws.Conn = &WebsocketConnection{} - ws.setConnectedStatus(true) + ws.setState(connected) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} - ws.setConnectedStatus(false) + ws.setState(disconnected) err = ws.Connect() assert.NoError(t, err, "Connect should not error") @@ -456,7 +448,7 @@ func TestWebsocket(t *testing.T) { // TestSubscribe logic test func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error") fnSub := func(subs []subscription.Subscription) error { @@ -501,7 +493,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { // TestResubscribe tests Resubscribing to existing subscriptions func TestResubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() wackedOutSetup := *defaultSetup wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 @@ -532,7 +524,7 @@ func TestResubscribe(t *testing.T) { // TestSubscriptionState tests Subscription state changes func TestSubscriptionState(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState} assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error") @@ -558,7 +550,7 @@ func TestSubscriptionState(t *testing.T) { // TestRemoveSubscriptions tests removing a subscription func TestRemoveSubscriptions(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Unite!"} assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") @@ -571,7 +563,7 @@ func TestRemoveSubscriptions(t *testing.T) { // TestConnectionMonitorNoConnection logic test func TestConnectionMonitorNoConnection(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() ws.connectionMonitorDelay = 500 ws.DataHandler = make(chan interface{}, 1) ws.ShutdownC = make(chan struct{}, 1) @@ -626,7 +618,7 @@ func TestGetSubscriptions(t *testing.T) { // TestSetCanUseAuthenticatedEndpoints logic test func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() result := ws.CanUseAuthenticatedEndpoints() if result { t.Error("expected `canUseAuthenticatedEndpoints` to be false") @@ -925,7 +917,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { if resp { t.Error("Expected false, `connected` is false") } - ws.setConnectedStatus(true) + ws.setState(connected) resp = ws.CanUseAuthenticatedWebsocketForWrapper() if resp { t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false") @@ -1109,7 +1101,7 @@ func TestFlushChannels(t *testing.T) { web := Websocket{ enabled: true, - connected: true, + state: connected, connector: connect, ShutdownC: make(chan struct{}), Subscriber: newgen.SUBME, @@ -1204,7 +1196,7 @@ func TestFlushChannels(t *testing.T) { t.Fatal(err) } - web.setConnectedStatus(true) + web.setState(connected) web.features.Unsubscribe = true err = web.FlushChannels() if err != nil { @@ -1216,7 +1208,7 @@ func TestDisable(t *testing.T) { t.Parallel() web := Websocket{ enabled: true, - connected: true, + state: connected, ShutdownC: make(chan struct{}), } err := web.Disable() @@ -1284,7 +1276,7 @@ func TestSetupNewConnection(t *testing.T) { connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), - Init: true, + state: disconnected, TrafficAlert: make(chan struct{}), ReadMessageErrors: make(chan error), DataHandler: make(chan interface{}), diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 9f3ec68a317..7af4ae03a1d 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -25,10 +25,10 @@ const ( type subscriptionMap map[any]*subscription.Subscription -type State int +type state int const ( - uninitialised State = iota + uninitialised state = iota disconnected connecting connected From 5cbdc1a47da9ed7e0a0453d8a1514ea4b0d1c16c Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 3 Feb 2024 09:25:36 +0700 Subject: [PATCH 05/15] Websocket: Tests and errors for websocket --- exchanges/stream/websocket.go | 14 ++-- exchanges/stream/websocket_test.go | 126 +++++++++-------------------- 2 files changed, 47 insertions(+), 93 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 53b49ab93c3..4fd0685bf7b 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -65,8 +65,10 @@ var ( errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") errSameProxyAddress = errors.New("cannot set proxy address to the same address") - errNoConnectFunc = errors.New("connect func not set") - errAlreadyConnected = errors.New("already connected") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") ) var globalReporter Reporter @@ -266,10 +268,10 @@ func (w *Websocket) Connect() error { return errors.New(WebsocketNotEnabled) } if w.IsConnecting() { - return fmt.Errorf("%v Websocket already attempting to connect", w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting) } if w.IsConnected() { - return fmt.Errorf("%v Websocket already connected", w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected) } w.subscriptionMutex.Lock() @@ -444,12 +446,12 @@ func (w *Websocket) Shutdown() error { defer w.m.Unlock() if !w.IsConnected() { - return fmt.Errorf("%v websocket: cannot shutdown %w", w.exchangeName, ErrNotConnected) + return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected) } // TODO: Interrupt connection and or close connection when it is re-established. if w.IsConnecting() { - return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", w.exchangeName) + return fmt.Errorf("%v %w: %w ", w.exchangeName, errCannotShutdown, errAlreadyReconnecting) } if w.verbose { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index d3b142fa926..188c8fcb30c 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -32,7 +32,7 @@ const ( ) var ( - errDastardlyReason = errors.New("cannot shutdown due to some dastardly reason") + errDastardlyReason = errors.New("some dastardly reason") ) var dialer websocket.Dialer @@ -73,7 +73,7 @@ var defaultSetup = &WebsocketSetup{ AuthenticatedWebsocketSupport: true, }, WebsocketTrafficTimeout: time.Second * 5, - Name: "exchangeName", + Name: "GTX", }, DefaultURL: "testDefaultURL", RunningURL: "wss://testRunningURL", @@ -97,147 +97,106 @@ type dodgyConnection struct { // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Shutdown() error { - return errDastardlyReason + return fmt.Errorf("%w: %w", errCannotShutdown, errDastardlyReason) } // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Connect() error { - return errDastardlyReason + return fmt.Errorf("cannot connect: %w", errDastardlyReason) } func TestSetup(t *testing.T) { t.Parallel() var w *Websocket err := w.Setup(nil) - if !errors.Is(err, errWebsocketIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketIsNil) - } + assert.ErrorIs(t, err, errWebsocketIsNil, "Setup should error correctly") w = &Websocket{DataHandler: make(chan interface{})} err = w.Setup(nil) - if !errors.Is(err, errWebsocketSetupIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSetupIsNil) - } + assert.ErrorIs(t, err, errWebsocketSetupIsNil, "Setup should error correctly") websocketSetup := &WebsocketSetup{} - err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } - w.setState(disconnected) err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil) - } + assert.ErrorIs(t, err, errExchangeConfigIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigNameUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigNameUnset) - } - websocketSetup.ExchangeConfig.Name = "testname" + assert.ErrorIs(t, err, errExchangeConfigNameUnset, "Setup should error correctly") + websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketFeaturesIsUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketFeaturesIsUnset) - } + assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset, "Setup should error correctly") websocketSetup.Features = &protocol.Features{} err = w.Setup(websocketSetup) - if !errors.Is(err, errConfigFeaturesIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errConfigFeaturesIsNil) - } + assert.ErrorIs(t, err, errConfigFeaturesIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketConnectorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketConnectorUnset) - } + assert.ErrorIs(t, err, errWebsocketConnectorUnset, "Setup should error correctly") websocketSetup.Connector = func() error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly") websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil } websocketSetup.Features.Unsubscribe = true err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketUnsubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketUnsubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly") websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriptionsGeneratorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriptionsGeneratorUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly") websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errDefaultURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errDefaultURLIsEmpty) - } + assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly") websocketSetup.DefaultURL = "test" err = w.Setup(websocketSetup) - if !errors.Is(err, errRunningURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errRunningURLIsEmpty) - } + assert.ErrorIs(t, err, errRunningURLIsEmpty, "Setup should error correctly") websocketSetup.RunningURL = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURL = "wss://www.google.com" websocketSetup.RunningURLAuth = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURLAuth = "wss://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidTrafficTimeout) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidTrafficTimeout) - } + assert.ErrorIs(t, err, errInvalidTrafficTimeout, "Setup should error correctly") websocketSetup.ExchangeConfig.WebsocketTrafficTimeout = time.Minute err = w.Setup(websocketSetup) - if !errors.Is(err, nil) { - t.Fatalf("received: %v but expected: %v", err, nil) - } + assert.NoError(t, err, "Setup should not error") } func TestTrafficMonitorTimeout(t *testing.T) { t.Parallel() ws := NewWebsocket() - if err := ws.Setup(defaultSetup); err != nil { - t.Fatal(err) - } - ws.trafficTimeout = time.Second * 2 + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + ws.trafficTimeout = time.Millisecond * 42 ws.ShutdownC = make(chan struct{}) ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") - } + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + // Deploy traffic alert ws.TrafficAlert <- struct{}{} // try to add another traffic monitor ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") - } + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + // prevent shutdown routine ws.setState(disconnected) // await timeout closure ws.Wg.Wait() - if ws.IsTrafficMonitorRunning() { - t.Error("should be dead") - } + assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be not be running") } func TestIsDisconnectionError(t *testing.T) { @@ -351,7 +310,7 @@ func TestWebsocket(t *testing.T) { err := ws.SetProxyAddress("garbagio") assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") - ws.Conn = &WebsocketConnection{} + ws.Conn = &dodgyConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) @@ -361,7 +320,7 @@ func TestWebsocket(t *testing.T) { err = ws.Setup(defaultSetup) assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly if called twice") - assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly") + assert.Equal(t, "GTX", ws.GetName(), "GetName should return correctly") assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") ws.setEnabled(false) @@ -376,38 +335,31 @@ func TestWebsocket(t *testing.T) { ws.setState(connected) err = ws.SetProxyAddress("https://192.168.0.1:1336") - assert.ErrorIs(t, err, errNoConnectFunc, "SetProxyAddress should call Connect and error from there") // This test asserts we actually set the proxy address, etc + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") err = ws.SetProxyAddress("https://192.168.0.1:1336") assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") - ws.setEnabled(false) // removing proxy err = ws.SetProxyAddress("") - assert.NoError(t, err, "SetProxyAddress should not error when removing proxy") + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Shutdown and error from there") + assert.ErrorIs(t, err, errCannotShutdown, "SetProxyAddress should call Shutdown and error from there") + + ws.Conn = &WebsocketConnection{} + ws.setEnabled(true) // reinstate proxy err = ws.SetProxyAddress("http://localhost:1337") assert.NoError(t, err, "SetProxyAddress should not error") - assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") - err = ws.Shutdown() - assert.ErrorIs(t, err, ErrNotConnected, "Shutdown should error when not Connected") - - ws.setState(connected) - ws.Conn = &dodgyConnection{} - err = ws.Shutdown() - assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy conn") - - ws.Conn = &WebsocketConnection{} - ws.setState(connected) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") + assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} ws.setState(disconnected) From 5e2fa6f2dea2c6a2a928e35c8b3d964c108f7b90 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 3 Feb 2024 16:18:38 +0700 Subject: [PATCH 06/15] Websocket: Make WebsocketNotEnabled a real error This allows for testing and avoids the repetition. If each returned error is a error.New() you can never use errors.Is() --- cmd/exchange_template/wrapper_file.tmpl | 2 +- exchanges/binance/binance_websocket.go | 2 +- exchanges/binanceus/binanceus_websocket.go | 2 +- exchanges/bitfinex/bitfinex_test.go | 2 +- exchanges/bitfinex/bitfinex_websocket.go | 2 +- exchanges/bithumb/bithumb_websocket.go | 3 +- exchanges/bitmex/bitmex_test.go | 2 +- exchanges/bitmex/bitmex_websocket.go | 2 +- exchanges/bitstamp/bitstamp_websocket.go | 2 +- exchanges/btcmarkets/btcmarkets_websocket.go | 2 +- exchanges/btse/btse_websocket.go | 2 +- exchanges/bybit/bybit.go | 2 - exchanges/bybit/bybit_inverse_websocket.go | 2 +- exchanges/bybit/bybit_linear_websocket.go | 2 +- exchanges/bybit/bybit_options_websocket.go | 2 +- exchanges/bybit/bybit_test.go | 7 +- exchanges/bybit/bybit_websocket.go | 2 +- exchanges/coinbasepro/coinbasepro_test.go | 2 +- .../coinbasepro/coinbasepro_websocket.go | 2 +- exchanges/coinut/coinut_test.go | 2 +- exchanges/coinut/coinut_websocket.go | 2 +- exchanges/gateio/gateio_websocket.go | 2 +- .../gateio/gateio_ws_delivery_futures.go | 2 +- exchanges/gateio/gateio_ws_futures.go | 2 +- exchanges/gateio/gateio_ws_option.go | 2 +- exchanges/gemini/gemini_test.go | 2 +- exchanges/gemini/gemini_websocket.go | 2 +- exchanges/hitbtc/hitbtc_test.go | 2 +- exchanges/hitbtc/hitbtc_websocket.go | 2 +- exchanges/huobi/huobi_test.go | 2 +- exchanges/huobi/huobi_websocket.go | 2 +- exchanges/kraken/kraken_test.go | 2 +- exchanges/kraken/kraken_websocket.go | 2 +- exchanges/kucoin/kucoin_websocket.go | 2 +- exchanges/okcoin/okcoin_websocket.go | 2 +- exchanges/okcoin/okcoin_ws_trade.go | 2 +- exchanges/okx/okx_websocket.go | 2 +- exchanges/poloniex/poloniex_test.go | 2 +- exchanges/poloniex/poloniex_websocket.go | 2 +- exchanges/stream/websocket.go | 30 +++--- exchanges/stream/websocket_test.go | 97 +++++-------------- exchanges/stream/websocket_types.go | 2 - 42 files changed, 80 insertions(+), 133 deletions(-) diff --git a/cmd/exchange_template/wrapper_file.tmpl b/cmd/exchange_template/wrapper_file.tmpl index d57f96b5451..e74ecbc320f 100644 --- a/cmd/exchange_template/wrapper_file.tmpl +++ b/cmd/exchange_template/wrapper_file.tmpl @@ -125,7 +125,7 @@ func ({{.Variable}} *{{.CapitalName}}) SetDefaults() { exchange.RestSpot: {{.Name}}APIURL, // exchange.WebsocketSpot: {{.Name}}WSAPIURL, }) - {{.Variable}}.Websocket = stream.New() + {{.Variable}}.Websocket = stream.NewWebsocket() {{.Variable}}.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit {{.Variable}}.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout {{.Variable}}.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 52cdd0cc392..bd96d546f4f 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -50,7 +50,7 @@ var ( // WsConnect initiates a websocket connection func (b *Binance) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 8f4d5c3cd6e..14098c1139b 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect initiates a websocket connection func (bi *Binanceus) WsConnect() error { if !bi.Websocket.IsEnabled() || !bi.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.HandshakeTimeout = bi.Config.HTTPTimeout diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index c2bba90d00b..e7be0fbe707 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1128,7 +1128,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !b.Websocket.IsEnabled() { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) if !b.API.AuthenticatedWebsocketSupport { diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index e1010eb1061..ae7cded7477 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -43,7 +43,7 @@ var cMtx sync.Mutex // WsConnect starts a new websocket connection func (b *Bitfinex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 3005f42359c..667ce131c71 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -2,7 +2,6 @@ package bithumb import ( "encoding/json" - "errors" "fmt" "net/http" "time" @@ -29,7 +28,7 @@ var ( // WsConnect initiates a websocket connection func (b *Bithumb) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/bitmex/bitmex_test.go b/exchanges/bitmex/bitmex_test.go index b59cb125812..e825ad8a07c 100644 --- a/exchanges/bitmex/bitmex_test.go +++ b/exchanges/bitmex/bitmex_test.go @@ -789,7 +789,7 @@ func TestGetDepositAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !b.Websocket.IsEnabled() && !b.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(b) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 6d04c106039..e1a475254f6 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -68,7 +68,7 @@ const ( // WsConnect initiates a new websocket connection func (b *Bitmex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index ab5465b574d..98aa6201df4 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect connects to a websocket feed func (b *Bitstamp) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 8c979b4d79b..01ba1a64d4c 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -39,7 +39,7 @@ var ( // WsConnect connects to a websocket feed func (b *BTCMarkets) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 41f25c95e35..e32fb9a8095 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -30,7 +30,7 @@ const ( // WsConnect connects the websocket client func (b *BTSE) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bybit/bybit.go b/exchanges/bybit/bybit.go index 5652c464057..3588336ff52 100644 --- a/exchanges/bybit/bybit.go +++ b/exchanges/bybit/bybit.go @@ -21,7 +21,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) // Bybit is the overarching type across this package @@ -90,7 +89,6 @@ var ( errAPIKeyIsNotUnified = errors.New("api key is not unified") errEndpointAvailableForNormalAPIKeyHolders = errors.New("endpoint available for normal API key holders only") errInvalidContractLength = errors.New("contract length cannot be less than or equal to zero") - errWebsocketNotEnabled = errors.New(stream.WebsocketNotEnabled) ) var ( diff --git a/exchanges/bybit/bybit_inverse_websocket.go b/exchanges/bybit/bybit_inverse_websocket.go index 77f387ace60..d1387c277f8 100644 --- a/exchanges/bybit/bybit_inverse_websocket.go +++ b/exchanges/bybit/bybit_inverse_websocket.go @@ -12,7 +12,7 @@ import ( // WsInverseConnect connects to inverse websocket feed func (by *Bybit) WsInverseConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.CoinMarginedFutures) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(inversePublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_linear_websocket.go b/exchanges/bybit/bybit_linear_websocket.go index efc2f68d1b8..9b3ed08426a 100644 --- a/exchanges/bybit/bybit_linear_websocket.go +++ b/exchanges/bybit/bybit_linear_websocket.go @@ -14,7 +14,7 @@ import ( // WsLinearConnect connects to linear a websocket feed func (by *Bybit) WsLinearConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.LinearContract) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(linearPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_options_websocket.go b/exchanges/bybit/bybit_options_websocket.go index 2f4abc7a76d..4bb25cef2a9 100644 --- a/exchanges/bybit/bybit_options_websocket.go +++ b/exchanges/bybit/bybit_options_websocket.go @@ -14,7 +14,7 @@ import ( // WsOptionsConnect connects to options a websocket feed func (by *Bybit) WsOptionsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Options) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(optionPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 58ca5bc46e8..092ca125462 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -20,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/margin" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -3064,7 +3065,7 @@ func TestWsLinearConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsLinearConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3074,7 +3075,7 @@ func TestWsInverseConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsInverseConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3084,7 +3085,7 @@ func TestWsOptionsConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsOptionsConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 857fd690afe..ff2698b8667 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -57,7 +57,7 @@ const ( // WsConnect connects to a websocket feed func (by *Bybit) WsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinbasepro/coinbasepro_test.go b/exchanges/coinbasepro/coinbasepro_test.go index 77f058384d1..0b04d4ab64b 100644 --- a/exchanges/coinbasepro/coinbasepro_test.go +++ b/exchanges/coinbasepro/coinbasepro_test.go @@ -681,7 +681,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index e4b02b764d8..5946cf778b5 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -31,7 +31,7 @@ const ( // WsConnect initiates a websocket connection func (c *CoinbasePro) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinut/coinut_test.go b/exchanges/coinut/coinut_test.go index 3431d60f730..0b165327846 100644 --- a/exchanges/coinut/coinut_test.go +++ b/exchanges/coinut/coinut_test.go @@ -66,7 +66,7 @@ func setupWSTestAuth(t *testing.T) { } if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } if sharedtestvalues.AreAPICredentialsSet(c) { c.Websocket.SetCanUseAuthenticatedEndpoints(true) diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 2453816e9ce..78b389879ca 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -41,7 +41,7 @@ var ( // WsConnect initiates a websocket connection func (c *COINUT) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index c26d04afe7c..3a3dddecf88 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -60,7 +60,7 @@ var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsConnect initiates a websocket connection func (g *Gateio) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Spot) if err != nil { diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index ba5c64afe3d..449181c1007 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -45,7 +45,7 @@ var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account func (g *Gateio) WsDeliveryFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.DeliveryFutures) if err != nil { diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index c0411a5816c..20e293b93af 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -64,7 +64,7 @@ var responseFuturesStream = make(chan stream.Response) // WsFuturesConnect initiates a websocket connection for futures account func (g *Gateio) WsFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) if err != nil { diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index d5340f0c350..3278914f21f 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -70,7 +70,7 @@ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. func (g *Gateio) WsOptionsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Options) if err != nil { diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index 6a3b35e1885..7477e78ff11 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -556,7 +556,7 @@ func TestWsAuth(t *testing.T) { if !g.Websocket.IsEnabled() && !g.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(g) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer go g.wsReadData() diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index 913c856ddd4..43c2135e021 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -39,7 +39,7 @@ var comms = make(chan stream.Response) // WsConnect initiates a websocket connection func (g *Gemini) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/hitbtc/hitbtc_test.go b/exchanges/hitbtc/hitbtc_test.go index 3e629d0a654..68b8495167b 100644 --- a/exchanges/hitbtc/hitbtc_test.go +++ b/exchanges/hitbtc/hitbtc_test.go @@ -466,7 +466,7 @@ func setupWsAuth(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index 705f584f01c..deb885424bf 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -34,7 +34,7 @@ const ( // WsConnect starts a new connection with the websocket API func (h *HitBTC) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/huobi/huobi_test.go b/exchanges/huobi/huobi_test.go index 4aafc7b74a5..16106daf566 100644 --- a/exchanges/huobi/huobi_test.go +++ b/exchanges/huobi/huobi_test.go @@ -78,7 +78,7 @@ func setupWsTests(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } comms = make(chan WsMessage, sharedtestvalues.WebsocketChannelOverrideCapacity) go h.wsReadData() diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index f601b03dbc2..92b5bf4c22e 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -62,7 +62,7 @@ var comms = make(chan WsMessage) // WsConnect initiates a new websocket connection func (h *HUOBI) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.wsDial(&dialer) diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index 263d70bf571..8530cb53b8d 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1215,7 +1215,7 @@ func setupWsTests(t *testing.T) { return } if !k.Websocket.IsEnabled() && !k.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(k) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := k.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 787d52b2c31..fd4325164e2 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -87,7 +87,7 @@ var cancelOrdersStatus = make(map[int64]*struct { // WsConnect initiates a websocket connection func (k *Kraken) WsConnect() error { if !k.Websocket.IsEnabled() || !k.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 3f8a0c783ca..917bc1a90e3 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -97,7 +97,7 @@ var ( // WsConnect creates a new websocket connection. func (ku *Kucoin) WsConnect() error { if !ku.Websocket.IsEnabled() || !ku.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } fetchedFuturesSnapshotOrderbook = map[string]bool{} var dialer websocket.Dialer diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index e714aba7d10..0787775e2f9 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -74,7 +74,7 @@ func isAuthenticatedChannel(channel string) bool { // WsConnect initiates a websocket connection func (o *Okcoin) WsConnect() error { if !o.Websocket.IsEnabled() || !o.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/okcoin/okcoin_ws_trade.go b/exchanges/okcoin/okcoin_ws_trade.go index cd0c7f86605..85b6102e24d 100644 --- a/exchanges/okcoin/okcoin_ws_trade.go +++ b/exchanges/okcoin/okcoin_ws_trade.go @@ -130,7 +130,7 @@ func (o *Okcoin) WsAmendMultipleOrder(args []AmendTradeOrderRequestParam) ([]Ame func (o *Okcoin) SendWebsocketRequest(operation string, data, result interface{}, authenticated bool) error { switch { case !o.Websocket.IsEnabled(): - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled case !o.Websocket.IsConnected(): return stream.ErrNotConnected case !o.Websocket.CanUseAuthenticatedEndpoints() && authenticated: diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index b4d211eec01..54bc6b0e486 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -216,7 +216,7 @@ const ( // WsConnect initiates a websocket connection func (ok *Okx) WsConnect() error { if !ok.Websocket.IsEnabled() || !ok.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/poloniex/poloniex_test.go b/exchanges/poloniex/poloniex_test.go index bab0ffd4a71..d74e2353d67 100644 --- a/exchanges/poloniex/poloniex_test.go +++ b/exchanges/poloniex/poloniex_test.go @@ -548,7 +548,7 @@ func TestGenerateNewAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !p.Websocket.IsEnabled() && !p.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(p) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 1be429a58ef..23774335fc8 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -55,7 +55,7 @@ var ( // WsConnect initiates a websocket connection func (p *Poloniex) WsConnect() error { if !p.Websocket.IsEnabled() || !p.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 4fd0685bf7b..4e73579a960 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -24,24 +24,20 @@ const ( defaultTrafficPeriod = time.Second ) +// Public websocket errors var ( - // ErrSubscriptionNotFound defines an error when a subscription is not found - ErrSubscriptionNotFound = errors.New("subscription not found") - // ErrSubscribedAlready defines an error when a channel is already subscribed - ErrSubscribedAlready = errors.New("duplicate subscription") - // ErrSubscriptionFailure defines an error when a subscription fails - ErrSubscriptionFailure = errors.New("subscription failure") - // ErrSubscriptionNotSupported defines an error when a subscription channel is not supported by an exchange + ErrWebsocketNotEnabled = errors.New("websocket not enabled") + ErrSubscriptionNotFound = errors.New("subscription not found") + ErrSubscribedAlready = errors.New("duplicate subscription") + ErrSubscriptionFailure = errors.New("subscription failure") ErrSubscriptionNotSupported = errors.New("subscription channel not supported ") - // ErrUnsubscribeFailure defines an error when a unsubscribe fails - ErrUnsubscribeFailure = errors.New("unsubscribe failure") - // ErrChannelInStateAlready defines an error when a subscription channel is already in a new state - ErrChannelInStateAlready = errors.New("channel already in state") - // ErrAlreadyDisabled is returned when you double-disable the websocket - ErrAlreadyDisabled = errors.New("websocket already disabled") - // ErrNotConnected defines an error when websocket is not connected - ErrNotConnected = errors.New("websocket is not connected") + ErrUnsubscribeFailure = errors.New("unsubscribe failure") + ErrChannelInStateAlready = errors.New("channel already in state") + ErrAlreadyDisabled = errors.New("websocket already disabled") + ErrNotConnected = errors.New("websocket is not connected") +) +var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") errWebsocketIsNil = errors.New("websocket is nil") @@ -265,7 +261,7 @@ func (w *Websocket) Connect() error { defer w.m.Unlock() if !w.IsEnabled() { - return errors.New(WebsocketNotEnabled) + return ErrWebsocketNotEnabled } if w.IsConnecting() { return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting) @@ -490,7 +486,7 @@ func (w *Websocket) Shutdown() error { // FlushChannels flushes channel subscriptions when there is a pair/asset change func (w *Websocket) FlushChannels() error { if !w.IsEnabled() { - return fmt.Errorf("%s websocket: service not enabled", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrWebsocketNotEnabled) } if !w.IsConnected() { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 188c8fcb30c..8459a5461a9 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -201,105 +201,62 @@ func TestTrafficMonitorTimeout(t *testing.T) { func TestIsDisconnectionError(t *testing.T) { t.Parallel() - isADisconnectionError := IsDisconnectionError(errors.New("errorText")) - if isADisconnectionError { - t.Error("Its not") - } - isADisconnectionError = IsDisconnectionError(&websocket.CloseError{ - Code: 1006, - Text: "errorText", - }) - if !isADisconnectionError { - t.Error("It is") - } - - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errClosedConnection, - }) - if isADisconnectionError { - t.Error("It's not") - } - - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errors.New("errText"), - }) - if !isADisconnectionError { - t.Error("It is") - } + assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsDisconnectionError should return true") + assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsDisconnectionError should return true") } func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} err := wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errNoConnectFunc, "Connect should error correctly") wsWrong.connector = func() error { return nil } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) wsWrong.setState(connecting) wsWrong.Wg = &sync.WaitGroup{} err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") wsWrong.setState(disconnected) - wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") } + wsWrong.connector = func() error { return errDastardlyReason } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") ws := NewWebsocket() err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Setup must not error") ws.trafficTimeout = time.Minute ws.connector = func() error { return nil } err = ws.Connect() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Connect must not error") ws.TrafficAlert <- struct{}{} - timer := time.NewTimer(900 * time.Millisecond) - ws.ReadMessageErrors <- errors.New("errorText") - select { - case err := <-ws.ToRoutine: - errText, ok := err.(error) - if !ok { - t.Error("unable to type assert error") - } else if errText.Error() != "errorText" { - t.Errorf("Expected 'errorText', received %v", err) - } - case <-timer.C: - t.Error("Timeout waiting for datahandler to receive error") - } - ws.ReadMessageErrors <- &websocket.CloseError{ - Code: 1006, - Text: "errorText", - } -outer: - for { + c := func(tb *assert.CollectT) { select { - case err := <-ws.ToRoutine: - if _, ok := err.(*websocket.CloseError); !ok { - t.Errorf("Error is not a disconnection error: %v", err) + case v := <-ws.ToRoutine: + switch err := v.(type) { + case *websocket.CloseError: + assert.Equal(tb, "SpecialText", err.Text, "Should get correct Close Error") + case error: + assert.ErrorIs(tb, err, errDastardlyReason, "Should get the correct error") } - case <-timer.C: - break outer + default: } } + + ws.ReadMessageErrors <- errDastardlyReason + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") + + ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") } func TestWebsocket(t *testing.T) { @@ -1041,9 +998,7 @@ func TestFlushChannels(t *testing.T) { dodgyWs := Websocket{} err := dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "FlushChannels should error correctly") dodgyWs.setEnabled(true) err = dodgyWs.FlushChannels() diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 7af4ae03a1d..31e1db58b09 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -15,8 +15,6 @@ import ( // Websocket functionality list and state consts const ( - // WebsocketNotEnabled alerts of a disabled websocket - WebsocketNotEnabled = "exchange_websocket_not_enabled" WebsocketNotAuthenticatedUsingRest = "%v - Websocket not authenticated, using REST\n" Ping = "ping" Pong = "pong" From 3c38cb9f42ded47acf31ef1c1f2d4c65fdd27b71 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 3 Feb 2024 17:39:35 +0700 Subject: [PATCH 07/15] Websocket: Add more testable errors --- exchanges/stream/websocket.go | 26 ++- exchanges/stream/websocket_connection.go | 21 +- exchanges/stream/websocket_test.go | 269 +++++++---------------- 3 files changed, 94 insertions(+), 222 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 4e73579a960..9e527622b3a 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -40,19 +40,23 @@ var ( var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") + errExchangeConfigEmpty = errors.New("exchange config is empty") errWebsocketIsNil = errors.New("websocket is nil") errWebsocketSetupIsNil = errors.New("websocket setup is nil") errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") errConfigFeaturesIsNil = errors.New("exchange config features is nil") errDefaultURLIsEmpty = errors.New("default url is empty") errRunningURLIsEmpty = errors.New("running url cannot be empty") errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameUnset = errors.New("exchange config name unset") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") errClosedConnection = errors.New("use of closed network connection") errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") @@ -65,6 +69,7 @@ var ( errAlreadyConnected = errors.New("websocket already connected") errCannotShutdown = errors.New("websocket cannot shutdown") errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") ) var globalReporter Reporter @@ -107,7 +112,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } if s.ExchangeConfig.Name == "" { - return errExchangeConfigNameUnset + return errExchangeConfigNameEmpty } w.exchangeName = s.ExchangeConfig.Name w.verbose = s.ExchangeConfig.Verbose @@ -196,22 +201,22 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { // SetupNewConnection sets up an auth or unauth streaming connection func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if w == nil { - return errors.New("setting up new connection error: websocket is nil") + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } if c == (ConnectionSetup{}) { - return errors.New("setting up new connection error: websocket connection configuration empty") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) } if w.exchangeName == "" { - return errors.New("setting up new connection error: exchange name not set, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) } if w.TrafficAlert == nil { - return errors.New("setting up new connection error: traffic alert is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) } if w.ReadMessageErrors == nil { - return errors.New("setting up new connection error: read message errors is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) } connectionURL := w.GetWebsocketURL() @@ -314,7 +319,7 @@ func (w *Websocket) Connect() error { // Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { if !w.IsEnabled() { - return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrAlreadyDisabled) } w.setEnabled(false) @@ -324,8 +329,7 @@ func (w *Websocket) Disable() error { // Enable enables the exchange websocket protocol func (w *Websocket) Enable() error { if w.IsConnected() || w.IsEnabled() { - return fmt.Errorf("websocket is already enabled for exchange %s", - w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyEnabled) } w.setEnabled(true) @@ -490,7 +494,7 @@ func (w *Websocket) FlushChannels() error { } if !w.IsConnected() { - return fmt.Errorf("%s websocket: service not connected", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) } if w.features.Subscribe { diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 0bb1e660412..009b9e37459 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -50,9 +50,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(signature, request inter return payload, nil case <-timer.C: timer.Stop() - return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", - w.ExchangeName, - signature) + return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature) } } @@ -72,25 +70,14 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header w.Connection, conStatus, err = dialer.Dial(w.URL, headers) if err != nil { if conStatus != nil { - return fmt.Errorf("%s websocket connection: %v %v %v Error: %v", - w.ExchangeName, - w.URL, - conStatus, - conStatus.StatusCode, - err) + return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) } - return fmt.Errorf("%s websocket connection: %v Error: %v", - w.ExchangeName, - w.URL, - err) + return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } defer conStatus.Body.Close() if w.Verbose { - log.Infof(log.WebsocketMgr, - "%v Websocket connected to %s\n", - w.ExchangeName, - w.URL) + log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) } select { case w.Traffic <- struct{}{}: diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 8459a5461a9..653abbf5f5c 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -122,7 +122,7 @@ func TestSetup(t *testing.T) { websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errExchangeConfigNameUnset, "Setup should error correctly") + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "Setup should error correctly") websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) @@ -480,16 +480,10 @@ func TestConnectionMonitorNoConnection(t *testing.T) { ws.Wg = &sync.WaitGroup{} ws.enabled = true err := ws.connectionMonitor() - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } - if !ws.IsConnectionMonitorRunning() { - t.Fatal("Should not have exited") - } + require.NoError(t, err, "connectionMonitor must not error") + assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") err = ws.connectionMonitor() - if !errors.Is(err, errAlreadyRunning) { - t.Fatalf("received: %v, but expected: %v", err, errAlreadyRunning) - } + assert.ErrorIs(t, err, errAlreadyRunning, "connectionMonitor should error correctly") } // TestGetSubscription logic test @@ -528,15 +522,9 @@ func TestGetSubscriptions(t *testing.T) { func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { t.Parallel() ws := NewWebsocket() - result := ws.CanUseAuthenticatedEndpoints() - if result { - t.Error("expected `canUseAuthenticatedEndpoints` to be false") - } + assert.False(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return false") ws.SetCanUseAuthenticatedEndpoints(true) - result = ws.CanUseAuthenticatedEndpoints() - if !result { - t.Error("expected `canUseAuthenticatedEndpoints` to be true") - } + assert.True(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return true") } // TestDial logic test @@ -773,69 +761,41 @@ func TestParseBinaryResponse(t *testing.T) { } var b bytes.Buffer - w := gzip.NewWriter(&b) - _, err := w.Write([]byte("hello")) - if err != nil { - t.Error(err) - } - err = w.Close() - if err != nil { - t.Error(err) - } - var resp []byte - resp, err = wc.parseBinaryResponse(b.Bytes()) - if err != nil { - t.Error(err) - } - if !strings.EqualFold(string(resp), "hello") { - t.Errorf("GZip conversion failed. Received: '%v', Expected: 'hello'", string(resp)) - } + g := gzip.NewWriter(&b) + _, err := g.Write([]byte("hello")) + require.NoError(t, err, "gzip.Write must not error") + assert.NoError(t, g.Close(), "Close should not error") + + resp, err := wc.parseBinaryResponse(b.Bytes()) + assert.NoError(t, err, "parseBinaryResponse should not error parsing gzip") + assert.EqualValues(t, "hello", resp, "parseBinaryResponse should decode gzip") + + b.Reset() + f, err := flate.NewWriter(&b, 1) + require.NoError(t, err, "flate.NewWriter must not error") + _, err = f.Write([]byte("goodbye")) + require.NoError(t, err, "flate.Write must not error") + assert.NoError(t, f.Close(), "Close should not error") - var b2 bytes.Buffer - w2, err2 := flate.NewWriter(&b2, 1) - if err2 != nil { - t.Error(err2) - } - _, err2 = w2.Write([]byte("hello")) - if err2 != nil { - t.Error(err) - } - err2 = w2.Close() - if err2 != nil { - t.Error(err) - } - resp2, err3 := wc.parseBinaryResponse(b2.Bytes()) - if err3 != nil { - t.Error(err3) - } - if !strings.EqualFold(string(resp2), "hello") { - t.Errorf("Deflate conversion failed. Received: '%v', Expected: 'hello'", string(resp2)) - } + resp, err = wc.parseBinaryResponse(b.Bytes()) + assert.NoError(t, err, "parseBinaryResponse should not error parsing inflate") + assert.EqualValues(t, "goodbye", resp, "parseBinaryResponse should deflate") - _, err4 := wc.parseBinaryResponse([]byte{}) - if err4 == nil || err4.Error() != "unexpected EOF" { - t.Error("Expected error 'unexpected EOF'") - } + _, err = wc.parseBinaryResponse([]byte{}) + assert.ErrorContains(t, err, "unexpected EOF", "parseBinaryResponse should error on empty input") } // TestCanUseAuthenticatedWebsocketForWrapper logic test func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { t.Parallel() ws := &Websocket{} - resp := ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is false") - } + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + ws.setState(connected) - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false") - } + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + ws.canUseAuthenticatedEndpoints = true - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if !resp { - t.Error("Expected true, `connected` and `CanUseAuthenticatedEndpoints` is true") - } + assert.True(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return true") } func TestGenerateMessageID(t *testing.T) { @@ -869,34 +829,22 @@ func BenchmarkGenerateMessageID_Low(b *testing.B) { func TestCheckWebsocketURL(t *testing.T) { err := checkWebsocketURL("") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on empty string") err = checkWebsocketURL("wowowow:wowowowo") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on bad format") err = checkWebsocketURL("://") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "missing protocol scheme", "checkWebsocketURL should error correctly on bad proto") err = checkWebsocketURL("http://www.google.com") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on wrong proto") err = checkWebsocketURL("wss://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") err = checkWebsocketURL("ws://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") } func TestGetChannelDifference(t *testing.T) { @@ -1002,9 +950,7 @@ func TestFlushChannels(t *testing.T) { dodgyWs.setEnabled(true) err = dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") web := Websocket{ enabled: true, @@ -1023,7 +969,7 @@ func TestFlushChannels(t *testing.T) { } problemFunc := func() ([]subscription.Subscription, error) { - return nil, errors.New("problems") + return nil, errDastardlyReason } noSub := func() ([]subscription.Subscription, error) { @@ -1037,47 +983,34 @@ func TestFlushChannels(t *testing.T) { return []subscription.Subscription{{Channel: "test"}}, nil } err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") web.features.FullPayloadSubscribe = true web.GenerateSubs = problemFunc err = web.FlushChannels() // error on full subscribeToChannels - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") web.GenerateSubs = noSub - err = web.FlushChannels() // No subs to sub - if err != nil { - t.Fatal(err) - } + err = web.FlushChannels() // No subs to unsub + assert.NoError(t, err, "FlushChannels should not error") web.GenerateSubs = newgen.generateSubs subs, err := web.GenerateSubs() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "GenerateSubs must not error") + web.AddSuccessfulSubscriptions(subs...) err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") web.features.FullPayloadSubscribe = false web.features.Subscribe = true web.GenerateSubs = problemFunc err = web.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") web.GenerateSubs = newgen.generateSubs err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") web.subscriptionMutex.Lock() web.subscriptions = subscriptionMap{ 41: { @@ -1094,21 +1027,15 @@ func TestFlushChannels(t *testing.T) { web.subscriptionMutex.Unlock() err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") web.setState(connected) web.features.Unsubscribe = true err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") } func TestDisable(t *testing.T) { @@ -1118,14 +1045,8 @@ func TestDisable(t *testing.T) { state: connected, ShutdownC: make(chan struct{}), } - err := web.Disable() - if err != nil { - t.Fatal(err) - } - err = web.Disable() - if err == nil { - t.Fatal("should already be disabled") - } + require.NoError(t, web.Disable(), "Disable must not error") + assert.ErrorIs(t, web.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } func TestEnable(t *testing.T) { @@ -1140,98 +1061,66 @@ func TestEnable(t *testing.T) { Subscriber: func([]subscription.Subscription) error { return nil }, } - err := web.Enable() - if err != nil { - t.Fatal(err) - } - - err = web.Enable() - if err == nil { - t.Fatal("should already be enabled") - } - - fmt.Print() + require.NoError(t, web.Enable(), "Enable must not error") + assert.ErrorIs(t, web.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { t.Parallel() var nonsenseWebsock *Websocket err := nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errWebsocketIsNil, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{exchangeName: "test"} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") nonsenseWebsock.TrafficAlert = make(chan struct{}) err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") web := Websocket{ connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), - state: disconnected, TrafficAlert: make(chan struct{}), ReadMessageErrors: make(chan error), DataHandler: make(chan interface{}), } err = web.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Setup should not error") + err = web.SetupNewConnection(ConnectionSetup{}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigEmpty, "SetupNewConnection should error correctly") + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err != nil { - t.Fatal(err) - } - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", - Authenticated: true}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetupNewConnection should not error") + + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", Authenticated: true}) + assert.NoError(t, err, "SetupNewConnection should not error") } func TestWebsocketConnectionShutdown(t *testing.T) { t.Parallel() wc := WebsocketConnection{} err := wc.Shutdown() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Shutdown should not error") err = wc.Dial(&websocket.Dialer{}, nil) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial must error correctly") wc.URL = websocketTestURL err = wc.Dial(&websocket.Dialer{}, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Dial must not error") err = wc.Shutdown() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Shutdown must not error") } // TestLatency logic test @@ -1285,27 +1174,19 @@ func TestCheckSubscriptions(t *testing.T) { t.Parallel() ws := Websocket{} err := ws.checkSubscriptions(nil) - if !errors.Is(err, errNoSubscriptionsSupplied) { - t.Fatalf("received: %v, but expected: %v", err, errNoSubscriptionsSupplied) - } + assert.ErrorIs(t, err, errNoSubscriptionsSupplied, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 1 err = ws.checkSubscriptions([]subscription.Subscription{{}, {}}) - if !errors.Is(err, errSubscriptionsExceedsLimit) { - t.Fatalf("received: %v, but expected: %v", err, errSubscriptionsExceedsLimit) - } + assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}} err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}}) - if !errors.Is(err, errChannelAlreadySubscribed) { - t.Fatalf("received: %v, but expected: %v", err, errChannelAlreadySubscribed) - } + assert.ErrorIs(t, err, errChannelAlreadySubscribed, "checkSubscriptions should error correctly") err = ws.checkSubscriptions([]subscription.Subscription{{}}) - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } + assert.NoError(t, err, "checkSubscriptions should not error") } From 06d433e13398852777ceed9590efa0aee456af9c Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sun, 4 Feb 2024 09:30:32 +0700 Subject: [PATCH 08/15] Websocket: Improve GenerateMessageID test Testing just the last id doesn't feel very robust --- exchanges/stream/websocket_connection.go | 2 +- exchanges/stream/websocket_test.go | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 009b9e37459..910142caa9a 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -272,7 +272,7 @@ func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) { return standardMessage, reader.Close() } -// GenerateMessageID Creates a messageID to checkout +// GenerateMessageID Creates a random message ID func (w *WebsocketConnection) GenerateMessageID(highPrec bool) int64 { var min int64 = 1e8 var max int64 = 2e8 diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 653abbf5f5c..bdb0478606d 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -801,13 +801,12 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { func TestGenerateMessageID(t *testing.T) { t.Parallel() wc := WebsocketConnection{} - var id int64 - for i := 0; i < 10; i++ { - newID := wc.GenerateMessageID(true) - if id == newID { - t.Fatal("ID generation is not unique") - } - id = newID + const spins = 1000 + ids := make([]int64, spins) + for i := 0; i < spins; i++ { + id := wc.GenerateMessageID(true) + assert.NotContains(t, ids, id, "GenerateMessageID must not generate the same ID twice") + ids[i] = id } } From 48111063ccb6045c262ebd524008e890dad2c5cd Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 5 Feb 2024 13:37:41 +0700 Subject: [PATCH 09/15] Websocket: Protect Setup() from races --- exchanges/stream/websocket.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 9e527622b3a..af882464604 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -103,6 +103,9 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } + w.m.Lock() + defer w.m.Unlock() + if w.IsInitialised() { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } From beb7763ed568e724c598a650bc67c843968e1219 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 15 Feb 2024 19:39:03 +0700 Subject: [PATCH 10/15] Websocket: Use atomics instead of mutex This was spurred by looking at the setState call in trafficMonitor and the effect on blocking and efficiency. With the new atomic types in Go 1.19, and the small types in use here, atomics should be safe for our usage. bools should be truly atomic, and uint32 is atomic when the accepted value range is less than one byte/uint8 since that can be written atomicly by concurrent processors. Maybe that's not even a factor any more, however we don't even have to worry enough to check. --- exchanges/stream/websocket.go | 93 +++++++++-------------------- exchanges/stream/websocket_test.go | 77 ++++++++++++------------ exchanges/stream/websocket_types.go | 18 +++--- 3 files changed, 74 insertions(+), 114 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index af882464604..409e0953251 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -128,7 +128,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { if s.ExchangeConfig.Features == nil { return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil) } - w.enabled = s.ExchangeConfig.Features.Enabled.Websocket + w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) if s.Connector == nil { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset) @@ -388,9 +388,7 @@ func (w *Websocket) connectionMonitor() error { if w.checkAndSetMonitorRunning() { return errAlreadyRunning } - w.fieldMutex.RLock() delay := w.connectionMonitorDelay - w.fieldMutex.RUnlock() go func() { timer := time.NewTimer(delay) @@ -618,104 +616,73 @@ func (w *Websocket) trafficMonitor() { }() } -// IsInitialised returns whether the websocket has been Setup() already -func (w *Websocket) IsInitialised() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state != uninitialised +func (w *Websocket) setState(s uint32) { + w.state.Store(s) } -func (w *Websocket) setState(s state) { - w.fieldMutex.Lock() - w.state = s - w.fieldMutex.Unlock() +// IsInitialised returns whether the websocket has been Setup() already +func (w *Websocket) IsInitialised() bool { + return w.state.Load() != uninitialised } // IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state == connected + return w.state.Load() == connected } // IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state == connecting + return w.state.Load() == connecting } func (w *Websocket) setEnabled(b bool) { - w.fieldMutex.Lock() - w.enabled = b - w.fieldMutex.Unlock() + w.enabled.Store(b) } // IsEnabled returns whether the websocket is enabled func (w *Websocket) IsEnabled() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.enabled + return w.enabled.Load() } func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.trafficMonitorRunning = b - w.fieldMutex.Unlock() + w.trafficMonitorRunning.Store(b) } // IsTrafficMonitorRunning returns status of the traffic monitor func (w *Websocket) IsTrafficMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.trafficMonitorRunning + return w.trafficMonitorRunning.Load() } func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - if w.connectionMonitorRunning { - return true - } - w.connectionMonitorRunning = true - return false + return !w.connectionMonitorRunning.CompareAndSwap(false, true) } func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.connectionMonitorRunning = b - w.fieldMutex.Unlock() + w.connectionMonitorRunning.Store(b) } // IsConnectionMonitorRunning returns status of connection monitor func (w *Websocket) IsConnectionMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connectionMonitorRunning + return w.connectionMonitorRunning.Load() } func (w *Websocket) setDataMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.dataMonitorRunning = b - w.fieldMutex.Unlock() + w.dataMonitorRunning.Store(b) } // IsDataMonitorRunning returns status of data monitor func (w *Websocket) IsDataMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.dataMonitorRunning + return w.dataMonitorRunning.Load() } // CanUseAuthenticatedWebsocketForWrapper Handles a common check to // verify whether a wrapper can use an authenticated websocket endpoint func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { - if w.IsConnected() && w.CanUseAuthenticatedEndpoints() { - return true - } else if w.IsConnected() && !w.CanUseAuthenticatedEndpoints() { - log.Infof(log.WebsocketMgr, - WebsocketNotAuthenticatedUsingRest, - w.exchangeName) + if w.IsConnected() { + if w.CanUseAuthenticatedEndpoints() { + return true + } + log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName) } return false } @@ -994,20 +961,14 @@ func (w *Websocket) GetSubscriptions() []subscription.Subscription { return subs } -// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in -// a thread safe manner -func (w *Websocket) SetCanUseAuthenticatedEndpoints(val bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - w.canUseAuthenticatedEndpoints = val +// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner +func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) { + w.canUseAuthenticatedEndpoints.Store(b) } -// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in -// a thread safe manner +// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner func (w *Websocket) CanUseAuthenticatedEndpoints() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.canUseAuthenticatedEndpoints + return w.canUseAuthenticatedEndpoints.Load() } // IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index bdb0478606d..4725cc9b064 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -478,7 +478,7 @@ func TestConnectionMonitorNoConnection(t *testing.T) { ws.ShutdownC = make(chan struct{}, 1) ws.exchangeName = "hello" ws.Wg = &sync.WaitGroup{} - ws.enabled = true + ws.setEnabled(true) err := ws.connectionMonitor() require.NoError(t, err, "connectionMonitor must not error") assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") @@ -792,9 +792,10 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") ws.setState(connected) + require.True(t, ws.IsConnected(), "IsConnected must return true") assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") - ws.canUseAuthenticatedEndpoints = true + ws.SetCanUseAuthenticatedEndpoints(true) assert.True(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return true") } @@ -951,9 +952,7 @@ func TestFlushChannels(t *testing.T) { err = dodgyWs.FlushChannels() assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") - web := Websocket{ - enabled: true, - state: connected, + w := Websocket{ connector: connect, ShutdownC: make(chan struct{}), Subscriber: newgen.SUBME, @@ -966,6 +965,8 @@ func TestFlushChannels(t *testing.T) { // in FlushChannels() so the traffic monitor doesn't time out and turn // this to an unconnected state } + w.setEnabled(true) + w.setState(connected) problemFunc := func() ([]subscription.Subscription, error) { return nil, errDastardlyReason @@ -978,40 +979,40 @@ func TestFlushChannels(t *testing.T) { // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} - web.GenerateSubs = func() ([]subscription.Subscription, error) { + w.GenerateSubs = func() ([]subscription.Subscription, error) { return []subscription.Subscription{{Channel: "test"}}, nil } - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.features.FullPayloadSubscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() // error on full subscribeToChannels + w.features.FullPayloadSubscribe = true + w.GenerateSubs = problemFunc + err = w.FlushChannels() // error on full subscribeToChannels assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - web.GenerateSubs = noSub - err = web.FlushChannels() // No subs to unsub + w.GenerateSubs = noSub + err = w.FlushChannels() // No subs to unsub assert.NoError(t, err, "FlushChannels should not error") - web.GenerateSubs = newgen.generateSubs - subs, err := web.GenerateSubs() + w.GenerateSubs = newgen.generateSubs + subs, err := w.GenerateSubs() require.NoError(t, err, "GenerateSubs must not error") - web.AddSuccessfulSubscriptions(subs...) - err = web.FlushChannels() + w.AddSuccessfulSubscriptions(subs...) + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.features.FullPayloadSubscribe = false - web.features.Subscribe = true + w.features.FullPayloadSubscribe = false + w.features.Subscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() + w.GenerateSubs = problemFunc + err = w.FlushChannels() assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - web.GenerateSubs = newgen.generateSubs - err = web.FlushChannels() + w.GenerateSubs = newgen.generateSubs + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.subscriptionMutex.Lock() - web.subscriptions = subscriptionMap{ + w.subscriptionMutex.Lock() + w.subscriptions = subscriptionMap{ 41: { Key: 41, Channel: "match channel", @@ -1023,34 +1024,34 @@ func TestFlushChannels(t *testing.T) { Pair: currency.NewPair(currency.THETA, currency.USDT), }, } - web.subscriptionMutex.Unlock() + w.subscriptionMutex.Unlock() - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.setState(connected) - web.features.Unsubscribe = true - err = web.FlushChannels() + w.setState(connected) + w.features.Unsubscribe = true + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") } func TestDisable(t *testing.T) { t.Parallel() - web := Websocket{ - enabled: true, - state: connected, + w := Websocket{ ShutdownC: make(chan struct{}), } - require.NoError(t, web.Disable(), "Disable must not error") - assert.ErrorIs(t, web.Disable(), ErrAlreadyDisabled, "Disable should error correctly") + w.setEnabled(true) + w.setState(connected) + require.NoError(t, w.Disable(), "Disable must not error") + assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } func TestEnable(t *testing.T) { t.Parallel() - web := Websocket{ + w := Websocket{ connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), @@ -1060,8 +1061,8 @@ func TestEnable(t *testing.T) { Subscriber: func([]subscription.Subscription) error { return nil }, } - require.NoError(t, web.Enable(), "Enable must not error") - assert.ErrorIs(t, web.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") + require.NoError(t, w.Enable(), "Enable must not error") + assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 31e1db58b09..a783d585a4e 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -2,6 +2,7 @@ package stream import ( "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -23,10 +24,8 @@ const ( type subscriptionMap map[any]*subscription.Subscription -type state int - const ( - uninitialised state = iota + uninitialised uint32 = iota disconnected connecting connected @@ -35,13 +34,13 @@ const ( // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { - canUseAuthenticatedEndpoints bool - enabled bool - state state + canUseAuthenticatedEndpoints atomic.Bool + enabled atomic.Bool + state atomic.Uint32 verbose bool - connectionMonitorRunning bool - trafficMonitorRunning bool - dataMonitorRunning bool + connectionMonitorRunning atomic.Bool + trafficMonitorRunning atomic.Bool + dataMonitorRunning atomic.Bool trafficTimeout time.Duration connectionMonitorDelay time.Duration proxyAddr string @@ -51,7 +50,6 @@ type Websocket struct { runningURLAuth string exchangeName string m sync.Mutex - fieldMutex sync.RWMutex connector func() error subscriptionMutex sync.RWMutex From 6db7bcddd38936198087214fc51dea7c9ccb7806 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 15 Feb 2024 20:14:47 +0700 Subject: [PATCH 11/15] Websocket: Fix and simplify traffic monitor trafficMonitor had a check throttle at the end of the for loop to stop it just gobbling the (blocking) trafficAlert channel non-stop. That makes sense, except that nothing is sent to the trafficAlert channel if there's no listener. So that means that it's out by one second on the trafficAlert, because any traffic received during the pause is doesn't try to send a traffic alert. The unstopped timer is deliberately leaked for later GC when shutdown. It won't delay/block anything, and it's a trivial memory leak during an infrequent event. Deliberately Choosing to recreate the timer each time instead of using Stop, drain and reset --- exchanges/stream/websocket.go | 78 +++++++---------- exchanges/stream/websocket_connection.go | 2 +- exchanges/stream/websocket_test.go | 104 +++++++++++++++++++---- 3 files changed, 117 insertions(+), 67 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 409e0953251..bebc3abb595 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -16,12 +16,7 @@ import ( ) const ( - defaultJobBuffer = 5000 - // defaultTrafficPeriod defines a period of pause for the traffic monitor, - // as there are periods with large incoming traffic alerts which requires a - // timer reset, this limits work on this routine to a more effective rate - // of check. - defaultTrafficPeriod = time.Second + jobBuffer = 5000 ) // Public websocket errors @@ -37,6 +32,7 @@ var ( ErrNotConnected = errors.New("websocket is not connected") ) +// Private websocket errors var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") @@ -72,7 +68,10 @@ var ( errConnSetup = errors.New("error in connection setup") ) -var globalReporter Reporter +var ( + globalReporter Reporter + trafficCheckInterval = time.Second +) // SetupGlobalReporter sets a reporter interface to be used // for all exchange requests @@ -83,9 +82,9 @@ func SetupGlobalReporter(r Reporter) { // NewWebsocket initialises the websocket struct func NewWebsocket() *Websocket { return &Websocket{ - DataHandler: make(chan interface{}, defaultJobBuffer), - ToRoutine: make(chan interface{}, defaultJobBuffer), - TrafficAlert: make(chan struct{}), + DataHandler: make(chan interface{}, jobBuffer), + ToRoutine: make(chan interface{}, jobBuffer), + TrafficAlert: make(chan struct{}, 1), ReadMessageErrors: make(chan error), Subscribe: make(chan []subscription.Subscription), Unsubscribe: make(chan []subscription.Subscription), @@ -545,9 +544,9 @@ func (w *Websocket) FlushChannels() error { return w.Connect() } -// trafficMonitor uses a timer of WebsocketTrafficLimitTime and once it expires, -// it will reconnect if the TrafficAlert channel has not received any data. The -// trafficTimer will reset on each traffic alert +// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert +// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic +// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again func (w *Websocket) trafficMonitor() { if w.IsTrafficMonitorRunning() { return @@ -556,62 +555,45 @@ func (w *Websocket) trafficMonitor() { w.Wg.Add(1) go func() { - var trafficTimer = time.NewTimer(w.trafficTimeout) - pause := make(chan struct{}) + t := time.NewTimer(w.trafficTimeout) for { select { case <-w.ShutdownC: if w.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) } - trafficTimer.Stop() + t.Stop() w.setTrafficMonitorRunning(false) w.Wg.Done() return - case <-w.TrafficAlert: - if !trafficTimer.Stop() { - select { - case <-trafficTimer.C: - default: + case <-time.After(trafficCheckInterval): + select { + case <-w.TrafficAlert: + w.setState(connected) + if !t.Stop() { + <-t.C } + t.Reset(w.trafficTimeout) + default: + } + case <-t.C: + if w.IsConnecting() { + t.Reset(w.trafficTimeout) + break } - w.setState(connected) - trafficTimer.Reset(w.trafficTimeout) - case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) } - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() // without this the w.Shutdown() call below will deadlock - if !w.IsConnecting() && w.IsConnected() { + w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call + w.Wg.Done() // Without this the w.Shutdown() call below will deadlock + if w.IsConnected() { err := w.Shutdown() if err != nil { log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) } } - return } - - if w.IsConnected() { - // Routine pausing mechanism - go func(p chan<- struct{}) { - time.Sleep(defaultTrafficPeriod) - select { - case p <- struct{}{}: - default: - } - }(pause) - select { - case <-w.ShutdownC: - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() - return - case <-pause: - } - } } }() } diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 910142caa9a..4d7681f8d13 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -227,7 +227,7 @@ func (w *WebsocketConnection) ReadMessage() Response { select { case w.Traffic <- struct{}{}: - default: // causes contention, just bypass if there is no receiver. + default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding } var standardMessage []byte diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 4725cc9b064..a2ad5e41307 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/http" + "os" "sort" "strconv" "strings" @@ -105,6 +106,12 @@ func (d *dodgyConnection) Connect() error { return fmt.Errorf("cannot connect: %w", errDastardlyReason) } +func TestMain(m *testing.M) { + // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing + trafficCheckInterval = 50 * time.Millisecond + os.Exit(m.Run()) +} + func TestSetup(t *testing.T) { t.Parallel() var w *Websocket @@ -181,22 +188,90 @@ func TestTrafficMonitorTimeout(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - ws.trafficTimeout = time.Millisecond * 42 + signal := struct{}{} + patience := 10 * time.Millisecond + // trafficCheckInterval is changed in TestMain to avoid racing + ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) + + thenish := time.Now() ws.trafficMonitor() + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, disconnected, ws.state.Load(), "websocket must be disconnected") - // Deploy traffic alert - ws.TrafficAlert <- struct{}{} - // try to add another traffic monitor + // Behaviour: Test multiple traffic alerts work and only process one trafficAlert per interval + for i := 0; i < 2; i++ { + ws.state.Store(disconnected) + + select { + case ws.TrafficAlert <- signal: + if i == 0 { + require.WithinDurationf(t, time.Now(), thenish, trafficCheckInterval, "First Non-blocking test must happen before the traffic is checked") + } + default: + require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) + } + + select { + case ws.TrafficAlert <- signal: + require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) + default: + if i == 0 { + require.WithinDuration(t, time.Now(), thenish, trafficCheckInterval, "First Blocking test must happen before the traffic is checked") + } + } + + require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 6*trafficCheckInterval, patience, "trafficAlert should be drained; Check #%d", i) + assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) + } + + // Behaviour: Shuts down websocket and exits on timeout + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") + }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown") + + // Behaviour: connecting status doesn't trigger shutdown + ws.state.Store(connecting) + ws.trafficTimeout = 50 * time.Millisecond + ws.trafficMonitor() + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting") + <-time.After(4 * ws.trafficTimeout) + require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks") + ws.state.Store(connected) + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") + }, 4*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown after connecting status changes") + + // Behaviour: shutdown is processed and waitgroup is cleared + ws.state.Store(connected) + ws.trafficTimeout = time.Minute ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - // prevent shutdown routine - ws.setState(disconnected) - // await timeout closure - ws.Wg.Wait() - assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be not be running") + wgReady := make(chan bool) + go func() { + ws.Wg.Wait() + close(wgReady) + }() + select { + case <-wgReady: + require.Failf(t, "", "WaitGroup should be blocking still") + case <-time.After(trafficCheckInterval): + } + + close(ws.ShutdownC) + + <-time.After(2 * trafficCheckInterval) + assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be shutdown") + select { + case <-wgReady: + default: + require.Failf(t, "", "WaitGroup should be freed now") + } } func TestIsDisconnectionError(t *testing.T) { @@ -1079,18 +1154,11 @@ func TestSetupNewConnection(t *testing.T) { err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") - nonsenseWebsock.TrafficAlert = make(chan struct{}) + nonsenseWebsock.TrafficAlert = make(chan struct{}, 1) err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") - web := Websocket{ - connector: connect, - Wg: new(sync.WaitGroup), - ShutdownC: make(chan struct{}), - TrafficAlert: make(chan struct{}), - ReadMessageErrors: make(chan error), - DataHandler: make(chan interface{}), - } + web := NewWebsocket() err = web.Setup(defaultSetup) assert.NoError(t, err, "Setup should not error") From 248212b948368cf8ae369a98ee998fd415d295d3 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Tue, 20 Feb 2024 14:57:20 +0700 Subject: [PATCH 12/15] Websocket: Split traficMonitor test on behaviours --- exchanges/stream/websocket_test.go | 33 ++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index a2ad5e41307..5476cd650b9 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -182,7 +182,9 @@ func TestSetup(t *testing.T) { assert.NoError(t, err, "Setup should not error") } -func TestTrafficMonitorTimeout(t *testing.T) { +// TestTrafficMonitorTrafficAlerts ensures multiple traffic alerts work and only process one trafficAlert per interval +// ensures shutdown works after traffic alerts +func TestTrafficMonitorTrafficAlerts(t *testing.T) { t.Parallel() ws := NewWebsocket() err := ws.Setup(defaultSetup) @@ -190,7 +192,6 @@ func TestTrafficMonitorTimeout(t *testing.T) { signal := struct{}{} patience := 10 * time.Millisecond - // trafficCheckInterval is changed in TestMain to avoid racing ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) @@ -200,7 +201,6 @@ func TestTrafficMonitorTimeout(t *testing.T) { assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") require.Equal(t, disconnected, ws.state.Load(), "websocket must be disconnected") - // Behaviour: Test multiple traffic alerts work and only process one trafficAlert per interval for i := 0; i < 2; i++ { ws.state.Store(disconnected) @@ -226,17 +226,24 @@ func TestTrafficMonitorTimeout(t *testing.T) { assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) } - // Behaviour: Shuts down websocket and exits on timeout require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") - }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown") + }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") +} + +// TestTrafficMonitorConnecting ensure connecting status doesn't trigger shutdown +func TestTrafficMonitorConnecting(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") - // Behaviour: connecting status doesn't trigger shutdown + ws.ShutdownC = make(chan struct{}) ws.state.Store(connecting) ws.trafficTimeout = 50 * time.Millisecond ws.trafficMonitor() - assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting") <-time.After(4 * ws.trafficTimeout) require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks") @@ -244,9 +251,17 @@ func TestTrafficMonitorTimeout(t *testing.T) { require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") - }, 4*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown after connecting status changes") + }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") +} - // Behaviour: shutdown is processed and waitgroup is cleared +// TestTrafficMonitorShutdown ensure shutdown is processed and waitgroup is cleared +func TestTrafficMonitorShutdown(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + ws.ShutdownC = make(chan struct{}) ws.state.Store(connected) ws.trafficTimeout = time.Minute ws.trafficMonitor() From d79f8fb7e3dda6a101e655725962d4a54c3b90d0 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Tue, 20 Feb 2024 15:46:33 +0700 Subject: [PATCH 13/15] Websocket: Remove trafficMonitor connected status trafficMonitor does not need to set the connection to be connected. Connect() does that. Anything after that should result in a full shutdown and restart. It can't and shouldn't become connected unexpectedly, and this is most likely a race anyway. Also dropped trafficCheckInterval to 100ms to mitigate races of traffic alerts being buffered for too long. --- exchanges/stream/websocket.go | 11 ++++++++--- exchanges/stream/websocket_test.go | 11 +++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index bebc3abb595..7d1d402d161 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -70,7 +70,7 @@ var ( var ( globalReporter Reporter - trafficCheckInterval = time.Second + trafficCheckInterval = 100 * time.Millisecond ) // SetupGlobalReporter sets a reporter interface to be used @@ -569,7 +569,6 @@ func (w *Websocket) trafficMonitor() { case <-time.After(trafficCheckInterval): select { case <-w.TrafficAlert: - w.setState(connected) if !t.Stop() { <-t.C } @@ -577,7 +576,13 @@ func (w *Websocket) trafficMonitor() { default: } case <-t.C: - if w.IsConnecting() { + checkAgain := w.IsConnecting() + select { + case <-w.TrafficAlert: + checkAgain = true + default: + } + if checkAgain { t.Reset(w.trafficTimeout) break } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 5476cd650b9..cb07ab5175a 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -194,16 +194,15 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { patience := 10 * time.Millisecond ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) + ws.state.Store(connected) thenish := time.Now() ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, disconnected, ws.state.Load(), "websocket must be disconnected") - - for i := 0; i < 2; i++ { - ws.state.Store(disconnected) + require.Equal(t, connected, ws.state.Load(), "websocket must be connected") + for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass select { case ws.TrafficAlert <- signal: if i == 0 { @@ -232,7 +231,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") } -// TestTrafficMonitorConnecting ensure connecting status doesn't trigger shutdown +// TestTrafficMonitorConnecting ensures connecting status doesn't trigger shutdown func TestTrafficMonitorConnecting(t *testing.T) { t.Parallel() ws := NewWebsocket() @@ -254,7 +253,7 @@ func TestTrafficMonitorConnecting(t *testing.T) { }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") } -// TestTrafficMonitorShutdown ensure shutdown is processed and waitgroup is cleared +// TestTrafficMonitorShutdown ensures shutdown is processed and waitgroup is cleared func TestTrafficMonitorShutdown(t *testing.T) { t.Parallel() ws := NewWebsocket() From 9aca06e68a50ab42f96a259f293edff9508c863a Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 21 Feb 2024 08:40:27 +0700 Subject: [PATCH 14/15] Websocket: Set disconnected earlier in Shutdown This caused a possible race where state is still connected, but we start to trigger interested actors via ShutdownC and Wait. They may check state and then call Shutdown again, such as trafficMonitor --- exchanges/stream/websocket.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 7d1d402d161..404d76d86f1 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -477,10 +477,11 @@ func (w *Websocket) Shutdown() error { w.subscriptions = subscriptionMap{} w.subscriptionMutex.Unlock() + w.setState(disconnected) + close(w.ShutdownC) w.Wg.Wait() w.ShutdownC = make(chan struct{}) - w.setState(disconnected) if w.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } From f9ea5697fe1ca5e8c0aa7c8d823f2ce59b471860 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Fri, 23 Feb 2024 13:18:57 +0700 Subject: [PATCH 15/15] Websocket: Wait 5s for slow tests to pass traffic draining Keep getting failures upstream on test rigs. Think they can be very contended, so this pushes the boundary right out to 5s --- exchanges/stream/websocket_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index cb07ab5175a..0d4e9c02e57 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -221,7 +221,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { } } - require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 6*trafficCheckInterval, patience, "trafficAlert should be drained; Check #%d", i) + require.Eventuallyf(t, func() bool { return len(ws.TrafficAlert) == 0 }, 5*time.Second, patience, "trafficAlert should be drained; Check #%d", i) assert.Truef(t, ws.IsConnected(), "state should still be connected; Check #%d", i) }