diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 502a8b2c59b..ed656f33324 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -889,28 +889,45 @@ func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error { // SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method // Errors are returned for duplicates or exceeding max Subscriptions -func (w *Websocket) SubscribeToChannels(channels subscription.List) error { - if err := w.checkSubscriptions(channels); err != nil { +func (w *Websocket) SubscribeToChannels(subs subscription.List) error { + if err := w.checkSubscriptions(subs); err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } - if err := w.Subscriber(channels); err != nil { + if err := w.Subscriber(subs); err != nil { return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err)) } return nil } -// AddSubscription adds a subscription to the subscription lists -func (w *Websocket) AddSubscription(c *subscription.Subscription) error { - if w == nil || c == nil { +// AddSubscription adds a subscription to the subscription store +func (w *Websocket) AddSubscription(s *subscription.Subscription) error { + if w == nil || s == nil { return common.ErrNilPointer } if w.subscriptions == nil { w.subscriptions = subscription.NewStore() } - return w.subscriptions.Add(c) + return w.subscriptions.Add(s) } -// RemoveSubscriptions removes subscriptions from the subscription list +// AddSubscriptions adds subscriptions to the subscription store +func (w *Websocket) AddSubscriptions(subs subscription.List) error { + if w == nil { + return common.ErrNilPointer + } + if w.subscriptions == nil { + w.subscriptions = subscription.NewStore() + } + var errs error + for _, s := range subs { + if err := w.subscriptions.Add(s); err != nil { + errs = common.AppendError(errs, err) + } + } + return errs +} + +// RemoveSubscription removes a subscription from the subscription store func (w *Websocket) RemoveSubscription(s *subscription.Subscription) { if w == nil || w.subscriptions == nil || s == nil { return @@ -918,6 +935,16 @@ func (w *Websocket) RemoveSubscription(s *subscription.Subscription) { w.subscriptions.Remove(s) } +// RemoveSubscriptions removes subscriptions from the subscription list +func (w *Websocket) RemoveSubscriptions(subs subscription.List) { + if w == nil || w.subscriptions == nil { + return + } + for _, s := range subs { + w.subscriptions.Remove(s) + } +} + // GetSubscription returns a subscription at the key provided // returns nil if no subscription is at that key or the key is nil // Keys can implement subscription.MatchableKey in order to provide custom matching logic diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index d0ebadda420..69f70ccd8b4 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -80,8 +80,8 @@ var defaultSetup = &WebsocketSetup{ return subscription.List{ {Channel: "TestSub"}, {Channel: "TestSub2", Key: "purple"}, - {Channel: "TestSub3", key: testSubKey{"mauve"}}, - {Channel: "TestSub4", key: 42}, + {Channel: "TestSub3", Key: testSubKey{"mauve"}}, + {Channel: "TestSub4", Key: 42}, }, nil }, Features: &protocol.Features{Subscribe: true, Unsubscribe: true}, @@ -157,20 +157,20 @@ func TestSetup(t *testing.T) { t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriberUnset) } - websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil } + websocketSetup.Subscriber = func(subscription.List) error { return nil } websocketSetup.Features.Unsubscribe = true err = w.Setup(websocketSetup) if !errors.Is(err, errWebsocketUnsubscriberUnset) { t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketUnsubscriberUnset) } - websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil } + websocketSetup.Unsubscriber = func(subscription.List) error { return nil } err = w.Setup(websocketSetup) if !errors.Is(err, errWebsocketSubscriptionsGeneratorUnset) { t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriptionsGeneratorUnset) } - websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil } + websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil } err = w.Setup(websocketSetup) if !errors.Is(err, errDefaultURLIsEmpty) { t.Fatalf("received: '%v' but expected: '%v'", err, errDefaultURLIsEmpty) @@ -505,12 +505,15 @@ func TestSubscribeUnsubscribe(t *testing.T) { ws := *New() assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error") - fnSub := func(subs []subscription.Subscription) error { - ws.AddSuccessfulSubscriptions(subs...) + fnSub := func(subs subscription.List) error { + for _, s := range subs { + s.SetState(subscription.SubscribedState) + } + ws.AddSubscriptions(subs) return nil } - fnUnsub := func(unsubs []subscription.Subscription) error { - ws.RemoveSubscriptions(unsubs...) + fnUnsub := func(unsubs subscription.List) error { + ws.RemoveSubscriptions(unsubs) return nil } ws.Subscriber = fnSub @@ -523,10 +526,10 @@ func TestSubscribeUnsubscribe(t *testing.T) { assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error") assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") - byDefKey := ws.GetSubscription(subscription.Key{Channel: "TestSub"}) - if assert.NotNil(t, byDefKey, "GetSubscription by default key should find a channel") { - assert.Equal(t, "TestSub", byDefKey.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") - assert.NotSame(t, byDefKey, ws.subscriptions["TestSub"], "GetSubscription returns a fresh pointer") + bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"}) + if assert.NotNil(t, bySub, "GetSubscription by by subscription should find a channel") { + assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel") + assert.Same(t, bySub, subs[0], "GetSubscription returns a fresh pointer") } if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") { assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel") @@ -539,7 +542,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { } assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil") assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil") - assert.ErrorIs(t, ws.SubscribeToChannels(subs), ErrSubscribedAlready, "Subscribe should error when already subscribed") + assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed") assert.ErrorIs(t, ws.SubscribeToChannels(nil), errNoSubscriptionsSupplied, "Subscribe to nil should error") assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error") } @@ -557,48 +560,29 @@ func TestResubscribe(t *testing.T) { err = ws.Setup(defaultSetup) assert.NoError(t, err, "WS Setup should not error") - fnSub := func(subs []subscription.Subscription) error { - ws.AddSuccessfulSubscriptions(subs...) + fnSub := func(subs subscription.List) error { + for _, s := range subs { + s.SetState(subscription.SubscribedState) + } + ws.AddSubscriptions(subs) return nil } - fnUnsub := func(unsubs []subscription.Subscription) error { - ws.RemoveSubscriptions(unsubs...) + fnUnsub := func(unsubs subscription.List) error { + ws.RemoveSubscriptions(unsubs) return nil } ws.Subscriber = fnSub ws.Unsubscriber = fnUnsub - channel := []subscription.Subscription{{Channel: "resubTest"}} + channel := subscription.List{{Channel: "resubTest"}} - assert.ErrorIs(t, ws.ResubscribeToChannel(&channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet") + assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet") assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error") - assert.NoError(t, ws.ResubscribeToChannel(&channel[0]), "Resubscribe should not error now the channel is subscribed") + assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed") } -// TestSubscriptionState tests Subscription state changes -func TestSubscriptionState(t *testing.T) { - t.Parallel() - ws := New() - - c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState} - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error") - - assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") - found := ws.GetSubscription(42) - assert.NotNil(t, found, "Should find the subscription") - assert.Equal(t, subscription.SubscribingState, found.State, "Subscription should be Subscribing") - assert.ErrorIs(t, ws.AddSubscription(c), ErrSubscribedAlready, "Adding an already existing sub should error") - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.SubscribingState), ErrChannelInStateAlready, "Setting Same state should error") - assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState+1), errInvalidChannelState, "Setting an invalid state should error") - - ws.AddSuccessfulSubscriptions(*c) - found = ws.GetSubscription(42) - assert.NotNil(t, found, "Should find the subscription") - assert.Equal(t, found.State, subscription.SubscribedState, "Subscription should be subscribed state") - - assert.NoError(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), "Setting Unsub state should not error") - found = ws.GetSubscription(42) - assert.Equal(t, found.State, subscription.UnsubscribingState, "Subscription should be unsubscribing state") +func TestAddSubscription(t *testing.T) { + t.Fatal("Not implemented, along with others") } // TestRemoveSubscriptions tests removing a subscription @@ -610,14 +594,14 @@ func TestRemoveSubscriptions(t *testing.T) { assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") assert.NotNil(t, ws.GetSubscription(42), "Added subscription should be findable") - ws.RemoveSubscriptions(*c) + ws.RemoveSubscriptions(subscription.List{c}) assert.Nil(t, ws.GetSubscription(42), "Remove should have removed the sub") } // TestConnectionMonitorNoConnection logic test func TestConnectionMonitorNoConnection(t *testing.T) { t.Parallel() - ws := *New() + ws := New() ws.connectionMonitorDelay = 500 ws.DataHandler = make(chan interface{}, 1) ws.ShutdownC = make(chan struct{}, 1) @@ -641,31 +625,24 @@ func TestConnectionMonitorNoConnection(t *testing.T) { func TestGetSubscription(t *testing.T) { t.Parallel() assert.Nil(t, (*Websocket).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil") - assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub map should return nil") - w := Websocket{ - subscriptions: subscription.Map{ - 42: { - Channel: "hello3", - }, - }, - } - assert.Nil(t, w.GetSubscription(43), "GetSubscription with an invalid key should return nil") - c := w.GetSubscription(42) - if assert.NotNil(t, c, "GetSubscription with an valid key should return a channel") { - assert.Equal(t, "hello3", c.Channel, "GetSubscription should return the correct channel details") - } + assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil") + w := New() + assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") + s := &subscription.Subscription{Key: 42, Channel: "hello3"} + w.AddSubscription(s) + assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") } // TestGetSubscriptions logic test func TestGetSubscriptions(t *testing.T) { t.Parallel() - w := Websocket{ - subscriptions: subscription.Map{ - 42: { - Channel: "hello3", - }, - }, - } + assert.Nil(t, (*Websocket).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil") + assert.Nil(t, (&Websocket{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil") + w := New() + w.AddSubscriptions(subscription.List{ + {Key: 42, Channel: "hello3"}, + {Key: 45, Channel: "hello4"}, + }) assert.Equal(t, "hello3", w.GetSubscriptions()[0].Channel, "GetSubscriptions should return the correct channel details") } @@ -1048,41 +1025,33 @@ func TestGetChannelDifference(t *testing.T) { t.Parallel() w := Websocket{} - newChans := []subscription.Subscription{ - { - Channel: "Test1", - }, - { - Channel: "Test2", - }, - { - Channel: "Test3", - }, + newChans := subscription.List{ + {Channel: "Test1"}, + {Channel: "Test2"}, + {Channel: "Test3"}, } subs, unsubs := w.GetChannelDifference(newChans) require.Equal(t, 3, len(subs), "Should get the correct number of subs") assert.Implements(t, (*subscription.MatchableKey)(nil), subs[0].Key, "Sub key must be matchable") assert.Equal(t, 0, len(unsubs), "Should get the correct number of unsubs") - w.AddSuccessfulSubscriptions(subs...) + for _, s := range subs { + s.SetState(subscription.SubscribedState) + } - flushedSubs := []subscription.Subscription{ - { - Channel: "Test2", - }, + w.AddSubscriptions(subs) + + flushedSubs := subscription.List{ + {Channel: "Test2"}, } subs, unsubs = w.GetChannelDifference(flushedSubs) assert.Equal(t, 0, len(subs), "Should get the correct number of subs") assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") - flushedSubs = []subscription.Subscription{ - { - Channel: "Test2", - }, - { - Channel: "Test4", - }, + flushedSubs = subscription.List{ + {Channel: "Test2"}, + {Channel: "Test4"}, } subs, unsubs = w.GetChannelDifference(flushedSubs) @@ -1115,7 +1084,7 @@ func (g *GenSubs) generateSubs() ([]subscription.Subscription, error) { return superduperchannelsubs, nil } -func (g *GenSubs) SUBME(subs []subscription.Subscription) error { +func (g *GenSubs) SUBME(subs subscription.List) error { if len(subs) == 0 { return errors.New("WOW") } @@ -1123,7 +1092,7 @@ func (g *GenSubs) SUBME(subs []subscription.Subscription) error { return nil } -func (g *GenSubs) UNSUBME(unsubs []subscription.Subscription) error { +func (g *GenSubs) UNSUBME(unsubs subscription.List) error { if len(unsubs) == 0 { return errors.New("WOW") } diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 049ffc02cde..0be695a33a4 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -73,7 +73,7 @@ func (s *Subscription) State() State { func (s *Subscription) SetState(state State) error { s.m.Lock() defer s.m.Unlock() - if state == s.State() { + if state == s.state { return ErrInStateAlready } if state > UnsubscribingState { diff --git a/exchanges/subscription/subscription_test.go b/exchanges/subscription/subscription_test.go index fc95a7a0661..60f1c621ae6 100644 --- a/exchanges/subscription/subscription_test.go +++ b/exchanges/subscription/subscription_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" @@ -13,25 +14,25 @@ import ( // TestEnsureKeyed logic test func TestEnsureKeyed(t *testing.T) { t.Parallel() - c := &Subscription{ + s := &Subscription{ Channel: "candles", Asset: asset.Spot, Pairs: []currency.Pair{currency.NewPair(currency.BTC, currency.USDT)}, } - k1, ok := c.EnsureKeyed().(*Subscription) + k1, ok := s.EnsureKeyed().(*Subscription) if assert.True(t, ok, "EnsureKeyed should return a *Subscription") { - assert.Same(t, k1, c, "Key should point to the same struct") + assert.Same(t, k1, s, "Key should point to the same struct") } type platypus string - c = &Subscription{ + s = &Subscription{ Key: platypus("Gerald"), Channel: "orderbook", Asset: asset.Margin, Pairs: []currency.Pair{currency.NewPair(currency.ETH, currency.USDC)}, } - k2, ok := c.EnsureKeyed().(platypus) + k2, ok := s.EnsureKeyed().(platypus) if assert.True(t, ok, "EnsureKeyed should return a platypus") { - assert.Exactly(t, k2, c.Key, "ensureKeyed should set the same key") + assert.Exactly(t, k2, s.Key, "ensureKeyed should set the same key") assert.EqualValues(t, "Gerald", k2, "key should have the correct value") } } @@ -55,3 +56,17 @@ func TestMarshaling(t *testing.T) { assert.NoError(t, err, "Marshalling should not error") assert.Equal(t, `{"enabled":true,"channel":"myTrades","authenticated":true}`, string(j), "Marshalling should be clean and concise") } + +// TestSetState tests Subscription state changes +func TestSetState(t *testing.T) { + t.Parallel() + + s := &Subscription{Key: 42, Channel: "Gophers"} + assert.Equal(t, UnknownState, s.State(), "State should start as unknown") + require.NoError(t, s.SetState(SubscribingState), "SetState should not error") + assert.Equal(t, SubscribingState, s.State(), "State should be set correctly") + assert.ErrorIs(t, s.SetState(SubscribingState), ErrInStateAlready, "SetState should error on same state") + assert.ErrorIs(t, s.SetState(UnsubscribingState+1), ErrInvalidState, "Setting an invalid state should error") + require.NoError(t, s.SetState(UnsubscribingState), "SetState should not error") + assert.Equal(t, UnsubscribingState, s.State(), "State should be set correctly") +}