Skip to content

Commit

Permalink
Subscriptions: State test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Feb 15, 2024
1 parent f40bc5e commit bc93447
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 106 deletions.
43 changes: 35 additions & 8 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -889,35 +889,62 @@ func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error {

// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
// Errors are returned for duplicates or exceeding max Subscriptions
func (w *Websocket) SubscribeToChannels(channels subscription.List) error {
if err := w.checkSubscriptions(channels); err != nil {
func (w *Websocket) SubscribeToChannels(subs subscription.List) error {
if err := w.checkSubscriptions(subs); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
if err := w.Subscriber(channels); err != nil {
if err := w.Subscriber(subs); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
return nil
}

// AddSubscription adds a subscription to the subscription lists
func (w *Websocket) AddSubscription(c *subscription.Subscription) error {
if w == nil || c == nil {
// AddSubscription adds a subscription to the subscription store
func (w *Websocket) AddSubscription(s *subscription.Subscription) error {
if w == nil || s == nil {
return common.ErrNilPointer
}
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
return w.subscriptions.Add(c)
return w.subscriptions.Add(s)
}

// RemoveSubscriptions removes subscriptions from the subscription list
// AddSubscriptions adds subscriptions to the subscription store
func (w *Websocket) AddSubscriptions(subs subscription.List) error {
if w == nil {
return common.ErrNilPointer
}
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
var errs error
for _, s := range subs {
if err := w.subscriptions.Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}

// RemoveSubscription removes a subscription from the subscription store
func (w *Websocket) RemoveSubscription(s *subscription.Subscription) {
if w == nil || w.subscriptions == nil || s == nil {
return
}
w.subscriptions.Remove(s)
}

// RemoveSubscriptions removes subscriptions from the subscription list
func (w *Websocket) RemoveSubscriptions(subs subscription.List) {
if w == nil || w.subscriptions == nil {
return
}
for _, s := range subs {
w.subscriptions.Remove(s)
}
}

// GetSubscription returns a subscription at the key provided
// returns nil if no subscription is at that key or the key is nil
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
Expand Down
151 changes: 60 additions & 91 deletions exchanges/stream/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ var defaultSetup = &WebsocketSetup{
return subscription.List{
{Channel: "TestSub"},
{Channel: "TestSub2", Key: "purple"},
{Channel: "TestSub3", key: testSubKey{"mauve"}},
{Channel: "TestSub4", key: 42},
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
{Channel: "TestSub4", Key: 42},
}, nil
},
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
Expand Down Expand Up @@ -157,20 +157,20 @@ func TestSetup(t *testing.T) {
t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriberUnset)
}

websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil }
websocketSetup.Subscriber = func(subscription.List) error { return nil }
websocketSetup.Features.Unsubscribe = true
err = w.Setup(websocketSetup)
if !errors.Is(err, errWebsocketUnsubscriberUnset) {
t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketUnsubscriberUnset)
}

websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil }
websocketSetup.Unsubscriber = func(subscription.List) error { return nil }
err = w.Setup(websocketSetup)
if !errors.Is(err, errWebsocketSubscriptionsGeneratorUnset) {
t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriptionsGeneratorUnset)
}

websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil }
websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil }
err = w.Setup(websocketSetup)
if !errors.Is(err, errDefaultURLIsEmpty) {
t.Fatalf("received: '%v' but expected: '%v'", err, errDefaultURLIsEmpty)
Expand Down Expand Up @@ -505,12 +505,15 @@ func TestSubscribeUnsubscribe(t *testing.T) {
ws := *New()
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")

fnSub := func(subs []subscription.Subscription) error {
ws.AddSuccessfulSubscriptions(subs...)
fnSub := func(subs subscription.List) error {
for _, s := range subs {
s.SetState(subscription.SubscribedState)
}
ws.AddSubscriptions(subs)
return nil
}
fnUnsub := func(unsubs []subscription.Subscription) error {
ws.RemoveSubscriptions(unsubs...)
fnUnsub := func(unsubs subscription.List) error {
ws.RemoveSubscriptions(unsubs)
return nil
}
ws.Subscriber = fnSub
Expand All @@ -523,10 +526,10 @@ func TestSubscribeUnsubscribe(t *testing.T) {
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error")
assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions")
byDefKey := ws.GetSubscription(subscription.Key{Channel: "TestSub"})
if assert.NotNil(t, byDefKey, "GetSubscription by default key should find a channel") {
assert.Equal(t, "TestSub", byDefKey.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.NotSame(t, byDefKey, ws.subscriptions["TestSub"], "GetSubscription returns a fresh pointer")
bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"})
if assert.NotNil(t, bySub, "GetSubscription by by subscription should find a channel") {
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.Same(t, bySub, subs[0], "GetSubscription returns a fresh pointer")
}
if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
Expand All @@ -539,7 +542,7 @@ func TestSubscribeUnsubscribe(t *testing.T) {
}
assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil")
assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil")
assert.ErrorIs(t, ws.SubscribeToChannels(subs), ErrSubscribedAlready, "Subscribe should error when already subscribed")
assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
assert.ErrorIs(t, ws.SubscribeToChannels(nil), errNoSubscriptionsSupplied, "Subscribe to nil should error")
assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error")
}
Expand All @@ -557,48 +560,29 @@ func TestResubscribe(t *testing.T) {
err = ws.Setup(defaultSetup)
assert.NoError(t, err, "WS Setup should not error")

fnSub := func(subs []subscription.Subscription) error {
ws.AddSuccessfulSubscriptions(subs...)
fnSub := func(subs subscription.List) error {
for _, s := range subs {
s.SetState(subscription.SubscribedState)
}
ws.AddSubscriptions(subs)
return nil
}
fnUnsub := func(unsubs []subscription.Subscription) error {
ws.RemoveSubscriptions(unsubs...)
fnUnsub := func(unsubs subscription.List) error {
ws.RemoveSubscriptions(unsubs)
return nil
}
ws.Subscriber = fnSub
ws.Unsubscriber = fnUnsub

channel := []subscription.Subscription{{Channel: "resubTest"}}
channel := subscription.List{{Channel: "resubTest"}}

assert.ErrorIs(t, ws.ResubscribeToChannel(&channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error")
assert.NoError(t, ws.ResubscribeToChannel(&channel[0]), "Resubscribe should not error now the channel is subscribed")
assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed")
}

// TestSubscriptionState tests Subscription state changes
func TestSubscriptionState(t *testing.T) {
t.Parallel()
ws := New()

c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState}
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error")

assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
found := ws.GetSubscription(42)
assert.NotNil(t, found, "Should find the subscription")
assert.Equal(t, subscription.SubscribingState, found.State, "Subscription should be Subscribing")
assert.ErrorIs(t, ws.AddSubscription(c), ErrSubscribedAlready, "Adding an already existing sub should error")
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.SubscribingState), ErrChannelInStateAlready, "Setting Same state should error")
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState+1), errInvalidChannelState, "Setting an invalid state should error")

ws.AddSuccessfulSubscriptions(*c)
found = ws.GetSubscription(42)
assert.NotNil(t, found, "Should find the subscription")
assert.Equal(t, found.State, subscription.SubscribedState, "Subscription should be subscribed state")

assert.NoError(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), "Setting Unsub state should not error")
found = ws.GetSubscription(42)
assert.Equal(t, found.State, subscription.UnsubscribingState, "Subscription should be unsubscribing state")
func TestAddSubscription(t *testing.T) {
t.Fatal("Not implemented, along with others")
}

// TestRemoveSubscriptions tests removing a subscription
Expand All @@ -610,14 +594,14 @@ func TestRemoveSubscriptions(t *testing.T) {
assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
assert.NotNil(t, ws.GetSubscription(42), "Added subscription should be findable")

ws.RemoveSubscriptions(*c)
ws.RemoveSubscriptions(subscription.List{c})
assert.Nil(t, ws.GetSubscription(42), "Remove should have removed the sub")
}

// TestConnectionMonitorNoConnection logic test
func TestConnectionMonitorNoConnection(t *testing.T) {
t.Parallel()
ws := *New()
ws := New()
ws.connectionMonitorDelay = 500
ws.DataHandler = make(chan interface{}, 1)
ws.ShutdownC = make(chan struct{}, 1)
Expand All @@ -641,31 +625,24 @@ func TestConnectionMonitorNoConnection(t *testing.T) {
func TestGetSubscription(t *testing.T) {
t.Parallel()
assert.Nil(t, (*Websocket).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub map should return nil")
w := Websocket{
subscriptions: subscription.Map{
42: {
Channel: "hello3",
},
},
}
assert.Nil(t, w.GetSubscription(43), "GetSubscription with an invalid key should return nil")
c := w.GetSubscription(42)
if assert.NotNil(t, c, "GetSubscription with an valid key should return a channel") {
assert.Equal(t, "hello3", c.Channel, "GetSubscription should return the correct channel details")
}
assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil")
w := New()
assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil")
s := &subscription.Subscription{Key: 42, Channel: "hello3"}
w.AddSubscription(s)
assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store")
}

// TestGetSubscriptions logic test
func TestGetSubscriptions(t *testing.T) {
t.Parallel()
w := Websocket{
subscriptions: subscription.Map{
42: {
Channel: "hello3",
},
},
}
assert.Nil(t, (*Websocket).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Websocket{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil")
w := New()
w.AddSubscriptions(subscription.List{
{Key: 42, Channel: "hello3"},
{Key: 45, Channel: "hello4"},
})
assert.Equal(t, "hello3", w.GetSubscriptions()[0].Channel, "GetSubscriptions should return the correct channel details")
}

Expand Down Expand Up @@ -1048,41 +1025,33 @@ func TestGetChannelDifference(t *testing.T) {
t.Parallel()
w := Websocket{}

newChans := []subscription.Subscription{
{
Channel: "Test1",
},
{
Channel: "Test2",
},
{
Channel: "Test3",
},
newChans := subscription.List{
{Channel: "Test1"},
{Channel: "Test2"},
{Channel: "Test3"},
}
subs, unsubs := w.GetChannelDifference(newChans)
require.Equal(t, 3, len(subs), "Should get the correct number of subs")
assert.Implements(t, (*subscription.MatchableKey)(nil), subs[0].Key, "Sub key must be matchable")
assert.Equal(t, 0, len(unsubs), "Should get the correct number of unsubs")

w.AddSuccessfulSubscriptions(subs...)
for _, s := range subs {
s.SetState(subscription.SubscribedState)
}

flushedSubs := []subscription.Subscription{
{
Channel: "Test2",
},
w.AddSubscriptions(subs)

flushedSubs := subscription.List{
{Channel: "Test2"},
}

subs, unsubs = w.GetChannelDifference(flushedSubs)
assert.Equal(t, 0, len(subs), "Should get the correct number of subs")
assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs")

flushedSubs = []subscription.Subscription{
{
Channel: "Test2",
},
{
Channel: "Test4",
},
flushedSubs = subscription.List{
{Channel: "Test2"},
{Channel: "Test4"},
}

subs, unsubs = w.GetChannelDifference(flushedSubs)
Expand Down Expand Up @@ -1115,15 +1084,15 @@ func (g *GenSubs) generateSubs() ([]subscription.Subscription, error) {
return superduperchannelsubs, nil
}

func (g *GenSubs) SUBME(subs []subscription.Subscription) error {
func (g *GenSubs) SUBME(subs subscription.List) error {
if len(subs) == 0 {
return errors.New("WOW")
}
g.subscribos = subs
return nil
}

func (g *GenSubs) UNSUBME(unsubs []subscription.Subscription) error {
func (g *GenSubs) UNSUBME(unsubs subscription.List) error {
if len(unsubs) == 0 {
return errors.New("WOW")
}
Expand Down
2 changes: 1 addition & 1 deletion exchanges/subscription/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (s *Subscription) State() State {
func (s *Subscription) SetState(state State) error {
s.m.Lock()
defer s.m.Unlock()
if state == s.State() {
if state == s.state {
return ErrInStateAlready
}
if state > UnsubscribingState {
Expand Down
Loading

0 comments on commit bc93447

Please sign in to comment.