From e51317a0903fd84e25da97af31ae1347c5d6f2af Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Thu, 14 Sep 2023 19:32:01 +0700 Subject: [PATCH] Websockets: Add key to websocket subscriptions --- exchanges/stream/stream_types.go | 7 +++++ exchanges/stream/websocket.go | 48 ++++++++++++++++++----------- exchanges/stream/websocket_types.go | 4 ++- 3 files changed, 40 insertions(+), 19 deletions(-) diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 053e5688a30..bc17df6dc0b 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -31,8 +31,15 @@ type Response struct { Raw []byte } +type defaultChannelKey struct { + Channel string + Currency currency.Pair + Asset asset.Item +} + // ChannelSubscription container for streaming subscriptions type ChannelSubscription struct { + Key any Channel string Currency currency.Pair Asset asset.Item diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 7d9c9e4e171..1abbcde514b 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -461,7 +461,7 @@ func (w *Websocket) Shutdown() error { // flush any subscriptions from last connection if needed w.subscriptionMutex.Lock() - w.subscriptions = nil + w.subscriptions = subscriptionMap{} w.subscriptionMutex.Unlock() close(w.ShutdownC) @@ -521,7 +521,7 @@ func (w *Websocket) FlushChannels() error { if len(newsubs) != 0 { // Purge subscription list as there will be conflicts w.subscriptionMutex.Lock() - w.subscriptions = nil + w.subscriptions = subscriptionMap{} w.subscriptionMutex.Unlock() return w.SubscribeToChannels(newsubs) } @@ -850,19 +850,19 @@ func (w *Websocket) GetChannelDifference(genSubs []ChannelSubscription) (sub, un defer w.subscriptionMutex.Unlock() oldsubs: - for x := range w.subscriptions { + for _, x := range w.subscriptions { for y := range genSubs { - if w.subscriptions[x].Equal(&genSubs[y]) { + if x.Equal(&genSubs[y]) { continue oldsubs } } - unsub = append(unsub, w.subscriptions[x]) + unsub = append(unsub, x) } newsubs: for x := range genSubs { - for y := range w.subscriptions { - if genSubs[x].Equal(&w.subscriptions[y]) { + for _, y := range w.subscriptions { + if genSubs[x].Equal(&y) { continue newsubs } } @@ -881,8 +881,8 @@ func (w *Websocket) UnsubscribeChannels(channels []ChannelSubscription) error { channels: for x := range channels { - for y := range w.subscriptions { - if channels[x].Equal(&w.subscriptions[y]) { + for _, y := range w.subscriptions { + if channels[x].Equal(&y) { continue channels } } @@ -912,8 +912,8 @@ func (w *Websocket) SubscribeToChannels(channels []ChannelSubscription) error { } w.subscriptionMutex.Lock() for x := range channels { - for y := range w.subscriptions { - if channels[x].Equal(&w.subscriptions[y]) { + for _, y := range w.subscriptions { + if channels[x].Equal(&y) { w.subscriptionMutex.Unlock() return fmt.Errorf("%s websocket: %v already subscribed", w.exchangeName, @@ -932,7 +932,17 @@ func (w *Websocket) SubscribeToChannels(channels []ChannelSubscription) error { // has been successfully subscribed func (w *Websocket) AddSuccessfulSubscriptions(channels ...ChannelSubscription) { w.subscriptionMutex.Lock() - w.subscriptions = append(w.subscriptions, channels...) + for i := range channels { + key := channels[i].Key + if key == nil { + key = defaultChannelKey{ + Channel: channels[i].Channel, + Asset: channels[i].Asset, + Currency: channels[i].Currency, + } + } + w.subscriptions[key] = channels[i] + } w.subscriptionMutex.Unlock() } @@ -942,11 +952,9 @@ func (w *Websocket) RemoveSuccessfulUnsubscriptions(channels ...ChannelSubscript w.subscriptionMutex.Lock() defer w.subscriptionMutex.Unlock() for x := range channels { - for y := range w.subscriptions { - if channels[x].Equal(&w.subscriptions[y]) { - w.subscriptions[y] = w.subscriptions[len(w.subscriptions)-1] - w.subscriptions[len(w.subscriptions)-1] = ChannelSubscription{} - w.subscriptions = w.subscriptions[:len(w.subscriptions)-1] + for _, y := range w.subscriptions { + if channels[x].Equal(&y) { + delete(w.subscriptions, y.Key) break } } @@ -964,7 +972,11 @@ func (w *ChannelSubscription) Equal(s *ChannelSubscription) bool { func (w *Websocket) GetSubscriptions() []ChannelSubscription { w.subscriptionMutex.Lock() defer w.subscriptionMutex.Unlock() - return append(w.subscriptions[:0:0], w.subscriptions...) + subs := make([]ChannelSubscription, len(w.subscriptions)) + for _, c := range w.subscriptions { + subs = append(subs, c) + } + return subs } // SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 8bf2e811ca7..9573908b01e 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -22,6 +22,8 @@ const ( UnhandledMessage = " - Unhandled websocket message: " ) +type subscriptionMap map[any]ChannelSubscription + // Websocket defines a return type for websocket connections via the interface // wrapper for routine processing type Websocket struct { @@ -47,7 +49,7 @@ type Websocket struct { connector func() error subscriptionMutex sync.Mutex - subscriptions []ChannelSubscription + subscriptions subscriptionMap Subscribe chan []ChannelSubscription Unsubscribe chan []ChannelSubscription