Skip to content

Commit

Permalink
Subscriptions: Internalise Map and Subscription
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Feb 14, 2024
1 parent 946fbb4 commit 054830e
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 230 deletions.
177 changes: 49 additions & 128 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ const (
// Public errors
var (
ErrSubscriptionNotFound = errors.New("subscription not found")
ErrSubscribedAlready = errors.New("duplicate subscription")
ErrSubscriptionFailure = errors.New("subscription failure")
ErrSubscriptionNotSupported = errors.New("subscription channel not supported ")
ErrUnsubscribeFailure = errors.New("unsubscribe failure")
Expand Down Expand Up @@ -78,6 +77,7 @@ func New() *Websocket {
Subscribe: make(chan []subscription.Subscription),
Unsubscribe: make(chan []subscription.Subscription),
Match: NewMatch(),
subscriptions: subscription.Map{},
}
}

Expand Down Expand Up @@ -263,9 +263,10 @@ func (w *Websocket) Connect() error {
w.exchangeName)
}

w.subscriptionMutex.Lock()
w.subscriptions = subscription.Map{}
w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
return common.ErrNilPointer
}
w.subscriptions.Clear()

w.dataMonitor()
w.trafficMonitor()
Expand Down Expand Up @@ -295,17 +296,12 @@ func (w *Websocket) Connect() error {
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
if len(subs) == 0 {
return nil
}
err = w.checkSubscriptions(subs)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
err = w.Subscriber(subs)
if err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
if len(subs) != 0 {
if err = w.SubscribeToChannels(subs); err != nil {
return err
}
}

return nil
}

Expand Down Expand Up @@ -484,9 +480,7 @@ func (w *Websocket) Shutdown() error {
}

// flush any subscriptions from last connection if needed
w.subscriptionMutex.Lock()
w.subscriptions = subscription.Map{}
w.subscriptionMutex.Unlock()
w.subscriptions.Clear()

