diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 69f70ccd8b4..9dfd5f0ee3b 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -1068,15 +1068,15 @@ func TestGetChannelDifference(t *testing.T) { // GenSubs defines a theoretical exchange with pair management type GenSubs struct { EnabledPairs currency.Pairs - subscribos []subscription.Subscription - unsubscribos []subscription.Subscription + subscribos subscription.List + unsubscribos subscription.List } // generateSubs default subs created from the enabled pairs list -func (g *GenSubs) generateSubs() ([]subscription.Subscription, error) { - superduperchannelsubs := make([]subscription.Subscription, len(g.EnabledPairs)) +func (g *GenSubs) generateSubs() (subscription.List, error) { + superduperchannelsubs := make(subscription.List, len(g.EnabledPairs)) for i := range g.EnabledPairs { - superduperchannelsubs[i] = subscription.Subscription{ + superduperchannelsubs[i] = &subscription.Subscription{ Channel: "TEST:" + strconv.FormatInt(int64(i), 10), Pairs: currency.Pairs{g.EnabledPairs[i]}, } @@ -1139,19 +1139,19 @@ func TestFlushChannels(t *testing.T) { // this to an unconnected state } - problemFunc := func() ([]subscription.Subscription, error) { + problemFunc := func() (subscription.List, error) { return nil, errors.New("problems") } - noSub := func() ([]subscription.Subscription, error) { + noSub := func() (subscription.List, error) { return nil, nil } // Disable pair and flush system newgen.EnabledPairs = []currency.Pair{ currency.NewPair(currency.BTC, currency.AUD)} - web.GenerateSubs = func() ([]subscription.Subscription, error) { - return []subscription.Subscription{{Channel: "test"}}, nil + web.GenerateSubs = func() (subscription.List, error) { + return subscription.List{{Channel: "test"}}, nil } err = web.FlushChannels() if err != nil { @@ -1176,7 +1176,10 @@ func TestFlushChannels(t *testing.T) { if err != nil { t.Fatal(err) } - web.AddSuccessfulSubscriptions(subs...) + for _, s := range subs { + s.SetState(subscription.SubscribedState) + } + web.AddSubscriptions(subs) err = web.FlushChannels() if err != nil { t.Fatal(err) @@ -1195,20 +1198,17 @@ func TestFlushChannels(t *testing.T) { if err != nil { t.Fatal(err) } - web.subscriptionMutex.Lock() - web.subscriptions = subscription.Map{ - 41: { - Key: 41, - Channel: "match channel", - Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.AUD)}, - }, - 42: { - Key: 42, - Channel: "unsub channel", - Pairs: currency.Pairs{currency.NewPair(currency.THETA, currency.USDT)}, - }, - } - web.subscriptionMutex.Unlock() + web.subscriptions = subscription.NewStore() + web.subscriptions.Add(&subscription.Subscription{ + Key: 41, + Channel: "match channel", + Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.AUD)}, + }) + web.subscriptions.Add(&subscription.Subscription{ + Key: 42, + Channel: "unsub channel", + Pairs: currency.Pairs{currency.NewPair(currency.THETA, currency.USDT)}, + }) err = web.FlushChannels() if err != nil { @@ -1251,10 +1251,10 @@ func TestEnable(t *testing.T) { connector: connect, Wg: new(sync.WaitGroup), ShutdownC: make(chan struct{}), - GenerateSubs: func() ([]subscription.Subscription, error) { - return []subscription.Subscription{{Channel: "test"}}, nil + GenerateSubs: func() (subscription.List, error) { + return subscription.List{{Channel: "test"}}, nil }, - Subscriber: func(cs []subscription.Subscription) error { return nil }, + Subscriber: func(cs subscription.List) error { return nil }, } err := web.Enable() @@ -1406,15 +1406,16 @@ func TestCheckSubscriptions(t *testing.T) { ws.MaxSubscriptionsPerConnection = 1 - err = ws.checkSubscriptions([]subscription.Subscription{{}, {}}) + err = ws.checkSubscriptions(subscription.List{{}, {}}) assert.ErrorIs(t, err, errSubscriptionsExceedsLimit, "checkSubscriptions should error correctly") ws.MaxSubscriptionsPerConnection = 2 - ws.subscriptions = subscription.Map{42: {Key: 42, Channel: "test"}} - err = ws.checkSubscriptions([]subscription.Subscription{{Key: 42, Channel: "test"}}) - assert.ErrorIs(t, err, ErrSubscribedAlready, "checkSubscriptions should error correctly") + ws.subscriptions = subscription.NewStore() + ws.subscriptions.Add(&subscription.Subscription{Key: 42, Channel: "test"}) + err = ws.checkSubscriptions(subscription.List{{Key: 42, Channel: "test"}}) + assert.ErrorIs(t, err, subscription.ErrDuplicate, "checkSubscriptions should error correctly") - err = ws.checkSubscriptions([]subscription.Subscription{{}}) + err = ws.checkSubscriptions(subscription.List{{}}) assert.NoError(t, err, "checkSubscriptions should not error") }