Skip to content

Commit

Permalink
Subscriptions: Nil pointer protections
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Feb 19, 2024
1 parent abe5241 commit a9c1cae
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
3 changes: 1 addition & 2 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ const (
// Public websocket errors
var (
ErrWebsocketNotEnabled = errors.New("websocket not enabled")
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
Expand Down Expand Up @@ -788,7 +787,7 @@ func (w *Websocket) UnsubscribeChannels(channels subscription.List) error {
}
for _, s := range channels {
if w.subscriptions.Get(s) == nil {
return fmt.Errorf("%s websocket: %w: %s", w.exchangeName, ErrSubscriptionNotFound, s)
return fmt.Errorf("%s websocket: %w: %s", w.exchangeName, subscription.ErrNotFound, s)
}
}
return w.Unsubscriber(channels)
Expand Down
4 changes: 2 additions & 2 deletions exchanges/stream/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ func TestSubscribeUnsubscribe(t *testing.T) {
subs, err := ws.GenerateSubs()
assert.NoError(t, err, "Generating test subscriptions should not error")
assert.ErrorIs(t, ws.UnsubscribeChannels(nil), errNoSubscriptionsSupplied, "Unsubscribing from nil should error")
assert.ErrorIs(t, ws.UnsubscribeChannels(subs), ErrSubscriptionNotFound, "Unsubscribing should error when not subscribed")
assert.ErrorIs(t, ws.UnsubscribeChannels(subs), subscription.ErrNotFound, "Unsubscribing should error when not subscribed")
assert.Nil(t, ws.GetSubscription(42), "GetSubscription on empty internal map should return")
assert.NoError(t, ws.SubscribeToChannels(subs), "Basic Subscribing should not error")
assert.Len(t, ws.GetSubscriptions(), 4, "Should have 4 subscriptions")
Expand Down Expand Up @@ -508,7 +508,7 @@ func TestResubscribe(t *testing.T) {

channel := subscription.List{{Channel: "resubTest"}}

assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), ErrSubscriptionNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.ErrorIs(t, ws.ResubscribeToChannel(channel[0]), subscription.ErrNotFound, "Resubscribe should error when channel isn't subscribed yet")
assert.NoError(t, ws.SubscribeToChannels(channel), "Subscribe should not error")
assert.NoError(t, ws.ResubscribeToChannel(channel[0]), "Resubscribe should not error now the channel is subscribed")
}
Expand Down
32 changes: 29 additions & 3 deletions exchanges/subscription/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package subscription
import (
"maps"
"sync"

"github.com/thrasher-corp/gocryptotrader/common"
)

// Store is a container of subscription pointers
Expand Down Expand Up @@ -30,6 +32,9 @@ func NewStoreFromList(s List) *Store {
// Key can be already set; if ommitted EnsureKeyed will be used
// Errors if it already exists
func (s *Store) Add(sub *Subscription) error {
if s == nil {
return common.ErrNilPointer
}
s.mu.Lock()
defer s.mu.Unlock()
key := sub.EnsureKeyed()
Expand All @@ -45,6 +50,9 @@ func (s *Store) Add(sub *Subscription) error {
// Get returns a pointer to a subscription or nil if not found
// If key implements MatchableKey then key.Match will be used
func (s *Store) Get(key any) *Subscription {
if s == nil {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
return s.get(key)
Expand Down Expand Up @@ -72,20 +80,29 @@ func (s *Store) get(key any) *Subscription {
}

// Remove removes a subscription from the store
func (s *Store) Remove(sub *Subscription) {
func (s *Store) Remove(sub *Subscription) error {
if s == nil {
return common.ErrNilPointer
}
s.mu.Lock()
defer s.mu.Unlock()

if found := s.get(sub); found != nil {
delete(s.m, found.Key)
return nil
}

return ErrNotFound
}

// List returns a slice of Subscriptions pointers
func (s *Store) List() []*Subscription {
func (s *Store) List() List {
if s == nil {
return List{}
}
s.mu.RLock()
defer s.mu.RUnlock()
subs := make([]*Subscription, 0, len(s.m))
subs := make(List, 0, len(s.m))
for _, s := range s.m {
subs = append(subs, s)
}
Expand All @@ -94,6 +111,9 @@ func (s *Store) List() []*Subscription {

// Clear empties the subscription store
func (s *Store) Clear() {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
clear(s.m)
Expand All @@ -117,6 +137,9 @@ func (s *Store) match(key MatchableKey) *Subscription {
// The store Diff is invoked upon is read-lock protected
// The new store is assumed to be a new instance and enjoys no locking protection
func (s *Store) Diff(compare List) (added, removed List) {
if s == nil {
return
}
s.mu.RLock()
defer s.mu.RUnlock()
removedMap := maps.Clone(s.m)
Expand All @@ -137,6 +160,9 @@ func (s *Store) Diff(compare List) (added, removed List) {

// Len returns the number of subscriptions
func (s *Store) Len() int {
if s == nil {
return 0
}
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.m)
Expand Down
1 change: 1 addition & 0 deletions exchanges/subscription/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const (

// Public errors
var (
ErrNotFound = errors.New("subscription not found")
ErrNotSinglePair = errors.New("only single pair subscriptions expected")
ErrInStateAlready = errors.New("subscription already in state")
ErrInvalidState = errors.New("invalid subscription state")
Expand Down

0 comments on commit a9c1cae

Please sign in to comment.