From 09a26f3380a49a66a839789e31d3c31fa2dc72f6 Mon Sep 17 00:00:00 2001 From: Gareth Kirwan Date: Sat, 17 Feb 2024 18:07:47 +0700 Subject: [PATCH] Subscriptions: Nil pointer protections --- exchanges/stream/websocket.go | 3 +-- exchanges/stream/websocket_test.go | 4 ++-- exchanges/subscription/store.go | 32 +++++++++++++++++++++++--- exchanges/subscription/subscription.go | 1 + 4 files changed, 33 insertions(+), 7 deletions(-) diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 5fa1edb1356..dc65232416b 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -26,7 +26,6 @@ const ( // Public errors var ( - ErrSubscriptionNotFound = errors.New("subscription not found") ErrSubscriptionFailure = errors.New("subscription failure") ErrSubscriptionNotSupported = errors.New("subscription channel not supported ") ErrUnsubscribeFailure = errors.New("unsubscribe failure") @@ -876,7 +875,7 @@ func (w *Websocket) UnsubscribeChannels(channels subscription.List) error { } for _, s := range channels { if w.subscriptions.Get(s) == nil { - return fmt.Errorf("%s websocket: %w: %s", w.exchangeName, ErrSubscriptionNotFound, s) + return fmt.Errorf("%s websocket: %w: %s", w.exchangeName, subscription.ErrNotFound, s) } } return w.Unsubscriber(channels) diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 7b215f51f58..3e0d032bbfa 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -523,7 +523,7 @@ func TestSubscribeUnsubscribe(t *testing.T) { subs, err := ws.GenerateSubs() assert.NoError(t, err, "Generating test subscriptions should not error") assert.ErrorIs(t, ws.UnsubscribeChannels(nil), errNoSubscriptionsSupplied, "Unsubscribing from nil should error") - assert.ErrorIs(t, ws.UnsubscribeChannels(subs), ErrSubscriptionNotFound, "Unsubscribing should error when not subscribed") + assert.ErrorIs(t, ws.UnsubscribeChannels(subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed") assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return") assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error") assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions") @@ -577,7 +577,7 @@ func TestResubscribe(t *testing.T) { channel := subscription.List{{Channel: "resubTest"}} - assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet") + assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet") assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error") assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed") } diff --git a/exchanges/subscription/store.go b/exchanges/subscription/store.go index 8e833c6fd65..93bd8b553a7 100644 --- a/exchanges/subscription/store.go +++ b/exchanges/subscription/store.go @@ -3,6 +3,8 @@ package subscription import ( "maps" "sync" + + "github.com/thrasher-corp/gocryptotrader/common" ) // Store is a container of subscription pointers @@ -30,6 +32,9 @@ func NewStoreFromList(s List) *Store { // Key can be already set; if ommitted EnsureKeyed will be used // Errors if it already exists func (s *Store) Add(sub *Subscription) error { + if s == nil { + return common.ErrNilPointer + } s.mu.Lock() defer s.mu.Unlock() key := sub.EnsureKeyed() @@ -45,6 +50,9 @@ func (s *Store) Add(sub *Subscription) error { // Get returns a pointer to a subscription or nil if not found // If key implements MatchableKey then key.Match will be used func (s *Store) Get(key any) *Subscription { + if s == nil { + return nil + } s.mu.RLock() defer s.mu.RUnlock() return s.get(key) @@ -72,20 +80,29 @@ func (s *Store) get(key any) *Subscription { } // Remove removes a subscription from the store -func (s *Store) Remove(sub *Subscription) { +func (s *Store) Remove(sub *Subscription) error { + if s == nil { + return common.ErrNilPointer + } s.mu.Lock() defer s.mu.Unlock() if found := s.get(sub); found != nil { delete(s.m, found.Key) + return nil } + + return ErrNotFound } // List returns a slice of Subscriptions pointers -func (s *Store) List() []*Subscription { +func (s *Store) List() List { + if s == nil { + return List{} + } s.mu.RLock() defer s.mu.RUnlock() - subs := make([]*Subscription, 0, len(s.m)) + subs := make(List, 0, len(s.m)) for _, s := range s.m { subs = append(subs, s) } @@ -94,6 +111,9 @@ func (s *Store) List() []*Subscription { // Clear empties the subscription store func (s *Store) Clear() { + if s == nil { + return + } s.mu.Lock() defer s.mu.Unlock() clear(s.m) @@ -117,6 +137,9 @@ func (s *Store) match(key MatchableKey) *Subscription { // The store Diff is invoked upon is read-lock protected // The new store is assumed to be a new instance and enjoys no locking protection func (s *Store) Diff(compare List) (added, removed List) { + if s == nil { + return + } s.mu.RLock() defer s.mu.RUnlock() removedMap := maps.Clone(s.m) @@ -137,6 +160,9 @@ func (s *Store) Diff(compare List) (added, removed List) { // Len returns the number of subscriptions func (s *Store) Len() int { + if s == nil { + return 0 + } s.mu.RLock() defer s.mu.RUnlock() return len(s.m) diff --git a/exchanges/subscription/subscription.go b/exchanges/subscription/subscription.go index cdcbdc1c5fe..80ec8fc4d10 100644 --- a/exchanges/subscription/subscription.go +++ b/exchanges/subscription/subscription.go @@ -27,6 +27,7 @@ const ( // Public errors var ( + ErrNotFound = errors.New("subscription not found") ErrNotSinglePair = errors.New("only single pair subscriptions expected") ErrInStateAlready = errors.New("subscription already in state") ErrInvalidState = errors.New("invalid subscription state")