diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index c16e549ffc8..f8f0850d82d 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -877,23 +877,31 @@ func (w *Websocket) GetName() string { // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(genSubs []subscription.Subscription) (sub, unsub []subscription.Subscription) { w.subscriptionMutex.RLock() - unsubMap := make(map[any]subscription.Subscription, len(w.subscriptions)) + unsubMap := subscription.Map{} for k, c := range w.subscriptions { - unsubMap[k] = *c + unsubMap[k] = c } w.subscriptionMutex.RUnlock() for i := range genSubs { key := genSubs[i].EnsureKeyed() - if _, ok := unsubMap[key]; ok { - delete(unsubMap, key) // If it's in both then we remove it from the unsubscribe list + + var found *subscription.Subscription + if m, ok := key.(subscription.MatchableKey); ok { + found = m.Match(unsubMap) + } else { + found = unsubMap[key] + } + + if found != nil { + delete(unsubMap, found.Key) // If it's in both then we remove it from the unsubscribe list } else { sub = append(sub, genSubs[i]) // If it's in genSubs but not existing subs we want to subscribe } } for _, c := range unsubMap { - unsub = append(unsub, c) + unsub = append(unsub, *c) } return diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 5a2b0a924f4..36e44679ec8 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1059,8 +1059,9 @@ func TestGetChannelDifference(t *testing.T) { }, } subs, unsubs := web.GetChannelDifference(newChans) - assert.Len(t, subs, 3, "Should get the correct number of subs") - assert.Len(t, unsubs, 0, "Should get the correct number of unsubs") + assert.Implements(t, (*subscription.MatchableKey)(nil), subs[0].Key, "Sub key must be matchable") + assert.Equal(t, 3, len(subs), "Should get the correct number of subs") + assert.Equal(t, 0, len(unsubs), "Should get the correct number of unsubs") web.AddSuccessfulSubscriptions(subs...) @@ -1071,8 +1072,8 @@ func TestGetChannelDifference(t *testing.T) { } subs, unsubs = web.GetChannelDifference(flushedSubs) - assert.Len(t, subs, 0, "Should get the correct number of subs") - assert.Len(t, unsubs, 2, "Should get the correct number of unsubs") + assert.Equal(t, 0, len(subs), "Should get the correct number of subs") + assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") flushedSubs = []subscription.Subscription{ { @@ -1084,10 +1085,10 @@ func TestGetChannelDifference(t *testing.T) { } subs, unsubs = web.GetChannelDifference(flushedSubs) - if assert.Len(t, subs, 1, "Should get the correct number of subs") { + if assert.Equal(t, 1, len(subs), "Should get the correct number of subs") { assert.Equal(t, subs[0].Channel, "Test4", "Should subscribe to the right channel") } - if assert.Len(t, unsubs, 2, "Should get the correct number of unsubs") { + if assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") { sort.Slice(unsubs, func(i, j int) bool { return unsubs[i].Channel <= unsubs[j].Channel }) assert.Equal(t, unsubs[0].Channel, "Test1", "Should unsubscribe from the right channels") assert.Equal(t, unsubs[1].Channel, "Test3", "Should unsubscribe from the right channels") diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 8a01d9bd2bc..68570ac0bfb 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -86,7 +86,7 @@ func (s *Subscription) EnsureKeyed() any { // If the key provided has: // * Empty pairs then only Subscriptions without pairs will be considered // * >=1 pairs then Subscriptions which contain all the pairs will be considered -func (k *Key) Match(m Map) *Subscription { +func (k Key) Match(m Map) *Subscription { for a, v := range m { candidate, ok := a.(Key) if !ok {