Skip to content

Commit

Permalink
Websockets: Add key to websocket subscriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Sep 14, 2023
1 parent 9492412 commit e51317a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
7 changes: 7 additions & 0 deletions exchanges/stream/stream_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,15 @@ type Response struct {
Raw []byte
}

type defaultChannelKey struct {
Channel string
Currency currency.Pair
Asset asset.Item
}

// ChannelSubscription container for streaming subscriptions
type ChannelSubscription struct {
Key any
Channel string
Currency currency.Pair
Asset asset.Item
Expand Down
48 changes: 30 additions & 18 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ func (w *Websocket) Shutdown() error {

// flush any subscriptions from last connection if needed
w.subscriptionMutex.Lock()
w.subscriptions = nil
w.subscriptions = subscriptionMap{}
w.subscriptionMutex.Unlock()

close(w.ShutdownC)
Expand Down Expand Up @@ -521,7 +521,7 @@ func (w *Websocket) FlushChannels() error {
if len(newsubs) != 0 {
// Purge subscription list as there will be conflicts
w.subscriptionMutex.Lock()
w.subscriptions = nil
w.subscriptions = subscriptionMap{}
w.subscriptionMutex.Unlock()
return w.SubscribeToChannels(newsubs)
}
Expand Down Expand Up @@ -850,19 +850,19 @@ func (w *Websocket) GetChannelDifference(genSubs []ChannelSubscription) (sub, un
defer w.subscriptionMutex.Unlock()

oldsubs:
for x := range w.subscriptions {
for _, x := range w.subscriptions {
for y := range genSubs {
if w.subscriptions[x].Equal(&genSubs[y]) {
if x.Equal(&genSubs[y]) {
continue oldsubs
}
}
unsub = append(unsub, w.subscriptions[x])
unsub = append(unsub, x)
}

newsubs:
for x := range genSubs {
for y := range w.subscriptions {
if genSubs[x].Equal(&w.subscriptions[y]) {
for _, y := range w.subscriptions {
if genSubs[x].Equal(&y) {
continue newsubs
}
}
Expand All @@ -881,8 +881,8 @@ func (w *Websocket) UnsubscribeChannels(channels []ChannelSubscription) error {

channels:
for x := range channels {
for y := range w.subscriptions {
if channels[x].Equal(&w.subscriptions[y]) {
for _, y := range w.subscriptions {
if channels[x].Equal(&y) {
continue channels
}
}
Expand Down Expand Up @@ -912,8 +912,8 @@ func (w *Websocket) SubscribeToChannels(channels []ChannelSubscription) error {
}
w.subscriptionMutex.Lock()
for x := range channels {
for y := range w.subscriptions {
if channels[x].Equal(&w.subscriptions[y]) {
for _, y := range w.subscriptions {
if channels[x].Equal(&y) {
w.subscriptionMutex.Unlock()
return fmt.Errorf("%s websocket: %v already subscribed",
w.exchangeName,
Expand All @@ -932,7 +932,17 @@ func (w *Websocket) SubscribeToChannels(channels []ChannelSubscription) error {
// has been successfully subscribed
func (w *Websocket) AddSuccessfulSubscriptions(channels ...ChannelSubscription) {
w.subscriptionMutex.Lock()
w.subscriptions = append(w.subscriptions, channels...)
for i := range channels {
key := channels[i].Key
if key == nil {
key = defaultChannelKey{
Channel: channels[i].Channel,
Asset: channels[i].Asset,
Currency: channels[i].Currency,
}
}
w.subscriptions[key] = channels[i]
}
w.subscriptionMutex.Unlock()
}

Expand All @@ -942,11 +952,9 @@ func (w *Websocket) RemoveSuccessfulUnsubscriptions(channels ...ChannelSubscript
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
for x := range channels {
for y := range w.subscriptions {
if channels[x].Equal(&w.subscriptions[y]) {
w.subscriptions[y] = w.subscriptions[len(w.subscriptions)-1]
w.subscriptions[len(w.subscriptions)-1] = ChannelSubscription{}
w.subscriptions = w.subscriptions[:len(w.subscriptions)-1]
for _, y := range w.subscriptions {
if channels[x].Equal(&y) {
delete(w.subscriptions, y.Key)
break
}
}
Expand All @@ -964,7 +972,11 @@ func (w *ChannelSubscription) Equal(s *ChannelSubscription) bool {
func (w *Websocket) GetSubscriptions() []ChannelSubscription {
w.subscriptionMutex.Lock()
defer w.subscriptionMutex.Unlock()
return append(w.subscriptions[:0:0], w.subscriptions...)
subs := make([]ChannelSubscription, len(w.subscriptions))
for _, c := range w.subscriptions {
subs = append(subs, c)
}
return subs
}

// SetCanUseAuthenticatedEndpoints sets canUseAuthenticatedEndpoints val in
Expand Down
4 changes: 3 additions & 1 deletion exchanges/stream/websocket_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ const (
UnhandledMessage = " - Unhandled websocket message: "
)

type subscriptionMap map[any]ChannelSubscription

// Websocket defines a return type for websocket connections via the interface
// wrapper for routine processing
type Websocket struct {
Expand All @@ -47,7 +49,7 @@ type Websocket struct {
connector func() error

subscriptionMutex sync.Mutex
subscriptions []ChannelSubscription
subscriptions subscriptionMap
Subscribe chan []ChannelSubscription
Unsubscribe chan []ChannelSubscription

Expand Down

0 comments on commit e51317a

Please sign in to comment.