close(w.ShutdownC)
w.Wg.Wait()
Expand Down Expand Up @@ -870,40 +864,35 @@ func (w *Websocket) GetName() string {
// GetChannelDifference finds the difference between the subscribed channels
// and the new subscription list when pairs are disabled or enabled.
func (w *Websocket) GetChannelDifference(newSubs []subscription.Subscription) (sub, unsub []subscription.Subscription) {
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
if w.subscriptions == nil {
w.subscriptions = subscription.Map{}
}
return w.subscriptions.Diff(subscription.ListToMap(newSubs))
}

// UnsubscribeChannels unsubscribes from a websocket channel
func (w *Websocket) UnsubscribeChannels(channels []subscription.Subscription) error {
// UnsubscribeChannels unsubscribes from a list of websocket channel
func (w *Websocket) UnsubscribeChannels(channels []*subscription.Subscription) error {
if len(channels) == 0 {
return fmt.Errorf("%s websocket: %w", w.exchangeName, errNoSubscriptionsSupplied)
}
w.subscriptionMutex.RLock()

for i := range channels {
key := channels[i].EnsureKeyed()
if _, ok := w.subscriptions[key]; !ok {
w.subscriptionMutex.RUnlock()
return fmt.Errorf("%s websocket: %w: %+v", w.exchangeName, ErrSubscriptionNotFound, channels[i])
}
if w.subscriptions == nil {
return nil
}
w.subscriptionMutex.RUnlock()
return w.Unsubscriber(channels)
}

// ResubscribeToChannel resubscribes to channel
func (w *Websocket) ResubscribeToChannel(subscribedChannel *subscription.Subscription) error {
err := w.UnsubscribeChannels([]subscription.Subscription{*subscribedChannel})
func (w *Websocket) ResubscribeToChannel(s *subscription.Subscription) error {
err := w.UnsubscribeChannels(s)
if err != nil {
return err
}
return w.SubscribeToChannels([]subscription.Subscription{*subscribedChannel})
return w.SubscribeToChannels(s)
}

// SubscribeToChannels appends supplied channels to channelsToSubscribe
func (w *Websocket) SubscribeToChannels(channels []subscription.Subscription) error {
// SubscribeToChannels subscribes to websocket channels using the exchange specific Subscriber method
// Errors are returned for duplicates or exceeding max Subscriptions
func (w *Websocket) SubscribeToChannels(channels []*subscription.Subscription) error {
if err := w.checkSubscriptions(channels); err != nil {
return fmt.Errorf("%s websocket: %w", w.exchangeName, common.AppendError(ErrSubscriptionFailure, err))
}
Expand All @@ -914,105 +903,40 @@ func (w *Websocket) SubscribeToChannels(channels []subscription.Subscription) er
}

// AddSubscription adds a subscription to the subscription lists
// Unlike AddSubscriptions this method will error if the subscription already exists
func (w *Websocket) AddSubscription(c *subscription.Subscription) error {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
w.subscriptions = subscription.Map{}
}

key := c.EnsureKeyed()
if s := w.getSubscription(key); s != nil {
return ErrSubscribedAlready
if w == nil || key == nil {
return common.ErrNilPointer
}

n := *c // Fresh copy; we don't want to use the pointer we were given and allow encapsulation/locks to be bypassed
w.subscriptions[key] = &n

return nil
}

// SetSubscriptionState sets an existing subscription state
// returns an error if the subscription is not found, or the new state is already set
func (w *Websocket) SetSubscriptionState(c *subscription.Subscription, state subscription.State) error {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
if w.subscriptions == nil
w.subscriptions = subscription.Map{}

Check failure on line 911 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (ubuntu-latest, 386, true, true)

syntax error: cannot use assignment w.subscriptions = subscription.Map as value

Check failure on line 911 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (ubuntu-latest, amd64, true, false)

syntax error: cannot use assignment w.subscriptions = subscription.Map as value

Check failure on line 911 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (macos-13, amd64, true, true)

syntax error: cannot use assignment w.subscriptions = subscription.Map as value

Check failure on line 911 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (macos-latest, amd64, true, true)

syntax error: cannot use assignment w.subscriptions = subscription.Map as value

Check failure on line 911 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (windows-latest, amd64, true, true)

syntax error: cannot use assignment w.subscriptions = subscription.Map as value
}
key := c.EnsureKeyed()
p, ok := w.subscriptions[key]
if !ok {
return ErrSubscriptionNotFound
}
return p.SetState(state)
}

// AddSuccessfulSubscriptions adds subscriptions to the subscription lists that
// has been successfully subscribed
func (w *Websocket) AddSuccessfulSubscriptions(channels ...subscription.Subscription) {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
w.subscriptions = subscription.Map{}
}
for _, c := range channels {
key := c.EnsureKeyed()
c.SetState(subscription.SubscribedState)
w.subscriptions[key] = &c
}
return w.subscriptions.Add(c)

Check failure on line 913 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (ubuntu-latest, 386, true, true)

syntax error: non-declaration statement outside function body

Check failure on line 913 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (ubuntu-latest, amd64, true, false)

syntax error: non-declaration statement outside function body

Check failure on line 913 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (macos-13, amd64, true, true)

syntax error: non-declaration statement outside function body

Check failure on line 913 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (macos-latest, amd64, true, true)

syntax error: non-declaration statement outside function body

Check failure on line 913 in exchanges/stream/websocket.go

View workflow job for this annotation

GitHub Actions / GoCryptoTrader back-end (windows-latest, amd64, true, true)

syntax error: non-declaration statement outside function body
}

// RemoveSubscriptions removes subscriptions from the subscription list
func (w *Websocket) RemoveSubscriptions(channels ...subscription.Subscription) {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
if w.subscriptions == nil {
w.subscriptions = subscription.Map{}
}
for i := range channels {
key := channels[i].EnsureKeyed()
delete(w.subscriptions, key)
func (w *Websocket) RemoveSubscription(s *subscription.Subscription) {
if w == nil || w.subscriptions == nil || key == nil {
return
}
w.subscriptions.Remove(s)
}

// GetSubscription returns a pointer to a copy of the subscription at the key provided
// GetSubscription returns a subscription at the key provided
// returns nil if no subscription is at that key or the key is nil
// Keys can implement subscription.MatchableKey in order to provide custom matching logic
func (w *Websocket) GetSubscription(key any) *subscription.Subscription {
if key == nil || w == nil || w.subscriptions == nil {
if w == nil || w.subscriptions == nil || key == nil {
return nil
}
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
if s := w.getSubscription(key); s != nil {
return s
}
return nil
}

func (w *Websocket) getSubscription(key any) *subscription.Subscription {
if m, ok := key.(subscription.MatchableKey); ok {
return m.Match(w.subscriptions)
}

if s, ok := w.subscriptions[key]; ok {
return s
}

return nil
return w.subscriptions.Get(key)
}

// GetSubscriptions returns a new slice of the subscriptions
func (w *Websocket) GetSubscriptions() []subscription.Subscription {
w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()
subs := make([]subscription.Subscription, 0, len(w.subscriptions))
for _, c := range w.subscriptions {
subs = append(subs, *c)
}
return subs
func (w *Websocket) GetSubscriptions() []*subscription.Subscription {
if key == nil {
return nil
}
return w.subscriptions.List()
}

// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in
Expand Down Expand Up @@ -1054,28 +978,25 @@ func checkWebsocketURL(s string) error {
return nil
}

// checkSubscriptions checks subscriptions against the max subscription limit
// and if the subscription already exists.
func (w *Websocket) checkSubscriptions(subs []subscription.Subscription) error {
// checkSubscriptions checks subscriptions against the max subscription limit and if the subscription already exists
// The subscription state is not considered when counting existing subscriptions
func (w *Websocket) checkSubscriptions(subs []*subscription.Subscription) error {
if len(subs) == 0 {
return errNoSubscriptionsSupplied
}

w.subscriptionMutex.RLock()
defer w.subscriptionMutex.RUnlock()

if w.MaxSubscriptionsPerConnection > 0 && len(w.subscriptions)+len(subs) > w.MaxSubscriptionsPerConnection {
c := w.subscriptions.Len()
if w.MaxSubscriptionsPerConnection > 0 && existing+len(subs) > w.MaxSubscriptionsPerConnection {
return fmt.Errorf("%w: current subscriptions: %v, incoming subscriptions: %v, max subscriptions per connection: %v - please reduce enabled pairs",
errSubscriptionsExceedsLimit,
len(w.subscriptions),
existing,
len(subs),
w.MaxSubscriptionsPerConnection)
}

for i := range subs {
key := subs[i].EnsureKeyed()
if _, ok := w.subscriptions[key]; ok {
return fmt.Errorf("%w for %+v", ErrSubscribedAlready, subs[i])
for _, s := range subs {
if s := w.subscriptions.Get(s); s != nil {
return fmt.Errorf("%w for %s", subscription.ErrDuplicate, s)
}
}

Expand Down
6 changes: 3 additions & 3 deletions exchanges/stream/websocket_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type Websocket struct {
connector func() error

subscriptionMutex sync.RWMutex
subscriptions subscription.Map
subscriptions *subscription.Map
Subscribe chan []subscription.Subscription
Unsubscribe chan []subscription.Subscription

Expand All @@ -60,7 +60,7 @@ type Websocket struct {
Unsubscriber func([]subscription.Subscription) error
// GenerateSubs function for package defined websocket generate
// subscriptions functionality
GenerateSubs func() ([]subscription.Subscription, error)
GenerateSubs func() ([]*subscription.Subscription, error)

DataHandler chan interface{}
ToRoutine chan interface{}
Expand Down Expand Up @@ -109,7 +109,7 @@ type WebsocketSetup struct {
Connector func() error
Subscriber func([]subscription.Subscription) error
Unsubscriber func([]subscription.Subscription) error
GenerateSubscriptions func() ([]subscription.Subscription, error)
GenerateSubscriptions func() ([]*subscription.Subscription, error)
Features *protocol.Features

// Local orderbook buffer config values
Expand Down
1 change: 1 addition & 0 deletions exchanges/subscription/key.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package subscription
16 changes: 16 additions & 0 deletions exchanges/subscription/list.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package subscription

import "slices"

// List is a container of subscription pointers
type List []*Subscription

// Strings returns a sorted slice of subscriptions
func (l List) Strings() []string {
s := make([]string, len(l))
for i := range l {
s[i] = l[i].String()
}
slices.Sort(s)
return s
}
Loading

0 comments on commit 054830e

Please sign in to comment.