diff --git a/exchanges/kraken/kraken_test.go b/exchanges/kraken/kraken_test.go index c029a91de08..64d49516012 100644 --- a/exchanges/kraken/kraken_test.go +++ b/exchanges/kraken/kraken_test.go @@ -34,11 +34,12 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" + testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" ) var k = &Kraken{} -var wsConnected bool +var btcusdtPair = currency.NewPair(currency.BTC, currency.USDT) // Please add your own APIkeys here or in config/testdata.json to do correct due diligence testing const ( @@ -1336,6 +1337,32 @@ func TestWsOwnTradesSub(t *testing.T) { assert.Len(t, k.Websocket.GetSubscriptions(), 0, "Should have successfully removed channel") } +// TestGenerateSubscriptions tests the subscriptions generated from configuration +func TestGenerateSubscriptions(t *testing.T) { + t.Parallel() + + subs, err := k.GenerateSubscriptions() + require.NoError(t, err, "GenerateSubscriptions should not error") + expected := []subscription.Subscription{} + pairs, err := k.GetEnabledPairs(asset.Spot) + for i := range pairs { + pairs[i].Delimiter = "/" + } + require.NoError(t, err, "GetEnabledPairs must not error") + require.False(t, k.Websocket.CanUseAuthenticatedEndpoints(), "Websocket must not be authenticated by default") + for _, exp := range k.Features.Subscriptions { + if exp.Authenticated { + continue + } + s := *exp + s.Channel = channelName(s.Channel) + s.Asset = asset.Spot + s.Pairs = pairs + expected = append(expected, s) + } + testsubs.Equal(t, expected, subs) +} + func TestGetWSToken(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, k) diff --git a/exchanges/kraken/kraken_websocket.go b/exchanges/kraken/kraken_websocket.go index bff978d57b5..33cf068378f 100644 --- a/exchanges/kraken/kraken_websocket.go +++ b/exchanges/kraken/kraken_websocket.go @@ -33,7 +33,8 @@ const ( krakenAuthWSURL = "wss://ws-auth.kraken.com" krakenWSSandboxURL = "wss://sandbox.kraken.com" krakenWSSupportedVersion = "1.4.0" - // WS endpoints + + // Websocket Channels krakenWsHeartbeat = "heartbeat" krakenWsSystemStatus = "systemStatus" krakenWsSubscribe = "subscribe" @@ -58,21 +59,20 @@ const ( krakenWsCandlesDefaultTimeframe = 1 ) +var subscriptionNames = map[string]string{ + subscription.TickerChannel: krakenWsTicker, + subscription.OrderbookChannel: krakenWsOrderbook, + subscription.CandlesChannel: krakenWsOHLC, + subscription.AllTradesChannel: krakenWsTrade, + subscription.MyTradesChannel: krakenWsOwnTrades, + subscription.MyOrdersChannel: krakenWsOpenOrders, + // No equivalents for: AllOrders +} + var ( authToken string ) -// Channels require a topic and a currency -// Format [[ticker,but-t4u],[orderbook,nce-btt]] -var defaultSubscribedChannels = []string{ - krakenWsTicker, - krakenWsTrade, - krakenWsOrderbook, - krakenWsOHLC, - krakenWsSpread, -} -var authenticatedChannels = []string{krakenWsOwnTrades, krakenWsOpenOrders} - // WsConnect initiates a websocket connection func (k *Kraken) WsConnect() error { if !k.Websocket.IsEnabled() || !k.IsEnabled() { @@ -999,41 +999,37 @@ func (k *Kraken) wsProcessCandles(c *subscription.Subscription, response []any, return nil } -// GenerateDefaultSubscriptions Adds default subscriptions to websocket to be handled by ManageSubscriptions() -func (k *Kraken) GenerateDefaultSubscriptions() ([]subscription.Subscription, error) { +// channelName converts global channel Names used in config of channel input into kucoin channel names +// returns the name unchanged if no match is found +func channelName(name string) string { + if s, ok := subscriptionNames[name]; ok { + return s + } + return name +} + +// GenerateSubscriptions sets up the configured subscriptions for the websocket +func (k *Kraken) GenerateSubscriptions() ([]subscription.Subscription, error) { + subscriptions := []subscription.Subscription{} enabledPairs, err := k.GetEnabledPairs(asset.Spot) if err != nil { return nil, err } - var subscriptions []subscription.Subscription - for i := range defaultSubscribedChannels { - /* - for j := range enabledPairs { - enabledPairs[j].Delimiter = "/" - } - */ - c := subscription.Subscription{ - Channel: defaultSubscribedChannels[i], - Pairs: enabledPairs, - Asset: asset.Spot, - Params: map[string]any{}, - } - switch defaultSubscribedChannels[i] { - case krakenWsOrderbook: - c.Params[ChannelOrderbookDepthKey] = krakenWsOrderbookDefaultDepth - case krakenWsOHLC: - c.Params[ChannelCandlesTimeframeKey] = krakenWsCandlesDefaultTimeframe - } - subscriptions = append(subscriptions, c) + for i := range enabledPairs { + enabledPairs[i].Delimiter = "/" } - if k.Websocket.CanUseAuthenticatedEndpoints() { - for i := range authenticatedChannels { - subscriptions = append(subscriptions, subscription.Subscription{ - Channel: authenticatedChannels[i], - Asset: asset.Spot, - }) + authed := k.Websocket.CanUseAuthenticatedEndpoints() + for _, baseSub := range k.Features.Subscriptions { + if !authed && baseSub.Authenticated { + continue } + s := *baseSub + s.Channel = channelName(s.Channel) + s.Asset = asset.Spot + s.Pairs = enabledPairs + subscriptions = append(subscriptions, s) } + return subscriptions, nil } @@ -1048,105 +1044,105 @@ func (k *Kraken) Unsubscribe(channels []subscription.Subscription) error { } // subscribeToChan sends a websocket message to receive data from the channel -func (k *Kraken) subscribeToChan(chans []subscription.Subscription) error { - if len(chans) != 1 { +func (k *Kraken) subscribeToChan(subs []subscription.Subscription) error { + if len(subs) != 1 { return errors.New("Kraken subscription batching not yet implemented") } - c := chans[0] - r, err := k.reqForSub(krakenWsSubscribe, &c) + s := subs[0] + r, err := k.reqForSub(krakenWsSubscribe, &s) if err != nil { - return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err) + return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, s.Channel, s.Pairs, err) } - if !c.Asset.IsValid() { - c.Asset = asset.Spot + if !s.Asset.IsValid() { + s.Asset = asset.Spot } - err = ensureChannelKeyed(&c, r) + err = ensureChannelKeyed(&s, r) if err != nil { return err } - c.State = subscription.SubscribingState - err = k.Websocket.AddSubscription(&c) + s.State = subscription.SubscribingState + err = k.Websocket.AddSubscription(&s) if err != nil { - return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err) + return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, s.Channel, s.Pairs, err) } conn := k.Websocket.Conn - if common.StringDataContains(authenticatedChannels, r.Subscription.Name) { + if s.Authenticated { r.Subscription.Token = authToken conn = k.Websocket.AuthConn } respRaw, err := conn.SendMessageReturnResponse(r.RequestID, r) if err != nil { - k.Websocket.RemoveSubscriptions(c) - return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err) + k.Websocket.RemoveSubscriptions(s) + return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrSubscriptionFailure, s.Channel, s.Pairs, err) } if err = k.getErrResp(respRaw); err != nil { - wErr := fmt.Errorf("%w Channel: %s Pair: %s; %w", stream.ErrSubscriptionFailure, c.Channel, c.Pairs, err) + wErr := fmt.Errorf("%w Channel: %s Pair: %s; %w", stream.ErrSubscriptionFailure, s.Channel, s.Pairs, err) k.Websocket.DataHandler <- wErr // Currently all or nothing on pairs; Alternatively parse response and remove failing pairs and retry - k.Websocket.RemoveSubscriptions(c) + k.Websocket.RemoveSubscriptions(s) return wErr } - if err = k.Websocket.SetSubscriptionState(&c, subscription.SubscribedState); err != nil { + if err = k.Websocket.SetSubscriptionState(&s, subscription.SubscribedState); err != nil { log.Errorf(log.ExchangeSys, "%s error setting channel to subscribed: %s", k.Name, err) } if k.Verbose { - log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s\n", k.Name, c.Channel, c.Pairs) + log.Debugf(log.ExchangeSys, "%s Subscribed to Channel: %s Pair: %s\n", k.Name, s.Channel, s.Pairs) } return nil } // unsubscribeFromChan sends a websocket message to stop receiving data from a channel -func (k *Kraken) unsubscribeFromChan(chans []subscription.Subscription) error { - if len(chans) != 1 { +func (k *Kraken) unsubscribeFromChan(subs []subscription.Subscription) error { + if len(subs) != 1 { return errors.New("Kraken subscription batching not yet implemented") } - c := chans[0] - r, err := k.reqForSub(krakenWsUnsubscribe, &c) + s := subs[0] + r, err := k.reqForSub(krakenWsUnsubscribe, &s) if err != nil { - return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrUnsubscribeFailure, c.Channel, c.Pairs, err) + return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrUnsubscribeFailure, s.Channel, s.Pairs, err) } - c.EnsureKeyed() + s.EnsureKeyed() - if err = k.Websocket.SetSubscriptionState(&c, subscription.UnsubscribingState); err != nil { + if err = k.Websocket.SetSubscriptionState(&s, subscription.UnsubscribingState); err != nil { // err is probably ErrChannelInStateAlready, but we want to bubble it up to prevent an attempt to Subscribe again // We can catch and ignore it in our call to resub - return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrUnsubscribeFailure, c.Channel, c.Pairs, err) + return fmt.Errorf("%w Channel: %s Pair: %s Error: %w", stream.ErrUnsubscribeFailure, s.Channel, s.Pairs, err) } conn := k.Websocket.Conn - if common.StringDataContains(authenticatedChannels, c.Channel) { + if s.Authenticated { conn = k.Websocket.AuthConn r.Subscription.Token = authToken } respRaw, err := conn.SendMessageReturnResponse(r.RequestID, r) if err != nil { - if e2 := k.Websocket.SetSubscriptionState(&c, subscription.SubscribedState); e2 != nil { + if e2 := k.Websocket.SetSubscriptionState(&s, subscription.SubscribedState); e2 != nil { log.Errorf(log.ExchangeSys, "%s error setting channel to subscribed: %s", k.Name, e2) } return err } if err = k.getErrResp(respRaw); err != nil { - wErr := fmt.Errorf("%w Channel: %s Pair: %s; %w", stream.ErrUnsubscribeFailure, c.Channel, c.Pairs, err) + wErr := fmt.Errorf("%w Channel: %s Pair: %s; %w", stream.ErrUnsubscribeFailure, s.Channel, s.Pairs, err) k.Websocket.DataHandler <- wErr - if e2 := k.Websocket.SetSubscriptionState(&c, subscription.SubscribedState); e2 != nil { + if e2 := k.Websocket.SetSubscriptionState(&s, subscription.SubscribedState); e2 != nil { log.Errorf(log.ExchangeSys, "%s error setting channel to subscribed: %s", k.Name, e2) } return wErr } - k.Websocket.RemoveSubscriptions(c) + k.Websocket.RemoveSubscriptions(s) return nil } diff --git a/exchanges/kraken/kraken_wrapper.go b/exchanges/kraken/kraken_wrapper.go index a18b13695b8..02065cf4f57 100644 --- a/exchanges/kraken/kraken_wrapper.go +++ b/exchanges/kraken/kraken_wrapper.go @@ -28,6 +28,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" "github.com/thrasher-corp/gocryptotrader/log" @@ -191,6 +192,14 @@ func (k *Kraken) SetDefaults() { GlobalResultLimit: 720, }, }, + Subscriptions: []*subscription.Subscription{ + {Enabled: true, Channel: subscription.TickerChannel}, + {Enabled: true, Channel: subscription.AllTradesChannel}, + {Enabled: true, Channel: subscription.CandlesChannel, Interval: kline.OneMin}, + {Enabled: true, Channel: subscription.OrderbookChannel, Levels: 1000}, + {Enabled: true, Channel: subscription.MyOrdersChannel, Authenticated: true}, + {Enabled: true, Channel: subscription.MyTradesChannel, Authenticated: true}, + }, } k.Requester, err = request.New(k.Name, @@ -242,7 +251,7 @@ func (k *Kraken) Setup(exch *config.Exchange) error { Connector: k.WsConnect, Subscriber: k.Subscribe, Unsubscriber: k.Unsubscribe, - GenerateSubscriptions: k.GenerateDefaultSubscriptions, + GenerateSubscriptions: k.GenerateSubscriptions, Features: &k.Features.Supports.WebsocketCapabilities, OrderbookBufferConfig: buffer.Config{SortBuffer: true}, }) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index f8f0850d82d..68873cc760e 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -875,36 +875,10 @@ func (w *Websocket) GetName() string { // GetChannelDifference finds the difference between the subscribed channels // and the new subscription list when pairs are disabled or enabled. -func (w *Websocket) GetChannelDifference(genSubs []subscription.Subscription) (sub, unsub []subscription.Subscription) { +func (w *Websocket) GetChannelDifference(newSubs []subscription.Subscription) (sub, unsub []subscription.Subscription) { w.subscriptionMutex.RLock() - unsubMap := subscription.Map{} - for k, c := range w.subscriptions { - unsubMap[k] = c - } - w.subscriptionMutex.RUnlock() - - for i := range genSubs { - key := genSubs[i].EnsureKeyed() - - var found *subscription.Subscription - if m, ok := key.(subscription.MatchableKey); ok { - found = m.Match(unsubMap) - } else { - found = unsubMap[key] - } - - if found != nil { - delete(unsubMap, found.Key) // If it's in both then we remove it from the unsubscribe list - } else { - sub = append(sub, genSubs[i]) // If it's in genSubs but not existing subs we want to subscribe - } - } - - for _, c := range unsubMap { - unsub = append(unsub, *c) - } - - return + defer w.subscriptionMutex.RUnlock() + return w.subscriptions.Diff(subscription.ListToMap(newSubs)) } // UnsubscribeChannels unsubscribes from a websocket channel diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 36e44679ec8..a94b28985aa 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/config" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/protocol" @@ -1045,7 +1046,7 @@ func TestCheckWebsocketURL(t *testing.T) { func TestGetChannelDifference(t *testing.T) { t.Parallel() - web := Websocket{} + w := Websocket{} newChans := []subscription.Subscription{ { @@ -1058,12 +1059,12 @@ func TestGetChannelDifference(t *testing.T) { Channel: "Test3", }, } - subs, unsubs := web.GetChannelDifference(newChans) + subs, unsubs := w.GetChannelDifference(newChans) + require.Equal(t, 3, len(subs), "Should get the correct number of subs") assert.Implements(t, (*subscription.MatchableKey)(nil), subs[0].Key, "Sub key must be matchable") - assert.Equal(t, 3, len(subs), "Should get the correct number of subs") assert.Equal(t, 0, len(unsubs), "Should get the correct number of unsubs") - web.AddSuccessfulSubscriptions(subs...) + w.AddSuccessfulSubscriptions(subs...) flushedSubs := []subscription.Subscription{ { @@ -1071,7 +1072,7 @@ func TestGetChannelDifference(t *testing.T) { }, } - subs, unsubs = web.GetChannelDifference(flushedSubs) + subs, unsubs = w.GetChannelDifference(flushedSubs) assert.Equal(t, 0, len(subs), "Should get the correct number of subs") assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") @@ -1084,7 +1085,7 @@ func TestGetChannelDifference(t *testing.T) { }, } - subs, unsubs = web.GetChannelDifference(flushedSubs) + subs, unsubs = w.GetChannelDifference(flushedSubs) if assert.Equal(t, 1, len(subs), "Should get the correct number of subs") { assert.Equal(t, subs[0].Channel, "Test4", "Should subscribe to the right channel") } diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 66a8f0538c5..7cfdb670b4e 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -3,6 +3,8 @@ package subscription import ( "errors" "fmt" + "maps" + "slices" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" @@ -47,6 +49,9 @@ type Key struct { // Map is a container of subscription pointers type Map map[any]*Subscription +// List is a container of subscription pointers +type List []Subscription + type SubscriptionInterface interface { } @@ -87,8 +92,8 @@ func (s *Subscription) EnsureKeyed() any { // * Empty pairs then only Subscriptions without pairs will be considered // * >=1 pairs then Subscriptions which contain all the pairs will be considered func (k Key) Match(m Map) *Subscription { - for a, v := range m { - candidate, ok := a.(Key) + for anyKey, s := range m { + candidate, ok := anyKey.(Key) if !ok { continue } @@ -99,11 +104,56 @@ func (k Key) Match(m Map) *Subscription { continue } if (k.Pairs == nil || len(*k.Pairs) == 0) && (candidate.Pairs == nil || len(*candidate.Pairs) == 0) { - return v + return s } if err := candidate.Pairs.ContainsAll(*k.Pairs, true); err == nil { - return v + return s } } return nil } + +// ListToMap creates a Map from a slice of subscriptions +func ListToMap(s List) *Map { + n := Map{} + for _, c := range s { + n[c.EnsureKeyed()] = &c + } + return &n +} + +// Diff returns a list of the added and missing subs between two maps +func (m *Map) Diff(newSubs *Map) (sub, unsub List) { + oldSubs := maps.Clone(*m) + for _, s := range *newSubs { + key := s.EnsureKeyed() + + var found *Subscription + if m, ok := key.(MatchableKey); ok { + found = m.Match(oldSubs) + } else { + found = oldSubs[key] + } + + if found != nil { + delete(oldSubs, found.Key) // If it's in both then we remove it from the unsubscribe list + } else { + sub = append(sub, *s) // If it's in newSubs but not oldSubs subs we want to subscribe + } + } + + for _, c := range oldSubs { + unsub = append(unsub, *c) + } + + return +} + +func (l List) Strings() []string { + s := make([]string, len(l)) + for i := range l { + s[i] = l[i].String() + } + slices.Sort(s) + return s +} diff --git a/internal/testing/subscriptions/subscriptions.go b/internal/testing/subscriptions/subscriptions.go new file mode 100644 index 00000000000..f5cdc8e5e6d --- /dev/null +++ b/internal/testing/subscriptions/subscriptions.go @@ -0,0 +1,24 @@ +package subscriptionstest + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" +) + +func Equal(tb testing.TB, a, b subscription.List) { + tb.Helper() + added, missing := subscription.ListToMap(a).Diff(subscription.ListToMap(b)) + if len(added) > 0 || len(missing) > 0 { + fail := "Differences:" + if len(added) > 0 { + fail = fail + "\n + " + strings.Join(added.Strings(), "\n + ") + } + if len(missing) > 0 { + fail = fail + "\n - " + strings.Join(missing.Strings(), "\n - ") + } + assert.Fail(tb, fail, "Subscriptions should be equal") + } +}