Skip to content

Commit

Permalink
Subscriptions: State test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Feb 14, 2024
1 parent f40bc5e commit 916ecb3
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 75 deletions.
43 changes: 35 additions & 8 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -889,35 +889,62 @@ func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error {

// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
// Errors are returned for duplicates or exceeding max Subscriptions
func (w *Websocket) SubscribeToChannels(channels subscription.List) error {
if err := w.checkSubscriptions(channels); err != nil {
func (w *Websocket) SubscribeToChannels(subs subscription.List) error {
if err := w.checkSubscriptions(subs); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
if err := w.Subscriber(channels); err != nil {
if err := w.Subscriber(subs); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
return nil
}

// AddSubscription adds a subscription to the subscription lists
func (w *Websocket) AddSubscription(c *subscription.Subscription) error {
if w == nil || c == nil {
// AddSubscription adds a subscription to the subscription store
func (w *Websocket) AddSubscription(s *subscription.Subscription) error {
if w == nil || s == nil {
return common.ErrNilPointer
}
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
return w.subscriptions.Add(c)
return w.subscriptions.Add(s)
}

// RemoveSubscriptions removes subscriptions from the subscription list
// AddSubscriptions adds subscriptions to the subscription store
func (w *Websocket) AddSubscriptions(subs subscription.List) error {
if w == nil {
return common.ErrNilPointer
}
if w.subscriptions == nil {
w.subscriptions = subscription.NewStore()
}
var errs error
for _, s := range subs {
if err := w.subscriptions.Add(s); err != nil {
errs = common.AppendError(errs, err)
}
}
return errs
}

// RemoveSubscription removes a subscription from the subscription store
func (w *Websocket) RemoveSubscription(s *subscription.Subscription) {
if w == nil || w.subscriptions == nil || s == nil {
return
}
w.subscriptions.Remove(s)
}

// RemoveSubscriptions removes subscriptions from the subscription list
func (w *Websocket) RemoveSubscriptions(subs subscription.List) {
if w == nil || w.subscriptions == nil {
return
}
for _, s := range subs {
w.subscriptions.Remove(s)
}
}

// GetSubscription returns a subscription at the key provided
// returns nil if no subscription is at that key or the key is nil
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
Expand Down
97 changes: 37 additions & 60 deletions exchanges/stream/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ var defaultSetup = &WebsocketSetup{
return subscription.List{
{Channel: "TestSub"},
{Channel: "TestSub2", Key: "purple"},
{Channel: "TestSub3", key: testSubKey{"mauve"}},
{Channel: "TestSub4", key: 42},
{Channel: "TestSub3", Key: testSubKey{"mauve"}},
{Channel: "TestSub4", Key: 42},
}, nil
},
Features: &protocol.Features{Subscribe: true, Unsubscribe: true},
Expand Down Expand Up @@ -157,20 +157,20 @@ func TestSetup(t *testing.T) {
t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriberUnset)
}

websocketSetup.Subscriber = func([]subscription.Subscription) error { return nil }
websocketSetup.Subscriber = func(subscription.List) error { return nil }
websocketSetup.Features.Unsubscribe = true
err = w.Setup(websocketSetup)
if !errors.Is(err, errWebsocketUnsubscriberUnset) {
t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketUnsubscriberUnset)
}

websocketSetup.Unsubscriber = func([]subscription.Subscription) error { return nil }
websocketSetup.Unsubscriber = func(subscription.List) error { return nil }
err = w.Setup(websocketSetup)
if !errors.Is(err, errWebsocketSubscriptionsGeneratorUnset) {
t.Fatalf("received: '%v' but expected: '%v'", err, errWebsocketSubscriptionsGeneratorUnset)
}

websocketSetup.GenerateSubscriptions = func() ([]subscription.Subscription, error) { return nil, nil }
websocketSetup.GenerateSubscriptions = func() (subscription.List, error) { return nil, nil }
err = w.Setup(websocketSetup)
if !errors.Is(err, errDefaultURLIsEmpty) {
t.Fatalf("received: '%v' but expected: '%v'", err, errDefaultURLIsEmpty)
Expand Down Expand Up @@ -505,12 +505,15 @@ func TestSubscribeUnsubscribe(t *testing.T) {
ws := *New()
assert.NoError(t, ws.Setup(defaultSetup), "WS Setup should not error")

fnSub := func(subs []subscription.Subscription) error {
ws.AddSuccessfulSubscriptions(subs...)
fnSub := func(subs subscription.List) error {
for _, s := range subs {
s.SetState(subscription.SubscribedState)
}
ws.AddSubscriptions(subs)
return nil
}
fnUnsub := func(unsubs []subscription.Subscription) error {
ws.RemoveSubscriptions(unsubs...)
fnUnsub := func(unsubs subscription.List) error {
ws.RemoveSubscriptions(unsubs)
return nil
}
ws.Subscriber = fnSub
Expand All @@ -523,10 +526,10 @@ func TestSubscribeUnsubscribe(t *testing.T) {
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error")
assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions")
byDefKey := ws.GetSubscription(subscription.Key{Channel: "TestSub"})
if assert.NotNil(t, byDefKey, "GetSubscription by default key should find a channel") {
assert.Equal(t, "TestSub", byDefKey.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.NotSame(t, byDefKey, ws.subscriptions["TestSub"], "GetSubscription returns a fresh pointer")
bySub := ws.GetSubscription(subscription.Subscription{Channel: "TestSub"})
if assert.NotNil(t, bySub, "GetSubscription by by subscription should find a channel") {
assert.Equal(t, "TestSub", bySub.Channel, "GetSubscription by default key should return a pointer a copy of the right channel")
assert.Same(t, bySub, subs[0], "GetSubscription returns a fresh pointer")
}
if assert.NotNil(t, ws.GetSubscription("purple"), "GetSubscription by string key should find a channel") {
assert.Equal(t, "TestSub2", ws.GetSubscription("purple").Channel, "GetSubscription by string key should return a pointer a copy of the right channel")
Expand All @@ -539,7 +542,7 @@ func TestSubscribeUnsubscribe(t *testing.T) {
}
assert.Nil(t, ws.GetSubscription(nil), "GetSubscription by nil should return nil")
assert.Nil(t, ws.GetSubscription(45), "GetSubscription by invalid key should return nil")
assert.ErrorIs(t, ws.SubscribeToChannels(subs), ErrSubscribedAlready, "Subscribe should error when already subscribed")
assert.ErrorIs(t, ws.SubscribeToChannels(subs), subscription.ErrDuplicate, "Subscribe should error when already subscribed")
assert.ErrorIs(t, ws.SubscribeToChannels(nil), errNoSubscriptionsSupplied, "Subscribe to nil should error")
assert.NoError(t, ws.UnsubscribeChannels(subs), "Unsubscribing should not error")
}
Expand All @@ -557,48 +560,29 @@ func TestResubscribe(t *testing.T) {
err = ws.Setup(defaultSetup)
assert.NoError(t, err, "WS Setup should not error")

fnSub := func(subs []subscription.Subscription) error {
ws.AddSuccessfulSubscriptions(subs...)
fnSub := func(subs subscription.List) error {
for _, s := range subs {
s.SetState(subscription.SubscribedState)
}
ws.AddSubscriptions(subs)
return nil
}
fnUnsub := func(unsubs []subscription.Subscription) error {
ws.RemoveSubscriptions(unsubs...)
fnUnsub := func(unsubs subscription.List) error {
ws.RemoveSubscriptions(unsubs)
return nil
}
ws.Subscriber = fnSub
ws.Unsubscriber = fnUnsub

channel := []subscription.Subscription{{Channel: "resubTest"}}
channel := subscription.List{{Channel: "resubTest"}}

assert.ErrorIs(t, ws.ResubscribeToChannel(&channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error")
assert.NoError(t, ws.ResubscribeToChannel(&channel[0]), "Resubscribe should not error now the channel is subscribed")
assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed")
}

// TestSubscriptionState tests Subscription state changes
func TestSubscriptionState(t *testing.T) {
t.Parallel()
ws := New()

c := &subscription.Subscription{Key: 42, Channel: "Gophers", State: subscription.SubscribingState}
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), ErrSubscriptionNotFound, "Setting an imaginary sub should error")

assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
found := ws.GetSubscription(42)
assert.NotNil(t, found, "Should find the subscription")
assert.Equal(t, subscription.SubscribingState, found.State, "Subscription should be Subscribing")
assert.ErrorIs(t, ws.AddSubscription(c), ErrSubscribedAlready, "Adding an already existing sub should error")
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.SubscribingState), ErrChannelInStateAlready, "Setting Same state should error")
assert.ErrorIs(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState+1), errInvalidChannelState, "Setting an invalid state should error")

ws.AddSuccessfulSubscriptions(*c)
found = ws.GetSubscription(42)
assert.NotNil(t, found, "Should find the subscription")
assert.Equal(t, found.State, subscription.SubscribedState, "Subscription should be subscribed state")

assert.NoError(t, ws.SetSubscriptionState(c, subscription.UnsubscribingState), "Setting Unsub state should not error")
found = ws.GetSubscription(42)
assert.Equal(t, found.State, subscription.UnsubscribingState, "Subscription should be unsubscribing state")
func TestAddSubscription(t *testing.T) {
t.Fatal("Not implemented, along with others")
}

// TestRemoveSubscriptions tests removing a subscription
Expand All @@ -610,14 +594,14 @@ func TestRemoveSubscriptions(t *testing.T) {
assert.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error")
assert.NotNil(t, ws.GetSubscription(42), "Added subscription should be findable")

ws.RemoveSubscriptions(*c)
ws.RemoveSubscriptions(subscription.List{c})
assert.Nil(t, ws.GetSubscription(42), "Remove should have removed the sub")
}

// TestConnectionMonitorNoConnection logic test
func TestConnectionMonitorNoConnection(t *testing.T) {
t.Parallel()
ws := *New()
ws := New()
ws.connectionMonitorDelay = 500
ws.DataHandler = make(chan interface{}, 1)
ws.ShutdownC = make(chan struct{}, 1)
Expand All @@ -641,19 +625,12 @@ func TestConnectionMonitorNoConnection(t *testing.T) {
func TestGetSubscription(t *testing.T) {
t.Parallel()
assert.Nil(t, (*Websocket).GetSubscription(nil, "imaginary"), "GetSubscription on a nil Websocket should return nil")
assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub map should return nil")
w := Websocket{
subscriptions: subscription.Map{
42: {
Channel: "hello3",
},
},
}
assert.Nil(t, w.GetSubscription(43), "GetSubscription with an invalid key should return nil")
c := w.GetSubscription(42)
if assert.NotNil(t, c, "GetSubscription with an valid key should return a channel") {
assert.Equal(t, "hello3", c.Channel, "GetSubscription should return the correct channel details")
}
assert.Nil(t, (&Websocket{}).GetSubscription("empty"), "GetSubscription on a Websocket with no sub store should return nil")
w := New()
assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil")
s := &subscription.Subscription{Key: 42, Channel: "hello3"}
w.AddSubscription(s)
assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store")
}

// TestGetSubscriptions logic test
Expand Down
2 changes: 1 addition & 1 deletion exchanges/subscription/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (s *Subscription) State() State {
func (s *Subscription) SetState(state State) error {
s.m.Lock()
defer s.m.Unlock()
if state == s.State() {
if state == s.state {
return ErrInStateAlready
}
if state > UnsubscribingState {
Expand Down
27 changes: 21 additions & 6 deletions exchanges/subscription/subscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/thrasher-corp/gocryptotrader/currency"
"github.com/thrasher-corp/gocryptotrader/exchanges/asset"
"github.com/thrasher-corp/gocryptotrader/exchanges/kline"
Expand All @@ -13,25 +14,25 @@ import (
// TestEnsureKeyed logic test
func TestEnsureKeyed(t *testing.T) {
t.Parallel()
c := &Subscription{
s := &Subscription{
Channel: "candles",
Asset: asset.Spot,
Pairs: []currency.Pair{currency.NewPair(currency.BTC, currency.USDT)},
}
k1, ok := c.EnsureKeyed().(*Subscription)
k1, ok := s.EnsureKeyed().(*Subscription)
if assert.True(t, ok, "EnsureKeyed should return a *Subscription") {
assert.Same(t, k1, c, "Key should point to the same struct")
assert.Same(t, k1, s, "Key should point to the same struct")
}
type platypus string
c = &Subscription{
s = &Subscription{
Key: platypus("Gerald"),
Channel: "orderbook",
Asset: asset.Margin,
Pairs: []currency.Pair{currency.NewPair(currency.ETH, currency.USDC)},
}
k2, ok := c.EnsureKeyed().(platypus)
k2, ok := s.EnsureKeyed().(platypus)
if assert.True(t, ok, "EnsureKeyed should return a platypus") {
assert.Exactly(t, k2, c.Key, "ensureKeyed should set the same key")
assert.Exactly(t, k2, s.Key, "ensureKeyed should set the same key")
assert.EqualValues(t, "Gerald", k2, "key should have the correct value")
}
}
Expand All @@ -55,3 +56,17 @@ func TestMarshaling(t *testing.T) {
assert.NoError(t, err, "Marshalling should not error")
assert.Equal(t, `{"enabled":true,"channel":"myTrades","authenticated":true}`, string(j), "Marshalling should be clean and concise")
}

// TestSetState tests Subscription state changes
func TestSetState(t *testing.T) {
t.Parallel()

s := &Subscription{Key: 42, Channel: "Gophers"}
assert.Equal(t, UnknownState, s.State(), "State should start as unknown")
require.NoError(t, s.SetState(SubscribingState), "SetState should not error")
assert.Equal(t, SubscribingState, s.State(), "State should be set correctly")
assert.ErrorIs(t, s.SetState(SubscribingState), ErrInStateAlready, "SetState should error on same state")
assert.ErrorIs(t, s.SetState(UnsubscribingState+1), ErrInvalidState, "Setting an invalid state should error")
require.NoError(t, s.SetState(UnsubscribingState), "SetState should not error")
assert.Equal(t, UnsubscribingState, s.State(), "State should be set correctly")
}

0 comments on commit 916ecb3

Please sign in to comment.