diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 5291449b139..15bc0f985cf 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -209,18 +209,18 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Spot connection err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: gateioWebsocketEndpoint, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleSpotData, - Subscriber: g.Subscribe, - Unsubscriber: g.Unsubscribe, - GenerateSubscriptions: g.generateSubscriptionsSpot, - Connector: g.WsConnectSpot, - Authenticate: g.authenticateSpot, - WrapperDefinedConnectionSignature: asset.Spot, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + URL: gateioWebsocketEndpoint, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleSpotData, + Subscriber: g.Subscribe, + Unsubscriber: g.Unsubscribe, + GenerateSubscriptions: g.generateSubscriptionsSpot, + Connector: g.WsConnectSpot, + Authenticate: g.authenticateSpot, + MessageFilter: stream.NewAssetFilter(asset.Spot), + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -234,12 +234,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, - Connector: g.WsFuturesConnect, - WrapperDefinedConnectionSignature: asset.USDTMarginedFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: g.WsFuturesConnect, + MessageFilter: stream.NewAssetFilter(asset.USDTMarginedFutures), + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -254,12 +254,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, - Connector: g.WsFuturesConnect, - WrapperDefinedConnectionSignature: asset.CoinMarginedFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsFuturesConnect, + MessageFilter: stream.NewAssetFilter(asset.CoinMarginedFutures), + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -275,12 +275,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) }, - Subscriber: g.DeliveryFuturesSubscribe, - Unsubscriber: g.DeliveryFuturesUnsubscribe, - GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, - Connector: g.WsDeliveryFuturesConnect, - WrapperDefinedConnectionSignature: asset.DeliveryFutures, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.DeliveryFuturesSubscribe, + Unsubscriber: g.DeliveryFuturesUnsubscribe, + GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, + Connector: g.WsDeliveryFuturesConnect, + MessageFilter: stream.NewAssetFilter(asset.DeliveryFutures), + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -288,17 +288,17 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Options return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: optionsWebsocketURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleOptionsData, - Subscriber: g.OptionsSubscribe, - Unsubscriber: g.OptionsUnsubscribe, - GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, - Connector: g.WsOptionsConnect, - WrapperDefinedConnectionSignature: asset.Options, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + URL: optionsWebsocketURL, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleOptionsData, + Subscriber: g.OptionsSubscribe, + Unsubscriber: g.OptionsUnsubscribe, + GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + Connector: g.WsOptionsConnect, + MessageFilter: stream.NewAssetFilter(asset.Options), + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) } diff --git a/exchanges/stream/filter.go b/exchanges/stream/filter.go new file mode 100644 index 00000000000..fa17777366d --- /dev/null +++ b/exchanges/stream/filter.go @@ -0,0 +1,25 @@ +package stream + +import "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + +type MessageFilter interface { + IsAppropriate(any) bool +} + +type AssetFilter struct { + Asset asset.Item +} + +var _ MessageFilter = &AssetFilter{} + +func NewAssetFilter(a asset.Item) *AssetFilter { + return &AssetFilter{a} +} + +func (f *AssetFilter) IsAppropriate(aAny any) bool { + a, ok := aAny.(asset.Item) + if !ok { + return false + } + return a == f.Asset +} diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index bf72a42afda..18bc73b9c52 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -86,10 +86,8 @@ type ConnectionSetup struct { BespokeGenerateMessageID func(highPrecision bool) int64 // Authenticate will be called to authenticate the connection Authenticate func(ctx context.Context, conn Connection) error - // WrapperDefinedConnectionSignature is any type that will match to a specific connection. This could be an asset - // type `asset.Spot`, a string type denoting the individual URL, an authenticated or unauthenticated string or a - // mixture of these. - WrapperDefinedConnectionSignature any + // MessageFilter allows message routing to the appropriate connection + MessageFilter MessageFilter } // ConnectionWrapper contains the connection setup details to be used when diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 18bed4ee457..3fd0ea620a7 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/url" - "reflect" "slices" "sync" "time" @@ -266,15 +265,11 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) } - if c.WrapperDefinedConnectionSignature != nil && !reflect.TypeOf(c.WrapperDefinedConnectionSignature).Comparable() { - return errWrapperDefinedConnectionSignatureNotComparable - } - for x := range w.connectionManager { // Below allows for multiple connections to the same URL with different outbound request signatures. This // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on // a spot connection. - if w.connectionManager[x].Setup.URL == c.URL && c.WrapperDefinedConnectionSignature == w.connectionManager[x].Setup.WrapperDefinedConnectionSignature { + if w.connectionManager[x].Setup.URL == c.URL && c.MessageFilter == w.connectionManager[x].Setup.MessageFilter { return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) } } @@ -1287,7 +1282,7 @@ func (w *Websocket) GetConnection(connSignature any) (Connection, error) { } for _, wrapper := range w.connectionManager { - if wrapper.Setup.WrapperDefinedConnectionSignature == connSignature { + if wrapper.Setup.MessageFilter != nil && wrapper.Setup.MessageFilter.IsAppropriate(connSignature) { if wrapper.Connection == nil { return nil, fmt.Errorf("%s: %s %w: %v", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, connSignature) }