diff --git a/currency/pairs.go b/currency/pairs.go index 887b46fcbc9..a68edfa9eb8 100644 --- a/currency/pairs.go +++ b/currency/pairs.go @@ -52,6 +52,19 @@ func (p Pairs) Strings() []string { return list } +// String is a convenience method returning a comma-separated string of uppercase currencies using / as delimiter +func (p Pairs) String() string { + f := PairFormat{ + Delimiter: "/", + Uppercase: true, + } + l := make([]string, len(p)) + for i, pair := range p { + l[i] = f.Format(pair) + } + return strings.Join(l, ",") +} + // Join returns a comma separated list of currency pairs func (p Pairs) Join() string { return strings.Join(p.Strings(), ",") diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 04f72e0f1f5..e2f25056bb0 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "os" - "sort" "strconv" "strings" "sync" @@ -527,22 +526,18 @@ func TestResubscribe(t *testing.T) { assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed") } -func TestAddSubscription(t *testing.T) { - t.Fatal("Not implemented, along with others") -} - -// TestRemoveSubscriptions tests removing a subscription -func TestRemoveSubscriptions(t *testing.T) { +// TestSubscriptions tests adding, getting and removing subscriptions +func TestSubscriptions(t *testing.T) { t.Parallel() - ws := NewWebsocket() + w := NewWebsocket() - c := &subscription.Subscription{Key: 42, Channel: "Unite!"} - require.NoError(t, ws.AddSubscription(c), "Adding first subscription should not error") - assert.NotNil(t, ws.GetSubscription(42), "Added subscription should be findable") + c := &subscription.Subscription{Key: 42, Channel: subscription.TickerChannel} + require.NoError(t, w.AddSubscription(c), "Adding first subscription should not error") + assert.Same(t, c, w.GetSubscription(42), "Get Subscription should retrieve the same subscription") - err := ws.RemoveSubscriptions(subscription.List{c}) + err := w.RemoveSubscriptions(subscription.List{c}) require.NoError(t, err, "RemoveSubscriptions must not error") - assert.Nil(t, ws.GetSubscription(42), "Remove should have removed the sub") + assert.Nil(t, w.GetSubscription(42), "Remove should have removed the sub") } // TestConnectionMonitorNoConnection logic test @@ -567,7 +562,7 @@ func TestGetSubscription(t *testing.T) { w := NewWebsocket() assert.Nil(t, w.GetSubscription(nil), "GetSubscription with a nil key should return nil") s := &subscription.Subscription{Key: 42, Channel: "hello3"} - w.AddSubscription(s) + require.NoError(t, w.AddSubscription(s), "AddSubscription must not error") assert.Same(t, s, w.GetSubscription(42), "GetSubscription should delegate to the store") } @@ -577,10 +572,11 @@ func TestGetSubscriptions(t *testing.T) { assert.Nil(t, (*Websocket).GetSubscriptions(nil), "GetSubscription on a nil Websocket should return nil") assert.Nil(t, (&Websocket{}).GetSubscriptions(), "GetSubscription on a Websocket with no sub store should return nil") w := NewWebsocket() - w.AddSubscriptions(subscription.List{ + err := w.AddSubscriptions(subscription.List{ {Key: 42, Channel: "hello3"}, {Key: 45, Channel: "hello4"}, }) + require.NoError(t, err, "AddSubscriptions must not error") assert.Equal(t, "hello3", w.GetSubscriptions()[0].Channel, "GetSubscriptions should return the correct channel details") } @@ -913,48 +909,20 @@ func TestCheckWebsocketURL(t *testing.T) { assert.NoError(t, err, "checkWebsocketURL should not error") } +// TestGetChannelDifference exercises GetChannelDifference +// See subscription.TestStoreDiff for further testing func TestGetChannelDifference(t *testing.T) { t.Parallel() - web := Websocket{} - - newChans := subscription.List{ - {Channel: "Test1"}, - {Channel: "Test2"}, - {Channel: "Test3"}, - } - subs, unsubs := web.GetChannelDifference(newChans) - assert.Implements(t, (*subscription.MatchableKey)(nil), subs[0].Key, "Sub key must be matchable") - assert.Equal(t, 3, len(subs), "Should get the correct number of subs") - assert.Empty(t, unsubs, "Should get no unsubs") - - for _, s := range subs { - s.SetState(subscription.SubscribedState) - } - - web.AddSubscriptions(subs) - flushedSubs := subscription.List{ - {Channel: "Test2"}, - } - - subs, unsubs = web.GetChannelDifference(flushedSubs) - assert.Empty(t, subs, "Should get no subs") - assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") - - flushedSubs = subscription.List{ - {Channel: "Test2"}, - {Channel: "Test4"}, - } - - subs, unsubs = web.GetChannelDifference(flushedSubs) - if assert.Equal(t, 1, len(subs), "Should get the correct number of subs") { - assert.Equal(t, "Test4", subs[0].Channel, "Should subscribe to the right channel") - } - if assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") { - sort.Slice(unsubs, func(i, j int) bool { return unsubs[i].Channel <= unsubs[j].Channel }) - assert.Equal(t, "Test1", unsubs[0].Channel, "Should unsubscribe from the right channels") - assert.Equal(t, "Test3", unsubs[1].Channel, "Should unsubscribe from the right channels") - } + w := &Websocket{} + assert.NotPanics(t, func() { w.GetChannelDifference(subscription.List{}) }, "Should not panic when called without a store") + subs, unsubs := w.GetChannelDifference(subscription.List{{Channel: subscription.CandlesChannel}}) + require.Equal(t, 1, len(subs), "Should get the correct number of subs") + require.Empty(t, unsubs, "Should get no unsubs") + require.NoError(t, w.AddSubscriptions(subs), "AddSubscriptions must not error") + subs, unsubs = w.GetChannelDifference(subscription.List{{Channel: subscription.TickerChannel}}) + require.Equal(t, 1, len(subs), "Should get the correct number of subs") + assert.Equal(t, 1, len(unsubs), "Should get the correct number of unsubs") } // GenSubs defines a theoretical exchange with pair management @@ -1050,10 +1018,7 @@ func TestFlushChannels(t *testing.T) { w.GenerateSubs = newgen.generateSubs subs, err := w.GenerateSubs() require.NoError(t, err, "GenerateSubs must not error") - for _, s := range subs { - s.SetState(subscription.SubscribedState) - } - w.AddSubscriptions(subs) + require.NoError(t, w.AddSubscriptions(subs), "AddSubscriptions must not error") err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") w.features.FullPayloadSubscribe = false @@ -1061,16 +1026,18 @@ func TestFlushChannels(t *testing.T) { w.GenerateSubs = newgen.generateSubs w.subscriptions = subscription.NewStore() - w.subscriptions.Add(&subscription.Subscription{ + err = w.subscriptions.Add(&subscription.Subscription{ Key: 41, Channel: "match channel", Pairs: currency.Pairs{currency.NewPair(currency.BTC, currency.AUD)}, }) - w.subscriptions.Add(&subscription.Subscription{ + require.NoError(t, err, "AddSubscription must not error") + err = w.subscriptions.Add(&subscription.Subscription{ Key: 42, Channel: "unsub channel", Pairs: currency.Pairs{currency.NewPair(currency.THETA, currency.USDT)}, }) + require.NoError(t, err, "AddSubscription must not error") err = w.FlushChannels() assert.NoError(t, err, "FlushChannels should not error") diff --git a/exchanges/subscription/list_test.go b/exchanges/subscription/list_test.go new file mode 100644 index 00000000000..e2293e7ea9d --- /dev/null +++ b/exchanges/subscription/list_test.go @@ -0,0 +1,25 @@ +package subscription + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" +) + +func TestListStrings(t *testing.T) { + l := List{ + &Subscription{ + Channel: TickerChannel, + Asset: asset.Spot, + Pairs: currency.Pairs{ethusdcPair, btcusdtPair}, + }, + &Subscription{ + Channel: OrderbookChannel, + Pairs: currency.Pairs{ethusdcPair}, + }, + } + exp := []string{"orderbook ETH/USDC", "ticker spot ETH/USDC,BTC/USDT"} + assert.ElementsMatch(t, exp, l.Strings(), "String must return correct sorted list") +} diff --git a/exchanges/subscription/store.go b/exchanges/subscription/store.go index 18db7fe338b..e769011acba 100644 --- a/exchanges/subscription/store.go +++ b/exchanges/subscription/store.go @@ -46,6 +46,9 @@ func (s *Store) Add(sub *Subscription) error { // Add adds a subscription to the store // This method provides no locking protection func (s *Store) add(sub *Subscription) error { + if s.m == nil { + s.m = map[any]*Subscription{} + } key := sub.EnsureKeyed() if found := s.get(key); found != nil { return ErrDuplicate @@ -55,9 +58,10 @@ func (s *Store) add(sub *Subscription) error { } // Get returns a pointer to a subscription or nil if not found -// If key implements MatchableKey then key.Match will be used +// If the key passed in is a Subscription then its Key will be used; which may be a pointer to itself. +// If key implements MatchableKey then key.Match will be used; Note that *Subscription implements MatchableKey func (s *Store) Get(key any) *Subscription { - if s == nil { + if s == nil || s.m == nil { return nil } s.mu.RLock() @@ -69,8 +73,10 @@ func (s *Store) Get(key any) *Subscription { // If the key passed in is a Subscription then its Key will be used; which may be a pointer to itself. // If key implements MatchableKey then key.Match will be used; Note that *Subscription implements MatchableKey // This method provides no locking protection -// returned subscriptions are implicitly guaranteed to have a Key func (s *Store) get(key any) *Subscription { + if s.m == nil { + return nil + } switch v := key.(type) { case Subscription: key = v.EnsureKeyed() @@ -87,14 +93,16 @@ func (s *Store) get(key any) *Subscription { } // Remove removes a subscription from the store -func (s *Store) Remove(sub *Subscription) error { - if s == nil || sub == nil { +// If the key passed in is a Subscription then its Key will be used; which may be a pointer to itself. +// If key implements MatchableKey then key.Match will be used; Note that *Subscription implements MatchableKey +func (s *Store) Remove(key any) error { + if s == nil || key == nil { return common.ErrNilPointer } s.mu.Lock() defer s.mu.Unlock() - if found := s.get(sub); found != nil { + if found := s.get(key); found != nil { delete(s.m, found.Key) return nil } @@ -104,7 +112,7 @@ func (s *Store) Remove(sub *Subscription) error { // List returns a slice of Subscriptions pointers func (s *Store) List() List { - if s == nil { + if s == nil || s.m == nil { return List{} } s.mu.RLock() @@ -123,6 +131,9 @@ func (s *Store) Clear() { } s.mu.Lock() defer s.mu.Unlock() + if s.m == nil { + s.m = map[any]*Subscription{} + } clear(s.m) } @@ -167,7 +178,7 @@ func (s *Store) Diff(compare List) (added, removed List) { // Len returns the number of subscriptions func (s *Store) Len() int { - if s == nil { + if s == nil || s.m == nil { return 0 } s.mu.RLock() diff --git a/exchanges/subscription/store_test.go b/exchanges/subscription/store_test.go new file mode 100644 index 00000000000..452315271c5 --- /dev/null +++ b/exchanges/subscription/store_test.go @@ -0,0 +1,184 @@ +package subscription + +import ( + "maps" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" +) + +// TestNewStore exercises NewStore +func TestNewStore(t *testing.T) { + s := NewStore() + require.IsType(t, &Store{}, s, "Must return a store ref") + require.NotNil(t, s.m, "storage map must be initialised") +} + +// TestNewStoreFromList exercises NewStoreFromList +func TestNewStoreFromList(t *testing.T) { + s, err := NewStoreFromList(List{}) + assert.NoError(t, err, "Should not error on empty list") + require.IsType(t, &Store{}, s, "Must return a store ref") + l := List{ + {Channel: OrderbookChannel}, + {Channel: TickerChannel}, + } + s, err = NewStoreFromList(l) + assert.NoError(t, err, "Should not error on empty list") + assert.Len(t, s.m, 2, "Map should have 2 values") + assert.NotNil(t, s.get(l[0]), "Should be able to get a list element") + + l = append(l, &Subscription{Channel: OrderbookChannel}) + _, err = NewStoreFromList(l) + assert.ErrorIs(t, err, ErrDuplicate, "Should error correctly on duplicates") +} + +// TestAdd exercises Add and add methods +func TestAdd(t *testing.T) { + assert.ErrorIs(t, (*Store)(nil).Add(&Subscription{}), common.ErrNilPointer, "Should error nil pointer correctly") + assert.ErrorIs(t, (&Store{}).Add(nil), common.ErrNilPointer, "Should error nil pointer correctly") + assert.NoError(t, (&Store{}).Add(&Subscription{}), "Should with no map should not error or panic") + s := NewStore() + sub := &Subscription{Channel: TickerChannel} + require.NoError(t, s.Add(sub), "Should not error on a standard add") + assert.NotNil(t, s.get(sub), "Should have stored the sub") + assert.ErrorIs(t, s.Add(sub), ErrDuplicate, "Should error on duplicates") + assert.Same(t, sub.Key, sub, "Add should call EnsureKeyed") +} + +// HobbitKey is just a fixture for testing MatchableKey +type HobbitKey int + +// Match implements MatchableKey +// Returns true if the key provided is twice as big as the actual sub key +func (f HobbitKey) Match(key any) bool { + i, ok := key.(HobbitKey) + return ok && int(i)*2 == int(f) +} + +// TestGet exercises Get and get methods +func TestGet(t *testing.T) { + assert.Nil(t, (*Store)(nil).Get(&Subscription{}), "Should return nil when called on nil") + assert.Nil(t, (&Store{}).Get(&Subscription{}), "Should return nil when called with no subscription map") + s := NewStore() + exp := List{ + {Channel: OrderbookChannel}, + {Channel: TickerChannel}, + {Key: 42, Channel: CandlesChannel}, + {Key: HobbitKey(24), Channel: CandlesChannel}, + } + for _, sub := range exp { + require.NoError(t, s.Add(sub), "Adding subscription must not error)") + } + + assert.Nil(t, s.Get(Subscription{Channel: OrderbookChannel, Asset: asset.Spot}), "Should return nil for an unknown sub") + assert.Same(t, exp[0], s.Get(exp[0]), "Should return same pointer for known sub") + assert.Same(t, exp[1], s.Get(Subscription{Channel: TickerChannel}), "Should return pointer for known sub passed-by-value") + assert.Same(t, exp[2], s.Get(42), "Should return pointer for simple key lookup") + assert.Same(t, exp[3], s.Get(HobbitKey(48)), "Should use MatchableKey interface to find subs") + assert.Nil(t, s.Get(HobbitKey(24)), "Should use MatchableKey interface to find subs, therefore not find a HobbitKey 24") +} + +// TestRemove exercises the Remove method +func TestRemove(t *testing.T) { + assert.ErrorIs(t, (*Store)(nil).Remove(&Subscription{}), common.ErrNilPointer, "Should error correctly when called on nil") + assert.ErrorIs(t, (&Store{}).Remove(nil), common.ErrNilPointer, "Should error correctly when called passing nil") + assert.ErrorIs(t, (&Store{}).Remove(&Subscription{}), ErrNotFound, "Should error correctly when called with no subscription map") + s := NewStore() + require.NoError(t, s.Add(&Subscription{Key: HobbitKey(24), Channel: CandlesChannel}), "Adding subscription must not error") + assert.ErrorIs(t, s.Remove(HobbitKey(24)), ErrNotFound, "Should error correctly when called with a non-matching hobbitkey") + assert.NoError(t, s.Remove(HobbitKey(48)), "Should not error correctly when called matching hobbitkey") + assert.Nil(t, s.Get(HobbitKey(48)), "Should have removed the sub") + assert.ErrorIs(t, s.Remove(HobbitKey(48)), ErrNotFound, "Should error correctly when called twice on same key") +} + +// TestList exercises the List and Len methods +func TestList(t *testing.T) { + assert.Empty(t, (*Store)(nil).List(), "Should return an empty List when called on nil") + assert.Empty(t, (&Store{}).List(), "Should return an empty List when called on Store without map") + s := NewStore() + exp := List{ + {Channel: OrderbookChannel}, + {Channel: TickerChannel}, + {Key: 42, Channel: CandlesChannel}, + } + for _, sub := range exp { + require.NoError(t, s.Add(sub), "Adding subscription must not error)") + } + l := s.List() + require.Len(t, l, 3, "Must have 3 elements in the list") + assert.ElementsMatch(t, exp, l, "List Should have the same subscriptions") + + require.Equal(t, 3, s.Len(), "Len must return 3") + require.Equal(t, 0, (*Store)(nil).Len(), "Len must return 0 on a nil store") + require.Equal(t, 0, (&Store{}).Len(), "Len must return 0 on an uninitialized store") +} + +// TestStoreClear exercises the Clear method +func TestStoreClear(t *testing.T) { + assert.NotPanics(t, func() { (*Store)(nil).Clear() }, "Should not panic when called on nil") + s := &Store{} + assert.NotPanics(t, func() { s.Clear() }, "Should not panic when called with no subscription map") + assert.NotNil(t, s.m, "Should create a map when called on an empty Store") + require.NoError(t, s.Add(&Subscription{Key: HobbitKey(24), Channel: CandlesChannel}), "Adding subscription must not error") + require.Len(t, s.m, 1, "Must have a subscription") + s.Clear() + require.Empty(t, s.m, "Map must be empty after clearing") + assert.NotPanics(t, func() { s.Clear() }, "Should not panic when called on an empty map") +} + +// TestStoreDiff exercises the Diff method +func TestStoreDiff(t *testing.T) { + s := NewStore() + assert.NotPanics(t, func() { (*Store)(nil).Diff(List{}) }, "Should not panic when called on nil") + assert.NotPanics(t, func() { (&Store{}).Diff(List{}) }, "Should not panic when called with no subscription map") + subs, unsubs := s.Diff(List{{Channel: TickerChannel}, {Channel: CandlesChannel}, {Channel: OrderbookChannel}}) + assert.Equal(t, 3, len(subs), "Should get the correct number of subs") + assert.Empty(t, unsubs, "Should get no unsubs") + for _, sub := range subs { + require.NoError(t, s.add(sub), "add must not error") + } + assert.NotPanics(t, func() { s.Diff(nil) }, "Should not panic when called with nil list") + + subs, unsubs = s.Diff(List{{Channel: CandlesChannel}}) + assert.Empty(t, subs, "Should get no subs") + assert.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") + subs, unsubs = s.Diff(List{{Channel: TickerChannel}, {Channel: MyTradesChannel}}) + require.Equal(t, 1, len(subs), "Should get the correct number of subs") + assert.Equal(t, MyTradesChannel, subs[0].Channel, "Should get correct channels in sub") + require.Equal(t, 2, len(unsubs), "Should get the correct number of unsubs") + EqualLists(t, unsubs, List{{Channel: OrderbookChannel}, {Channel: CandlesChannel}}) +} + +func EqualLists(tb testing.TB, a, b List) { + tb.Helper() + // Must not use store.Diff directly + s, err := NewStoreFromList(a) + require.NoError(tb, err, "NewStoreFromList must not error") + missingMap := maps.Clone(s.m) + var added, missing List + for _, sub := range b { + if found := s.get(sub); found != nil { + delete(missingMap, found.Key) + } else { + added = append(added, sub) + } + } + for _, c := range missingMap { + missing = append(missing, c) + } + if len(added) > 0 || len(missing) > 0 { + fail := "Differences:" + if len(added) > 0 { + fail = fail + "\n + " + strings.Join(added.Strings(), "\n + ") + } + if len(missing) > 0 { + fail = fail + "\n - " + strings.Join(missing.Strings(), "\n - ") + } + assert.Fail(tb, fail, "Subscriptions should be equal") + } +} diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index 3e340ebbc1c..32e799ad44c 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -88,8 +88,8 @@ func (s *Subscription) SetState(state State) error { return nil } -// EnsureKeyed sets the default key on a channel if it doesn't have one -// Returns key for convenience +// EnsureKeyed returns the subscription key +// If no key exists then a pointer to the subscription itself will be used, since Subscriptions implement MatchableKey func (s *Subscription) EnsureKeyed() any { if s.Key == nil { s.Key = s @@ -103,17 +103,25 @@ func (s *Subscription) EnsureKeyed() any { // 2) >=1 pairs then Subscriptions which contain all the pairs match // Such that a subscription for all enabled pairs will be matched when seaching for any one pair func (s *Subscription) Match(key any) bool { - b, ok := key.(*Subscription) + var b *Subscription + switch v := key.(type) { + case *Subscription: + b = v + case Subscription: + b = &v + default: + return false + } + switch { - case !ok, - s.Channel != b.Channel, - s.Asset != b.Asset, - len(b.Pairs) == 0 && len(s.Pairs) != 0, + case b.Channel != s.Channel, + b.Asset != s.Asset, // len(b.Pairs) == 0 && len(s.Pairs) == 0: Okay; continue to next non-pairs check + len(b.Pairs) == 0 && len(s.Pairs) != 0, len(b.Pairs) != 0 && len(s.Pairs) == 0, len(b.Pairs) != 0 && s.Pairs.ContainsAll(b.Pairs, true) != nil, - s.Levels != b.Levels, - s.Interval != b.Interval: + b.Levels != s.Levels, + b.Interval != s.Interval: return false } diff --git a/exchanges/subscription/subscription_test.go b/exchanges/subscription/subscription_test.go index 38cabb8694a..b9a71b4ae6a 100644 --- a/exchanges/subscription/subscription_test.go +++ b/exchanges/subscription/subscription_test.go @@ -11,36 +11,67 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/kline" ) -// TestEnsureKeyed logic test -func TestEnsureKeyed(t *testing.T) { - t.Parallel() +var ( + btcusdtPair = currency.NewPair(currency.BTC, currency.USDT) + ethusdcPair = currency.NewPair(currency.ETH, currency.USDC) + ltcusdcPair = currency.NewPair(currency.LTC, currency.USDC) +) + +// TestSubscriptionString exercises the String method +func TestSubscriptionString(t *testing.T) { s := &Subscription{ Channel: "candles", Asset: asset.Spot, - Pairs: []currency.Pair{currency.NewPair(currency.BTC, currency.USDT)}, + Pairs: currency.Pairs{btcusdtPair, ethusdcPair.Format(currency.PairFormat{Delimiter: "/"})}, + } + assert.Equal(t, "candles spot BTC/USDT,ETH/USDC", s.String(), "Subscription String should return correct value") +} + +// TestState exercises the state getter +func TestState(t *testing.T) { + t.Parallel() + s := &Subscription{} + assert.Equal(t, InactiveState, s.State(), "State should return initial state") + s.state = SubscribedState + assert.Equal(t, SubscribedState, s.State(), "State should return correct state") +} + +// TestSetState exercises the state setter +func TestSetState(t *testing.T) { + t.Parallel() + + s := &Subscription{state: UnsubscribingState} + + for i := InactiveState; i <= UnsubscribingState; i++ { + assert.NoErrorf(t, s.SetState(i), "State should not error setting state %s", i) } + assert.ErrorIs(t, s.SetState(UnsubscribingState), ErrInStateAlready, "SetState should error on same state") + assert.ErrorIs(t, s.SetState(UnsubscribingState+1), ErrInvalidState, "Setting an invalid state should error") +} + +// TestEnsureKeyed exercises the key getter and ensures it sets a self-pointer key for non +func TestEnsureKeyed(t *testing.T) { + t.Parallel() + s := &Subscription{} k1, ok := s.EnsureKeyed().(*Subscription) if assert.True(t, ok, "EnsureKeyed should return a *Subscription") { - assert.Same(t, k1, s, "Key should point to the same struct") + assert.Same(t, s, k1, "Key should point to the same struct") } type platypus string s = &Subscription{ Key: platypus("Gerald"), Channel: "orderbook", - Asset: asset.Margin, - Pairs: []currency.Pair{currency.NewPair(currency.ETH, currency.USDC)}, - } - k2, ok := s.EnsureKeyed().(platypus) - if assert.True(t, ok, "EnsureKeyed should return a platypus") { - assert.Exactly(t, k2, s.Key, "ensureKeyed should set the same key") - assert.EqualValues(t, "Gerald", k2, "key should have the correct value") } + k2 := s.EnsureKeyed() + assert.IsType(t, platypus(""), k2, "EnsureKeyed should return a platypus") + assert.Equal(t, s.Key, k2, "Key should be the key provided") } -// TestMarshalling logic test -func TestMarshaling(t *testing.T) { +// TestSubscriptionMarshalling ensures json Marshalling is clean and concise +// Since there is no UnmarshalJSON, this just exercises the json field tags of Subscription, and regressions in conciseness +func TestSubscriptionMarshaling(t *testing.T) { t.Parallel() - j, err := json.Marshal(&Subscription{Channel: CandlesChannel}) + j, err := json.Marshal(&Subscription{Key: 42, Channel: CandlesChannel}) assert.NoError(t, err, "Marshalling should not error") assert.Equal(t, `{"enabled":false,"channel":"candles"}`, string(j), "Marshalling should be clean and concise") @@ -57,16 +88,44 @@ func TestMarshaling(t *testing.T) { assert.Equal(t, `{"enabled":true,"channel":"myTrades","authenticated":true}`, string(j), "Marshalling should be clean and concise") } -// TestSetState tests Subscription state changes -func TestSetState(t *testing.T) { +// TestSubscriptionMatch exercises the Subscription MatchableKey interface implementation +func TestSubscriptionMatch(t *testing.T) { t.Parallel() + require.Implements(t, (*MatchableKey)(nil), new(Subscription), "Must implement MatchableKey") + s := &Subscription{Channel: TickerChannel} + assert.NotNil(t, s.EnsureKeyed(), "EnsureKeyed should work") + assert.False(t, s.Match(42), "Match should reject an invalid key type") + try := &Subscription{Channel: OrderbookChannel} + require.False(t, s.Match(try), "Gate 1: Match must reject a bad Channel") + try = &Subscription{Channel: TickerChannel} + require.True(t, s.Match(Subscription{Channel: TickerChannel}), "Match must accept a pass-by-value subscription") + require.True(t, s.Match(try), "Gate 1: Match must accept a good Channel") + s.Asset = asset.Spot + require.False(t, s.Match(try), "Gate 2: Match must reject a bad Asset") + try.Asset = asset.Spot + require.True(t, s.Match(try), "Gate 2: Match must accept a good Asset") - s := &Subscription{Key: 42, Channel: "Gophers"} - assert.Equal(t, InactiveState, s.State(), "State should start as unknown") - require.NoError(t, s.SetState(SubscribingState), "SetState should not error") - assert.Equal(t, SubscribingState, s.State(), "State should be set correctly") - assert.ErrorIs(t, s.SetState(SubscribingState), ErrInStateAlready, "SetState should error on same state") - assert.ErrorIs(t, s.SetState(UnsubscribingState+1), ErrInvalidState, "Setting an invalid state should error") - require.NoError(t, s.SetState(UnsubscribingState), "SetState should not error") - assert.Equal(t, UnsubscribingState, s.State(), "State should be set correctly") + s.Pairs = currency.Pairs{btcusdtPair} + require.False(t, s.Match(try), "Gate 3: Match must reject a pair list when searching for no pairs") + try.Pairs = s.Pairs + s.Pairs = nil + require.False(t, s.Match(try), "Gate 4: Match must reject empty Pairs when searching for a list") + s.Pairs = try.Pairs + require.True(t, s.Match(try), "Gate 5: Match must accept matching pairs") + s.Pairs = currency.Pairs{ethusdcPair} + require.False(t, s.Match(try), "Gate 5: Match must reject mismatched pairs") + s.Pairs = currency.Pairs{btcusdtPair, ethusdcPair} + require.True(t, s.Match(try), "Gate 5: Match must accept one of the key pairs matching in sub pairs") + try.Pairs = currency.Pairs{btcusdtPair, ltcusdcPair} + require.False(t, s.Match(try), "Gate 5: Match must reject when sub pair list doesn't contain all key pairs") + s.Pairs = currency.Pairs{btcusdtPair, ethusdcPair, ltcusdcPair} + require.True(t, s.Match(try), "Gate 5: Match must accept all of the key pairs are contained in sub pairs") + s.Levels = 4 + require.False(t, s.Match(try), "Gate 6: Match must reject a bad Level") + try.Levels = 4 + require.True(t, s.Match(try), "Gate 6: Match must accept a good Level") + s.Interval = kline.FiveMin + require.False(t, s.Match(try), "Gate 7: Match must reject a bad Interval") + try.Interval = kline.FiveMin + require.True(t, s.Match(try), "Gate 7: Match must accept a good Inteval") } diff --git a/internal/testing/subscriptions/subscriptions.go b/internal/testing/subscriptions/subscriptions.go index 1604be279b0..9e1192866c4 100644 --- a/internal/testing/subscriptions/subscriptions.go +++ b/internal/testing/subscriptions/subscriptions.go @@ -9,10 +9,12 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) +// Equal is a utility function to compare subscription lists and show a pretty failure message +// It overcomes the verbose depth of assert.ElementsMatch spewConfig func Equal(tb testing.TB, a, b subscription.List) { tb.Helper() s, err := subscription.NewStoreFromList(a) - require.NoError(t, err, "NewStoreFromList must not error") + require.NoError(tb, err, "NewStoreFromList must not error") added, missing := s.Diff(b) if len(added) > 0 || len(missing) > 0 { fail := "Differences:"