From 9b56503e82c4aaad66262498c0abc2aea65f4ee1 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Fri, 26 Jan 2024 14:32:45 +0700 Subject: [PATCH 01/35] Websocket: Remove IsInit and simplify SetProxyAddress IsInit was basically the same as IsConnected. Any time Connect was called both would be set to true. Any time we had a disconnect they'd both be set to false Shutdown() incorrectly didn't setInit(false) SetProxyAddress simplified to only reconnect a connected Websocket. Any other state means it hasn't been Connected, or it's about to reconnect anyway. There's no handling for IsConnecting previously, either, so I've wrapped that behind the main mutex. --- exchanges/stream/websocket.go | 104 ++++++++++------------------------ 1 file changed, 30 insertions(+), 74 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index aefc3400f60..b49835bac80 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -286,7 +286,6 @@ func (w *Websocket) Connect() error { } w.setConnectedStatus(true) w.setConnectingStatus(false) - w.setInit(true) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() @@ -369,9 +368,7 @@ func (w *Websocket) dataMonitor() { case <-w.ShutdownC: return default: - log.Warnf(log.WebsocketMgr, - "%s exchange backlog in websocket processing detected", - w.exchangeName) + log.Warnf(log.WebsocketMgr, "%s exchange backlog in websocket processing detected", w.exchangeName) select { case w.ToRoutine <- d: case <-w.ShutdownC: @@ -396,26 +393,19 @@ func (w *Websocket) connectionMonitor() error { timer := time.NewTimer(delay) for { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: running connection monitor cycle\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) } if !w.IsEnabled() { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connectionMonitor - websocket disabled, shutting down\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) } if w.IsConnected() { - err := w.Shutdown() - if err != nil { + if err := w.Shutdown(); err != nil { log.Errorln(log.WebsocketMgr, err) } } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connection monitor exiting\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) } timer.Stop() w.setConnectionMonitorRunning(false) @@ -424,10 +414,7 @@ func (w *Websocket) connectionMonitor() error { select { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { - w.setInit(false) - log.Warnf(log.WebsocketMgr, - "%v websocket has been disconnected. Reason: %v", - w.exchangeName, err) + log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) w.setConnectedStatus(false) } @@ -459,21 +446,16 @@ func (w *Websocket) Shutdown() error { defer w.m.Unlock() if !w.IsConnected() { - return fmt.Errorf("%v websocket: cannot shutdown %w", - w.exchangeName, - ErrNotConnected) + return fmt.Errorf("%v websocket: cannot shutdown %w", w.exchangeName, ErrNotConnected) } // TODO: Interrupt connection and or close connection when it is re-established. if w.IsConnecting() { - return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", - w.exchangeName) + return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", w.exchangeName) } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: shutting down websocket\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", w.exchangeName) } defer w.Orderbook.FlushBuffer() @@ -501,9 +483,7 @@ func (w *Websocket) Shutdown() error { w.setConnectedStatus(false) w.setConnectingStatus(false) if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: completed websocket shutdown\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } return nil } @@ -582,9 +562,7 @@ func (w *Websocket) trafficMonitor() { select { case <-w.ShutdownC: if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown message received\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) } trafficTimer.Stop() w.setTrafficMonitorRunning(false) @@ -601,10 +579,7 @@ func (w *Websocket) trafficMonitor() { trafficTimer.Reset(w.trafficTimeout) case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { - log.Warnf(log.WebsocketMgr, - "%v websocket: has not received a traffic alert in %v. Reconnecting", - w.exchangeName, - w.trafficTimeout) + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) } trafficTimer.Stop() w.setTrafficMonitorRunning(false) @@ -612,9 +587,7 @@ func (w *Websocket) trafficMonitor() { if !w.IsConnecting() && w.IsConnected() { err := w.Shutdown() if err != nil { - log.Errorf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown err: %s", - w.exchangeName, err) + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) } } @@ -682,19 +655,6 @@ func (w *Websocket) IsEnabled() bool { return w.enabled } -func (w *Websocket) setInit(b bool) { - w.fieldMutex.Lock() - w.Init = b - w.fieldMutex.Unlock() -} - -// IsInit returns status of init -func (w *Websocket) IsInit() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.Init -} - func (w *Websocket) setTrafficMonitorRunning(b bool) { w.fieldMutex.Lock() w.trafficMonitorRunning = b @@ -820,28 +780,22 @@ func (w *Websocket) GetWebsocketURL() string { // SetProxyAddress sets websocket proxy address func (w *Websocket) SetProxyAddress(proxyAddr string) error { + w.m.Lock() + if proxyAddr != "" { - _, err := url.ParseRequestURI(proxyAddr) - if err != nil { - return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", - w.exchangeName, - err) + if _, err := url.ParseRequestURI(proxyAddr); err != nil { + w.m.Unlock() + return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", w.exchangeName, err) } if w.proxyAddr == proxyAddr { - return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", - w.exchangeName, - w.proxyAddr) + w.m.Unlock() + return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", w.exchangeName, w.proxyAddr) } - log.Debugf(log.ExchangeSys, - "%s websocket: setting websocket proxy: %s\n", - w.exchangeName, - proxyAddr) + log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr) } else { - log.Debugf(log.ExchangeSys, - "%s websocket: removing websocket proxy\n", - w.exchangeName) + log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) } if w.Conn != nil { @@ -852,15 +806,17 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { } w.proxyAddr = proxyAddr - if w.IsInit() && w.IsEnabled() { - if w.IsConnected() { - err := w.Shutdown() - if err != nil { - return err - } + + if w.IsConnected() { + w.m.Unlock() + if err := w.Shutdown(); err != nil { + return err } return w.Connect() } + + w.m.Unlock() + return nil } From 2609d450a5b54fe0e6e59a8fc2482daf8fa7326a Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 27 Jan 2024 12:13:49 +0700 Subject: [PATCH 02/35] Websocket: Expand and Assertify tests --- exchanges/stream/websocket.go | 19 ++-- exchanges/stream/websocket_test.go | 141 ++++++++++------------------- 2 files changed, 55 insertions(+), 105 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index b49835bac80..bb23c9a8464 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -64,6 +64,7 @@ var ( errNoSubscriptionsSupplied = errors.New("no subscriptions supplied") errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") ) var globalReporter Reporter @@ -262,12 +263,10 @@ func (w *Websocket) Connect() error { return errors.New(WebsocketNotEnabled) } if w.IsConnecting() { - return fmt.Errorf("%v Websocket already attempting to connect", - w.exchangeName) + return fmt.Errorf("%v Websocket already attempting to connect", w.exchangeName) } if w.IsConnected() { - return fmt.Errorf("%v Websocket already connected", - w.exchangeName) + return fmt.Errorf("%v Websocket already connected", w.exchangeName) } w.subscriptionMutex.Lock() @@ -281,8 +280,7 @@ func (w *Websocket) Connect() error { err := w.connector() if err != nil { w.setConnectingStatus(false) - return fmt.Errorf("%v Error connecting %s", - w.exchangeName, err) + return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } w.setConnectedStatus(true) w.setConnectingStatus(false) @@ -290,10 +288,7 @@ func (w *Websocket) Connect() error { if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() if err != nil { - log.Errorf(log.WebsocketMgr, - "%s cannot start websocket connection monitor %v", - w.GetName(), - err) + log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) } } @@ -785,12 +780,12 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { if proxyAddr != "" { if _, err := url.ParseRequestURI(proxyAddr); err != nil { w.m.Unlock() - return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", w.exchangeName, err) + return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err) } if w.proxyAddr == proxyAddr { w.m.Unlock() - return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", w.exchangeName, w.proxyAddr) + return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr) } log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3ab49e0df10..650509b11d1 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -30,6 +31,10 @@ const ( proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) +var ( + errDastardlyReason = errors.New("cannot shutdown due to some dastardly reason") +) + var dialer websocket.Dialer type testStruct struct { @@ -92,12 +97,12 @@ type dodgyConnection struct { // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Shutdown() error { - return errors.New("cannot shutdown due to some dastardly reason") + return errDastardlyReason } // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Connect() error { - return errors.New("cannot connect due to some dastardly reason") + return errDastardlyReason } func TestSetup(t *testing.T) { @@ -349,152 +354,102 @@ func TestWebsocket(t *testing.T) { Name: "test", }, }) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } + assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "SetProxyAddress should error correctly") ws := *New() err = ws.SetProxyAddress("garbagio") - if err == nil { - t.Error("error cannot be nil") - } + assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") ws.Conn = &WebsocketConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) + err = ws.SetProxyAddress("https://192.168.0.1:1337") - if err == nil { - t.Error("error cannot be nil") - } + assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") + ws.setConnectedStatus(true) ws.ShutdownC = make(chan struct{}) ws.Wg = &sync.WaitGroup{} + err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } + assert.ErrorIs(t, err, errNoConnectFunc, "SetProxyAddress should call Connect and error from there") // This test asserts we actually set the proxy address, etc err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } + assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") ws.setEnabled(false) // removing proxy err = ws.SetProxyAddress("") - if err != nil { - t.Error(err) - } + assert.NoError(t, err, "SetProxyAddress should not error when removing proxy") + // reinstate proxy err = ws.SetProxyAddress("http://localhost:1337") - if err != nil { - t.Error(err) - } - // conflict proxy - err = ws.SetProxyAddress("http://localhost:1337") - if err == nil { - t.Error("error cannot be nil") - } - err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } - if ws.GetName() != "exchangeName" { - t.Error("WebsocketSetup") - } + assert.NoError(t, err, "SetProxyAddress should not error") - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } + err = ws.Setup(defaultSetup) // Sets to enabled again + require.NoError(t, err, "Setup may not error") + + assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly") + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") ws.setEnabled(false) - if ws.IsEnabled() { - t.Error("WebsocketSetup") - } + assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") + ws.setEnabled(true) - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") - if ws.GetProxyAddress() != "http://localhost:1337" { - t.Error("WebsocketSetup") - } + assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") + assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") + assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") - if ws.GetWebsocketURL() != "wss://testRunningURL" { - t.Error("WebsocketSetup") - } - if ws.trafficTimeout != time.Second*5 { - t.Error("WebsocketSetup") - } - // -- Not connected shutdown err = ws.Shutdown() - if err == nil { - t.Fatal("should not be connected to able to shut down") - } + assert.ErrorIs(t, err, ErrNotConnected, "Shutdown should error when not Connected") ws.setConnectedStatus(true) ws.Conn = &dodgyConnection{} err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy conn") ws.Conn = &WebsocketConnection{} ws.setConnectedStatus(true) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil ") - } + assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} ws.setConnectedStatus(false) - // -- Normal connect err = ws.Connect() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Connect should not error") ws.defaultURL = "ws://demos.kaazing.com/echo" ws.defaultURLAuth = "ws://demos.kaazing.com/echo" err = ws.SetWebsocketURL("", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("", true, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, false) - if err != nil { - t.Fatal(err) - } - // Attempt reconnect + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, true) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error on reconnect") + // -- initiate the reconnect which is usually handled by connection monitor err = ws.Connect() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "ReConnect called manually should not error") + err = ws.Connect() - if err == nil { - t.Fatal("should already be connected") - } - // -- Normal shutdown + assert.ErrorIs(t, err, errAlreadyConnected, "ReConnect should error when already connected") + err = ws.Shutdown() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Shutdown should not error") ws.Wg.Wait() } From 398cd7c45df011f727f10e57a8049b092be557b0 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sun, 28 Jan 2024 08:24:15 +0700 Subject: [PATCH 03/35] Websocket: Simplify state transistions --- exchanges/stream/websocket.go | 9 ++++----- exchanges/stream/websocket_types.go | 12 ++++++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index bb23c9a8464..8f27bca4c50 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -78,7 +78,6 @@ func SetupGlobalReporter(r Reporter) { // New initialises the websocket struct func New() *Websocket { return &Websocket{ - Init: true, DataHandler: make(chan interface{}, defaultJobBuffer), ToRoutine: make(chan interface{}, defaultJobBuffer), TrafficAlert: make(chan struct{}), @@ -99,7 +98,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } - if !w.Init { + if w.state != uninitialised { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } @@ -613,7 +612,7 @@ func (w *Websocket) trafficMonitor() { func (w *Websocket) setConnectedStatus(b bool) { w.fieldMutex.Lock() - w.connected = b + w.state = connected w.fieldMutex.Unlock() } @@ -621,12 +620,12 @@ func (w *Websocket) setConnectedStatus(b bool) { func (w *Websocket) IsConnected() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() - return w.connected + return w.state == connected } func (w *Websocket) setConnectingStatus(b bool) { w.fieldMutex.Lock() - w.connecting = b + w.state = connecting w.fieldMutex.Unlock() } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 925c34b907c..2db7fe67154 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -25,13 +25,21 @@ const ( type subscriptionMap map[any]*subscription.Subscription +type State int + +const ( + uninitialised State = iota + disconnected + connecting + connected +) + // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { canUseAuthenticatedEndpoints bool enabled bool - Init bool - connected bool + state State connecting bool verbose bool connectionMonitorRunning bool From 5be87e6ce8c9574815556ada7fa781f008903291 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 5 Feb 2024 10:47:57 +0700 Subject: [PATCH 04/35] fixup! Websocket: Simplify state transistions --- exchanges/stream/websocket_types.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 2db7fe67154..9f3ec68a317 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -39,8 +39,7 @@ const ( type Websocket struct { canUseAuthenticatedEndpoints bool enabled bool - state State - connecting bool + state state verbose bool connectionMonitorRunning bool trafficMonitorRunning bool From de01ec17b2587f2ea026049436dc23d780b1dfd5 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 29 Jan 2024 14:48:54 +0700 Subject: [PATCH 05/35] Websocket: Simplify Connecting/Connected state --- engine/websocketroutine_manager_test.go | 2 +- exchanges/binance/binance_wrapper.go | 2 +- exchanges/binanceus/binanceus_wrapper.go | 2 +- exchanges/bitfinex/bitfinex_wrapper.go | 2 +- exchanges/bithumb/bithumb_wrapper.go | 2 +- exchanges/bitmex/bitmex_wrapper.go | 2 +- exchanges/bitstamp/bitstamp_wrapper.go | 2 +- exchanges/btcmarkets/btcmarkets_wrapper.go | 2 +- exchanges/btse/btse_wrapper.go | 2 +- exchanges/bybit/bybit_wrapper.go | 2 +- exchanges/coinbasepro/coinbasepro_wrapper.go | 2 +- exchanges/coinut/coinut_wrapper.go | 2 +- exchanges/exchange_test.go | 6 +- exchanges/gateio/gateio_wrapper.go | 2 +- exchanges/gemini/gemini_wrapper.go | 2 +- exchanges/hitbtc/hitbtc_wrapper.go | 2 +- exchanges/huobi/huobi_wrapper.go | 2 +- exchanges/kraken/kraken_wrapper.go | 2 +- exchanges/kucoin/kucoin_wrapper.go | 2 +- exchanges/okcoin/okcoin_wrapper.go | 2 +- exchanges/okx/okx_wrapper.go | 2 +- exchanges/poloniex/poloniex_wrapper.go | 2 +- .../sharedtestvalues/sharedtestvalues.go | 1 - exchanges/stream/websocket.go | 37 ++++---- exchanges/stream/websocket_test.go | 86 +++++++++---------- exchanges/stream/websocket_types.go | 4 +- 26 files changed, 82 insertions(+), 94 deletions(-) diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index e082193bda4..c1c3541db31 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -293,7 +293,7 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { t.Fatal("unexpected data handlers registered") } - mock := stream.New() + mock := stream.NewWebsocket() mock.ToRoutine = make(chan interface{}) m.state = readyState err = m.websocketDataReceiver(mock) diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index bdf70e0d3ee..d54f89242ca 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -238,7 +238,7 @@ func (b *Binance) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/binanceus/binanceus_wrapper.go b/exchanges/binanceus/binanceus_wrapper.go index 3f078ede7a8..ed2e5d10f5e 100644 --- a/exchanges/binanceus/binanceus_wrapper.go +++ b/exchanges/binanceus/binanceus_wrapper.go @@ -162,7 +162,7 @@ func (bi *Binanceus) SetDefaults() { "%s setting default endpoints error %v", bi.Name, err) } - bi.Websocket = stream.New() + bi.Websocket = stream.NewWebsocket() bi.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit bi.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout bi.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 7e1ef64c766..2aba9987cf9 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -198,7 +198,7 @@ func (b *Bitfinex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 5dbc314674d..24ea8980b19 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -150,7 +150,7 @@ func (b *Bithumb) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 151f1ef28e9..a0810b1a302 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -175,7 +175,7 @@ func (b *Bitmex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index 888fae0b395..2fe7fa96d82 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -146,7 +146,7 @@ func (b *Bitstamp) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index 17ee7277c0a..8a924b08cf0 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -150,7 +150,7 @@ func (b *BTCMarkets) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index f7ec3c8f702..3cdbd56b2c5 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -176,7 +176,7 @@ func (b *BTSE) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 28d4f15041d..b6e7f2be224 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -216,7 +216,7 @@ func (by *Bybit) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - by.Websocket = stream.New() + by.Websocket = stream.NewWebsocket() by.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit by.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout by.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 47a20d9ed95..21a34e2ea22 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -145,7 +145,7 @@ func (c *CoinbasePro) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 503c2909461..db4af53badd 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -127,7 +127,7 @@ func (c *COINUT) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index d41f499d48d..ad0060495f3 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -198,7 +198,7 @@ func TestSetClientProxyAddress(t *testing.T) { Name: "rawr", Requester: requester} - newBase.Websocket = stream.New() + newBase.Websocket = stream.NewWebsocket() err = newBase.SetClientProxyAddress("") if err != nil { t.Error(err) @@ -1251,7 +1251,7 @@ func TestSetupDefaults(t *testing.T) { } // Test websocket support - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.Features.Supports.Websocket = true err = b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ @@ -1596,7 +1596,7 @@ func TestIsWebsocketEnabled(t *testing.T) { t.Error("exchange doesn't support websocket") } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() err := b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ Enabled: true, diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 3026efc0608..adea4d2fa6b 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -194,7 +194,7 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index fee75d6b1a2..d2d89eb5008 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -128,7 +128,7 @@ func (g *Gemini) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index 3b65fe649b5..7bcf7ada335 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -147,7 +147,7 @@ func (h *HitBTC) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 3d1d576b9d2..90d70491bd8 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -202,7 +202,7 @@ func (h *HUOBI) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index de631e0d904..5875917592c 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -209,7 +209,7 @@ func (k *Kraken) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - k.Websocket = stream.New() + k.Websocket = stream.NewWebsocket() k.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit k.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout k.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kucoin/kucoin_wrapper.go b/exchanges/kucoin/kucoin_wrapper.go index 8a98a77a02d..d767a2cf441 100644 --- a/exchanges/kucoin/kucoin_wrapper.go +++ b/exchanges/kucoin/kucoin_wrapper.go @@ -195,7 +195,7 @@ func (ku *Kucoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - ku.Websocket = stream.New() + ku.Websocket = stream.NewWebsocket() ku.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit ku.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout ku.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okcoin/okcoin_wrapper.go b/exchanges/okcoin/okcoin_wrapper.go index 8519cdc1b11..af0655ef3f1 100644 --- a/exchanges/okcoin/okcoin_wrapper.go +++ b/exchanges/okcoin/okcoin_wrapper.go @@ -150,7 +150,7 @@ func (o *Okcoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - o.Websocket = stream.New() + o.Websocket = stream.NewWebsocket() o.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit o.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout o.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index 0b5472b9e65..64e2e269877 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -190,7 +190,7 @@ func (ok *Okx) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - ok.Websocket = stream.New() + ok.Websocket = stream.NewWebsocket() ok.WebsocketResponseMaxLimit = okxWebsocketResponseMaxLimit ok.WebsocketResponseCheckTimeout = okxWebsocketResponseMaxLimit ok.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index 97eb9293d56..57f28fac98b 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -159,7 +159,7 @@ func (p *Poloniex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - p.Websocket = stream.New() + p.Websocket = stream.NewWebsocket() p.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit p.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout p.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index 60d51d82bfe..54acf9dcb5b 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -57,7 +57,6 @@ func GetWebsocketStructChannelOverride() chan struct{} { // NewTestWebsocket returns a test websocket object func NewTestWebsocket() *stream.Websocket { return &stream.Websocket{ - Init: true, DataHandler: make(chan interface{}, WebsocketChannelOverrideCapacity), ToRoutine: make(chan interface{}, 1000), TrafficAlert: make(chan struct{}), diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 8f27bca4c50..428812b0ee1 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -65,6 +65,8 @@ var ( errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("connect func not set") + errAlreadyConnected = errors.New("already connected") ) var globalReporter Reporter @@ -75,8 +77,8 @@ func SetupGlobalReporter(r Reporter) { globalReporter = r } -// New initialises the websocket struct -func New() *Websocket { +// NewWebsocket initialises the websocket struct +func NewWebsocket() *Websocket { return &Websocket{ DataHandler: make(chan interface{}, defaultJobBuffer), ToRoutine: make(chan interface{}, defaultJobBuffer), @@ -188,6 +190,8 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions) } w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection + w.setState(disconnected) + return nil } @@ -253,7 +257,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { // function func (w *Websocket) Connect() error { if w.connector == nil { - return errors.New("websocket connect function not set, cannot continue") + return errNoConnectFunc } w.m.Lock() defer w.m.Unlock() @@ -274,15 +278,14 @@ func (w *Websocket) Connect() error { w.dataMonitor() w.trafficMonitor() - w.setConnectingStatus(true) + w.setState(connecting) err := w.connector() if err != nil { - w.setConnectingStatus(false) + w.setState(disconnected) return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } - w.setConnectedStatus(true) - w.setConnectingStatus(false) + w.setState(connected) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() @@ -310,6 +313,7 @@ func (w *Websocket) Connect() error { } // Disable disables the exchange websocket protocol +// Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { if !w.IsEnabled() { return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName) @@ -409,7 +413,7 @@ func (w *Websocket) connectionMonitor() error { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) - w.setConnectedStatus(false) + w.setState(disconnected) } w.DataHandler <- err @@ -474,8 +478,7 @@ func (w *Websocket) Shutdown() error { close(w.ShutdownC) w.Wg.Wait() w.ShutdownC = make(chan struct{}) - w.setConnectedStatus(false) - w.setConnectingStatus(false) + w.setState(disconnected) if w.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } @@ -569,7 +572,7 @@ func (w *Websocket) trafficMonitor() { default: } } - w.setConnectedStatus(true) + w.setState(connected) trafficTimer.Reset(w.trafficTimeout) case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { @@ -610,9 +613,9 @@ func (w *Websocket) trafficMonitor() { }() } -func (w *Websocket) setConnectedStatus(b bool) { +func (w *Websocket) setState(s state) { w.fieldMutex.Lock() - w.state = connected + w.state = s w.fieldMutex.Unlock() } @@ -623,17 +626,11 @@ func (w *Websocket) IsConnected() bool { return w.state == connected } -func (w *Websocket) setConnectingStatus(b bool) { - w.fieldMutex.Lock() - w.state = connecting - w.fieldMutex.Unlock() -} - // IsConnecting returns status of connecting func (w *Websocket) IsConnecting() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() - return w.connecting + return w.state == connecting } func (w *Websocket) setEnabled(b bool) { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 650509b11d1..d3b142fa926 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -125,7 +125,7 @@ func TestSetup(t *testing.T) { t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) } - w.Init = true + w.setState(disconnected) err = w.Setup(websocketSetup) if !errors.Is(err, errExchangeConfigIsNil) { t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil) @@ -214,7 +214,7 @@ func TestSetup(t *testing.T) { func TestTrafficMonitorTimeout(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() if err := ws.Setup(defaultSetup); err != nil { t.Fatal(err) } @@ -232,7 +232,7 @@ func TestTrafficMonitorTimeout(t *testing.T) { t.Fatal("traffic monitor should be running") } // prevent shutdown routine - ws.setConnectedStatus(false) + ws.setState(disconnected) // await timeout closure ws.Wg.Wait() if ws.IsTrafficMonitorRunning() { @@ -284,21 +284,21 @@ func TestConnectionMessageErrors(t *testing.T) { } wsWrong.setEnabled(true) - wsWrong.setConnectingStatus(true) + wsWrong.setState(connecting) wsWrong.Wg = &sync.WaitGroup{} err = wsWrong.Connect() if err == nil { t.Fatal("error cannot be nil") } - wsWrong.setConnectedStatus(false) + wsWrong.setState(disconnected) wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") } err = wsWrong.Connect() if err == nil { t.Fatal("error cannot be nil") } - ws := *New() + ws := NewWebsocket() err = ws.Setup(defaultSetup) if err != nil { t.Fatal(err) @@ -345,31 +345,35 @@ outer: func TestWebsocket(t *testing.T) { t.Parallel() - wsInit := Websocket{} - err := wsInit.Setup(&WebsocketSetup{ - ExchangeConfig: &config.Exchange{ - Features: &config.FeaturesConfig{ - Enabled: config.FeaturesEnabledConfig{Websocket: true}, - }, - Name: "test", - }, - }) - assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "SetProxyAddress should error correctly") - ws := *New() - err = ws.SetProxyAddress("garbagio") + ws := NewWebsocket() + + err := ws.SetProxyAddress("garbagio") assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") ws.Conn = &WebsocketConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) + err = ws.Setup(defaultSetup) // Sets to enabled again + require.NoError(t, err, "Setup may not error") + + err = ws.Setup(defaultSetup) + assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly if called twice") + + assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly") + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") + + ws.setEnabled(false) + assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") + + ws.setEnabled(true) + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") + err = ws.SetProxyAddress("https://192.168.0.1:1337") assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") - ws.setConnectedStatus(true) - ws.ShutdownC = make(chan struct{}) - ws.Wg = &sync.WaitGroup{} + ws.setState(connected) err = ws.SetProxyAddress("https://192.168.0.1:1336") assert.ErrorIs(t, err, errNoConnectFunc, "SetProxyAddress should call Connect and error from there") // This test asserts we actually set the proxy address, etc @@ -386,18 +390,6 @@ func TestWebsocket(t *testing.T) { err = ws.SetProxyAddress("http://localhost:1337") assert.NoError(t, err, "SetProxyAddress should not error") - err = ws.Setup(defaultSetup) // Sets to enabled again - require.NoError(t, err, "Setup may not error") - - assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly") - assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") - - ws.setEnabled(false) - assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") - - ws.setEnabled(true) - assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") - assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") @@ -405,20 +397,20 @@ func TestWebsocket(t *testing.T) { err = ws.Shutdown() assert.ErrorIs(t, err, ErrNotConnected, "Shutdown should error when not Connected") - ws.setConnectedStatus(true) + ws.setState(connected) ws.Conn = &dodgyConnection{} err = ws.Shutdown() assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy conn") ws.Conn = &WebsocketConnection{} - ws.setConnectedStatus(true) + ws.setState(connected) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} - ws.setConnectedStatus(false) + ws.setState(disconnected) err = ws.Connect() assert.NoError(t, err, "Connect should not error") @@ -456,7 +448,7 @@ func TestWebsocket(t *testing.T) { // TestSubscribe logic test func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error") fnSub := func(subs []subscription.Subscription) error { @@ -501,7 +493,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { // TestResubscribe tests Resubscribing to existing subscriptions func TestResubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() wackedOutSetup := *defaultSetup wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 @@ -532,7 +524,7 @@ func TestResubscribe(t *testing.T) { // TestSubscriptionState tests Subscription state changes func TestSubscriptionState(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState} assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error") @@ -558,7 +550,7 @@ func TestSubscriptionState(t *testing.T) { // TestRemoveSubscriptions tests removing a subscription func TestRemoveSubscriptions(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Unite!"} assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") @@ -571,7 +563,7 @@ func TestRemoveSubscriptions(t *testing.T) { // TestConnectionMonitorNoConnection logic test func TestConnectionMonitorNoConnection(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() ws.connectionMonitorDelay = 500 ws.DataHandler = make(chan interface{}, 1) ws.ShutdownC = make(chan struct{}, 1) @@ -626,7 +618,7 @@ func TestGetSubscriptions(t *testing.T) { // TestSetCanUseAuthenticatedEndpoints logic test func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() result := ws.CanUseAuthenticatedEndpoints() if result { t.Error("expected `canUseAuthenticatedEndpoints` to be false") @@ -925,7 +917,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { if resp { t.Error("Expected false, `connected` is false") } - ws.setConnectedStatus(true) + ws.setState(connected) resp = ws.CanUseAuthenticatedWebsocketForWrapper() if resp { t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false") @@ -1109,7 +1101,7 @@ func TestFlushChannels(t *testing.T) { web := Websocket{ enabled: true, - connected: true, + state: connected, connector: connect, ShutdownC: make(chan struct{}), Subscriber: newgen.SUBME, @@ -1204,7 +1196,7 @@ func TestFlushChannels(t *testing.T) { t.Fatal(err) } - web.setConnectedStatus(true) + web.setState(connected) web.features.Unsubscribe = true err = web.FlushChannels() if err != nil { @@ -1216,7 +1208,7 @@ func TestDisable(t *testing.T) { t.Parallel() web := Websocket{ enabled: true, - connected: true, + state: connected, ShutdownC: make(chan struct{}), } err := web.Disable() @@ -1284,7 +1276,7 @@ func TestSetupNewConnection(t *testing.T) { connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), - Init: true, + state: disconnected, TrafficAlert: make(chan struct{}), ReadMessageErrors: make(chan error), DataHandler: make(chan interface{}), diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 9f3ec68a317..7af4ae03a1d 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -25,10 +25,10 @@ const ( type subscriptionMap map[any]*subscription.Subscription -type State int +type state int const ( - uninitialised State = iota + uninitialised state = iota disconnected connecting connected From cf654dd475b6f61e1409c96e4479d9376c0c46b3 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 5 Feb 2024 10:41:22 +0700 Subject: [PATCH 06/35] fixup! Websocket: Simplify Connecting/Connected state --- exchanges/stream/websocket.go | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 428812b0ee1..53b49ab93c3 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -100,7 +100,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } - if w.state != uninitialised { + if w.IsInitialised() { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } @@ -613,20 +613,27 @@ func (w *Websocket) trafficMonitor() { }() } +// IsInitialised returns whether the websocket has been Setup() already +func (w *Websocket) IsInitialised() bool { + w.fieldMutex.RLock() + defer w.fieldMutex.RUnlock() + return w.state != uninitialised +} + func (w *Websocket) setState(s state) { w.fieldMutex.Lock() w.state = s w.fieldMutex.Unlock() } -// IsConnected returns status of connection +// IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() return w.state == connected } -// IsConnecting returns status of connecting +// IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() @@ -639,7 +646,7 @@ func (w *Websocket) setEnabled(b bool) { w.fieldMutex.Unlock() } -// IsEnabled returns status of enabled +// IsEnabled returns whether the websocket is enabled func (w *Websocket) IsEnabled() bool { w.fieldMutex.RLock() defer w.fieldMutex.RUnlock() From 2dc5b2c35dec2b60762a15bdb9ddec6da850987c Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 3 Feb 2024 09:25:36 +0700 Subject: [PATCH 07/35] Websocket: Tests and errors for websocket --- exchanges/stream/websocket.go | 14 ++-- exchanges/stream/websocket_test.go | 126 +++++++++-------------------- 2 files changed, 47 insertions(+), 93 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 53b49ab93c3..4fd0685bf7b 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -65,8 +65,10 @@ var ( errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") errSameProxyAddress = errors.New("cannot set proxy address to the same address") - errNoConnectFunc = errors.New("connect func not set") - errAlreadyConnected = errors.New("already connected") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") ) var globalReporter Reporter @@ -266,10 +268,10 @@ func (w *Websocket) Connect() error { return errors.New(WebsocketNotEnabled) } if w.IsConnecting() { - return fmt.Errorf("%v Websocket already attempting to connect", w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting) } if w.IsConnected() { - return fmt.Errorf("%v Websocket already connected", w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected) } w.subscriptionMutex.Lock() @@ -444,12 +446,12 @@ func (w *Websocket) Shutdown() error { defer w.m.Unlock() if !w.IsConnected() { - return fmt.Errorf("%v websocket: cannot shutdown %w", w.exchangeName, ErrNotConnected) + return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected) } // TODO: Interrupt connection and or close connection when it is re-established. if w.IsConnecting() { - return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", w.exchangeName) + return fmt.Errorf("%v %w: %w ", w.exchangeName, errCannotShutdown, errAlreadyReconnecting) } if w.verbose { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index d3b142fa926..188c8fcb30c 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -32,7 +32,7 @@ const ( ) var ( - errDastardlyReason = errors.New("cannot shutdown due to some dastardly reason") + errDastardlyReason = errors.New("some dastardly reason") ) var dialer websocket.Dialer @@ -73,7 +73,7 @@ var defaultSetup = &WebsocketSetup{ AuthenticatedWebsocketSupport: true, }, WebsocketTrafficTimeout: time.Second * 5, - Name: "exchangeName", + Name: "GTX", }, DefaultURL: "testDefaultURL", RunningURL: "wss://testRunningURL", @@ -97,147 +97,106 @@ type dodgyConnection struct { // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Shutdown() error { - return errDastardlyReason + return fmt.Errorf("%w: %w", errCannotShutdown, errDastardlyReason) } // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Connect() error { - return errDastardlyReason + return fmt.Errorf("cannot connect: %w", errDastardlyReason) } func TestSetup(t *testing.T) { t.Parallel() var w *Websocket err := w.Setup(nil) - if !errors.Is(err, errWebsocketIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketIsNil) - } + assert.ErrorIs(t, err, errWebsocketIsNil, "Setup should error correctly") w = &Websocket{DataHandler: make(chan interface{})} err = w.Setup(nil) - if !errors.Is(err, errWebsocketSetupIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSetupIsNil) - } + assert.ErrorIs(t, err, errWebsocketSetupIsNil, "Setup should error correctly") websocketSetup := &WebsocketSetup{} - err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } - w.setState(disconnected) err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil) - } + assert.ErrorIs(t, err, errExchangeConfigIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigNameUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigNameUnset) - } - websocketSetup.ExchangeConfig.Name = "testname" + assert.ErrorIs(t, err, errExchangeConfigNameUnset, "Setup should error correctly") + websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketFeaturesIsUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketFeaturesIsUnset) - } + assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset, "Setup should error correctly") websocketSetup.Features = &protocol.Features{} err = w.Setup(websocketSetup) - if !errors.Is(err, errConfigFeaturesIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errConfigFeaturesIsNil) - } + assert.ErrorIs(t, err, errConfigFeaturesIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketConnectorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketConnectorUnset) - } + assert.ErrorIs(t, err, errWebsocketConnectorUnset, "Setup should error correctly") websocketSetup.Connector = func() error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly") websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil } websocketSetup.Features.Unsubscribe = true err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketUnsubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketUnsubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly") websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriptionsGeneratorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriptionsGeneratorUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly") websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errDefaultURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errDefaultURLIsEmpty) - } + assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly") websocketSetup.DefaultURL = "test" err = w.Setup(websocketSetup) - if !errors.Is(err, errRunningURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errRunningURLIsEmpty) - } + assert.ErrorIs(t, err, errRunningURLIsEmpty, "Setup should error correctly") websocketSetup.RunningURL = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURL = "wss://www.google.com" websocketSetup.RunningURLAuth = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURLAuth = "wss://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidTrafficTimeout) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidTrafficTimeout) - } + assert.ErrorIs(t, err, errInvalidTrafficTimeout, "Setup should error correctly") websocketSetup.ExchangeConfig.WebsocketTrafficTimeout = time.Minute err = w.Setup(websocketSetup) - if !errors.Is(err, nil) { - t.Fatalf("received: %v but expected: %v", err, nil) - } + assert.NoError(t, err, "Setup should not error") } func TestTrafficMonitorTimeout(t *testing.T) { t.Parallel() ws := NewWebsocket() - if err := ws.Setup(defaultSetup); err != nil { - t.Fatal(err) - } - ws.trafficTimeout = time.Second * 2 + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + ws.trafficTimeout = time.Millisecond * 42 ws.ShutdownC = make(chan struct{}) ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") - } + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + // Deploy traffic alert ws.TrafficAlert <- struct{}{} // try to add another traffic monitor ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") - } + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + // prevent shutdown routine ws.setState(disconnected) // await timeout closure ws.Wg.Wait() - if ws.IsTrafficMonitorRunning() { - t.Error("should be dead") - } + assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be not be running") } func TestIsDisconnectionError(t *testing.T) { @@ -351,7 +310,7 @@ func TestWebsocket(t *testing.T) { err := ws.SetProxyAddress("garbagio") assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") - ws.Conn = &WebsocketConnection{} + ws.Conn = &dodgyConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) @@ -361,7 +320,7 @@ func TestWebsocket(t *testing.T) { err = ws.Setup(defaultSetup) assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly if called twice") - assert.Equal(t, "exchangeName", ws.GetName(), "GetName should return correctly") + assert.Equal(t, "GTX", ws.GetName(), "GetName should return correctly") assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") ws.setEnabled(false) @@ -376,38 +335,31 @@ func TestWebsocket(t *testing.T) { ws.setState(connected) err = ws.SetProxyAddress("https://192.168.0.1:1336") - assert.ErrorIs(t, err, errNoConnectFunc, "SetProxyAddress should call Connect and error from there") // This test asserts we actually set the proxy address, etc + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") err = ws.SetProxyAddress("https://192.168.0.1:1336") assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") - ws.setEnabled(false) // removing proxy err = ws.SetProxyAddress("") - assert.NoError(t, err, "SetProxyAddress should not error when removing proxy") + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Shutdown and error from there") + assert.ErrorIs(t, err, errCannotShutdown, "SetProxyAddress should call Shutdown and error from there") + + ws.Conn = &WebsocketConnection{} + ws.setEnabled(true) // reinstate proxy err = ws.SetProxyAddress("http://localhost:1337") assert.NoError(t, err, "SetProxyAddress should not error") - assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") - err = ws.Shutdown() - assert.ErrorIs(t, err, ErrNotConnected, "Shutdown should error when not Connected") - - ws.setState(connected) - ws.Conn = &dodgyConnection{} - err = ws.Shutdown() - assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy conn") - - ws.Conn = &WebsocketConnection{} - ws.setState(connected) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") + assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} ws.setState(disconnected) From 5c35987b0bcf969ee94b0e9467e4fbae0b31c657 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 3 Feb 2024 16:18:38 +0700 Subject: [PATCH 08/35] Websocket: Make WebsocketNotEnabled a real error This allows for testing and avoids the repetition. If each returned error is a error.New() you can never use errors.Is() --- cmd/exchange_template/wrapper_file.tmpl | 2 +- exchanges/binance/binance_websocket.go | 2 +- exchanges/binanceus/binanceus_websocket.go | 2 +- exchanges/bitfinex/bitfinex_test.go | 2 +- exchanges/bitfinex/bitfinex_websocket.go | 2 +- exchanges/bithumb/bithumb_websocket.go | 3 +- exchanges/bitmex/bitmex_test.go | 2 +- exchanges/bitmex/bitmex_websocket.go | 2 +- exchanges/bitstamp/bitstamp_websocket.go | 2 +- exchanges/btcmarkets/btcmarkets_websocket.go | 2 +- exchanges/btse/btse_websocket.go | 2 +- exchanges/bybit/bybit.go | 2 - exchanges/bybit/bybit_inverse_websocket.go | 2 +- exchanges/bybit/bybit_linear_websocket.go | 2 +- exchanges/bybit/bybit_options_websocket.go | 2 +- exchanges/bybit/bybit_test.go | 7 +- exchanges/bybit/bybit_websocket.go | 2 +- exchanges/coinbasepro/coinbasepro_test.go | 2 +- .../coinbasepro/coinbasepro_websocket.go | 2 +- exchanges/coinut/coinut_test.go | 2 +- exchanges/coinut/coinut_websocket.go | 2 +- exchanges/gateio/gateio_websocket.go | 2 +- .../gateio/gateio_ws_delivery_futures.go | 2 +- exchanges/gateio/gateio_ws_futures.go | 2 +- exchanges/gateio/gateio_ws_option.go | 2 +- exchanges/gemini/gemini_test.go | 2 +- exchanges/gemini/gemini_websocket.go | 2 +- exchanges/hitbtc/hitbtc_test.go | 2 +- exchanges/hitbtc/hitbtc_websocket.go | 2 +- exchanges/huobi/huobi_test.go | 2 +- exchanges/huobi/huobi_websocket.go | 2 +- exchanges/kraken/kraken_test.go | 2 +- exchanges/kraken/kraken_websocket.go | 2 +- exchanges/kucoin/kucoin_websocket.go | 2 +- exchanges/okcoin/okcoin_websocket.go | 2 +- exchanges/okcoin/okcoin_ws_trade.go | 2 +- exchanges/okx/okx_websocket.go | 2 +- exchanges/poloniex/poloniex_test.go | 2 +- exchanges/poloniex/poloniex_websocket.go | 2 +- exchanges/stream/websocket.go | 30 +++--- exchanges/stream/websocket_test.go | 97 +++++-------------- exchanges/stream/websocket_types.go | 2 - 42 files changed, 80 insertions(+), 133 deletions(-) diff --git a/cmd/exchange_template/wrapper_file.tmpl b/cmd/exchange_template/wrapper_file.tmpl index d57f96b5451..e74ecbc320f 100644 --- a/cmd/exchange_template/wrapper_file.tmpl +++ b/cmd/exchange_template/wrapper_file.tmpl @@ -125,7 +125,7 @@ func ({{.Variable}} *{{.CapitalName}}) SetDefaults() { exchange.RestSpot: {{.Name}}APIURL, // exchange.WebsocketSpot: {{.Name}}WSAPIURL, }) - {{.Variable}}.Websocket = stream.New() + {{.Variable}}.Websocket = stream.NewWebsocket() {{.Variable}}.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit {{.Variable}}.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout {{.Variable}}.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 52cdd0cc392..bd96d546f4f 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -50,7 +50,7 @@ var ( // WsConnect initiates a websocket connection func (b *Binance) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 8f4d5c3cd6e..14098c1139b 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect initiates a websocket connection func (bi *Binanceus) WsConnect() error { if !bi.Websocket.IsEnabled() || !bi.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.HandshakeTimeout = bi.Config.HTTPTimeout diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index c2bba90d00b..e7be0fbe707 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1128,7 +1128,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !b.Websocket.IsEnabled() { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) if !b.API.AuthenticatedWebsocketSupport { diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index e1010eb1061..ae7cded7477 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -43,7 +43,7 @@ var cMtx sync.Mutex // WsConnect starts a new websocket connection func (b *Bitfinex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 3005f42359c..667ce131c71 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -2,7 +2,6 @@ package bithumb import ( "encoding/json" - "errors" "fmt" "net/http" "time" @@ -29,7 +28,7 @@ var ( // WsConnect initiates a websocket connection func (b *Bithumb) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/bitmex/bitmex_test.go b/exchanges/bitmex/bitmex_test.go index b59cb125812..e825ad8a07c 100644 --- a/exchanges/bitmex/bitmex_test.go +++ b/exchanges/bitmex/bitmex_test.go @@ -789,7 +789,7 @@ func TestGetDepositAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !b.Websocket.IsEnabled() && !b.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(b) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 6d04c106039..e1a475254f6 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -68,7 +68,7 @@ const ( // WsConnect initiates a new websocket connection func (b *Bitmex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index ab5465b574d..98aa6201df4 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect connects to a websocket feed func (b *Bitstamp) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 8c979b4d79b..01ba1a64d4c 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -39,7 +39,7 @@ var ( // WsConnect connects to a websocket feed func (b *BTCMarkets) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 41f25c95e35..e32fb9a8095 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -30,7 +30,7 @@ const ( // WsConnect connects the websocket client func (b *BTSE) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bybit/bybit.go b/exchanges/bybit/bybit.go index 5652c464057..3588336ff52 100644 --- a/exchanges/bybit/bybit.go +++ b/exchanges/bybit/bybit.go @@ -21,7 +21,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) // Bybit is the overarching type across this package @@ -90,7 +89,6 @@ var ( errAPIKeyIsNotUnified = errors.New("api key is not unified") errEndpointAvailableForNormalAPIKeyHolders = errors.New("endpoint available for normal API key holders only") errInvalidContractLength = errors.New("contract length cannot be less than or equal to zero") - errWebsocketNotEnabled = errors.New(stream.WebsocketNotEnabled) ) var ( diff --git a/exchanges/bybit/bybit_inverse_websocket.go b/exchanges/bybit/bybit_inverse_websocket.go index 77f387ace60..d1387c277f8 100644 --- a/exchanges/bybit/bybit_inverse_websocket.go +++ b/exchanges/bybit/bybit_inverse_websocket.go @@ -12,7 +12,7 @@ import ( // WsInverseConnect connects to inverse websocket feed func (by *Bybit) WsInverseConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.CoinMarginedFutures) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(inversePublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_linear_websocket.go b/exchanges/bybit/bybit_linear_websocket.go index efc2f68d1b8..9b3ed08426a 100644 --- a/exchanges/bybit/bybit_linear_websocket.go +++ b/exchanges/bybit/bybit_linear_websocket.go @@ -14,7 +14,7 @@ import ( // WsLinearConnect connects to linear a websocket feed func (by *Bybit) WsLinearConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.LinearContract) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(linearPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_options_websocket.go b/exchanges/bybit/bybit_options_websocket.go index 2f4abc7a76d..4bb25cef2a9 100644 --- a/exchanges/bybit/bybit_options_websocket.go +++ b/exchanges/bybit/bybit_options_websocket.go @@ -14,7 +14,7 @@ import ( // WsOptionsConnect connects to options a websocket feed func (by *Bybit) WsOptionsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Options) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(optionPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 58ca5bc46e8..092ca125462 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -20,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/margin" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -3064,7 +3065,7 @@ func TestWsLinearConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsLinearConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3074,7 +3075,7 @@ func TestWsInverseConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsInverseConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3084,7 +3085,7 @@ func TestWsOptionsConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsOptionsConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 857fd690afe..ff2698b8667 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -57,7 +57,7 @@ const ( // WsConnect connects to a websocket feed func (by *Bybit) WsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinbasepro/coinbasepro_test.go b/exchanges/coinbasepro/coinbasepro_test.go index 77f058384d1..0b04d4ab64b 100644 --- a/exchanges/coinbasepro/coinbasepro_test.go +++ b/exchanges/coinbasepro/coinbasepro_test.go @@ -681,7 +681,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index e4b02b764d8..5946cf778b5 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -31,7 +31,7 @@ const ( // WsConnect initiates a websocket connection func (c *CoinbasePro) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/coinut/coinut_test.go b/exchanges/coinut/coinut_test.go index 3431d60f730..0b165327846 100644 --- a/exchanges/coinut/coinut_test.go +++ b/exchanges/coinut/coinut_test.go @@ -66,7 +66,7 @@ func setupWSTestAuth(t *testing.T) { } if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } if sharedtestvalues.AreAPICredentialsSet(c) { c.Websocket.SetCanUseAuthenticatedEndpoints(true) diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 2453816e9ce..78b389879ca 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -41,7 +41,7 @@ var ( // WsConnect initiates a websocket connection func (c *COINUT) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index c26d04afe7c..3a3dddecf88 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -60,7 +60,7 @@ var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsConnect initiates a websocket connection func (g *Gateio) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Spot) if err != nil { diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index ba5c64afe3d..449181c1007 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -45,7 +45,7 @@ var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account func (g *Gateio) WsDeliveryFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.DeliveryFutures) if err != nil { diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index c0411a5816c..20e293b93af 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -64,7 +64,7 @@ var responseFuturesStream = make(chan stream.Response) // WsFuturesConnect initiates a websocket connection for futures account func (g *Gateio) WsFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) if err != nil { diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index d5340f0c350..3278914f21f 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -70,7 +70,7 @@ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. func (g *Gateio) WsOptionsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Options) if err != nil { diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index 6a3b35e1885..7477e78ff11 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -556,7 +556,7 @@ func TestWsAuth(t *testing.T) { if !g.Websocket.IsEnabled() && !g.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(g) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer go g.wsReadData() diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index 913c856ddd4..43c2135e021 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -39,7 +39,7 @@ var comms = make(chan stream.Response) // WsConnect initiates a websocket connection func (g *Gemini) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/hitbtc/hitbtc_test.go b/exchanges/hitbtc/hitbtc_test.go index 3e629d0a654..68b8495167b 100644 --- a/exchanges/hitbtc/hitbtc_test.go +++ b/exchanges/hitbtc/hitbtc_test.go @@ -466,7 +466,7 @@ func setupWsAuth(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index 705f584f01c..deb885424bf 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -34,7 +34,7 @@ const ( // WsConnect starts a new connection with the websocket API func (h *HitBTC) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/huobi/huobi_test.go b/exchanges/huobi/huobi_test.go index 4aafc7b74a5..16106daf566 100644 --- a/exchanges/huobi/huobi_test.go +++ b/exchanges/huobi/huobi_test.go @@ -78,7 +78,7 @@ func setupWsTests(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } comms = make(chan WsMessage, sharedtestvalues.WebsocketChannelOverrideCapacity) go h.wsReadData() diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index f601b03dbc2..92b5bf4c22e 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -62,7 +62,7 @@ var comms = make(chan WsMessage) // WsConnect initiates a new websocket connection func (h *HUOBI) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.wsDial(&dialer) diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index 263d70bf571..8530cb53b8d 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1215,7 +1215,7 @@ func setupWsTests(t *testing.T) { return } if !k.Websocket.IsEnabled() && !k.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(k) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := k.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 787d52b2c31..fd4325164e2 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -87,7 +87,7 @@ var cancelOrdersStatus = make(map[int64]*struct { // WsConnect initiates a websocket connection func (k *Kraken) WsConnect() error { if !k.Websocket.IsEnabled() || !k.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 3f8a0c783ca..917bc1a90e3 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -97,7 +97,7 @@ var ( // WsConnect creates a new websocket connection. func (ku *Kucoin) WsConnect() error { if !ku.Websocket.IsEnabled() || !ku.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } fetchedFuturesSnapshotOrderbook = map[string]bool{} var dialer websocket.Dialer diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index e714aba7d10..0787775e2f9 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -74,7 +74,7 @@ func isAuthenticatedChannel(channel string) bool { // WsConnect initiates a websocket connection func (o *Okcoin) WsConnect() error { if !o.Websocket.IsEnabled() || !o.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/okcoin/okcoin_ws_trade.go b/exchanges/okcoin/okcoin_ws_trade.go index cd0c7f86605..85b6102e24d 100644 --- a/exchanges/okcoin/okcoin_ws_trade.go +++ b/exchanges/okcoin/okcoin_ws_trade.go @@ -130,7 +130,7 @@ func (o *Okcoin) WsAmendMultipleOrder(args []AmendTradeOrderRequestParam) ([]Ame func (o *Okcoin) SendWebsocketRequest(operation string, data, result interface{}, authenticated bool) error { switch { case !o.Websocket.IsEnabled(): - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled case !o.Websocket.IsConnected(): return stream.ErrNotConnected case !o.Websocket.CanUseAuthenticatedEndpoints() && authenticated: diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index b4d211eec01..54bc6b0e486 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -216,7 +216,7 @@ const ( // WsConnect initiates a websocket connection func (ok *Okx) WsConnect() error { if !ok.Websocket.IsEnabled() || !ok.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/poloniex/poloniex_test.go b/exchanges/poloniex/poloniex_test.go index bab0ffd4a71..d74e2353d67 100644 --- a/exchanges/poloniex/poloniex_test.go +++ b/exchanges/poloniex/poloniex_test.go @@ -548,7 +548,7 @@ func TestGenerateNewAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !p.Websocket.IsEnabled() && !p.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(p) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 1be429a58ef..23774335fc8 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -55,7 +55,7 @@ var ( // WsConnect initiates a websocket connection func (p *Poloniex) WsConnect() error { if !p.Websocket.IsEnabled() || !p.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 4fd0685bf7b..4e73579a960 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -24,24 +24,20 @@ const ( defaultTrafficPeriod = time.Second ) +// Public websocket errors var ( - // ErrSubscriptionNotFound defines an error when a subscription is not found - ErrSubscriptionNotFound = errors.New("subscription not found") - // ErrSubscribedAlready defines an error when a channel is already subscribed - ErrSubscribedAlready = errors.New("duplicate subscription") - // ErrSubscriptionFailure defines an error when a subscription fails - ErrSubscriptionFailure = errors.New("subscription failure") - // ErrSubscriptionNotSupported defines an error when a subscription channel is not supported by an exchange + ErrWebsocketNotEnabled = errors.New("websocket not enabled") + ErrSubscriptionNotFound = errors.New("subscription not found") + ErrSubscribedAlready = errors.New("duplicate subscription") + ErrSubscriptionFailure = errors.New("subscription failure") ErrSubscriptionNotSupported = errors.New("subscription channel not supported ") - // ErrUnsubscribeFailure defines an error when a unsubscribe fails - ErrUnsubscribeFailure = errors.New("unsubscribe failure") - // ErrChannelInStateAlready defines an error when a subscription channel is already in a new state - ErrChannelInStateAlready = errors.New("channel already in state") - // ErrAlreadyDisabled is returned when you double-disable the websocket - ErrAlreadyDisabled = errors.New("websocket already disabled") - // ErrNotConnected defines an error when websocket is not connected - ErrNotConnected = errors.New("websocket is not connected") + ErrUnsubscribeFailure = errors.New("unsubscribe failure") + ErrChannelInStateAlready = errors.New("channel already in state") + ErrAlreadyDisabled = errors.New("websocket already disabled") + ErrNotConnected = errors.New("websocket is not connected") +) +var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") errWebsocketIsNil = errors.New("websocket is nil") @@ -265,7 +261,7 @@ func (w *Websocket) Connect() error { defer w.m.Unlock() if !w.IsEnabled() { - return errors.New(WebsocketNotEnabled) + return ErrWebsocketNotEnabled } if w.IsConnecting() { return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting) @@ -490,7 +486,7 @@ func (w *Websocket) Shutdown() error { // FlushChannels flushes channel subscriptions when there is a pair/asset change func (w *Websocket) FlushChannels() error { if !w.IsEnabled() { - return fmt.Errorf("%s websocket: service not enabled", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrWebsocketNotEnabled) } if !w.IsConnected() { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 188c8fcb30c..3ae1f388157 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -201,105 +201,62 @@ func TestTrafficMonitorTimeout(t *testing.T) { func TestIsDisconnectionError(t *testing.T) { t.Parallel() - isADisconnectionError := IsDisconnectionError(errors.New("errorText")) - if isADisconnectionError { - t.Error("Its not") - } - isADisconnectionError = IsDisconnectionError(&websocket.CloseError{ - Code: 1006, - Text: "errorText", - }) - if !isADisconnectionError { - t.Error("It is") - } - - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errClosedConnection, - }) - if isADisconnectionError { - t.Error("It's not") - } - - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errors.New("errText"), - }) - if !isADisconnectionError { - t.Error("It is") - } + assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsADisconnectionError should return false") + assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsADisconnectionError should return true") + assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsADisconnectionError should return false") + assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsADisconnectionError should return true") } func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} err := wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errNoConnectFunc, "Connect should error correctly") wsWrong.connector = func() error { return nil } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) wsWrong.setState(connecting) wsWrong.Wg = &sync.WaitGroup{} err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") wsWrong.setState(disconnected) - wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") } + wsWrong.connector = func() error { return errDastardlyReason } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") ws := NewWebsocket() err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Setup must not error") ws.trafficTimeout = time.Minute ws.connector = func() error { return nil } err = ws.Connect() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Connect must not error") ws.TrafficAlert <- struct{}{} - timer := time.NewTimer(900 * time.Millisecond) - ws.ReadMessageErrors <- errors.New("errorText") - select { - case err := <-ws.ToRoutine: - errText, ok := err.(error) - if !ok { - t.Error("unable to type assert error") - } else if errText.Error() != "errorText" { - t.Errorf("Expected 'errorText', received %v", err) - } - case <-timer.C: - t.Error("Timeout waiting for datahandler to receive error") - } - ws.ReadMessageErrors <- &websocket.CloseError{ - Code: 1006, - Text: "errorText", - } -outer: - for { + c := func(ta *assert.CollectT) { select { - case err := <-ws.ToRoutine: - if _, ok := err.(*websocket.CloseError); !ok { - t.Errorf("Error is not a disconnection error: %v", err) + case v := <-ws.ToRoutine: + switch err := v.(type) { + case *websocket.CloseError: + assert.Equal(t, "SpecialText", err.Text, "Should get correct Close Error") + case error: + assert.ErrorIs(t, err, errDastardlyReason, "Should get the correct error") } - case <-timer.C: - break outer + default: } } + + ws.ReadMessageErrors <- errDastardlyReason + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") + + ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") } func TestWebsocket(t *testing.T) { @@ -1041,9 +998,7 @@ func TestFlushChannels(t *testing.T) { dodgyWs := Websocket{} err := dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "FlushChannels should error correctly") dodgyWs.setEnabled(true) err = dodgyWs.FlushChannels() diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 7af4ae03a1d..31e1db58b09 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -15,8 +15,6 @@ import ( // Websocket functionality list and state consts const ( - // WebsocketNotEnabled alerts of a disabled websocket - WebsocketNotEnabled = "exchange_websocket_not_enabled" WebsocketNotAuthenticatedUsingRest = "%v - Websocket not authenticated, using REST\n" Ping = "ping" Pong = "pong" From 1ec16a569108cd52a36822aa59c0c2b0763d9f6e Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 5 Feb 2024 11:07:22 +0700 Subject: [PATCH 09/35] fixup! Websocket: Make WebsocketNotEnabled a real error --- exchanges/stream/websocket_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3ae1f388157..521fccb8e77 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -201,10 +201,10 @@ func TestTrafficMonitorTimeout(t *testing.T) { func TestIsDisconnectionError(t *testing.T) { t.Parallel() - assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsADisconnectionError should return false") - assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsADisconnectionError should return true") - assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsADisconnectionError should return false") - assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsADisconnectionError should return true") + assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsDisconnectionError should return true") + assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsDisconnectionError should return true") } func TestConnectionMessageErrors(t *testing.T) { From 208ca57a628886da14fed96ce9eea499e219dcdd Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Fri, 16 Feb 2024 10:36:07 +0700 Subject: [PATCH 10/35] fixup! Websocket: Make WebsocketNotEnabled a real error --- exchanges/stream/websocket_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 521fccb8e77..8459a5461a9 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -239,14 +239,14 @@ func TestConnectionMessageErrors(t *testing.T) { ws.TrafficAlert <- struct{}{} - c := func(ta *assert.CollectT) { + c := func(tb *assert.CollectT) { select { case v := <-ws.ToRoutine: switch err := v.(type) { case *websocket.CloseError: - assert.Equal(t, "SpecialText", err.Text, "Should get correct Close Error") + assert.Equal(tb, "SpecialText", err.Text, "Should get correct Close Error") case error: - assert.ErrorIs(t, err, errDastardlyReason, "Should get the correct error") + assert.ErrorIs(tb, err, errDastardlyReason, "Should get the correct error") } default: } From f35de757a1adde68da280d1184bdf0b96d4f01ab Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 3 Feb 2024 17:39:35 +0700 Subject: [PATCH 11/35] Websocket: Add more testable errors --- exchanges/stream/websocket.go | 26 ++- exchanges/stream/websocket_connection.go | 21 +- exchanges/stream/websocket_test.go | 269 +++++++---------------- 3 files changed, 94 insertions(+), 222 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 4e73579a960..9e527622b3a 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -40,19 +40,23 @@ var ( var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") + errExchangeConfigEmpty = errors.New("exchange config is empty") errWebsocketIsNil = errors.New("websocket is nil") errWebsocketSetupIsNil = errors.New("websocket setup is nil") errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") errConfigFeaturesIsNil = errors.New("exchange config features is nil") errDefaultURLIsEmpty = errors.New("default url is empty") errRunningURLIsEmpty = errors.New("running url cannot be empty") errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameUnset = errors.New("exchange config name unset") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") errClosedConnection = errors.New("use of closed network connection") errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") @@ -65,6 +69,7 @@ var ( errAlreadyConnected = errors.New("websocket already connected") errCannotShutdown = errors.New("websocket cannot shutdown") errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") ) var globalReporter Reporter @@ -107,7 +112,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } if s.ExchangeConfig.Name == "" { - return errExchangeConfigNameUnset + return errExchangeConfigNameEmpty } w.exchangeName = s.ExchangeConfig.Name w.verbose = s.ExchangeConfig.Verbose @@ -196,22 +201,22 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { // SetupNewConnection sets up an auth or unauth streaming connection func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if w == nil { - return errors.New("setting up new connection error: websocket is nil") + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } if c == (ConnectionSetup{}) { - return errors.New("setting up new connection error: websocket connection configuration empty") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) } if w.exchangeName == "" { - return errors.New("setting up new connection error: exchange name not set, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) } if w.TrafficAlert == nil { - return errors.New("setting up new connection error: traffic alert is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) } if w.ReadMessageErrors == nil { - return errors.New("setting up new connection error: read message errors is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) } connectionURL := w.GetWebsocketURL() @@ -314,7 +319,7 @@ func (w *Websocket) Connect() error { // Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { if !w.IsEnabled() { - return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrAlreadyDisabled) } w.setEnabled(false) @@ -324,8 +329,7 @@ func (w *Websocket) Disable() error { // Enable enables the exchange websocket protocol func (w *Websocket) Enable() error { if w.IsConnected() || w.IsEnabled() { - return fmt.Errorf("websocket is already enabled for exchange %s", - w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyEnabled) } w.setEnabled(true) @@ -490,7 +494,7 @@ func (w *Websocket) FlushChannels() error { } if !w.IsConnected() { - return fmt.Errorf("%s websocket: service not connected", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) } if w.features.Subscribe { diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 0bb1e660412..009b9e37459 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -50,9 +50,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(signature, request inter return payload, nil case <-timer.C: timer.Stop() - return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", - w.ExchangeName, - signature) + return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature) } } @@ -72,25 +70,14 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header w.Connection, conStatus, err = dialer.Dial(w.URL, headers) if err != nil { if conStatus != nil { - return fmt.Errorf("%s websocket connection: %v %v %v Error: %v", - w.ExchangeName, - w.URL, - conStatus, - conStatus.StatusCode, - err) + return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) } - return fmt.Errorf("%s websocket connection: %v Error: %v", - w.ExchangeName, - w.URL, - err) + return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } defer conStatus.Body.Close() if w.Verbose { - log.Infof(log.WebsocketMgr, - "%v Websocket connected to %s\n", - w.ExchangeName, - w.URL) + log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) } select { case w.Traffic <- struct{}{}: diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 8459a5461a9..653abbf5f5c 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -122,7 +122,7 @@ func TestSetup(t *testing.T) { websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - assert.ErrorIs(t, err, errExchangeConfigNameUnset, "Setup should error correctly") + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "Setup should error correctly") websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) @@ -480,16 +480,10 @@ func TestConnectionMonitorNoConnection(t *testing.T) { ws.Wg = &sync.WaitGroup{} ws.enabled = true err := ws.connectionMonitor() - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } - if !ws.IsConnectionMonitorRunning() { - t.Fatal("Should not have exited") - } + require.NoError(t, err, "connectionMonitor must not error") + assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") err = ws.connectionMonitor() - if !errors.Is(err, errAlreadyRunning) { - t.Fatalf("received: %v, but expected: %v", err, errAlreadyRunning) - } + assert.ErrorIs(t, err, errAlreadyRunning, "connectionMonitor should error correctly") } // TestGetSubscription logic test @@ -528,15 +522,9 @@ func TestGetSubscriptions(t *testing.T) { func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { t.Parallel() ws := NewWebsocket() - result := ws.CanUseAuthenticatedEndpoints() - if result { - t.Error("expected `canUseAuthenticatedEndpoints` to be false") - } + assert.False(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return false") ws.SetCanUseAuthenticatedEndpoints(true) - result = ws.CanUseAuthenticatedEndpoints() - if !result { - t.Error("expected `canUseAuthenticatedEndpoints` to be true") - } + assert.True(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return true") } // TestDial logic test @@ -773,69 +761,41 @@ func TestParseBinaryResponse(t *testing.T) { } var b bytes.Buffer - w := gzip.NewWriter(&b) - _, err := w.Write([]byte("hello")) - if err != nil { - t.Error(err) - } - err = w.Close() - if err != nil { - t.Error(err) - } - var resp []byte - resp, err = wc.parseBinaryResponse(b.Bytes()) - if err != nil { - t.Error(err) - } - if !strings.EqualFold(string(resp), "hello") { - t.Errorf("GZip conversion failed. Received: '%v', Expected: 'hello'", string(resp)) - } + g := gzip.NewWriter(&b) + _, err := g.Write([]byte("hello")) + require.NoError(t, err, "gzip.Write must not error") + assert.NoError(t, g.Close(), "Close should not error") + + resp, err := wc.parseBinaryResponse(b.Bytes()) + assert.NoError(t, err, "parseBinaryResponse should not error parsing gzip") + assert.EqualValues(t, "hello", resp, "parseBinaryResponse should decode gzip") + + b.Reset() + f, err := flate.NewWriter(&b, 1) + require.NoError(t, err, "flate.NewWriter must not error") + _, err = f.Write([]byte("goodbye")) + require.NoError(t, err, "flate.Write must not error") + assert.NoError(t, f.Close(), "Close should not error") - var b2 bytes.Buffer - w2, err2 := flate.NewWriter(&b2, 1) - if err2 != nil { - t.Error(err2) - } - _, err2 = w2.Write([]byte("hello")) - if err2 != nil { - t.Error(err) - } - err2 = w2.Close() - if err2 != nil { - t.Error(err) - } - resp2, err3 := wc.parseBinaryResponse(b2.Bytes()) - if err3 != nil { - t.Error(err3) - } - if !strings.EqualFold(string(resp2), "hello") { - t.Errorf("Deflate conversion failed. Received: '%v', Expected: 'hello'", string(resp2)) - } + resp, err = wc.parseBinaryResponse(b.Bytes()) + assert.NoError(t, err, "parseBinaryResponse should not error parsing inflate") + assert.EqualValues(t, "goodbye", resp, "parseBinaryResponse should deflate") - _, err4 := wc.parseBinaryResponse([]byte{}) - if err4 == nil || err4.Error() != "unexpected EOF" { - t.Error("Expected error 'unexpected EOF'") - } + _, err = wc.parseBinaryResponse([]byte{}) + assert.ErrorContains(t, err, "unexpected EOF", "parseBinaryResponse should error on empty input") } // TestCanUseAuthenticatedWebsocketForWrapper logic test func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { t.Parallel() ws := &Websocket{} - resp := ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is false") - } + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + ws.setState(connected) - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false") - } + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + ws.canUseAuthenticatedEndpoints = true - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if !resp { - t.Error("Expected true, `connected` and `CanUseAuthenticatedEndpoints` is true") - } + assert.True(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return true") } func TestGenerateMessageID(t *testing.T) { @@ -869,34 +829,22 @@ func BenchmarkGenerateMessageID_Low(b *testing.B) { func TestCheckWebsocketURL(t *testing.T) { err := checkWebsocketURL("") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on empty string") err = checkWebsocketURL("wowowow:wowowowo") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on bad format") err = checkWebsocketURL("://") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "missing protocol scheme", "checkWebsocketURL should error correctly on bad proto") err = checkWebsocketURL("http://www.google.com") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on wrong proto") err = checkWebsocketURL("wss://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") err = checkWebsocketURL("ws://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") } func TestGetChannelDifference(t *testing.T) { @@ -1002,9 +950,7 @@ func TestFlushChannels(t *testing.T) { dodgyWs.setEnabled(true) err = dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") web := Websocket{ enabled: true, @@ -1023,7 +969,7 @@ func TestFlushChannels(t *testing.T) { } problemFunc := func() ([]subscription.Subscription, error) { - return nil, errors.New("problems") + return nil, errDastardlyReason } noSub := func() ([]subscription.Subscription, error) { @@ -1037,47 +983,34 @@ func TestFlushChannels(t *testing.T) { return []subscription.Subscription{{Channel: "test"}}, nil } err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") web.features.FullPayloadSubscribe = true web.GenerateSubs = problemFunc err = web.FlushChannels() // error on full subscribeToChannels - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") web.GenerateSubs = noSub - err = web.FlushChannels() // No subs to sub - if err != nil { - t.Fatal(err) - } + err = web.FlushChannels() // No subs to unsub + assert.NoError(t, err, "FlushChannels should not error") web.GenerateSubs = newgen.generateSubs subs, err := web.GenerateSubs() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "GenerateSubs must not error") + web.AddSuccessfulSubscriptions(subs...) err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") web.features.FullPayloadSubscribe = false web.features.Subscribe = true web.GenerateSubs = problemFunc err = web.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") web.GenerateSubs = newgen.generateSubs err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") web.subscriptionMutex.Lock() web.subscriptions = subscriptionMap{ 41: { @@ -1094,21 +1027,15 @@ func TestFlushChannels(t *testing.T) { web.subscriptionMutex.Unlock() err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") web.setState(connected) web.features.Unsubscribe = true err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "FlushChannels should not error") } func TestDisable(t *testing.T) { @@ -1118,14 +1045,8 @@ func TestDisable(t *testing.T) { state: connected, ShutdownC: make(chan struct{}), } - err := web.Disable() - if err != nil { - t.Fatal(err) - } - err = web.Disable() - if err == nil { - t.Fatal("should already be disabled") - } + require.NoError(t, web.Disable(), "Disable must not error") + assert.ErrorIs(t, web.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } func TestEnable(t *testing.T) { @@ -1140,98 +1061,66 @@ func TestEnable(t *testing.T) { Subscriber: func([]subscription.Subscription) error { return nil }, } - err := web.Enable() - if err != nil { - t.Fatal(err) - } - - err = web.Enable() - if err == nil { - t.Fatal("should already be enabled") - } - - fmt.Print() + require.NoError(t, web.Enable(), "Enable must not error") + assert.ErrorIs(t, web.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { t.Parallel() var nonsenseWebsock *Websocket err := nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errWebsocketIsNil, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{exchangeName: "test"} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") nonsenseWebsock.TrafficAlert = make(chan struct{}) err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") web := Websocket{ connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), - state: disconnected, TrafficAlert: make(chan struct{}), ReadMessageErrors: make(chan error), DataHandler: make(chan interface{}), } err = web.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Setup should not error") + err = web.SetupNewConnection(ConnectionSetup{}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigEmpty, "SetupNewConnection should error correctly") + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err != nil { - t.Fatal(err) - } - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", - Authenticated: true}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetupNewConnection should not error") + + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", Authenticated: true}) + assert.NoError(t, err, "SetupNewConnection should not error") } func TestWebsocketConnectionShutdown(t *testing.T) { t.Parallel() wc := WebsocketConnection{} err := wc.Shutdown() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Shutdown should not error") err = wc.Dial(&websocket.Dialer{}, nil) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial must error correctly") wc.URL = websocketTestURL err = wc.Dial(&websocket.Dialer{}, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Dial must not error") err = wc.Shutdown() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Shutdown must not error") } // TestLatency logic test @@ -1285,27 +1174,19 @@ func TestCheckSubscriptions(t *testing.T) { t.Parallel() ws := Websocket{} err := ws.checkSubscriptions(nil) - if !errors.Is(err, errNoSubscriptionsSupplied) { - t.Fatalf("received: %v, but expected: %v", err, errNoSubscriptionsSupplied) - } + assert.ErrorIs(t, err, errNoSubscriptionsSupplied, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 1 err = ws.checkSubscriptions([]subscription.Subscription{{}, {}}) - if !errors.Is(err, errSubscriptionsExceedsLimit) { - t.Fatalf("received: %v, but expected: %v", err, errSubscriptionsExceedsLimit) - } + assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}} err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}}) - if !errors.Is(err, errChannelAlreadySubscribed) { - t.Fatalf("received: %v, but expected: %v", err, errChannelAlreadySubscribed) - } + assert.ErrorIs(t, err, errChannelAlreadySubscribed, "checkSubscriptions should error correctly") err = ws.checkSubscriptions([]subscription.Subscription{{}}) - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } + assert.NoError(t, err, "checkSubscriptions should not error") } From 78a7e833da14048c4504a3a927d3c9b1e863856a Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sun, 4 Feb 2024 09:30:32 +0700 Subject: [PATCH 12/35] Websocket: Improve GenerateMessageID test Testing just the last id doesn't feel very robust --- exchanges/stream/websocket_connection.go | 2 +- exchanges/stream/websocket_test.go | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 009b9e37459..910142caa9a 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -272,7 +272,7 @@ func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) { return standardMessage, reader.Close() } -// GenerateMessageID Creates a messageID to checkout +// GenerateMessageID Creates a random message ID func (w *WebsocketConnection) GenerateMessageID(highPrec bool) int64 { var min int64 = 1e8 var max int64 = 2e8 diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 653abbf5f5c..bdb0478606d 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -801,13 +801,12 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { func TestGenerateMessageID(t *testing.T) { t.Parallel() wc := WebsocketConnection{} - var id int64 - for i := 0; i < 10; i++ { - newID := wc.GenerateMessageID(true) - if id == newID { - t.Fatal("ID generation is not unique") - } - id = newID + const spins = 1000 + ids := make([]int64, spins) + for i := 0; i < spins; i++ { + id := wc.GenerateMessageID(true) + assert.NotContains(t, ids, id, "GenerateMessageID must not generate the same ID twice") + ids[i] = id } } From ebaa55d63833b4da6fd6841445a98fb60db64575 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 5 Feb 2024 13:37:41 +0700 Subject: [PATCH 13/35] Websocket: Protect Setup() from races --- exchanges/stream/websocket.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 9e527622b3a..af882464604 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -103,6 +103,9 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } + w.m.Lock() + defer w.m.Unlock() + if w.IsInitialised() { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } From 3e19e83e3d553664981d142ad2b541c19a58e65e Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 15 Feb 2024 19:39:03 +0700 Subject: [PATCH 14/35] Websocket: Use atomics instead of mutex This was spurred by looking at the setState call in trafficMonitor and the effect on blocking and efficiency. With the new atomic types in Go 1.19, and the small types in use here, atomics should be safe for our usage. bools should be truly atomic, and uint32 is atomic when the accepted value range is less than one byte/uint8 since that can be written atomicly by concurrent processors. Maybe that's not even a factor any more, however we don't even have to worry enough to check. --- exchanges/stream/websocket.go | 93 +++++++++-------------------- exchanges/stream/websocket_test.go | 77 ++++++++++++------------ exchanges/stream/websocket_types.go | 18 +++--- 3 files changed, 74 insertions(+), 114 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index af882464604..409e0953251 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -128,7 +128,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { if s.ExchangeConfig.Features == nil { return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil) } - w.enabled = s.ExchangeConfig.Features.Enabled.Websocket + w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) if s.Connector == nil { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset) @@ -388,9 +388,7 @@ func (w *Websocket) connectionMonitor() error { if w.checkAndSetMonitorRunning() { return errAlreadyRunning } - w.fieldMutex.RLock() delay := w.connectionMonitorDelay - w.fieldMutex.RUnlock() go func() { timer := time.NewTimer(delay) @@ -618,104 +616,73 @@ func (w *Websocket) trafficMonitor() { }() } -// IsInitialised returns whether the websocket has been Setup() already -func (w *Websocket) IsInitialised() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state != uninitialised +func (w *Websocket) setState(s uint32) { + w.state.Store(s) } -func (w *Websocket) setState(s state) { - w.fieldMutex.Lock() - w.state = s - w.fieldMutex.Unlock() +// IsInitialised returns whether the websocket has been Setup() already +func (w *Websocket) IsInitialised() bool { + return w.state.Load() != uninitialised } // IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state == connected + return w.state.Load() == connected } // IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.state == connecting + return w.state.Load() == connecting } func (w *Websocket) setEnabled(b bool) { - w.fieldMutex.Lock() - w.enabled = b - w.fieldMutex.Unlock() + w.enabled.Store(b) } // IsEnabled returns whether the websocket is enabled func (w *Websocket) IsEnabled() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.enabled + return w.enabled.Load() } func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.trafficMonitorRunning = b - w.fieldMutex.Unlock() + w.trafficMonitorRunning.Store(b) } // IsTrafficMonitorRunning returns status of the traffic monitor func (w *Websocket) IsTrafficMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.trafficMonitorRunning + return w.trafficMonitorRunning.Load() } func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - if w.connectionMonitorRunning { - return true - } - w.connectionMonitorRunning = true - return false + return !w.connectionMonitorRunning.CompareAndSwap(false, true) } func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.connectionMonitorRunning = b - w.fieldMutex.Unlock() + w.connectionMonitorRunning.Store(b) } // IsConnectionMonitorRunning returns status of connection monitor func (w *Websocket) IsConnectionMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connectionMonitorRunning + return w.connectionMonitorRunning.Load() } func (w *Websocket) setDataMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.dataMonitorRunning = b - w.fieldMutex.Unlock() + w.dataMonitorRunning.Store(b) } // IsDataMonitorRunning returns status of data monitor func (w *Websocket) IsDataMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.dataMonitorRunning + return w.dataMonitorRunning.Load() } // CanUseAuthenticatedWebsocketForWrapper Handles a common check to // verify whether a wrapper can use an authenticated websocket endpoint func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { - if w.IsConnected() && w.CanUseAuthenticatedEndpoints() { - return true - } else if w.IsConnected() && !w.CanUseAuthenticatedEndpoints() { - log.Infof(log.WebsocketMgr, - WebsocketNotAuthenticatedUsingRest, - w.exchangeName) + if w.IsConnected() { + if w.CanUseAuthenticatedEndpoints() { + return true + } + log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName) } return false } @@ -994,20 +961,14 @@ func (w *Websocket) GetSubscriptions() []subscription.Subscription { return subs } -// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in -// a thread safe manner -func (w *Websocket) SetCanUseAuthenticatedEndpoints(val bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - w.canUseAuthenticatedEndpoints = val +// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner +func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) { + w.canUseAuthenticatedEndpoints.Store(b) } -// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in -// a thread safe manner +// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner func (w *Websocket) CanUseAuthenticatedEndpoints() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.canUseAuthenticatedEndpoints + return w.canUseAuthenticatedEndpoints.Load() } // IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index bdb0478606d..4725cc9b064 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -478,7 +478,7 @@ func TestConnectionMonitorNoConnection(t *testing.T) { ws.ShutdownC = make(chan struct{}, 1) ws.exchangeName = "hello" ws.Wg = &sync.WaitGroup{} - ws.enabled = true + ws.setEnabled(true) err := ws.connectionMonitor() require.NoError(t, err, "connectionMonitor must not error") assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") @@ -792,9 +792,10 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") ws.setState(connected) + require.True(t, ws.IsConnected(), "IsConnected must return true") assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") - ws.canUseAuthenticatedEndpoints = true + ws.SetCanUseAuthenticatedEndpoints(true) assert.True(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return true") } @@ -951,9 +952,7 @@ func TestFlushChannels(t *testing.T) { err = dodgyWs.FlushChannels() assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") - web := Websocket{ - enabled: true, - state: connected, + w := Websocket{ connector: connect, ShutdownC: make(chan struct{}), Subscriber: newgen.SUBME, @@ -966,6 +965,8 @@ func TestFlushChannels(t *testing.T) { // in FlushChannels() so the traffic monitor doesn't time out and turn // this to an unconnected state } + w.setEnabled(true) + w.setState(connected) problemFunc := func() ([]subscription.Subscription, error) { return nil, errDastardlyReason @@ -978,40 +979,40 @@ func TestFlushChannels(t *testing.T) { // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} - web.GenerateSubs = func() ([]subscription.Subscription, error) { + w.GenerateSubs = func() ([]subscription.Subscription, error) { return []subscription.Subscription{{Channel: "test"}}, nil } - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.features.FullPayloadSubscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() // error on full subscribeToChannels + w.features.FullPayloadSubscribe = true + w.GenerateSubs = problemFunc + err = w.FlushChannels() // error on full subscribeToChannels assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - web.GenerateSubs = noSub - err = web.FlushChannels() // No subs to unsub + w.GenerateSubs = noSub + err = w.FlushChannels() // No subs to unsub assert.NoError(t, err, "FlushChannels should not error") - web.GenerateSubs = newgen.generateSubs - subs, err := web.GenerateSubs() + w.GenerateSubs = newgen.generateSubs + subs, err := w.GenerateSubs() require.NoError(t, err, "GenerateSubs must not error") - web.AddSuccessfulSubscriptions(subs...) - err = web.FlushChannels() + w.AddSuccessfulSubscriptions(subs...) + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.features.FullPayloadSubscribe = false - web.features.Subscribe = true + w.features.FullPayloadSubscribe = false + w.features.Subscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() + w.GenerateSubs = problemFunc + err = w.FlushChannels() assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - web.GenerateSubs = newgen.generateSubs - err = web.FlushChannels() + w.GenerateSubs = newgen.generateSubs + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.subscriptionMutex.Lock() - web.subscriptions = subscriptionMap{ + w.subscriptionMutex.Lock() + w.subscriptions = subscriptionMap{ 41: { Key: 41, Channel: "match channel", @@ -1023,34 +1024,34 @@ func TestFlushChannels(t *testing.T) { Pair: currency.NewPair(currency.THETA, currency.USDT), }, } - web.subscriptionMutex.Unlock() + w.subscriptionMutex.Unlock() - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - err = web.FlushChannels() + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - web.setState(connected) - web.features.Unsubscribe = true - err = web.FlushChannels() + w.setState(connected) + w.features.Unsubscribe = true + err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") } func TestDisable(t *testing.T) { t.Parallel() - web := Websocket{ - enabled: true, - state: connected, + w := Websocket{ ShutdownC: make(chan struct{}), } - require.NoError(t, web.Disable(), "Disable must not error") - assert.ErrorIs(t, web.Disable(), ErrAlreadyDisabled, "Disable should error correctly") + w.setEnabled(true) + w.setState(connected) + require.NoError(t, w.Disable(), "Disable must not error") + assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } func TestEnable(t *testing.T) { t.Parallel() - web := Websocket{ + w := Websocket{ connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), @@ -1060,8 +1061,8 @@ func TestEnable(t *testing.T) { Subscriber: func([]subscription.Subscription) error { return nil }, } - require.NoError(t, web.Enable(), "Enable must not error") - assert.ErrorIs(t, web.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") + require.NoError(t, w.Enable(), "Enable must not error") + assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 31e1db58b09..a783d585a4e 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -2,6 +2,7 @@ package stream import ( "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -23,10 +24,8 @@ const ( type subscriptionMap map[any]*subscription.Subscription -type state int - const ( - uninitialised state = iota + uninitialised uint32 = iota disconnected connecting connected @@ -35,13 +34,13 @@ const ( // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { - canUseAuthenticatedEndpoints bool - enabled bool - state state + canUseAuthenticatedEndpoints atomic.Bool + enabled atomic.Bool + state atomic.Uint32 verbose bool - connectionMonitorRunning bool - trafficMonitorRunning bool - dataMonitorRunning bool + connectionMonitorRunning atomic.Bool + trafficMonitorRunning atomic.Bool + dataMonitorRunning atomic.Bool trafficTimeout time.Duration connectionMonitorDelay time.Duration proxyAddr string @@ -51,7 +50,6 @@ type Websocket struct { runningURLAuth string exchangeName string m sync.Mutex - fieldMutex sync.RWMutex connector func() error subscriptionMutex sync.RWMutex From c3edef2cce3f6b6d900a31f66d6f47a09688ed99 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 15 Feb 2024 20:14:47 +0700 Subject: [PATCH 15/35] Websocket: Fix and simplify traffic monitor trafficMonitor had a check throttle at the end of the for loop to stop it just gobbling the (blocking) trafficAlert channel non-stop. That makes sense, except that nothing is sent to the trafficAlert channel if there's no listener. So that means that it's out by one second on the trafficAlert, because any traffic received during the pause is doesn't try to send a traffic alert. The unstopped timer is deliberately leaked for later GC when shutdown. It won't delay/block anything, and it's a trivial memory leak during an infrequent event. Deliberately Choosing to recreate the timer each time instead of using Stop, drain and reset --- exchanges/stream/websocket.go | 78 +++++++----------- exchanges/stream/websocket_connection.go | 2 +- exchanges/stream/websocket_test.go | 100 +++++++++++++++++++---- 3 files changed, 113 insertions(+), 67 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 409e0953251..bebc3abb595 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -16,12 +16,7 @@ import ( ) const ( - defaultJobBuffer = 5000 - // defaultTrafficPeriod defines a period of pause for the traffic monitor, - // as there are periods with large incoming traffic alerts which requires a - // timer reset, this limits work on this routine to a more effective rate - // of check. - defaultTrafficPeriod = time.Second + jobBuffer = 5000 ) // Public websocket errors @@ -37,6 +32,7 @@ var ( ErrNotConnected = errors.New("websocket is not connected") ) +// Private websocket errors var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") @@ -72,7 +68,10 @@ var ( errConnSetup = errors.New("error in connection setup") ) -var globalReporter Reporter +var ( + globalReporter Reporter + trafficCheckInterval = time.Second +) // SetupGlobalReporter sets a reporter interface to be used // for all exchange requests @@ -83,9 +82,9 @@ func SetupGlobalReporter(r Reporter) { // NewWebsocket initialises the websocket struct func NewWebsocket() *Websocket { return &Websocket{ - DataHandler: make(chan interface{}, defaultJobBuffer), - ToRoutine: make(chan interface{}, defaultJobBuffer), - TrafficAlert: make(chan struct{}), + DataHandler: make(chan interface{}, jobBuffer), + ToRoutine: make(chan interface{}, jobBuffer), + TrafficAlert: make(chan struct{}, 1), ReadMessageErrors: make(chan error), Subscribe: make(chan []subscription.Subscription), Unsubscribe: make(chan []subscription.Subscription), @@ -545,9 +544,9 @@ func (w *Websocket) FlushChannels() error { return w.Connect() } -// trafficMonitor uses a timer of WebsocketTrafficLimitTime and once it expires, -// it will reconnect if the TrafficAlert channel has not received any data. The -// trafficTimer will reset on each traffic alert +// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert +// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic +// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again func (w *Websocket) trafficMonitor() { if w.IsTrafficMonitorRunning() { return @@ -556,62 +555,45 @@ func (w *Websocket) trafficMonitor() { w.Wg.Add(1) go func() { - var trafficTimer = time.NewTimer(w.trafficTimeout) - pause := make(chan struct{}) + t := time.NewTimer(w.trafficTimeout) for { select { case <-w.ShutdownC: if w.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) } - trafficTimer.Stop() + t.Stop() w.setTrafficMonitorRunning(false) w.Wg.Done() return - case <-w.TrafficAlert: - if !trafficTimer.Stop() { - select { - case <-trafficTimer.C: - default: + case <-time.After(trafficCheckInterval): + select { + case <-w.TrafficAlert: + w.setState(connected) + if !t.Stop() { + <-t.C } + t.Reset(w.trafficTimeout) + default: + } + case <-t.C: + if w.IsConnecting() { + t.Reset(w.trafficTimeout) + break } - w.setState(connected) - trafficTimer.Reset(w.trafficTimeout) - case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) } - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() // without this the w.Shutdown() call below will deadlock - if !w.IsConnecting() && w.IsConnected() { + w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call + w.Wg.Done() // Without this the w.Shutdown() call below will deadlock + if w.IsConnected() { err := w.Shutdown() if err != nil { log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) } } - return } - - if w.IsConnected() { - // Routine pausing mechanism - go func(p chan<- struct{}) { - time.Sleep(defaultTrafficPeriod) - select { - case p <- struct{}{}: - default: - } - }(pause) - select { - case <-w.ShutdownC: - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() - return - case <-pause: - } - } } }() } diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 910142caa9a..b5c7e83d9e7 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -227,7 +227,7 @@ func (w *WebsocketConnection) ReadMessage() Response { select { case w.Traffic <- struct{}{}: - default: // causes contention, just bypass if there is no receiver. + default: // Non-Blocking write ensuses 1 buffered signal per trafficCheckInterval to avoid flooding } var standardMessage []byte diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 4725cc9b064..692d84bed4a 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -181,22 +181,93 @@ func TestTrafficMonitorTimeout(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - ws.trafficTimeout = time.Millisecond * 42 + signal := struct{}{} + patience := 10 * time.Millisecond + trafficCheckInterval = 50 * time.Millisecond + ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) + + thenish := time.Now() ws.trafficMonitor() + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, disconnected, ws.state.Load(), "websocket must be disconnected") - // Deploy traffic alert - ws.TrafficAlert <- struct{}{} - // try to add another traffic monitor + // Behaviour: Test multiple traffic alerts work and only process one trafficAlert per interval + for i := 0; i < 2; i++ { + ws.state.Store(disconnected) + + select { + case ws.TrafficAlert <- signal: + if i == 0 { + require.WithinDurationf(t, time.Now(), thenish, trafficCheckInterval, "First Non-blocking test must happen before the traffic is checked") + } + default: + require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) + } + + select { + case ws.TrafficAlert <- signal: + require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) + default: + if i == 0 { + require.WithinDuration(t, time.Now(), thenish, trafficCheckInterval, "First Blocking test must happen before the traffic is checked") + } + } + + require.EventuallyWithTf(t, func(c *assert.CollectT) { + assert.Truef(c, ws.IsConnected(), "state should been marked as connected; Check #%d", i) + assert.Emptyf(c, ws.TrafficAlert, "trafficAlert channel should be drained; Check #%d", i) + }, 2*trafficCheckInterval, patience, "trafficAlert should be read and state connected; Check #%d", i) + } + + // Behaviour: Shuts down websocket and exits on timeout + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor shound be shut down") + }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown") + + // Behaviour: connecting status doesn't trigger shutdown + ws.state.Store(connecting) + ws.trafficTimeout = 50 * time.Millisecond + ws.trafficMonitor() + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting") + <-time.After(4 * ws.trafficTimeout) + require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks") + ws.state.Store(connected) + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor shound be shut down") + }, 4*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown after connecting status changes") + + // Behaviour: shutdown is processed and waitgroup is cleared + trafficCheckInterval = 10 * time.Millisecond + ws.state.Store(connected) + ws.trafficTimeout = time.Minute ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - // prevent shutdown routine - ws.setState(disconnected) - // await timeout closure - ws.Wg.Wait() - assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be not be running") + wgReady := make(chan bool) + go func() { + ws.Wg.Wait() + close(wgReady) + }() + select { + case <-wgReady: + require.Failf(t, "", "WaitGroup should be blocking still") + case <-time.After(1 * trafficCheckInterval): + } + + close(ws.ShutdownC) + + <-time.After(2 * trafficCheckInterval) + assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be shutdown") + select { + case <-wgReady: + default: + require.Failf(t, "", "WaitGroup should be freed now") + } } func TestIsDisconnectionError(t *testing.T) { @@ -1079,18 +1150,11 @@ func TestSetupNewConnection(t *testing.T) { err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") - nonsenseWebsock.TrafficAlert = make(chan struct{}) + nonsenseWebsock.TrafficAlert = make(chan struct{}, 1) err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") - web := Websocket{ - connector: connect, - Wg: new(sync.WaitGroup), - ShutdownC: make(chan struct{}), - TrafficAlert: make(chan struct{}), - ReadMessageErrors: make(chan error), - DataHandler: make(chan interface{}), - } + web := NewWebsocket() err = web.Setup(defaultSetup) assert.NoError(t, err, "Setup should not error") From 4bcf97340b4492d5b01378933f45a877f70bf2a5 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 17 Feb 2024 09:05:59 +0700 Subject: [PATCH 16/35] fixup! Websocket: Fix and simplify traffic monitor Fix race on changing trafficCheckInterval --- exchanges/stream/websocket_test.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 692d84bed4a..69f6461c67e 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -105,6 +105,11 @@ func (d *dodgyConnection) Connect() error { return fmt.Errorf("cannot connect: %w", errDastardlyReason) } +func TestMain(_ *testing.M) { + // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing + trafficCheckInterval = 50 * time.Millisecond +} + func TestSetup(t *testing.T) { t.Parallel() var w *Websocket @@ -183,7 +188,7 @@ func TestTrafficMonitorTimeout(t *testing.T) { signal := struct{}{} patience := 10 * time.Millisecond - trafficCheckInterval = 50 * time.Millisecond + // trafficCheckInterval is changed in TestMain to avoid racing ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) From 61910caaa556ad7b5aaacf735ddf0c76e0a8e07c Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Tue, 20 Feb 2024 11:28:01 +0700 Subject: [PATCH 17/35] fixup! Websocket: Fix and simplify traffic monitor --- exchanges/stream/websocket_connection.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index b5c7e83d9e7..4d7681f8d13 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -227,7 +227,7 @@ func (w *WebsocketConnection) ReadMessage() Response { select { case w.Traffic <- struct{}{}: - default: // Non-Blocking write ensuses 1 buffered signal per trafficCheckInterval to avoid flooding + default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding } var standardMessage []byte From 56b62692c9df5e0c9319d2ec3f1b00092676bfea Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Tue, 20 Feb 2024 11:30:46 +0700 Subject: [PATCH 18/35] fixup! Websocket: Fix and simplify traffic monitor --- exchanges/stream/websocket_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 69f6461c67e..74acf8c3d24 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/http" + "os" "sort" "strconv" "strings" @@ -105,9 +106,10 @@ func (d *dodgyConnection) Connect() error { return fmt.Errorf("cannot connect: %w", errDastardlyReason) } -func TestMain(_ *testing.M) { +func TestMain(m *testing.M) { // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing trafficCheckInterval = 50 * time.Millisecond + os.Exit(m.Run()) } func TestSetup(t *testing.T) { From 7dd00dc94be82daa0356eea3fe2b4c9e28d791d5 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Tue, 20 Feb 2024 11:32:41 +0700 Subject: [PATCH 19/35] fixup! Websocket: Fix and simplify traffic monitor --- exchanges/stream/websocket_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 74acf8c3d24..957fd8d1621 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -249,7 +249,6 @@ func TestTrafficMonitorTimeout(t *testing.T) { }, 4*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown after connecting status changes") // Behaviour: shutdown is processed and waitgroup is cleared - trafficCheckInterval = 10 * time.Millisecond ws.state.Store(connected) ws.trafficTimeout = time.Minute ws.trafficMonitor() From d6918391e935a17574f63b410b5d7771205fa6d5 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Tue, 20 Feb 2024 14:57:20 +0700 Subject: [PATCH 20/35] Websocket: Split traficMonitor test on behaviours --- exchanges/stream/websocket_test.go | 32 ++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 957fd8d1621..4273f739765 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -182,7 +182,9 @@ func TestSetup(t *testing.T) { assert.NoError(t, err, "Setup should not error") } -func TestTrafficMonitorTimeout(t *testing.T) { +// TestTrafficMonitorTrafficAlerts ensures multiple traffic alerts work and only process one trafficAlert per interval +// ensures shutdown works after traffic alerts +func TestTrafficMonitorTrafficAlerts(t *testing.T) { t.Parallel() ws := NewWebsocket() err := ws.Setup(defaultSetup) @@ -190,7 +192,6 @@ func TestTrafficMonitorTimeout(t *testing.T) { signal := struct{}{} patience := 10 * time.Millisecond - // trafficCheckInterval is changed in TestMain to avoid racing ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) @@ -200,7 +201,6 @@ func TestTrafficMonitorTimeout(t *testing.T) { assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") require.Equal(t, disconnected, ws.state.Load(), "websocket must be disconnected") - // Behaviour: Test multiple traffic alerts work and only process one trafficAlert per interval for i := 0; i < 2; i++ { ws.state.Store(disconnected) @@ -228,17 +228,24 @@ func TestTrafficMonitorTimeout(t *testing.T) { }, 2*trafficCheckInterval, patience, "trafficAlert should be read and state connected; Check #%d", i) } - // Behaviour: Shuts down websocket and exits on timeout require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor shound be shut down") }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown") +} + +// TestTrafficMonitorConnecting ensure connecting status doesn't trigger shutdown +func TestTrafficMonitorConnecting(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") - // Behaviour: connecting status doesn't trigger shutdown + ws.ShutdownC = make(chan struct{}) ws.state.Store(connecting) ws.trafficTimeout = 50 * time.Millisecond ws.trafficMonitor() - assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting") <-time.After(4 * ws.trafficTimeout) require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks") @@ -246,9 +253,18 @@ func TestTrafficMonitorTimeout(t *testing.T) { require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor shound be shut down") - }, 4*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown after connecting status changes") + }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") +} + +// TestTrafficMonitorShutdown ensure shutdown is processed and waitgroup is cleared +func TestTrafficMonitorShutdown(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") - // Behaviour: shutdown is processed and waitgroup is cleared + ws.ShutdownC = make(chan struct{}) + ws.trafficTimeout = 50 * time.Millisecond ws.state.Store(connected) ws.trafficTimeout = time.Minute ws.trafficMonitor() From 2f6bfc0d59c00d42117fc87ac1dd4a5e635ca380 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Tue, 20 Feb 2024 15:46:33 +0700 Subject: [PATCH 21/35] Websocket: Remove trafficMonitor connected status trafficMonitor does not need to set the connection to be connected. Connect() does that. Anything after that should result in a full shutdown and restart. It can't and shouldn't become connected unexpectedly, and this is most likely a race anyway. Also dropped trafficCheckInterval to 100ms to mitigate races of traffic alerts being buffered for too long. --- exchanges/stream/websocket.go | 11 ++++++++--- exchanges/stream/websocket_test.go | 17 ++++++++--------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index bebc3abb595..7d1d402d161 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -70,7 +70,7 @@ var ( var ( globalReporter Reporter - trafficCheckInterval = time.Second + trafficCheckInterval = 100 * time.Millisecond ) // SetupGlobalReporter sets a reporter interface to be used @@ -569,7 +569,6 @@ func (w *Websocket) trafficMonitor() { case <-time.After(trafficCheckInterval): select { case <-w.TrafficAlert: - w.setState(connected) if !t.Stop() { <-t.C } @@ -577,7 +576,13 @@ func (w *Websocket) trafficMonitor() { default: } case <-t.C: - if w.IsConnecting() { + checkAgain := w.IsConnecting() + select { + case <-w.TrafficAlert: + checkAgain = true + default: + } + if checkAgain { t.Reset(w.trafficTimeout) break } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 4273f739765..b099adca17f 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -194,16 +194,15 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { patience := 10 * time.Millisecond ws.trafficTimeout = 200 * time.Millisecond ws.ShutdownC = make(chan struct{}) + ws.state.Store(connected) thenish := time.Now() ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, disconnected, ws.state.Load(), "websocket must be disconnected") - - for i := 0; i < 2; i++ { - ws.state.Store(disconnected) + require.Equal(t, connected, ws.state.Load(), "websocket must be connected") + for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass select { case ws.TrafficAlert <- signal: if i == 0 { @@ -223,18 +222,18 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { } require.EventuallyWithTf(t, func(c *assert.CollectT) { - assert.Truef(c, ws.IsConnected(), "state should been marked as connected; Check #%d", i) + assert.Truef(c, ws.IsConnected(), "state should still be connected; Check #%d", i) assert.Emptyf(c, ws.TrafficAlert, "trafficAlert channel should be drained; Check #%d", i) - }, 2*trafficCheckInterval, patience, "trafficAlert should be read and state connected; Check #%d", i) + }, 2*trafficCheckInterval, patience, "trafficAlert should be read; Check #%d", i) } require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor shound be shut down") - }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown") + }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") } -// TestTrafficMonitorConnecting ensure connecting status doesn't trigger shutdown +// TestTrafficMonitorConnecting ensures connecting status doesn't trigger shutdown func TestTrafficMonitorConnecting(t *testing.T) { t.Parallel() ws := NewWebsocket() @@ -256,7 +255,7 @@ func TestTrafficMonitorConnecting(t *testing.T) { }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") } -// TestTrafficMonitorShutdown ensure shutdown is processed and waitgroup is cleared +// TestTrafficMonitorShutdown ensures shutdown is processed and waitgroup is cleared func TestTrafficMonitorShutdown(t *testing.T) { t.Parallel() ws := NewWebsocket() From d0cd94ee1d9b736d6fa9485ba6a321ae328c611f Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 21 Feb 2024 08:16:58 +0700 Subject: [PATCH 22/35] fixup! Websocket: Split traficMonitor test on behaviours --- exchanges/stream/websocket_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index b099adca17f..ea7150c5808 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -229,7 +229,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") - assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor shound be shut down") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") } @@ -251,7 +251,7 @@ func TestTrafficMonitorConnecting(t *testing.T) { ws.state.Store(connected) require.EventuallyWithT(t, func(c *assert.CollectT) { assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") - assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor shound be shut down") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") } @@ -263,7 +263,6 @@ func TestTrafficMonitorShutdown(t *testing.T) { require.NoError(t, err, "Setup must not error") ws.ShutdownC = make(chan struct{}) - ws.trafficTimeout = 50 * time.Millisecond ws.state.Store(connected) ws.trafficTimeout = time.Minute ws.trafficMonitor() From e040f17b41a7f9a03d3b110ab3541c01b650fb9e Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 21 Feb 2024 08:39:02 +0700 Subject: [PATCH 23/35] fixup! Websocket: Fix and simplify traffic monitor --- exchanges/stream/websocket_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index ea7150c5808..b7d4c1f313c 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -276,7 +276,7 @@ func TestTrafficMonitorShutdown(t *testing.T) { select { case <-wgReady: require.Failf(t, "", "WaitGroup should be blocking still") - case <-time.After(1 * trafficCheckInterval): + case <-time.After(trafficCheckInterval): } close(ws.ShutdownC) From 89be3d85773d6732bac5be290a95ee9bcfe56945 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 21 Feb 2024 08:40:27 +0700 Subject: [PATCH 24/35] Websocket: Set disconnected earlier in Shutdown This caused a possible race where state is still connected, but we start to trigger interested actors via ShutdownC and Wait. They may check state and then call Shutdown again, such as trafficMonitor --- exchanges/stream/websocket.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 7d1d402d161..404d76d86f1 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -477,10 +477,11 @@ func (w *Websocket) Shutdown() error { w.subscriptions = subscriptionMap{} w.subscriptionMutex.Unlock() + w.setState(disconnected) + close(w.ShutdownC) w.Wg.Wait() w.ShutdownC = make(chan struct{}) - w.setState(disconnected) if w.verbose { log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } From 1b1719e3187b846159b57221d3a1a4567f3e22cf Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 27 Dec 2023 20:09:49 +0700 Subject: [PATCH 25/35] Common: Fix fmt msg lost in AppendError().Error() Retain the .msg field of a go fmt.Errorf .msg field returned by .Error() when wrapping multiple errors. This fixes a situation where a nested stack of errors would lose formatting information, which is often used to supply identifying context. e.g. ``` err = fmt.Errorf("%w `%s`: %w", errParsingField, fieldName, parsingError) errs = common.AppendError(errs, err) ``` This isn't really an issue with our implementation; Calling Unwrap() on a fmt.Errorf() which returns a wrapErrors will lose that formatting. Our issue was that we were using just Unwrap() to bind together our chain-of-custody. --- common/common.go | 55 +++++++++++++++++++++++++++++++++---------- common/common_test.go | 7 ++++++ 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/common/common.go b/common/common.go index 9e0d32db2bc..c4483a2d5cc 100644 --- a/common/common.go +++ b/common/common.go @@ -463,11 +463,22 @@ func InArray(val, array interface{}) (exists bool, index int) { return } +// fmtError holds a formatted msg and the errors which formatted it +type fmtError struct { + errs []error + msg string +} + // multiError holds errors as a slice type multiError struct { errs []error } +type unwrappable interface { + Unwrap() []error + Error() string +} + // AppendError appends an error to a list of exesting errors // Either argument may be: // * A vanilla error @@ -481,20 +492,35 @@ func AppendError(original, incoming error) error { if original == nil { return incoming } - newErrs := []error{incoming} - if u, ok := incoming.(interface{ Unwrap() []error }); ok { - newErrs = u.Unwrap() + if u, ok := incoming.(unwrappable); ok { + incoming = &fmtError{ + errs: u.Unwrap(), + msg: u.Error(), + } } - if u, ok := original.(interface{ Unwrap() []error }); ok { - return &multiError{ - errs: append(u.Unwrap(), newErrs...), + switch v := original.(type) { + case *multiError: + v.errs = append(v.errs, incoming) + return v + case unwrappable: + original = &fmtError{ + errs: v.Unwrap(), + msg: v.Error(), } } return &multiError{ - errs: append([]error{original}, newErrs...), + errs: append([]error{original}, incoming), } } +func (e *fmtError) Error() string { + return e.msg +} + +func (e *fmtError) Unwrap() []error { + return e.errs +} + // Error displays all errors comma separated func (e *multiError) Error() string { allErrors := make([]string, len(e.errs)) @@ -506,11 +532,16 @@ func (e *multiError) Error() string { // Unwrap returns all of the errors in the multiError func (e *multiError) Unwrap() []error { - return e.errs -} - -type unwrappable interface { - Unwrap() []error + errs := make([]error, 0, len(e.errs)) + for _, e := range e.errs { + switch v := e.(type) { + case unwrappable: + errs = append(errs, unwrapDeep(v)...) + default: + errs = append(errs, v) + } + } + return errs } // unwrapDeep walks down a stack of nested fmt.Errorf("%w: %w") errors diff --git a/common/common_test.go b/common/common_test.go index fd6d448d4c6..e9c8bf28412 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -703,6 +703,13 @@ func TestErrors(t *testing.T) { assert.ErrorIs(t, ExcludeError(err, e5), e3, "Excluding e5 should retain e3") assert.ErrorIs(t, ExcludeError(err, e5), e4, "Excluding e5 should retain the vanilla co-wrapped e4") assert.NotErrorIs(t, ExcludeError(err, e5), e5, "e4 should be excluded") + + // Formatting retention + err = AppendError(e1, fmt.Errorf("%w: Run out of `%s`: %w", e3, "sausages", e5)) + assert.ErrorIs(t, err, e1, "Should be an e1") + assert.ErrorIs(t, err, e3, "Should be an e3") + assert.ErrorIs(t, err, e5, "Should be an e5") + assert.ErrorContains(t, err, "sausages", "Should know about secret snausages") } func TestParseStartEndDate(t *testing.T) { From 8c88cb1b3a616d851b517c40f0f0ced1b516d4fc Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 19 Feb 2024 09:55:19 +0700 Subject: [PATCH 26/35] Tests: TestFixtureToDataHandler preserve WS If the exchange passed in already has a websocket, don't clobber it --- exchanges/sharedtestvalues/sharedtestvalues.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index 54acf9dcb5b..71792bbd977 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -165,10 +165,14 @@ func TestFixtureToDataHandler(t *testing.T, seed, e exchange.IBotExchange, fixtu assert.NoError(t, err, "Loading currency pairs should not error") b.Name = "fixture" - b.Websocket = &stream.Websocket{ - Wg: new(sync.WaitGroup), - DataHandler: make(chan interface{}, 128), + + if b.Websocket == nil { + b.Websocket = &stream.Websocket{ + Wg: new(sync.WaitGroup), + DataHandler: make(chan interface{}, 128), + } } + b.API.Endpoints = b.NewEndpoints() fixture, err := os.Open(fixturePath) From 07748446c20d7138e551b3d4796566c4a12c51af Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 12 Feb 2024 13:37:12 +0700 Subject: [PATCH 27/35] Tests: Fix WsAuth turned off by config checking --- internal/testing/exchange/exchange.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/testing/exchange/exchange.go b/internal/testing/exchange/exchange.go index 8a943677984..b8b20dfb624 100644 --- a/internal/testing/exchange/exchange.go +++ b/internal/testing/exchange/exchange.go @@ -94,6 +94,8 @@ func MockWSInstance[T any, PT interface { b := e.GetBase() b.SkipAuthCheck = true + b.API.AuthenticatedWebsocketSupport = true + err := b.API.Endpoints.SetRunning("RestSpotURL", s.URL) require.NoError(tb, err, "Endpoints.SetRunning should not error for RestSpotURL") for _, auth := range []bool{true, false} { From d92e451be438da6b7a3db1f1167c931920127bd0 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 19 Feb 2024 09:56:35 +0700 Subject: [PATCH 28/35] Websocket: Use ErrSubscribedAlready instead of errChannelAlreadySubscribed --- exchanges/stream/websocket.go | 3 +-- exchanges/stream/websocket_test.go | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 404d76d86f1..17968a79549 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -58,7 +58,6 @@ var ( errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") errNoSubscriptionsSupplied = errors.New("no subscriptions supplied") - errChannelAlreadySubscribed = errors.New("channel already subscribed") errInvalidChannelState = errors.New("invalid Channel state") errSameProxyAddress = errors.New("cannot set proxy address to the same address") errNoConnectFunc = errors.New("websocket connect func not set") @@ -1003,7 +1002,7 @@ func (w *Websocket) checkSubscriptions(subs []subscription.Subscription) error { for i := range subs { key := subs[i].EnsureKeyed() if _, ok := w.subscriptions[key]; ok { - return fmt.Errorf("%w for %+v", errChannelAlreadySubscribed, subs[i]) + return fmt.Errorf("%w for %+v", ErrSubscribedAlready, subs[i]) } } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index b7d4c1f313c..dd827336960 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -485,7 +485,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { } assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, ws.SubscribeToChannels(subs), errChannelAlreadySubscribed, "Subscribe should error when already subscribed") + assert.ErrorIs(t, ws.SubscribeToChannels(subs), ErrSubscribedAlready, "Subscribe should error when already subscribed") assert.ErrorIs(t, ws.SubscribeToChannels(nil), errNoSubscriptionsSupplied, "Subscribe to nil should error") assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error") } @@ -1269,7 +1269,7 @@ func TestCheckSubscriptions(t *testing.T) { ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}} err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}}) - assert.ErrorIs(t, err, errChannelAlreadySubscribed, "checkSubscriptions should error correctly") + assert.ErrorIs(t, err, ErrSubscribedAlready, "checkSubscriptions should error correctly") err = ws.checkSubscriptions([]subscription.Subscription{{}}) assert.NoError(t, err, "checkSubscriptions should not error") From aa5b69618c1c5e5761a17111b9db5f017c90eca9 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 10 Jan 2024 10:02:15 +0700 Subject: [PATCH 29/35] Bitfinex: Driveby removal of lint hint --- exchanges/bitfinex/bitfinex_wrapper.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 2aba9987cf9..23c7446882d 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -643,8 +643,8 @@ func (b *Bitfinex) SubmitOrder(ctx context.Context, o *order.Submit) (*order.Sub var orderID string status := order.New if b.Websocket.CanUseAuthenticatedWebsocketForWrapper() { - symbolStr, err := b.fixCasing(fPair, o.AssetType) //nolint:govet // intentional shadow of err - if err != nil { + var symbolStr string + if symbolStr, err = b.fixCasing(fPair, o.AssetType); err != nil { return nil, err } orderType := strings.ToUpper(o.Type.String()) From 9d90615d1c47f2b57cf3f0d162c4dbeaa3a274ce Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 19 Feb 2024 14:02:23 +0700 Subject: [PATCH 30/35] Subscriptions: Replace Pair with Pairs Given that some subscriptions have multiple pairs, support that as the standard. --- exchanges/exchange.go | 6 +- exchanges/exchange_test.go | 8 +- exchanges/interfaces.go | 6 +- exchanges/stream/websocket.go | 240 +++++-------- exchanges/stream/websocket_test.go | 333 ++++++++---------- exchanges/stream/websocket_types.go | 31 +- exchanges/subscription/list.go | 16 + exchanges/subscription/store.go | 176 +++++++++ exchanges/subscription/subscription.go | 133 ++++--- exchanges/subscription/subscription_test.go | 40 ++- internal/testing/exchange/exchange.go | 8 +- .../testing/subscriptions/subscriptions.go | 27 ++ 12 files changed, 587 insertions(+), 437 deletions(-) create mode 100644 exchanges/subscription/list.go create mode 100644 exchanges/subscription/store.go create mode 100644 internal/testing/subscriptions/subscriptions.go diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 25931e59983..31b0d66a86c 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -1164,7 +1164,7 @@ func (b *Base) FlushWebsocketChannels() error { // SubscribeToWebsocketChannels appends to ChannelsToSubscribe // which lets websocket.manageSubscriptions handle subscribing -func (b *Base) SubscribeToWebsocketChannels(channels []subscription.Subscription) error { +func (b *Base) SubscribeToWebsocketChannels(channels []*subscription.Subscription) error { if b.Websocket == nil { return common.ErrFunctionNotSupported } @@ -1173,7 +1173,7 @@ func (b *Base) SubscribeToWebsocketChannels(channels []subscription.Subscription // UnsubscribeToWebsocketChannels removes from ChannelsToSubscribe // which lets websocket.manageSubscriptions handle unsubscribing -func (b *Base) UnsubscribeToWebsocketChannels(channels []subscription.Subscription) error { +func (b *Base) UnsubscribeToWebsocketChannels(channels []*subscription.Subscription) error { if b.Websocket == nil { return common.ErrFunctionNotSupported } @@ -1181,7 +1181,7 @@ func (b *Base) UnsubscribeToWebsocketChannels(channels []subscription.Subscripti } // GetSubscriptions returns a copied list of subscriptions -func (b *Base) GetSubscriptions() ([]subscription.Subscription, error) { +func (b *Base) GetSubscriptions() ([]*subscription.Subscription, error) { if b.Websocket == nil { return nil, common.ErrFunctionNotSupported } diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index ad0060495f3..7c0b3014020 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -1263,8 +1263,8 @@ func TestSetupDefaults(t *testing.T) { DefaultURL: "ws://something.com", RunningURL: "ws://something.com", Connector: func() error { return nil }, - GenerateSubscriptions: func() ([]subscription.Subscription, error) { return []subscription.Subscription{}, nil }, - Subscriber: func([]subscription.Subscription) error { return nil }, + GenerateSubscriptions: func() (subscription.List, error) { return subscription.List{}, nil }, + Subscriber: func(subscription.List) error { return nil }, }) if err != nil { t.Fatal(err) @@ -3279,7 +3279,7 @@ func TestSetSubscriptionsFromConfig(t *testing.T) { Features: &config.FeaturesConfig{}, }, } - subs := []*subscription.Subscription{ + subs := subscription.List{ {Channel: subscription.CandlesChannel, Interval: kline.OneDay, Enabled: true}, } b.Features.Subscriptions = subs @@ -3287,7 +3287,7 @@ func TestSetSubscriptionsFromConfig(t *testing.T) { assert.ElementsMatch(t, subs, b.Config.Features.Subscriptions, "Config Subscriptions should be updated") assert.ElementsMatch(t, subs, b.Features.Subscriptions, "Subscriptions should be the same") - subs = []*subscription.Subscription{ + subs = subscription.List{ {Channel: subscription.OrderbookChannel, Interval: kline.OneDay, Enabled: true}, } b.Config.Features.Subscriptions = subs diff --git a/exchanges/interfaces.go b/exchanges/interfaces.go index b7cb5b5c3f6..1745c1492fb 100644 --- a/exchanges/interfaces.go +++ b/exchanges/interfaces.go @@ -71,9 +71,9 @@ type IBotExchange interface { EnableRateLimiter() error GetServerTime(ctx context.Context, ai asset.Item) (time.Time, error) GetWebsocket() (*stream.Websocket, error) - SubscribeToWebsocketChannels(channels []subscription.Subscription) error - UnsubscribeToWebsocketChannels(channels []subscription.Subscription) error - GetSubscriptions() ([]subscription.Subscription, error) + SubscribeToWebsocketChannels(channels []*subscription.Subscription) error + UnsubscribeToWebsocketChannels(channels []*subscription.Subscription) error + GetSubscriptions() ([]*subscription.Subscription, error) FlushWebsocketChannels() error AuthenticateWebsocket(ctx context.Context) error GetOrderExecutionLimits(a asset.Item, cp currency.Pair) (order.MinMaxLevel, error) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 17968a79549..82d8f3f1847 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -5,12 +5,12 @@ import ( "fmt" "net" "net/url" - "sync" "time" "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/log" ) @@ -22,12 +22,9 @@ const ( // Public websocket errors var ( ErrWebsocketNotEnabled = errors.New("websocket not enabled") - ErrSubscriptionNotFound = errors.New("subscription not found") - ErrSubscribedAlready = errors.New("duplicate subscription") ErrSubscriptionFailure = errors.New("subscription failure") ErrSubscriptionNotSupported = errors.New("subscription channel not supported ") ErrUnsubscribeFailure = errors.New("unsubscribe failure") - ErrChannelInStateAlready = errors.New("channel already in state") ErrAlreadyDisabled = errors.New("websocket already disabled") ErrNotConnected = errors.New("websocket is not connected") ) @@ -58,7 +55,6 @@ var ( errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") errNoSubscriptionsSupplied = errors.New("no subscriptions supplied") - errInvalidChannelState = errors.New("invalid Channel state") errSameProxyAddress = errors.New("cannot set proxy address to the same address") errNoConnectFunc = errors.New("websocket connect func not set") errAlreadyConnected = errors.New("websocket already connected") @@ -83,11 +79,12 @@ func NewWebsocket() *Websocket { return &Websocket{ DataHandler: make(chan interface{}, jobBuffer), ToRoutine: make(chan interface{}, jobBuffer), + ShutdownC: make(chan struct{}), TrafficAlert: make(chan struct{}, 1), ReadMessageErrors: make(chan error), - Subscribe: make(chan []subscription.Subscription), - Unsubscribe: make(chan []subscription.Subscription), Match: NewMatch(), + subscriptions: subscription.NewStore(), + features: &protocol.Features{}, } } @@ -180,7 +177,6 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout w.ShutdownC = make(chan struct{}) - w.Wg = new(sync.WaitGroup) w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport) if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil { @@ -242,7 +238,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { Traffic: w.TrafficAlert, readMessageErrors: w.ReadMessageErrors, ShutdownC: w.ShutdownC, - Wg: w.Wg, + Wg: &w.Wg, Match: w.Match, RateLimit: c.RateLimit, Reporter: c.ConnectionLevelReporter, @@ -276,9 +272,10 @@ func (w *Websocket) Connect() error { return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected) } - w.subscriptionMutex.Lock() - w.subscriptions = subscriptionMap{} - w.subscriptionMutex.Unlock() + if w.subscriptions == nil { + return common.ErrNilPointer + } + w.subscriptions.Clear() w.dataMonitor() w.trafficMonitor() @@ -302,17 +299,12 @@ func (w *Websocket) Connect() error { if err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } - if len(subs) == 0 { - return nil - } - err = w.checkSubscriptions(subs) - if err != nil { - return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) - } - err = w.Subscriber(subs) - if err != nil { - return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + if len(subs) != 0 { + if err := w.SubscribeToChannels(subs); err != nil { + return err + } } + return nil } @@ -472,9 +464,7 @@ func (w *Websocket) Shutdown() error { } // flush any subscriptions from last connection if needed - w.subscriptionMutex.Lock() - w.subscriptions = subscriptionMap{} - w.subscriptionMutex.Unlock() + w.subscriptions.Clear() w.setState(disconnected) @@ -530,9 +520,7 @@ func (w *Websocket) FlushChannels() error { if len(newsubs) != 0 { // Purge subscription list as there will be conflicts - w.subscriptionMutex.Lock() - w.subscriptions = subscriptionMap{} - w.subscriptionMutex.Unlock() + w.subscriptions.Clear() return w.SubscribeToChannels(newsubs) } return nil @@ -789,163 +777,117 @@ func (w *Websocket) GetName() string { // GetChannelDifference finds the difference between the subscribed channels // and the new subscription list when pairs are disabled or enabled. -func (w *Websocket) GetChannelDifference(genSubs []subscription.Subscription) (sub, unsub []subscription.Subscription) { - w.subscriptionMutex.RLock() - unsubMap := make(map[any]subscription.Subscription, len(w.subscriptions)) - for k, c := range w.subscriptions { - unsubMap[k] = *c - } - w.subscriptionMutex.RUnlock() - - for i := range genSubs { - key := genSubs[i].EnsureKeyed() - if _, ok := unsubMap[key]; ok { - delete(unsubMap, key) // If it's in both then we remove it from the unsubscribe list - } else { - sub = append(sub, genSubs[i]) // If it's in genSubs but not existing subs we want to subscribe - } - } - - for x := range unsubMap { - unsub = append(unsub, unsubMap[x]) +func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub subscription.List) { + if w.subscriptions == nil { + w.subscriptions = subscription.NewStore() } - - return + return w.subscriptions.Diff(newSubs) } -// UnsubscribeChannels unsubscribes from a websocket channel -func (w *Websocket) UnsubscribeChannels(channels []subscription.Subscription) error { +// UnsubscribeChannels unsubscribes from a list of websocket channel +func (w *Websocket) UnsubscribeChannels(channels subscription.List) error { if len(channels) == 0 { return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoSubscriptionsSupplied) } - w.subscriptionMutex.RLock() - - for i := range channels { - key := channels[i].EnsureKeyed() - if _, ok := w.subscriptions[key]; !ok { - w.subscriptionMutex.RUnlock() - return fmt.Errorf("%s websocket: %w: %+v", w.exchangeName, ErrSubscriptionNotFound, channels[i]) + if w.subscriptions == nil { + return common.ErrNilPointer + } + for _, s := range channels { + if w.subscriptions.Get(s) == nil { + return fmt.Errorf("%s websocket: %w: %s", w.exchangeName, subscription.ErrNotFound, s) } } - w.subscriptionMutex.RUnlock() return w.Unsubscriber(channels) } // ResubscribeToChannel resubscribes to channel -func (w *Websocket) ResubscribeToChannel(subscribedChannel *subscription.Subscription) error { - err := w.UnsubscribeChannels([]subscription.Subscription{*subscribedChannel}) +func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error { + l := subscription.List{s} + err := w.UnsubscribeChannels(l) if err != nil { return err } - return w.SubscribeToChannels([]subscription.Subscription{*subscribedChannel}) + return w.SubscribeToChannels(l) } -// SubscribeToChannels appends supplied channels to channelsToSubscribe -func (w *Websocket) SubscribeToChannels(channels []subscription.Subscription) error { - if err := w.checkSubscriptions(channels); err != nil { +// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method +// Errors are returned for duplicates or exceeding max Subscriptions +func (w *Websocket) SubscribeToChannels(subs subscription.List) error { + if err := w.checkSubscriptions(subs); err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } - if err := w.Subscriber(channels); err != nil { + if err := w.Subscriber(subs); err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } return nil } -// AddSubscription adds a subscription to the subscription lists -// Unlike AddSubscriptions this method will error if the subscription already exists -func (w *Websocket) AddSubscription(c *subscription.Subscription) error { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() - if w.subscriptions == nil { - w.subscriptions = subscriptionMap{} +// AddSubscription adds a subscription to the subscription store +func (w *Websocket) AddSubscription(s *subscription.Subscription) error { + if w == nil || s == nil { + return common.ErrNilPointer } - key := c.EnsureKeyed() - if _, ok := w.subscriptions[key]; ok { - return ErrSubscribedAlready + if w.subscriptions == nil { + w.subscriptions = subscription.NewStore() } - - n := *c // Fresh copy; we don't want to use the pointer we were given and allow encapsulation/locks to be bypassed - w.subscriptions[key] = &n - - return nil + return w.subscriptions.Add(s) } -// SetSubscriptionState sets an existing subscription state -// returns an error if the subscription is not found, or the new state is already set -func (w *Websocket) SetSubscriptionState(c *subscription.Subscription, state subscription.State) error { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() - if w.subscriptions == nil { - w.subscriptions = subscriptionMap{} - } - key := c.EnsureKeyed() - p, ok := w.subscriptions[key] - if !ok { - return ErrSubscriptionNotFound +// AddSubscriptions adds subscriptions to the subscription store +func (w *Websocket) AddSubscriptions(subs subscription.List) error { + if w == nil { + return common.ErrNilPointer } - if state == p.State { - return ErrChannelInStateAlready + if w.subscriptions == nil { + w.subscriptions = subscription.NewStore() } - if state > subscription.UnsubscribingState { - return errInvalidChannelState + var errs error + for _, s := range subs { + if err := w.subscriptions.Add(s); err != nil { + errs = common.AppendError(errs, err) + } } - p.State = state - return nil + return errs } -// AddSuccessfulSubscriptions adds subscriptions to the subscription lists that -// has been successfully subscribed -func (w *Websocket) AddSuccessfulSubscriptions(channels ...subscription.Subscription) { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() - if w.subscriptions == nil { - w.subscriptions = subscriptionMap{} - } - for _, cN := range channels { //nolint:gocritic // See below comment - c := cN // cN is an iteration var; Not safe to make a pointer to - key := c.EnsureKeyed() - c.State = subscription.SubscribedState - w.subscriptions[key] = &c +// RemoveSubscription removes a subscription from the subscription store +func (w *Websocket) RemoveSubscription(s *subscription.Subscription) error { + if w == nil || w.subscriptions == nil || s == nil { + return common.ErrNilPointer } + return w.subscriptions.Remove(s) } // RemoveSubscriptions removes subscriptions from the subscription list -func (w *Websocket) RemoveSubscriptions(channels ...subscription.Subscription) { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() - if w.subscriptions == nil { - w.subscriptions = subscriptionMap{} - } - for i := range channels { - key := channels[i].EnsureKeyed() - delete(w.subscriptions, key) +func (w *Websocket) RemoveSubscriptions(subs subscription.List) error { + if w == nil || w.subscriptions == nil { + return common.ErrNilPointer + } + var errs error + for _, s := range subs { + if err := w.subscriptions.Remove(s); err != nil { + errs = common.AppendError(errs, err) + } } + return errs } -// GetSubscription returns a pointer to a copy of the subscription at the key provided +// GetSubscription returns a subscription at the key provided // returns nil if no subscription is at that key or the key is nil +// Keys can implement subscription.MatchableKey in order to provide custom matching logic func (w *Websocket) GetSubscription(key any) *subscription.Subscription { - if key == nil || w == nil || w.subscriptions == nil { + if w == nil || w.subscriptions == nil || key == nil { return nil } - w.subscriptionMutex.RLock() - defer w.subscriptionMutex.RUnlock() - if s, ok := w.subscriptions[key]; ok { - c := *s - return &c - } - return nil + return w.subscriptions.Get(key) } // GetSubscriptions returns a new slice of the subscriptions -func (w *Websocket) GetSubscriptions() []subscription.Subscription { - w.subscriptionMutex.RLock() - defer w.subscriptionMutex.RUnlock() - subs := make([]subscription.Subscription, 0, len(w.subscriptions)) - for _, c := range w.subscriptions { - subs = append(subs, *c) +func (w *Websocket) GetSubscriptions() subscription.List { + if w == nil || w.subscriptions == nil { + return nil } - return subs + return w.subscriptions.List() } // SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner @@ -981,28 +923,28 @@ func checkWebsocketURL(s string) error { return nil } -// checkSubscriptions checks subscriptions against the max subscription limit -// and if the subscription already exists. -func (w *Websocket) checkSubscriptions(subs []subscription.Subscription) error { +// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists +// The subscription state is not considered when counting existing subscriptions +func (w *Websocket) checkSubscriptions(subs subscription.List) error { if len(subs) == 0 { return errNoSubscriptionsSupplied } + if w.subscriptions == nil { + return common.ErrNilPointer + } - w.subscriptionMutex.RLock() - defer w.subscriptionMutex.RUnlock() - - if w.MaxSubscriptionsPerConnection > 0 && len(w.subscriptions)+len(subs) > w.MaxSubscriptionsPerConnection { + existing := w.subscriptions.Len() + if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection { return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs", errSubscriptionsExceedsLimit, - len(w.subscriptions), + existing, len(subs), w.MaxSubscriptionsPerConnection) } - for i := range subs { - key := subs[i].EnsureKeyed() - if _, ok := w.subscriptions[key]; ok { - return fmt.Errorf("%w for %+v", ErrSubscribedAlready, subs[i]) + for _, s := range subs { + if found := w.subscriptions.Get(s); found != nil { + return fmt.Errorf("%w for %s", subscription.ErrDuplicate, s) } } diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index dd827336960..a10a5402f9c 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -20,6 +20,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -79,10 +80,10 @@ var defaultSetup = &WebsocketSetup{ DefaultURL: "testDefaultURL", RunningURL: "wss://testRunningURL", Connector: func() error { return nil }, - Subscriber: func([]subscription.Subscription) error { return nil }, - Unsubscriber: func([]subscription.Subscription) error { return nil }, - GenerateSubscriptions: func() ([]subscription.Subscription, error) { - return []subscription.Subscription{ + Subscriber: func(subscription.List) error { return nil }, + Unsubscriber: func(subscription.List) error { return nil }, + GenerateSubscriptions: func() (subscription.List, error) { + return subscription.List{ {Channel: "TestSub"}, {Channel: "TestSub2", Key: "purple"}, {Channel: "TestSub3", Key: testSubKey{"mauve"}}, @@ -147,16 +148,16 @@ func TestSetup(t *testing.T) { err = w.Setup(websocketSetup) assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly") - websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil } + websocketSetup.Subscriber = func(subscription.List) error { return nil } websocketSetup.Features.Unsubscribe = true err = w.Setup(websocketSetup) assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly") - websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil } + websocketSetup.Unsubscriber = func(subscription.List) error { return nil } err = w.Setup(websocketSetup) assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly") - websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil } + websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } err = w.Setup(websocketSetup) assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly") @@ -193,7 +194,6 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { signal := struct{}{} patience := 10 * time.Millisecond ws.trafficTimeout = 200 * time.Millisecond - ws.ShutdownC = make(chan struct{}) ws.state.Store(connected) thenish := time.Now() @@ -240,7 +240,6 @@ func TestTrafficMonitorConnecting(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - ws.ShutdownC = make(chan struct{}) ws.state.Store(connecting) ws.trafficTimeout = 50 * time.Millisecond ws.trafficMonitor() @@ -262,7 +261,6 @@ func TestTrafficMonitorShutdown(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - ws.ShutdownC = make(chan struct{}) ws.state.Store(connected) ws.trafficTimeout = time.Minute ws.trafficMonitor() @@ -310,10 +308,14 @@ func TestConnectionMessageErrors(t *testing.T) { wsWrong.setEnabled(true) wsWrong.setState(connecting) - wsWrong.Wg = &sync.WaitGroup{} err = wsWrong.Connect() assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") + wsWrong.setState(disconnected) + err = wsWrong.Connect() + assert.ErrorIs(t, err, common.ErrNilPointer, "Connect should get a nil pointer error, presumably on subs") + + wsWrong.subscriptions = subscription.NewStore() wsWrong.setState(disconnected) wsWrong.connector = func() error { return errDastardlyReason } err = wsWrong.Connect() @@ -323,7 +325,7 @@ func TestConnectionMessageErrors(t *testing.T) { err = ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") ws.trafficTimeout = time.Minute - ws.connector = func() error { return nil } + ws.connector = connect err = ws.Connect() require.NoError(t, err, "Connect must not error") @@ -445,34 +447,48 @@ func TestWebsocket(t *testing.T) { ws.Wg.Wait() } +func currySimpleSub(w *Websocket) func(subscription.List) error { + return func(subs subscription.List) error { + for _, s := range subs { + if err := s.SetState(subscription.SubscribedState); err != nil { + return err + } + } + return w.AddSubscriptions(subs) + } +} + +func currySimpleUnsub(w *Websocket) func(subscription.List) error { + return func(unsubs subscription.List) error { + for _, s := range unsubs { + if err := s.SetState(subscription.InactiveState); err != nil { + return err + } + } + return w.RemoveSubscriptions(unsubs) + } +} + // TestSubscribe logic test func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() ws := NewWebsocket() assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error") - fnSub := func(subs []subscription.Subscription) error { - ws.AddSuccessfulSubscriptions(subs...) - return nil - } - fnUnsub := func(unsubs []subscription.Subscription) error { - ws.RemoveSubscriptions(unsubs...) - return nil - } - ws.Subscriber = fnSub - ws.Unsubscriber = fnUnsub + ws.Subscriber = currySimpleSub(ws) + ws.Unsubscriber = currySimpleUnsub(ws) subs, err := ws.GenerateSubs() assert.NoError(t, err, "Generating test subscriptions should not error") assert.ErrorIs(t, ws.UnsubscribeChannels(nil), errNoSubscriptionsSupplied, "Unsubscribing from nil should error") - assert.ErrorIs(t, ws.UnsubscribeChannels(subs), ErrSubscriptionNotFound, "Unsubscribing should error when not subscribed") + assert.ErrorIs(t, ws.UnsubscribeChannels(subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error") assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") - byDefKey := ws.GetSubscription(subscription.DefaultKey{Channel: "TestSub"}) - if assert.NotNil(t, byDefKey, "GetSubscription by default key should find a channel") { - assert.Equal(t, "TestSub", byDefKey.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") - assert.NotSame(t, byDefKey, ws.subscriptions["TestSub"], "GetSubscription returns a fresh pointer") + bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) + if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { + assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") + assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") } if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") { assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") @@ -485,7 +501,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { } assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, ws.SubscribeToChannels(subs), ErrSubscribedAlready, "Subscribe should error when already subscribed") + assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") assert.ErrorIs(t, ws.SubscribeToChannels(nil), errNoSubscriptionsSupplied, "Subscribe to nil should error") assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error") } @@ -503,48 +519,18 @@ func TestResubscribe(t *testing.T) { err = ws.Setup(defaultSetup) assert.NoError(t, err, "WS Setup should not error") - fnSub := func(subs []subscription.Subscription) error { - ws.AddSuccessfulSubscriptions(subs...) - return nil - } - fnUnsub := func(unsubs []subscription.Subscription) error { - ws.RemoveSubscriptions(unsubs...) - return nil - } - ws.Subscriber = fnSub - ws.Unsubscriber = fnUnsub + ws.Subscriber = currySimpleSub(ws) + ws.Unsubscriber = currySimpleUnsub(ws) - channel := []subscription.Subscription{{Channel: "resubTest"}} + channel := subscription.List{{Channel: "resubTest"}} - assert.ErrorIs(t, ws.ResubscribeToChannel(&channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet") + assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error") - assert.NoError(t, ws.ResubscribeToChannel(&channel[0]), "Resubscribe should not error now the channel is subscribed") + assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed") } -// TestSubscriptionState tests Subscription state changes -func TestSubscriptionState(t *testing.T) { - t.Parallel() - ws := NewWebsocket() - - c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState} - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error") - - assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") - found := ws.GetSubscription(42) - assert.NotNil(t, found, "Should find the subscription") - assert.Equal(t, subscription.SubscribingState, found.State, "Subscription should be Subscribing") - assert.ErrorIs(t, ws.AddSubscription(c), ErrSubscribedAlready, "Adding an already existing sub should error") - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.SubscribingState), ErrChannelInStateAlready, "Setting Same state should error") - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState+1), errInvalidChannelState, "Setting an invalid state should error") - - ws.AddSuccessfulSubscriptions(*c) - found = ws.GetSubscription(42) - assert.NotNil(t, found, "Should find the subscription") - assert.Equal(t, subscription.SubscribedState, found.State, "Subscription should be subscribed state") - - assert.NoError(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), "Setting Unsub state should not error") - found = ws.GetSubscription(42) - assert.Equal(t, subscription.UnsubscribingState, found.State, "Subscription should be unsubscribing state") +func TestAddSubscription(t *testing.T) { + t.Fatal("Not implemented, along with others") } // TestRemoveSubscriptions tests removing a subscription @@ -553,10 +539,11 @@ func TestRemoveSubscriptions(t *testing.T) { ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Unite!"} - assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") + require.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") assert.NotNil(t, ws.GetSubscription(42), "Added subscription should be findable") - ws.RemoveSubscriptions(*c) + err := ws.RemoveSubscriptions(subscription.List{c}) + require.NoError(t, err, "RemoveSubscriptions must not error") assert.Nil(t, ws.GetSubscription(42), "Remove should have removed the sub") } @@ -565,10 +552,7 @@ func TestConnectionMonitorNoConnection(t *testing.T) { t.Parallel() ws := NewWebsocket() ws.connectionMonitorDelay = 500 - ws.DataHandler = make(chan interface{}, 1) - ws.ShutdownC = make(chan struct{}, 1) ws.exchangeName = "hello" - ws.Wg = &sync.WaitGroup{} ws.setEnabled(true) err := ws.connectionMonitor() require.NoError(t, err, "connectionMonitor must not error") @@ -581,31 +565,24 @@ func TestConnectionMonitorNoConnection(t *testing.T) { func TestGetSubscription(t *testing.T) { t.Parallel() assert.Nil(t, (*Websocket).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil") - assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub map should return nil") - w := Websocket{ - subscriptions: subscriptionMap{ - 42: { - Channel: "hello3", - }, - }, - } - assert.Nil(t, w.GetSubscription(43), "GetSubscription with an invalid key should return nil") - c := w.GetSubscription(42) - if assert.NotNil(t, c, "GetSubscription with an valid key should return a channel") { - assert.Equal(t, "hello3", c.Channel, "GetSubscription should return the correct channel details") - } + assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil") + w := NewWebsocket() + assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") + s := &subscription.Subscription{Key: 42, Channel: "hello3"} + w.AddSubscription(s) + assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") } // TestGetSubscriptions logic test func TestGetSubscriptions(t *testing.T) { t.Parallel() - w := Websocket{ - subscriptions: subscriptionMap{ - 42: { - Channel: "hello3", - }, - }, - } + assert.Nil(t, (*Websocket).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil") + assert.Nil(t, (&Websocket{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil") + w := NewWebsocket() + w.AddSubscriptions(subscription.List{ + {Key: 42, Channel: "hello3"}, + {Key: 45, Channel: "hello4"}, + }) assert.Equal(t, "hello3", w.GetSubscriptions()[0].Channel, "GetSubscriptions should return the correct channel details") } @@ -942,47 +919,40 @@ func TestGetChannelDifference(t *testing.T) { t.Parallel() web := Websocket{} - newChans := []subscription.Subscription{ - { - Channel: "Test1", - }, - { - Channel: "Test2", - }, - { - Channel: "Test3", - }, + newChans := subscription.List{ + {Channel: "Test1"}, + {Channel: "Test2"}, + {Channel: "Test3"}, } subs, unsubs := web.GetChannelDifference(newChans) - assert.Len(t, subs, 3, "Should get the correct number of subs") - assert.Empty(t, unsubs, "Should get the correct number of unsubs") + assert.Implements(t, (*subscription.MatchableKey)(nil), subs[0].Key, "Sub key must be matchable") + assert.Equal(t, 3, len(subs), "Should get the correct number of subs") + assert.Empty(t, unsubs, "Should get no unsubs") - web.AddSuccessfulSubscriptions(subs...) + for _, s := range subs { + s.SetState(subscription.SubscribedState) + } - flushedSubs := []subscription.Subscription{ - { - Channel: "Test2", - }, + web.AddSubscriptions(subs) + + flushedSubs := subscription.List{ + {Channel: "Test2"}, } subs, unsubs = web.GetChannelDifference(flushedSubs) - assert.Empty(t, subs, "Should get the correct number of subs") - assert.Len(t, unsubs, 2, "Should get the correct number of unsubs") + assert.Empty(t, subs, "Should get no subs") + assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") - flushedSubs = []subscription.Subscription{ - { - Channel: "Test2", - }, - { - Channel: "Test4", - }, + flushedSubs = subscription.List{ + {Channel: "Test2"}, + {Channel: "Test4"}, } subs, unsubs = web.GetChannelDifference(flushedSubs) - if assert.Len(t, subs, 1, "Should get the correct number of subs") { + if assert.Equal(t, 1, len(subs), "Should get the correct number of subs") { assert.Equal(t, "Test4", subs[0].Channel, "Should subscribe to the right channel") } - if assert.Len(t, unsubs, 2, "Should get the correct number of unsubs") { + if assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") { sort.Slice(unsubs, func(i, j int) bool { return unsubs[i].Channel <= unsubs[j].Channel }) assert.Equal(t, "Test1", unsubs[0].Channel, "Should unsubscribe from the right channels") assert.Equal(t, "Test3", unsubs[1].Channel, "Should unsubscribe from the right channels") @@ -992,23 +962,23 @@ func TestGetChannelDifference(t *testing.T) { // GenSubs defines a theoretical exchange with pair management type GenSubs struct { EnabledPairs currency.Pairs - subscribos []subscription.Subscription - unsubscribos []subscription.Subscription + subscribos subscription.List + unsubscribos subscription.List } // generateSubs default subs created from the enabled pairs list -func (g *GenSubs) generateSubs() ([]subscription.Subscription, error) { - superduperchannelsubs := make([]subscription.Subscription, len(g.EnabledPairs)) +func (g *GenSubs) generateSubs() (subscription.List, error) { + superduperchannelsubs := make(subscription.List, len(g.EnabledPairs)) for i := range g.EnabledPairs { - superduperchannelsubs[i] = subscription.Subscription{ + superduperchannelsubs[i] = &subscription.Subscription{ Channel: "TEST:" + strconv.FormatInt(int64(i), 10), - Pair: g.EnabledPairs[i], + Pairs: currency.Pairs{g.EnabledPairs[i]}, } } return superduperchannelsubs, nil } -func (g *GenSubs) SUBME(subs []subscription.Subscription) error { +func (g *GenSubs) SUBME(subs subscription.List) error { if len(subs) == 0 { return errors.New("WOW") } @@ -1016,7 +986,7 @@ func (g *GenSubs) SUBME(subs []subscription.Subscription) error { return nil } -func (g *GenSubs) UNSUBME(unsubs []subscription.Subscription) error { +func (g *GenSubs) UNSUBME(unsubs subscription.List) error { if len(unsubs) == 0 { return errors.New("WOW") } @@ -1043,82 +1013,66 @@ func TestFlushChannels(t *testing.T) { err = dodgyWs.FlushChannels() assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") - w := Websocket{ - connector: connect, - ShutdownC: make(chan struct{}), - Subscriber: newgen.SUBME, - Unsubscriber: newgen.UNSUBME, - Wg: new(sync.WaitGroup), - features: &protocol.Features{ - // No features - }, - trafficTimeout: time.Second * 30, // Added for when we utilise connect() - // in FlushChannels() so the traffic monitor doesn't time out and turn - // this to an unconnected state - } + w := NewWebsocket() + w.connector = connect + w.Subscriber = newgen.SUBME + w.Unsubscriber = newgen.UNSUBME + // Added for when we utilise connect() in FlushChannels() so the traffic monitor doesn't time out and turn this to an unconnected state + w.trafficTimeout = time.Second * 30 + w.setEnabled(true) w.setState(connected) - problemFunc := func() ([]subscription.Subscription, error) { + problemFunc := func() (subscription.List, error) { return nil, errDastardlyReason } - noSub := func() ([]subscription.Subscription, error) { + noSub := func() (subscription.List, error) { return nil, nil } // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} - w.GenerateSubs = func() ([]subscription.Subscription, error) { - return []subscription.Subscription{{Channel: "test"}}, nil + w.GenerateSubs = func() (subscription.List, error) { + return subscription.List{{Channel: "test"}}, nil } err = w.FlushChannels() - assert.NoError(t, err, "FlushChannels should not error") + require.NoError(t, err, "Flush Channels must not error") w.features.FullPayloadSubscribe = true w.GenerateSubs = problemFunc err = w.FlushChannels() // error on full subscribeToChannels - assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs") w.GenerateSubs = noSub - err = w.FlushChannels() // No subs to unsub - assert.NoError(t, err, "FlushChannels should not error") + err = w.FlushChannels() // No subs to sub + assert.NoError(t, err, "Flush Channels should not error") w.GenerateSubs = newgen.generateSubs subs, err := w.GenerateSubs() require.NoError(t, err, "GenerateSubs must not error") - - w.AddSuccessfulSubscriptions(subs...) + for _, s := range subs { + s.SetState(subscription.SubscribedState) + } + w.AddSubscriptions(subs) err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") w.features.FullPayloadSubscribe = false w.features.Subscribe = true - w.GenerateSubs = problemFunc - err = w.FlushChannels() - assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly") - w.GenerateSubs = newgen.generateSubs - err = w.FlushChannels() - assert.NoError(t, err, "FlushChannels should not error") - w.subscriptionMutex.Lock() - w.subscriptions = subscriptionMap{ - 41: { - Key: 41, - Channel: "match channel", - Pair: currency.NewPair(currency.BTC, currency.AUD), - }, - 42: { - Key: 42, - Channel: "unsub channel", - Pair: currency.NewPair(currency.THETA, currency.USDT), - }, - } - w.subscriptionMutex.Unlock() - - err = w.FlushChannels() - assert.NoError(t, err, "FlushChannels should not error") + w.subscriptions = subscription.NewStore() + w.subscriptions.Add(&subscription.Subscription{ + Key: 41, + Channel: "match channel", + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.AUD)}, + }) + w.subscriptions.Add(&subscription.Subscription{ + Key: 42, + Channel: "unsub channel", + Pairs: currency.Pairs{currency.NewPair(currency.THETA, currency.USDT)}, + }) err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") @@ -1131,9 +1085,7 @@ func TestFlushChannels(t *testing.T) { func TestDisable(t *testing.T) { t.Parallel() - w := Websocket{ - ShutdownC: make(chan struct{}), - } + w := NewWebsocket() w.setEnabled(true) w.setState(connected) require.NoError(t, w.Disable(), "Disable must not error") @@ -1142,16 +1094,11 @@ func TestDisable(t *testing.T) { func TestEnable(t *testing.T) { t.Parallel() - w := Websocket{ - connector: connect, - Wg: new(sync.WaitGroup), - ShutdownC: make(chan struct{}), - GenerateSubs: func() ([]subscription.Subscription, error) { - return []subscription.Subscription{{Channel: "test"}}, nil - }, - Subscriber: func([]subscription.Subscription) error { return nil }, - } - + w := NewWebsocket() + w.connector = connect + w.Subscriber = func(subscription.List) error { return nil } + w.Unsubscriber = func(subscription.List) error { return nil } + w.GenerateSubs = func() (subscription.List, error) { return nil, nil } require.NoError(t, w.Enable(), "Enable must not error") assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } @@ -1262,15 +1209,21 @@ func TestCheckSubscriptions(t *testing.T) { ws.MaxSubscriptionsPerConnection = 1 - err = ws.checkSubscriptions([]subscription.Subscription{{}, {}}) + err = ws.checkSubscriptions(subscription.List{{}, {}}) + assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly when subscriptions is empty") + + ws.subscriptions = subscription.NewStore() + err = ws.checkSubscriptions(subscription.List{{}, {}}) assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 - ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}} - err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}}) - assert.ErrorIs(t, err, ErrSubscribedAlready, "checkSubscriptions should error correctly") + ws.subscriptions = subscription.NewStore() + err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"}) + require.NoError(t, err, "Add subscription must not error") + err = ws.checkSubscriptions(subscription.List{{Key: 42, Channel: "test"}}) + assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly") - err = ws.checkSubscriptions([]subscription.Subscription{{}}) + err = ws.checkSubscriptions(subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index a783d585a4e..353003d0a21 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -22,8 +22,6 @@ const ( UnhandledMessage = " - Unhandled websocket message: " ) -type subscriptionMap map[any]*subscription.Subscription - const ( uninitialised uint32 = iota disconnected @@ -53,19 +51,14 @@ type Websocket struct { connector func() error subscriptionMutex sync.RWMutex - subscriptions subscriptionMap - Subscribe chan []subscription.Subscription - Unsubscribe chan []subscription.Subscription - - // Subscriber function for package defined websocket subscriber - // functionality - Subscriber func([]subscription.Subscription) error - // Unsubscriber function for packaged defined websocket unsubscriber - // functionality - Unsubscriber func([]subscription.Subscription) error - // GenerateSubs function for package defined websocket generate - // subscriptions functionality - GenerateSubs func() ([]subscription.Subscription, error) + subscriptions *subscription.Store + + // Subscriber function for exchange specific subscribe implementation + Subscriber func(subscription.List) error + // Subscriber function for exchange specific unsubscribe implementation + Unsubscriber func(subscription.List) error + // GenerateSubs function for exchange specific generating subscriptions from Features.Subscriptions, Pairs and Assets + GenerateSubs func() (subscription.List, error) DataHandler chan interface{} ToRoutine chan interface{} @@ -74,7 +67,7 @@ type Websocket struct { // shutdown synchronises shutdown event across routines ShutdownC chan struct{} - Wg *sync.WaitGroup + Wg sync.WaitGroup // Orderbook is a local buffer of orderbooks Orderbook buffer.Orderbook @@ -112,9 +105,9 @@ type WebsocketSetup struct { RunningURL string RunningURLAuth string Connector func() error - Subscriber func([]subscription.Subscription) error - Unsubscriber func([]subscription.Subscription) error - GenerateSubscriptions func() ([]subscription.Subscription, error) + Subscriber func(subscription.List) error + Unsubscriber func(subscription.List) error + GenerateSubscriptions func() (subscription.List, error) Features *protocol.Features // Local orderbook buffer config values diff --git a/exchanges/subscription/list.go b/exchanges/subscription/list.go new file mode 100644 index 00000000000..4d3cef229e8 --- /dev/null +++ b/exchanges/subscription/list.go @@ -0,0 +1,16 @@ +package subscription + +import "slices" + +// List is a container of subscription pointers +type List []*Subscription + +// Strings returns a sorted slice of subscriptions +func (l List) Strings() []string { + s := make([]string, len(l)) + for i := range l { + s[i] = l[i].String() + } + slices.Sort(s) + return s +} diff --git a/exchanges/subscription/store.go b/exchanges/subscription/store.go new file mode 100644 index 00000000000..18db7fe338b --- /dev/null +++ b/exchanges/subscription/store.go @@ -0,0 +1,176 @@ +package subscription + +import ( + "maps" + "sync" + + "github.com/thrasher-corp/gocryptotrader/common" +) + +// Store is a container of subscription pointers +type Store struct { + m map[any]*Subscription + mu sync.RWMutex +} + +// NewStore creates a ready to use store and should always be used +func NewStore() *Store { + return &Store{ + m: map[any]*Subscription{}, + } +} + +// NewStoreFromList creates a Store from a List +func NewStoreFromList(l List) (*Store, error) { + s := NewStore() + for _, sub := range l { + if err := s.add(sub); err != nil { + return nil, err + } + } + return s, nil +} + +// Add adds a subscription to the store +// Key can be already set; if omitted EnsureKeyed will be used +// Errors if it already exists +func (s *Store) Add(sub *Subscription) error { + if s == nil || sub == nil { + return common.ErrNilPointer + } + s.mu.Lock() + defer s.mu.Unlock() + return s.add(sub) +} + +// Add adds a subscription to the store +// This method provides no locking protection +func (s *Store) add(sub *Subscription) error { + key := sub.EnsureKeyed() + if found := s.get(key); found != nil { + return ErrDuplicate + } + s.m[key] = sub + return nil +} + +// Get returns a pointer to a subscription or nil if not found +// If key implements MatchableKey then key.Match will be used +func (s *Store) Get(key any) *Subscription { + if s == nil { + return nil + } + s.mu.RLock() + defer s.mu.RUnlock() + return s.get(key) +} + +// get returns a pointer to subscription or nil if not found +// If the key passed in is a Subscription then its Key will be used; which may be a pointer to itself. +// If key implements MatchableKey then key.Match will be used; Note that *Subscription implements MatchableKey +// This method provides no locking protection +// returned subscriptions are implicitly guaranteed to have a Key +func (s *Store) get(key any) *Subscription { + switch v := key.(type) { + case Subscription: + key = v.EnsureKeyed() + case *Subscription: + key = v.EnsureKeyed() + } + + switch v := key.(type) { + case MatchableKey: + return s.match(v) + default: + return s.m[v] + } +} + +// Remove removes a subscription from the store +func (s *Store) Remove(sub *Subscription) error { + if s == nil || sub == nil { + return common.ErrNilPointer + } + s.mu.Lock() + defer s.mu.Unlock() + + if found := s.get(sub); found != nil { + delete(s.m, found.Key) + return nil + } + + return ErrNotFound +} + +// List returns a slice of Subscriptions pointers +func (s *Store) List() List { + if s == nil { + return List{} + } + s.mu.RLock() + defer s.mu.RUnlock() + subs := make(List, 0, len(s.m)) + for _, s := range s.m { + subs = append(subs, s) + } + return subs +} + +// Clear empties the subscription store +func (s *Store) Clear() { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + clear(s.m) +} + +// match returns the first subscription which matches the Key's Asset, Channel and Pairs +// If the key provided has: +// 1) Empty pairs then only Subscriptions without pairs will be considered +// 2) >=1 pairs then Subscriptions which contain all the pairs will be considered +// This method provides no locking protection +func (s *Store) match(key MatchableKey) *Subscription { + for anyKey, s := range s.m { + if key.Match(anyKey) { + return s + } + } + return nil +} + +// Diff returns a list of the added and missing subs from a new list +// The store Diff is invoked upon is read-lock protected +// The new store is assumed to be a new instance and enjoys no locking protection +func (s *Store) Diff(compare List) (added, removed List) { + if s == nil { + return + } + s.mu.RLock() + defer s.mu.RUnlock() + removedMap := maps.Clone(s.m) + for _, sub := range compare { + if found := s.get(sub); found != nil { + delete(removedMap, found.Key) + } else { + added = append(added, sub) + } + } + + for _, c := range removedMap { + removed = append(removed, c) + } + + return +} + +// Len returns the number of subscriptions +func (s *Store) Len() int { + if s == nil { + return 0 + } + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.m) +} diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 874822ba79a..3e340ebbc1c 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -1,92 +1,121 @@ package subscription import ( - "encoding/json" + "errors" "fmt" + "sync" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" ) -// DefaultKey is the fallback key for AddSuccessfulSubscriptions -type DefaultKey struct { - Channel string - Pair currency.Pair - Asset asset.Item -} - -// State tracks the status of a subscription channel -type State uint8 +// State constants +const ( + InactiveState State = iota + SubscribingState + SubscribedState + UnsubscribingState +) +// Ticker constants const ( - UnknownState State = iota // UnknownState subscription state is not registered, but doesn't imply Inactive - SubscribingState // SubscribingState means channel is in the process of subscribing - SubscribedState // SubscribedState means the channel has finished a successful and acknowledged subscription - UnsubscribingState // UnsubscribingState means the channel has started to unsubscribe, but not yet confirmed + TickerChannel = "ticker" + OrderbookChannel = "orderbook" + CandlesChannel = "candles" + AllOrdersChannel = "allOrders" + AllTradesChannel = "allTrades" + MyTradesChannel = "myTrades" + MyOrdersChannel = "myOrders" +) - TickerChannel = "ticker" // TickerChannel Subscription Type - OrderbookChannel = "orderbook" // OrderbookChannel Subscription Type - CandlesChannel = "candles" // CandlesChannel Subscription Type - AllOrdersChannel = "allOrders" // AllOrdersChannel Subscription Type - AllTradesChannel = "allTrades" // AllTradesChannel Subscription Type - MyTradesChannel = "myTrades" // MyTradesChannel Subscription Type - MyOrdersChannel = "myOrders" // MyOrdersChannel Subscription Type +// Public errors +var ( + ErrNotFound = errors.New("subscription not found") + ErrNotSinglePair = errors.New("only single pair subscriptions expected") + ErrInStateAlready = errors.New("subscription already in state") + ErrInvalidState = errors.New("invalid subscription state") + ErrDuplicate = errors.New("duplicate subscription") ) +// State tracks the status of a subscription channel +type State uint8 + // Subscription container for streaming subscriptions type Subscription struct { Enabled bool `json:"enabled"` Key any `json:"-"` Channel string `json:"channel,omitempty"` - Pair currency.Pair `json:"pair,omitempty"` + Pairs currency.Pairs `json:"pairs,omitempty"` Asset asset.Item `json:"asset,omitempty"` Params map[string]interface{} `json:"params,omitempty"` - State State `json:"-"` Interval kline.Interval `json:"interval,omitempty"` Levels int `json:"levels,omitempty"` Authenticated bool `json:"authenticated,omitempty"` + state State + m sync.RWMutex } -// MarshalJSON generates a JSON representation of a Subscription, specifically for config writing -// The only reason it exists is to avoid having to make Pair a pointer, since that would be generally painful -// If Pair becomes a pointer, this method is redundant and should be removed -func (s *Subscription) MarshalJSON() ([]byte, error) { - // None of the usual type embedding tricks seem to work for not emitting an nil Pair - // The embedded type's Pair always fills the empty value - type MaybePair struct { - Enabled bool `json:"enabled"` - Channel string `json:"channel,omitempty"` - Asset asset.Item `json:"asset,omitempty"` - Params map[string]interface{} `json:"params,omitempty"` - Interval kline.Interval `json:"interval,omitempty"` - Levels int `json:"levels,omitempty"` - Authenticated bool `json:"authenticated,omitempty"` - Pair *currency.Pair `json:"pair,omitempty"` - } - - k := MaybePair{s.Enabled, s.Channel, s.Asset, s.Params, s.Interval, s.Levels, s.Authenticated, nil} - if s.Pair != currency.EMPTYPAIR { - k.Pair = &s.Pair - } - - return json.Marshal(k) +// MatchableKey interface should be implemented by Key types which want a more complex matching than a simple key equality check +type MatchableKey interface { + Match(any) bool } // String implements the Stringer interface for Subscription, giving a human representation of the subscription func (s *Subscription) String() string { - return fmt.Sprintf("%s %s %s", s.Channel, s.Asset, s.Pair) + return fmt.Sprintf("%s %s %s", s.Channel, s.Asset, s.Pairs) +} + +// State returns the subscription state +func (s *Subscription) State() State { + s.m.RLock() + defer s.m.RUnlock() + return s.state +} + +// SetState sets the subscription state +// Errors if already in that state or the new state is not valid +func (s *Subscription) SetState(state State) error { + s.m.Lock() + defer s.m.Unlock() + if state == s.state { + return ErrInStateAlready + } + if state > UnsubscribingState { + return ErrInvalidState + } + s.state = state + return nil } // EnsureKeyed sets the default key on a channel if it doesn't have one // Returns key for convenience func (s *Subscription) EnsureKeyed() any { if s.Key == nil { - s.Key = DefaultKey{ - Channel: s.Channel, - Asset: s.Asset, - Pair: s.Pair, - } + s.Key = s } return s.Key } + +// Match returns if the two keys match Channels, Assets, Pairs, Interval and Levels: +// Key Pairs comparison: +// 1) Empty pairs then only Subscriptions without pairs match +// 2) >=1 pairs then Subscriptions which contain all the pairs match +// Such that a subscription for all enabled pairs will be matched when seaching for any one pair +func (s *Subscription) Match(key any) bool { + b, ok := key.(*Subscription) + switch { + case !ok, + s.Channel != b.Channel, + s.Asset != b.Asset, + len(b.Pairs) == 0 && len(s.Pairs) != 0, + // len(b.Pairs) == 0 && len(s.Pairs) == 0: Okay; continue to next non-pairs check + len(b.Pairs) != 0 && len(s.Pairs) == 0, + len(b.Pairs) != 0 && s.Pairs.ContainsAll(b.Pairs, true) != nil, + s.Levels != b.Levels, + s.Interval != b.Interval: + return false + } + + return true +} diff --git a/exchanges/subscription/subscription_test.go b/exchanges/subscription/subscription_test.go index 4f9a97ab979..38cabb8694a 100644 --- a/exchanges/subscription/subscription_test.go +++ b/exchanges/subscription/subscription_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" @@ -13,28 +14,25 @@ import ( // TestEnsureKeyed logic test func TestEnsureKeyed(t *testing.T) { t.Parallel() - c := Subscription{ + s := &Subscription{ Channel: "candles", Asset: asset.Spot, - Pair: currency.NewPair(currency.BTC, currency.USDT), + Pairs: []currency.Pair{currency.NewPair(currency.BTC, currency.USDT)}, } - k1, ok := c.EnsureKeyed().(DefaultKey) - if assert.True(t, ok, "EnsureKeyed should return a DefaultKey") { - assert.Exactly(t, k1, c.Key, "EnsureKeyed should set the same key") - assert.Equal(t, k1.Channel, c.Channel, "DefaultKey channel should be correct") - assert.Equal(t, k1.Asset, c.Asset, "DefaultKey asset should be correct") - assert.Equal(t, k1.Pair, c.Pair, "DefaultKey currency should be correct") + k1, ok := s.EnsureKeyed().(*Subscription) + if assert.True(t, ok, "EnsureKeyed should return a *Subscription") { + assert.Same(t, k1, s, "Key should point to the same struct") } type platypus string - c = Subscription{ + s = &Subscription{ Key: platypus("Gerald"), Channel: "orderbook", Asset: asset.Margin, - Pair: currency.NewPair(currency.ETH, currency.USDC), + Pairs: []currency.Pair{currency.NewPair(currency.ETH, currency.USDC)}, } - k2, ok := c.EnsureKeyed().(platypus) + k2, ok := s.EnsureKeyed().(platypus) if assert.True(t, ok, "EnsureKeyed should return a platypus") { - assert.Exactly(t, k2, c.Key, "EnsureKeyed should set the same key") + assert.Exactly(t, k2, s.Key, "ensureKeyed should set the same key") assert.EqualValues(t, "Gerald", k2, "key should have the correct value") } } @@ -50,11 +48,25 @@ func TestMarshaling(t *testing.T) { assert.NoError(t, err, "Marshalling should not error") assert.Equal(t, `{"enabled":true,"channel":"orderbook","interval":"5m","levels":4}`, string(j), "Marshalling should be clean and concise") - j, err = json.Marshal(&Subscription{Enabled: true, Channel: OrderbookChannel, Interval: kline.FiveMin, Levels: 4, Pair: currency.NewPair(currency.BTC, currency.USDT)}) + j, err = json.Marshal(&Subscription{Enabled: true, Channel: OrderbookChannel, Interval: kline.FiveMin, Levels: 4, Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USDT)}}) assert.NoError(t, err, "Marshalling should not error") - assert.Equal(t, `{"enabled":true,"channel":"orderbook","interval":"5m","levels":4,"pair":"BTCUSDT"}`, string(j), "Marshalling should be clean and concise") + assert.Equal(t, `{"enabled":true,"channel":"orderbook","pairs":"BTCUSDT","interval":"5m","levels":4}`, string(j), "Marshalling should be clean and concise") j, err = json.Marshal(&Subscription{Enabled: true, Channel: MyTradesChannel, Authenticated: true}) assert.NoError(t, err, "Marshalling should not error") assert.Equal(t, `{"enabled":true,"channel":"myTrades","authenticated":true}`, string(j), "Marshalling should be clean and concise") } + +// TestSetState tests Subscription state changes +func TestSetState(t *testing.T) { + t.Parallel() + + s := &Subscription{Key: 42, Channel: "Gophers"} + assert.Equal(t, InactiveState, s.State(), "State should start as unknown") + require.NoError(t, s.SetState(SubscribingState), "SetState should not error") + assert.Equal(t, SubscribingState, s.State(), "State should be set correctly") + assert.ErrorIs(t, s.SetState(SubscribingState), ErrInStateAlready, "SetState should error on same state") + assert.ErrorIs(t, s.SetState(UnsubscribingState+1), ErrInvalidState, "Setting an invalid state should error") + require.NoError(t, s.SetState(UnsubscribingState), "SetState should not error") + assert.Equal(t, UnsubscribingState, s.State(), "State should be set correctly") +} diff --git a/internal/testing/exchange/exchange.go b/internal/testing/exchange/exchange.go index b8b20dfb624..22e5fb28f13 100644 --- a/internal/testing/exchange/exchange.go +++ b/internal/testing/exchange/exchange.go @@ -148,13 +148,15 @@ func SetupWs(tb testing.TB, e exchange.IBotExchange) { } b := e.GetBase() - if !b.Websocket.IsEnabled() { + w, err := b.GetWebsocket() + if err != nil || !b.Websocket.IsEnabled() { tb.Skip("Websocket not enabled") } - if b.Websocket.IsConnected() { + if w.IsConnected() { return } - err := b.Websocket.Connect() + + err = w.Connect() require.NoError(tb, err, "WsConnect should not error") setupWsOnce[e] = true diff --git a/internal/testing/subscriptions/subscriptions.go b/internal/testing/subscriptions/subscriptions.go new file mode 100644 index 00000000000..1604be279b0 --- /dev/null +++ b/internal/testing/subscriptions/subscriptions.go @@ -0,0 +1,27 @@ +package subscriptionstest + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" +) + +func Equal(tb testing.TB, a, b subscription.List) { + tb.Helper() + s, err := subscription.NewStoreFromList(a) + require.NoError(t, err, "NewStoreFromList must not error") + added, missing := s.Diff(b) + if len(added) > 0 || len(missing) > 0 { + fail := "Differences:" + if len(added) > 0 { + fail = fail + "\n + " + strings.Join(added.Strings(), "\n + ") + } + if len(missing) > 0 { + fail = fail + "\n - " + strings.Join(missing.Strings(), "\n - ") + } + assert.Fail(tb, fail, "Subscriptions should be equal") + } +} From 2ca3de1de1a3b7f9b17a6ffddd9185622eaa5f60 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 19 Feb 2024 14:02:39 +0700 Subject: [PATCH 31/35] Subscriptions: Use Pairs in all exchanges Except Kraken, which needs atomicity to not collide with upcoming work --- exchanges/binanceus/binanceus_websocket.go | 2 +- exchanges/bitfinex/bitfinex_test.go | 26 ++--- exchanges/bitfinex/bitfinex_websocket.go | 95 +++++++++++++++---- exchanges/bithumb/bithumb_websocket.go | 4 +- exchanges/bitmex/bitmex_websocket.go | 4 +- exchanges/bitstamp/bitstamp_websocket.go | 4 +- exchanges/btcmarkets/btcmarkets_websocket.go | 12 +-- exchanges/btse/btse_websocket.go | 2 +- exchanges/coinbasepro/coinbasepro_test.go | 2 +- .../coinbasepro/coinbasepro_websocket.go | 6 +- exchanges/coinut/coinut_websocket.go | 6 +- exchanges/gateio/gateio_websocket.go | 6 +- .../gateio/gateio_ws_delivery_futures.go | 6 +- exchanges/gateio/gateio_ws_futures.go | 6 +- exchanges/gateio/gateio_ws_option.go | 8 +- exchanges/gemini/gemini_websocket.go | 10 +- exchanges/hitbtc/hitbtc_websocket.go | 8 +- exchanges/huobi/huobi_websocket.go | 2 +- exchanges/kucoin/kucoin_test.go | 2 +- exchanges/okcoin/okcoin_websocket.go | 10 +- exchanges/poloniex/poloniex_websocket.go | 6 +- 21 files changed, 142 insertions(+), 85 deletions(-) diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 14098c1139b..90d710997f5 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -560,7 +560,7 @@ subs: } subscriptions = append(subscriptions, subscription.Subscription{ Channel: lp.String() + channels[z], - Pair: pairs[y], + Pairs: currency.Pairs{pairs[y]}, Asset: asset.Spot, }) } diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index e7be0fbe707..4baace3186f 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1158,7 +1158,7 @@ func TestWsAuth(t *testing.T) { // See also TestSubscribeReq which covers key and symbol conversion func TestWsSubscribe(t *testing.T) { setupWs(t) - err := b.Subscribe([]subscription.Subscription{{Channel: wsTicker, Pair: currency.NewPair(currency.BTC, currency.USD), Asset: asset.Spot}}) + err := b.Subscribe([]subscription.Subscription{{Channel: wsTicker, Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USD)}, Asset: asset.Spot}}) assert.NoError(t, err, "Subrcribe should not error") catcher := func() (ok bool) { i := <-b.Websocket.DataHandler @@ -1171,7 +1171,7 @@ func TestWsSubscribe(t *testing.T) { assert.NoError(t, err, "GetSubscriptions should not error") assert.Len(t, subs, 1, "We should only have 1 subscription; subID subscription should have been Removed by subscribeToChan") - err = b.Subscribe([]subscription.Subscription{{Channel: wsTicker, Pair: currency.NewPair(currency.BTC, currency.USD), Asset: asset.Spot}}) + err = b.Subscribe([]subscription.Subscription{{Channel: wsTicker, Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USD)}, Asset: asset.Spot}}) assert.ErrorIs(t, err, stream.ErrSubscriptionFailure, "Duplicate subscription should error correctly") catcher = func() bool { i := <-b.Websocket.DataHandler @@ -1202,7 +1202,7 @@ func TestWsSubscribe(t *testing.T) { err = b.Subscribe([]subscription.Subscription{{ Channel: wsTicker, - Pair: currency.NewPair(currency.BTC, currency.USD), + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USD)}, Asset: asset.Spot, Params: map[string]interface{}{"key": "tBTCUSD"}, }}) @@ -1214,7 +1214,7 @@ func TestWsSubscribe(t *testing.T) { func TestSubscribeReq(t *testing.T) { c := &subscription.Subscription{ Channel: wsCandles, - Pair: currency.NewPair(currency.BTC, currency.USD), + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USD)}, Asset: asset.MarginFunding, Params: map[string]interface{}{ CandlesPeriodKey: "30", @@ -1233,14 +1233,14 @@ func TestSubscribeReq(t *testing.T) { c = &subscription.Subscription{ Channel: wsBook, - Pair: currency.NewPair(currency.BTC, currency.DOGE), + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.DOGE)}, Asset: asset.Spot, } r, err = subscribeReq(c) assert.NoError(t, err, "subscribeReq should not error") assert.Equal(t, "tBTC:DOGE", r["symbol"], "symbol should use colon delimiter if a currency is > 3 chars") - c.Pair = currency.NewPair(currency.BTC, currency.LTC) + c.Pairs = currency.Pairs{currency.NewPair(currency.BTC, currency.LTC)} r, err = subscribeReq(c) assert.NoError(t, err, "subscribeReq should not error") assert.Equal(t, "tBTCLTC", r["symbol"], "symbol should not use colon delimiter if both currencies < 3 chars") @@ -1353,7 +1353,7 @@ func TestWsSubscribedResponse(t *testing.T) { } func TestWsOrderBook(t *testing.T) { - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Key: 23405, Asset: asset.Spot, Pair: btcusdPair, Channel: wsBook}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Key: 23405, Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsBook}) pressXToJSON := `[23405,[[38334303613,9348.8,0.53],[38334308111,9348.8,5.98979404],[38331335157,9344.1,1.28965787],[38334302803,9343.8,0.08230094],[38334279092,9343,0.8],[38334307036,9342.938663676,0.8],[38332749107,9342.9,0.2],[38332277330,9342.8,0.85],[38329406786,9342,0.1432012],[38332841570,9341.947288638,0.3],[38332163238,9341.7,0.3],[38334303384,9341.6,0.324],[38332464840,9341.4,0.5],[38331935870,9341.2,0.5],[38334312082,9340.9,0.02126899],[38334261292,9340.8,0.26763],[38334138680,9340.625455254,0.12],[38333896802,9339.8,0.85],[38331627527,9338.9,1.57863959],[38334186713,9338.9,0.26769],[38334305819,9338.8,2.999],[38334211180,9338.75285796,3.999],[38334310699,9337.8,0.10679883],[38334307414,9337.5,1],[38334179822,9337.1,0.26773],[38334306600,9336.659955102,1.79],[38334299667,9336.6,1.1],[38334306452,9336.6,0.13979771],[38325672859,9336.3,1.25],[38334311646,9336.2,1],[38334258509,9336.1,0.37],[38334310592,9336,1.79],[38334310378,9335.6,1.43],[38334132444,9335.2,0.26777],[38331367325,9335,0.07],[38334310703,9335,0.10680562],[38334298209,9334.7,0.08757301],[38334304857,9334.456899462,0.291],[38334309940,9334.088390727,0.0725],[38334310377,9333.7,1.2868],[38334297615,9333.607784,0.1108],[38334095188,9333.3,0.26785],[38334228913,9332.7,0.40861186],[38334300526,9332.363996604,0.3884],[38334310701,9332.2,0.10680562],[38334303548,9332.005382871,0.07],[38334311798,9331.8,0.41285228],[38334301012,9331.7,1.7952],[38334089877,9331.4,0.2679],[38321942150,9331.2,0.2],[38334310670,9330,1.069],[38334063096,9329.6,0.26796],[38334310700,9329.4,0.10680562],[38334310404,9329.3,1],[38334281630,9329.1,6.57150597],[38334036864,9327.7,0.26801],[38334310702,9326.6,0.10680562],[38334311799,9326.1,0.50220625],[38334164163,9326,0.219638],[38334309722,9326,1.5],[38333051682,9325.8,0.26807],[38334302027,9325.7,0.75],[38334203435,9325.366592,0.32397696],[38321967613,9325,0.05],[38334298787,9324.9,0.3],[38334301719,9324.8,3.6227592],[38331316716,9324.763454646,0.71442],[38334310698,9323.8,0.10680562],[38334035499,9323.7,0.23431017],[38334223472,9322.670551788,0.42150603],[38334163459,9322.560399006,0.143967],[38321825171,9320.8,2],[38334075805,9320.467496148,0.30772633],[38334075800,9319.916732238,0.61457592],[38333682302,9319.7,0.0011],[38331323088,9319.116771762,0.12913],[38333677480,9319,0.0199],[38334277797,9318.6,0.89],[38325235155,9318.041088,1.20249],[38334310910,9317.82382938,1.79],[38334311811,9317.2,0.61079138],[38334311812,9317.2,0.71937652],[38333298214,9317.1,50],[38334306359,9317,1.79],[38325531545,9316.382823951,0.21263],[38333727253,9316.3,0.02316372],[38333298213,9316.1,45],[38333836479,9316,2.135],[38324520465,9315.9,2.7681],[38334307411,9315.5,1],[38330313617,9315.3,0.84455],[38334077770,9315.294024,0.01248397],[38334286663,9315.294024,1],[38325533762,9315.290315394,2.40498],[38334310018,9315.2,3],[38333682617,9314.6,0.0011],[38334304794,9314.6,0.76364676],[38334304798,9314.3,0.69242113],[38332915733,9313.8,0.0199],[38334084411,9312.8,1],[38334311893,9350.1,-1.015],[38334302734,9350.3,-0.26737],[38334300732,9350.8,-5.2],[38333957619,9351,-0.90677089],[38334300521,9351,-1.6457],[38334301600,9351.012829557,-0.0523],[38334308878,9351.7,-2.5],[38334299570,9351.921544,-0.1015],[38334279367,9352.1,-0.26732],[38334299569,9352.411802928,-0.4036],[38334202773,9353.4,-0.02139404],[38333918472,9353.7,-1.96412776],[38334278782,9354,-0.26731],[38334278606,9355,-1.2785],[38334302105,9355.439221251,-0.79191542],[38313897370,9355.569409242,-0.43363],[38334292995,9355.584296,-0.0979],[38334216989,9355.8,-0.03686414],[38333894025,9355.9,-0.26721],[38334293798,9355.936691952,-0.4311],[38331159479,9356,-0.4204022],[38333918888,9356.1,-1.10885563],[38334298205,9356.4,-0.20124428],[38328427481,9356.5,-0.1],[38333343289,9356.6,-0.41034213],[38334297205,9356.6,-0.08835018],[38334277927,9356.741101161,-0.0737],[38334311645,9356.8,-0.5],[38334309002,9356.9,-5],[38334309736,9357,-0.10680107],[38334306448,9357.4,-0.18645275],[38333693302,9357.7,-0.2672],[38332815159,9357.8,-0.0011],[38331239824,9358.2,-0.02],[38334271608,9358.3,-2.999],[38334311971,9358.4,-0.55],[38333919260,9358.5,-1.9972841],[38334265365,9358.5,-1.7841],[38334277960,9359,-3],[38334274601,9359.020969848,-3],[38326848839,9359.1,-0.84],[38334291080,9359.247048,-0.16199869],[38326848844,9359.4,-1.84],[38333680200,9359.6,-0.26713],[38331326606,9359.8,-0.84454],[38334309738,9359.8,-0.10680107],[38331314707,9359.9,-0.2],[38333919803,9360.9,-1.41177599],[38323651149,9361.33417827,-0.71442],[38333656906,9361.5,-0.26705],[38334035500,9361.5,-0.40861586],[38334091886,9362.4,-6.85940815],[38334269617,9362.5,-4],[38323629409,9362.545858872,-2.40497],[38334309737,9362.7,-0.10680107],[38334312380,9362.7,-3],[38325280830,9362.8,-1.75123],[38326622800,9362.8,-1.05145],[38333175230,9363,-0.0011],[38326848745,9363.2,-0.79],[38334308960,9363.206775564,-0.12],[38333920234,9363.3,-1.25318113],[38326848843,9363.4,-1.29],[38331239823,9363.4,-0.02],[38333209613,9363.4,-0.26719],[38334299964,9364,-0.05583123],[38323470224,9364.161816648,-0.12912],[38334284711,9365,-0.21346019],[38334299594,9365,-2.6757062],[38323211816,9365.073132585,-0.21262],[38334312456,9365.1,-0.11167861],[38333209612,9365.2,-0.26719],[38327770474,9365.3,-0.0073],[38334298788,9365.3,-0.3],[38334075803,9365.409831204,-0.30772637],[38334309740,9365.5,-0.10680107],[38326608767,9365.7,-2.76809],[38333920657,9365.7,-1.25848083],[38329594226,9366.6,-0.02587],[38334311813,9366.7,-4.72290945],[38316386301,9367.39258128,-2.37581],[38334302026,9367.4,-4.5],[38334228915,9367.9,-0.81725458],[38333921381,9368.1,-1.72213641],[38333175678,9368.2,-0.0011],[38334301150,9368.2,-2.654604],[38334297208,9368.3,-0.78036466],[38334309739,9368.3,-0.10680107],[38331227515,9368.7,-0.02],[38331184470,9369,-0.003975],[38334203436,9369.319616,-0.32397695],[38334269964,9369.7,-0.5],[38328386732,9370,-4.11759935],[38332719555,9370,-0.025],[38333921935,9370.5,-1.2224398],[38334258511,9370.5,-0.35],[38326848842,9370.8,-0.34],[38333985038,9370.9,-0.8551502],[38334283018,9370.9,-1],[38326848744,9371,-1.34]],5]` err := b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1370,7 +1370,7 @@ func TestWsOrderBook(t *testing.T) { } func TestWsTradeResponse(t *testing.T) { - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: btcusdPair, Channel: wsTrades, Key: 18788}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTrades, Key: 18788}) pressXToJSON := `[18788,[[412685577,1580268444802,11.1998,176.3],[412685575,1580268444802,5,176.29952759],[412685574,1580268374717,1.99069999,176.41],[412685573,1580268374717,1.00930001,176.41],[412685572,1580268358760,0.9907,176.47],[412685571,1580268324362,0.5505,176.44],[412685570,1580268297270,-0.39040819,176.39],[412685568,1580268297270,-0.39780162,176.46475676],[412685567,1580268283470,-0.09,176.41],[412685566,1580268256536,-2.31310783,176.48],[412685565,1580268256536,-0.59669217,176.49],[412685564,1580268256536,-0.9902,176.49],[412685562,1580268194474,0.9902,176.55],[412685561,1580268186215,0.1,176.6],[412685560,1580268185964,-2.17096773,176.5],[412685559,1580268185964,-1.82903227,176.51],[412685558,1580268181215,2.098914,176.53],[412685557,1580268169844,16.7302,176.55],[412685556,1580268169844,3.25,176.54],[412685555,1580268155725,0.23576115,176.45],[412685553,1580268155725,3,176.44596249],[412685552,1580268155725,3.25,176.44],[412685551,1580268155725,5,176.44],[412685550,1580268155725,0.65830078,176.41],[412685549,1580268155725,0.45063807,176.41],[412685548,1580268153825,-0.67604704,176.39],[412685547,1580268145713,2.5883,176.41],[412685543,1580268087513,12.92927,176.33],[412685542,1580268087513,0.40083,176.33],[412685533,1580268005756,-0.17096773,176.32]]]` err := b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1379,7 +1379,7 @@ func TestWsTradeResponse(t *testing.T) { } func TestWsTickerResponse(t *testing.T) { - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: btcusdPair, Channel: wsTicker, Key: 11534}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTicker, Key: 11534}) pressXToJSON := `[11534,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err := b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1389,7 +1389,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: pair, Channel: wsTicker, Key: 123412}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: pair, Channel: wsTicker, Key: 123412}) pressXToJSON = `[123412,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1399,7 +1399,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: pair, Channel: wsTicker, Key: 123413}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: pair, Channel: wsTicker, Key: 123413}) pressXToJSON = `[123413,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1409,7 +1409,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: pair, Channel: wsTicker, Key: 123414}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: pair, Channel: wsTicker, Key: 123414}) pressXToJSON = `[123414,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1418,7 +1418,7 @@ func TestWsTickerResponse(t *testing.T) { } func TestWsCandleResponse(t *testing.T) { - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: btcusdPair, Channel: wsCandles, Key: 343351}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsCandles, Key: 343351}) pressXToJSON := `[343351,[[1574698260000,7379.785503,7383.8,7388.3,7379.785503,1.68829482]]]` err := b.wsHandleData([]byte(pressXToJSON)) if err != nil { diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index ae7cded7477..df97ad55beb 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -506,7 +506,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { chanID, err := jsonparser.GetInt(respRaw, "chanId") if err != nil { - return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, errParsingWSField, err, c.Channel, c.Pair) + return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, errParsingWSField, err, c.Channel, c.Pairs) } // Note: chanID's int type avoids conflicts with the string type subID key because of the type difference @@ -516,7 +516,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { b.Websocket.AddSuccessfulSubscriptions(*c) if b.Verbose { - log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pair, chanID) + log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pairs, chanID) } if !b.Websocket.Match.IncomingWithData("subscribe:"+subID, respRaw) { return fmt.Errorf("%v channel subscribe listener not found", subID) @@ -525,6 +525,10 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { } func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType string, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if eventType == wsChecksum { return b.handleWSChecksum(c, d) } @@ -533,6 +537,10 @@ func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType return nil } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } + switch c.Channel { case wsBook: return b.handleWSBookUpdate(c, d) @@ -548,6 +556,9 @@ func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType } func (b *Bitfinex) handleWSChecksum(c *subscription.Subscription, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } var token int if f, ok := d[2].(float64); !ok { return common.GetTypeAssertError("float64", d[2], "checksum") @@ -579,6 +590,12 @@ func (b *Bitfinex) handleWSChecksum(c *subscription.Subscription, d []interface{ } func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } var newOrderbook []WebsocketBook obSnapBundle, ok := d[1].([]interface{}) if !ok { @@ -632,7 +649,7 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac Amount: rateAmount}) } } - if err := b.WsInsertSnapshot(c.Pair, c.Asset, newOrderbook, fundingRate); err != nil { + if err := b.WsInsertSnapshot(c.Pairs[0], c.Asset, newOrderbook, fundingRate); err != nil { return fmt.Errorf("inserting snapshot error: %s", err) } @@ -664,7 +681,7 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac Amount: amountRate}) } - if err := b.WsUpdateOrderbook(c, c.Pair, c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil { + if err := b.WsUpdateOrderbook(c, c.Pairs[0], c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil { return fmt.Errorf("updating orderbook error: %s", err) } @@ -674,6 +691,12 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac } func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } candleBundle, ok := d[1].([]interface{}) if !ok || len(candleBundle) == 0 { return nil @@ -712,7 +735,7 @@ func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interf } klineData.Exchange = b.Name klineData.AssetType = c.Asset - klineData.Pair = c.Pair + klineData.Pair = c.Pairs[0] b.Websocket.DataHandler <- klineData } case float64: @@ -741,13 +764,19 @@ func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interf } klineData.Exchange = b.Name klineData.AssetType = c.Asset - klineData.Pair = c.Pair + klineData.Pair = c.Pairs[0] b.Websocket.DataHandler <- klineData } return nil } func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } tickerData, ok := d[1].([]interface{}) if !ok { return errors.New("type assertion for tickerData") @@ -755,7 +784,7 @@ func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interf t := &ticker.Price{ AssetType: c.Asset, - Pair: c.Pair, + Pair: c.Pairs[0], ExchangeName: b.Name, } @@ -821,6 +850,12 @@ func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interf } func (b *Bitfinex) handleWSTradesUpdate(c *subscription.Subscription, eventType string, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } if !b.IsSaveTradeDataEnabled() { return nil } @@ -936,7 +971,7 @@ func (b *Bitfinex) handleWSTradesUpdate(c *subscription.Subscription, eventType } trades[i] = trade.Data{ TID: strconv.FormatInt(tradeHolder[i].ID, 10), - CurrencyPair: c.Pair, + CurrencyPair: c.Pairs[0], Timestamp: time.UnixMilli(tradeHolder[i].Timestamp), Price: price, Amount: newAmount, @@ -1510,6 +1545,12 @@ func (b *Bitfinex) WsInsertSnapshot(p currency.Pair, assetType asset.Item, books // WsUpdateOrderbook updates the orderbook list, removing and adding to the // orderbook sides func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pair, assetType asset.Item, book []WebsocketBook, sequenceNo int64, fundingRate bool) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } orderbookUpdate := orderbook.Update{ Asset: assetType, Pair: p, @@ -1592,7 +1633,9 @@ func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pa if err = validateCRC32(ob, checkme.Token); err != nil { log.Errorf(log.WebsocketMgr, "%s websocket orderbook update error, will resubscribe orderbook: %v", b.Name, err) - b.resubOrderbook(c) + if e2 := b.resubOrderbook(c); e2 != nil { + log.Errorf(log.WebsocketMgr, "%s error resubscribing orderbook: %v", b.Name, e2) + } return err } } @@ -1603,8 +1646,15 @@ func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pa // resubOrderbook resubscribes the orderbook after a consistency error, probably a failed checksum, // which forces a fresh snapshot. If we don't do this the orderbook will keep erroring and drifting. // Flushing the orderbook happens immediately, but the ReSub itself is a go routine to avoid blocking the WS data channel -func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) { - if err := b.Websocket.Orderbook.FlushOrderbook(c.Pair, c.Asset); err != nil { +func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } + if err := b.Websocket.Orderbook.FlushOrderbook(c.Pairs[0], c.Asset); err != nil { + // Non-fatal error log.Errorf(log.ExchangeSys, "%s error flushing orderbook: %v", b.Name, err) } @@ -1645,7 +1695,7 @@ func (b *Bitfinex) GenerateDefaultSubscriptions() ([]subscription.Subscription, subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[j], - Pair: enabledPairs[k], + Pairs: enabledPairs[k], Params: params, Asset: assets[i], }) @@ -1684,7 +1734,7 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error { c := chans[0] req, err := subscribeReq(&c) if err != nil { - return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair) + return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs) } // subId is a single round-trip identifier that provides linking sub requests to chanIDs @@ -1699,7 +1749,7 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error { c.State = subscription.SubscribingState err = b.Websocket.AddSubscription(&c) if err != nil { - return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pair, err) + return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err) } // Always remove the temporary subscription keyed by subID @@ -1707,11 +1757,11 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error { respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("subscribe:"+subID, req) if err != nil { - return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair) + return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs) } if err = b.getErrResp(respRaw); err != nil { - wErr := fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair) + wErr := fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs) b.Websocket.DataHandler <- wErr return wErr } @@ -1721,6 +1771,13 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error { // subscribeReq returns a map of request params for subscriptions func subscribeReq(c *subscription.Subscription) (map[string]interface{}, error) { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } + pair := c.Pairs[0] req := map[string]interface{}{ "event": "subscribe", "channel": c.Channel, @@ -1743,13 +1800,13 @@ func subscribeReq(c *subscription.Subscription) (map[string]interface{}, error) prefix = "f" } - needsDelimiter := c.Pair.Len() > 6 + needsDelimiter := pair.Len() > 6 var formattedPair string if needsDelimiter { - formattedPair = c.Pair.Format(currency.PairFormat{Uppercase: true, Delimiter: ":"}).String() + formattedPair = pair.Format(currency.PairFormat{Uppercase: true, Delimiter: ":"}).String() } else { - formattedPair = currency.PairFormat{Uppercase: true}.Format(c.Pair) + formattedPair = currency.PairFormat{Uppercase: true}.Format(pair) } if c.Channel == wsCandles { diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 667ce131c71..990d9c767e4 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -184,7 +184,7 @@ func (b *Bithumb) GenerateSubscriptions() ([]subscription.Subscription, error) { for y := range channels { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[y], - Pair: pairs[x].Format(pFmt), + Pairs: pairs[x].Format(pFmt), Asset: asset.Spot, }) } @@ -203,7 +203,7 @@ func (b *Bithumb) Subscribe(channelsToSubscribe []subscription.Subscription) err } subs[channelsToSubscribe[i].Channel] = s } - s.Symbols = append(s.Symbols, channelsToSubscribe[i].Pair) + s.Symbols = append(s.Symbols, channelsToSubscribe[i].Pairs) } tSub, ok := subs["ticker"] diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index e1a475254f6..82c36a6e69c 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -571,7 +571,7 @@ func (b *Bitmex) GenerateDefaultSubscriptions() ([]subscription.Subscription, er } subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[z] + ":" + pFmt.Format(contracts[y]), - Pair: contracts[y], + Pairs: contracts[y], Asset: assets[x], }) } @@ -621,7 +621,7 @@ func (b *Bitmex) GenerateAuthenticatedSubscriptions() ([]subscription.Subscripti for j := range contracts { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i] + ":" + pFmt.Format(contracts[j]), - Pair: contracts[j], + Pairs: contracts[j], Asset: asset.PerpetualContract, }) } diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index 98aa6201df4..094ec21e785 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -246,7 +246,7 @@ func (b *Bitstamp) generateDefaultSubscriptions() ([]subscription.Subscription, subscriptions = append(subscriptions, subscription.Subscription{ Channel: defaultSubChannels[j] + "_" + p.String(), Asset: asset.Spot, - Pair: p, + Pairs: currency.Pairs{p}, }) } if b.Websocket.CanUseAuthenticatedEndpoints() { @@ -254,7 +254,7 @@ func (b *Bitstamp) generateDefaultSubscriptions() ([]subscription.Subscription, subscriptions = append(subscriptions, subscription.Subscription{ Channel: defaultAuthSubChannels[j] + "_" + p.String(), Asset: asset.Spot, - Pair: p, + Pairs: currency.Pairs{p}, Params: map[string]interface{}{ "auth": struct{}{}, }, diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 01ba1a64d4c..f067ff868f5 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -336,7 +336,7 @@ func (b *BTCMarkets) generateDefaultSubscriptions() ([]subscription.Subscription for j := range enabledCurrencies { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i], - Pair: enabledCurrencies[j], + Pairs: enabledCurrencies[j], Asset: asset.Spot, }) } @@ -370,10 +370,10 @@ func (b *BTCMarkets) Subscribe(subs []subscription.Subscription) error { authenticate = true } payload.Channels = append(payload.Channels, subs[i].Channel) - if subs[i].Pair.IsEmpty() { + if subs[i].Pairs.IsEmpty() { continue } - pair := subs[i].Pair.String() + pair := subs[i].Pairs.String() if common.StringDataCompare(payload.MarketIDs, pair) { continue } @@ -415,11 +415,11 @@ func (b *BTCMarkets) Unsubscribe(subs []subscription.Subscription) error { } for i := range subs { payload.Channels = append(payload.Channels, subs[i].Channel) - if subs[i].Pair.IsEmpty() { + if subs[i].Pairs.IsEmpty() { continue } - pair := subs[i].Pair.String() + pair := subs[i].Pairs.String() if common.StringDataCompare(payload.MarketIDs, pair) { continue } @@ -439,7 +439,7 @@ func (b *BTCMarkets) Unsubscribe(subs []subscription.Subscription) error { func (b *BTCMarkets) ReSubscribeSpecificOrderbook(pair currency.Pair) error { sub := []subscription.Subscription{{ Channel: wsOB, - Pair: pair, + Pairs: pair, Asset: asset.Spot, }} if err := b.Unsubscribe(sub); err != nil { diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index e32fb9a8095..0f294a71b2a 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -377,7 +377,7 @@ func (b *BTSE) GenerateDefaultSubscriptions() ([]subscription.Subscription, erro for j := range pairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: fmt.Sprintf(channels[i], pairs[j]), - Pair: pairs[j], + Pairs: pairs[j], Asset: asset.Spot, }) } diff --git a/exchanges/coinbasepro/coinbasepro_test.go b/exchanges/coinbasepro/coinbasepro_test.go index 0b04d4ab64b..67915bf2924 100644 --- a/exchanges/coinbasepro/coinbasepro_test.go +++ b/exchanges/coinbasepro/coinbasepro_test.go @@ -693,7 +693,7 @@ func TestWsAuth(t *testing.T) { err = c.Subscribe([]subscription.Subscription{ { Channel: "user", - Pair: testPair, + Pairs: testPair, }, }) if err != nil { diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index 5946cf778b5..961c75d7029 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -390,7 +390,7 @@ func (c *CoinbasePro) GenerateDefaultSubscriptions() ([]subscription.Subscriptio } subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i], - Pair: fPair, + Pairs: fPair, Asset: asset.Spot, }) } @@ -414,7 +414,7 @@ func (c *CoinbasePro) Subscribe(channelsToSubscribe []subscription.Subscription) } productIDs := make([]string, 0, len(channelsToSubscribe)) for i := range channelsToSubscribe { - p := channelsToSubscribe[i].Pair.String() + p := channelsToSubscribe[i].Pairs.String() if p != "" && !common.StringDataCompare(productIDs, p) { // get all unique productIDs in advance as we generate by channels productIDs = append(productIDs, p) @@ -466,7 +466,7 @@ func (c *CoinbasePro) Unsubscribe(channelsToUnsubscribe []subscription.Subscript } productIDs := make([]string, 0, len(channelsToUnsubscribe)) for i := range channelsToUnsubscribe { - p := channelsToUnsubscribe[i].Pair.String() + p := channelsToUnsubscribe[i].Pairs.String() if p != "" && !common.StringDataCompare(productIDs, p) { // get all unique productIDs in advance as we generate by channels productIDs = append(productIDs, p) diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 78b389879ca..b1c1780fda6 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -609,7 +609,7 @@ func (c *COINUT) GenerateDefaultSubscriptions() ([]subscription.Subscription, er for j := range enabledPairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i], - Pair: enabledPairs[j], + Pairs: enabledPairs[j], Asset: asset.Spot, }) } @@ -621,7 +621,7 @@ func (c *COINUT) GenerateDefaultSubscriptions() ([]subscription.Subscription, er func (c *COINUT) Subscribe(channelsToSubscribe []subscription.Subscription) error { var errs error for i := range channelsToSubscribe { - fPair, err := c.FormatExchangeCurrency(channelsToSubscribe[i].Pair, asset.Spot) + fPair, err := c.FormatExchangeCurrency(channelsToSubscribe[i].Pairs, asset.Spot) if err != nil { errs = common.AppendError(errs, err) continue @@ -650,7 +650,7 @@ func (c *COINUT) Subscribe(channelsToSubscribe []subscription.Subscription) erro func (c *COINUT) Unsubscribe(channelToUnsubscribe []subscription.Subscription) error { var errs error for i := range channelToUnsubscribe { - fPair, err := c.FormatExchangeCurrency(channelToUnsubscribe[i].Pair, asset.Spot) + fPair, err := c.FormatExchangeCurrency(channelToUnsubscribe[i].Pairs, asset.Spot) if err != nil { errs = common.AppendError(errs, err) continue diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index 3a3dddecf88..809835d7f8c 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -680,7 +680,7 @@ func (g *Gateio) GenerateDefaultSubscriptions() ([]subscription.Subscription, er subscriptions = append(subscriptions, subscription.Subscription{ Channel: channelsToSubscribe[i], - Pair: fpair.Upper(), + Pairs: fpair.Upper(), Asset: assetType, Params: params, }) @@ -738,8 +738,8 @@ func (g *Gateio) generatePayload(event string, channelsToSubscribe []subscriptio for i := range channelsToSubscribe { var auth *WsAuthInput timestamp := time.Now() - channelsToSubscribe[i].Pair.Delimiter = currency.UnderscoreDelimiter - params := []string{channelsToSubscribe[i].Pair.String()} + channelsToSubscribe[i].Pairs.Delimiter = currency.UnderscoreDelimiter + params := []string{channelsToSubscribe[i].Pairs.String()} switch channelsToSubscribe[i].Channel { case spotOrderbookChannel: interval, okay := channelsToSubscribe[i].Params["interval"].(kline.Interval) diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index 449181c1007..cf57caecb0f 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -176,7 +176,7 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() ([]subscription.S } subscriptions = append(subscriptions, subscription.Subscription{ Channel: channelsToSubscribe[i], - Pair: fpair.Upper(), + Pairs: fpair.Upper(), Params: params, }) } @@ -246,7 +246,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib var auth *WsAuthInput timestamp := time.Now() var params []string - params = []string{channelsToSubscribe[i].Pair.String()} + params = []string{channelsToSubscribe[i].Pairs.String()} if g.Websocket.CanUseAuthenticatedEndpoints() { switch channelsToSubscribe[i].Channel { case futuresOrdersChannel, futuresUserTradesChannel, @@ -310,7 +310,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib params = append(params, intervalString) } } - if strings.HasPrefix(channelsToSubscribe[i].Pair.Quote.Upper().String(), "USDT") { + if strings.HasPrefix(channelsToSubscribe[i].Pairs.Quote.Upper().String(), "USDT") { payloads[0] = append(payloads[0], WsInput{ ID: g.Websocket.Conn.GenerateMessageID(false), Event: event, diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 20e293b93af..f2c01fc0ae9 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -156,7 +156,7 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() ([]subscription.Subscript } subscriptions[count] = subscription.Subscription{ Channel: channelsToSubscribe[i], - Pair: fpair.Upper(), + Pairs: fpair.Upper(), Params: params, } count++ @@ -324,7 +324,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe []subs var auth *WsAuthInput timestamp := time.Now() var params []string - params = []string{channelsToSubscribe[i].Pair.String()} + params = []string{channelsToSubscribe[i].Pairs.String()} if g.Websocket.CanUseAuthenticatedEndpoints() { switch channelsToSubscribe[i].Channel { case futuresOrdersChannel, futuresUserTradesChannel, @@ -388,7 +388,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe []subs params = append(params, intervalString) } } - if strings.HasPrefix(channelsToSubscribe[i].Pair.Quote.Upper().String(), "USDT") { + if strings.HasPrefix(channelsToSubscribe[i].Pairs.Quote.Upper().String(), "USDT") { payloads[0] = append(payloads[0], WsInput{ ID: g.Websocket.Conn.GenerateMessageID(false), Event: event, diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index 3278914f21f..d69acf5b173 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -165,7 +165,7 @@ getEnabledPairs: } subscriptions = append(subscriptions, subscription.Subscription{ Channel: channelsToSubscribe[i], - Pair: fpair.Upper(), + Pairs: fpair.Upper(), Params: params, }) } @@ -190,7 +190,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe []subs optionsUnderlyingPriceChannel, optionsUnderlyingCandlesticksChannel: var uly currency.Pair - uly, err = g.GetUnderlyingFromCurrencyPair(channelsToSubscribe[i].Pair) + uly, err = g.GetUnderlyingFromCurrencyPair(channelsToSubscribe[i].Pairs) if err != nil { return nil, err } @@ -198,8 +198,8 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe []subs case optionsBalancesChannel: // options.balance channel does not require underlying or contract default: - channelsToSubscribe[i].Pair.Delimiter = currency.UnderscoreDelimiter - params = append(params, channelsToSubscribe[i].Pair.String()) + channelsToSubscribe[i].Pairs.Delimiter = currency.UnderscoreDelimiter + params = append(params, channelsToSubscribe[i].Pairs.String()) } switch channelsToSubscribe[i].Channel { case optionsOrderbookChannel: diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index 43c2135e021..af3ca1ba4ba 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -80,7 +80,7 @@ func (g *Gemini) GenerateDefaultSubscriptions() ([]subscription.Subscription, er for y := range pairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[x], - Pair: pairs[y], + Pairs: pairs[y], Asset: asset.Spot, }) } @@ -100,10 +100,10 @@ func (g *Gemini) Subscribe(channelsToSubscribe []subscription.Subscription) erro var pairs currency.Pairs for x := range channelsToSubscribe { - if pairs.Contains(channelsToSubscribe[x].Pair, true) { + if pairs.Contains(channelsToSubscribe[x].Pairs, true) { continue } - pairs = append(pairs, channelsToSubscribe[x].Pair) + pairs = append(pairs, channelsToSubscribe[x].Pairs) } fmtPairs, err := g.FormatExchangeCurrencies(pairs, asset.Spot) @@ -144,10 +144,10 @@ func (g *Gemini) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) var pairs currency.Pairs for x := range channelsToUnsubscribe { - if pairs.Contains(channelsToUnsubscribe[x].Pair, true) { + if pairs.Contains(channelsToUnsubscribe[x].Pairs, true) { continue } - pairs = append(pairs, channelsToUnsubscribe[x].Pair) + pairs = append(pairs, channelsToUnsubscribe[x].Pairs) } fmtPairs, err := g.FormatExchangeCurrencies(pairs, asset.Spot) diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index deb885424bf..5b030b79740 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -492,7 +492,7 @@ func (h *HitBTC) GenerateDefaultSubscriptions() ([]subscription.Subscription, er enabledCurrencies[j].Delimiter = "" subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i], - Pair: fPair, + Pairs: fPair, Asset: asset.Spot, }) } @@ -509,8 +509,8 @@ func (h *HitBTC) Subscribe(channelsToSubscribe []subscription.Subscription) erro ID: h.Websocket.Conn.GenerateMessageID(false), } - if channelsToSubscribe[i].Pair.String() != "" { - subscribe.Params.Symbol = channelsToSubscribe[i].Pair.String() + if channelsToSubscribe[i].Pairs.String() != "" { + subscribe.Params.Symbol = channelsToSubscribe[i].Pairs.String() } if strings.EqualFold(channelsToSubscribe[i].Channel, "subscribeTrades") { subscribe.Params.Limit = 100 @@ -546,7 +546,7 @@ func (h *HitBTC) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) Method: unsubscribeChannel, } - unsubscribe.Params.Symbol = channelsToUnsubscribe[i].Pair.String() + unsubscribe.Params.Symbol = channelsToUnsubscribe[i].Pairs.String() if strings.EqualFold(unsubscribeChannel, "unsubscribeTrades") { unsubscribe.Params.Limit = 100 } else if strings.EqualFold(unsubscribeChannel, "unsubscribeCandles") { diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index 92b5bf4c22e..d9cdc84bad5 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -538,7 +538,7 @@ func (h *HUOBI) GenerateDefaultSubscriptions() ([]subscription.Subscription, err enabledCurrencies[j].Lower().String()) subscriptions = append(subscriptions, subscription.Subscription{ Channel: channel, - Pair: enabledCurrencies[j], + Pairs: enabledCurrencies[j], }) } } diff --git a/exchanges/kucoin/kucoin_test.go b/exchanges/kucoin/kucoin_test.go index 302af45eb0c..bc46ac1a241 100644 --- a/exchanges/kucoin/kucoin_test.go +++ b/exchanges/kucoin/kucoin_test.go @@ -2542,7 +2542,7 @@ func TestProcessMarketSnapshot(t *testing.T) { func TestSubscribeMarketSnapshot(t *testing.T) { t.Parallel() setupWS() - err := ku.Subscribe([]subscription.Subscription{{Channel: marketSymbolSnapshotChannel, Pair: currency.Pair{Base: currency.BTC}}}) + err := ku.Subscribe([]subscription.Subscription{{Channel: marketSymbolSnapshotChannel, Pairs: currency.Pair{Base: currency.BTC}}}) assert.NoError(t, err, "Subscribe to MarketSnapshot should not error") } diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index 0787775e2f9..f0d0ada2b47 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -584,7 +584,7 @@ func (o *Okcoin) wsProcessOrderbook(respRaw []byte, obChannel string) error { func (o *Okcoin) ReSubscribeSpecificOrderbook(obChannel string, p currency.Pair) error { subscription := []subscription.Subscription{{ Channel: obChannel, - Pair: p, + Pairs: p, }} if err := o.Unsubscribe(subscription); err != nil { return err @@ -801,7 +801,7 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er for p := range pairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[s], - Pair: pairs[p], + Pairs: pairs[p], }) } case wsStatus: @@ -836,7 +836,7 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er for p := range pairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[s], - Pair: pairs[p], + Pairs: pairs[p], Asset: asset.Spot, }) } @@ -891,8 +891,8 @@ func (o *Okcoin) handleSubscriptions(operation string, subs []subscription.Subsc if subs[i].Asset != asset.Empty { argument["instType"] = strings.ToUpper(subs[i].Asset.String()) } - if !subs[i].Pair.IsEmpty() { - argument["instId"] = subs[i].Pair.String() + if !subs[i].Pairs.IsEmpty() { + argument["instId"] = subs[i].Pairs.String() } if authenticatedChannelSubscription { authTemp.Arguments = append(authTemp.Arguments, argument) diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 23774335fc8..dd57a762016 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -562,7 +562,7 @@ func (p *Poloniex) GenerateDefaultSubscriptions() ([]subscription.Subscription, enabledPairs[j].Delimiter = currency.UnderscoreDelimiter subscriptions = append(subscriptions, subscription.Subscription{ Channel: "orderbook", - Pair: enabledPairs[j], + Pairs: enabledPairs[j], Asset: asset.Spot, }) } @@ -599,7 +599,7 @@ channels: sub[i].Channel): subscriptionRequest.Channel = wsTickerDataID default: - subscriptionRequest.Channel = sub[i].Pair.String() + subscriptionRequest.Channel = sub[i].Pairs.String() } err := p.Websocket.Conn.SendJSONMessage(subscriptionRequest) @@ -646,7 +646,7 @@ channels: unsub[i].Channel): unsubscriptionRequest.Channel = wsTickerDataID default: - unsubscriptionRequest.Channel = unsub[i].Pair.String() + unsubscriptionRequest.Channel = unsub[i].Pairs.String() } err := p.Websocket.Conn.SendJSONMessage(unsubscriptionRequest) if err != nil { From b7b906c35eb7ac051b40304293645ec6f1a1a96e Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Mon, 19 Feb 2024 10:20:37 +0700 Subject: [PATCH 32/35] Kraken: Add subscription Pairs support Note: This is a naieve implementation because we want to rebase the kraken websocket rewrite on top of this --- exchanges/kraken/kraken_test.go | 2 +- exchanges/kraken/kraken_websocket.go | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index 8530cb53b8d..c99627fc3fb 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1251,7 +1251,7 @@ func TestWebsocketSubscribe(t *testing.T) { err := k.Subscribe([]subscription.Subscription{ { Channel: defaultSubscribedChannels[0], - Pair: currency.NewPairWithDelimiter("XBT", "USD", "/"), + Pairs: currency.Pairs{currency.NewPairWithDelimiter("XBT", "USD", "/")}, }, }) if err != nil { diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index fd4325164e2..5657c2d908c 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -859,7 +859,7 @@ func (k *Kraken) wsProcessOrderBook(channelData *WebsocketChannelData, data map[ } }(&subscription.Subscription{ Channel: krakenWsOrderbook, - Pair: outbound, + Pairs: currency.Pairs{outbound}, Asset: asset.Spot, }) return err @@ -1221,7 +1221,7 @@ func (k *Kraken) GenerateDefaultSubscriptions() ([]subscription.Subscription, er enabledPairs[j].Delimiter = "/" subscriptions = append(subscriptions, subscription.Subscription{ Channel: defaultSubscribedChannels[i], - Pair: enabledPairs[j], + Pairs: currency.Pairs{enabledPairs[j]}, Asset: asset.Spot, }) } @@ -1248,7 +1248,7 @@ channels: } for j := range *s { - (*s)[j].Pairs = append((*s)[j].Pairs, channelsToSubscribe[i].Pair.String()) + (*s)[j].Pairs = append((*s)[j].Pairs, channelsToSubscribe[i].Pairs[0].String()) (*s)[j].Channels = append((*s)[j].Channels, channelsToSubscribe[i]) continue channels } @@ -1264,8 +1264,8 @@ channels: if channelsToSubscribe[i].Channel == "book" { outbound.Subscription.Depth = krakenWsOrderbookDepth } - if !channelsToSubscribe[i].Pair.IsEmpty() { - outbound.Pairs = []string{channelsToSubscribe[i].Pair.String()} + if !channelsToSubscribe[i].Pairs[0].IsEmpty() { + outbound.Pairs = []string{channelsToSubscribe[i].Pairs[0].String()} } if common.StringDataContains(authenticatedChannels, channelsToSubscribe[i].Channel) { outbound.Subscription.Token = authToken @@ -1306,7 +1306,7 @@ channels: for y := range unsubs { if unsubs[y].Subscription.Name == channelsToUnsubscribe[x].Channel { unsubs[y].Pairs = append(unsubs[y].Pairs, - channelsToUnsubscribe[x].Pair.String()) + channelsToUnsubscribe[x].Pairs[0].String()) unsubs[y].Channels = append(unsubs[y].Channels, channelsToUnsubscribe[x]) continue channels @@ -1326,7 +1326,7 @@ channels: unsub := WebsocketSubscriptionEventRequest{ Event: krakenWsUnsubscribe, - Pairs: []string{channelsToUnsubscribe[x].Pair.String()}, + Pairs: []string{channelsToUnsubscribe[x].Pairs[0].String()}, Subscription: WebsocketSubscriptionData{ Name: channelsToUnsubscribe[x].Channel, Depth: depth, From 2ecad0832a2bfc804422f277b83a20d40e35ef38 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Wed, 21 Feb 2024 15:09:20 +0700 Subject: [PATCH 33/35] Linter: Disable testifylint.Len We deliberately use Equal over Len to avoid spamming the contents of large Slices --- .golangci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.golangci.yml b/.golangci.yml index 585c5ca0f9b..b5324f959b5 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -146,6 +146,8 @@ linters-settings: disable: - require-error - float-compare + # We deliberately use Equal over Len to avoid spamming the contents of large Slices + - len issues: max-issues-per-linter: 0 From 4b0d2f13b1290f317c3be01b9e52022b04fd4c46 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 22 Feb 2024 09:57:57 +0700 Subject: [PATCH 34/35] Websocket: Add suffix to state consts --- exchanges/stream/websocket.go | 18 +++++++------- exchanges/stream/websocket_test.go | 38 ++++++++++++++--------------- exchanges/stream/websocket_types.go | 8 +++--- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 82d8f3f1847..d0029b208db 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -190,7 +190,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions) } w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection - w.setState(disconnected) + w.setState(disconnectedState) return nil } @@ -279,14 +279,14 @@ func (w *Websocket) Connect() error { w.dataMonitor() w.trafficMonitor() - w.setState(connecting) + w.setState(connectingState) err := w.connector() if err != nil { - w.setState(disconnected) + w.setState(disconnectedState) return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } - w.setState(connected) + w.setState(connectedState) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() @@ -406,7 +406,7 @@ func (w *Websocket) connectionMonitor() error { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) - w.setState(disconnected) + w.setState(disconnectedState) } w.DataHandler <- err @@ -466,7 +466,7 @@ func (w *Websocket) Shutdown() error { // flush any subscriptions from last connection if needed w.subscriptions.Clear() - w.setState(disconnected) + w.setState(disconnectedState) close(w.ShutdownC) w.Wg.Wait() @@ -597,17 +597,17 @@ func (w *Websocket) setState(s uint32) { // IsInitialised returns whether the websocket has been Setup() already func (w *Websocket) IsInitialised() bool { - return w.state.Load() != uninitialised + return w.state.Load() != uninitialisedState } // IsConnected returns whether the websocket is connected func (w *Websocket) IsConnected() bool { - return w.state.Load() == connected + return w.state.Load() == connectedState } // IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { - return w.state.Load() == connecting + return w.state.Load() == connectingState } func (w *Websocket) setEnabled(b bool) { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index a10a5402f9c..3c370fd3676 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -194,13 +194,13 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { signal := struct{}{} patience := 10 * time.Millisecond ws.trafficTimeout = 200 * time.Millisecond - ws.state.Store(connected) + ws.state.Store(connectedState) thenish := time.Now() ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, connected, ws.state.Load(), "websocket must be connected") + require.Equal(t, connectedState, ws.state.Load(), "websocket must be connected") for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass select { @@ -228,7 +228,7 @@ func TestTrafficMonitorTrafficAlerts(t *testing.T) { } require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") } @@ -240,16 +240,16 @@ func TestTrafficMonitorConnecting(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - ws.state.Store(connecting) + ws.state.Store(connectingState) ws.trafficTimeout = 50 * time.Millisecond ws.trafficMonitor() require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") - require.Equal(t, connecting, ws.state.Load(), "websocket must be connecting") + require.Equal(t, connectingState, ws.state.Load(), "websocket must be connecting") <-time.After(4 * ws.trafficTimeout) - require.Equal(t, connecting, ws.state.Load(), "websocket must still be connecting after several checks") - ws.state.Store(connected) + require.Equal(t, connectingState, ws.state.Load(), "websocket must still be connecting after several checks") + ws.state.Store(connectedState) require.EventuallyWithT(t, func(c *assert.CollectT) { - assert.Equal(c, disconnected, ws.state.Load(), "websocket must be disconnected") + assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") } @@ -261,7 +261,7 @@ func TestTrafficMonitorShutdown(t *testing.T) { err := ws.Setup(defaultSetup) require.NoError(t, err, "Setup must not error") - ws.state.Store(connected) + ws.state.Store(connectedState) ws.trafficTimeout = time.Minute ws.trafficMonitor() assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") @@ -307,16 +307,16 @@ func TestConnectionMessageErrors(t *testing.T) { assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) - wsWrong.setState(connecting) + wsWrong.setState(connectingState) err = wsWrong.Connect() assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") - wsWrong.setState(disconnected) + wsWrong.setState(disconnectedState) err = wsWrong.Connect() assert.ErrorIs(t, err, common.ErrNilPointer, "Connect should get a nil pointer error, presumably on subs") wsWrong.subscriptions = subscription.NewStore() - wsWrong.setState(disconnected) + wsWrong.setState(disconnectedState) wsWrong.connector = func() error { return errDastardlyReason } err = wsWrong.Connect() assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") @@ -382,7 +382,7 @@ func TestWebsocket(t *testing.T) { err = ws.SetProxyAddress("https://192.168.0.1:1337") assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") - ws.setState(connected) + ws.setState(connectedState) err = ws.SetProxyAddress("https://192.168.0.1:1336") assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") @@ -405,14 +405,14 @@ func TestWebsocket(t *testing.T) { assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") - ws.setState(connected) + ws.setState(connectedState) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} - ws.setState(disconnected) + ws.setState(disconnectedState) err = ws.Connect() assert.NoError(t, err, "Connect should not error") @@ -859,7 +859,7 @@ func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { ws := &Websocket{} assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") - ws.setState(connected) + ws.setState(connectedState) require.True(t, ws.IsConnected(), "IsConnected must return true") assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") @@ -1021,7 +1021,7 @@ func TestFlushChannels(t *testing.T) { w.trafficTimeout = time.Second * 30 w.setEnabled(true) - w.setState(connected) + w.setState(connectedState) problemFunc := func() (subscription.List, error) { return nil, errDastardlyReason @@ -1077,7 +1077,7 @@ func TestFlushChannels(t *testing.T) { err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") - w.setState(connected) + w.setState(connectedState) w.features.Unsubscribe = true err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") @@ -1087,7 +1087,7 @@ func TestDisable(t *testing.T) { t.Parallel() w := NewWebsocket() w.setEnabled(true) - w.setState(connected) + w.setState(connectedState) require.NoError(t, w.Disable(), "Disable must not error") assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 353003d0a21..707fc7dcb05 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -23,10 +23,10 @@ const ( ) const ( - uninitialised uint32 = iota - disconnected - connecting - connected + uninitialisedState uint32 = iota + disconnectedState + connectingState + connectedState ) // Websocket defines a return type for websocket connections via the interface From 579ec5571af153198d15de1adf10fc7e88e5ce7c Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 22 Feb 2024 10:52:13 +0700 Subject: [PATCH 35/35] Subscription: Test coverage --- currency/pairs.go | 13 +++ exchanges/subscription/list_test.go | 25 +++++ exchanges/subscription/subscription.go | 26 +++-- exchanges/subscription/subscription_test.go | 109 +++++++++++++++----- 4 files changed, 139 insertions(+), 34 deletions(-) create mode 100644 exchanges/subscription/list_test.go diff --git a/currency/pairs.go b/currency/pairs.go index 887b46fcbc9..a68edfa9eb8 100644 --- a/currency/pairs.go +++ b/currency/pairs.go @@ -52,6 +52,19 @@ func (p Pairs) Strings() []string { return list } +// String is a convenience method returning a comma-separated string of uppercase currencies using / as delimiter +func (p Pairs) String() string { + f := PairFormat{ + Delimiter: "/", + Uppercase: true, + } + l := make([]string, len(p)) + for i, pair := range p { + l[i] = f.Format(pair) + } + return strings.Join(l, ",") +} + // Join returns a comma separated list of currency pairs func (p Pairs) Join() string { return strings.Join(p.Strings(), ",") diff --git a/exchanges/subscription/list_test.go b/exchanges/subscription/list_test.go new file mode 100644 index 00000000000..e2293e7ea9d --- /dev/null +++ b/exchanges/subscription/list_test.go @@ -0,0 +1,25 @@ +package subscription + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" +) + +func TestListStrings(t *testing.T) { + l := List{ + &Subscription{ + Channel: TickerChannel, + Asset: asset.Spot, + Pairs: currency.Pairs{ethusdcPair, btcusdtPair}, + }, + &Subscription{ + Channel: OrderbookChannel, + Pairs: currency.Pairs{ethusdcPair}, + }, + } + exp := []string{"orderbook ETH/USDC", "ticker spot ETH/USDC,BTC/USDT"} + assert.ElementsMatch(t, exp, l.Strings(), "String must return correct sorted list") +} diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 3e340ebbc1c..32e799ad44c 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -88,8 +88,8 @@ func (s *Subscription) SetState(state State) error { return nil } -// EnsureKeyed sets the default key on a channel if it doesn't have one -// Returns key for convenience +// EnsureKeyed returns the subscription key +// If no key exists then a pointer to the subscription itself will be used, since Subscriptions implement MatchableKey func (s *Subscription) EnsureKeyed() any { if s.Key == nil { s.Key = s @@ -103,17 +103,25 @@ func (s *Subscription) EnsureKeyed() any { // 2) >=1 pairs then Subscriptions which contain all the pairs match // Such that a subscription for all enabled pairs will be matched when seaching for any one pair func (s *Subscription) Match(key any) bool { - b, ok := key.(*Subscription) + var b *Subscription + switch v := key.(type) { + case *Subscription: + b = v + case Subscription: + b = &v + default: + return false + } + switch { - case !ok, - s.Channel != b.Channel, - s.Asset != b.Asset, - len(b.Pairs) == 0 && len(s.Pairs) != 0, + case b.Channel != s.Channel, + b.Asset != s.Asset, // len(b.Pairs) == 0 && len(s.Pairs) == 0: Okay; continue to next non-pairs check + len(b.Pairs) == 0 && len(s.Pairs) != 0, len(b.Pairs) != 0 && len(s.Pairs) == 0, len(b.Pairs) != 0 && s.Pairs.ContainsAll(b.Pairs, true) != nil, - s.Levels != b.Levels, - s.Interval != b.Interval: + b.Levels != s.Levels, + b.Interval != s.Interval: return false } diff --git a/exchanges/subscription/subscription_test.go b/exchanges/subscription/subscription_test.go index 38cabb8694a..b9a71b4ae6a 100644 --- a/exchanges/subscription/subscription_test.go +++ b/exchanges/subscription/subscription_test.go @@ -11,36 +11,67 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" ) -// TestEnsureKeyed logic test -func TestEnsureKeyed(t *testing.T) { - t.Parallel() +var ( + btcusdtPair = currency.NewPair(currency.BTC, currency.USDT) + ethusdcPair = currency.NewPair(currency.ETH, currency.USDC) + ltcusdcPair = currency.NewPair(currency.LTC, currency.USDC) +) + +// TestSubscriptionString exercises the String method +func TestSubscriptionString(t *testing.T) { s := &Subscription{ Channel: "candles", Asset: asset.Spot, - Pairs: []currency.Pair{currency.NewPair(currency.BTC, currency.USDT)}, + Pairs: currency.Pairs{btcusdtPair, ethusdcPair.Format(currency.PairFormat{Delimiter: "/"})}, + } + assert.Equal(t, "candles spot BTC/USDT,ETH/USDC", s.String(), "Subscription String should return correct value") +} + +// TestState exercises the state getter +func TestState(t *testing.T) { + t.Parallel() + s := &Subscription{} + assert.Equal(t, InactiveState, s.State(), "State should return initial state") + s.state = SubscribedState + assert.Equal(t, SubscribedState, s.State(), "State should return correct state") +} + +// TestSetState exercises the state setter +func TestSetState(t *testing.T) { + t.Parallel() + + s := &Subscription{state: UnsubscribingState} + + for i := InactiveState; i <= UnsubscribingState; i++ { + assert.NoErrorf(t, s.SetState(i), "State should not error setting state %s", i) } + assert.ErrorIs(t, s.SetState(UnsubscribingState), ErrInStateAlready, "SetState should error on same state") + assert.ErrorIs(t, s.SetState(UnsubscribingState+1), ErrInvalidState, "Setting an invalid state should error") +} + +// TestEnsureKeyed exercises the key getter and ensures it sets a self-pointer key for non +func TestEnsureKeyed(t *testing.T) { + t.Parallel() + s := &Subscription{} k1, ok := s.EnsureKeyed().(*Subscription) if assert.True(t, ok, "EnsureKeyed should return a *Subscription") { - assert.Same(t, k1, s, "Key should point to the same struct") + assert.Same(t, s, k1, "Key should point to the same struct") } type platypus string s = &Subscription{ Key: platypus("Gerald"), Channel: "orderbook", - Asset: asset.Margin, - Pairs: []currency.Pair{currency.NewPair(currency.ETH, currency.USDC)}, - } - k2, ok := s.EnsureKeyed().(platypus) - if assert.True(t, ok, "EnsureKeyed should return a platypus") { - assert.Exactly(t, k2, s.Key, "ensureKeyed should set the same key") - assert.EqualValues(t, "Gerald", k2, "key should have the correct value") } + k2 := s.EnsureKeyed() + assert.IsType(t, platypus(""), k2, "EnsureKeyed should return a platypus") + assert.Equal(t, s.Key, k2, "Key should be the key provided") } -// TestMarshalling logic test -func TestMarshaling(t *testing.T) { +// TestSubscriptionMarshalling ensures json Marshalling is clean and concise +// Since there is no UnmarshalJSON, this just exercises the json field tags of Subscription, and regressions in conciseness +func TestSubscriptionMarshaling(t *testing.T) { t.Parallel() - j, err := json.Marshal(&Subscription{Channel: CandlesChannel}) + j, err := json.Marshal(&Subscription{Key: 42, Channel: CandlesChannel}) assert.NoError(t, err, "Marshalling should not error") assert.Equal(t, `{"enabled":false,"channel":"candles"}`, string(j), "Marshalling should be clean and concise") @@ -57,16 +88,44 @@ func TestMarshaling(t *testing.T) { assert.Equal(t, `{"enabled":true,"channel":"myTrades","authenticated":true}`, string(j), "Marshalling should be clean and concise") } -// TestSetState tests Subscription state changes -func TestSetState(t *testing.T) { +// TestSubscriptionMatch exercises the Subscription MatchableKey interface implementation +func TestSubscriptionMatch(t *testing.T) { t.Parallel() + require.Implements(t, (*MatchableKey)(nil), new(Subscription), "Must implement MatchableKey") + s := &Subscription{Channel: TickerChannel} + assert.NotNil(t, s.EnsureKeyed(), "EnsureKeyed should work") + assert.False(t, s.Match(42), "Match should reject an invalid key type") + try := &Subscription{Channel: OrderbookChannel} + require.False(t, s.Match(try), "Gate 1: Match must reject a bad Channel") + try = &Subscription{Channel: TickerChannel} + require.True(t, s.Match(Subscription{Channel: TickerChannel}), "Match must accept a pass-by-value subscription") + require.True(t, s.Match(try), "Gate 1: Match must accept a good Channel") + s.Asset = asset.Spot + require.False(t, s.Match(try), "Gate 2: Match must reject a bad Asset") + try.Asset = asset.Spot + require.True(t, s.Match(try), "Gate 2: Match must accept a good Asset") - s := &Subscription{Key: 42, Channel: "Gophers"} - assert.Equal(t, InactiveState, s.State(), "State should start as unknown") - require.NoError(t, s.SetState(SubscribingState), "SetState should not error") - assert.Equal(t, SubscribingState, s.State(), "State should be set correctly") - assert.ErrorIs(t, s.SetState(SubscribingState), ErrInStateAlready, "SetState should error on same state") - assert.ErrorIs(t, s.SetState(UnsubscribingState+1), ErrInvalidState, "Setting an invalid state should error") - require.NoError(t, s.SetState(UnsubscribingState), "SetState should not error") - assert.Equal(t, UnsubscribingState, s.State(), "State should be set correctly") + s.Pairs = currency.Pairs{btcusdtPair} + require.False(t, s.Match(try), "Gate 3: Match must reject a pair list when searching for no pairs") + try.Pairs = s.Pairs + s.Pairs = nil + require.False(t, s.Match(try), "Gate 4: Match must reject empty Pairs when searching for a list") + s.Pairs = try.Pairs + require.True(t, s.Match(try), "Gate 5: Match must accept matching pairs") + s.Pairs = currency.Pairs{ethusdcPair} + require.False(t, s.Match(try), "Gate 5: Match must reject mismatched pairs") + s.Pairs = currency.Pairs{btcusdtPair, ethusdcPair} + require.True(t, s.Match(try), "Gate 5: Match must accept one of the key pairs matching in sub pairs") + try.Pairs = currency.Pairs{btcusdtPair, ltcusdcPair} + require.False(t, s.Match(try), "Gate 5: Match must reject when sub pair list doesn't contain all key pairs") + s.Pairs = currency.Pairs{btcusdtPair, ethusdcPair, ltcusdcPair} + require.True(t, s.Match(try), "Gate 5: Match must accept all of the key pairs are contained in sub pairs") + s.Levels = 4 + require.False(t, s.Match(try), "Gate 6: Match must reject a bad Level") + try.Levels = 4 + require.True(t, s.Match(try), "Gate 6: Match must accept a good Level") + s.Interval = kline.FiveMin + require.False(t, s.Match(try), "Gate 7: Match must reject a bad Interval") + try.Interval = kline.FiveMin + require.True(t, s.Match(try), "Gate 7: Match must accept a good Inteval") }