Skip to content

Commit

Permalink
Stream: Use types for message filters
Browse files Browse the repository at this point in the history
  • Loading branch information
gbjk committed Nov 22, 2024
1 parent 274ae8f commit 33bb31e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 52 deletions.
82 changes: 41 additions & 41 deletions exchanges/gateio/gateio_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,18 +209,18 @@ func (g *Gateio) Setup(exch *config.Exchange) error {
}
// Spot connection
err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{
URL: gateioWebsocketEndpoint,
RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit),
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
Handler: g.WsHandleSpotData,
Subscriber: g.Subscribe,
Unsubscriber: g.Unsubscribe,
GenerateSubscriptions: g.generateSubscriptionsSpot,
Connector: g.WsConnectSpot,
Authenticate: g.authenticateSpot,
WrapperDefinedConnectionSignature: asset.Spot,
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
URL: gateioWebsocketEndpoint,
RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit),
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
Handler: g.WsHandleSpotData,
Subscriber: g.Subscribe,
Unsubscriber: g.Unsubscribe,
GenerateSubscriptions: g.generateSubscriptionsSpot,
Connector: g.WsConnectSpot,
Authenticate: g.authenticateSpot,
MessageFilter: stream.NewAssetFilter(asset.Spot),
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
})
if err != nil {
return err
Expand All @@ -234,12 +234,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error {
Handler: func(ctx context.Context, incoming []byte) error {
return g.WsHandleFuturesData(ctx, incoming, asset.Futures)
},
Subscriber: g.FuturesSubscribe,
Unsubscriber: g.FuturesUnsubscribe,
GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) },
Connector: g.WsFuturesConnect,
WrapperDefinedConnectionSignature: asset.USDTMarginedFutures,
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
Subscriber: g.FuturesSubscribe,
Unsubscriber: g.FuturesUnsubscribe,
GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) },
Connector: g.WsFuturesConnect,
MessageFilter: stream.NewAssetFilter(asset.USDTMarginedFutures),
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
})
if err != nil {
return err
Expand All @@ -254,12 +254,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error {
Handler: func(ctx context.Context, incoming []byte) error {
return g.WsHandleFuturesData(ctx, incoming, asset.Futures)
},
Subscriber: g.FuturesSubscribe,
Unsubscriber: g.FuturesUnsubscribe,
GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) },
Connector: g.WsFuturesConnect,
WrapperDefinedConnectionSignature: asset.CoinMarginedFutures,
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
Subscriber: g.FuturesSubscribe,
Unsubscriber: g.FuturesUnsubscribe,
GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) },
Connector: g.WsFuturesConnect,
MessageFilter: stream.NewAssetFilter(asset.CoinMarginedFutures),
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
})
if err != nil {
return err
Expand All @@ -275,30 +275,30 @@ func (g *Gateio) Setup(exch *config.Exchange) error {
Handler: func(ctx context.Context, incoming []byte) error {
return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures)
},
Subscriber: g.DeliveryFuturesSubscribe,
Unsubscriber: g.DeliveryFuturesUnsubscribe,
GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions,
Connector: g.WsDeliveryFuturesConnect,
WrapperDefinedConnectionSignature: asset.DeliveryFutures,
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
Subscriber: g.DeliveryFuturesSubscribe,
Unsubscriber: g.DeliveryFuturesUnsubscribe,
GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions,
Connector: g.WsDeliveryFuturesConnect,
MessageFilter: stream.NewAssetFilter(asset.DeliveryFutures),
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
})
if err != nil {
return err
}

// Futures connection - Options
return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{
URL: optionsWebsocketURL,
RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit),
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
Handler: g.WsHandleOptionsData,
Subscriber: g.OptionsSubscribe,
Unsubscriber: g.OptionsUnsubscribe,
GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions,
Connector: g.WsOptionsConnect,
WrapperDefinedConnectionSignature: asset.Options,
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
URL: optionsWebsocketURL,
RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit),
ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout,
ResponseMaxLimit: exch.WebsocketResponseMaxLimit,
Handler: g.WsHandleOptionsData,
Subscriber: g.OptionsSubscribe,
Unsubscriber: g.OptionsUnsubscribe,
GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions,
Connector: g.WsOptionsConnect,
MessageFilter: stream.NewAssetFilter(asset.Options),
BespokeGenerateMessageID: g.GenerateWebsocketMessageID,
})
}

Expand Down
25 changes: 25 additions & 0 deletions exchanges/stream/filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package stream

import "github.com/thrasher-corp/gocryptotrader/exchanges/asset"

type MessageFilter interface {
IsAppropriate(any) bool
}

type AssetFilter struct {
Asset asset.Item
}

var _ MessageFilter = &AssetFilter{}

func NewAssetFilter(a asset.Item) *AssetFilter {
return &AssetFilter{a}
}

func (f *AssetFilter) IsAppropriate(aAny any) bool {
a, ok := aAny.(asset.Item)
if !ok {
return false
}
return a == f.Asset
}
6 changes: 2 additions & 4 deletions exchanges/stream/stream_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,8 @@ type ConnectionSetup struct {
BespokeGenerateMessageID func(highPrecision bool) int64
// Authenticate will be called to authenticate the connection
Authenticate func(ctx context.Context, conn Connection) error
// WrapperDefinedConnectionSignature is any type that will match to a specific connection. This could be an asset
// type `asset.Spot`, a string type denoting the individual URL, an authenticated or unauthenticated string or a
// mixture of these.
WrapperDefinedConnectionSignature any
// MessageFilter allows message routing to the appropriate connection
MessageFilter MessageFilter
}

// ConnectionWrapper contains the connection setup details to be used when
Expand Down
9 changes: 2 additions & 7 deletions exchanges/stream/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/url"
"reflect"
"slices"
"sync"
"time"
Expand Down Expand Up @@ -266,15 +265,11 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error {
return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset)
}

if c.WrapperDefinedConnectionSignature != nil && !reflect.TypeOf(c.WrapperDefinedConnectionSignature).Comparable() {
return errWrapperDefinedConnectionSignatureNotComparable
}

for x := range w.connectionManager {
// Below allows for multiple connections to the same URL with different outbound request signatures. This
// allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on
// a spot connection.
if w.connectionManager[x].Setup.URL == c.URL && c.WrapperDefinedConnectionSignature == w.connectionManager[x].Setup.WrapperDefinedConnectionSignature {
if w.connectionManager[x].Setup.URL == c.URL && c.MessageFilter == w.connectionManager[x].Setup.MessageFilter {
return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication)
}
}
Expand Down Expand Up @@ -1287,7 +1282,7 @@ func (w *Websocket) GetConnection(connSignature any) (Connection, error) {
}

for _, wrapper := range w.connectionManager {
if wrapper.Setup.WrapperDefinedConnectionSignature == connSignature {
if wrapper.Setup.MessageFilter != nil && wrapper.Setup.MessageFilter.IsAppropriate(connSignature) {
if wrapper.Connection == nil {
return nil, fmt.Errorf("%s: %s %w: %v", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, connSignature)
}
Expand Down

0 comments on commit 33bb31e

Please sign in to comment.