diff --git a/.golangci.yml b/.golangci.yml index 585c5ca0f9b..b5324f959b5 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -146,6 +146,8 @@ linters-settings: disable: - require-error - float-compare + # We deliberately use Equal over Len to avoid spamming the contents of large Slices + - len issues: max-issues-per-linter: 0 diff --git a/cmd/exchange_template/wrapper_file.tmpl b/cmd/exchange_template/wrapper_file.tmpl index d57f96b5451..e74ecbc320f 100644 --- a/cmd/exchange_template/wrapper_file.tmpl +++ b/cmd/exchange_template/wrapper_file.tmpl @@ -125,7 +125,7 @@ func ({{.Variable}} *{{.CapitalName}}) SetDefaults() { exchange.RestSpot: {{.Name}}APIURL, // exchange.WebsocketSpot: {{.Name}}WSAPIURL, }) - {{.Variable}}.Websocket = stream.New() + {{.Variable}}.Websocket = stream.NewWebsocket() {{.Variable}}.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit {{.Variable}}.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout {{.Variable}}.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/common/common.go b/common/common.go index 9e0d32db2bc..c4483a2d5cc 100644 --- a/common/common.go +++ b/common/common.go @@ -463,11 +463,22 @@ func InArray(val, array interface{}) (exists bool, index int) { return } +// fmtError holds a formatted msg and the errors which formatted it +type fmtError struct { + errs []error + msg string +} + // multiError holds errors as a slice type multiError struct { errs []error } +type unwrappable interface { + Unwrap() []error + Error() string +} + // AppendError appends an error to a list of exesting errors // Either argument may be: // * A vanilla error @@ -481,20 +492,35 @@ func AppendError(original, incoming error) error { if original == nil { return incoming } - newErrs := []error{incoming} - if u, ok := incoming.(interface{ Unwrap() []error }); ok { - newErrs = u.Unwrap() + if u, ok := incoming.(unwrappable); ok { + incoming = &fmtError{ + errs: u.Unwrap(), + msg: u.Error(), + } } - if u, ok := original.(interface{ Unwrap() []error }); ok { - return &multiError{ - errs: append(u.Unwrap(), newErrs...), + switch v := original.(type) { + case *multiError: + v.errs = append(v.errs, incoming) + return v + case unwrappable: + original = &fmtError{ + errs: v.Unwrap(), + msg: v.Error(), } } return &multiError{ - errs: append([]error{original}, newErrs...), + errs: append([]error{original}, incoming), } } +func (e *fmtError) Error() string { + return e.msg +} + +func (e *fmtError) Unwrap() []error { + return e.errs +} + // Error displays all errors comma separated func (e *multiError) Error() string { allErrors := make([]string, len(e.errs)) @@ -506,11 +532,16 @@ func (e *multiError) Error() string { // Unwrap returns all of the errors in the multiError func (e *multiError) Unwrap() []error { - return e.errs -} - -type unwrappable interface { - Unwrap() []error + errs := make([]error, 0, len(e.errs)) + for _, e := range e.errs { + switch v := e.(type) { + case unwrappable: + errs = append(errs, unwrapDeep(v)...) + default: + errs = append(errs, v) + } + } + return errs } // unwrapDeep walks down a stack of nested fmt.Errorf("%w: %w") errors diff --git a/common/common_test.go b/common/common_test.go index fd6d448d4c6..e9c8bf28412 100644 --- a/common/common_test.go +++ b/common/common_test.go @@ -703,6 +703,13 @@ func TestErrors(t *testing.T) { assert.ErrorIs(t, ExcludeError(err, e5), e3, "Excluding e5 should retain e3") assert.ErrorIs(t, ExcludeError(err, e5), e4, "Excluding e5 should retain the vanilla co-wrapped e4") assert.NotErrorIs(t, ExcludeError(err, e5), e5, "e4 should be excluded") + + // Formatting retention + err = AppendError(e1, fmt.Errorf("%w: Run out of `%s`: %w", e3, "sausages", e5)) + assert.ErrorIs(t, err, e1, "Should be an e1") + assert.ErrorIs(t, err, e3, "Should be an e3") + assert.ErrorIs(t, err, e5, "Should be an e5") + assert.ErrorContains(t, err, "sausages", "Should know about secret snausages") } func TestParseStartEndDate(t *testing.T) { diff --git a/currency/pairs.go b/currency/pairs.go index 887b46fcbc9..a68edfa9eb8 100644 --- a/currency/pairs.go +++ b/currency/pairs.go @@ -52,6 +52,19 @@ func (p Pairs) Strings() []string { return list } +// String is a convenience method returning a comma-separated string of uppercase currencies using / as delimiter +func (p Pairs) String() string { + f := PairFormat{ + Delimiter: "/", + Uppercase: true, + } + l := make([]string, len(p)) + for i, pair := range p { + l[i] = f.Format(pair) + } + return strings.Join(l, ",") +} + // Join returns a comma separated list of currency pairs func (p Pairs) Join() string { return strings.Join(p.Strings(), ",") diff --git a/engine/websocketroutine_manager_test.go b/engine/websocketroutine_manager_test.go index e082193bda4..c1c3541db31 100644 --- a/engine/websocketroutine_manager_test.go +++ b/engine/websocketroutine_manager_test.go @@ -293,7 +293,7 @@ func TestRegisterWebsocketDataHandlerWithFunctionality(t *testing.T) { t.Fatal("unexpected data handlers registered") } - mock := stream.New() + mock := stream.NewWebsocket() mock.ToRoutine = make(chan interface{}) m.state = readyState err = m.websocketDataReceiver(mock) diff --git a/exchanges/binance/binance_websocket.go b/exchanges/binance/binance_websocket.go index 52cdd0cc392..bd96d546f4f 100644 --- a/exchanges/binance/binance_websocket.go +++ b/exchanges/binance/binance_websocket.go @@ -50,7 +50,7 @@ var ( // WsConnect initiates a websocket connection func (b *Binance) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer diff --git a/exchanges/binance/binance_wrapper.go b/exchanges/binance/binance_wrapper.go index bdf70e0d3ee..d54f89242ca 100644 --- a/exchanges/binance/binance_wrapper.go +++ b/exchanges/binance/binance_wrapper.go @@ -238,7 +238,7 @@ func (b *Binance) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/binanceus/binanceus_websocket.go b/exchanges/binanceus/binanceus_websocket.go index 8f4d5c3cd6e..90d710997f5 100644 --- a/exchanges/binanceus/binanceus_websocket.go +++ b/exchanges/binanceus/binanceus_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect initiates a websocket connection func (bi *Binanceus) WsConnect() error { if !bi.Websocket.IsEnabled() || !bi.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.HandshakeTimeout = bi.Config.HTTPTimeout @@ -560,7 +560,7 @@ subs: } subscriptions = append(subscriptions, subscription.Subscription{ Channel: lp.String() + channels[z], - Pair: pairs[y], + Pairs: currency.Pairs{pairs[y]}, Asset: asset.Spot, }) } diff --git a/exchanges/binanceus/binanceus_wrapper.go b/exchanges/binanceus/binanceus_wrapper.go index 3f078ede7a8..ed2e5d10f5e 100644 --- a/exchanges/binanceus/binanceus_wrapper.go +++ b/exchanges/binanceus/binanceus_wrapper.go @@ -162,7 +162,7 @@ func (bi *Binanceus) SetDefaults() { "%s setting default endpoints error %v", bi.Name, err) } - bi.Websocket = stream.New() + bi.Websocket = stream.NewWebsocket() bi.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit bi.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout bi.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitfinex/bitfinex_test.go b/exchanges/bitfinex/bitfinex_test.go index c2bba90d00b..4baace3186f 100644 --- a/exchanges/bitfinex/bitfinex_test.go +++ b/exchanges/bitfinex/bitfinex_test.go @@ -1128,7 +1128,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !b.Websocket.IsEnabled() { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } sharedtestvalues.SkipTestIfCredentialsUnset(t, b) if !b.API.AuthenticatedWebsocketSupport { @@ -1158,7 +1158,7 @@ func TestWsAuth(t *testing.T) { // See also TestSubscribeReq which covers key and symbol conversion func TestWsSubscribe(t *testing.T) { setupWs(t) - err := b.Subscribe([]subscription.Subscription{{Channel: wsTicker, Pair: currency.NewPair(currency.BTC, currency.USD), Asset: asset.Spot}}) + err := b.Subscribe([]subscription.Subscription{{Channel: wsTicker, Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USD)}, Asset: asset.Spot}}) assert.NoError(t, err, "Subrcribe should not error") catcher := func() (ok bool) { i := <-b.Websocket.DataHandler @@ -1171,7 +1171,7 @@ func TestWsSubscribe(t *testing.T) { assert.NoError(t, err, "GetSubscriptions should not error") assert.Len(t, subs, 1, "We should only have 1 subscription; subID subscription should have been Removed by subscribeToChan") - err = b.Subscribe([]subscription.Subscription{{Channel: wsTicker, Pair: currency.NewPair(currency.BTC, currency.USD), Asset: asset.Spot}}) + err = b.Subscribe([]subscription.Subscription{{Channel: wsTicker, Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USD)}, Asset: asset.Spot}}) assert.ErrorIs(t, err, stream.ErrSubscriptionFailure, "Duplicate subscription should error correctly") catcher = func() bool { i := <-b.Websocket.DataHandler @@ -1202,7 +1202,7 @@ func TestWsSubscribe(t *testing.T) { err = b.Subscribe([]subscription.Subscription{{ Channel: wsTicker, - Pair: currency.NewPair(currency.BTC, currency.USD), + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USD)}, Asset: asset.Spot, Params: map[string]interface{}{"key": "tBTCUSD"}, }}) @@ -1214,7 +1214,7 @@ func TestWsSubscribe(t *testing.T) { func TestSubscribeReq(t *testing.T) { c := &subscription.Subscription{ Channel: wsCandles, - Pair: currency.NewPair(currency.BTC, currency.USD), + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USD)}, Asset: asset.MarginFunding, Params: map[string]interface{}{ CandlesPeriodKey: "30", @@ -1233,14 +1233,14 @@ func TestSubscribeReq(t *testing.T) { c = &subscription.Subscription{ Channel: wsBook, - Pair: currency.NewPair(currency.BTC, currency.DOGE), + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.DOGE)}, Asset: asset.Spot, } r, err = subscribeReq(c) assert.NoError(t, err, "subscribeReq should not error") assert.Equal(t, "tBTC:DOGE", r["symbol"], "symbol should use colon delimiter if a currency is > 3 chars") - c.Pair = currency.NewPair(currency.BTC, currency.LTC) + c.Pairs = currency.Pairs{currency.NewPair(currency.BTC, currency.LTC)} r, err = subscribeReq(c) assert.NoError(t, err, "subscribeReq should not error") assert.Equal(t, "tBTCLTC", r["symbol"], "symbol should not use colon delimiter if both currencies < 3 chars") @@ -1353,7 +1353,7 @@ func TestWsSubscribedResponse(t *testing.T) { } func TestWsOrderBook(t *testing.T) { - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Key: 23405, Asset: asset.Spot, Pair: btcusdPair, Channel: wsBook}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Key: 23405, Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsBook}) pressXToJSON := `[23405,[[38334303613,9348.8,0.53],[38334308111,9348.8,5.98979404],[38331335157,9344.1,1.28965787],[38334302803,9343.8,0.08230094],[38334279092,9343,0.8],[38334307036,9342.938663676,0.8],[38332749107,9342.9,0.2],[38332277330,9342.8,0.85],[38329406786,9342,0.1432012],[38332841570,9341.947288638,0.3],[38332163238,9341.7,0.3],[38334303384,9341.6,0.324],[38332464840,9341.4,0.5],[38331935870,9341.2,0.5],[38334312082,9340.9,0.02126899],[38334261292,9340.8,0.26763],[38334138680,9340.625455254,0.12],[38333896802,9339.8,0.85],[38331627527,9338.9,1.57863959],[38334186713,9338.9,0.26769],[38334305819,9338.8,2.999],[38334211180,9338.75285796,3.999],[38334310699,9337.8,0.10679883],[38334307414,9337.5,1],[38334179822,9337.1,0.26773],[38334306600,9336.659955102,1.79],[38334299667,9336.6,1.1],[38334306452,9336.6,0.13979771],[38325672859,9336.3,1.25],[38334311646,9336.2,1],[38334258509,9336.1,0.37],[38334310592,9336,1.79],[38334310378,9335.6,1.43],[38334132444,9335.2,0.26777],[38331367325,9335,0.07],[38334310703,9335,0.10680562],[38334298209,9334.7,0.08757301],[38334304857,9334.456899462,0.291],[38334309940,9334.088390727,0.0725],[38334310377,9333.7,1.2868],[38334297615,9333.607784,0.1108],[38334095188,9333.3,0.26785],[38334228913,9332.7,0.40861186],[38334300526,9332.363996604,0.3884],[38334310701,9332.2,0.10680562],[38334303548,9332.005382871,0.07],[38334311798,9331.8,0.41285228],[38334301012,9331.7,1.7952],[38334089877,9331.4,0.2679],[38321942150,9331.2,0.2],[38334310670,9330,1.069],[38334063096,9329.6,0.26796],[38334310700,9329.4,0.10680562],[38334310404,9329.3,1],[38334281630,9329.1,6.57150597],[38334036864,9327.7,0.26801],[38334310702,9326.6,0.10680562],[38334311799,9326.1,0.50220625],[38334164163,9326,0.219638],[38334309722,9326,1.5],[38333051682,9325.8,0.26807],[38334302027,9325.7,0.75],[38334203435,9325.366592,0.32397696],[38321967613,9325,0.05],[38334298787,9324.9,0.3],[38334301719,9324.8,3.6227592],[38331316716,9324.763454646,0.71442],[38334310698,9323.8,0.10680562],[38334035499,9323.7,0.23431017],[38334223472,9322.670551788,0.42150603],[38334163459,9322.560399006,0.143967],[38321825171,9320.8,2],[38334075805,9320.467496148,0.30772633],[38334075800,9319.916732238,0.61457592],[38333682302,9319.7,0.0011],[38331323088,9319.116771762,0.12913],[38333677480,9319,0.0199],[38334277797,9318.6,0.89],[38325235155,9318.041088,1.20249],[38334310910,9317.82382938,1.79],[38334311811,9317.2,0.61079138],[38334311812,9317.2,0.71937652],[38333298214,9317.1,50],[38334306359,9317,1.79],[38325531545,9316.382823951,0.21263],[38333727253,9316.3,0.02316372],[38333298213,9316.1,45],[38333836479,9316,2.135],[38324520465,9315.9,2.7681],[38334307411,9315.5,1],[38330313617,9315.3,0.84455],[38334077770,9315.294024,0.01248397],[38334286663,9315.294024,1],[38325533762,9315.290315394,2.40498],[38334310018,9315.2,3],[38333682617,9314.6,0.0011],[38334304794,9314.6,0.76364676],[38334304798,9314.3,0.69242113],[38332915733,9313.8,0.0199],[38334084411,9312.8,1],[38334311893,9350.1,-1.015],[38334302734,9350.3,-0.26737],[38334300732,9350.8,-5.2],[38333957619,9351,-0.90677089],[38334300521,9351,-1.6457],[38334301600,9351.012829557,-0.0523],[38334308878,9351.7,-2.5],[38334299570,9351.921544,-0.1015],[38334279367,9352.1,-0.26732],[38334299569,9352.411802928,-0.4036],[38334202773,9353.4,-0.02139404],[38333918472,9353.7,-1.96412776],[38334278782,9354,-0.26731],[38334278606,9355,-1.2785],[38334302105,9355.439221251,-0.79191542],[38313897370,9355.569409242,-0.43363],[38334292995,9355.584296,-0.0979],[38334216989,9355.8,-0.03686414],[38333894025,9355.9,-0.26721],[38334293798,9355.936691952,-0.4311],[38331159479,9356,-0.4204022],[38333918888,9356.1,-1.10885563],[38334298205,9356.4,-0.20124428],[38328427481,9356.5,-0.1],[38333343289,9356.6,-0.41034213],[38334297205,9356.6,-0.08835018],[38334277927,9356.741101161,-0.0737],[38334311645,9356.8,-0.5],[38334309002,9356.9,-5],[38334309736,9357,-0.10680107],[38334306448,9357.4,-0.18645275],[38333693302,9357.7,-0.2672],[38332815159,9357.8,-0.0011],[38331239824,9358.2,-0.02],[38334271608,9358.3,-2.999],[38334311971,9358.4,-0.55],[38333919260,9358.5,-1.9972841],[38334265365,9358.5,-1.7841],[38334277960,9359,-3],[38334274601,9359.020969848,-3],[38326848839,9359.1,-0.84],[38334291080,9359.247048,-0.16199869],[38326848844,9359.4,-1.84],[38333680200,9359.6,-0.26713],[38331326606,9359.8,-0.84454],[38334309738,9359.8,-0.10680107],[38331314707,9359.9,-0.2],[38333919803,9360.9,-1.41177599],[38323651149,9361.33417827,-0.71442],[38333656906,9361.5,-0.26705],[38334035500,9361.5,-0.40861586],[38334091886,9362.4,-6.85940815],[38334269617,9362.5,-4],[38323629409,9362.545858872,-2.40497],[38334309737,9362.7,-0.10680107],[38334312380,9362.7,-3],[38325280830,9362.8,-1.75123],[38326622800,9362.8,-1.05145],[38333175230,9363,-0.0011],[38326848745,9363.2,-0.79],[38334308960,9363.206775564,-0.12],[38333920234,9363.3,-1.25318113],[38326848843,9363.4,-1.29],[38331239823,9363.4,-0.02],[38333209613,9363.4,-0.26719],[38334299964,9364,-0.05583123],[38323470224,9364.161816648,-0.12912],[38334284711,9365,-0.21346019],[38334299594,9365,-2.6757062],[38323211816,9365.073132585,-0.21262],[38334312456,9365.1,-0.11167861],[38333209612,9365.2,-0.26719],[38327770474,9365.3,-0.0073],[38334298788,9365.3,-0.3],[38334075803,9365.409831204,-0.30772637],[38334309740,9365.5,-0.10680107],[38326608767,9365.7,-2.76809],[38333920657,9365.7,-1.25848083],[38329594226,9366.6,-0.02587],[38334311813,9366.7,-4.72290945],[38316386301,9367.39258128,-2.37581],[38334302026,9367.4,-4.5],[38334228915,9367.9,-0.81725458],[38333921381,9368.1,-1.72213641],[38333175678,9368.2,-0.0011],[38334301150,9368.2,-2.654604],[38334297208,9368.3,-0.78036466],[38334309739,9368.3,-0.10680107],[38331227515,9368.7,-0.02],[38331184470,9369,-0.003975],[38334203436,9369.319616,-0.32397695],[38334269964,9369.7,-0.5],[38328386732,9370,-4.11759935],[38332719555,9370,-0.025],[38333921935,9370.5,-1.2224398],[38334258511,9370.5,-0.35],[38326848842,9370.8,-0.34],[38333985038,9370.9,-0.8551502],[38334283018,9370.9,-1],[38326848744,9371,-1.34]],5]` err := b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1370,7 +1370,7 @@ func TestWsOrderBook(t *testing.T) { } func TestWsTradeResponse(t *testing.T) { - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: btcusdPair, Channel: wsTrades, Key: 18788}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTrades, Key: 18788}) pressXToJSON := `[18788,[[412685577,1580268444802,11.1998,176.3],[412685575,1580268444802,5,176.29952759],[412685574,1580268374717,1.99069999,176.41],[412685573,1580268374717,1.00930001,176.41],[412685572,1580268358760,0.9907,176.47],[412685571,1580268324362,0.5505,176.44],[412685570,1580268297270,-0.39040819,176.39],[412685568,1580268297270,-0.39780162,176.46475676],[412685567,1580268283470,-0.09,176.41],[412685566,1580268256536,-2.31310783,176.48],[412685565,1580268256536,-0.59669217,176.49],[412685564,1580268256536,-0.9902,176.49],[412685562,1580268194474,0.9902,176.55],[412685561,1580268186215,0.1,176.6],[412685560,1580268185964,-2.17096773,176.5],[412685559,1580268185964,-1.82903227,176.51],[412685558,1580268181215,2.098914,176.53],[412685557,1580268169844,16.7302,176.55],[412685556,1580268169844,3.25,176.54],[412685555,1580268155725,0.23576115,176.45],[412685553,1580268155725,3,176.44596249],[412685552,1580268155725,3.25,176.44],[412685551,1580268155725,5,176.44],[412685550,1580268155725,0.65830078,176.41],[412685549,1580268155725,0.45063807,176.41],[412685548,1580268153825,-0.67604704,176.39],[412685547,1580268145713,2.5883,176.41],[412685543,1580268087513,12.92927,176.33],[412685542,1580268087513,0.40083,176.33],[412685533,1580268005756,-0.17096773,176.32]]]` err := b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1379,7 +1379,7 @@ func TestWsTradeResponse(t *testing.T) { } func TestWsTickerResponse(t *testing.T) { - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: btcusdPair, Channel: wsTicker, Key: 11534}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsTicker, Key: 11534}) pressXToJSON := `[11534,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err := b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1389,7 +1389,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: pair, Channel: wsTicker, Key: 123412}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: pair, Channel: wsTicker, Key: 123412}) pressXToJSON = `[123412,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1399,7 +1399,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: pair, Channel: wsTicker, Key: 123413}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: pair, Channel: wsTicker, Key: 123413}) pressXToJSON = `[123413,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1409,7 +1409,7 @@ func TestWsTickerResponse(t *testing.T) { if err != nil { t.Error(err) } - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: pair, Channel: wsTicker, Key: 123414}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: pair, Channel: wsTicker, Key: 123414}) pressXToJSON = `[123414,[61.304,2228.36155358,61.305,1323.2442970500003,0.395,0.0065,61.371,50973.3020771,62.5,57.421]]` err = b.wsHandleData([]byte(pressXToJSON)) if err != nil { @@ -1418,7 +1418,7 @@ func TestWsTickerResponse(t *testing.T) { } func TestWsCandleResponse(t *testing.T) { - b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pair: btcusdPair, Channel: wsCandles, Key: 343351}) + b.Websocket.AddSuccessfulSubscriptions(subscription.Subscription{Asset: asset.Spot, Pairs: currency.Pairs{btcusdPair}, Channel: wsCandles, Key: 343351}) pressXToJSON := `[343351,[[1574698260000,7379.785503,7383.8,7388.3,7379.785503,1.68829482]]]` err := b.wsHandleData([]byte(pressXToJSON)) if err != nil { diff --git a/exchanges/bitfinex/bitfinex_websocket.go b/exchanges/bitfinex/bitfinex_websocket.go index e1010eb1061..df97ad55beb 100644 --- a/exchanges/bitfinex/bitfinex_websocket.go +++ b/exchanges/bitfinex/bitfinex_websocket.go @@ -43,7 +43,7 @@ var cMtx sync.Mutex // WsConnect starts a new websocket connection func (b *Bitfinex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -506,7 +506,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { chanID, err := jsonparser.GetInt(respRaw, "chanId") if err != nil { - return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, errParsingWSField, err, c.Channel, c.Pair) + return fmt.Errorf("%w: %w 'chanId': %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, errParsingWSField, err, c.Channel, c.Pairs) } // Note: chanID's int type avoids conflicts with the string type subID key because of the type difference @@ -516,7 +516,7 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { b.Websocket.AddSuccessfulSubscriptions(*c) if b.Verbose { - log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pair, chanID) + log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s ChannelID: %d\n", b.Name, c.Channel, c.Pairs, chanID) } if !b.Websocket.Match.IncomingWithData("subscribe:"+subID, respRaw) { return fmt.Errorf("%v channel subscribe listener not found", subID) @@ -525,6 +525,10 @@ func (b *Bitfinex) handleWSSubscribed(respRaw []byte) error { } func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType string, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if eventType == wsChecksum { return b.handleWSChecksum(c, d) } @@ -533,6 +537,10 @@ func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType return nil } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } + switch c.Channel { case wsBook: return b.handleWSBookUpdate(c, d) @@ -548,6 +556,9 @@ func (b *Bitfinex) handleWSChannelUpdate(c *subscription.Subscription, eventType } func (b *Bitfinex) handleWSChecksum(c *subscription.Subscription, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } var token int if f, ok := d[2].(float64); !ok { return common.GetTypeAssertError("float64", d[2], "checksum") @@ -579,6 +590,12 @@ func (b *Bitfinex) handleWSChecksum(c *subscription.Subscription, d []interface{ } func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } var newOrderbook []WebsocketBook obSnapBundle, ok := d[1].([]interface{}) if !ok { @@ -632,7 +649,7 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac Amount: rateAmount}) } } - if err := b.WsInsertSnapshot(c.Pair, c.Asset, newOrderbook, fundingRate); err != nil { + if err := b.WsInsertSnapshot(c.Pairs[0], c.Asset, newOrderbook, fundingRate); err != nil { return fmt.Errorf("inserting snapshot error: %s", err) } @@ -664,7 +681,7 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac Amount: amountRate}) } - if err := b.WsUpdateOrderbook(c, c.Pair, c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil { + if err := b.WsUpdateOrderbook(c, c.Pairs[0], c.Asset, newOrderbook, int64(sequenceNo), fundingRate); err != nil { return fmt.Errorf("updating orderbook error: %s", err) } @@ -674,6 +691,12 @@ func (b *Bitfinex) handleWSBookUpdate(c *subscription.Subscription, d []interfac } func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } candleBundle, ok := d[1].([]interface{}) if !ok || len(candleBundle) == 0 { return nil @@ -712,7 +735,7 @@ func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interf } klineData.Exchange = b.Name klineData.AssetType = c.Asset - klineData.Pair = c.Pair + klineData.Pair = c.Pairs[0] b.Websocket.DataHandler <- klineData } case float64: @@ -741,13 +764,19 @@ func (b *Bitfinex) handleWSCandleUpdate(c *subscription.Subscription, d []interf } klineData.Exchange = b.Name klineData.AssetType = c.Asset - klineData.Pair = c.Pair + klineData.Pair = c.Pairs[0] b.Websocket.DataHandler <- klineData } return nil } func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } tickerData, ok := d[1].([]interface{}) if !ok { return errors.New("type assertion for tickerData") @@ -755,7 +784,7 @@ func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interf t := &ticker.Price{ AssetType: c.Asset, - Pair: c.Pair, + Pair: c.Pairs[0], ExchangeName: b.Name, } @@ -821,6 +850,12 @@ func (b *Bitfinex) handleWSTickerUpdate(c *subscription.Subscription, d []interf } func (b *Bitfinex) handleWSTradesUpdate(c *subscription.Subscription, eventType string, d []interface{}) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } if !b.IsSaveTradeDataEnabled() { return nil } @@ -936,7 +971,7 @@ func (b *Bitfinex) handleWSTradesUpdate(c *subscription.Subscription, eventType } trades[i] = trade.Data{ TID: strconv.FormatInt(tradeHolder[i].ID, 10), - CurrencyPair: c.Pair, + CurrencyPair: c.Pairs[0], Timestamp: time.UnixMilli(tradeHolder[i].Timestamp), Price: price, Amount: newAmount, @@ -1510,6 +1545,12 @@ func (b *Bitfinex) WsInsertSnapshot(p currency.Pair, assetType asset.Item, books // WsUpdateOrderbook updates the orderbook list, removing and adding to the // orderbook sides func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pair, assetType asset.Item, book []WebsocketBook, sequenceNo int64, fundingRate bool) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } orderbookUpdate := orderbook.Update{ Asset: assetType, Pair: p, @@ -1592,7 +1633,9 @@ func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pa if err = validateCRC32(ob, checkme.Token); err != nil { log.Errorf(log.WebsocketMgr, "%s websocket orderbook update error, will resubscribe orderbook: %v", b.Name, err) - b.resubOrderbook(c) + if e2 := b.resubOrderbook(c); e2 != nil { + log.Errorf(log.WebsocketMgr, "%s error resubscribing orderbook: %v", b.Name, e2) + } return err } } @@ -1603,8 +1646,15 @@ func (b *Bitfinex) WsUpdateOrderbook(c *subscription.Subscription, p currency.Pa // resubOrderbook resubscribes the orderbook after a consistency error, probably a failed checksum, // which forces a fresh snapshot. If we don't do this the orderbook will keep erroring and drifting. // Flushing the orderbook happens immediately, but the ReSub itself is a go routine to avoid blocking the WS data channel -func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) { - if err := b.Websocket.Orderbook.FlushOrderbook(c.Pair, c.Asset); err != nil { +func (b *Bitfinex) resubOrderbook(c *subscription.Subscription) error { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } + if err := b.Websocket.Orderbook.FlushOrderbook(c.Pairs[0], c.Asset); err != nil { + // Non-fatal error log.Errorf(log.ExchangeSys, "%s error flushing orderbook: %v", b.Name, err) } @@ -1645,7 +1695,7 @@ func (b *Bitfinex) GenerateDefaultSubscriptions() ([]subscription.Subscription, subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[j], - Pair: enabledPairs[k], + Pairs: enabledPairs[k], Params: params, Asset: assets[i], }) @@ -1684,7 +1734,7 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error { c := chans[0] req, err := subscribeReq(&c) if err != nil { - return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair) + return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs) } // subId is a single round-trip identifier that provides linking sub requests to chanIDs @@ -1699,7 +1749,7 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error { c.State = subscription.SubscribingState err = b.Websocket.AddSubscription(&c) if err != nil { - return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pair, err) + return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err) } // Always remove the temporary subscription keyed by subID @@ -1707,11 +1757,11 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error { respRaw, err := b.Websocket.Conn.SendMessageReturnResponse("subscribe:"+subID, req) if err != nil { - return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair) + return fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs) } if err = b.getErrResp(respRaw); err != nil { - wErr := fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pair) + wErr := fmt.Errorf("%w: %w; Channel: %s Pair: %s", stream.ErrSubscriptionFailure, err, c.Channel, c.Pairs) b.Websocket.DataHandler <- wErr return wErr } @@ -1721,6 +1771,13 @@ func (b *Bitfinex) subscribeToChan(chans []subscription.Subscription) error { // subscribeReq returns a map of request params for subscriptions func subscribeReq(c *subscription.Subscription) (map[string]interface{}, error) { + if c == nil { + return common.ErrNilPointer + } + if len(c.Pairs) != 1 { + return subscription.ErrNotSinglePair + } + pair := c.Pairs[0] req := map[string]interface{}{ "event": "subscribe", "channel": c.Channel, @@ -1743,13 +1800,13 @@ func subscribeReq(c *subscription.Subscription) (map[string]interface{}, error) prefix = "f" } - needsDelimiter := c.Pair.Len() > 6 + needsDelimiter := pair.Len() > 6 var formattedPair string if needsDelimiter { - formattedPair = c.Pair.Format(currency.PairFormat{Uppercase: true, Delimiter: ":"}).String() + formattedPair = pair.Format(currency.PairFormat{Uppercase: true, Delimiter: ":"}).String() } else { - formattedPair = currency.PairFormat{Uppercase: true}.Format(c.Pair) + formattedPair = currency.PairFormat{Uppercase: true}.Format(pair) } if c.Channel == wsCandles { diff --git a/exchanges/bitfinex/bitfinex_wrapper.go b/exchanges/bitfinex/bitfinex_wrapper.go index 7e1ef64c766..23c7446882d 100644 --- a/exchanges/bitfinex/bitfinex_wrapper.go +++ b/exchanges/bitfinex/bitfinex_wrapper.go @@ -198,7 +198,7 @@ func (b *Bitfinex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit @@ -643,8 +643,8 @@ func (b *Bitfinex) SubmitOrder(ctx context.Context, o *order.Submit) (*order.Sub var orderID string status := order.New if b.Websocket.CanUseAuthenticatedWebsocketForWrapper() { - symbolStr, err := b.fixCasing(fPair, o.AssetType) //nolint:govet // intentional shadow of err - if err != nil { + var symbolStr string + if symbolStr, err = b.fixCasing(fPair, o.AssetType); err != nil { return nil, err } orderType := strings.ToUpper(o.Type.String()) diff --git a/exchanges/bithumb/bithumb_websocket.go b/exchanges/bithumb/bithumb_websocket.go index 3005f42359c..990d9c767e4 100644 --- a/exchanges/bithumb/bithumb_websocket.go +++ b/exchanges/bithumb/bithumb_websocket.go @@ -2,7 +2,6 @@ package bithumb import ( "encoding/json" - "errors" "fmt" "net/http" "time" @@ -29,7 +28,7 @@ var ( // WsConnect initiates a websocket connection func (b *Bithumb) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer @@ -185,7 +184,7 @@ func (b *Bithumb) GenerateSubscriptions() ([]subscription.Subscription, error) { for y := range channels { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[y], - Pair: pairs[x].Format(pFmt), + Pairs: pairs[x].Format(pFmt), Asset: asset.Spot, }) } @@ -204,7 +203,7 @@ func (b *Bithumb) Subscribe(channelsToSubscribe []subscription.Subscription) err } subs[channelsToSubscribe[i].Channel] = s } - s.Symbols = append(s.Symbols, channelsToSubscribe[i].Pair) + s.Symbols = append(s.Symbols, channelsToSubscribe[i].Pairs) } tSub, ok := subs["ticker"] diff --git a/exchanges/bithumb/bithumb_wrapper.go b/exchanges/bithumb/bithumb_wrapper.go index 5dbc314674d..24ea8980b19 100644 --- a/exchanges/bithumb/bithumb_wrapper.go +++ b/exchanges/bithumb/bithumb_wrapper.go @@ -150,7 +150,7 @@ func (b *Bithumb) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout } diff --git a/exchanges/bitmex/bitmex_test.go b/exchanges/bitmex/bitmex_test.go index b59cb125812..e825ad8a07c 100644 --- a/exchanges/bitmex/bitmex_test.go +++ b/exchanges/bitmex/bitmex_test.go @@ -789,7 +789,7 @@ func TestGetDepositAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !b.Websocket.IsEnabled() && !b.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(b) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bitmex/bitmex_websocket.go b/exchanges/bitmex/bitmex_websocket.go index 6d04c106039..82c36a6e69c 100644 --- a/exchanges/bitmex/bitmex_websocket.go +++ b/exchanges/bitmex/bitmex_websocket.go @@ -68,7 +68,7 @@ const ( // WsConnect initiates a new websocket connection func (b *Bitmex) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -571,7 +571,7 @@ func (b *Bitmex) GenerateDefaultSubscriptions() ([]subscription.Subscription, er } subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[z] + ":" + pFmt.Format(contracts[y]), - Pair: contracts[y], + Pairs: contracts[y], Asset: assets[x], }) } @@ -621,7 +621,7 @@ func (b *Bitmex) GenerateAuthenticatedSubscriptions() ([]subscription.Subscripti for j := range contracts { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i] + ":" + pFmt.Format(contracts[j]), - Pair: contracts[j], + Pairs: contracts[j], Asset: asset.PerpetualContract, }) } diff --git a/exchanges/bitmex/bitmex_wrapper.go b/exchanges/bitmex/bitmex_wrapper.go index 151f1ef28e9..a0810b1a302 100644 --- a/exchanges/bitmex/bitmex_wrapper.go +++ b/exchanges/bitmex/bitmex_wrapper.go @@ -175,7 +175,7 @@ func (b *Bitmex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bitstamp/bitstamp_websocket.go b/exchanges/bitstamp/bitstamp_websocket.go index ab5465b574d..094ec21e785 100644 --- a/exchanges/bitstamp/bitstamp_websocket.go +++ b/exchanges/bitstamp/bitstamp_websocket.go @@ -45,7 +45,7 @@ var ( // WsConnect connects to a websocket feed func (b *Bitstamp) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -246,7 +246,7 @@ func (b *Bitstamp) generateDefaultSubscriptions() ([]subscription.Subscription, subscriptions = append(subscriptions, subscription.Subscription{ Channel: defaultSubChannels[j] + "_" + p.String(), Asset: asset.Spot, - Pair: p, + Pairs: currency.Pairs{p}, }) } if b.Websocket.CanUseAuthenticatedEndpoints() { @@ -254,7 +254,7 @@ func (b *Bitstamp) generateDefaultSubscriptions() ([]subscription.Subscription, subscriptions = append(subscriptions, subscription.Subscription{ Channel: defaultAuthSubChannels[j] + "_" + p.String(), Asset: asset.Spot, - Pair: p, + Pairs: currency.Pairs{p}, Params: map[string]interface{}{ "auth": struct{}{}, }, diff --git a/exchanges/bitstamp/bitstamp_wrapper.go b/exchanges/bitstamp/bitstamp_wrapper.go index 888fae0b395..2fe7fa96d82 100644 --- a/exchanges/bitstamp/bitstamp_wrapper.go +++ b/exchanges/bitstamp/bitstamp_wrapper.go @@ -146,7 +146,7 @@ func (b *Bitstamp) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btcmarkets/btcmarkets_websocket.go b/exchanges/btcmarkets/btcmarkets_websocket.go index 8c979b4d79b..f067ff868f5 100644 --- a/exchanges/btcmarkets/btcmarkets_websocket.go +++ b/exchanges/btcmarkets/btcmarkets_websocket.go @@ -39,7 +39,7 @@ var ( // WsConnect connects to a websocket feed func (b *BTCMarkets) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -336,7 +336,7 @@ func (b *BTCMarkets) generateDefaultSubscriptions() ([]subscription.Subscription for j := range enabledCurrencies { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i], - Pair: enabledCurrencies[j], + Pairs: enabledCurrencies[j], Asset: asset.Spot, }) } @@ -370,10 +370,10 @@ func (b *BTCMarkets) Subscribe(subs []subscription.Subscription) error { authenticate = true } payload.Channels = append(payload.Channels, subs[i].Channel) - if subs[i].Pair.IsEmpty() { + if subs[i].Pairs.IsEmpty() { continue } - pair := subs[i].Pair.String() + pair := subs[i].Pairs.String() if common.StringDataCompare(payload.MarketIDs, pair) { continue } @@ -415,11 +415,11 @@ func (b *BTCMarkets) Unsubscribe(subs []subscription.Subscription) error { } for i := range subs { payload.Channels = append(payload.Channels, subs[i].Channel) - if subs[i].Pair.IsEmpty() { + if subs[i].Pairs.IsEmpty() { continue } - pair := subs[i].Pair.String() + pair := subs[i].Pairs.String() if common.StringDataCompare(payload.MarketIDs, pair) { continue } @@ -439,7 +439,7 @@ func (b *BTCMarkets) Unsubscribe(subs []subscription.Subscription) error { func (b *BTCMarkets) ReSubscribeSpecificOrderbook(pair currency.Pair) error { sub := []subscription.Subscription{{ Channel: wsOB, - Pair: pair, + Pairs: pair, Asset: asset.Spot, }} if err := b.Unsubscribe(sub); err != nil { diff --git a/exchanges/btcmarkets/btcmarkets_wrapper.go b/exchanges/btcmarkets/btcmarkets_wrapper.go index 17ee7277c0a..8a924b08cf0 100644 --- a/exchanges/btcmarkets/btcmarkets_wrapper.go +++ b/exchanges/btcmarkets/btcmarkets_wrapper.go @@ -150,7 +150,7 @@ func (b *BTCMarkets) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/btse/btse_websocket.go b/exchanges/btse/btse_websocket.go index 41f25c95e35..0f294a71b2a 100644 --- a/exchanges/btse/btse_websocket.go +++ b/exchanges/btse/btse_websocket.go @@ -30,7 +30,7 @@ const ( // WsConnect connects the websocket client func (b *BTSE) WsConnect() error { if !b.Websocket.IsEnabled() || !b.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := b.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -377,7 +377,7 @@ func (b *BTSE) GenerateDefaultSubscriptions() ([]subscription.Subscription, erro for j := range pairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: fmt.Sprintf(channels[i], pairs[j]), - Pair: pairs[j], + Pairs: pairs[j], Asset: asset.Spot, }) } diff --git a/exchanges/btse/btse_wrapper.go b/exchanges/btse/btse_wrapper.go index f7ec3c8f702..3cdbd56b2c5 100644 --- a/exchanges/btse/btse_wrapper.go +++ b/exchanges/btse/btse_wrapper.go @@ -176,7 +176,7 @@ func (b *BTSE) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit b.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout b.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/bybit/bybit.go b/exchanges/bybit/bybit.go index 5652c464057..3588336ff52 100644 --- a/exchanges/bybit/bybit.go +++ b/exchanges/bybit/bybit.go @@ -21,7 +21,6 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/orderbook" "github.com/thrasher-corp/gocryptotrader/exchanges/request" - "github.com/thrasher-corp/gocryptotrader/exchanges/stream" ) // Bybit is the overarching type across this package @@ -90,7 +89,6 @@ var ( errAPIKeyIsNotUnified = errors.New("api key is not unified") errEndpointAvailableForNormalAPIKeyHolders = errors.New("endpoint available for normal API key holders only") errInvalidContractLength = errors.New("contract length cannot be less than or equal to zero") - errWebsocketNotEnabled = errors.New(stream.WebsocketNotEnabled) ) var ( diff --git a/exchanges/bybit/bybit_inverse_websocket.go b/exchanges/bybit/bybit_inverse_websocket.go index 77f387ace60..d1387c277f8 100644 --- a/exchanges/bybit/bybit_inverse_websocket.go +++ b/exchanges/bybit/bybit_inverse_websocket.go @@ -12,7 +12,7 @@ import ( // WsInverseConnect connects to inverse websocket feed func (by *Bybit) WsInverseConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.CoinMarginedFutures) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(inversePublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_linear_websocket.go b/exchanges/bybit/bybit_linear_websocket.go index efc2f68d1b8..9b3ed08426a 100644 --- a/exchanges/bybit/bybit_linear_websocket.go +++ b/exchanges/bybit/bybit_linear_websocket.go @@ -14,7 +14,7 @@ import ( // WsLinearConnect connects to linear a websocket feed func (by *Bybit) WsLinearConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.LinearContract) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(linearPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_options_websocket.go b/exchanges/bybit/bybit_options_websocket.go index 2f4abc7a76d..4bb25cef2a9 100644 --- a/exchanges/bybit/bybit_options_websocket.go +++ b/exchanges/bybit/bybit_options_websocket.go @@ -14,7 +14,7 @@ import ( // WsOptionsConnect connects to options a websocket feed func (by *Bybit) WsOptionsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Options) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } by.Websocket.Conn.SetURL(optionPublic) var dialer websocket.Dialer diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 58ca5bc46e8..092ca125462 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -20,6 +20,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/margin" "github.com/thrasher-corp/gocryptotrader/exchanges/order" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) @@ -3064,7 +3065,7 @@ func TestWsLinearConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsLinearConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3074,7 +3075,7 @@ func TestWsInverseConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsInverseConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } @@ -3084,7 +3085,7 @@ func TestWsOptionsConnect(t *testing.T) { t.Skip(skippingWebsocketFunctionsForMockTesting) } err := b.WsOptionsConnect() - if err != nil && !errors.Is(err, errWebsocketNotEnabled) { + if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { t.Error(err) } } diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 857fd690afe..ff2698b8667 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -57,7 +57,7 @@ const ( // WsConnect connects to a websocket feed func (by *Bybit) WsConnect() error { if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) { - return errWebsocketNotEnabled + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := by.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 28d4f15041d..b6e7f2be224 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -216,7 +216,7 @@ func (by *Bybit) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - by.Websocket = stream.New() + by.Websocket = stream.NewWebsocket() by.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit by.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout by.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinbasepro/coinbasepro_test.go b/exchanges/coinbasepro/coinbasepro_test.go index 77f058384d1..67915bf2924 100644 --- a/exchanges/coinbasepro/coinbasepro_test.go +++ b/exchanges/coinbasepro/coinbasepro_test.go @@ -681,7 +681,7 @@ func TestGetDepositAddress(t *testing.T) { // TestWsAuth dials websocket, sends login request. func TestWsAuth(t *testing.T) { if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -693,7 +693,7 @@ func TestWsAuth(t *testing.T) { err = c.Subscribe([]subscription.Subscription{ { Channel: "user", - Pair: testPair, + Pairs: testPair, }, }) if err != nil { diff --git a/exchanges/coinbasepro/coinbasepro_websocket.go b/exchanges/coinbasepro/coinbasepro_websocket.go index e4b02b764d8..961c75d7029 100644 --- a/exchanges/coinbasepro/coinbasepro_websocket.go +++ b/exchanges/coinbasepro/coinbasepro_websocket.go @@ -31,7 +31,7 @@ const ( // WsConnect initiates a websocket connection func (c *CoinbasePro) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -390,7 +390,7 @@ func (c *CoinbasePro) GenerateDefaultSubscriptions() ([]subscription.Subscriptio } subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i], - Pair: fPair, + Pairs: fPair, Asset: asset.Spot, }) } @@ -414,7 +414,7 @@ func (c *CoinbasePro) Subscribe(channelsToSubscribe []subscription.Subscription) } productIDs := make([]string, 0, len(channelsToSubscribe)) for i := range channelsToSubscribe { - p := channelsToSubscribe[i].Pair.String() + p := channelsToSubscribe[i].Pairs.String() if p != "" && !common.StringDataCompare(productIDs, p) { // get all unique productIDs in advance as we generate by channels productIDs = append(productIDs, p) @@ -466,7 +466,7 @@ func (c *CoinbasePro) Unsubscribe(channelsToUnsubscribe []subscription.Subscript } productIDs := make([]string, 0, len(channelsToUnsubscribe)) for i := range channelsToUnsubscribe { - p := channelsToUnsubscribe[i].Pair.String() + p := channelsToUnsubscribe[i].Pairs.String() if p != "" && !common.StringDataCompare(productIDs, p) { // get all unique productIDs in advance as we generate by channels productIDs = append(productIDs, p) diff --git a/exchanges/coinbasepro/coinbasepro_wrapper.go b/exchanges/coinbasepro/coinbasepro_wrapper.go index 47a20d9ed95..21a34e2ea22 100644 --- a/exchanges/coinbasepro/coinbasepro_wrapper.go +++ b/exchanges/coinbasepro/coinbasepro_wrapper.go @@ -145,7 +145,7 @@ func (c *CoinbasePro) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/coinut/coinut_test.go b/exchanges/coinut/coinut_test.go index 3431d60f730..0b165327846 100644 --- a/exchanges/coinut/coinut_test.go +++ b/exchanges/coinut/coinut_test.go @@ -66,7 +66,7 @@ func setupWSTestAuth(t *testing.T) { } if !c.Websocket.IsEnabled() && !c.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(c) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } if sharedtestvalues.AreAPICredentialsSet(c) { c.Websocket.SetCanUseAuthenticatedEndpoints(true) diff --git a/exchanges/coinut/coinut_websocket.go b/exchanges/coinut/coinut_websocket.go index 2453816e9ce..b1c1780fda6 100644 --- a/exchanges/coinut/coinut_websocket.go +++ b/exchanges/coinut/coinut_websocket.go @@ -41,7 +41,7 @@ var ( // WsConnect initiates a websocket connection func (c *COINUT) WsConnect() error { if !c.Websocket.IsEnabled() || !c.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := c.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -609,7 +609,7 @@ func (c *COINUT) GenerateDefaultSubscriptions() ([]subscription.Subscription, er for j := range enabledPairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i], - Pair: enabledPairs[j], + Pairs: enabledPairs[j], Asset: asset.Spot, }) } @@ -621,7 +621,7 @@ func (c *COINUT) GenerateDefaultSubscriptions() ([]subscription.Subscription, er func (c *COINUT) Subscribe(channelsToSubscribe []subscription.Subscription) error { var errs error for i := range channelsToSubscribe { - fPair, err := c.FormatExchangeCurrency(channelsToSubscribe[i].Pair, asset.Spot) + fPair, err := c.FormatExchangeCurrency(channelsToSubscribe[i].Pairs, asset.Spot) if err != nil { errs = common.AppendError(errs, err) continue @@ -650,7 +650,7 @@ func (c *COINUT) Subscribe(channelsToSubscribe []subscription.Subscription) erro func (c *COINUT) Unsubscribe(channelToUnsubscribe []subscription.Subscription) error { var errs error for i := range channelToUnsubscribe { - fPair, err := c.FormatExchangeCurrency(channelToUnsubscribe[i].Pair, asset.Spot) + fPair, err := c.FormatExchangeCurrency(channelToUnsubscribe[i].Pairs, asset.Spot) if err != nil { errs = common.AppendError(errs, err) continue diff --git a/exchanges/coinut/coinut_wrapper.go b/exchanges/coinut/coinut_wrapper.go index 503c2909461..db4af53badd 100644 --- a/exchanges/coinut/coinut_wrapper.go +++ b/exchanges/coinut/coinut_wrapper.go @@ -127,7 +127,7 @@ func (c *COINUT) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - c.Websocket = stream.New() + c.Websocket = stream.NewWebsocket() c.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit c.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout c.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 25931e59983..31b0d66a86c 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -1164,7 +1164,7 @@ func (b *Base) FlushWebsocketChannels() error { // SubscribeToWebsocketChannels appends to ChannelsToSubscribe // which lets websocket.manageSubscriptions handle subscribing -func (b *Base) SubscribeToWebsocketChannels(channels []subscription.Subscription) error { +func (b *Base) SubscribeToWebsocketChannels(channels []*subscription.Subscription) error { if b.Websocket == nil { return common.ErrFunctionNotSupported } @@ -1173,7 +1173,7 @@ func (b *Base) SubscribeToWebsocketChannels(channels []subscription.Subscription // UnsubscribeToWebsocketChannels removes from ChannelsToSubscribe // which lets websocket.manageSubscriptions handle unsubscribing -func (b *Base) UnsubscribeToWebsocketChannels(channels []subscription.Subscription) error { +func (b *Base) UnsubscribeToWebsocketChannels(channels []*subscription.Subscription) error { if b.Websocket == nil { return common.ErrFunctionNotSupported } @@ -1181,7 +1181,7 @@ func (b *Base) UnsubscribeToWebsocketChannels(channels []subscription.Subscripti } // GetSubscriptions returns a copied list of subscriptions -func (b *Base) GetSubscriptions() ([]subscription.Subscription, error) { +func (b *Base) GetSubscriptions() ([]*subscription.Subscription, error) { if b.Websocket == nil { return nil, common.ErrFunctionNotSupported } diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index d41f499d48d..7c0b3014020 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -198,7 +198,7 @@ func TestSetClientProxyAddress(t *testing.T) { Name: "rawr", Requester: requester} - newBase.Websocket = stream.New() + newBase.Websocket = stream.NewWebsocket() err = newBase.SetClientProxyAddress("") if err != nil { t.Error(err) @@ -1251,7 +1251,7 @@ func TestSetupDefaults(t *testing.T) { } // Test websocket support - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() b.Features.Supports.Websocket = true err = b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ @@ -1263,8 +1263,8 @@ func TestSetupDefaults(t *testing.T) { DefaultURL: "ws://something.com", RunningURL: "ws://something.com", Connector: func() error { return nil }, - GenerateSubscriptions: func() ([]subscription.Subscription, error) { return []subscription.Subscription{}, nil }, - Subscriber: func([]subscription.Subscription) error { return nil }, + GenerateSubscriptions: func() (subscription.List, error) { return subscription.List{}, nil }, + Subscriber: func(subscription.List) error { return nil }, }) if err != nil { t.Fatal(err) @@ -1596,7 +1596,7 @@ func TestIsWebsocketEnabled(t *testing.T) { t.Error("exchange doesn't support websocket") } - b.Websocket = stream.New() + b.Websocket = stream.NewWebsocket() err := b.Websocket.Setup(&stream.WebsocketSetup{ ExchangeConfig: &config.Exchange{ Enabled: true, @@ -3279,7 +3279,7 @@ func TestSetSubscriptionsFromConfig(t *testing.T) { Features: &config.FeaturesConfig{}, }, } - subs := []*subscription.Subscription{ + subs := subscription.List{ {Channel: subscription.CandlesChannel, Interval: kline.OneDay, Enabled: true}, } b.Features.Subscriptions = subs @@ -3287,7 +3287,7 @@ func TestSetSubscriptionsFromConfig(t *testing.T) { assert.ElementsMatch(t, subs, b.Config.Features.Subscriptions, "Config Subscriptions should be updated") assert.ElementsMatch(t, subs, b.Features.Subscriptions, "Subscriptions should be the same") - subs = []*subscription.Subscription{ + subs = subscription.List{ {Channel: subscription.OrderbookChannel, Interval: kline.OneDay, Enabled: true}, } b.Config.Features.Subscriptions = subs diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index c26d04afe7c..809835d7f8c 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -60,7 +60,7 @@ var fetchedCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsConnect initiates a websocket connection func (g *Gateio) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Spot) if err != nil { @@ -680,7 +680,7 @@ func (g *Gateio) GenerateDefaultSubscriptions() ([]subscription.Subscription, er subscriptions = append(subscriptions, subscription.Subscription{ Channel: channelsToSubscribe[i], - Pair: fpair.Upper(), + Pairs: fpair.Upper(), Asset: assetType, Params: params, }) @@ -738,8 +738,8 @@ func (g *Gateio) generatePayload(event string, channelsToSubscribe []subscriptio for i := range channelsToSubscribe { var auth *WsAuthInput timestamp := time.Now() - channelsToSubscribe[i].Pair.Delimiter = currency.UnderscoreDelimiter - params := []string{channelsToSubscribe[i].Pair.String()} + channelsToSubscribe[i].Pairs.Delimiter = currency.UnderscoreDelimiter + params := []string{channelsToSubscribe[i].Pairs.String()} switch channelsToSubscribe[i].Channel { case spotOrderbookChannel: interval, okay := channelsToSubscribe[i].Params["interval"].(kline.Interval) diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 3026efc0608..adea4d2fa6b 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -194,7 +194,7 @@ func (g *Gateio) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/gateio/gateio_ws_delivery_futures.go b/exchanges/gateio/gateio_ws_delivery_futures.go index ba5c64afe3d..cf57caecb0f 100644 --- a/exchanges/gateio/gateio_ws_delivery_futures.go +++ b/exchanges/gateio/gateio_ws_delivery_futures.go @@ -45,7 +45,7 @@ var fetchedFuturesCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsDeliveryFuturesConnect initiates a websocket connection for delivery futures account func (g *Gateio) WsDeliveryFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.DeliveryFutures) if err != nil { @@ -176,7 +176,7 @@ func (g *Gateio) GenerateDeliveryFuturesDefaultSubscriptions() ([]subscription.S } subscriptions = append(subscriptions, subscription.Subscription{ Channel: channelsToSubscribe[i], - Pair: fpair.Upper(), + Pairs: fpair.Upper(), Params: params, }) } @@ -246,7 +246,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib var auth *WsAuthInput timestamp := time.Now() var params []string - params = []string{channelsToSubscribe[i].Pair.String()} + params = []string{channelsToSubscribe[i].Pairs.String()} if g.Websocket.CanUseAuthenticatedEndpoints() { switch channelsToSubscribe[i].Channel { case futuresOrdersChannel, futuresUserTradesChannel, @@ -310,7 +310,7 @@ func (g *Gateio) generateDeliveryFuturesPayload(event string, channelsToSubscrib params = append(params, intervalString) } } - if strings.HasPrefix(channelsToSubscribe[i].Pair.Quote.Upper().String(), "USDT") { + if strings.HasPrefix(channelsToSubscribe[i].Pairs.Quote.Upper().String(), "USDT") { payloads[0] = append(payloads[0], WsInput{ ID: g.Websocket.Conn.GenerateMessageID(false), Event: event, diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index c0411a5816c..f2c01fc0ae9 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -64,7 +64,7 @@ var responseFuturesStream = make(chan stream.Response) // WsFuturesConnect initiates a websocket connection for futures account func (g *Gateio) WsFuturesConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Futures) if err != nil { @@ -156,7 +156,7 @@ func (g *Gateio) GenerateFuturesDefaultSubscriptions() ([]subscription.Subscript } subscriptions[count] = subscription.Subscription{ Channel: channelsToSubscribe[i], - Pair: fpair.Upper(), + Pairs: fpair.Upper(), Params: params, } count++ @@ -324,7 +324,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe []subs var auth *WsAuthInput timestamp := time.Now() var params []string - params = []string{channelsToSubscribe[i].Pair.String()} + params = []string{channelsToSubscribe[i].Pairs.String()} if g.Websocket.CanUseAuthenticatedEndpoints() { switch channelsToSubscribe[i].Channel { case futuresOrdersChannel, futuresUserTradesChannel, @@ -388,7 +388,7 @@ func (g *Gateio) generateFuturesPayload(event string, channelsToSubscribe []subs params = append(params, intervalString) } } - if strings.HasPrefix(channelsToSubscribe[i].Pair.Quote.Upper().String(), "USDT") { + if strings.HasPrefix(channelsToSubscribe[i].Pairs.Quote.Upper().String(), "USDT") { payloads[0] = append(payloads[0], WsInput{ ID: g.Websocket.Conn.GenerateMessageID(false), Event: event, diff --git a/exchanges/gateio/gateio_ws_option.go b/exchanges/gateio/gateio_ws_option.go index d5340f0c350..d69acf5b173 100644 --- a/exchanges/gateio/gateio_ws_option.go +++ b/exchanges/gateio/gateio_ws_option.go @@ -70,7 +70,7 @@ var fetchedOptionsCurrencyPairSnapshotOrderbook = make(map[string]bool) // WsOptionsConnect initiates a websocket connection to options websocket endpoints. func (g *Gateio) WsOptionsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } err := g.CurrencyPairs.IsAssetEnabled(asset.Options) if err != nil { @@ -165,7 +165,7 @@ getEnabledPairs: } subscriptions = append(subscriptions, subscription.Subscription{ Channel: channelsToSubscribe[i], - Pair: fpair.Upper(), + Pairs: fpair.Upper(), Params: params, }) } @@ -190,7 +190,7 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe []subs optionsUnderlyingPriceChannel, optionsUnderlyingCandlesticksChannel: var uly currency.Pair - uly, err = g.GetUnderlyingFromCurrencyPair(channelsToSubscribe[i].Pair) + uly, err = g.GetUnderlyingFromCurrencyPair(channelsToSubscribe[i].Pairs) if err != nil { return nil, err } @@ -198,8 +198,8 @@ func (g *Gateio) generateOptionsPayload(event string, channelsToSubscribe []subs case optionsBalancesChannel: // options.balance channel does not require underlying or contract default: - channelsToSubscribe[i].Pair.Delimiter = currency.UnderscoreDelimiter - params = append(params, channelsToSubscribe[i].Pair.String()) + channelsToSubscribe[i].Pairs.Delimiter = currency.UnderscoreDelimiter + params = append(params, channelsToSubscribe[i].Pairs.String()) } switch channelsToSubscribe[i].Channel { case optionsOrderbookChannel: diff --git a/exchanges/gemini/gemini_test.go b/exchanges/gemini/gemini_test.go index 6a3b35e1885..7477e78ff11 100644 --- a/exchanges/gemini/gemini_test.go +++ b/exchanges/gemini/gemini_test.go @@ -556,7 +556,7 @@ func TestWsAuth(t *testing.T) { if !g.Websocket.IsEnabled() && !g.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(g) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer go g.wsReadData() diff --git a/exchanges/gemini/gemini_websocket.go b/exchanges/gemini/gemini_websocket.go index 913c856ddd4..af3ca1ba4ba 100644 --- a/exchanges/gemini/gemini_websocket.go +++ b/exchanges/gemini/gemini_websocket.go @@ -39,7 +39,7 @@ var comms = make(chan stream.Response) // WsConnect initiates a websocket connection func (g *Gemini) WsConnect() error { if !g.Websocket.IsEnabled() || !g.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer @@ -80,7 +80,7 @@ func (g *Gemini) GenerateDefaultSubscriptions() ([]subscription.Subscription, er for y := range pairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[x], - Pair: pairs[y], + Pairs: pairs[y], Asset: asset.Spot, }) } @@ -100,10 +100,10 @@ func (g *Gemini) Subscribe(channelsToSubscribe []subscription.Subscription) erro var pairs currency.Pairs for x := range channelsToSubscribe { - if pairs.Contains(channelsToSubscribe[x].Pair, true) { + if pairs.Contains(channelsToSubscribe[x].Pairs, true) { continue } - pairs = append(pairs, channelsToSubscribe[x].Pair) + pairs = append(pairs, channelsToSubscribe[x].Pairs) } fmtPairs, err := g.FormatExchangeCurrencies(pairs, asset.Spot) @@ -144,10 +144,10 @@ func (g *Gemini) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) var pairs currency.Pairs for x := range channelsToUnsubscribe { - if pairs.Contains(channelsToUnsubscribe[x].Pair, true) { + if pairs.Contains(channelsToUnsubscribe[x].Pairs, true) { continue } - pairs = append(pairs, channelsToUnsubscribe[x].Pair) + pairs = append(pairs, channelsToUnsubscribe[x].Pairs) } fmtPairs, err := g.FormatExchangeCurrencies(pairs, asset.Spot) diff --git a/exchanges/gemini/gemini_wrapper.go b/exchanges/gemini/gemini_wrapper.go index fee75d6b1a2..d2d89eb5008 100644 --- a/exchanges/gemini/gemini_wrapper.go +++ b/exchanges/gemini/gemini_wrapper.go @@ -128,7 +128,7 @@ func (g *Gemini) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - g.Websocket = stream.New() + g.Websocket = stream.NewWebsocket() g.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit g.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout g.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/hitbtc/hitbtc_test.go b/exchanges/hitbtc/hitbtc_test.go index 3e629d0a654..68b8495167b 100644 --- a/exchanges/hitbtc/hitbtc_test.go +++ b/exchanges/hitbtc/hitbtc_test.go @@ -466,7 +466,7 @@ func setupWsAuth(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer diff --git a/exchanges/hitbtc/hitbtc_websocket.go b/exchanges/hitbtc/hitbtc_websocket.go index 705f584f01c..5b030b79740 100644 --- a/exchanges/hitbtc/hitbtc_websocket.go +++ b/exchanges/hitbtc/hitbtc_websocket.go @@ -34,7 +34,7 @@ const ( // WsConnect starts a new connection with the websocket API func (h *HitBTC) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -492,7 +492,7 @@ func (h *HitBTC) GenerateDefaultSubscriptions() ([]subscription.Subscription, er enabledCurrencies[j].Delimiter = "" subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[i], - Pair: fPair, + Pairs: fPair, Asset: asset.Spot, }) } @@ -509,8 +509,8 @@ func (h *HitBTC) Subscribe(channelsToSubscribe []subscription.Subscription) erro ID: h.Websocket.Conn.GenerateMessageID(false), } - if channelsToSubscribe[i].Pair.String() != "" { - subscribe.Params.Symbol = channelsToSubscribe[i].Pair.String() + if channelsToSubscribe[i].Pairs.String() != "" { + subscribe.Params.Symbol = channelsToSubscribe[i].Pairs.String() } if strings.EqualFold(channelsToSubscribe[i].Channel, "subscribeTrades") { subscribe.Params.Limit = 100 @@ -546,7 +546,7 @@ func (h *HitBTC) Unsubscribe(channelsToUnsubscribe []subscription.Subscription) Method: unsubscribeChannel, } - unsubscribe.Params.Symbol = channelsToUnsubscribe[i].Pair.String() + unsubscribe.Params.Symbol = channelsToUnsubscribe[i].Pairs.String() if strings.EqualFold(unsubscribeChannel, "unsubscribeTrades") { unsubscribe.Params.Limit = 100 } else if strings.EqualFold(unsubscribeChannel, "unsubscribeCandles") { diff --git a/exchanges/hitbtc/hitbtc_wrapper.go b/exchanges/hitbtc/hitbtc_wrapper.go index 3b65fe649b5..7bcf7ada335 100644 --- a/exchanges/hitbtc/hitbtc_wrapper.go +++ b/exchanges/hitbtc/hitbtc_wrapper.go @@ -147,7 +147,7 @@ func (h *HitBTC) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/huobi/huobi_test.go b/exchanges/huobi/huobi_test.go index 4aafc7b74a5..16106daf566 100644 --- a/exchanges/huobi/huobi_test.go +++ b/exchanges/huobi/huobi_test.go @@ -78,7 +78,7 @@ func setupWsTests(t *testing.T) { return } if !h.Websocket.IsEnabled() && !h.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(h) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } comms = make(chan WsMessage, sharedtestvalues.WebsocketChannelOverrideCapacity) go h.wsReadData() diff --git a/exchanges/huobi/huobi_websocket.go b/exchanges/huobi/huobi_websocket.go index f601b03dbc2..d9cdc84bad5 100644 --- a/exchanges/huobi/huobi_websocket.go +++ b/exchanges/huobi/huobi_websocket.go @@ -62,7 +62,7 @@ var comms = make(chan WsMessage) // WsConnect initiates a new websocket connection func (h *HUOBI) WsConnect() error { if !h.Websocket.IsEnabled() || !h.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := h.wsDial(&dialer) @@ -538,7 +538,7 @@ func (h *HUOBI) GenerateDefaultSubscriptions() ([]subscription.Subscription, err enabledCurrencies[j].Lower().String()) subscriptions = append(subscriptions, subscription.Subscription{ Channel: channel, - Pair: enabledCurrencies[j], + Pairs: enabledCurrencies[j], }) } } diff --git a/exchanges/huobi/huobi_wrapper.go b/exchanges/huobi/huobi_wrapper.go index 3d1d576b9d2..90d70491bd8 100644 --- a/exchanges/huobi/huobi_wrapper.go +++ b/exchanges/huobi/huobi_wrapper.go @@ -202,7 +202,7 @@ func (h *HUOBI) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - h.Websocket = stream.New() + h.Websocket = stream.NewWebsocket() h.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit h.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout h.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/interfaces.go b/exchanges/interfaces.go index b7cb5b5c3f6..1745c1492fb 100644 --- a/exchanges/interfaces.go +++ b/exchanges/interfaces.go @@ -71,9 +71,9 @@ type IBotExchange interface { EnableRateLimiter() error GetServerTime(ctx context.Context, ai asset.Item) (time.Time, error) GetWebsocket() (*stream.Websocket, error) - SubscribeToWebsocketChannels(channels []subscription.Subscription) error - UnsubscribeToWebsocketChannels(channels []subscription.Subscription) error - GetSubscriptions() ([]subscription.Subscription, error) + SubscribeToWebsocketChannels(channels []*subscription.Subscription) error + UnsubscribeToWebsocketChannels(channels []*subscription.Subscription) error + GetSubscriptions() ([]*subscription.Subscription, error) FlushWebsocketChannels() error AuthenticateWebsocket(ctx context.Context) error GetOrderExecutionLimits(a asset.Item, cp currency.Pair) (order.MinMaxLevel, error) diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index 263d70bf571..c99627fc3fb 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -1215,7 +1215,7 @@ func setupWsTests(t *testing.T) { return } if !k.Websocket.IsEnabled() && !k.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(k) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := k.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -1251,7 +1251,7 @@ func TestWebsocketSubscribe(t *testing.T) { err := k.Subscribe([]subscription.Subscription{ { Channel: defaultSubscribedChannels[0], - Pair: currency.NewPairWithDelimiter("XBT", "USD", "/"), + Pairs: currency.Pairs{currency.NewPairWithDelimiter("XBT", "USD", "/")}, }, }) if err != nil { diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index 787d52b2c31..5657c2d908c 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -87,7 +87,7 @@ var cancelOrdersStatus = make(map[int64]*struct { // WsConnect initiates a websocket connection func (k *Kraken) WsConnect() error { if !k.Websocket.IsEnabled() || !k.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer @@ -859,7 +859,7 @@ func (k *Kraken) wsProcessOrderBook(channelData *WebsocketChannelData, data map[ } }(&subscription.Subscription{ Channel: krakenWsOrderbook, - Pair: outbound, + Pairs: currency.Pairs{outbound}, Asset: asset.Spot, }) return err @@ -1221,7 +1221,7 @@ func (k *Kraken) GenerateDefaultSubscriptions() ([]subscription.Subscription, er enabledPairs[j].Delimiter = "/" subscriptions = append(subscriptions, subscription.Subscription{ Channel: defaultSubscribedChannels[i], - Pair: enabledPairs[j], + Pairs: currency.Pairs{enabledPairs[j]}, Asset: asset.Spot, }) } @@ -1248,7 +1248,7 @@ channels: } for j := range *s { - (*s)[j].Pairs = append((*s)[j].Pairs, channelsToSubscribe[i].Pair.String()) + (*s)[j].Pairs = append((*s)[j].Pairs, channelsToSubscribe[i].Pairs[0].String()) (*s)[j].Channels = append((*s)[j].Channels, channelsToSubscribe[i]) continue channels } @@ -1264,8 +1264,8 @@ channels: if channelsToSubscribe[i].Channel == "book" { outbound.Subscription.Depth = krakenWsOrderbookDepth } - if !channelsToSubscribe[i].Pair.IsEmpty() { - outbound.Pairs = []string{channelsToSubscribe[i].Pair.String()} + if !channelsToSubscribe[i].Pairs[0].IsEmpty() { + outbound.Pairs = []string{channelsToSubscribe[i].Pairs[0].String()} } if common.StringDataContains(authenticatedChannels, channelsToSubscribe[i].Channel) { outbound.Subscription.Token = authToken @@ -1306,7 +1306,7 @@ channels: for y := range unsubs { if unsubs[y].Subscription.Name == channelsToUnsubscribe[x].Channel { unsubs[y].Pairs = append(unsubs[y].Pairs, - channelsToUnsubscribe[x].Pair.String()) + channelsToUnsubscribe[x].Pairs[0].String()) unsubs[y].Channels = append(unsubs[y].Channels, channelsToUnsubscribe[x]) continue channels @@ -1326,7 +1326,7 @@ channels: unsub := WebsocketSubscriptionEventRequest{ Event: krakenWsUnsubscribe, - Pairs: []string{channelsToUnsubscribe[x].Pair.String()}, + Pairs: []string{channelsToUnsubscribe[x].Pairs[0].String()}, Subscription: WebsocketSubscriptionData{ Name: channelsToUnsubscribe[x].Channel, Depth: depth, diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index de631e0d904..5875917592c 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -209,7 +209,7 @@ func (k *Kraken) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - k.Websocket = stream.New() + k.Websocket = stream.NewWebsocket() k.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit k.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout k.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/kucoin/kucoin_test.go b/exchanges/kucoin/kucoin_test.go index 302af45eb0c..bc46ac1a241 100644 --- a/exchanges/kucoin/kucoin_test.go +++ b/exchanges/kucoin/kucoin_test.go @@ -2542,7 +2542,7 @@ func TestProcessMarketSnapshot(t *testing.T) { func TestSubscribeMarketSnapshot(t *testing.T) { t.Parallel() setupWS() - err := ku.Subscribe([]subscription.Subscription{{Channel: marketSymbolSnapshotChannel, Pair: currency.Pair{Base: currency.BTC}}}) + err := ku.Subscribe([]subscription.Subscription{{Channel: marketSymbolSnapshotChannel, Pairs: currency.Pair{Base: currency.BTC}}}) assert.NoError(t, err, "Subscribe to MarketSnapshot should not error") } diff --git a/exchanges/kucoin/kucoin_websocket.go b/exchanges/kucoin/kucoin_websocket.go index 3f8a0c783ca..917bc1a90e3 100644 --- a/exchanges/kucoin/kucoin_websocket.go +++ b/exchanges/kucoin/kucoin_websocket.go @@ -97,7 +97,7 @@ var ( // WsConnect creates a new websocket connection. func (ku *Kucoin) WsConnect() error { if !ku.Websocket.IsEnabled() || !ku.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } fetchedFuturesSnapshotOrderbook = map[string]bool{} var dialer websocket.Dialer diff --git a/exchanges/kucoin/kucoin_wrapper.go b/exchanges/kucoin/kucoin_wrapper.go index 8a98a77a02d..d767a2cf441 100644 --- a/exchanges/kucoin/kucoin_wrapper.go +++ b/exchanges/kucoin/kucoin_wrapper.go @@ -195,7 +195,7 @@ func (ku *Kucoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - ku.Websocket = stream.New() + ku.Websocket = stream.NewWebsocket() ku.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit ku.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout ku.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okcoin/okcoin_websocket.go b/exchanges/okcoin/okcoin_websocket.go index e714aba7d10..f0d0ada2b47 100644 --- a/exchanges/okcoin/okcoin_websocket.go +++ b/exchanges/okcoin/okcoin_websocket.go @@ -74,7 +74,7 @@ func isAuthenticatedChannel(channel string) bool { // WsConnect initiates a websocket connection func (o *Okcoin) WsConnect() error { if !o.Websocket.IsEnabled() || !o.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 @@ -584,7 +584,7 @@ func (o *Okcoin) wsProcessOrderbook(respRaw []byte, obChannel string) error { func (o *Okcoin) ReSubscribeSpecificOrderbook(obChannel string, p currency.Pair) error { subscription := []subscription.Subscription{{ Channel: obChannel, - Pair: p, + Pairs: p, }} if err := o.Unsubscribe(subscription); err != nil { return err @@ -801,7 +801,7 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er for p := range pairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[s], - Pair: pairs[p], + Pairs: pairs[p], }) } case wsStatus: @@ -836,7 +836,7 @@ func (o *Okcoin) GenerateDefaultSubscriptions() ([]subscription.Subscription, er for p := range pairs { subscriptions = append(subscriptions, subscription.Subscription{ Channel: channels[s], - Pair: pairs[p], + Pairs: pairs[p], Asset: asset.Spot, }) } @@ -891,8 +891,8 @@ func (o *Okcoin) handleSubscriptions(operation string, subs []subscription.Subsc if subs[i].Asset != asset.Empty { argument["instType"] = strings.ToUpper(subs[i].Asset.String()) } - if !subs[i].Pair.IsEmpty() { - argument["instId"] = subs[i].Pair.String() + if !subs[i].Pairs.IsEmpty() { + argument["instId"] = subs[i].Pairs.String() } if authenticatedChannelSubscription { authTemp.Arguments = append(authTemp.Arguments, argument) diff --git a/exchanges/okcoin/okcoin_wrapper.go b/exchanges/okcoin/okcoin_wrapper.go index 8519cdc1b11..af0655ef3f1 100644 --- a/exchanges/okcoin/okcoin_wrapper.go +++ b/exchanges/okcoin/okcoin_wrapper.go @@ -150,7 +150,7 @@ func (o *Okcoin) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - o.Websocket = stream.New() + o.Websocket = stream.NewWebsocket() o.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit o.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout o.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/okcoin/okcoin_ws_trade.go b/exchanges/okcoin/okcoin_ws_trade.go index cd0c7f86605..85b6102e24d 100644 --- a/exchanges/okcoin/okcoin_ws_trade.go +++ b/exchanges/okcoin/okcoin_ws_trade.go @@ -130,7 +130,7 @@ func (o *Okcoin) WsAmendMultipleOrder(args []AmendTradeOrderRequestParam) ([]Ame func (o *Okcoin) SendWebsocketRequest(operation string, data, result interface{}, authenticated bool) error { switch { case !o.Websocket.IsEnabled(): - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled case !o.Websocket.IsConnected(): return stream.ErrNotConnected case !o.Websocket.CanUseAuthenticatedEndpoints() && authenticated: diff --git a/exchanges/okx/okx_websocket.go b/exchanges/okx/okx_websocket.go index b4d211eec01..54bc6b0e486 100644 --- a/exchanges/okx/okx_websocket.go +++ b/exchanges/okx/okx_websocket.go @@ -216,7 +216,7 @@ const ( // WsConnect initiates a websocket connection func (ok *Okx) WsConnect() error { if !ok.Websocket.IsEnabled() || !ok.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer dialer.ReadBufferSize = 8192 diff --git a/exchanges/okx/okx_wrapper.go b/exchanges/okx/okx_wrapper.go index 0b5472b9e65..64e2e269877 100644 --- a/exchanges/okx/okx_wrapper.go +++ b/exchanges/okx/okx_wrapper.go @@ -190,7 +190,7 @@ func (ok *Okx) SetDefaults() { log.Errorln(log.ExchangeSys, err) } - ok.Websocket = stream.New() + ok.Websocket = stream.NewWebsocket() ok.WebsocketResponseMaxLimit = okxWebsocketResponseMaxLimit ok.WebsocketResponseCheckTimeout = okxWebsocketResponseMaxLimit ok.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/poloniex/poloniex_test.go b/exchanges/poloniex/poloniex_test.go index bab0ffd4a71..d74e2353d67 100644 --- a/exchanges/poloniex/poloniex_test.go +++ b/exchanges/poloniex/poloniex_test.go @@ -548,7 +548,7 @@ func TestGenerateNewAddress(t *testing.T) { func TestWsAuth(t *testing.T) { t.Parallel() if !p.Websocket.IsEnabled() && !p.API.AuthenticatedWebsocketSupport || !sharedtestvalues.AreAPICredentialsSet(p) { - t.Skip(stream.WebsocketNotEnabled) + t.Skip(stream.ErrWebsocketNotEnabled.Error()) } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) diff --git a/exchanges/poloniex/poloniex_websocket.go b/exchanges/poloniex/poloniex_websocket.go index 1be429a58ef..dd57a762016 100644 --- a/exchanges/poloniex/poloniex_websocket.go +++ b/exchanges/poloniex/poloniex_websocket.go @@ -55,7 +55,7 @@ var ( // WsConnect initiates a websocket connection func (p *Poloniex) WsConnect() error { if !p.Websocket.IsEnabled() || !p.IsEnabled() { - return errors.New(stream.WebsocketNotEnabled) + return stream.ErrWebsocketNotEnabled } var dialer websocket.Dialer err := p.Websocket.Conn.Dial(&dialer, http.Header{}) @@ -562,7 +562,7 @@ func (p *Poloniex) GenerateDefaultSubscriptions() ([]subscription.Subscription, enabledPairs[j].Delimiter = currency.UnderscoreDelimiter subscriptions = append(subscriptions, subscription.Subscription{ Channel: "orderbook", - Pair: enabledPairs[j], + Pairs: enabledPairs[j], Asset: asset.Spot, }) } @@ -599,7 +599,7 @@ channels: sub[i].Channel): subscriptionRequest.Channel = wsTickerDataID default: - subscriptionRequest.Channel = sub[i].Pair.String() + subscriptionRequest.Channel = sub[i].Pairs.String() } err := p.Websocket.Conn.SendJSONMessage(subscriptionRequest) @@ -646,7 +646,7 @@ channels: unsub[i].Channel): unsubscriptionRequest.Channel = wsTickerDataID default: - unsubscriptionRequest.Channel = unsub[i].Pair.String() + unsubscriptionRequest.Channel = unsub[i].Pairs.String() } err := p.Websocket.Conn.SendJSONMessage(unsubscriptionRequest) if err != nil { diff --git a/exchanges/poloniex/poloniex_wrapper.go b/exchanges/poloniex/poloniex_wrapper.go index 97eb9293d56..57f28fac98b 100644 --- a/exchanges/poloniex/poloniex_wrapper.go +++ b/exchanges/poloniex/poloniex_wrapper.go @@ -159,7 +159,7 @@ func (p *Poloniex) SetDefaults() { if err != nil { log.Errorln(log.ExchangeSys, err) } - p.Websocket = stream.New() + p.Websocket = stream.NewWebsocket() p.WebsocketResponseMaxLimit = exchange.DefaultWebsocketResponseMaxLimit p.WebsocketResponseCheckTimeout = exchange.DefaultWebsocketResponseCheckTimeout p.WebsocketOrderbookBufferLimit = exchange.DefaultWebsocketOrderbookBufferLimit diff --git a/exchanges/sharedtestvalues/sharedtestvalues.go b/exchanges/sharedtestvalues/sharedtestvalues.go index 60d51d82bfe..71792bbd977 100644 --- a/exchanges/sharedtestvalues/sharedtestvalues.go +++ b/exchanges/sharedtestvalues/sharedtestvalues.go @@ -57,7 +57,6 @@ func GetWebsocketStructChannelOverride() chan struct{} { // NewTestWebsocket returns a test websocket object func NewTestWebsocket() *stream.Websocket { return &stream.Websocket{ - Init: true, DataHandler: make(chan interface{}, WebsocketChannelOverrideCapacity), ToRoutine: make(chan interface{}, 1000), TrafficAlert: make(chan struct{}), @@ -166,10 +165,14 @@ func TestFixtureToDataHandler(t *testing.T, seed, e exchange.IBotExchange, fixtu assert.NoError(t, err, "Loading currency pairs should not error") b.Name = "fixture" - b.Websocket = &stream.Websocket{ - Wg: new(sync.WaitGroup), - DataHandler: make(chan interface{}, 128), + + if b.Websocket == nil { + b.Websocket = &stream.Websocket{ + Wg: new(sync.WaitGroup), + DataHandler: make(chan interface{}, 128), + } } + b.API.Endpoints = b.NewEndpoints() fixture, err := os.Open(fixturePath) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index aefc3400f60..d0029b208db 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -5,68 +5,68 @@ import ( "fmt" "net" "net/url" - "sync" "time" "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/log" ) const ( - defaultJobBuffer = 5000 - // defaultTrafficPeriod defines a period of pause for the traffic monitor, - // as there are periods with large incoming traffic alerts which requires a - // timer reset, this limits work on this routine to a more effective rate - // of check. - defaultTrafficPeriod = time.Second + jobBuffer = 5000 ) +// Public websocket errors var ( - // ErrSubscriptionNotFound defines an error when a subscription is not found - ErrSubscriptionNotFound = errors.New("subscription not found") - // ErrSubscribedAlready defines an error when a channel is already subscribed - ErrSubscribedAlready = errors.New("duplicate subscription") - // ErrSubscriptionFailure defines an error when a subscription fails - ErrSubscriptionFailure = errors.New("subscription failure") - // ErrSubscriptionNotSupported defines an error when a subscription channel is not supported by an exchange + ErrWebsocketNotEnabled = errors.New("websocket not enabled") + ErrSubscriptionFailure = errors.New("subscription failure") ErrSubscriptionNotSupported = errors.New("subscription channel not supported ") - // ErrUnsubscribeFailure defines an error when a unsubscribe fails - ErrUnsubscribeFailure = errors.New("unsubscribe failure") - // ErrChannelInStateAlready defines an error when a subscription channel is already in a new state - ErrChannelInStateAlready = errors.New("channel already in state") - // ErrAlreadyDisabled is returned when you double-disable the websocket - ErrAlreadyDisabled = errors.New("websocket already disabled") - // ErrNotConnected defines an error when websocket is not connected - ErrNotConnected = errors.New("websocket is not connected") + ErrUnsubscribeFailure = errors.New("unsubscribe failure") + ErrAlreadyDisabled = errors.New("websocket already disabled") + ErrNotConnected = errors.New("websocket is not connected") +) +// Private websocket errors +var ( errAlreadyRunning = errors.New("connection monitor is already running") errExchangeConfigIsNil = errors.New("exchange config is nil") + errExchangeConfigEmpty = errors.New("exchange config is empty") errWebsocketIsNil = errors.New("websocket is nil") errWebsocketSetupIsNil = errors.New("websocket setup is nil") errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") errConfigFeaturesIsNil = errors.New("exchange config features is nil") errDefaultURLIsEmpty = errors.New("default url is empty") errRunningURLIsEmpty = errors.New("running url cannot be empty") errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameUnset = errors.New("exchange config name unset") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") errClosedConnection = errors.New("use of closed network connection") errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") errNoSubscriptionsSupplied = errors.New("no subscriptions supplied") - errChannelAlreadySubscribed = errors.New("channel already subscribed") - errInvalidChannelState = errors.New("invalid Channel state") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") ) -var globalReporter Reporter +var ( + globalReporter Reporter + trafficCheckInterval = 100 * time.Millisecond +) // SetupGlobalReporter sets a reporter interface to be used // for all exchange requests @@ -74,17 +74,17 @@ func SetupGlobalReporter(r Reporter) { globalReporter = r } -// New initialises the websocket struct -func New() *Websocket { +// NewWebsocket initialises the websocket struct +func NewWebsocket() *Websocket { return &Websocket{ - Init: true, - DataHandler: make(chan interface{}, defaultJobBuffer), - ToRoutine: make(chan interface{}, defaultJobBuffer), - TrafficAlert: make(chan struct{}), + DataHandler: make(chan interface{}, jobBuffer), + ToRoutine: make(chan interface{}, jobBuffer), + ShutdownC: make(chan struct{}), + TrafficAlert: make(chan struct{}, 1), ReadMessageErrors: make(chan error), - Subscribe: make(chan []subscription.Subscription), - Unsubscribe: make(chan []subscription.Subscription), Match: NewMatch(), + subscriptions: subscription.NewStore(), + features: &protocol.Features{}, } } @@ -98,7 +98,10 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return errWebsocketSetupIsNil } - if !w.Init { + w.m.Lock() + defer w.m.Unlock() + + if w.IsInitialised() { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyInitialised) } @@ -107,7 +110,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { } if s.ExchangeConfig.Name == "" { - return errExchangeConfigNameUnset + return errExchangeConfigNameEmpty } w.exchangeName = s.ExchangeConfig.Name w.verbose = s.ExchangeConfig.Verbose @@ -120,7 +123,7 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { if s.ExchangeConfig.Features == nil { return fmt.Errorf("%s %w", w.exchangeName, errConfigFeaturesIsNil) } - w.enabled = s.ExchangeConfig.Features.Enabled.Websocket + w.setEnabled(s.ExchangeConfig.Features.Enabled.Websocket) if s.Connector == nil { return fmt.Errorf("%s %w", w.exchangeName, errWebsocketConnectorUnset) @@ -174,7 +177,6 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { w.trafficTimeout = s.ExchangeConfig.WebsocketTrafficTimeout w.ShutdownC = make(chan struct{}) - w.Wg = new(sync.WaitGroup) w.SetCanUseAuthenticatedEndpoints(s.ExchangeConfig.API.AuthenticatedWebsocketSupport) if err := w.Orderbook.Setup(s.ExchangeConfig, &s.OrderbookBufferConfig, w.DataHandler); err != nil { @@ -188,28 +190,30 @@ func (w *Websocket) Setup(s *WebsocketSetup) error { return fmt.Errorf("%s %w", w.exchangeName, errInvalidMaxSubscriptions) } w.MaxSubscriptionsPerConnection = s.MaxWebsocketSubscriptionsPerConnection + w.setState(disconnectedState) + return nil } // SetupNewConnection sets up an auth or unauth streaming connection func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { if w == nil { - return errors.New("setting up new connection error: websocket is nil") + return fmt.Errorf("%w: %w", errConnSetup, errWebsocketIsNil) } if c == (ConnectionSetup{}) { - return errors.New("setting up new connection error: websocket connection configuration empty") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigEmpty) } if w.exchangeName == "" { - return errors.New("setting up new connection error: exchange name not set, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errExchangeConfigNameEmpty) } if w.TrafficAlert == nil { - return errors.New("setting up new connection error: traffic alert is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errTrafficAlertNil) } if w.ReadMessageErrors == nil { - return errors.New("setting up new connection error: read message errors is nil, please call setup first") + return fmt.Errorf("%w: %w", errConnSetup, errReadMessageErrorsNil) } connectionURL := w.GetWebsocketURL() @@ -234,7 +238,7 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { Traffic: w.TrafficAlert, readMessageErrors: w.ReadMessageErrors, ShutdownC: w.ShutdownC, - Wg: w.Wg, + Wg: &w.Wg, Match: w.Match, RateLimit: c.RateLimit, Reporter: c.ConnectionLevelReporter, @@ -253,48 +257,41 @@ func (w *Websocket) SetupNewConnection(c ConnectionSetup) error { // function func (w *Websocket) Connect() error { if w.connector == nil { - return errors.New("websocket connect function not set, cannot continue") + return errNoConnectFunc } w.m.Lock() defer w.m.Unlock() if !w.IsEnabled() { - return errors.New(WebsocketNotEnabled) + return ErrWebsocketNotEnabled } if w.IsConnecting() { - return fmt.Errorf("%v Websocket already attempting to connect", - w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyReconnecting) } if w.IsConnected() { - return fmt.Errorf("%v Websocket already connected", - w.exchangeName) + return fmt.Errorf("%v %w", w.exchangeName, errAlreadyConnected) } - w.subscriptionMutex.Lock() - w.subscriptions = subscriptionMap{} - w.subscriptionMutex.Unlock() + if w.subscriptions == nil { + return common.ErrNilPointer + } + w.subscriptions.Clear() w.dataMonitor() w.trafficMonitor() - w.setConnectingStatus(true) + w.setState(connectingState) err := w.connector() if err != nil { - w.setConnectingStatus(false) - return fmt.Errorf("%v Error connecting %s", - w.exchangeName, err) + w.setState(disconnectedState) + return fmt.Errorf("%v Error connecting %w", w.exchangeName, err) } - w.setConnectedStatus(true) - w.setConnectingStatus(false) - w.setInit(true) + w.setState(connectedState) if !w.IsConnectionMonitorRunning() { err = w.connectionMonitor() if err != nil { - log.Errorf(log.WebsocketMgr, - "%s cannot start websocket connection monitor %v", - w.GetName(), - err) + log.Errorf(log.WebsocketMgr, "%s cannot start websocket connection monitor %v", w.GetName(), err) } } @@ -302,24 +299,20 @@ func (w *Websocket) Connect() error { if err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } - if len(subs) == 0 { - return nil - } - err = w.checkSubscriptions(subs) - if err != nil { - return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) - } - err = w.Subscriber(subs) - if err != nil { - return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) + if len(subs) != 0 { + if err := w.SubscribeToChannels(subs); err != nil { + return err + } } + return nil } // Disable disables the exchange websocket protocol +// Note that connectionMonitor will be responsible for shutting down the websocket after disabling func (w *Websocket) Disable() error { if !w.IsEnabled() { - return fmt.Errorf("%w for exchange '%s'", ErrAlreadyDisabled, w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrAlreadyDisabled) } w.setEnabled(false) @@ -329,8 +322,7 @@ func (w *Websocket) Disable() error { // Enable enables the exchange websocket protocol func (w *Websocket) Enable() error { if w.IsConnected() || w.IsEnabled() { - return fmt.Errorf("websocket is already enabled for exchange %s", - w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, errWebsocketAlreadyEnabled) } w.setEnabled(true) @@ -369,9 +361,7 @@ func (w *Websocket) dataMonitor() { case <-w.ShutdownC: return default: - log.Warnf(log.WebsocketMgr, - "%s exchange backlog in websocket processing detected", - w.exchangeName) + log.Warnf(log.WebsocketMgr, "%s exchange backlog in websocket processing detected", w.exchangeName) select { case w.ToRoutine <- d: case <-w.ShutdownC: @@ -388,34 +378,25 @@ func (w *Websocket) connectionMonitor() error { if w.checkAndSetMonitorRunning() { return errAlreadyRunning } - w.fieldMutex.RLock() delay := w.connectionMonitorDelay - w.fieldMutex.RUnlock() go func() { timer := time.NewTimer(delay) for { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: running connection monitor cycle\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: running connection monitor cycle", w.exchangeName) } if !w.IsEnabled() { if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connectionMonitor - websocket disabled, shutting down\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connectionMonitor - websocket disabled, shutting down", w.exchangeName) } if w.IsConnected() { - err := w.Shutdown() - if err != nil { + if err := w.Shutdown(); err != nil { log.Errorln(log.WebsocketMgr, err) } } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: connection monitor exiting\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: connection monitor exiting", w.exchangeName) } timer.Stop() w.setConnectionMonitorRunning(false) @@ -424,11 +405,8 @@ func (w *Websocket) connectionMonitor() error { select { case err := <-w.ReadMessageErrors: if IsDisconnectionError(err) { - w.setInit(false) - log.Warnf(log.WebsocketMgr, - "%v websocket has been disconnected. Reason: %v", - w.exchangeName, err) - w.setConnectedStatus(false) + log.Warnf(log.WebsocketMgr, "%v websocket has been disconnected. Reason: %v", w.exchangeName, err) + w.setState(disconnectedState) } w.DataHandler <- err @@ -459,21 +437,16 @@ func (w *Websocket) Shutdown() error { defer w.m.Unlock() if !w.IsConnected() { - return fmt.Errorf("%v websocket: cannot shutdown %w", - w.exchangeName, - ErrNotConnected) + return fmt.Errorf("%v %w: %w", w.exchangeName, errCannotShutdown, ErrNotConnected) } // TODO: Interrupt connection and or close connection when it is re-established. if w.IsConnecting() { - return fmt.Errorf("%v websocket: cannot shutdown, in the process of reconnection", - w.exchangeName) + return fmt.Errorf("%v %w: %w ", w.exchangeName, errCannotShutdown, errAlreadyReconnecting) } if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: shutting down websocket\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: shutting down websocket", w.exchangeName) } defer w.Orderbook.FlushBuffer() @@ -491,19 +464,15 @@ func (w *Websocket) Shutdown() error { } // flush any subscriptions from last connection if needed - w.subscriptionMutex.Lock() - w.subscriptions = subscriptionMap{} - w.subscriptionMutex.Unlock() + w.subscriptions.Clear() + + w.setState(disconnectedState) close(w.ShutdownC) w.Wg.Wait() w.ShutdownC = make(chan struct{}) - w.setConnectedStatus(false) - w.setConnectingStatus(false) if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: completed websocket shutdown\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: completed websocket shutdown", w.exchangeName) } return nil } @@ -511,11 +480,11 @@ func (w *Websocket) Shutdown() error { // FlushChannels flushes channel subscriptions when there is a pair/asset change func (w *Websocket) FlushChannels() error { if !w.IsEnabled() { - return fmt.Errorf("%s websocket: service not enabled", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrWebsocketNotEnabled) } if !w.IsConnected() { - return fmt.Errorf("%s websocket: service not connected", w.exchangeName) + return fmt.Errorf("%s %w", w.exchangeName, ErrNotConnected) } if w.features.Subscribe { @@ -551,9 +520,7 @@ func (w *Websocket) FlushChannels() error { if len(newsubs) != 0 { // Purge subscription list as there will be conflicts - w.subscriptionMutex.Lock() - w.subscriptions = subscriptionMap{} - w.subscriptionMutex.Unlock() + w.subscriptions.Clear() return w.SubscribeToChannels(newsubs) } return nil @@ -565,9 +532,9 @@ func (w *Websocket) FlushChannels() error { return w.Connect() } -// trafficMonitor uses a timer of WebsocketTrafficLimitTime and once it expires, -// it will reconnect if the TrafficAlert channel has not received any data. The -// trafficTimer will reset on each traffic alert +// trafficMonitor waits trafficCheckInterval before checking for a trafficAlert +// 1 slot buffer means that connection will only write to trafficAlert once per trafficCheckInterval to avoid read/write flood in high traffic +// Otherwise we Shutdown the connection after trafficTimeout, unless it's connecting. connectionMonitor is responsible for Connecting again func (w *Websocket) trafficMonitor() { if w.IsTrafficMonitorRunning() { return @@ -576,183 +543,121 @@ func (w *Websocket) trafficMonitor() { w.Wg.Add(1) go func() { - var trafficTimer = time.NewTimer(w.trafficTimeout) - pause := make(chan struct{}) + t := time.NewTimer(w.trafficTimeout) for { select { case <-w.ShutdownC: if w.verbose { - log.Debugf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown message received\n", - w.exchangeName) + log.Debugf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown message received", w.exchangeName) } - trafficTimer.Stop() + t.Stop() w.setTrafficMonitorRunning(false) w.Wg.Done() return - case <-w.TrafficAlert: - if !trafficTimer.Stop() { - select { - case <-trafficTimer.C: - default: + case <-time.After(trafficCheckInterval): + select { + case <-w.TrafficAlert: + if !t.Stop() { + <-t.C } + t.Reset(w.trafficTimeout) + default: + } + case <-t.C: + checkAgain := w.IsConnecting() + select { + case <-w.TrafficAlert: + checkAgain = true + default: + } + if checkAgain { + t.Reset(w.trafficTimeout) + break } - w.setConnectedStatus(true) - trafficTimer.Reset(w.trafficTimeout) - case <-trafficTimer.C: // Falls through when timer runs out if w.verbose { - log.Warnf(log.WebsocketMgr, - "%v websocket: has not received a traffic alert in %v. Reconnecting", - w.exchangeName, - w.trafficTimeout) + log.Warnf(log.WebsocketMgr, "%v websocket: has not received a traffic alert in %v. Reconnecting", w.exchangeName, w.trafficTimeout) } - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() // without this the w.Shutdown() call below will deadlock - if !w.IsConnecting() && w.IsConnected() { + w.setTrafficMonitorRunning(false) // Cannot defer lest Connect is called after Shutdown but before deferred call + w.Wg.Done() // Without this the w.Shutdown() call below will deadlock + if w.IsConnected() { err := w.Shutdown() if err != nil { - log.Errorf(log.WebsocketMgr, - "%v websocket: trafficMonitor shutdown err: %s", - w.exchangeName, err) + log.Errorf(log.WebsocketMgr, "%v websocket: trafficMonitor shutdown err: %s", w.exchangeName, err) } } - return } - - if w.IsConnected() { - // Routine pausing mechanism - go func(p chan<- struct{}) { - time.Sleep(defaultTrafficPeriod) - select { - case p <- struct{}{}: - default: - } - }(pause) - select { - case <-w.ShutdownC: - trafficTimer.Stop() - w.setTrafficMonitorRunning(false) - w.Wg.Done() - return - case <-pause: - } - } } }() } -func (w *Websocket) setConnectedStatus(b bool) { - w.fieldMutex.Lock() - w.connected = b - w.fieldMutex.Unlock() +func (w *Websocket) setState(s uint32) { + w.state.Store(s) } -// IsConnected returns status of connection -func (w *Websocket) IsConnected() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connected +// IsInitialised returns whether the websocket has been Setup() already +func (w *Websocket) IsInitialised() bool { + return w.state.Load() != uninitialisedState } -func (w *Websocket) setConnectingStatus(b bool) { - w.fieldMutex.Lock() - w.connecting = b - w.fieldMutex.Unlock() +// IsConnected returns whether the websocket is connected +func (w *Websocket) IsConnected() bool { + return w.state.Load() == connectedState } -// IsConnecting returns status of connecting +// IsConnecting returns whether the websocket is connecting func (w *Websocket) IsConnecting() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connecting + return w.state.Load() == connectingState } func (w *Websocket) setEnabled(b bool) { - w.fieldMutex.Lock() - w.enabled = b - w.fieldMutex.Unlock() + w.enabled.Store(b) } -// IsEnabled returns status of enabled +// IsEnabled returns whether the websocket is enabled func (w *Websocket) IsEnabled() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.enabled -} - -func (w *Websocket) setInit(b bool) { - w.fieldMutex.Lock() - w.Init = b - w.fieldMutex.Unlock() -} - -// IsInit returns status of init -func (w *Websocket) IsInit() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.Init + return w.enabled.Load() } func (w *Websocket) setTrafficMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.trafficMonitorRunning = b - w.fieldMutex.Unlock() + w.trafficMonitorRunning.Store(b) } // IsTrafficMonitorRunning returns status of the traffic monitor func (w *Websocket) IsTrafficMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.trafficMonitorRunning + return w.trafficMonitorRunning.Load() } func (w *Websocket) checkAndSetMonitorRunning() (alreadyRunning bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - if w.connectionMonitorRunning { - return true - } - w.connectionMonitorRunning = true - return false + return !w.connectionMonitorRunning.CompareAndSwap(false, true) } func (w *Websocket) setConnectionMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.connectionMonitorRunning = b - w.fieldMutex.Unlock() + w.connectionMonitorRunning.Store(b) } // IsConnectionMonitorRunning returns status of connection monitor func (w *Websocket) IsConnectionMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.connectionMonitorRunning + return w.connectionMonitorRunning.Load() } func (w *Websocket) setDataMonitorRunning(b bool) { - w.fieldMutex.Lock() - w.dataMonitorRunning = b - w.fieldMutex.Unlock() + w.dataMonitorRunning.Store(b) } // IsDataMonitorRunning returns status of data monitor func (w *Websocket) IsDataMonitorRunning() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.dataMonitorRunning + return w.dataMonitorRunning.Load() } // CanUseAuthenticatedWebsocketForWrapper Handles a common check to // verify whether a wrapper can use an authenticated websocket endpoint func (w *Websocket) CanUseAuthenticatedWebsocketForWrapper() bool { - if w.IsConnected() && w.CanUseAuthenticatedEndpoints() { - return true - } else if w.IsConnected() && !w.CanUseAuthenticatedEndpoints() { - log.Infof(log.WebsocketMgr, - WebsocketNotAuthenticatedUsingRest, - w.exchangeName) + if w.IsConnected() { + if w.CanUseAuthenticatedEndpoints() { + return true + } + log.Infof(log.WebsocketMgr, WebsocketNotAuthenticatedUsingRest, w.exchangeName) } return false } @@ -820,28 +725,22 @@ func (w *Websocket) GetWebsocketURL() string { // SetProxyAddress sets websocket proxy address func (w *Websocket) SetProxyAddress(proxyAddr string) error { + w.m.Lock() + if proxyAddr != "" { - _, err := url.ParseRequestURI(proxyAddr) - if err != nil { - return fmt.Errorf("%v websocket: cannot set proxy address error '%v'", - w.exchangeName, - err) + if _, err := url.ParseRequestURI(proxyAddr); err != nil { + w.m.Unlock() + return fmt.Errorf("%v websocket: cannot set proxy address: %w", w.exchangeName, err) } if w.proxyAddr == proxyAddr { - return fmt.Errorf("%v websocket: cannot set proxy address to the same address '%v'", - w.exchangeName, - w.proxyAddr) + w.m.Unlock() + return fmt.Errorf("%v websocket: %w '%v'", w.exchangeName, errSameProxyAddress, w.proxyAddr) } - log.Debugf(log.ExchangeSys, - "%s websocket: setting websocket proxy: %s\n", - w.exchangeName, - proxyAddr) + log.Debugf(log.ExchangeSys, "%s websocket: setting websocket proxy: %s", w.exchangeName, proxyAddr) } else { - log.Debugf(log.ExchangeSys, - "%s websocket: removing websocket proxy\n", - w.exchangeName) + log.Debugf(log.ExchangeSys, "%s websocket: removing websocket proxy", w.exchangeName) } if w.Conn != nil { @@ -852,15 +751,17 @@ func (w *Websocket) SetProxyAddress(proxyAddr string) error { } w.proxyAddr = proxyAddr - if w.IsInit() && w.IsEnabled() { - if w.IsConnected() { - err := w.Shutdown() - if err != nil { - return err - } + + if w.IsConnected() { + w.m.Unlock() + if err := w.Shutdown(); err != nil { + return err } return w.Connect() } + + w.m.Unlock() + return nil } @@ -876,179 +777,127 @@ func (w *Websocket) GetName() string { // GetChannelDifference finds the difference between the subscribed channels // and the new subscription list when pairs are disabled or enabled. -func (w *Websocket) GetChannelDifference(genSubs []subscription.Subscription) (sub, unsub []subscription.Subscription) { - w.subscriptionMutex.RLock() - unsubMap := make(map[any]subscription.Subscription, len(w.subscriptions)) - for k, c := range w.subscriptions { - unsubMap[k] = *c - } - w.subscriptionMutex.RUnlock() - - for i := range genSubs { - key := genSubs[i].EnsureKeyed() - if _, ok := unsubMap[key]; ok { - delete(unsubMap, key) // If it's in both then we remove it from the unsubscribe list - } else { - sub = append(sub, genSubs[i]) // If it's in genSubs but not existing subs we want to subscribe - } - } - - for x := range unsubMap { - unsub = append(unsub, unsubMap[x]) +func (w *Websocket) GetChannelDifference(newSubs subscription.List) (sub, unsub subscription.List) { + if w.subscriptions == nil { + w.subscriptions = subscription.NewStore() } - - return + return w.subscriptions.Diff(newSubs) } -// UnsubscribeChannels unsubscribes from a websocket channel -func (w *Websocket) UnsubscribeChannels(channels []subscription.Subscription) error { +// UnsubscribeChannels unsubscribes from a list of websocket channel +func (w *Websocket) UnsubscribeChannels(channels subscription.List) error { if len(channels) == 0 { return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoSubscriptionsSupplied) } - w.subscriptionMutex.RLock() - - for i := range channels { - key := channels[i].EnsureKeyed() - if _, ok := w.subscriptions[key]; !ok { - w.subscriptionMutex.RUnlock() - return fmt.Errorf("%s websocket: %w: %+v", w.exchangeName, ErrSubscriptionNotFound, channels[i]) + if w.subscriptions == nil { + return common.ErrNilPointer + } + for _, s := range channels { + if w.subscriptions.Get(s) == nil { + return fmt.Errorf("%s websocket: %w: %s", w.exchangeName, subscription.ErrNotFound, s) } } - w.subscriptionMutex.RUnlock() return w.Unsubscriber(channels) } // ResubscribeToChannel resubscribes to channel -func (w *Websocket) ResubscribeToChannel(subscribedChannel *subscription.Subscription) error { - err := w.UnsubscribeChannels([]subscription.Subscription{*subscribedChannel}) +func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error { + l := subscription.List{s} + err := w.UnsubscribeChannels(l) if err != nil { return err } - return w.SubscribeToChannels([]subscription.Subscription{*subscribedChannel}) + return w.SubscribeToChannels(l) } -// SubscribeToChannels appends supplied channels to channelsToSubscribe -func (w *Websocket) SubscribeToChannels(channels []subscription.Subscription) error { - if err := w.checkSubscriptions(channels); err != nil { +// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method +// Errors are returned for duplicates or exceeding max Subscriptions +func (w *Websocket) SubscribeToChannels(subs subscription.List) error { + if err := w.checkSubscriptions(subs); err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } - if err := w.Subscriber(channels); err != nil { + if err := w.Subscriber(subs); err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } return nil } -// AddSubscription adds a subscription to the subscription lists -// Unlike AddSubscriptions this method will error if the subscription already exists -func (w *Websocket) AddSubscription(c *subscription.Subscription) error { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() - if w.subscriptions == nil { - w.subscriptions = subscriptionMap{} +// AddSubscription adds a subscription to the subscription store +func (w *Websocket) AddSubscription(s *subscription.Subscription) error { + if w == nil || s == nil { + return common.ErrNilPointer } - key := c.EnsureKeyed() - if _, ok := w.subscriptions[key]; ok { - return ErrSubscribedAlready + if w.subscriptions == nil { + w.subscriptions = subscription.NewStore() } - - n := *c // Fresh copy; we don't want to use the pointer we were given and allow encapsulation/locks to be bypassed - w.subscriptions[key] = &n - - return nil + return w.subscriptions.Add(s) } -// SetSubscriptionState sets an existing subscription state -// returns an error if the subscription is not found, or the new state is already set -func (w *Websocket) SetSubscriptionState(c *subscription.Subscription, state subscription.State) error { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() - if w.subscriptions == nil { - w.subscriptions = subscriptionMap{} - } - key := c.EnsureKeyed() - p, ok := w.subscriptions[key] - if !ok { - return ErrSubscriptionNotFound +// AddSubscriptions adds subscriptions to the subscription store +func (w *Websocket) AddSubscriptions(subs subscription.List) error { + if w == nil { + return common.ErrNilPointer } - if state == p.State { - return ErrChannelInStateAlready + if w.subscriptions == nil { + w.subscriptions = subscription.NewStore() } - if state > subscription.UnsubscribingState { - return errInvalidChannelState + var errs error + for _, s := range subs { + if err := w.subscriptions.Add(s); err != nil { + errs = common.AppendError(errs, err) + } } - p.State = state - return nil + return errs } -// AddSuccessfulSubscriptions adds subscriptions to the subscription lists that -// has been successfully subscribed -func (w *Websocket) AddSuccessfulSubscriptions(channels ...subscription.Subscription) { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() - if w.subscriptions == nil { - w.subscriptions = subscriptionMap{} - } - for _, cN := range channels { //nolint:gocritic // See below comment - c := cN // cN is an iteration var; Not safe to make a pointer to - key := c.EnsureKeyed() - c.State = subscription.SubscribedState - w.subscriptions[key] = &c +// RemoveSubscription removes a subscription from the subscription store +func (w *Websocket) RemoveSubscription(s *subscription.Subscription) error { + if w == nil || w.subscriptions == nil || s == nil { + return common.ErrNilPointer } + return w.subscriptions.Remove(s) } // RemoveSubscriptions removes subscriptions from the subscription list -func (w *Websocket) RemoveSubscriptions(channels ...subscription.Subscription) { - w.subscriptionMutex.Lock() - defer w.subscriptionMutex.Unlock() - if w.subscriptions == nil { - w.subscriptions = subscriptionMap{} - } - for i := range channels { - key := channels[i].EnsureKeyed() - delete(w.subscriptions, key) +func (w *Websocket) RemoveSubscriptions(subs subscription.List) error { + if w == nil || w.subscriptions == nil { + return common.ErrNilPointer + } + var errs error + for _, s := range subs { + if err := w.subscriptions.Remove(s); err != nil { + errs = common.AppendError(errs, err) + } } + return errs } -// GetSubscription returns a pointer to a copy of the subscription at the key provided +// GetSubscription returns a subscription at the key provided // returns nil if no subscription is at that key or the key is nil +// Keys can implement subscription.MatchableKey in order to provide custom matching logic func (w *Websocket) GetSubscription(key any) *subscription.Subscription { - if key == nil || w == nil || w.subscriptions == nil { + if w == nil || w.subscriptions == nil || key == nil { return nil } - w.subscriptionMutex.RLock() - defer w.subscriptionMutex.RUnlock() - if s, ok := w.subscriptions[key]; ok { - c := *s - return &c - } - return nil + return w.subscriptions.Get(key) } // GetSubscriptions returns a new slice of the subscriptions -func (w *Websocket) GetSubscriptions() []subscription.Subscription { - w.subscriptionMutex.RLock() - defer w.subscriptionMutex.RUnlock() - subs := make([]subscription.Subscription, 0, len(w.subscriptions)) - for _, c := range w.subscriptions { - subs = append(subs, *c) - } - return subs +func (w *Websocket) GetSubscriptions() subscription.List { + if w == nil || w.subscriptions == nil { + return nil + } + return w.subscriptions.List() } -// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in -// a thread safe manner -func (w *Websocket) SetCanUseAuthenticatedEndpoints(val bool) { - w.fieldMutex.Lock() - defer w.fieldMutex.Unlock() - w.canUseAuthenticatedEndpoints = val +// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in a thread safe manner +func (w *Websocket) SetCanUseAuthenticatedEndpoints(b bool) { + w.canUseAuthenticatedEndpoints.Store(b) } -// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in -// a thread safe manner +// CanUseAuthenticatedEndpoints gets canUseAuthenticatedEndpoints val in a thread safe manner func (w *Websocket) CanUseAuthenticatedEndpoints() bool { - w.fieldMutex.RLock() - defer w.fieldMutex.RUnlock() - return w.canUseAuthenticatedEndpoints + return w.canUseAuthenticatedEndpoints.Load() } // IsDisconnectionError Determines if the error sent over chan ReadMessageErrors is a disconnection error @@ -1074,28 +923,28 @@ func checkWebsocketURL(s string) error { return nil } -// checkSubscriptions checks subscriptions against the max subscription limit -// and if the subscription already exists. -func (w *Websocket) checkSubscriptions(subs []subscription.Subscription) error { +// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists +// The subscription state is not considered when counting existing subscriptions +func (w *Websocket) checkSubscriptions(subs subscription.List) error { if len(subs) == 0 { return errNoSubscriptionsSupplied } + if w.subscriptions == nil { + return common.ErrNilPointer + } - w.subscriptionMutex.RLock() - defer w.subscriptionMutex.RUnlock() - - if w.MaxSubscriptionsPerConnection > 0 && len(w.subscriptions)+len(subs) > w.MaxSubscriptionsPerConnection { + existing := w.subscriptions.Len() + if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection { return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs", errSubscriptionsExceedsLimit, - len(w.subscriptions), + existing, len(subs), w.MaxSubscriptionsPerConnection) } - for i := range subs { - key := subs[i].EnsureKeyed() - if _, ok := w.subscriptions[key]; ok { - return fmt.Errorf("%w for %+v", errChannelAlreadySubscribed, subs[i]) + for _, s := range subs { + if found := w.subscriptions.Get(s); found != nil { + return fmt.Errorf("%w for %s", subscription.ErrDuplicate, s) } } diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 0bb1e660412..4d7681f8d13 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -50,9 +50,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(signature, request inter return payload, nil case <-timer.C: timer.Stop() - return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", - w.ExchangeName, - signature) + return nil, fmt.Errorf("%s websocket connection: timeout waiting for response with signature: %v", w.ExchangeName, signature) } } @@ -72,25 +70,14 @@ func (w *WebsocketConnection) Dial(dialer *websocket.Dialer, headers http.Header w.Connection, conStatus, err = dialer.Dial(w.URL, headers) if err != nil { if conStatus != nil { - return fmt.Errorf("%s websocket connection: %v %v %v Error: %v", - w.ExchangeName, - w.URL, - conStatus, - conStatus.StatusCode, - err) + return fmt.Errorf("%s websocket connection: %v %v %v Error: %w", w.ExchangeName, w.URL, conStatus, conStatus.StatusCode, err) } - return fmt.Errorf("%s websocket connection: %v Error: %v", - w.ExchangeName, - w.URL, - err) + return fmt.Errorf("%s websocket connection: %v Error: %w", w.ExchangeName, w.URL, err) } defer conStatus.Body.Close() if w.Verbose { - log.Infof(log.WebsocketMgr, - "%v Websocket connected to %s\n", - w.ExchangeName, - w.URL) + log.Infof(log.WebsocketMgr, "%v Websocket connected to %s\n", w.ExchangeName, w.URL) } select { case w.Traffic <- struct{}{}: @@ -240,7 +227,7 @@ func (w *WebsocketConnection) ReadMessage() Response { select { case w.Traffic <- struct{}{}: - default: // causes contention, just bypass if there is no receiver. + default: // Non-Blocking write ensures 1 buffered signal per trafficCheckInterval to avoid flooding } var standardMessage []byte @@ -285,7 +272,7 @@ func (w *WebsocketConnection) parseBinaryResponse(resp []byte) ([]byte, error) { return standardMessage, reader.Close() } -// GenerateMessageID Creates a messageID to checkout +// GenerateMessageID Creates a random message ID func (w *WebsocketConnection) GenerateMessageID(highPrec bool) int64 { var min int64 = 1e8 var max int64 = 2e8 diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 3ab49e0df10..3c370fd3676 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -9,6 +9,7 @@ import ( "fmt" "net" "net/http" + "os" "sort" "strconv" "strings" @@ -18,6 +19,8 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -30,6 +33,10 @@ const ( proxyURL = "http://212.186.171.4:80" // Replace with a usable proxy server ) +var ( + errDastardlyReason = errors.New("some dastardly reason") +) + var dialer websocket.Dialer type testStruct struct { @@ -68,15 +75,15 @@ var defaultSetup = &WebsocketSetup{ AuthenticatedWebsocketSupport: true, }, WebsocketTrafficTimeout: time.Second * 5, - Name: "exchangeName", + Name: "GTX", }, DefaultURL: "testDefaultURL", RunningURL: "wss://testRunningURL", Connector: func() error { return nil }, - Subscriber: func([]subscription.Subscription) error { return nil }, - Unsubscriber: func([]subscription.Subscription) error { return nil }, - GenerateSubscriptions: func() ([]subscription.Subscription, error) { - return []subscription.Subscription{ + Subscriber: func(subscription.List) error { return nil }, + Unsubscriber: func(subscription.List) error { return nil }, + GenerateSubscriptions: func() (subscription.List, error) { + return subscription.List{ {Channel: "TestSub"}, {Channel: "TestSub2", Key: "purple"}, {Channel: "TestSub3", Key: testSubKey{"mauve"}}, @@ -92,440 +99,396 @@ type dodgyConnection struct { // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Shutdown() error { - return errors.New("cannot shutdown due to some dastardly reason") + return fmt.Errorf("%w: %w", errCannotShutdown, errDastardlyReason) } // override websocket connection method to produce a wicked terrible error func (d *dodgyConnection) Connect() error { - return errors.New("cannot connect due to some dastardly reason") + return fmt.Errorf("cannot connect: %w", errDastardlyReason) +} + +func TestMain(m *testing.M) { + // Change trafficCheckInterval for TestTrafficMonitorTimeout before parallel tests to avoid racing + trafficCheckInterval = 50 * time.Millisecond + os.Exit(m.Run()) } func TestSetup(t *testing.T) { t.Parallel() var w *Websocket err := w.Setup(nil) - if !errors.Is(err, errWebsocketIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketIsNil) - } + assert.ErrorIs(t, err, errWebsocketIsNil, "Setup should error correctly") w = &Websocket{DataHandler: make(chan interface{})} err = w.Setup(nil) - if !errors.Is(err, errWebsocketSetupIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSetupIsNil) - } + assert.ErrorIs(t, err, errWebsocketSetupIsNil, "Setup should error correctly") websocketSetup := &WebsocketSetup{} - err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } - w.Init = true err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigIsNil) - } + assert.ErrorIs(t, err, errExchangeConfigIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig = &config.Exchange{} err = w.Setup(websocketSetup) - if !errors.Is(err, errExchangeConfigNameUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errExchangeConfigNameUnset) - } - websocketSetup.ExchangeConfig.Name = "testname" + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "Setup should error correctly") + websocketSetup.ExchangeConfig.Name = "testname" err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketFeaturesIsUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketFeaturesIsUnset) - } + assert.ErrorIs(t, err, errWebsocketFeaturesIsUnset, "Setup should error correctly") websocketSetup.Features = &protocol.Features{} err = w.Setup(websocketSetup) - if !errors.Is(err, errConfigFeaturesIsNil) { - t.Fatalf("received: '%v' but expected: '%v'", err, errConfigFeaturesIsNil) - } + assert.ErrorIs(t, err, errConfigFeaturesIsNil, "Setup should error correctly") websocketSetup.ExchangeConfig.Features = &config.FeaturesConfig{} err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketConnectorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketConnectorUnset) - } + assert.ErrorIs(t, err, errWebsocketConnectorUnset, "Setup should error correctly") websocketSetup.Connector = func() error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriberUnset, "Setup should error correctly") - websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil } + websocketSetup.Subscriber = func(subscription.List) error { return nil } websocketSetup.Features.Unsubscribe = true err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketUnsubscriberUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketUnsubscriberUnset) - } + assert.ErrorIs(t, err, errWebsocketUnsubscriberUnset, "Setup should error correctly") - websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil } + websocketSetup.Unsubscriber = func(subscription.List) error { return nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errWebsocketSubscriptionsGeneratorUnset) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriptionsGeneratorUnset) - } + assert.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset, "Setup should error correctly") - websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil } + websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } err = w.Setup(websocketSetup) - if !errors.Is(err, errDefaultURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errDefaultURLIsEmpty) - } + assert.ErrorIs(t, err, errDefaultURLIsEmpty, "Setup should error correctly") websocketSetup.DefaultURL = "test" err = w.Setup(websocketSetup) - if !errors.Is(err, errRunningURLIsEmpty) { - t.Fatalf("received: '%v' but expected: '%v'", err, errRunningURLIsEmpty) - } + assert.ErrorIs(t, err, errRunningURLIsEmpty, "Setup should error correctly") websocketSetup.RunningURL = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURL = "wss://www.google.com" websocketSetup.RunningURLAuth = "http://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidWebsocketURL) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidWebsocketURL) - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "Setup should error correctly") websocketSetup.RunningURLAuth = "wss://www.google.com" err = w.Setup(websocketSetup) - if !errors.Is(err, errInvalidTrafficTimeout) { - t.Fatalf("received: '%v' but expected: '%v'", err, errInvalidTrafficTimeout) - } + assert.ErrorIs(t, err, errInvalidTrafficTimeout, "Setup should error correctly") websocketSetup.ExchangeConfig.WebsocketTrafficTimeout = time.Minute err = w.Setup(websocketSetup) - if !errors.Is(err, nil) { - t.Fatalf("received: %v but expected: %v", err, nil) - } + assert.NoError(t, err, "Setup should not error") } -func TestTrafficMonitorTimeout(t *testing.T) { +// TestTrafficMonitorTrafficAlerts ensures multiple traffic alerts work and only process one trafficAlert per interval +// ensures shutdown works after traffic alerts +func TestTrafficMonitorTrafficAlerts(t *testing.T) { t.Parallel() - ws := *New() - if err := ws.Setup(defaultSetup); err != nil { - t.Fatal(err) - } - ws.trafficTimeout = time.Second * 2 - ws.ShutdownC = make(chan struct{}) + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + signal := struct{}{} + patience := 10 * time.Millisecond + ws.trafficTimeout = 200 * time.Millisecond + ws.state.Store(connectedState) + + thenish := time.Now() ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") + + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, connectedState, ws.state.Load(), "websocket must be connected") + + for i := 0; i < 6; i++ { // Timeout will happen at 200ms so we want 6 * 50ms checks to pass + select { + case ws.TrafficAlert <- signal: + if i == 0 { + require.WithinDurationf(t, time.Now(), thenish, trafficCheckInterval, "First Non-blocking test must happen before the traffic is checked") + } + default: + require.Failf(t, "", "TrafficAlert should not block; Check #%d", i) + } + + select { + case ws.TrafficAlert <- signal: + require.Failf(t, "", "TrafficAlert should block after first slot used; Check #%d", i) + default: + if i == 0 { + require.WithinDuration(t, time.Now(), thenish, trafficCheckInterval, "First Blocking test must happen before the traffic is checked") + } + } + + require.EventuallyWithTf(t, func(c *assert.CollectT) { + assert.Truef(c, ws.IsConnected(), "state should still be connected; Check #%d", i) + assert.Emptyf(c, ws.TrafficAlert, "trafficAlert channel should be drained; Check #%d", i) + }, 2*trafficCheckInterval, patience, "trafficAlert should be read; Check #%d", i) } - // Deploy traffic alert - ws.TrafficAlert <- struct{}{} - // try to add another traffic monitor + + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") + }, 2*ws.trafficTimeout, patience, "trafficTimeout should trigger a shutdown once we stop feeding trafficAlerts") +} + +// TestTrafficMonitorConnecting ensures connecting status doesn't trigger shutdown +func TestTrafficMonitorConnecting(t *testing.T) { + t.Parallel() + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") + + ws.state.Store(connectingState) + ws.trafficTimeout = 50 * time.Millisecond ws.trafficMonitor() - if !ws.IsTrafficMonitorRunning() { - t.Fatal("traffic monitor should be running") - } - // prevent shutdown routine - ws.setConnectedStatus(false) - // await timeout closure - ws.Wg.Wait() - if ws.IsTrafficMonitorRunning() { - t.Error("should be dead") - } + require.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + require.Equal(t, connectingState, ws.state.Load(), "websocket must be connecting") + <-time.After(4 * ws.trafficTimeout) + require.Equal(t, connectingState, ws.state.Load(), "websocket must still be connecting after several checks") + ws.state.Store(connectedState) + require.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Equal(c, disconnectedState, ws.state.Load(), "websocket must be disconnected") + assert.False(c, ws.IsTrafficMonitorRunning(), "trafficMonitor should be shut down") + }, 4*ws.trafficTimeout, 10*time.Millisecond, "trafficTimeout should trigger a shutdown after connecting status changes") } -func TestIsDisconnectionError(t *testing.T) { +// TestTrafficMonitorShutdown ensures shutdown is processed and waitgroup is cleared +func TestTrafficMonitorShutdown(t *testing.T) { t.Parallel() - isADisconnectionError := IsDisconnectionError(errors.New("errorText")) - if isADisconnectionError { - t.Error("Its not") - } - isADisconnectionError = IsDisconnectionError(&websocket.CloseError{ - Code: 1006, - Text: "errorText", - }) - if !isADisconnectionError { - t.Error("It is") - } + ws := NewWebsocket() + err := ws.Setup(defaultSetup) + require.NoError(t, err, "Setup must not error") - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errClosedConnection, - }) - if isADisconnectionError { - t.Error("It's not") + ws.state.Store(connectedState) + ws.trafficTimeout = time.Minute + ws.trafficMonitor() + assert.True(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be running") + + wgReady := make(chan bool) + go func() { + ws.Wg.Wait() + close(wgReady) + }() + select { + case <-wgReady: + require.Failf(t, "", "WaitGroup should be blocking still") + case <-time.After(trafficCheckInterval): } - isADisconnectionError = IsDisconnectionError(&net.OpError{ - Err: errors.New("errText"), - }) - if !isADisconnectionError { - t.Error("It is") + close(ws.ShutdownC) + + <-time.After(2 * trafficCheckInterval) + assert.False(t, ws.IsTrafficMonitorRunning(), "traffic monitor should be shutdown") + select { + case <-wgReady: + default: + require.Failf(t, "", "WaitGroup should be freed now") } } +func TestIsDisconnectionError(t *testing.T) { + t.Parallel() + assert.False(t, IsDisconnectionError(errors.New("errorText")), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&websocket.CloseError{Code: 1006, Text: "errorText"}), "IsDisconnectionError should return true") + assert.False(t, IsDisconnectionError(&net.OpError{Err: errClosedConnection}), "IsDisconnectionError should return false") + assert.True(t, IsDisconnectionError(&net.OpError{Err: errors.New("errText")}), "IsDisconnectionError should return true") +} + func TestConnectionMessageErrors(t *testing.T) { t.Parallel() var wsWrong = &Websocket{} err := wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errNoConnectFunc, "Connect should error correctly") wsWrong.connector = func() error { return nil } err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "Connect should error correctly") wsWrong.setEnabled(true) - wsWrong.setConnectingStatus(true) - wsWrong.Wg = &sync.WaitGroup{} + wsWrong.setState(connectingState) err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errAlreadyReconnecting, "Connect should error correctly") - wsWrong.setConnectedStatus(false) - wsWrong.connector = func() error { return errors.New("edge case error of dooooooom") } + wsWrong.setState(disconnectedState) err = wsWrong.Connect() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, common.ErrNilPointer, "Connect should get a nil pointer error, presumably on subs") - ws := *New() + wsWrong.subscriptions = subscription.NewStore() + wsWrong.setState(disconnectedState) + wsWrong.connector = func() error { return errDastardlyReason } + err = wsWrong.Connect() + assert.ErrorIs(t, err, errDastardlyReason, "Connect should error correctly") + + ws := NewWebsocket() err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Setup must not error") ws.trafficTimeout = time.Minute - ws.connector = func() error { return nil } + ws.connector = connect err = ws.Connect() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Connect must not error") ws.TrafficAlert <- struct{}{} - timer := time.NewTimer(900 * time.Millisecond) - ws.ReadMessageErrors <- errors.New("errorText") - select { - case err := <-ws.ToRoutine: - errText, ok := err.(error) - if !ok { - t.Error("unable to type assert error") - } else if errText.Error() != "errorText" { - t.Errorf("Expected 'errorText', received %v", err) - } - case <-timer.C: - t.Error("Timeout waiting for datahandler to receive error") - } - ws.ReadMessageErrors <- &websocket.CloseError{ - Code: 1006, - Text: "errorText", - } -outer: - for { + c := func(tb *assert.CollectT) { select { - case err := <-ws.ToRoutine: - if _, ok := err.(*websocket.CloseError); !ok { - t.Errorf("Error is not a disconnection error: %v", err) + case v := <-ws.ToRoutine: + switch err := v.(type) { + case *websocket.CloseError: + assert.Equal(tb, "SpecialText", err.Text, "Should get correct Close Error") + case error: + assert.ErrorIs(tb, err, errDastardlyReason, "Should get the correct error") } - case <-timer.C: - break outer + default: } } + + ws.ReadMessageErrors <- errDastardlyReason + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") + + ws.ReadMessageErrors <- &websocket.CloseError{Code: 1006, Text: "SpecialText"} + assert.EventuallyWithT(t, c, 900*time.Millisecond, 10*time.Millisecond, "Should get an error down the routine") } func TestWebsocket(t *testing.T) { t.Parallel() - wsInit := Websocket{} - err := wsInit.Setup(&WebsocketSetup{ - ExchangeConfig: &config.Exchange{ - Features: &config.FeaturesConfig{ - Enabled: config.FeaturesEnabledConfig{Websocket: true}, - }, - Name: "test", - }, - }) - if !errors.Is(err, errWebsocketAlreadyInitialised) { - t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketAlreadyInitialised) - } - ws := *New() - err = ws.SetProxyAddress("garbagio") - if err == nil { - t.Error("error cannot be nil") - } + ws := NewWebsocket() - ws.Conn = &WebsocketConnection{} + err := ws.SetProxyAddress("garbagio") + assert.ErrorContains(t, err, "invalid URI for request", "SetProxyAddress should error correctly") + + ws.Conn = &dodgyConnection{} ws.AuthConn = &WebsocketConnection{} ws.setEnabled(true) - err = ws.SetProxyAddress("https://192.168.0.1:1337") - if err == nil { - t.Error("error cannot be nil") - } - ws.setConnectedStatus(true) - ws.ShutdownC = make(chan struct{}) - ws.Wg = &sync.WaitGroup{} - err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } - err = ws.SetProxyAddress("https://192.168.0.1:1336") - if err == nil { - t.Error("SetProxyAddress", err) - } - ws.setEnabled(false) + err = ws.Setup(defaultSetup) // Sets to enabled again + require.NoError(t, err, "Setup may not error") - // removing proxy - err = ws.SetProxyAddress("") - if err != nil { - t.Error(err) - } - // reinstate proxy - err = ws.SetProxyAddress("http://localhost:1337") - if err != nil { - t.Error(err) - } - // conflict proxy - err = ws.SetProxyAddress("http://localhost:1337") - if err == nil { - t.Error("error cannot be nil") - } err = ws.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } - if ws.GetName() != "exchangeName" { - t.Error("WebsocketSetup") - } + assert.ErrorIs(t, err, errWebsocketAlreadyInitialised, "Setup should error correctly if called twice") - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } + assert.Equal(t, "GTX", ws.GetName(), "GetName should return correctly") + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by Setup") ws.setEnabled(false) - if ws.IsEnabled() { - t.Error("WebsocketSetup") - } + assert.False(t, ws.IsEnabled(), "Websocket should be disabled by setEnabled(false)") + ws.setEnabled(true) - if !ws.IsEnabled() { - t.Error("WebsocketSetup") - } + assert.True(t, ws.IsEnabled(), "Websocket should be enabled by setEnabled(true)") - if ws.GetProxyAddress() != "http://localhost:1337" { - t.Error("WebsocketSetup") - } + err = ws.SetProxyAddress("https://192.168.0.1:1337") + assert.NoError(t, err, "SetProxyAddress should not error when not yet connected") - if ws.GetWebsocketURL() != "wss://testRunningURL" { - t.Error("WebsocketSetup") - } - if ws.trafficTimeout != time.Second*5 { - t.Error("WebsocketSetup") - } - // -- Not connected shutdown - err = ws.Shutdown() - if err == nil { - t.Fatal("should not be connected to able to shut down") - } + ws.setState(connectedState) - ws.setConnectedStatus(true) - ws.Conn = &dodgyConnection{} - err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil") - } + err = ws.SetProxyAddress("https://192.168.0.1:1336") + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Connect and error from there") + + err = ws.SetProxyAddress("https://192.168.0.1:1336") + assert.ErrorIs(t, err, errSameProxyAddress, "SetProxyAddress should error correctly") + + // removing proxy + err = ws.SetProxyAddress("") + assert.ErrorIs(t, err, errDastardlyReason, "SetProxyAddress should call Shutdown and error from there") + assert.ErrorIs(t, err, errCannotShutdown, "SetProxyAddress should call Shutdown and error from there") ws.Conn = &WebsocketConnection{} + ws.setEnabled(true) - ws.setConnectedStatus(true) + // reinstate proxy + err = ws.SetProxyAddress("http://localhost:1337") + assert.NoError(t, err, "SetProxyAddress should not error") + assert.Equal(t, "http://localhost:1337", ws.GetProxyAddress(), "GetProxyAddress should return correctly") + assert.Equal(t, "wss://testRunningURL", ws.GetWebsocketURL(), "GetWebsocketURL should return correctly") + assert.Equal(t, time.Second*5, ws.trafficTimeout, "trafficTimeout should default correctly") + + ws.setState(connectedState) ws.AuthConn = &dodgyConnection{} err = ws.Shutdown() - if err == nil { - t.Fatal("error cannot be nil ") - } + assert.ErrorIs(t, err, errDastardlyReason, "Shutdown should error correctly with a dodgy authConn") + assert.ErrorIs(t, err, errCannotShutdown, "Shutdown should error correctly with a dodgy authConn") ws.AuthConn = &WebsocketConnection{} - ws.setConnectedStatus(false) + ws.setState(disconnectedState) - // -- Normal connect err = ws.Connect() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Connect should not error") ws.defaultURL = "ws://demos.kaazing.com/echo" ws.defaultURLAuth = "ws://demos.kaazing.com/echo" err = ws.SetWebsocketURL("", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", false, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("", true, false) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, false) - if err != nil { - t.Fatal(err) - } - // Attempt reconnect + assert.NoError(t, err, "SetWebsocketURL should not error") + err = ws.SetWebsocketURL("ws://demos.kaazing.com/echo", true, true) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetWebsocketURL should not error on reconnect") + // -- initiate the reconnect which is usually handled by connection monitor err = ws.Connect() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "ReConnect called manually should not error") + err = ws.Connect() - if err == nil { - t.Fatal("should already be connected") - } - // -- Normal shutdown + assert.ErrorIs(t, err, errAlreadyConnected, "ReConnect should error when already connected") + err = ws.Shutdown() - if err != nil { - t.Fatal("WebsocketSetup", err) - } + assert.NoError(t, err, "Shutdown should not error") ws.Wg.Wait() } +func currySimpleSub(w *Websocket) func(subscription.List) error { + return func(subs subscription.List) error { + for _, s := range subs { + if err := s.SetState(subscription.SubscribedState); err != nil { + return err + } + } + return w.AddSubscriptions(subs) + } +} + +func currySimpleUnsub(w *Websocket) func(subscription.List) error { + return func(unsubs subscription.List) error { + for _, s := range unsubs { + if err := s.SetState(subscription.InactiveState); err != nil { + return err + } + } + return w.RemoveSubscriptions(unsubs) + } +} + // TestSubscribe logic test func TestSubscribeUnsubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error") - fnSub := func(subs []subscription.Subscription) error { - ws.AddSuccessfulSubscriptions(subs...) - return nil - } - fnUnsub := func(unsubs []subscription.Subscription) error { - ws.RemoveSubscriptions(unsubs...) - return nil - } - ws.Subscriber = fnSub - ws.Unsubscriber = fnUnsub + ws.Subscriber = currySimpleSub(ws) + ws.Unsubscriber = currySimpleUnsub(ws) subs, err := ws.GenerateSubs() assert.NoError(t, err, "Generating test subscriptions should not error") assert.ErrorIs(t, ws.UnsubscribeChannels(nil), errNoSubscriptionsSupplied, "Unsubscribing from nil should error") - assert.ErrorIs(t, ws.UnsubscribeChannels(subs), ErrSubscriptionNotFound, "Unsubscribing should error when not subscribed") + assert.ErrorIs(t, ws.UnsubscribeChannels(subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error") assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") - byDefKey := ws.GetSubscription(subscription.DefaultKey{Channel: "TestSub"}) - if assert.NotNil(t, byDefKey, "GetSubscription by default key should find a channel") { - assert.Equal(t, "TestSub", byDefKey.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") - assert.NotSame(t, byDefKey, ws.subscriptions["TestSub"], "GetSubscription returns a fresh pointer") + bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) + if assert.NotNil(t, bySub, "GetSubscription by subscription should find a channel") { + assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") + assert.Same(t, bySub, subs[0], "GetSubscription returns the same pointer") } if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") { assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") @@ -538,7 +501,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { } assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, ws.SubscribeToChannels(subs), errChannelAlreadySubscribed, "Subscribe should error when already subscribed") + assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") assert.ErrorIs(t, ws.SubscribeToChannels(nil), errNoSubscriptionsSupplied, "Subscribe to nil should error") assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error") } @@ -546,7 +509,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { // TestResubscribe tests Resubscribing to existing subscriptions func TestResubscribe(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() wackedOutSetup := *defaultSetup wackedOutSetup.MaxWebsocketSubscriptionsPerConnection = -1 @@ -556,131 +519,80 @@ func TestResubscribe(t *testing.T) { err = ws.Setup(defaultSetup) assert.NoError(t, err, "WS Setup should not error") - fnSub := func(subs []subscription.Subscription) error { - ws.AddSuccessfulSubscriptions(subs...) - return nil - } - fnUnsub := func(unsubs []subscription.Subscription) error { - ws.RemoveSubscriptions(unsubs...) - return nil - } - ws.Subscriber = fnSub - ws.Unsubscriber = fnUnsub + ws.Subscriber = currySimpleSub(ws) + ws.Unsubscriber = currySimpleUnsub(ws) - channel := []subscription.Subscription{{Channel: "resubTest"}} + channel := subscription.List{{Channel: "resubTest"}} - assert.ErrorIs(t, ws.ResubscribeToChannel(&channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet") + assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error") - assert.NoError(t, ws.ResubscribeToChannel(&channel[0]), "Resubscribe should not error now the channel is subscribed") + assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed") } -// TestSubscriptionState tests Subscription state changes -func TestSubscriptionState(t *testing.T) { - t.Parallel() - ws := New() - - c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState} - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error") - - assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") - found := ws.GetSubscription(42) - assert.NotNil(t, found, "Should find the subscription") - assert.Equal(t, subscription.SubscribingState, found.State, "Subscription should be Subscribing") - assert.ErrorIs(t, ws.AddSubscription(c), ErrSubscribedAlready, "Adding an already existing sub should error") - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.SubscribingState), ErrChannelInStateAlready, "Setting Same state should error") - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState+1), errInvalidChannelState, "Setting an invalid state should error") - - ws.AddSuccessfulSubscriptions(*c) - found = ws.GetSubscription(42) - assert.NotNil(t, found, "Should find the subscription") - assert.Equal(t, subscription.SubscribedState, found.State, "Subscription should be subscribed state") - - assert.NoError(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), "Setting Unsub state should not error") - found = ws.GetSubscription(42) - assert.Equal(t, subscription.UnsubscribingState, found.State, "Subscription should be unsubscribing state") +func TestAddSubscription(t *testing.T) { + t.Fatal("Not implemented, along with others") } // TestRemoveSubscriptions tests removing a subscription func TestRemoveSubscriptions(t *testing.T) { t.Parallel() - ws := New() + ws := NewWebsocket() c := &subscription.Subscription{Key: 42, Channel: "Unite!"} - assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") + require.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") assert.NotNil(t, ws.GetSubscription(42), "Added subscription should be findable") - ws.RemoveSubscriptions(*c) + err := ws.RemoveSubscriptions(subscription.List{c}) + require.NoError(t, err, "RemoveSubscriptions must not error") assert.Nil(t, ws.GetSubscription(42), "Remove should have removed the sub") } // TestConnectionMonitorNoConnection logic test func TestConnectionMonitorNoConnection(t *testing.T) { t.Parallel() - ws := *New() + ws := NewWebsocket() ws.connectionMonitorDelay = 500 - ws.DataHandler = make(chan interface{}, 1) - ws.ShutdownC = make(chan struct{}, 1) ws.exchangeName = "hello" - ws.Wg = &sync.WaitGroup{} - ws.enabled = true + ws.setEnabled(true) err := ws.connectionMonitor() - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } - if !ws.IsConnectionMonitorRunning() { - t.Fatal("Should not have exited") - } + require.NoError(t, err, "connectionMonitor must not error") + assert.True(t, ws.IsConnectionMonitorRunning(), "IsConnectionMonitorRunning should return true") err = ws.connectionMonitor() - if !errors.Is(err, errAlreadyRunning) { - t.Fatalf("received: %v, but expected: %v", err, errAlreadyRunning) - } + assert.ErrorIs(t, err, errAlreadyRunning, "connectionMonitor should error correctly") } // TestGetSubscription logic test func TestGetSubscription(t *testing.T) { t.Parallel() assert.Nil(t, (*Websocket).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil") - assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub map should return nil") - w := Websocket{ - subscriptions: subscriptionMap{ - 42: { - Channel: "hello3", - }, - }, - } - assert.Nil(t, w.GetSubscription(43), "GetSubscription with an invalid key should return nil") - c := w.GetSubscription(42) - if assert.NotNil(t, c, "GetSubscription with an valid key should return a channel") { - assert.Equal(t, "hello3", c.Channel, "GetSubscription should return the correct channel details") - } + assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil") + w := NewWebsocket() + assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") + s := &subscription.Subscription{Key: 42, Channel: "hello3"} + w.AddSubscription(s) + assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") } // TestGetSubscriptions logic test func TestGetSubscriptions(t *testing.T) { t.Parallel() - w := Websocket{ - subscriptions: subscriptionMap{ - 42: { - Channel: "hello3", - }, - }, - } + assert.Nil(t, (*Websocket).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil") + assert.Nil(t, (&Websocket{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil") + w := NewWebsocket() + w.AddSubscriptions(subscription.List{ + {Key: 42, Channel: "hello3"}, + {Key: 45, Channel: "hello4"}, + }) assert.Equal(t, "hello3", w.GetSubscriptions()[0].Channel, "GetSubscriptions should return the correct channel details") } // TestSetCanUseAuthenticatedEndpoints logic test func TestSetCanUseAuthenticatedEndpoints(t *testing.T) { t.Parallel() - ws := *New() - result := ws.CanUseAuthenticatedEndpoints() - if result { - t.Error("expected `canUseAuthenticatedEndpoints` to be false") - } + ws := NewWebsocket() + assert.False(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return false") ws.SetCanUseAuthenticatedEndpoints(true) - result = ws.CanUseAuthenticatedEndpoints() - if !result { - t.Error("expected `canUseAuthenticatedEndpoints` to be true") - } + assert.True(t, ws.CanUseAuthenticatedEndpoints(), "CanUseAuthenticatedEndpoints should return true") } // TestDial logic test @@ -917,81 +829,53 @@ func TestParseBinaryResponse(t *testing.T) { } var b bytes.Buffer - w := gzip.NewWriter(&b) - _, err := w.Write([]byte("hello")) - if err != nil { - t.Error(err) - } - err = w.Close() - if err != nil { - t.Error(err) - } - var resp []byte - resp, err = wc.parseBinaryResponse(b.Bytes()) - if err != nil { - t.Error(err) - } - if !strings.EqualFold(string(resp), "hello") { - t.Errorf("GZip conversion failed. Received: '%v', Expected: 'hello'", string(resp)) - } + g := gzip.NewWriter(&b) + _, err := g.Write([]byte("hello")) + require.NoError(t, err, "gzip.Write must not error") + assert.NoError(t, g.Close(), "Close should not error") + + resp, err := wc.parseBinaryResponse(b.Bytes()) + assert.NoError(t, err, "parseBinaryResponse should not error parsing gzip") + assert.EqualValues(t, "hello", resp, "parseBinaryResponse should decode gzip") + + b.Reset() + f, err := flate.NewWriter(&b, 1) + require.NoError(t, err, "flate.NewWriter must not error") + _, err = f.Write([]byte("goodbye")) + require.NoError(t, err, "flate.Write must not error") + assert.NoError(t, f.Close(), "Close should not error") - var b2 bytes.Buffer - w2, err2 := flate.NewWriter(&b2, 1) - if err2 != nil { - t.Error(err2) - } - _, err2 = w2.Write([]byte("hello")) - if err2 != nil { - t.Error(err) - } - err2 = w2.Close() - if err2 != nil { - t.Error(err) - } - resp2, err3 := wc.parseBinaryResponse(b2.Bytes()) - if err3 != nil { - t.Error(err3) - } - if !strings.EqualFold(string(resp2), "hello") { - t.Errorf("Deflate conversion failed. Received: '%v', Expected: 'hello'", string(resp2)) - } + resp, err = wc.parseBinaryResponse(b.Bytes()) + assert.NoError(t, err, "parseBinaryResponse should not error parsing inflate") + assert.EqualValues(t, "goodbye", resp, "parseBinaryResponse should deflate") - _, err4 := wc.parseBinaryResponse([]byte{}) - if err4 == nil || err4.Error() != "unexpected EOF" { - t.Error("Expected error 'unexpected EOF'") - } + _, err = wc.parseBinaryResponse([]byte{}) + assert.ErrorContains(t, err, "unexpected EOF", "parseBinaryResponse should error on empty input") } // TestCanUseAuthenticatedWebsocketForWrapper logic test func TestCanUseAuthenticatedWebsocketForWrapper(t *testing.T) { t.Parallel() ws := &Websocket{} - resp := ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is false") - } - ws.setConnectedStatus(true) - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if resp { - t.Error("Expected false, `connected` is true and `CanUseAuthenticatedEndpoints` is false") - } - ws.canUseAuthenticatedEndpoints = true - resp = ws.CanUseAuthenticatedWebsocketForWrapper() - if !resp { - t.Error("Expected true, `connected` and `CanUseAuthenticatedEndpoints` is true") - } + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + + ws.setState(connectedState) + require.True(t, ws.IsConnected(), "IsConnected must return true") + assert.False(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return false") + + ws.SetCanUseAuthenticatedEndpoints(true) + assert.True(t, ws.CanUseAuthenticatedWebsocketForWrapper(), "CanUseAuthenticatedWebsocketForWrapper should return true") } func TestGenerateMessageID(t *testing.T) { t.Parallel() wc := WebsocketConnection{} - var id int64 - for i := 0; i < 10; i++ { - newID := wc.GenerateMessageID(true) - if id == newID { - t.Fatal("ID generation is not unique") - } - id = newID + const spins = 1000 + ids := make([]int64, spins) + for i := 0; i < spins; i++ { + id := wc.GenerateMessageID(true) + assert.NotContains(t, ids, id, "GenerateMessageID must not generate the same ID twice") + ids[i] = id } } @@ -1013,81 +897,62 @@ func BenchmarkGenerateMessageID_Low(b *testing.B) { func TestCheckWebsocketURL(t *testing.T) { err := checkWebsocketURL("") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on empty string") err = checkWebsocketURL("wowowow:wowowowo") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on bad format") err = checkWebsocketURL("://") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "missing protocol scheme", "checkWebsocketURL should error correctly on bad proto") err = checkWebsocketURL("http://www.google.com") - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errInvalidWebsocketURL, "checkWebsocketURL should error correctly on wrong proto") err = checkWebsocketURL("wss://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") err = checkWebsocketURL("ws://websocketconnection.place") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "checkWebsocketURL should not error") } func TestGetChannelDifference(t *testing.T) { t.Parallel() web := Websocket{} - newChans := []subscription.Subscription{ - { - Channel: "Test1", - }, - { - Channel: "Test2", - }, - { - Channel: "Test3", - }, + newChans := subscription.List{ + {Channel: "Test1"}, + {Channel: "Test2"}, + {Channel: "Test3"}, } subs, unsubs := web.GetChannelDifference(newChans) - assert.Len(t, subs, 3, "Should get the correct number of subs") - assert.Empty(t, unsubs, "Should get the correct number of unsubs") + assert.Implements(t, (*subscription.MatchableKey)(nil), subs[0].Key, "Sub key must be matchable") + assert.Equal(t, 3, len(subs), "Should get the correct number of subs") + assert.Empty(t, unsubs, "Should get no unsubs") - web.AddSuccessfulSubscriptions(subs...) + for _, s := range subs { + s.SetState(subscription.SubscribedState) + } - flushedSubs := []subscription.Subscription{ - { - Channel: "Test2", - }, + web.AddSubscriptions(subs) + + flushedSubs := subscription.List{ + {Channel: "Test2"}, } subs, unsubs = web.GetChannelDifference(flushedSubs) - assert.Empty(t, subs, "Should get the correct number of subs") - assert.Len(t, unsubs, 2, "Should get the correct number of unsubs") + assert.Empty(t, subs, "Should get no subs") + assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") - flushedSubs = []subscription.Subscription{ - { - Channel: "Test2", - }, - { - Channel: "Test4", - }, + flushedSubs = subscription.List{ + {Channel: "Test2"}, + {Channel: "Test4"}, } subs, unsubs = web.GetChannelDifference(flushedSubs) - if assert.Len(t, subs, 1, "Should get the correct number of subs") { + if assert.Equal(t, 1, len(subs), "Should get the correct number of subs") { assert.Equal(t, "Test4", subs[0].Channel, "Should subscribe to the right channel") } - if assert.Len(t, unsubs, 2, "Should get the correct number of unsubs") { + if assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") { sort.Slice(unsubs, func(i, j int) bool { return unsubs[i].Channel <= unsubs[j].Channel }) assert.Equal(t, "Test1", unsubs[0].Channel, "Should unsubscribe from the right channels") assert.Equal(t, "Test3", unsubs[1].Channel, "Should unsubscribe from the right channels") @@ -1097,23 +962,23 @@ func TestGetChannelDifference(t *testing.T) { // GenSubs defines a theoretical exchange with pair management type GenSubs struct { EnabledPairs currency.Pairs - subscribos []subscription.Subscription - unsubscribos []subscription.Subscription + subscribos subscription.List + unsubscribos subscription.List } // generateSubs default subs created from the enabled pairs list -func (g *GenSubs) generateSubs() ([]subscription.Subscription, error) { - superduperchannelsubs := make([]subscription.Subscription, len(g.EnabledPairs)) +func (g *GenSubs) generateSubs() (subscription.List, error) { + superduperchannelsubs := make(subscription.List, len(g.EnabledPairs)) for i := range g.EnabledPairs { - superduperchannelsubs[i] = subscription.Subscription{ + superduperchannelsubs[i] = &subscription.Subscription{ Channel: "TEST:" + strconv.FormatInt(int64(i), 10), - Pair: g.EnabledPairs[i], + Pairs: currency.Pairs{g.EnabledPairs[i]}, } } return superduperchannelsubs, nil } -func (g *GenSubs) SUBME(subs []subscription.Subscription) error { +func (g *GenSubs) SUBME(subs subscription.List) error { if len(subs) == 0 { return errors.New("WOW") } @@ -1121,7 +986,7 @@ func (g *GenSubs) SUBME(subs []subscription.Subscription) error { return nil } -func (g *GenSubs) UNSUBME(unsubs []subscription.Subscription) error { +func (g *GenSubs) UNSUBME(unsubs subscription.List) error { if len(unsubs) == 0 { return errors.New("WOW") } @@ -1142,242 +1007,151 @@ func TestFlushChannels(t *testing.T) { dodgyWs := Websocket{} err := dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrWebsocketNotEnabled, "FlushChannels should error correctly") dodgyWs.setEnabled(true) err = dodgyWs.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, ErrNotConnected, "FlushChannels should error correctly") - web := Websocket{ - enabled: true, - connected: true, - connector: connect, - ShutdownC: make(chan struct{}), - Subscriber: newgen.SUBME, - Unsubscriber: newgen.UNSUBME, - Wg: new(sync.WaitGroup), - features: &protocol.Features{ - // No features - }, - trafficTimeout: time.Second * 30, // Added for when we utilise connect() - // in FlushChannels() so the traffic monitor doesn't time out and turn - // this to an unconnected state - } + w := NewWebsocket() + w.connector = connect + w.Subscriber = newgen.SUBME + w.Unsubscriber = newgen.UNSUBME + // Added for when we utilise connect() in FlushChannels() so the traffic monitor doesn't time out and turn this to an unconnected state + w.trafficTimeout = time.Second * 30 + + w.setEnabled(true) + w.setState(connectedState) - problemFunc := func() ([]subscription.Subscription, error) { - return nil, errors.New("problems") + problemFunc := func() (subscription.List, error) { + return nil, errDastardlyReason } - noSub := func() ([]subscription.Subscription, error) { + noSub := func() (subscription.List, error) { return nil, nil } // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} - web.GenerateSubs = func() ([]subscription.Subscription, error) { - return []subscription.Subscription{{Channel: "test"}}, nil - } - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } - - web.features.FullPayloadSubscribe = true - web.GenerateSubs = problemFunc - err = web.FlushChannels() // error on full subscribeToChannels - if err == nil { - t.Fatal("error cannot be nil") - } - - web.GenerateSubs = noSub - err = web.FlushChannels() // No subs to sub - if err != nil { - t.Fatal(err) - } - - web.GenerateSubs = newgen.generateSubs - subs, err := web.GenerateSubs() - if err != nil { - t.Fatal(err) - } - web.AddSuccessfulSubscriptions(subs...) - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } - web.features.FullPayloadSubscribe = false - web.features.Subscribe = true - - web.GenerateSubs = problemFunc - err = web.FlushChannels() - if err == nil { - t.Fatal("error cannot be nil") - } - - web.GenerateSubs = newgen.generateSubs - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } - web.subscriptionMutex.Lock() - web.subscriptions = subscriptionMap{ - 41: { - Key: 41, - Channel: "match channel", - Pair: currency.NewPair(currency.BTC, currency.AUD), - }, - 42: { - Key: 42, - Channel: "unsub channel", - Pair: currency.NewPair(currency.THETA, currency.USDT), - }, - } - web.subscriptionMutex.Unlock() + w.GenerateSubs = func() (subscription.List, error) { + return subscription.List{{Channel: "test"}}, nil + } + err = w.FlushChannels() + require.NoError(t, err, "Flush Channels must not error") + + w.features.FullPayloadSubscribe = true + w.GenerateSubs = problemFunc + err = w.FlushChannels() // error on full subscribeToChannels + assert.ErrorIs(t, err, errDastardlyReason, "FlushChannels should error correctly on GenerateSubs") + + w.GenerateSubs = noSub + err = w.FlushChannels() // No subs to sub + assert.NoError(t, err, "Flush Channels should not error") + + w.GenerateSubs = newgen.generateSubs + subs, err := w.GenerateSubs() + require.NoError(t, err, "GenerateSubs must not error") + for _, s := range subs { + s.SetState(subscription.SubscribedState) + } + w.AddSubscriptions(subs) + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") + w.features.FullPayloadSubscribe = false + w.features.Subscribe = true + + w.GenerateSubs = newgen.generateSubs + w.subscriptions = subscription.NewStore() + w.subscriptions.Add(&subscription.Subscription{ + Key: 41, + Channel: "match channel", + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.AUD)}, + }) + w.subscriptions.Add(&subscription.Subscription{ + Key: 42, + Channel: "unsub channel", + Pairs: currency.Pairs{currency.NewPair(currency.THETA, currency.USDT)}, + }) - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } - - web.setConnectedStatus(true) - web.features.Unsubscribe = true - err = web.FlushChannels() - if err != nil { - t.Fatal(err) - } + w.setState(connectedState) + w.features.Unsubscribe = true + err = w.FlushChannels() + assert.NoError(t, err, "FlushChannels should not error") } func TestDisable(t *testing.T) { t.Parallel() - web := Websocket{ - enabled: true, - connected: true, - ShutdownC: make(chan struct{}), - } - err := web.Disable() - if err != nil { - t.Fatal(err) - } - err = web.Disable() - if err == nil { - t.Fatal("should already be disabled") - } + w := NewWebsocket() + w.setEnabled(true) + w.setState(connectedState) + require.NoError(t, w.Disable(), "Disable must not error") + assert.ErrorIs(t, w.Disable(), ErrAlreadyDisabled, "Disable should error correctly") } func TestEnable(t *testing.T) { t.Parallel() - web := Websocket{ - connector: connect, - Wg: new(sync.WaitGroup), - ShutdownC: make(chan struct{}), - GenerateSubs: func() ([]subscription.Subscription, error) { - return []subscription.Subscription{{Channel: "test"}}, nil - }, - Subscriber: func([]subscription.Subscription) error { return nil }, - } - - err := web.Enable() - if err != nil { - t.Fatal(err) - } - - err = web.Enable() - if err == nil { - t.Fatal("should already be enabled") - } - - fmt.Print() + w := NewWebsocket() + w.connector = connect + w.Subscriber = func(subscription.List) error { return nil } + w.Unsubscriber = func(subscription.List) error { return nil } + w.GenerateSubs = func() (subscription.List, error) { return nil, nil } + require.NoError(t, w.Enable(), "Enable must not error") + assert.ErrorIs(t, w.Enable(), errWebsocketAlreadyEnabled, "Enable should error correctly") } func TestSetupNewConnection(t *testing.T) { t.Parallel() var nonsenseWebsock *Websocket err := nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errWebsocketIsNil, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigNameEmpty, "SetupNewConnection should error correctly") nonsenseWebsock = &Websocket{exchangeName: "test"} err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errTrafficAlertNil, "SetupNewConnection should error correctly") - nonsenseWebsock.TrafficAlert = make(chan struct{}) + nonsenseWebsock.TrafficAlert = make(chan struct{}, 1) err = nonsenseWebsock.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errReadMessageErrorsNil, "SetupNewConnection should error correctly") - web := Websocket{ - connector: connect, - Wg: new(sync.WaitGroup), - ShutdownC: make(chan struct{}), - Init: true, - TrafficAlert: make(chan struct{}), - ReadMessageErrors: make(chan error), - DataHandler: make(chan interface{}), - } + web := NewWebsocket() err = web.Setup(defaultSetup) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Setup should not error") + err = web.SetupNewConnection(ConnectionSetup{}) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorIs(t, err, errExchangeConfigEmpty, "SetupNewConnection should error correctly") + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring"}) - if err != nil { - t.Fatal(err) - } - err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", - Authenticated: true}) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "SetupNewConnection should not error") + + err = web.SetupNewConnection(ConnectionSetup{URL: "urlstring", Authenticated: true}) + assert.NoError(t, err, "SetupNewConnection should not error") } func TestWebsocketConnectionShutdown(t *testing.T) { t.Parallel() wc := WebsocketConnection{} err := wc.Shutdown() - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err, "Shutdown should not error") err = wc.Dial(&websocket.Dialer{}, nil) - if err == nil { - t.Fatal("error cannot be nil") - } + assert.ErrorContains(t, err, "malformed ws or wss URL", "Dial must error correctly") wc.URL = websocketTestURL err = wc.Dial(&websocket.Dialer{}, nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Dial must not error") err = wc.Shutdown() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err, "Shutdown must not error") } // TestLatency logic test @@ -1431,27 +1205,25 @@ func TestCheckSubscriptions(t *testing.T) { t.Parallel() ws := Websocket{} err := ws.checkSubscriptions(nil) - if !errors.Is(err, errNoSubscriptionsSupplied) { - t.Fatalf("received: %v, but expected: %v", err, errNoSubscriptionsSupplied) - } + assert.ErrorIs(t, err, errNoSubscriptionsSupplied, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 1 - err = ws.checkSubscriptions([]subscription.Subscription{{}, {}}) - if !errors.Is(err, errSubscriptionsExceedsLimit) { - t.Fatalf("received: %v, but expected: %v", err, errSubscriptionsExceedsLimit) - } + err = ws.checkSubscriptions(subscription.List{{}, {}}) + assert.ErrorIs(t, err, common.ErrNilPointer, "checkSubscriptions should error correctly when subscriptions is empty") + + ws.subscriptions = subscription.NewStore() + err = ws.checkSubscriptions(subscription.List{{}, {}}) + assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 - ws.subscriptions = subscriptionMap{42: {Key: 42, Channel: "test"}} - err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}}) - if !errors.Is(err, errChannelAlreadySubscribed) { - t.Fatalf("received: %v, but expected: %v", err, errChannelAlreadySubscribed) - } + ws.subscriptions = subscription.NewStore() + err = ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"}) + require.NoError(t, err, "Add subscription must not error") + err = ws.checkSubscriptions(subscription.List{{Key: 42, Channel: "test"}}) + assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly") - err = ws.checkSubscriptions([]subscription.Subscription{{}}) - if !errors.Is(err, nil) { - t.Fatalf("received: %v, but expected: %v", err, nil) - } + err = ws.checkSubscriptions(subscription.List{{}}) + assert.NoError(t, err, "checkSubscriptions should not error") } diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 925c34b907c..707fc7dcb05 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -2,6 +2,7 @@ package stream import ( "sync" + "sync/atomic" "time" "github.com/gorilla/websocket" @@ -15,28 +16,29 @@ import ( // Websocket functionality list and state consts const ( - // WebsocketNotEnabled alerts of a disabled websocket - WebsocketNotEnabled = "exchange_websocket_not_enabled" WebsocketNotAuthenticatedUsingRest = "%v - Websocket not authenticated, using REST\n" Ping = "ping" Pong = "pong" UnhandledMessage = " - Unhandled websocket message: " ) -type subscriptionMap map[any]*subscription.Subscription +const ( + uninitialisedState uint32 = iota + disconnectedState + connectingState + connectedState +) // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { - canUseAuthenticatedEndpoints bool - enabled bool - Init bool - connected bool - connecting bool + canUseAuthenticatedEndpoints atomic.Bool + enabled atomic.Bool + state atomic.Uint32 verbose bool - connectionMonitorRunning bool - trafficMonitorRunning bool - dataMonitorRunning bool + connectionMonitorRunning atomic.Bool + trafficMonitorRunning atomic.Bool + dataMonitorRunning atomic.Bool trafficTimeout time.Duration connectionMonitorDelay time.Duration proxyAddr string @@ -46,23 +48,17 @@ type Websocket struct { runningURLAuth string exchangeName string m sync.Mutex - fieldMutex sync.RWMutex connector func() error subscriptionMutex sync.RWMutex - subscriptions subscriptionMap - Subscribe chan []subscription.Subscription - Unsubscribe chan []subscription.Subscription - - // Subscriber function for package defined websocket subscriber - // functionality - Subscriber func([]subscription.Subscription) error - // Unsubscriber function for packaged defined websocket unsubscriber - // functionality - Unsubscriber func([]subscription.Subscription) error - // GenerateSubs function for package defined websocket generate - // subscriptions functionality - GenerateSubs func() ([]subscription.Subscription, error) + subscriptions *subscription.Store + + // Subscriber function for exchange specific subscribe implementation + Subscriber func(subscription.List) error + // Subscriber function for exchange specific unsubscribe implementation + Unsubscriber func(subscription.List) error + // GenerateSubs function for exchange specific generating subscriptions from Features.Subscriptions, Pairs and Assets + GenerateSubs func() (subscription.List, error) DataHandler chan interface{} ToRoutine chan interface{} @@ -71,7 +67,7 @@ type Websocket struct { // shutdown synchronises shutdown event across routines ShutdownC chan struct{} - Wg *sync.WaitGroup + Wg sync.WaitGroup // Orderbook is a local buffer of orderbooks Orderbook buffer.Orderbook @@ -109,9 +105,9 @@ type WebsocketSetup struct { RunningURL string RunningURLAuth string Connector func() error - Subscriber func([]subscription.Subscription) error - Unsubscriber func([]subscription.Subscription) error - GenerateSubscriptions func() ([]subscription.Subscription, error) + Subscriber func(subscription.List) error + Unsubscriber func(subscription.List) error + GenerateSubscriptions func() (subscription.List, error) Features *protocol.Features // Local orderbook buffer config values diff --git a/exchanges/subscription/list.go b/exchanges/subscription/list.go new file mode 100644 index 00000000000..4d3cef229e8 --- /dev/null +++ b/exchanges/subscription/list.go @@ -0,0 +1,16 @@ +package subscription + +import "slices" + +// List is a container of subscription pointers +type List []*Subscription + +// Strings returns a sorted slice of subscriptions +func (l List) Strings() []string { + s := make([]string, len(l)) + for i := range l { + s[i] = l[i].String() + } + slices.Sort(s) + return s +} diff --git a/exchanges/subscription/list_test.go b/exchanges/subscription/list_test.go new file mode 100644 index 00000000000..e2293e7ea9d --- /dev/null +++ b/exchanges/subscription/list_test.go @@ -0,0 +1,25 @@ +package subscription + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" +) + +func TestListStrings(t *testing.T) { + l := List{ + &Subscription{ + Channel: TickerChannel, + Asset: asset.Spot, + Pairs: currency.Pairs{ethusdcPair, btcusdtPair}, + }, + &Subscription{ + Channel: OrderbookChannel, + Pairs: currency.Pairs{ethusdcPair}, + }, + } + exp := []string{"orderbook ETH/USDC", "ticker spot ETH/USDC,BTC/USDT"} + assert.ElementsMatch(t, exp, l.Strings(), "String must return correct sorted list") +} diff --git a/exchanges/subscription/store.go b/exchanges/subscription/store.go new file mode 100644 index 00000000000..18db7fe338b --- /dev/null +++ b/exchanges/subscription/store.go @@ -0,0 +1,176 @@ +package subscription + +import ( + "maps" + "sync" + + "github.com/thrasher-corp/gocryptotrader/common" +) + +// Store is a container of subscription pointers +type Store struct { + m map[any]*Subscription + mu sync.RWMutex +} + +// NewStore creates a ready to use store and should always be used +func NewStore() *Store { + return &Store{ + m: map[any]*Subscription{}, + } +} + +// NewStoreFromList creates a Store from a List +func NewStoreFromList(l List) (*Store, error) { + s := NewStore() + for _, sub := range l { + if err := s.add(sub); err != nil { + return nil, err + } + } + return s, nil +} + +// Add adds a subscription to the store +// Key can be already set; if omitted EnsureKeyed will be used +// Errors if it already exists +func (s *Store) Add(sub *Subscription) error { + if s == nil || sub == nil { + return common.ErrNilPointer + } + s.mu.Lock() + defer s.mu.Unlock() + return s.add(sub) +} + +// Add adds a subscription to the store +// This method provides no locking protection +func (s *Store) add(sub *Subscription) error { + key := sub.EnsureKeyed() + if found := s.get(key); found != nil { + return ErrDuplicate + } + s.m[key] = sub + return nil +} + +// Get returns a pointer to a subscription or nil if not found +// If key implements MatchableKey then key.Match will be used +func (s *Store) Get(key any) *Subscription { + if s == nil { + return nil + } + s.mu.RLock() + defer s.mu.RUnlock() + return s.get(key) +} + +// get returns a pointer to subscription or nil if not found +// If the key passed in is a Subscription then its Key will be used; which may be a pointer to itself. +// If key implements MatchableKey then key.Match will be used; Note that *Subscription implements MatchableKey +// This method provides no locking protection +// returned subscriptions are implicitly guaranteed to have a Key +func (s *Store) get(key any) *Subscription { + switch v := key.(type) { + case Subscription: + key = v.EnsureKeyed() + case *Subscription: + key = v.EnsureKeyed() + } + + switch v := key.(type) { + case MatchableKey: + return s.match(v) + default: + return s.m[v] + } +} + +// Remove removes a subscription from the store +func (s *Store) Remove(sub *Subscription) error { + if s == nil || sub == nil { + return common.ErrNilPointer + } + s.mu.Lock() + defer s.mu.Unlock() + + if found := s.get(sub); found != nil { + delete(s.m, found.Key) + return nil + } + + return ErrNotFound +} + +// List returns a slice of Subscriptions pointers +func (s *Store) List() List { + if s == nil { + return List{} + } + s.mu.RLock() + defer s.mu.RUnlock() + subs := make(List, 0, len(s.m)) + for _, s := range s.m { + subs = append(subs, s) + } + return subs +} + +// Clear empties the subscription store +func (s *Store) Clear() { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + clear(s.m) +} + +// match returns the first subscription which matches the Key's Asset, Channel and Pairs +// If the key provided has: +// 1) Empty pairs then only Subscriptions without pairs will be considered +// 2) >=1 pairs then Subscriptions which contain all the pairs will be considered +// This method provides no locking protection +func (s *Store) match(key MatchableKey) *Subscription { + for anyKey, s := range s.m { + if key.Match(anyKey) { + return s + } + } + return nil +} + +// Diff returns a list of the added and missing subs from a new list +// The store Diff is invoked upon is read-lock protected +// The new store is assumed to be a new instance and enjoys no locking protection +func (s *Store) Diff(compare List) (added, removed List) { + if s == nil { + return + } + s.mu.RLock() + defer s.mu.RUnlock() + removedMap := maps.Clone(s.m) + for _, sub := range compare { + if found := s.get(sub); found != nil { + delete(removedMap, found.Key) + } else { + added = append(added, sub) + } + } + + for _, c := range removedMap { + removed = append(removed, c) + } + + return +} + +// Len returns the number of subscriptions +func (s *Store) Len() int { + if s == nil { + return 0 + } + s.mu.RLock() + defer s.mu.RUnlock() + return len(s.m) +} diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 874822ba79a..32e799ad44c 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -1,92 +1,129 @@ package subscription import ( - "encoding/json" + "errors" "fmt" + "sync" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" ) -// DefaultKey is the fallback key for AddSuccessfulSubscriptions -type DefaultKey struct { - Channel string - Pair currency.Pair - Asset asset.Item -} - -// State tracks the status of a subscription channel -type State uint8 +// State constants +const ( + InactiveState State = iota + SubscribingState + SubscribedState + UnsubscribingState +) +// Ticker constants const ( - UnknownState State = iota // UnknownState subscription state is not registered, but doesn't imply Inactive - SubscribingState // SubscribingState means channel is in the process of subscribing - SubscribedState // SubscribedState means the channel has finished a successful and acknowledged subscription - UnsubscribingState // UnsubscribingState means the channel has started to unsubscribe, but not yet confirmed + TickerChannel = "ticker" + OrderbookChannel = "orderbook" + CandlesChannel = "candles" + AllOrdersChannel = "allOrders" + AllTradesChannel = "allTrades" + MyTradesChannel = "myTrades" + MyOrdersChannel = "myOrders" +) - TickerChannel = "ticker" // TickerChannel Subscription Type - OrderbookChannel = "orderbook" // OrderbookChannel Subscription Type - CandlesChannel = "candles" // CandlesChannel Subscription Type - AllOrdersChannel = "allOrders" // AllOrdersChannel Subscription Type - AllTradesChannel = "allTrades" // AllTradesChannel Subscription Type - MyTradesChannel = "myTrades" // MyTradesChannel Subscription Type - MyOrdersChannel = "myOrders" // MyOrdersChannel Subscription Type +// Public errors +var ( + ErrNotFound = errors.New("subscription not found") + ErrNotSinglePair = errors.New("only single pair subscriptions expected") + ErrInStateAlready = errors.New("subscription already in state") + ErrInvalidState = errors.New("invalid subscription state") + ErrDuplicate = errors.New("duplicate subscription") ) +// State tracks the status of a subscription channel +type State uint8 + // Subscription container for streaming subscriptions type Subscription struct { Enabled bool `json:"enabled"` Key any `json:"-"` Channel string `json:"channel,omitempty"` - Pair currency.Pair `json:"pair,omitempty"` + Pairs currency.Pairs `json:"pairs,omitempty"` Asset asset.Item `json:"asset,omitempty"` Params map[string]interface{} `json:"params,omitempty"` - State State `json:"-"` Interval kline.Interval `json:"interval,omitempty"` Levels int `json:"levels,omitempty"` Authenticated bool `json:"authenticated,omitempty"` + state State + m sync.RWMutex } -// MarshalJSON generates a JSON representation of a Subscription, specifically for config writing -// The only reason it exists is to avoid having to make Pair a pointer, since that would be generally painful -// If Pair becomes a pointer, this method is redundant and should be removed -func (s *Subscription) MarshalJSON() ([]byte, error) { - // None of the usual type embedding tricks seem to work for not emitting an nil Pair - // The embedded type's Pair always fills the empty value - type MaybePair struct { - Enabled bool `json:"enabled"` - Channel string `json:"channel,omitempty"` - Asset asset.Item `json:"asset,omitempty"` - Params map[string]interface{} `json:"params,omitempty"` - Interval kline.Interval `json:"interval,omitempty"` - Levels int `json:"levels,omitempty"` - Authenticated bool `json:"authenticated,omitempty"` - Pair *currency.Pair `json:"pair,omitempty"` - } - - k := MaybePair{s.Enabled, s.Channel, s.Asset, s.Params, s.Interval, s.Levels, s.Authenticated, nil} - if s.Pair != currency.EMPTYPAIR { - k.Pair = &s.Pair - } - - return json.Marshal(k) +// MatchableKey interface should be implemented by Key types which want a more complex matching than a simple key equality check +type MatchableKey interface { + Match(any) bool } // String implements the Stringer interface for Subscription, giving a human representation of the subscription func (s *Subscription) String() string { - return fmt.Sprintf("%s %s %s", s.Channel, s.Asset, s.Pair) + return fmt.Sprintf("%s %s %s", s.Channel, s.Asset, s.Pairs) +} + +// State returns the subscription state +func (s *Subscription) State() State { + s.m.RLock() + defer s.m.RUnlock() + return s.state +} + +// SetState sets the subscription state +// Errors if already in that state or the new state is not valid +func (s *Subscription) SetState(state State) error { + s.m.Lock() + defer s.m.Unlock() + if state == s.state { + return ErrInStateAlready + } + if state > UnsubscribingState { + return ErrInvalidState + } + s.state = state + return nil } -// EnsureKeyed sets the default key on a channel if it doesn't have one -// Returns key for convenience +// EnsureKeyed returns the subscription key +// If no key exists then a pointer to the subscription itself will be used, since Subscriptions implement MatchableKey func (s *Subscription) EnsureKeyed() any { if s.Key == nil { - s.Key = DefaultKey{ - Channel: s.Channel, - Asset: s.Asset, - Pair: s.Pair, - } + s.Key = s } return s.Key } + +// Match returns if the two keys match Channels, Assets, Pairs, Interval and Levels: +// Key Pairs comparison: +// 1) Empty pairs then only Subscriptions without pairs match +// 2) >=1 pairs then Subscriptions which contain all the pairs match +// Such that a subscription for all enabled pairs will be matched when seaching for any one pair +func (s *Subscription) Match(key any) bool { + var b *Subscription + switch v := key.(type) { + case *Subscription: + b = v + case Subscription: + b = &v + default: + return false + } + + switch { + case b.Channel != s.Channel, + b.Asset != s.Asset, + // len(b.Pairs) == 0 && len(s.Pairs) == 0: Okay; continue to next non-pairs check + len(b.Pairs) == 0 && len(s.Pairs) != 0, + len(b.Pairs) != 0 && len(s.Pairs) == 0, + len(b.Pairs) != 0 && s.Pairs.ContainsAll(b.Pairs, true) != nil, + b.Levels != s.Levels, + b.Interval != s.Interval: + return false + } + + return true +} diff --git a/exchanges/subscription/subscription_test.go b/exchanges/subscription/subscription_test.go index 4f9a97ab979..b9a71b4ae6a 100644 --- a/exchanges/subscription/subscription_test.go +++ b/exchanges/subscription/subscription_test.go @@ -5,44 +5,73 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" ) -// TestEnsureKeyed logic test -func TestEnsureKeyed(t *testing.T) { - t.Parallel() - c := Subscription{ +var ( + btcusdtPair = currency.NewPair(currency.BTC, currency.USDT) + ethusdcPair = currency.NewPair(currency.ETH, currency.USDC) + ltcusdcPair = currency.NewPair(currency.LTC, currency.USDC) +) + +// TestSubscriptionString exercises the String method +func TestSubscriptionString(t *testing.T) { + s := &Subscription{ Channel: "candles", Asset: asset.Spot, - Pair: currency.NewPair(currency.BTC, currency.USDT), + Pairs: currency.Pairs{btcusdtPair, ethusdcPair.Format(currency.PairFormat{Delimiter: "/"})}, + } + assert.Equal(t, "candles spot BTC/USDT,ETH/USDC", s.String(), "Subscription String should return correct value") +} + +// TestState exercises the state getter +func TestState(t *testing.T) { + t.Parallel() + s := &Subscription{} + assert.Equal(t, InactiveState, s.State(), "State should return initial state") + s.state = SubscribedState + assert.Equal(t, SubscribedState, s.State(), "State should return correct state") +} + +// TestSetState exercises the state setter +func TestSetState(t *testing.T) { + t.Parallel() + + s := &Subscription{state: UnsubscribingState} + + for i := InactiveState; i <= UnsubscribingState; i++ { + assert.NoErrorf(t, s.SetState(i), "State should not error setting state %s", i) } - k1, ok := c.EnsureKeyed().(DefaultKey) - if assert.True(t, ok, "EnsureKeyed should return a DefaultKey") { - assert.Exactly(t, k1, c.Key, "EnsureKeyed should set the same key") - assert.Equal(t, k1.Channel, c.Channel, "DefaultKey channel should be correct") - assert.Equal(t, k1.Asset, c.Asset, "DefaultKey asset should be correct") - assert.Equal(t, k1.Pair, c.Pair, "DefaultKey currency should be correct") + assert.ErrorIs(t, s.SetState(UnsubscribingState), ErrInStateAlready, "SetState should error on same state") + assert.ErrorIs(t, s.SetState(UnsubscribingState+1), ErrInvalidState, "Setting an invalid state should error") +} + +// TestEnsureKeyed exercises the key getter and ensures it sets a self-pointer key for non +func TestEnsureKeyed(t *testing.T) { + t.Parallel() + s := &Subscription{} + k1, ok := s.EnsureKeyed().(*Subscription) + if assert.True(t, ok, "EnsureKeyed should return a *Subscription") { + assert.Same(t, s, k1, "Key should point to the same struct") } type platypus string - c = Subscription{ + s = &Subscription{ Key: platypus("Gerald"), Channel: "orderbook", - Asset: asset.Margin, - Pair: currency.NewPair(currency.ETH, currency.USDC), - } - k2, ok := c.EnsureKeyed().(platypus) - if assert.True(t, ok, "EnsureKeyed should return a platypus") { - assert.Exactly(t, k2, c.Key, "EnsureKeyed should set the same key") - assert.EqualValues(t, "Gerald", k2, "key should have the correct value") } + k2 := s.EnsureKeyed() + assert.IsType(t, platypus(""), k2, "EnsureKeyed should return a platypus") + assert.Equal(t, s.Key, k2, "Key should be the key provided") } -// TestMarshalling logic test -func TestMarshaling(t *testing.T) { +// TestSubscriptionMarshalling ensures json Marshalling is clean and concise +// Since there is no UnmarshalJSON, this just exercises the json field tags of Subscription, and regressions in conciseness +func TestSubscriptionMarshaling(t *testing.T) { t.Parallel() - j, err := json.Marshal(&Subscription{Channel: CandlesChannel}) + j, err := json.Marshal(&Subscription{Key: 42, Channel: CandlesChannel}) assert.NoError(t, err, "Marshalling should not error") assert.Equal(t, `{"enabled":false,"channel":"candles"}`, string(j), "Marshalling should be clean and concise") @@ -50,11 +79,53 @@ func TestMarshaling(t *testing.T) { assert.NoError(t, err, "Marshalling should not error") assert.Equal(t, `{"enabled":true,"channel":"orderbook","interval":"5m","levels":4}`, string(j), "Marshalling should be clean and concise") - j, err = json.Marshal(&Subscription{Enabled: true, Channel: OrderbookChannel, Interval: kline.FiveMin, Levels: 4, Pair: currency.NewPair(currency.BTC, currency.USDT)}) + j, err = json.Marshal(&Subscription{Enabled: true, Channel: OrderbookChannel, Interval: kline.FiveMin, Levels: 4, Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.USDT)}}) assert.NoError(t, err, "Marshalling should not error") - assert.Equal(t, `{"enabled":true,"channel":"orderbook","interval":"5m","levels":4,"pair":"BTCUSDT"}`, string(j), "Marshalling should be clean and concise") + assert.Equal(t, `{"enabled":true,"channel":"orderbook","pairs":"BTCUSDT","interval":"5m","levels":4}`, string(j), "Marshalling should be clean and concise") j, err = json.Marshal(&Subscription{Enabled: true, Channel: MyTradesChannel, Authenticated: true}) assert.NoError(t, err, "Marshalling should not error") assert.Equal(t, `{"enabled":true,"channel":"myTrades","authenticated":true}`, string(j), "Marshalling should be clean and concise") } + +// TestSubscriptionMatch exercises the Subscription MatchableKey interface implementation +func TestSubscriptionMatch(t *testing.T) { + t.Parallel() + require.Implements(t, (*MatchableKey)(nil), new(Subscription), "Must implement MatchableKey") + s := &Subscription{Channel: TickerChannel} + assert.NotNil(t, s.EnsureKeyed(), "EnsureKeyed should work") + assert.False(t, s.Match(42), "Match should reject an invalid key type") + try := &Subscription{Channel: OrderbookChannel} + require.False(t, s.Match(try), "Gate 1: Match must reject a bad Channel") + try = &Subscription{Channel: TickerChannel} + require.True(t, s.Match(Subscription{Channel: TickerChannel}), "Match must accept a pass-by-value subscription") + require.True(t, s.Match(try), "Gate 1: Match must accept a good Channel") + s.Asset = asset.Spot + require.False(t, s.Match(try), "Gate 2: Match must reject a bad Asset") + try.Asset = asset.Spot + require.True(t, s.Match(try), "Gate 2: Match must accept a good Asset") + + s.Pairs = currency.Pairs{btcusdtPair} + require.False(t, s.Match(try), "Gate 3: Match must reject a pair list when searching for no pairs") + try.Pairs = s.Pairs + s.Pairs = nil + require.False(t, s.Match(try), "Gate 4: Match must reject empty Pairs when searching for a list") + s.Pairs = try.Pairs + require.True(t, s.Match(try), "Gate 5: Match must accept matching pairs") + s.Pairs = currency.Pairs{ethusdcPair} + require.False(t, s.Match(try), "Gate 5: Match must reject mismatched pairs") + s.Pairs = currency.Pairs{btcusdtPair, ethusdcPair} + require.True(t, s.Match(try), "Gate 5: Match must accept one of the key pairs matching in sub pairs") + try.Pairs = currency.Pairs{btcusdtPair, ltcusdcPair} + require.False(t, s.Match(try), "Gate 5: Match must reject when sub pair list doesn't contain all key pairs") + s.Pairs = currency.Pairs{btcusdtPair, ethusdcPair, ltcusdcPair} + require.True(t, s.Match(try), "Gate 5: Match must accept all of the key pairs are contained in sub pairs") + s.Levels = 4 + require.False(t, s.Match(try), "Gate 6: Match must reject a bad Level") + try.Levels = 4 + require.True(t, s.Match(try), "Gate 6: Match must accept a good Level") + s.Interval = kline.FiveMin + require.False(t, s.Match(try), "Gate 7: Match must reject a bad Interval") + try.Interval = kline.FiveMin + require.True(t, s.Match(try), "Gate 7: Match must accept a good Inteval") +} diff --git a/internal/testing/exchange/exchange.go b/internal/testing/exchange/exchange.go index 8a943677984..22e5fb28f13 100644 --- a/internal/testing/exchange/exchange.go +++ b/internal/testing/exchange/exchange.go @@ -94,6 +94,8 @@ func MockWSInstance[T any, PT interface { b := e.GetBase() b.SkipAuthCheck = true + b.API.AuthenticatedWebsocketSupport = true + err := b.API.Endpoints.SetRunning("RestSpotURL", s.URL) require.NoError(tb, err, "Endpoints.SetRunning should not error for RestSpotURL") for _, auth := range []bool{true, false} { @@ -146,13 +148,15 @@ func SetupWs(tb testing.TB, e exchange.IBotExchange) { } b := e.GetBase() - if !b.Websocket.IsEnabled() { + w, err := b.GetWebsocket() + if err != nil || !b.Websocket.IsEnabled() { tb.Skip("Websocket not enabled") } - if b.Websocket.IsConnected() { + if w.IsConnected() { return } - err := b.Websocket.Connect() + + err = w.Connect() require.NoError(tb, err, "WsConnect should not error") setupWsOnce[e] = true diff --git a/internal/testing/subscriptions/subscriptions.go b/internal/testing/subscriptions/subscriptions.go new file mode 100644 index 00000000000..1604be279b0 --- /dev/null +++ b/internal/testing/subscriptions/subscriptions.go @@ -0,0 +1,27 @@ +package subscriptionstest + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" +) + +func Equal(tb testing.TB, a, b subscription.List) { + tb.Helper() + s, err := subscription.NewStoreFromList(a) + require.NoError(t, err, "NewStoreFromList must not error") + added, missing := s.Diff(b) + if len(added) > 0 || len(missing) > 0 { + fail := "Differences:" + if len(added) > 0 { + fail = fail + "\n + " + strings.Join(added.Strings(), "\n + ") + } + if len(missing) > 0 { + fail = fail + "\n - " + strings.Join(missing.Strings(), "\n - ") + } + assert.Fail(tb, fail, "Subscriptions should be equal") + } +}